Optimizing matrix multiplciation
Optimizing matrix multiplication on an RTX 3050
I was working on mytorch, a mock tensor library for GPU, and I decided to spend some time optimizing the matrix multiplication function. After all, matrix multiplication is the most expensive operation in deep learning, and optimizing batched matrix multiplication is something of a rite of passage in performance computing.
If you’re writing a production tensor library like Pytorch, your code needs to be fast in many different settings:
- on CPUs and GPUs of various architectures
- with tensors (batches of matrices) of various sizes
- with matrices in row- or column-major order
- with matrices that might not be contiguous in memory.
That’s a lot of engineering. I decided to start by optimizing a single task: I want to see how fast I can perform a batched matrix multiplication on two 32-bit float tensors of shape (100, 1000, 1000), in row-major order, on an RTX 3050 GPU.
Useful reading
I drew inspiration from a couple of terrific blog posts.
Some benchmarks
CUDA provides a highly optimized library for the Basic Linear Algebra Subroutines (BLAS), called cublas. Batched matrix multiplication is one of the standard operations in BLAS. Using cublas, our computation takes 68-72 ms. (As always, the benchmark varies slightly from run to run, and somewhat more from day to day.)
My first attempt to do the job with no optimizations will take 840 ms, about 12 times slower than cublas.
As of this writing, I’ve sped the calculation up to 72 ms, achieving parity (or very close to it) with CUBLAS. I’m working to get further improvement.
Baseline analysis
The RTX 3050 has 18 streaming multiprocessors with 128 cores each, for a total of 2304 cores. It runs at a boost clock speed of 1.47 GHz, for a total of ~3.39T cycle-cores per second. At two floating-point operations per clock cycle, the chip can achieve 6.77 TFLOPS.
The matrix multiplication requires 2*100*1000*1000*1000 = 200 billion floating-point operations.
Assuming 6.774 TFLOPS (and zero overhead) gives a time of 30 ms to perform the
batched matrix multiplication.
I’ll take 30 ms as a theoretical maximum speed.
Of course, this doesn’t take into account the time required for memory access
and any other computational overhead.
In my experiments, cublas takes just over twice this estimated theoretical best time.
A simple first attempt
I put together a simple CUDA C++ implementation of batched matrix multiplication on pytorch-style tensors, with arbitrary dimension, shape and stride.
I assigned each thread to compute a single entry of the output tensor (so our example will spin up 100 million threads), and I packed 1024 threads per block, the maximum value (for a total of about 98,000 blocks).
The calculation took ~840 ms, 28 times the theoretical best time, and 12 times slower than cublas.
Looks like we have some optimizing to do.
First thoughts
The kernel is probably memory-bound, not compute-bound. It requires a total of 200 billion floating-point operations, but 800 GB of memory reads. (At each step in the matrix-multiplication loop, the kernel reads one float from each matrix, and does one multiplication and one addition: two FLOPs and 8 bytes.) Memory reads are much more expensive than compute: the chip can compute about 6 TFLOPS, but memory bandwidth is only about 168 GB/s. The memory performance will be slightly faster due to caching, but I still expect memory to be the bottleneck.
The natural idea is to try thread coarsening or tiling. I want to load some input data once and compute on it many times. If one thread is responsible for computing a small box of matrix entries, rather than just one, that thread can use its memory access more efficiently. And by using shared memory (low-latency memory which is shared among all threads in a block) I can arrange for several threads to collaborate using data that is only loaded (from slow global memory) once.
But first I want to pick some low-hanging fruit.
Cutting down on shape-and-stride calculations
The kernel includes logic to handle arbitrary shapes and strides. I think this logic is imposing a lot of unnecessary cost.
To start with, the shape-and-stride calculation is being done in the kernel: each of my 100 million threads is reading the shapes and strides of the input tensors and computing the index of the one entry it needs to access. That’s a lot of repeated calculation. Worse yet, the shapes and strides are stored in vectors whose length (the dimension of the tensor) is unknown at compile time. This means that the vectors are stored in global memory, resulting in a lot of unnecessary memory access. (Actually, since these values are accessed so often, they are probably stored in a low-level cache…)
Most matrix multiplication in practice works on batched matrices with a simple structure: the matrices are contiguous in memory, in row-major order; and the “matrix dimensions” are the two dimensions with the smallest strides. So, I’ll try to optimize a simple case first: a batched matrix product of two three-dimensional (batch, row, col) tensors.
In any case, I wrote a new kernel matmul_3d() that assumes its inputs are contiguous three-dimensional tensors, and accepts the shape directly as argument to the function. The result: from 840ms down to 800ms.
OK, let’s try to get more economies of scale.
Profiling
Nvidia provides a powerful profiler, ncu. The profiler shows all sorts of metrics, including memory throughput, cache (L1 and L2) throughput, compute throughput, occupancy and workload statistics… It even offers suggestions for optimization.
In my experience, at this stage, it’s best to focus on writing efficient code. The profiler is a great tool later on, when I have specific questions about resource usage.
Indeed, this profiler’s first suggestion is that I change the number of threads per block to increase occupancy. That’s not the biggest priority at this point – and indeed, if I make the change the profiler suggests, performance gets worse.
So: I’m going to ignore ncu’s advice and get back to writing solid code. The big bottleneck is memory access, so that’s where I’m going to focus. But the profiler will come back later on.
Improving memory efficiency
To start with, I’m going to make two improvements to the kernel.
- Load inputs into shared memory, in batches, and
- make each thread responsible for more than one output entry.
I’m going to make configurable parameters for:
- TM and TN – these determine how many output values each thread will calculate;
- TPBM and TPBN – these determine how many threads in a block; and
- BK – the multiplication loop size.
This last parameter needs some explanation. Each block of threads is responsible for computing a (TM*TPBM) by (TN*TPBN) submatrix of the output matrix. To do this it will need to access some number of full rows of the first input, and some number of full columns of the second; then it will loop over the columns of the first input (and the rows of the second).
There might not be enough room in shared memory for all the rows and columns that need to be loaded. So, instead of being loaded in full, they will be loaded in blocks of BK. In other words: matrix multiplication involves a summation over the intermediate dimension; we will break that summation into chunks of size BK, and compute partial sums one chunk at a time.
Improving memory access patterns
Global memory is stored in DRAM; shared memory (and the L1 cache) are stored in SRAM. DRAM reads memory in consecutive 32-byte chunks; if I don’t use all 32 bytes, I’m wasting bandwidth. SRAM memory is stored in 32 banks (each 4 bytes wide – so for example bank 0 is responible for addresses 0, 1, 2, 3 modulo 128). In a single read, SRAM can read any 4-bite word from each of its 32 banks, independently. So we want each thread in a warp to try to read from a different bank. If multiple threads request data from the same bank, the result is a “bank conflict”: the SRAM will have to perform multiple physical reads before the result can be returned.
In both situations, a good pattern is for the 32 threads in a warp to access consecutive floats in memory:
int idx = threadIdx.x;
data[idx] ... .First, I’ll make sure the number of rows in each block of threads is 32 (at least when the matrices have >= 32 rows); this means each warp is exactly one row.
Now let’s plan how to arrange memory and threads. As far as memory:
- The input tensors are already laid out contiguous in row-major order, we can’t change that;
- The result tensor is also in row-major order; we can’t touch it either;
- But the “shared” tensors (copies of tiles of input tensors that reside in shared memory) can be arranged how we like.
And as far as thread arrangement, we have to decide how to divide up each of these three operations among threads in the block:
- Copy input tensor a into shared memory;
- Copy input tensor b into shared memory;
- Perform the “multiplication loop” to compute entries of output tensor.
Let’s start with the entries of the output tensor. Conceptually it looks something like the following.
float cml_sum = 0.0f;
for (int loop_idx = 0; ... ) {
cml_sum += a_shared[row][loop_idx] * b_shared[loop_idx][col];
}
result[row][col] += cml_sum;A natural choice is to have consecutive threads operate on the same “row” and consecutive “col”: this way the global memory writes at line 4 are efficient, with all 32 threads in the warp writing to one 128-bit line of global memory (or two lines, if the alignment isn’t right). Assuming b_shared is stored in row-major order, the shared memory reads in line 2 are good as well: all threads read the same entry from a_shared, which is efficient (it’s called “broadcasting”), and the 32 threads write to 32 consecutive entries of b_shared.
As for copying global ‘a’ and ‘b’ into shared ‘a_shared’ and ‘b_shared’: it’s the same idea. I store ‘a_shared’ and ‘b_shared’ in row-major order, so data that is contiguous in ‘a’ is also contiguous in ‘a_shared’. Then I arrange for all the threads in the block to handle consecutive floats, one float each.
Thread-local results in registers
The kernel is still suffering from global memory writes inside a tight loop: if you look back up at the “conceptual” code, the line
result[row][col] += cml_sum;entails a read and a write to slow global DRAM.
Better: store per-thread results in registers until the calculation is done. In pseudocode:
float tmp[TM][TN];
for (k ...) {
for (m_ctr = 0; m_ctr < TM; m_ctr++) {
for (n_ctr = 0; n_ctr < TN; n_ctr++) {
tmp[m_ctr][n_ctr] += A_s[m][k] * B_s[k][n];
}
}
}
for (m_ctr = 0; m_ctr < TM; m_ctr++) {
for (n_ctr = 0; n_ctr < TN; n_ctr++) {
result[...] = tmp[...];
}
}Parameter tuning
At this point I have a bunch of different parameters:
- the dimensions of a tile TM, TN;
- the number of tiles per block (in row and column dimensions);
- the number of $k$ to loop over.
After some manual experimentation with these parameters I find the following values.
#define TM (8)
#define TN (8)
#define TPBM (16)
#define TPBN (16)
#define BM (TM * TPBM)
#define BN (TN * TPBN)
#define BK (16)My batched matrix multiplication is now at 107 ms. That’s still about 1.5x the time the CUBLAS kernel takes, but I have one more trick up my sleeve.
Vectorized loads
GPUs offer a single instruction to read 128 bits (= 16 bytes = 4 floats) at once.
The catch is that the read address has to be 16-byte aligned.
If you promise the compiler that your address is 16-byte aligned,
the compiler will give you vectorized reads and writes.
One easy way to do this, syntactically, is with the float4 type,
which represents a vector (an array?) – anyway, it represents 4 floats in a row,
aligned to a multiple of 16 bytes.
Here’s some mock code to show how to do it: if you’re copying a bunch of floats
float *A, *B;
for (int i = 0; i < N; i++) {
A[i] = B[i];
}you can simply cast the pointers from float* to float4*, like so:
// gentle reminder that things need to be aligned
assert((uintptr_t) A % 16 == 0);
// and that I'm not handling edge effects
assert(N % 4 == 0);
for (int i = 0; i < N; i += 4) {
*reinterpret_cast<float4*>(&A[i]) = *reinterpret_cast<float4*>(&B[i]);
}In other words, we’ve replaced $N$ assignments float = float
with $N/4$ assignments float4 = float4.
And here’s what the change looks like in my kernel.
There’s all sorts of messy indexing going on thanks to tiling in various dimensions,
but notice how simple it is to change plain float loads to vectorized float4 loads:
I really just multiply the step size by 4 and hit the pointers with reinterpret_cast.
Before:
// A_s[i, j] = A[a_row + i, k + j]
for (size_t a_idx = threadIdx.x; a_idx < BM * BK; a_idx += blockDim.x) {
size_t i = a_idx / BK;
size_t j = a_idx % BK;
if (k0 + j < K && i + a_row_base < M) {
A_s[i * BK + j] = A[i * (a.shape[2]) + k0 + j];
} else {
A_s[i * BK + j] = 0;
}
}After:
// A_s[i, j] = A[a_row + i, k + j]
for (size_t a_idx = 4 * threadIdx.x; a_idx < BM * BK; a_idx += 4 * blockDim.x) {
size_t i = a_idx / BK;
size_t j = a_idx % BK;
if (k0 + j < K && i + a_row_base < M) {
*reinterpret_cast<float4*>(&A_s[i * BK + j]) = *reinterpret_cast<float4*>(&A[i * (a.shape[2]) + k0 + j]);
} else {
A_s[i * BK + j] = 0;
A_s[i * BK + j + 1] = 0;
A_s[i * BK + j + 2] = 0;
A_s[i * BK + j + 3] = 0;
}
}Of course, there’s some additional boilerplate to check bounds and alignment. Nvidia guarantees that cudaMalloc returns 256-byte aligned memory, but we still need to worry about the dimensions of the matrix (which could be arbitrary), and we’d better make sure we choose the parameter $BK$ to be a multiple of 4. I’m deliberately running this initial benchmark on 1000-by-1000 matrices so I don’t have to worry about edge effects.
Vectorized loads from global into shared memory shave bring the runtime down to 72 ms: we’ve achieved parity with CUBLAS (on a good day).
The existing kernel, in code
Here it is.
__global__ void matmul_tiled_2(ContiguousTensor3d_Device a,
ContiguousTensor3d_Device b,
ContiguousTensor3d_Device res) {
// blocks: (batch, row, col)
// threads: (n_threads, 1, 1) -- I will manage indexing myself
size_t M = a.shape[1], K = a.shape[2], N = b.shape[2];
size_t a_row_base = BM * blockIdx.y;
size_t b_col_base = BN * blockIdx.z;
size_t batch_idx = blockIdx.x;
float *A =
a.data + batch_idx * a.shape[1] * a.shape[2] + a_row_base * a.shape[2];
float *B = b.data + batch_idx * b.shape[1] * b.shape[2] + b_col_base;
float *RES = res.data + batch_idx * res.shape[1] * res.shape[2] +
a_row_base * res.shape[2] + b_col_base;
__shared__ float A_s[BM * BK];
__shared__ float B_s[BK * BN];
// store results in registers
// each thread will be responsible for TM * TN entries of C...
// TM rows and TN cols, the rows strided every TPBM, the cols strided every
// TPBN
float tmp[TM * TN] = {0};
size_t this_thread_row = threadIdx.x / TPBN;
size_t this_thread_col = threadIdx.x % TPBN;
for (size_t k0 = 0; k0 < K; k0 += BK) {
// load a and b
// A_s[i, j] = A[a_row + i, k + j]
// values to fill: BM * BK
// threads: TPB
// want: blockDim.x div. by BK
for (size_t a_idx = 4 * threadIdx.x; a_idx < BM * BK; a_idx += 4 * blockDim.x) {
size_t i = a_idx / BK;
size_t j = a_idx % BK;
if (k0 + j < K && i + a_row_base < M) {
*reinterpret_cast<float4*>(&A_s[i * BK + j]) = *reinterpret_cast<float4*>(&A[i * (a.shape[2]) + k0 + j]);
} else {
A_s[i * BK + j] = 0;
A_s[i * BK + j + 1] = 0;
A_s[i * BK + j + 2] = 0;
A_s[i * BK + j + 3] = 0;
}
}
// B_s[i, j] = B[k + i, b_col + j]
for (size_t b_idx = 4 * threadIdx.x; b_idx < BK * BN; b_idx += 4 * blockDim.x) {
size_t i = b_idx / BN;
size_t j = b_idx % BN;
if (k0 + i < K && j + b_col_base < N) {
*reinterpret_cast<float4*>(&B_s[i * BN + j]) = *reinterpret_cast<float4*>(&B[(k0 + i) * b.shape[2] + j]);
} else {
B_s[i * BN + j] = 0;
B_s[i * BN + j + 1] = 0;
B_s[i * BN + j + 2] = 0;
B_s[i * BN + j + 3] = 0;
}
}
__syncthreads();
for (size_t k = 0; k < BK; k++) {
float B_reg[TN];
for (size_t col_counter = 0; col_counter < TN; col_counter++) {
size_t col = col_counter * TPBN + this_thread_col;
B_reg[col_counter] = B_s[k * BN + col];
}
for (size_t row_counter = 0; row_counter < TM; row_counter++) {
// filling in row a_row_base + row_counter * TPBM + this_thread_row of C
// which is row row_counter of tmp
size_t row = row_counter * TPBM + this_thread_row;
float a_val = A_s[row * BK + k];
for (size_t col_counter = 0; col_counter < TN; col_counter++) {
// col b_col_base + col_counter * TPBN + this_thread_col of C
// which is col col_counter of tmp
float b_val = B_reg[col_counter];
tmp[row_counter * TN + col_counter] += a_val * b_val;
}
}
}
__syncthreads();
}
// now tmp[row * TN + col] goes into RES[(row * TPBM + this_thread_row) *
// res.shape[2] + (col * TPBN + this_thread_col)]
for (size_t row_counter = 0; row_counter < TM; row_counter++) {
for (size_t col_counter = 0; col_counter < TN; col_counter++) {
RES[(row_counter * TPBM + this_thread_row) * res.shape[2] +
(col_counter * TPBN + this_thread_col)] =
tmp[row_counter * TN + col_counter];
}
}
}More optimizations
This is still a work in progress. The most promising next optimizations are…
- double-buffering, and
- a thorough parameter sweep.