The repo.

In this post I’ll offer my own somewhat contrarian explanation of why the Adam optimizer works.

Then I’ll demonstrate my explanation with some experiments on a simple proof-of-concept optimizer I made up, called GradSign.

The Adam optimizer

Adam is a widely-used optimization algorithm that tends to perform very well on deep learning tasks.

Adam is often explained as an extension of stochastic gradient descent (SGD): sample one batch, compute the loss and its gradient, and smooth the result out by taking an exponential moving average of the gradient. Then there’s a step that the standard explanations sort of glide over – something about a second moment (i.e. a variance estimate) for the gradient, something about “adaptive choice of the learning rate” – and then you take your step… and magically end up with a good optimizer.

(A quick Google search turns up plenty of explanations along these lines: for example, here, here, here…)

I’d like to offer a different take on Adam – less calculus, more statistics. My take will suggest a different toy model of an optimizer: not stochastic gradient descent, but an algorithm that only looks at the sign (not the magnitude) of each gradient. I call my optimizer GradSign. I’ll test it out on a simple machine learning task: building an MNIST classifier.

Spoiler alert: No, GradSign doesn’t outperform Adam. In my small experiment, my optimizer performs comparably to Adam, but SGD (surprisingly) outperforms them both.

I hope reading this inspires you to tinker and explore.

This post has two parts:

  • My reinterpretation of Adam
  • An explanation of the new optimizer

You can see experimental results with the new optimizer on Github.

Adam reexplained

The Adam optimizer works by keeping first- and second-moment estimates (exponential moving averages) for each the gradient of each parameter; at each step, those estimates are used to determine the change to that parameter.

We will consider a single parameter, one of millions or billions in a large model: Adam works on each parameter independently, so we don’t need to worry about anything else.

I will follow the notation from Algorithm 1 of the original paper.

The algorithm depends on three parameters, with the following suggested values: \[ \alpha = 0.001 \] \[ \beta_1 = 0.9 \] \[ \beta_2 = 0.999. \] Here $\alpha$ is the learning rate. (We will see that, even though $\alpha$ is called a “learning rate”, it does not have the same units as the learning rate in SGD.) The parameters $\beta_1$ and $\beta_2$ determine the timescales for the two exponential moving averages. (With the suggested values, the first moment is averaged with a decay time of 10 iteration steps, and the second with a decay time of 1000.)

Let $g_t$ be the gradient of our favorite parameter at timestep $t$. (Remember, the gradient at each timestep depends on both the parameter values – which are updated at each step – and the random choice of a fresh batch of data.) The exponential moving average of the first moment (mean, $m$) and second moment (uncentered variance, $v$) are computed recursively: \[ m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t \] \[ v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2. \] Finally, the parameter update is computed as \[ - \alpha \frac{m_t}{\sqrt{v_t} + \epsilon}. \]

Let us unroll the recurrences. (We will also assume, as a mathematical convenience, that the gradients go infinitely far back in time. In reality one has to worry about how to initialize $m_t$ and $v_t$, but after the first few optimizer steps, our assumption will be a reasonable one.) \[ m_t = (1 - \beta_1) \sum_{i = 0}^{\infty} \beta_1^i g_{t - i} \] \[ v_t = (1 - \beta_2) \sum_{i = 0}^{\infty} \beta_2^i g_{t - i}^2 \] These are exponential moving averages – weighted averages of past values of $g_t$ (or its square), the relevance of past values decaying over time with a characteristic time scale (a physicist might say “half-life”) depending on the $\beta$’s.

I will make three simplifications to make the analysis easier. First, I will ignore the $\epsilon \approx 10^{-8}$ that is thrown into the denominator of \[ \text{update} = - \alpha \frac{m_t}{\sqrt{v_t} + \epsilon}. \] for numerical stability reasons. The $\epsilon$ is just there to make sure the algorithm does something sensible when the gradient vanishes.

Second, I will replace the exponential weighted average (with parameters $\beta$) with a simple unweighted average (“moving window”) over the past $n$ terms (where $n$ is the window size). This would be a computational disaster to implement – we would have to store $n$ values per moment per parameter, rather than just $1$ – but it will help us build a clean mental model.

Finally, I will assume $\beta_1 = \beta_2$, or in other words that the two moving averages are computed over the same size of window $n_1 = n_2$. Unlike the first two simplifications, assuming $\beta_1 = \beta_2$ really does change the behavior of the optimizer in a meaningful way. I’ll come back to this later, because I want to understand this simplified model first.

So, let’s say $\beta_1 = \beta_2 = 0.99$, so now $m_t$ and $v_t$ are averages over the past 100 iterations: \[ m_t = \frac{\sum_{i=0}^{99} g_{t-i}}{100} \] \[ v_t = \frac{\sum_{i=0}^{99} g_{t-i}^2}{100} \] and \[ \text{update} = - \alpha \frac{ \sum g_{t-i} / 100 } { \sqrt{ \sum g_{t-i}^2 / 100 } }. \] But now that fraction is something we can understand: it is nothing more than the cosine similarity between the vectors \[ \hat{g} = (g_{t-99}, g_{t-98}, \ldots, g_{t-1}, g_t) \] and \[ \hat{1} = (1, 1, \ldots, 1, 1)! \]

In other words, if we call $\theta$ the angle between those two vectors (my apologies, this is not the same as the $\theta$ in the paper), then the update to our parameter is simply \[ \text{update} = - \alpha \cos \theta. \]

