<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="3.10.0">Jekyll</generator><link href="/feed.xml" rel="self" type="application/atom+xml" /><link href="/" rel="alternate" type="text/html" /><updated>2026-04-20T23:17:06+00:00</updated><id>/feed.xml</id><title type="html">Brian Lawrence</title><subtitle>doing stuff with computers</subtitle><entry><title type="html">Optimizing matrix multiplciation</title><link href="/opt/" rel="alternate" type="text/html" title="Optimizing matrix multiplciation" /><published>2026-04-01T00:00:00+00:00</published><updated>2026-04-01T00:00:00+00:00</updated><id>/opt</id><content type="html" xml:base="/opt/"><![CDATA[<h1 id="optimizing-matrix-multiplication-on-an-rtx-3050">Optimizing matrix multiplication on an RTX 3050</h1>

<p>I was working on <a href="https://github.com/brian-lawrence-math/mytorch-draft">mytorch</a>, a mock tensor library for GPU, 
and I decided to spend some time optimizing the matrix multiplication function.
After all, matrix multiplication is the most expensive operation in deep learning,
and optimizing batched matrix multiplication is something of a rite of passage in performance computing.</p>

<p>If you’re writing a production tensor library like Pytorch, 
your code needs to be fast in many different settings:</p>
<ul>
  <li>on CPUs and GPUs of various architectures</li>
  <li>with tensors (batches of matrices) of various sizes</li>
  <li>with matrices in row- or column-major order</li>
  <li>with matrices that might not be contiguous in memory.</li>
</ul>

<p>That’s a lot of engineering.  I decided to start by optimizing a single task:
I want to see how fast I can perform a batched matrix
multiplication on two 32-bit float tensors of shape (100, 1000, 1000), 
in row-major order, on an RTX 3050 GPU.</p>

<h2 id="useful-reading">Useful reading</h2>

<p>I drew inspiration from a couple of terrific
<a href="https://www.aleksagordic.com/blog/matmul">blog</a>
<a href="https://siboehm.com/articles/22/CUDA-MMM">posts</a>.</p>

<h2 id="some-benchmarks">Some benchmarks</h2>

<p>CUDA provides a highly optimized library for the Basic Linear Algebra Subroutines (BLAS), called cublas.
Batched matrix multiplication is one of the standard operations in BLAS.
Using cublas, our computation takes 68-72 ms.
(As always, the benchmark varies slightly from run to run, and somewhat more from day to day.)</p>

<p>My first attempt to do the job with no optimizations will take 840 ms, about 12 times slower than cublas.</p>

<p>As of this writing, I’ve sped the calculation up to 72 ms,
achieving parity (or very close to it) with CUBLAS.
I’m working to get further improvement.</p>

<h2 id="baseline-analysis">Baseline analysis</h2>

<p>The RTX 3050 has 18 streaming multiprocessors with 128 cores each, for a total of 2304 cores.
It runs at a boost clock speed of 1.47 GHz, for a total of ~3.39T cycle-cores per second.
At two floating-point operations per clock cycle,
the chip can achieve 6.77 TFLOPS.</p>

<p>The matrix multiplication requires 2*100*1000*1000*1000 = 200 billion floating-point operations.<br />
Assuming 6.774 TFLOPS (and zero overhead) gives a time of 30 ms to perform the 
batched matrix multiplication.
I’ll take 30 ms as a theoretical maximum speed.
Of course, this doesn’t take into account the time required for memory access
and any other computational overhead.</p>

<p>In my experiments, cublas takes just over twice this estimated theoretical best time.</p>

<h2 id="a-simple-first-attempt">A simple first attempt</h2>

<p>I put together a simple CUDA C++ implementation of batched matrix multiplication on pytorch-style tensors,
with arbitrary dimension, shape and stride.</p>

<p>I assigned each thread to compute a single entry of the output tensor
(so our example will spin up 100 million threads),
and I packed 1024 threads per block, the maximum value (for a total of about 98,000 blocks).</p>

<p>The calculation took ~840 ms, 28 times the theoretical best time, and 12 times slower than cublas.</p>

<p>Looks like we have some optimizing to do.</p>

<h2 id="first-thoughts">First thoughts</h2>

<p>The kernel is probably memory-bound, not compute-bound.
It requires a total of 200 billion floating-point operations,
but 800 GB of memory reads.
(At each step in the matrix-multiplication loop,
the kernel reads one float from each matrix, and does one multiplication and one addition:
two FLOPs and 8 bytes.)
Memory reads are much more expensive than compute:
the chip can compute about 6 TFLOPS, but memory bandwidth is only about 168 GB/s.
The memory performance will be slightly faster due to caching,
but I still expect memory to be the bottleneck.</p>

<p>The natural idea is to try thread coarsening or tiling.
I want to load some input data once and compute on it many times.
If one thread is responsible for computing a small box of matrix entries, rather than just one,
that thread can use its memory access more efficiently.
And by using shared memory
(low-latency memory which is shared among all threads in a block)
I can arrange for several threads to collaborate using data that is only loaded (from slow global memory) once.</p>

<p>But first I want to pick some low-hanging fruit.</p>

<h2 id="cutting-down-on-shape-and-stride-calculations">Cutting down on shape-and-stride calculations</h2>

<p>The kernel includes logic to handle arbitrary shapes and strides.
I think this logic is imposing a lot of unnecessary cost.</p>

<p>To start with, the shape-and-stride calculation is being done in the kernel:
each of my 100 million threads is reading the shapes and strides of the input tensors
and computing the index of the one entry it needs to access.
That’s a lot of repeated calculation.
Worse yet, the shapes and strides are stored in vectors
whose length (the dimension of the tensor) is unknown at compile time.
This means that the vectors are stored in global memory, resulting in
a lot of unnecessary memory access.
(Actually, since these values are accessed so often, 
they are probably stored in a low-level cache…)</p>

<p>Most matrix multiplication in practice works on batched matrices with a simple structure:
the matrices are contiguous in memory, in row-major order;
and the “matrix dimensions” are the two dimensions with the smallest strides.
So, I’ll try to optimize a simple case first:
a batched matrix product of two three-dimensional (batch, row, col) tensors.</p>

<p>In any case, I wrote a new kernel matmul_3d()
that assumes its inputs are contiguous three-dimensional tensors,
and accepts the shape directly as argument to the function.
The result: from 840ms down to 800ms.</p>

<p>OK, let’s try to get more economies of scale.</p>

<h2 id="profiling">Profiling</h2>
<p>Nvidia provides a powerful profiler, ncu.
The profiler shows all sorts of metrics, including memory throughput, cache (L1 and L2) throughput, compute throughput, occupancy and workload statistics…
It even offers suggestions for optimization.</p>

<p>In my experience, at this stage, 
it’s best to focus on writing efficient code.
The profiler is a great tool later on, when I have specific questions about resource usage.</p>

<p>Indeed, this profiler’s first suggestion is that I change the number of threads per block 
to increase occupancy.
That’s not the biggest priority at this point – and indeed,
if I make the change the profiler suggests, performance gets worse.</p>

<p>So: I’m going to ignore ncu’s advice and get back to writing solid code.
The big bottleneck is memory access, so that’s where I’m going to focus.
But the profiler will come back later on.</p>

<h2 id="improving-memory-efficiency">Improving memory efficiency</h2>

<p>To start with, I’m going to make two improvements to the kernel.</p>
<ul>
  <li>Load inputs into shared memory, in batches, and</li>
  <li>make each thread responsible for more than one output entry.</li>
</ul>

<p>I’m going to make configurable parameters for:</p>
<ul>
  <li>TM and TN – these determine how many output values each thread will calculate;</li>
  <li>TPBM and TPBN – these determine how many threads in a block; and</li>
  <li>BK – the multiplication loop size.</li>
</ul>

<p>This last parameter needs some explanation.
Each block of threads is responsible for computing a (TM*TPBM) by (TN*TPBN) submatrix
of the output matrix.
To do this it will need to access some number of full rows of the first input,
and some number of full columns of the second;
then it will loop over the columns of the first input (and the rows of the second).</p>

<p>There might not be enough room in shared memory for all the rows and columns that need to be loaded.
So, instead of being loaded in full, they will be loaded in blocks of BK.
In other words: matrix multiplication involves a summation over the intermediate dimension;
we will break that summation into chunks of size BK,
and compute partial sums one chunk at a time.</p>

<h2 id="improving-memory-access-patterns">Improving memory access patterns</h2>

<p>Global memory is stored in DRAM; shared memory (and the L1 cache) are stored in SRAM.
DRAM reads memory in consecutive 32-byte chunks; if I don’t use all 32 bytes, I’m wasting bandwidth.
SRAM memory is stored in 32 banks (each 4 bytes wide – so for example bank 0 is responible for addresses 0, 1, 2, 3 modulo 128).
In a single read, SRAM can read any 4-bite word from each of its 32 banks, independently.
So we want each thread in a warp to try to read from a different bank.
If multiple threads request data from the same bank, the result is a “bank conflict”:
the SRAM will have to perform multiple physical reads before the result can be returned.</p>

<p>In both situations, a good pattern is for the 32 threads in a warp to access consecutive floats in memory:</p>

<figure class="highlight"><pre><code class="language-cpp" data-lang="cpp"><span class="kt">int</span> <span class="n">idx</span> <span class="o">=</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
<span class="n">data</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="p">...</span> <span class="p">.</span></code></pre></figure>

<p>First, I’ll make sure the number of rows in each block of threads is 32 (at least when the matrices have &gt;= 32 rows);
this means each warp is exactly one row.</p>

<p>Now let’s plan how to arrange memory and threads.  As far as memory:</p>
<ul>
  <li>The input tensors are already laid out contiguous in row-major order, we can’t change that;</li>
  <li>The result tensor is also in row-major order; we can’t touch it either;</li>
  <li>But the “shared” tensors (copies of tiles of input tensors that reside in shared memory) can be arranged how we like.</li>
</ul>

<p>And as far as thread arrangement, we have to decide how to divide up each of these three operations among threads in the block:</p>
<ul>
  <li>Copy input tensor a into shared memory;</li>
  <li>Copy input tensor b into shared memory;</li>
  <li>Perform the “multiplication loop” to compute entries of output tensor.</li>
</ul>

<p>Let’s start with the entries of the output tensor.  Conceptually it looks something like the following.</p>

<figure class="highlight"><pre><code class="language-cpp" data-lang="cpp"><span class="kt">float</span> <span class="n">cml_sum</span> <span class="o">=</span> <span class="mf">0.0</span><span class="n">f</span><span class="p">;</span>
<span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">loop_idx</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="p">...</span> <span class="p">)</span> <span class="p">{</span>
    <span class="n">cml_sum</span> <span class="o">+=</span> <span class="n">a_shared</span><span class="p">[</span><span class="n">row</span><span class="p">][</span><span class="n">loop_idx</span><span class="p">]</span> <span class="o">*</span> <span class="n">b_shared</span><span class="p">[</span><span class="n">loop_idx</span><span class="p">][</span><span class="n">col</span><span class="p">];</span>
<span class="p">}</span>
<span class="n">result</span><span class="p">[</span><span class="n">row</span><span class="p">][</span><span class="n">col</span><span class="p">]</span> <span class="o">+=</span> <span class="n">cml_sum</span><span class="p">;</span></code></pre></figure>

<p>A natural choice is to have consecutive threads operate on the same “row” and consecutive “col”:
this way the global memory writes at line 4 are efficient, with all 32 threads in the warp
writing to one 128-bit line of global memory (or two lines, if the alignment isn’t right).
Assuming b_shared is stored in row-major order, the shared memory reads in line 2 are good as well:
all threads read the same entry from a_shared, which is efficient (it’s called “broadcasting”),
and the 32 threads write to 32 consecutive entries of b_shared.</p>

<p>As for copying global ‘a’ and ‘b’ into shared ‘a_shared’ and ‘b_shared’: it’s the same idea.
I store ‘a_shared’ and ‘b_shared’ in row-major order, so data that is contiguous in ‘a’ 
is also contiguous in ‘a_shared’.
Then I arrange for all the threads in the block to handle consecutive floats, one float each.</p>

<h1 id="thread-local-results-in-registers">Thread-local results in registers</h1>

<p>The kernel is still suffering from global memory writes inside a tight loop:
if you look back up at the “conceptual” code, the line</p>

<figure class="highlight"><pre><code class="language-cpp" data-lang="cpp"><span class="n">result</span><span class="p">[</span><span class="n">row</span><span class="p">][</span><span class="n">col</span><span class="p">]</span> <span class="o">+=</span> <span class="n">cml_sum</span><span class="p">;</span></code></pre></figure>

<p>entails a read and a write to slow global DRAM.</p>

<p>Better: store per-thread results in registers until the calculation is done.
In pseudocode:</p>

<figure class="highlight"><pre><code class="language-cpp" data-lang="cpp"><span class="kt">float</span> <span class="n">tmp</span><span class="p">[</span><span class="n">TM</span><span class="p">][</span><span class="n">TN</span><span class="p">];</span>
<span class="k">for</span> <span class="p">(</span><span class="n">k</span> <span class="p">...)</span> <span class="p">{</span>
    <span class="k">for</span> <span class="p">(</span><span class="n">m_ctr</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">m_ctr</span> <span class="o">&lt;</span> <span class="n">TM</span><span class="p">;</span> <span class="n">m_ctr</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
        <span class="k">for</span> <span class="p">(</span><span class="n">n_ctr</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">n_ctr</span> <span class="o">&lt;</span> <span class="n">TN</span><span class="p">;</span> <span class="n">n_ctr</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
            <span class="n">tmp</span><span class="p">[</span><span class="n">m_ctr</span><span class="p">][</span><span class="n">n_ctr</span><span class="p">]</span> <span class="o">+=</span> <span class="n">A_s</span><span class="p">[</span><span class="n">m</span><span class="p">][</span><span class="n">k</span><span class="p">]</span> <span class="o">*</span> <span class="n">B_s</span><span class="p">[</span><span class="n">k</span><span class="p">][</span><span class="n">n</span><span class="p">];</span>
        <span class="p">}</span>
    <span class="p">}</span>
<span class="p">}</span>

<span class="k">for</span> <span class="p">(</span><span class="n">m_ctr</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">m_ctr</span> <span class="o">&lt;</span> <span class="n">TM</span><span class="p">;</span> <span class="n">m_ctr</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
    <span class="k">for</span> <span class="p">(</span><span class="n">n_ctr</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">n_ctr</span> <span class="o">&lt;</span> <span class="n">TN</span><span class="p">;</span> <span class="n">n_ctr</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
        <span class="n">result</span><span class="p">[...]</span> <span class="o">=</span> <span class="n">tmp</span><span class="p">[...];</span>
    <span class="p">}</span>
<span class="p">}</span></code></pre></figure>

<h1 id="parameter-tuning">Parameter tuning</h1>

<p>At this point I have a bunch of different parameters:</p>
<ul>
  <li>the dimensions of a tile TM, TN;</li>
  <li>the number of tiles per block (in row and column dimensions);</li>
  <li>the number of $k$ to loop over.</li>
</ul>

<p>After some manual experimentation with these parameters I find the following values.</p>

<figure class="highlight"><pre><code class="language-cpp" data-lang="cpp"><span class="cp">#define TM (8)
#define TN (8)
#define TPBM (16)
#define TPBN (16)
#define BM (TM * TPBM)
#define BN (TN * TPBN)
#define BK (16)</span></code></pre></figure>

<p>My batched matrix multiplication is now at 107 ms.
That’s still about 1.5x the time the CUBLAS kernel takes, but I have one more trick up my sleeve.</p>

<h1 id="vectorized-loads">Vectorized loads</h1>

<p>GPUs offer a single instruction to read 128 bits (= 16 bytes = 4 floats) at once.
The catch is that the read address has to be 16-byte aligned.
If you promise the compiler that your address is 16-byte aligned,
the compiler will give you vectorized reads and writes.
One easy way to do this, syntactically, is with the <code class="language-plaintext highlighter-rouge">float4</code> type,
which represents a vector (an array?) – anyway, it represents 4 floats in a row,
aligned to a multiple of 16 bytes.</p>

<p>Here’s some mock code to show how to do it: if you’re copying a bunch of floats</p>

<figure class="highlight"><pre><code class="language-cpp" data-lang="cpp"><span class="kt">float</span> <span class="o">*</span><span class="n">A</span><span class="p">,</span> <span class="o">*</span><span class="n">B</span><span class="p">;</span>
<span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">;</span> <span class="n">i</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
    <span class="n">A</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">B</span><span class="p">[</span><span class="n">i</span><span class="p">];</span>
<span class="p">}</span></code></pre></figure>

