The FlashAttention CUDA Kernel Line by Line
Flash Attention is a memory-efficient algorithm for computing attention in transformers. Let's break down the CUDA implementation block by block.
The core innovation of Flash Attention: processing attention in blocks while maintaining numerical stability through careful tracking of maximum values and partial sums. The algorithm achieves memory efficiency by never materializing the full attention matrix while still producing mathematically equivalent results to the standard attention mechanism.
First we need some small macros. The CEIL_DIV
is used to compute the ceiling of the division of two numbers.
#include <cuda.h>
#include <cuda_runtime.h>
#define CEIL_DIV(x, y) ((x) >= 0 ? (((x) + (y) - 1) / (y)) : ((x) / (y)))
Now we need the kernel declaration:
template <const int Br, const int Bc>
__global__ void flash_attn_kernel(float* Q, float* K, float* V, int N, int d, int Tr, int Tc, float scale, float* l, float* m, float* O)
This declares a CUDA kernel with template parameters Br
and Bc
which represent the block sizes for rows and columns. The function takes:
- Q, K, V: Query, Key, and Value matrices
- N: Sequence length
- d: Head dimension
- Tr, Tc: Number of blocks in rows and columns
- scale: Scaling factor for attention scores
- l, m: Auxiliary arrays for the numerically stable softmax computation
- O: Output matrix
Thread and Block Index Setup
int tx = threadIdx.x;
int bx = blockIdx.x;
int by = blockIdx.y;
int qkv_off = (bx * gridDim.y * N * d) + (by * N * d);
int lm_off = (bx * gridDim.y * N) + (by * N);
This section retrieves the thread and block indices (tx, bx, by) that identify which thread is executing. It then calculates two important offset values: qkv_off for accessing the Q, K, V matrices and lm_off for accessing the auxiliary arrays l and m. These offsets ensure each block processes the correct portion of the input data.
Shared Memory Setup
extern __shared__ float smem[];
float* Qi = smem;
float* Kj = Qi + Br * d;
float* Vj = Kj + Bc * d;
float* Sij = Vj + Bc * d;
float* Oi = Sij + Br * Bc;
float* li = Oi + Br * d;
float* li_new = li + Br;
float* mi = li_new + Br;
float* mi_new = mi + Br;
float* mij_dash = mi_new + Br;
This section partitions the shared memory into different regions for:
- Qi: Current block of Query matrix
- Kj: Current block of Key matrix
- Vj: Current block of Value matrix
- Sij: Attention scores
- Oi: Output accumulator
- li, li_new, mi, mi_new, mij_dash: Softmax computation variables
Main Processing Loops
The kernel uses two nested loops:
- Outer loop over Key/Value blocks (
j
) - Inner loop over Query blocks (
i
)
Key-Value Loading
for (int j = 0; j < Tc; j++) {
int loads_per_thread = CEIL_DIV(d, Br);
for (int e = 0; e < loads_per_thread; e++) {
int idx = e * (Br * Bc) + tx;
if (idx < Bc * d) {
int row = idx / d;
int col = idx % d;
if (j * Bc + row < N) {
Kj[row * d + col] = K[qkv_off + (j * Bc + row) * d + col];
Vj[row * d + col] = V[qkv_off + (j * Bc + row) * d + col];
}
}
}
__syncthreads();
This section loads the current block of K and V matrices into shared memory, with bounds checking to prevent out-of-bounds access.
Let's break down the main computation loop which processes each block of queries:
Loading Query and Output Data
for (int e = 0; e < loads_per_thread; e++) {
int idx = e * (Br * Bc) + tx;
if (idx < Br * d) {
int row = idx / d;
int col = idx % d;
if (i * Br + row < N) {
Qi[row * d + col] = Q[qkv_off + (i * Br + row) * d + col];
Oi[row * d + col] = O[qkv_off + (i * Br + row) * d + col];
}
}
}
This section loads the current block of Query matrix and Output accumulator into shared memory. Each thread may load multiple elements to ensure efficient memory access patterns. The bounds checking ensures we don't access out-of-bounds memory when N isn't perfectly divisible by the block size.
Thread Position Calculation and Max/Sum Loading
int s_row = tx / Bc;
int s_col = tx % Bc;
if (s_col == 0) {
mi[s_row] = m[lm_off + (i * Br) + s_row];
li[s_row] = l[lm_off + (i * Br) + s_row];
}
__syncthreads();
Each thread calculates its position in the shared memory block. The first thread in each row (s_col == 0
) loads the running maximum and sum values needed for the numerically stable softmax computation.
Computing Attention Scores
float acc = 0.f;
for (int k = 0; k < d; k++)
acc += Qi[s_row * d + k] * Kj[s_col * d + k];
acc *= scale;
Sij[s_row * Bc + s_col] = acc;
This computes the scaled dot product attention scores between the Query and Key matrices. Each thread computes one element of the attention score matrix Sij.
Numerically Stable Softmax Computation
if (s_col == 0) {
float row_m = -INFINITY, row_l = 0.f;
// Find max for numerical stability
for (int c = 0; c < Bc; c++) {
float val = Sij[s_row * Bc + c];
if (val > row_m) {
row_m = val;
}
}
// Compute exponentials and sum
for (int c = 0; c < Bc; c++) {
float exp_val = expf(Sij[s_row * Bc + c] - row_m);
Sij[s_row * Bc + c] = exp_val;
row_l += exp_val;
}
// Update running max and sum
mij_dash[s_row] = row_m;
mi_new[s_row] = max(mi[s_row], row_m);
li_new[s_row] = expf(mi[s_row] - mi_new[s_row]) * li[s_row] +
expf(row_m - mi_new[s_row]) * row_l;
}
This implements the numerically stable softmax algorithm. For each row:
- Find the maximum value for numerical stability
- Compute exponentials normalized by the maximum
- Sum the normalized exponentials
- Update the running maximum and sum for the online softmax computation
Output Computation and Update
for (int col = s_col; col < d; col += Bc) {
float acc = 0.f;
for (int c = 0; c < Bc; c++)
acc += Sij[s_row * Bc + c] * Vj[c * d + col];
int global_row = (i * Br) + s_row;
if (global_row < N) {
Oi[s_row * d + col] = (1 / li_new[s_row]) *
((li[s_row] * expf(mi[s_row] - mi_new[s_row]) * Oi[s_row * d + col]) +
(expf(mij_dash[s_row] - mi_new[s_row]) * acc));
O[qkv_off + global_row * d + col] = Oi[s_row * d + col];
}
}
This final section:
- Computes the matrix multiplication between the softmax probabilities and the Value matrix
- Updates the output using the online softmax algorithm
- Each thread may compute multiple output elements (when d > Bc)
- The reason for the complex update formula is to ensure numerical stability while accumulating partial results
State Update for Next Iteration
m[lm_off + (i * Br) + s_row] = mi_new[s_row];
l[lm_off + (i * Br) + s_row] = li_new[s_row];
Finally, we update the running maximum and sum values in global memory for the next iteration. This is crucial for the online softmax computation across blocks.