We immediately see:

  • $\alpha$ is the largest possible update to our parameter, and
  • the size of the update is determined by how close the different $g_t$’s are to each other, rather than the size of $g_t$. (The Adam paper calls this a “signal-to-noise ratio”.)

In fact, the cosine similarity is invariant under scaling the $g_t$’s. (Contrast this to classical gradient descent, where the step size is a product of learning rate and gradient, and you have to do lots of extra work to make sure gradients in different parts of the network have the same scale.)

Thinking statistically

When we update our parameter (we’re just focusing on one parameter, remember? the rest will come along for the ride) our goal is to make the loss decrease. Calculus teaches us that the gradient (a first derivative) determines whether the loss will go up or down, but let me reframe the question statistically: How confident are we that making this change will decrease the loss?

In the statistical framing, there are two sources of noise we need to worry about:

  • Sample randomness: each gradient is computed from only a small batch of data.
  • The gradient landscape: the slope might be negative now, but if our step size is too large we may overshoot and end up climbing back uphill.

Looking at the past $n$ steps can protect us against both types of noise! We’re trying to evaluate a statistical hypothesis, like “decreasing this parameter will result in a lower loss against the next randomly chosen batch”. Clearly, a natural statistic is “on how many of the last $n$ batches was this gradient positive?” If the gradient has consistent positive values across batches, we can expect the gradient to have a positive value on the next batch as well.

Similarly, we want to know if the loss landscape is bumpy or smooth. We can think of the past $n$ parameter updates as a sort of random walk across this landscape (not uniformly random of course, but governed by a complicated stochastic process). We’re about to take another step in this random walk. If the gradient has been consistent in the past, we can have more confidence that our gradient will remain positive through the full length of our next step.

So, instead of thinking of an optimizer step as a gradient update, I think of it as a statistical confidence test: how confident are we that this step will result in negative gradient on the next (yet unseen) batch, at the current parameter position, the updated position, and everywhere in between?

If this is what’s going on with Adam, then maybe the size of the gradient doesn’t matter at all. Maybe all that matters is: At how many of the past $n$ timesteps was the gradent positive, and at how many was it negative?

We’ll turn this idea into a new optimizer algorithm soon, but first I want to wrap up a couple of loose ends.

The role of $\beta_1$ and $\beta_2$

Earlier on, we made the simplifying assumption that $\beta_1 = \beta_2$. I told you that we were simplifying away something important, but I didn’t tell you what. Now it’s time to fix that.

In the real world, some good values for $\beta_1$ and $\beta_2$ are \[ \beta_1 = 0.9 \] \[ \beta_2 = 0.999. \] In other words, the first moment estimate (the numerator) averages the gradient over the last 10 or so timesteps, while the second moment (the denominator) averages its square over the last 1000.

Normally, you might think, this won’t make a big difference. But it makes a big difference when the gradient is sparse.

Imagine a parameter that is usually unimportant: its gradient is close to zero. But every so often, the parameter becomes very important, and its gradient gets big. (You might imagine that in a large, complex LLM, this one parameter is responsible for learning one particular thing – and that one thing only rarely shows up in the training data.)

The role of $\beta_2$ is to remember that this gradient has a track record of sudden spikes. When a gradient has spiked in the past, we don’t want to make updates based on small gradients. But we also don’t want to continue making updates based on a gradient from many steps back. The solution is to make $\beta_2$ large (remember the spike and slow down learning for 1000 steps) but keep $\beta_1$ small (stop making updates 10 steps after the spike).

Adam and memory

The Adam optimizer stores two floating-point values (the moment estimates $m_t$ and $v_t$) per parameter. While the forward and backward pass can often be computed in 16-bit precision, the Adam optimizer state requires 32 bits for each of $m_t$ and $v_t$. (Storing optimizer state in 16-bit precision leads to numerical instability and degrades training performance.) In a typical training run, memory usage is dominated by per-parameter costs: 4 bytes for the (full-precision) master copy of the parameter, 2 for the 16-bit downcasted parameter, 2 for the gradient, and 8 for the optimizer state – so the optimizer is responsible for about half the total memory usage.

Wouldn’t it be great if we could use less?

Adam uses 8 bytes per parameter. Here is an optimizer that uses just 1.

GradSign

GradSign is a simple proof-of-concept optimizer that:

  • only uses the sign (+ or -) of the gradient of each parameter, and
  • only keeps one byte of optimizer data, per parameter.

The update code is as follows:

def step(self):
    for n, p in self.named_params:
        new_count = 2 * (p.grad > 0) - 1
        self.grad_counts[n] -= (self.grad_counts[n] + 4) // 8
        self.grad_counts[n] += 8 * new_count

        p.data -= self.lr * (self.grad_counts[n] / 64.0)

The dictionary self.grad_counts stores, for each parameter, a quantized exponential moving average of the sign of the gradient over the past several timesteps. The moving average has a characteristic timescale of 8 timesteps (in other words, $\beta = 0.875$). The quantized moving average is scaled to be between $-64$ and $64$, so that it fits within a signed 8-bit integer.

I chose these parameters to control quantization error in the decay term

self.grad_counts[n] -= (self.grad_counts[n] + 4) // 8.

The decay term will be nonzero as soon as the value self.grad_counts exceeds $\pm 4$. Each update to the average causes the count to change by $\pm 8$, so the grad_counts parameter cannot get stuck in the no-decay region.

Experiment

The experiment code and detailed results are posted on Github.

In summary, I run the GradSign on 1000 batches of 32 samples each, which amounts to a single pass through just over half of the dataset. The resulting model achieves over 98% performance.