<p>you can simply cast the pointers from <code class="language-plaintext highlighter-rouge">float*</code> to <code class="language-plaintext highlighter-rouge">float4*</code>, like so:</p>

<figure class="highlight"><pre><code class="language-cpp" data-lang="cpp"><span class="c1">// gentle reminder that things need to be aligned</span>
<span class="n">assert</span><span class="p">((</span><span class="kt">uintptr_t</span><span class="p">)</span> <span class="n">A</span> <span class="o">%</span> <span class="mi">16</span> <span class="o">==</span> <span class="mi">0</span><span class="p">);</span>

<span class="c1">// and that I'm not handling edge effects</span>
<span class="n">assert</span><span class="p">(</span><span class="n">N</span> <span class="o">%</span> <span class="mi">4</span> <span class="o">==</span> <span class="mi">0</span><span class="p">);</span>

<span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">;</span> <span class="n">i</span> <span class="o">+=</span> <span class="mi">4</span><span class="p">)</span> <span class="p">{</span>
    <span class="o">*</span><span class="k">reinterpret_cast</span><span class="o">&lt;</span><span class="n">float4</span><span class="o">*&gt;</span><span class="p">(</span><span class="o">&amp;</span><span class="n">A</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="o">=</span> <span class="o">*</span><span class="k">reinterpret_cast</span><span class="o">&lt;</span><span class="n">float4</span><span class="o">*&gt;</span><span class="p">(</span><span class="o">&amp;</span><span class="n">B</span><span class="p">[</span><span class="n">i</span><span class="p">]);</span>
<span class="p">}</span></code></pre></figure>

<p>In other words, we’ve replaced $N$ assignments <code class="language-plaintext highlighter-rouge">float = float</code>
with $N/4$ assignments <code class="language-plaintext highlighter-rouge">float4 = float4</code>.</p>

<p>And here’s what the change looks like in my kernel.
There’s all sorts of messy indexing going on thanks to tiling in various dimensions,
but notice how simple it is to change plain <code class="language-plaintext highlighter-rouge">float</code> loads to vectorized <code class="language-plaintext highlighter-rouge">float4</code> loads:
I really just multiply the step size by 4 and hit the pointers with <code class="language-plaintext highlighter-rouge">reinterpret_cast</code>.</p>

<p>Before:</p>

<figure class="highlight"><pre><code class="language-cpp" data-lang="cpp"><span class="c1">// A_s[i, j] = A[a_row + i, k + j]</span>
<span class="k">for</span> <span class="p">(</span><span class="kt">size_t</span> <span class="n">a_idx</span> <span class="o">=</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span> <span class="n">a_idx</span> <span class="o">&lt;</span> <span class="n">BM</span> <span class="o">*</span> <span class="n">BK</span><span class="p">;</span> <span class="n">a_idx</span> <span class="o">+=</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span><span class="p">)</span> <span class="p">{</span>
  <span class="kt">size_t</span> <span class="n">i</span> <span class="o">=</span> <span class="n">a_idx</span> <span class="o">/</span> <span class="n">BK</span><span class="p">;</span>
  <span class="kt">size_t</span> <span class="n">j</span> <span class="o">=</span> <span class="n">a_idx</span> <span class="o">%</span> <span class="n">BK</span><span class="p">;</span>
  <span class="k">if</span> <span class="p">(</span><span class="n">k0</span> <span class="o">+</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">K</span> <span class="o">&amp;&amp;</span> <span class="n">i</span> <span class="o">+</span> <span class="n">a_row_base</span> <span class="o">&lt;</span> <span class="n">M</span><span class="p">)</span> <span class="p">{</span>
    <span class="n">A_s</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">BK</span> <span class="o">+</span> <span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">])</span> <span class="o">+</span> <span class="n">k0</span> <span class="o">+</span> <span class="n">j</span><span class="p">];</span>
  <span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
    <span class="n">A_s</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">BK</span> <span class="o">+</span> <span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
  <span class="p">}</span>
<span class="p">}</span></code></pre></figure>

<p>After:</p>

<figure class="highlight"><pre><code class="language-cpp" data-lang="cpp"><span class="c1">// A_s[i, j] = A[a_row + i, k + j]</span>
<span class="k">for</span> <span class="p">(</span><span class="kt">size_t</span> <span class="n">a_idx</span> <span class="o">=</span> <span class="mi">4</span> <span class="o">*</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span> <span class="n">a_idx</span> <span class="o">&lt;</span> <span class="n">BM</span> <span class="o">*</span> <span class="n">BK</span><span class="p">;</span> <span class="n">a_idx</span> <span class="o">+=</span> <span class="mi">4</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span><span class="p">)</span> <span class="p">{</span>
  <span class="kt">size_t</span> <span class="n">i</span> <span class="o">=</span> <span class="n">a_idx</span> <span class="o">/</span> <span class="n">BK</span><span class="p">;</span>
  <span class="kt">size_t</span> <span class="n">j</span> <span class="o">=</span> <span class="n">a_idx</span> <span class="o">%</span> <span class="n">BK</span><span class="p">;</span>
  <span class="k">if</span> <span class="p">(</span><span class="n">k0</span> <span class="o">+</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">K</span> <span class="o">&amp;&amp;</span> <span class="n">i</span> <span class="o">+</span> <span class="n">a_row_base</span> <span class="o">&lt;</span> <span class="n">M</span><span class="p">)</span> <span class="p">{</span>
    <span class="o">*</span><span class="k">reinterpret_cast</span><span class="o">&lt;</span><span class="n">float4</span><span class="o">*&gt;</span><span class="p">(</span><span class="o">&amp;</span><span class="n">A_s</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">BK</span> <span class="o">+</span> <span class="n">j</span><span class="p">])</span> <span class="o">=</span> <span class="o">*</span><span class="k">reinterpret_cast</span><span class="o">&lt;</span><span class="n">float4</span><span class="o">*&gt;</span><span class="p">(</span><span class="o">&amp;</span><span class="n">A</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">])</span> <span class="o">+</span> <span class="n">k0</span> <span class="o">+</span> <span class="n">j</span><span class="p">]);</span>
  <span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
    <span class="n">A_s</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">BK</span> <span class="o">+</span> <span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
    <span class="n">A_s</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">BK</span> <span class="o">+</span> <span class="n">j</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
    <span class="n">A_s</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">BK</span> <span class="o">+</span> <span class="n">j</span> <span class="o">+</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
    <span class="n">A_s</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">BK</span> <span class="o">+</span> <span class="n">j</span> <span class="o">+</span> <span class="mi">3</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
  <span class="p">}</span>
<span class="p">}</span></code></pre></figure>

<p>Of course, there’s some additional boilerplate to check bounds and alignment.
Nvidia guarantees that cudaMalloc returns 256-byte aligned memory,
but we still need to worry about the dimensions of the matrix (which could be arbitrary),
and we’d better make sure we choose the parameter $BK$ to be a multiple of 4.
I’m deliberately running this initial benchmark on 1000-by-1000 matrices
so I don’t have to worry about edge effects.</p>

<p>Vectorized loads from global into shared memory shave bring the runtime down to 72 ms:
we’ve achieved parity with CUBLAS (on a good day).</p>

<h1 id="the-existing-kernel-in-code">The existing kernel, in code</h1>

<p>Here it is.</p>

<figure class="highlight"><pre><code class="language-cpp" data-lang="cpp"><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">matmul_tiled_2</span><span class="p">(</span><span class="n">ContiguousTensor3d_Device</span> <span class="n">a</span><span class="p">,</span>
                               <span class="n">ContiguousTensor3d_Device</span> <span class="n">b</span><span class="p">,</span>
                               <span class="n">ContiguousTensor3d_Device</span> <span class="n">res</span><span class="p">)</span> <span class="p">{</span>
  <span class="c1">// blocks: (batch, row, col)</span>
  <span class="c1">// threads: (n_threads, 1, 1)  -- I will manage indexing myself</span>

  <span class="kt">size_t</span> <span class="n">M</span> <span class="o">=</span> <span class="n">a</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">K</span> <span class="o">=</span> <span class="n">a</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">N</span> <span class="o">=</span> <span class="n">b</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">];</span>
  <span class="kt">size_t</span> <span class="n">a_row_base</span> <span class="o">=</span> <span class="n">BM</span> <span class="o">*</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">y</span><span class="p">;</span>
  <span class="kt">size_t</span> <span class="n">b_col_base</span> <span class="o">=</span> <span class="n">BN</span> <span class="o">*</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">z</span><span class="p">;</span>
  <span class="kt">size_t</span> <span class="n">batch_idx</span> <span class="o">=</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
  <span class="kt">float</span> <span class="o">*</span><span class="n">A</span> <span class="o">=</span>
      <span class="n">a</span><span class="p">.</span><span class="n">data</span> <span class="o">+</span> <span class="n">batch_idx</span> <span class="o">*</span> <span class="n">a</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">a</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">+</span> <span class="n">a_row_base</span> <span class="o">*</span> <span class="n">a</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">];</span>
  <span class="kt">float</span> <span class="o">*</span><span class="n">B</span> <span class="o">=</span> <span class="n">b</span><span class="p">.</span><span class="n">data</span> <span class="o">+</span> <span class="n">batch_idx</span> <span class="o">*</span> <span class="n">b</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">b</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">+</span> <span class="n">b_col_base</span><span class="p">;</span>
  <span class="kt">float</span> <span class="o">*</span><span class="n">RES</span> <span class="o">=</span> <span class="n">res</span><span class="p">.</span><span class="n">data</span> <span class="o">+</span> <span class="n">batch_idx</span> <span class="o">*</span> <span class="n">res</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">res</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">+</span>
               <span class="n">a_row_base</span> <span class="o">*</span> <span class="n">res</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">+</span> <span class="n">b_col_base</span><span class="p">;</span>

  <span class="n">__shared__</span> <span class="kt">float</span> <span class="n">A_s</span><span class="p">[</span><span class="n">BM</span> <span class="o">*</span> <span class="n">BK</span><span class="p">];</span>
  <span class="n">__shared__</span> <span class="kt">float</span> <span class="n">B_s</span><span class="p">[</span><span class="n">BK</span> <span class="o">*</span> <span class="n">BN</span><span class="p">];</span>

  <span class="c1">// store results in registers</span>
  <span class="c1">// each thread will be responsible for TM * TN entries of C...</span>
  <span class="c1">// TM rows and TN cols, the rows strided every TPBM, the cols strided every</span>
  <span class="c1">// TPBN</span>
  <span class="kt">float</span> <span class="n">tmp</span><span class="p">[</span><span class="n">TM</span> <span class="o">*</span> <span class="n">TN</span><span class="p">]</span> <span class="o">=</span> <span class="p">{</span><span class="mi">0</span><span class="p">};</span>

  <span class="kt">size_t</span> <span class="n">this_thread_row</span> <span class="o">=</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">/</span> <span class="n">TPBN</span><span class="p">;</span>
  <span class="kt">size_t</span> <span class="n">this_thread_col</span> <span class="o">=</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">%</span> <span class="n">TPBN</span><span class="p">;</span>

  <span class="k">for</span> <span class="p">(</span><span class="kt">size_t</span> <span class="n">k0</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">k0</span> <span class="o">&lt;</span> <span class="n">K</span><span class="p">;</span> <span class="n">k0</span> <span class="o">+=</span> <span class="n">BK</span><span class="p">)</span> <span class="p">{</span>
    <span class="c1">// load a and b</span>
    <span class="c1">// A_s[i, j] = A[a_row + i, k + j]</span>

    <span class="c1">// values to fill: BM * BK</span>
    <span class="c1">// threads: TPB</span>
    <span class="c1">// want: blockDim.x div. by BK</span>
    <span class="k">for</span> <span class="p">(</span><span class="kt">size_t</span> <span class="n">a_idx</span> <span class="o">=</span> <span class="mi">4</span> <span class="o">*</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span> <span class="n">a_idx</span> <span class="o">&lt;</span> <span class="n">BM</span> <span class="o">*</span> <span class="n">BK</span><span class="p">;</span> <span class="n">a_idx</span> <span class="o">+=</span> <span class="mi">4</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span><span class="p">)</span> <span class="p">{</span>
      <span class="kt">size_t</span> <span class="n">i</span> <span class="o">=</span> <span class="n">a_idx</span> <span class="o">/</span> <span class="n">BK</span><span class="p">;</span>
      <span class="kt">size_t</span> <span class="n">j</span> <span class="o">=</span> <span class="n">a_idx</span> <span class="o">%</span> <span class="n">BK</span><span class="p">;</span>
      <span class="k">if</span> <span class="p">(</span><span class="n">k0</span> <span class="o">+</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">K</span> <span class="o">&amp;&amp;</span> <span class="n">i</span> <span class="o">+</span> <span class="n">a_row_base</span> <span class="o">&lt;</span> <span class="n">M</span><span class="p">)</span> <span class="p">{</span>
        <span class="o">*</span><span class="k">reinterpret_cast</span><span class="o">&lt;</span><span class="n">float4</span><span class="o">*&gt;</span><span class="p">(</span><span class="o">&amp;</span><span class="n">A_s</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">BK</span> <span class="o">+</span> <span class="n">j</span><span class="p">])</span> <span class="o">=</span> <span class="o">*</span><span class="k">reinterpret_cast</span><span class="o">&lt;</span><span class="n">float4</span><span class="o">*&gt;</span><span class="p">(</span><span class="o">&amp;</span><span class="n">A</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">])</span> <span class="o">+</span> <span class="n">k0</span> <span class="o">+</span> <span class="n">j</span><span class="p">]);</span>
      <span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
        <span class="n">A_s</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">BK</span> <span class="o">+</span> <span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
		<span class="n">A_s</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">BK</span> <span class="o">+</span> <span class="n">j</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
		<span class="n">A_s</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">BK</span> <span class="o">+</span> <span class="n">j</span> <span class="o">+</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
		<span class="n">A_s</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">BK</span> <span class="o">+</span> <span class="n">j</span> <span class="o">+</span> <span class="mi">3</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
      <span class="p">}</span>
    <span class="p">}</span>

    <span class="c1">// B_s[i, j] = B[k + i, b_col + j]</span>
    <span class="k">for</span> <span class="p">(</span><span class="kt">size_t</span> <span class="n">b_idx</span> <span class="o">=</span> <span class="mi">4</span> <span class="o">*</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span> <span class="n">b_idx</span> <span class="o">&lt;</span> <span class="n">BK</span> <span class="o">*</span> <span class="n">BN</span><span class="p">;</span> <span class="n">b_idx</span> <span class="o">+=</span> <span class="mi">4</span> <span class="o">*</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span><span class="p">)</span> <span class="p">{</span>
      <span class="kt">size_t</span> <span class="n">i</span> <span class="o">=</span> <span class="n">b_idx</span> <span class="o">/</span> <span class="n">BN</span><span class="p">;</span>
      <span class="kt">size_t</span> <span class="n">j</span> <span class="o">=</span> <span class="n">b_idx</span> <span class="o">%</span> <span class="n">BN</span><span class="p">;</span>
      <span class="k">if</span> <span class="p">(</span><span class="n">k0</span> <span class="o">+</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">K</span> <span class="o">&amp;&amp;</span> <span class="n">j</span> <span class="o">+</span> <span class="n">b_col_base</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">)</span> <span class="p">{</span>
        <span class="o">*</span><span class="k">reinterpret_cast</span><span class="o">&lt;</span><span class="n">float4</span><span class="o">*&gt;</span><span class="p">(</span><span class="o">&amp;</span><span class="n">B_s</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">BN</span> <span class="o">+</span> <span class="n">j</span><span class="p">])</span> <span class="o">=</span> <span class="o">*</span><span class="k">reinterpret_cast</span><span class="o">&lt;</span><span class="n">float4</span><span class="o">*&gt;</span><span class="p">(</span><span class="o">&amp;</span><span class="n">B</span><span class="p">[(</span><span class="n">k0</span> <span class="o">+</span> <span class="n">i</span><span class="p">)</span> <span class="o">*</span> <span class="n">b</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">+</span> <span class="n">j</span><span class="p">]);</span>
      <span class="p">}</span> <span class="k">else</span> <span class="p">{</span>
        <span class="n">B_s</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">BN</span> <span class="o">+</span> <span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
        <span class="n">B_s</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">BN</span> <span class="o">+</span> <span class="n">j</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
        <span class="n">B_s</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">BN</span> <span class="o">+</span> <span class="n">j</span> <span class="o">+</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
        <span class="n">B_s</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">BN</span> <span class="o">+</span> <span class="n">j</span> <span class="o">+</span> <span class="mi">3</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
      <span class="p">}</span>
    <span class="p">}</span>

    <span class="n">__syncthreads</span><span class="p">();</span>
    <span class="k">for</span> <span class="p">(</span><span class="kt">size_t</span> <span class="n">k</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">k</span> <span class="o">&lt;</span> <span class="n">BK</span><span class="p">;</span> <span class="n">k</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
	  <span class="kt">float</span> <span class="n">B_reg</span><span class="p">[</span><span class="n">TN</span><span class="p">];</span>

	  <span class="k">for</span> <span class="p">(</span><span class="kt">size_t</span> <span class="n">col_counter</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">col_counter</span> <span class="o">&lt;</span> <span class="n">TN</span><span class="p">;</span> <span class="n">col_counter</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
          <span class="kt">size_t</span> <span class="n">col</span> <span class="o">=</span> <span class="n">col_counter</span> <span class="o">*</span> <span class="n">TPBN</span> <span class="o">+</span> <span class="n">this_thread_col</span><span class="p">;</span>
		  <span class="n">B_reg</span><span class="p">[</span><span class="n">col_counter</span><span class="p">]</span> <span class="o">=</span> <span class="n">B_s</span><span class="p">[</span><span class="n">k</span> <span class="o">*</span> <span class="n">BN</span> <span class="o">+</span> <span class="n">col</span><span class="p">];</span>
	  <span class="p">}</span>

      <span class="k">for</span> <span class="p">(</span><span class="kt">size_t</span> <span class="n">row_counter</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">row_counter</span> <span class="o">&lt;</span> <span class="n">TM</span><span class="p">;</span> <span class="n">row_counter</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
        <span class="c1">// filling in row a_row_base + row_counter * TPBM + this_thread_row of C</span>
        <span class="c1">// which is row row_counter of tmp</span>
        <span class="kt">size_t</span> <span class="n">row</span> <span class="o">=</span> <span class="n">row_counter</span> <span class="o">*</span> <span class="n">TPBM</span> <span class="o">+</span> <span class="n">this_thread_row</span><span class="p">;</span>
        <span class="kt">float</span> <span class="n">a_val</span> <span class="o">=</span> <span class="n">A_s</span><span class="p">[</span><span class="n">row</span> <span class="o">*</span> <span class="n">BK</span> <span class="o">+</span> <span class="n">k</span><span class="p">];</span>
        <span class="k">for</span> <span class="p">(</span><span class="kt">size_t</span> <span class="n">col_counter</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">col_counter</span> <span class="o">&lt;</span> <span class="n">TN</span><span class="p">;</span> <span class="n">col_counter</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
          <span class="c1">// col b_col_base + col_counter * TPBN + this_thread_col of C</span>
          <span class="c1">// which is col col_counter of tmp</span>
		  <span class="kt">float</span> <span class="n">b_val</span> <span class="o">=</span> <span class="n">B_reg</span><span class="p">[</span><span class="n">col_counter</span><span class="p">];</span>

          <span class="n">tmp</span><span class="p">[</span><span class="n">row_counter</span> <span class="o">*</span> <span class="n">TN</span> <span class="o">+</span> <span class="n">col_counter</span><span class="p">]</span> <span class="o">+=</span> <span class="n">a_val</span> <span class="o">*</span> <span class="n">b_val</span><span class="p">;</span>
        <span class="p">}</span>
      <span class="p">}</span>
    <span class="p">}</span>
    <span class="n">__syncthreads</span><span class="p">();</span>
  <span class="p">}</span>
  <span class="c1">// now tmp[row * TN + col] goes into RES[(row * TPBM + this_thread_row) *</span>
  <span class="c1">// res.shape[2] + (col * TPBN + this_thread_col)]</span>
  <span class="k">for</span> <span class="p">(</span><span class="kt">size_t</span> <span class="n">row_counter</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">row_counter</span> <span class="o">&lt;</span> <span class="n">TM</span><span class="p">;</span> <span class="n">row_counter</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
    <span class="k">for</span> <span class="p">(</span><span class="kt">size_t</span> <span class="n">col_counter</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">col_counter</span> <span class="o">&lt;</span> <span class="n">TN</span><span class="p">;</span> <span class="n">col_counter</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
      <span class="n">RES</span><span class="p">[(</span><span class="n">row_counter</span> <span class="o">*</span> <span class="n">TPBM</span> <span class="o">+</span> <span class="n">this_thread_row</span><span class="p">)</span> <span class="o">*</span> <span class="n">res</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">+</span>
          <span class="p">(</span><span class="n">col_counter</span> <span class="o">*</span> <span class="n">TPBN</span> <span class="o">+</span> <span class="n">this_thread_col</span><span class="p">)]</span> <span class="o">=</span>
          <span class="n">tmp</span><span class="p">[</span><span class="n">row_counter</span> <span class="o">*</span> <span class="n">TN</span> <span class="o">+</span> <span class="n">col_counter</span><span class="p">];</span>
    <span class="p">}</span>
  <span class="p">}</span>
<span class="p">}</span></code></pre></figure>

<h1 id="more-optimizations">More optimizations</h1>

<p>This is still a work in progress.  The most promising next optimizations are…</p>
<ul>
  <li>double-buffering, and</li>
  <li>a thorough parameter sweep.</li>
</ul>]]></content><author><name></name></author><summary type="html"><![CDATA[Optimizing matrix multiplication on an RTX 3050]]></summary></entry><entry><title type="html">Speed-running Integer Factorization with AI</title><link href="/quadratic-sieve/" rel="alternate" type="text/html" title="Speed-running Integer Factorization with AI" /><published>2026-04-01T00:00:00+00:00</published><updated>2026-04-01T00:00:00+00:00</updated><id>/quadratic-sieve</id><content type="html" xml:base="/quadratic-sieve/"><![CDATA[<p><a href="https://github.com/brian-lawrence-math/quadratic-sieve">The repo.</a></p>

<p>The <a href="https://en.wikipedia.org/wiki/Quadratic_sieve">quadratic sieve</a> is a fast integer factorization algorithm:
if a positive integer $N$ is the product of two large prime numbers \(N = pq\),
then quadratic sieve will find the two factors reasonably quickly.
“Reasonably” is the key word here: if \(N\) is an \(n\)-bit integer
(so \(N \approx 2^n\)), then the runtime of quadratic sieve is something like \(2^{\sqrt{n}}\).
This is much faster than trial division (runtime \(\sqrt{N} \approx 2^{n/2}\))
but still not even close to a polynomial-time algorithm.</p>

<p>Integer factorization is relevant because, if you can factor large integers,
you can break the <a href="https://en.wikipedia.org/wiki/RSA_cryptosystem">RSA cryptosystem</a>.
In fact, it’s precisely because of quadratic sieve and <a href="https://en.wikipedia.org/wiki/General_number_field_sieve">algorithms like it</a>
that many systems have switched over from RSA to <a href="https://en.wikipedia.org/wiki/Elliptic-curve_cryptography">elliptic curve cryptography</a>
in recent years.</p>

<p>Anyway, I’m going to implement quadratic sieve on an Nvidia RTX 3050,
with some help from OpenAI’s Codex.
I want to see what size of integer I can factor in a reasonable amount of time
(let’s say one minute or so).
This will be a good excuse for me to explore both</p>
<ul>
  <li>the challenges of adapting an interesting algorithm for parallelization, and</li>
  <li>what an AI agent can do.</li>
</ul>

<p>This is a work in progress: I’ve gone through a couple of iterations
with Codex, and I’ve already factored some pretty big numbers,
but I’m still experimenting to see just how much I can get the AI to do.</p>

<h2 id="summary-what-coding-agents-can-and-cant-do">Summary: what coding agents can and can’t do</h2>

<p>Before I get to the juicy details, here’s what I’ve learned from working with Codex.</p>

<p>Codex is very, very good at basic coding tasks.
It never forgets semicolons.  Its code always compiles.
It knows how to organize a project, 
how to start by building something simple that works and make iterative improvements from there.</p>

<p>It’s very good at reading and writing code. It quickly spots the sort of edge cases and off-by-one errors that used to cause me so much grief.</p>

<p>Codex also knows, in broad strokes, how the quadratic sieve algorithm is supposed to work.
“First you choose a factor base, then you find relations, then you solve a binary matrix.”</p>

<p>So working with Codex is a tremendous productivity boost.</p>

<p>But Codex is reluctant to do any sort of experimentation or analysis, beyond “write code and see if it works”.
Throughout the project, I had to give it two kinds of prompts:</p>
<ul>
  <li>Requests for empirical data:
Even though I told Codex I was interested in optimizing performance,
the agent was reluctant to run experiments and gather data.
For example:
    <ul>
      <li>The agent gave up on a small factorization task because it only found half the necessary number of relations.
But the target was 60 seconds, and the program only ran for about 1 second.
The solution – to double or triple the size of the search space – would have been obvious
if Codex had measured the runtime.</li>
      <li>On a longer-running program, the agent put a bunch of effort into optimizing a trivially small code path
(the is_squarefree() function).
This function took one millisecond out of a total of about 4 seconds of execution.
Again, a misdirected effort that could have been prevented by some simple timing.
In fact, it’s even worse: (1) Codex already had timing data in context, it just had to look, and 
(2) even without experimental data, it should have been clear that is_squarefree() just wasn’t the bottleneck here.</li>
    </ul>
  </li>
  <li>Theoretical analysis (asymptotic estimates, etc.).
Codex can do this sort of thing pretty well if I hold its hand,
but by default Codex would rather code than think.
    <ul>
      <li>When the search could not find enough relations, Codex suggested decreasing the factor base bound \(B\).
In fact the solution is precisely the opposite: making \(B\) bigger makes relations more plentiful.</li>
      <li>It’s important to find the best search space to search for relations; 
doing this right gives a huge performance boost.
It takes a big of theoretical work to figure out the search space.
(It’s nothing terribly complicated; I explain the basics further down in this post.)
Codex can do every step of the calculation, but I couldn’t get it to do the full analysis with any degree of autonomy.</li>
    </ul>
  </li>
</ul>

<p>In summary: Codex is a terrific coder and a great productivity boost.
If I want to get more out of it, the next challenge is to get it to divide its efforts
among implementation, empirical data-gathering, and theoretical analysis.</p>

<h2 id="the-algorithm">The algorithm</h2>

<p>OK, with that summary out of the way, let’s get back to a brief overview of quadratic sieve.
I want to give you a sense of</p>
<ul>
  <li>the different parameter choices and how they affect performance, and</li>
  <li>how amenable things are to parallelization on GPU.</li>
</ul>

<p>If you like to read code, take a look at a basic <a href="https://github.com/brian-lawrence-math/quadratic-sieve/blob/main/python/qs.py">Python implementation</a>.</p>

<p>Just to fix ideas, imagine that $N = pq$ is a 100-200 bit integer (30-70 digits),
and the primes $p$ and $q$ are about the same size.
With these parameters, we’d like things to run in a few seconds on a GPU.</p>

<p>The algorithm has two steps (not counting pre- and post-processing, which should be “fast”).</p>

<ul>
  <li>Preprocessing: Choose a factor base bound $B$ (for us, maybe $1000$, $10000$, $100000$) and a search plan for the next step.</li>
  <li>Search for “relations”.  A relation is a pair $(k, a)$ such that
\[
\text{all the prime factors of } a^2 - kN    \text{ are less than } B.
\]
In other words:
\[ a^2 \equiv \text{product of small primes (mod } N\text{).} \]
The number of relations you need to find is a few more than the number of primes less than $B$
(which is about $B / \log B$).</li>
  <li>Solve a linear system: Combine some of the relations to get a handful of perfect-square relations of the form
\[ a^2 \equiv b^2 \text{(mod } N \text{).} \]
This step amounts to solving a large binary matrix.</li>
  <li>Postprocessing: For each perfect-square relation, you know that 
\[ a^2 - b^2 = (a + b) (a - b) \]
is a multiple of $N$.<br />
Compute
\[ \operatorname{gcd}(a+b, N) \]
using the Euclidean algorithm.
There is a 50% chance that this gcd will be a nontrivial factor of $N$ (either $p$ or $q$), 
in which case you are done.
If not, repeat with more perfect-square relations until you find a factor.</li>
</ul>

<h3 id="asymptotics">Asymptotics</h3>

<p>In choosing $B$, there is a tradeoff between the search for relations and solving the linear system.
If $B$ is too small, relations will be hard to find (it’s a very rare number that is only divisible by 2, 3 or 5!).
But the larger $B$ is, the larger the matrix that needs to be solved in the solving step.</p>

<p>Here are some impressionistic asymptotics: I’m ignoring some important log factors and constants
so I can paint a clear picture in your head.
The goal is not to get precise runtime estimates (those come from experiment!)
but to give you a mental model of the tradeoffs.</p>

<p>Suppose $B$ is a $b$-bit integer, and suppose $N \approx B^k$.  In other words:
\[  B \approx 2^b \]
\[  N \approx 2^n \]
\[  n = bk.  \]</p>

<p>Then the search for relations will take something like $2^{2k}$ time – the smaller $B$ is compared to $N$, the rarer these relations are.</p>

<p>Solving the linear system will take something like $2^{3b}$ time – you’re solving a matrix of size (approximately) $B$,
and solving a matrix (at least by the naive algorithm) takes cubic time.</p>

<p>Now you can see where the $2^{\sqrt{n}}$ asymptotic comes from.  Suppose you have $2^t$ total time budget.
You want to divide it close to evenly between search and solve;
let’s say you allocate $2^t$ for each step.
(Oops!  Did I just double your budget?
Don’t worry, this is already a sloppy estimate, one more factor of 2 is no big deal.)
This means you want
\[ k = t/2 \]
\[ b = t/3 \]
so you can factor a number of up to
\[ n = bk = t^2 / 6 \]
bits.</p>

<p>In practice, this means:</p>
<ul>
  <li>The algorithm has two phases, search and solve, which can be run, timed, and optimized independently.</li>
  <li>Increasing $B$ makes the search phase faster and the solve phase slower.</li>
  <li>The solve phase runtime depends only on the size of $B$, not on $N$.  In practice, $B$ in the tens of thousands might take a few seconds on a CPU.</li>
</ul>

<h3 id="more-analysis-the-search-space">More analysis: the search space.</h3>

<p>Like I said before, we want 
\[ a^2 - k N \]
to have a factorization into lots of small factors.
Of course, this is most likely to happen if $a^2 - kN$ is small!
So the natural thing to do is to pick some smallish integer $k$, and then search for $a$ in some interval around $\sqrt{kN}$.
In other words, we’ll look at
\[ a = x + \left \lfloor \sqrt{kN} \right \rfloor \]
where $k$ is a small positive integer, and $x$ is a small integer (positive or negative).</p>

<p>What sorts of $k$ and $x$ should we search over?
I think it’s pretty clear that we should figure out which $k$ and $x$ will make
\[ a^2 - k N \]
be the smallest, and target those first.</p>

<p>With a little bit of algebra, you can find that
\[ a^2 - k N \approx 2 x \sqrt{kN} \]
when $x$ is small.  So the obvious strategy is to pick some bound $M$ and search all pairs $k$ and $x$
for which
\[ -M &lt; 2x \sqrt{kN} &lt; M. \]</p>

<p>At least, I think this is obvious.  Codex doesn’t.
Codex wants to pick a single value of $k$ and then search over $x$ in an ever-growing interval.
I wanted to see if I could get Codex to do the analysis my way;
after all, the calculations I just did are well within its abilities!
It took a surprising amount of coaxing, but in the end we got a 5-10x speedup on the search step.</p>

<p>OK, enough of theory, let’s get down to the metal.</p>

<h2 id="parallelizing-the-search-and-an-interesting-tradeoff">Parallelizing the search, and an interesting tradeoff</h2>

<p>Both of the two big steps (search and solve) could potentially benefit from parallelization,
but search is the natural first target.
After all, the search is what they call “embarrassingly parallel”:
each pair $(k, x)$ either is a hit or it isn’t.
On the other hand, imagine trying to solve a matrix – think row reduction or Gauss-Jordan elimination –
in parallel.
(<a href="https://en.wikipedia.org/wiki/Gaussian_elimination">Wikipedia</a> has a nice animation.)
Maybe each thread can take responsibility for reducing a single row,
but each pivot row will need to be broadcast to the threads, one after another.
And that’s not to mention some of the more complicated <a href="https://en.wikipedia.org/wiki/Block_Wiedemann_algorithm">optimizations</a> 
that might be useful in our sparse setting.</p>

<p>So let’s focus on doing the search in parallel.</p>

<p>At first glance, this seems easy enough.
Each thread takes responsibility for a single $a^2 - k N$.
The thread does a series of trial divisions (by 2, 3, 5, 7, …),
until it reaches the bound $B$ – at which point it either accepts or rejects.</p>

<p>But it turns out this simple approach introduces a whole bunch of duplicated computation!
To see why, think about a single prime factor, like 2027.
Your algorithm ends up testing each of these numbers individually for divisibility by 2027.
A much faster approach is to “sieve”:
it turns out you can do just a single divisibility test and
mark off all the $a$ that make $a^2 - k N$ a multiple of 2027.
(As usual, <a href="https://en.wikipedia.org/wiki/Sieve_of_Eratosthenes">Wikipedia</a> has a terrific animation.
And notice that this particular optimization predates the invention of the GPU by quite some years.)</p>

<p>Now imagine you’re searching, say, one million numbers.
You can either do one million trial divisions (in parallel, one number per thread),
or you can assign a single thread responsibility for finding all multiples of 2027
(parallelizing the work one prime per thread).
If you use the sieving approach, instead of one million trial divisions,
you do just one trial division, and then you “mark” about $1000000/2027 \approx 500$ 
values.  Should be an easy win!</p>

<p>But…</p>

<p>Computation is fast; memory is slow.
We’ve just gone from a pure-compute regime where each thread can
(more or less) do the full computation using only its own registers
to a random-access regime where each thread is making its own individual, unpredictable
writes to global memory.</p>

<p>On a GPU (at least most Nvidia GPUs are like this), threads are organized in “warps” of 32.
All the threads in a warp execute the same instruction in lockstep,
on the same streaming multiprocessor, with access to the same low-latency shared memory.
(Actually, I’m simplifying things somewhat.
Threads are organized into “warps” of 32 and “blocks” of a larger size –
the programmer has some control over block size but 1024 is typical.
Threads in a warp execute at the same time; threads in a block share memory.)</p>

<p>Anyway, imagine we parallize the work “one prime per thread”.
So our imaginary 2027 thread is sharing a warp (and memory etc.) with other threads responsible for different primes:
2029, 2039, 2053, and so forth.
At each iteration of the loop, each thread will compute the next multiple of its own prime,
and then issue a write request to… somewhere in this global array of 1 million items.
The write address won’t be cached, because who could have predicted it?</p>

<p>Even worse, these 32 writes will widely spread across memory.
Writes to DRAM on a GPU come in 128-byte transactions, 
and if all 32 threads write to the same 128-byte block of memory, 
the writes can be “coalesced” into a single transaction.
But if each thread writes to multiples of a different prime, 
the chip will be forced to handle 32 different transactions of 128 bytes each.</p>

<p>And finally, to protect against race conditions (what if two different prime threads
try to write to the same spot at the same time?)
we’ll need to use atomic operations – introducing further inefficiency.</p>

<p>Empirically, switching from “one $x$ per thread” trial division to “one prime per thread” sieving
leads to a substantial (5-10x) slowdown in the search phase.</p>

<h3 id="a-more-efficient-solution">A more efficient solution</h3>

<p>A hybrid approach, a sort of “data-local sieving,” turns out to give the best of both worlds.</p>

<p>A block of threads collaboratively loads a number of $x$ values (1024 values works well) 
into high-bandwidth shared memory.
Once the values have been loaded, the threads will run the sieving procedure on these 1024 values.</p>

<figure class="highlight"><pre><code class="language-cpp" data-lang="cpp"><span class="c1">// load 1024 x values into shared memory</span>

<span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">p_idx</span> <span class="o">=</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span> <span class="n">p_idx</span> <span class="o">&lt;</span> <span class="n">n_primes</span><span class="p">;</span> <span class="n">p_idx</span> <span class="o">+=</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span><span class="p">)</span> <span class="p">{</span>
    <span class="kt">int</span> <span class="n">p</span> <span class="o">=</span> <span class="n">primes</span><span class="p">[</span><span class="n">p_idx</span><span class="p">];</span>
    <span class="c1">// x_start = first multiple of p in this interval</span>
    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x_start</span><span class="p">;</span> <span class="n">x</span> <span class="o">&lt;</span> <span class="n">x_end</span><span class="p">;</span> <span class="n">x</span> <span class="o">+=</span> <span class="n">p</span><span class="p">)</span> <span class="p">{</span>
        <span class="c1">// label x as a multiple of p</span>
    <span class="p">}</span>
<span class="p">}</span>

<span class="c1">// write those 1024 x values back into global memory</span></code></pre></figure>

<p>As far as memory is concerned, this approach requires only one DRAM read and one write per $x$.  Clearly this is best possible.</p>

<p>As for computation, we keep most of the benefits of sieving.
Remember: the total computation per prime $p$ is <code class="language-plaintext highlighter-rouge">num_x</code> with the naive algorithm,
versus $num_x / p$ with sieving – so sieving gives a factor of $p$ savings.
This algorithm processes 1024 values of $x$ at a cost of
\[ 1 + 1024 / p: \]
the initial calculation of <code class="language-plaintext highlighter-rouge">x_start</code> is required once, 
but the inner loop strides through the $x$ values in steps of $p$.
The total cost for all $num_x$ values is
\[ num_x \left ( \frac{1}{1024} + \frac{1}{p} \right ). \]
In words:</p>
<ul>
  <li>When $p &lt; 1024$, the algorithm is almost as compute-efficient as sieving.</li>
  <li>When $p &gt; 1024$, the algorithm is 1024 times faster than the naive algorithm.
That’s pretty good.</li>
</ul>

<p>One more note: From experiments, it turns out to be most efficient
to run the naive algorithm for very small primes ($p &lt; 32$),
and then switch to this hybrid memory-local sieve for $p &gt; 32$.</p>

<h2 id="whats-next">What’s next?</h2>

<p>We’ve done some pretty good optimizations on the search phase.
Next I want to see what we can do about:</p>
<ul>
  <li>solve, and</li>
  <li>preprocessing.</li>
</ul>

<p>The solve phase involves row-reducing a binary matrix.
It’s not the most natural candidate for parallelization –
too much memory interaction between rows – 
but I’m doing some experiments to see if I can get speedup on a GPU.</p>

<p>The preprocessing phase is also surprisingly time-consuming
(~20 sec for a 192-bit $N$).
Here again there is room to optimize by factoring out some
parallelization-friendly parts for execution on GPU.</p>]]></content><author><name></name></author><category term="jekyll" /><category term="update" /><summary type="html"><![CDATA[The repo.]]></summary></entry><entry><title type="html">Understanding the Adam optimizer</title><link href="/adam/" rel="alternate" type="text/html" title="Understanding the Adam optimizer" /><published>2026-04-01T00:00:00+00:00</published><updated>2026-04-01T00:00:00+00:00</updated><id>/adam</id><content type="html" xml:base="/adam/"><![CDATA[<p><a href="https://github.com/brian-lawrence-math/mnist">The repo.</a></p>

<p>In this post I’ll offer my own somewhat contrarian explanation
of why the Adam optimizer works.</p>

<p>Then I’ll demonstrate my explanation with some <a href="https://github.com/brian-lawrence-math/mnist">experiments</a> on
a simple proof-of-concept optimizer I made up, called GradSign.</p>

<h1 id="the-adam-optimizer">The Adam optimizer</h1>

<p><a href="https://arxiv.org/pdf/1412.6980">Adam</a> is a widely-used optimization algorithm 
that tends to perform very well on deep learning tasks.</p>

<p>Adam is often explained as an extension of stochastic gradient descent (SGD):
sample one batch, compute the loss and its gradient,
and smooth the result out by taking an exponential moving average of the gradient.
Then there’s a step that the standard explanations sort of glide over –
something about a second moment (i.e. a variance estimate) for the gradient,
something about “adaptive choice of the learning rate” –
and then you take your step… and magically end up with a good optimizer.</p>

<p>(A quick Google search turns up plenty of explanations along these lines:
for example, <a href="https://en.wikipedia.org/wiki/Stochastic_gradient_descent#Adam">here</a>, <a href="https://medium.com/@weidagang/demystifying-the-adam-optimizer-in-machine-learning-4401d162cb9e">here</a>, <a href="https://optimization.cbe.cornell.edu/index.php?title=Adam">here</a>…)</p>

<p>I’d like to offer a different take on Adam – less calculus, more statistics.
My take will suggest a different toy model of an optimizer:
not stochastic gradient descent, but an algorithm that only looks at the sign
(not the magnitude) of each gradient.
I call my optimizer GradSign.
I’ll test it out on a simple machine learning task: building an MNIST classifier.</p>

<p>Spoiler alert: No, GradSign doesn’t outperform Adam.
In my small experiment, my optimizer performs comparably to Adam,
but SGD (surprisingly) outperforms them both.</p>

<p>I hope reading this inspires you to tinker and explore.</p>

<p>This post has two parts:</p>
<ul>
  <li>My reinterpretation of Adam</li>
  <li>An explanation of the new optimizer</li>
</ul>

<p>You can see experimental results with the new optimizer on <a href="https://github.com/brian-lawrence-math/mnist">Github</a>.</p>

<h2 id="adam-reexplained">Adam reexplained</h2>

<p>The Adam optimizer works by keeping first- and second-moment estimates
(exponential moving averages) for each the gradient of each parameter;
at each step, those estimates are used to determine the change to that parameter.</p>

<p>We will consider a single parameter, one of millions or billions in a large model:
Adam works on each parameter independently, so we don’t need to worry about anything else.</p>

<p>I will follow the notation from Algorithm 1 of the <a href="https://arxiv.org/pdf/1412.6980">original paper</a>.</p>

<p>The algorithm depends on three parameters, with the following suggested values:
\[  \alpha = 0.001  \]
\[  \beta_1 = 0.9  \]
\[  \beta_2 = 0.999.  \]
Here $\alpha$ is the learning rate.
(We will see that, even though $\alpha$ is called a “learning rate”,
it does not have the same units as the learning rate in SGD.)
The parameters $\beta_1$ and $\beta_2$ determine the timescales for 
the two exponential moving averages.
(With the suggested values, the first moment is averaged with a decay time of 10 iteration steps,
and the second with a decay time of 1000.)</p>

<p>Let $g_t$ be the gradient of our favorite parameter at timestep $t$.
(Remember, the gradient at each timestep depends on both the parameter values –
which are updated at each step – and the random choice of a fresh batch of data.)
The exponential moving average of the first moment (mean, $m$) and second moment (uncentered variance, $v$)
are computed recursively:
\[ m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t \]
\[ v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2. \]
Finally, the parameter update is computed as
\[ - \alpha \frac{m_t}{\sqrt{v_t} + \epsilon}. \]</p>

<p>Let us unroll the recurrences.
(We will also assume, as a mathematical convenience, that the gradients go infinitely far back in time.
In reality one has to worry about how to initialize $m_t$ and $v_t$,
but after the first few optimizer steps, our assumption will be a reasonable one.)
\[ m_t = (1 - \beta_1) \sum_{i = 0}^{\infty} \beta_1^i g_{t - i} \]
\[ v_t = (1 - \beta_2) \sum_{i = 0}^{\infty} \beta_2^i g_{t - i}^2 \]
These are exponential moving averages – weighted averages of past values of $g_t$ (or its square),
the relevance of past values decaying over time with a characteristic time scale
(a physicist might say “half-life”) depending on the $\beta$’s.</p>

<p>I will make three simplifications to make the analysis easier.
First, I will ignore the $\epsilon \approx 10^{-8}$ that is thrown into the denominator of
\[ \text{update} = - \alpha \frac{m_t}{\sqrt{v_t} + \epsilon}. \]
for numerical stability reasons.
The $\epsilon$ is just there to make sure the algorithm does something sensible when the gradient vanishes.</p>

<p>Second, I will replace the exponential weighted average (with parameters $\beta$)
with a simple unweighted average (“moving window”) over the past $n$ terms (where $n$ is the window size).
This would be a computational disaster to implement –
we would have to store $n$ values per moment per parameter, rather than just $1$ –
but it will help us build a clean mental model.</p>

<p>Finally, I will assume $\beta_1 = \beta_2$, or in other words that the two moving averages
are computed over the same size of window $n_1 = n_2$.
Unlike the first two simplifications, assuming $\beta_1 = \beta_2$ really does change the behavior of the optimizer in a meaningful way.
I’ll come back to this later, because I want to understand this simplified model first.</p>

<p>So, let’s say $\beta_1 = \beta_2 = 0.99$, so now $m_t$ and $v_t$ are averages
over the past 100 iterations:
\[ m_t = \frac{\sum_{i=0}^{99} g_{t-i}}{100}  \]
\[ v_t = \frac{\sum_{i=0}^{99} g_{t-i}^2}{100}  \]
and
\[ \text{update} = - \alpha \frac{ \sum g_{t-i} / 100 } { \sqrt{ \sum g_{t-i}^2 / 100 } }. \]
But now that fraction is something we can understand:
it is nothing more than the <em>cosine similarity</em> between the vectors
\[ \hat{g} = (g_{t-99}, g_{t-98}, \ldots, g_{t-1}, g_t) \]
and 
\[ \hat{1} = (1, 1, \ldots, 1, 1)!  \]</p>

<p>In other words, if we call $\theta$ the angle between those two vectors 
(my apologies, this is not the same as the $\theta$ in the paper),
then the update to our parameter is simply
\[ \text{update} = - \alpha \cos \theta.  \]</p>

<p>We immediately see:</p>
<ul>
  <li>$\alpha$ is the largest possible update to our parameter, and</li>
  <li>the size of the update is determined by how close the different $g_t$’s are to each other, rather than the size of $g_t$.
 (The Adam paper calls this a “signal-to-noise ratio”.)</li>
</ul>

<p>In fact, the cosine similarity is invariant under scaling the $g_t$’s.
(Contrast this to classical gradient descent, 
where the step size is a product of learning rate and gradient, 
and you have to do lots of extra work to make sure gradients in different parts of the network
have the same scale.)</p>

<h2 id="thinking-statistically">Thinking statistically</h2>

<p>When we update our parameter (we’re just focusing on one parameter, remember? the rest will come along for the ride)
our goal is to make the loss decrease.
Calculus teaches us that the gradient (a first derivative) determines whether the loss will go up or down,
but let me reframe the question statistically:
<em>How confident are we that making this change will decrease the loss?</em></p>

<p>In the statistical framing, there are two sources of noise we need to worry about:</p>
<ul>
  <li>Sample randomness: each gradient is computed from only a small batch of data.</li>
  <li>The gradient landscape: the slope might be negative now, but if our step size is too large we may overshoot and end up climbing back uphill.</li>
</ul>

<p>Looking at the past $n$ steps can protect us against both types of noise!
We’re trying to evaluate a statistical hypothesis, like
“decreasing this parameter will result in a lower loss against the next randomly chosen batch”.
Clearly, a natural statistic is “on how many of the last $n$ batches was this gradient positive?”
If the gradient has consistent positive values across batches, we can expect the gradient to have 
a positive value on the next batch as well.</p>

<p>Similarly, we want to know if the loss landscape is bumpy or smooth.
We can think of the past $n$ parameter updates as a sort of random walk across this landscape
(not uniformly random of course, but governed by a complicated stochastic process).
We’re about to take another step in this random walk.
If the gradient has been consistent in the past,
we can have more confidence that our gradient will remain positive through the full length of our next step.</p>

<p>So, instead of thinking of an optimizer step as a gradient update,
I think of it as a statistical confidence test:
how confident are we that this step will result in negative gradient
on the next (yet unseen) batch, 
at the current parameter position, the updated position, and everywhere in between?</p>

<p>If this is what’s going on with Adam, then maybe the size of the gradient doesn’t matter at all.
Maybe all that matters is: 
<em>At how many of the past $n$ timesteps was the gradent positive,
and at how many was it negative?</em></p>

<p>We’ll turn this idea into a new optimizer algorithm soon,
but first I want to wrap up a couple of loose ends.</p>

<h2 id="the-role-of-beta_1-and-beta_2">The role of $\beta_1$ and $\beta_2$</h2>

<p>Earlier on, we made the simplifying assumption that $\beta_1 = \beta_2$.
I told you that we were simplifying away something important, but I didn’t tell you what.
Now it’s time to fix that.</p>

<p>In the real world, some good values for $\beta_1$ and $\beta_2$ are 
\[  \beta_1 = 0.9  \]
\[  \beta_2 = 0.999.  \]
In other words,
the first moment estimate (the numerator) averages the gradient over the last 
10 or so timesteps,
while the second moment (the denominator) averages its square over the last 1000.</p>

<p>Normally, you might think, this won’t make a big difference.
But it makes a big difference when the gradient is <em>sparse</em>.</p>

<p>Imagine a parameter that is usually unimportant: its gradient is close to zero.
But every so often, the parameter becomes very important, and its gradient gets big.
(You might imagine that in a large, complex LLM, this one parameter is responsible for
learning one particular thing – and that one thing only rarely shows up in the training data.)</p>

<p>The role of $\beta_2$ is to remember that this gradient has a track record of sudden spikes.
When a gradient has spiked in the past, we don’t want to make updates based on small gradients.
But we also don’t want to continue making updates based on a gradient
from many steps back.
The solution is to make $\beta_2$ large (remember the spike and slow down learning for 1000 steps)
but keep $\beta_1$ small (stop making updates 10 steps after the spike).</p>

<h2 id="adam-and-memory">Adam and memory</h2>

<p>The Adam optimizer stores two floating-point values (the moment estimates $m_t$ and $v_t$)
per parameter.
While the forward and backward pass can often be computed in 16-bit precision,
the Adam optimizer state requires 32 bits for each of $m_t$ and $v_t$.
(Storing optimizer state in 16-bit precision leads to numerical instability and degrades training performance.)
In a typical training run, memory usage is dominated by per-parameter costs:
4 bytes for the (full-precision) master copy of the parameter, 
2 for the 16-bit downcasted parameter, 2 for the gradient, and 8 for the optimizer state –
so the optimizer is responsible for about half the total memory usage.</p>

<p>Wouldn’t it be great if we could use less?</p>

<p>Adam uses 8 bytes per parameter.  Here is an optimizer that uses just 1.</p>

<h2 id="gradsign">GradSign</h2>

<p>GradSign is a simple proof-of-concept optimizer that:</p>
<ul>
  <li>only uses the sign (+ or -) of the gradient of each parameter, and</li>
  <li>only keeps one byte of optimizer data, per parameter.</li>
</ul>

<p>The update code is as follows:</p>

<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    <span class="k">for</span> <span class="n">n</span><span class="p">,</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">.</span><span class="n">named_params</span><span class="p">:</span>
        <span class="n">new_count</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="p">(</span><span class="n">p</span><span class="p">.</span><span class="n">grad</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">grad_counts</span><span class="p">[</span><span class="n">n</span><span class="p">]</span> <span class="o">-=</span> <span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">grad_counts</span><span class="p">[</span><span class="n">n</span><span class="p">]</span> <span class="o">+</span> <span class="mi">4</span><span class="p">)</span> <span class="o">//</span> <span class="mi">8</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">grad_counts</span><span class="p">[</span><span class="n">n</span><span class="p">]</span> <span class="o">+=</span> <span class="mi">8</span> <span class="o">*</span> <span class="n">new_count</span>

        <span class="n">p</span><span class="p">.</span><span class="n">data</span> <span class="o">-=</span> <span class="bp">self</span><span class="p">.</span><span class="n">lr</span> <span class="o">*</span> <span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">grad_counts</span><span class="p">[</span><span class="n">n</span><span class="p">]</span> <span class="o">/</span> <span class="mf">64.0</span><span class="p">)</span></code></pre></figure>

<p>The dictionary <code class="language-plaintext highlighter-rouge">self.grad_counts</code> stores, for each parameter,
a quantized exponential moving average of the sign of the gradient
over the past several timesteps.
The moving average has a characteristic timescale of 8 timesteps
(in other words, $\beta = 0.875$).
The quantized moving average is scaled to be between $-64$ and $64$, 
so that it fits within a signed 8-bit integer.</p>

<p>I chose these parameters to control quantization error in the decay term</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>self.grad_counts[n] -= (self.grad_counts[n] + 4) // 8.
</code></pre></div></div>
<p>The decay term will be nonzero as soon as the value <code class="language-plaintext highlighter-rouge">self.grad_counts</code> exceeds $\pm 4$.
Each update to the average causes the count to change by $\pm 8$, 
so the <code class="language-plaintext highlighter-rouge">grad_counts</code> parameter cannot get stuck in the no-decay region.</p>

<h2 id="experiment">Experiment</h2>

<p>The experiment code and detailed results are posted on <a href="https://github.com/brian-lawrence-math/mnist">Github</a>.</p>

<p>In summary, I run the GradSign on 1000 batches of 32 samples each,
which amounts to a single pass through just over half of the dataset.
The resulting model achieves over 98% performance.</p>]]></content><author><name></name></author><summary type="html"><![CDATA[The repo.]]></summary></entry></feed>