flashattention学习笔记

之前面ailab是把我送走的最后一个问题就是flashattention的原理,现学习一下

资料

讲解视频:虽然说播放量只有6000多,但我感觉算是讲的很好了,就是在softmax部分,稍微有一点点没看懂。

参考的代码:一开场一个pybind把我愉悦送走了

原论文:一眼没看

公式推导

写在前面

  • 这份代码的实现,整体上没有什么大问题,但是在block_dim的设计上,即一个线程块中应该有多少个线程部分,设计的有些不合理。按照代码的设计方案,每个线程块中的线程数等于Bc即32,这会导致每个线程负责的计算量太大。

    以(2048,16,32)的输入为例子(即一个句子有2048个单词,每个单词有512维数据,共分为了16个头,每个头处理32维数据),比较合理的grid_dim和block_dim设置如下:

    1. 首先确定好Bc的值,这部分可以通过共享内存大小计算,这里假定Bc为64,即子矩阵的行数为64,一次将64*512维矩阵拿进共享内存做计算。
    2. grid_dim设置为(2048/Bc,16),这样,每个grid处理的是一个头对应的一小部份矩阵(大小为Bc*32)
    3. block_dim设置为Bc*n,n表示将32维的数据分为了几份,即一个线程处理(1行,32/n)列的Q*K的转置个元素的计算

    这份代码实际上处理的是多个batch,因此,grid设置应为(batch_size*2048/BC,16),block设置不变。

  • 由于原代码一个线程处理一整行,所以m和l的计算显得很简单,实际上需要通过thread_idx去考虑处理同一行的多个线程中最大的m和l。万幸的是,公式没有变化。
    具体流程为:

    1. 局部统计量计算:每个线程计算其负责的16个元素的 局部最大值 和 局部指数和
    2. 全局最大值 ( m ) 归约:通过共享内存和Warp级归约,跨线程协作获取整行的全局最大值
    3. 全局指数和 ( l ) 归约:基于已知的全局最大值 ( m ),重新计算指数并归约求和
    4. Softmax归一化与输出:使用归约后的 ( m ) 和 ( l ) 计算标准化注意力权重

代码解析

main.cpp

首先看一下main.cpp,这里主要是用了pybind把cpp代码打包成python可以直接用的module,第一个参数是预定义宏,第二个参数名可以随便起,后面的m.def 中,第一个参数 forward是python中的函数名,PyTorch 的包装函数将 C++ 实现的 forward 函数转换为 Python 可调用的形式,处理了张量类型、梯度等转换,第三个forward是说明。

bench.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import math

import torch
from torch.nn import functional as F
from torch.utils.cpp_extension import load

# Load the CUDA kernel as a python module
minimal_attn = load(name='minimal_attn', sources=['main.cpp', 'flash.cu'], extra_cuda_cflags=['-O2'])

# Use small model params, otherwise slower than manual attention. See caveats in README.
batch_size = 16
n_head = 12
seq_len = 64
head_embd = 64

q = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
k = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
v = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()

print('=== profiling manual attention ===')

# Our minimal flash attention aims to be faster than this by avoiding HBM read/writes of N^2 matrices.
def manual_attn(q, k, v):
att = (q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1))))
att = F.softmax(att, dim=-1)
y = att @ v
return y

with torch.autograd.profiler.profile(use_cuda=True) as prof:
manual_result = manual_attn(q, k, v)
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))

print('=== profiling minimal flash attention === ')

with torch.autograd.profiler.profile(use_cuda=True) as prof:
minimal_result = minimal_attn.forward(q, k, v)
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))

print('attn values sanity check:', torch.allclose(minimal_result, manual_result, rtol=0, atol=1e-02))

这部分代码一点难度没有,无非就是用手写的方式实现了下attention。

需要学习下:

1
2
3
4
5
6
7
8
# Load the CUDA kernel as a python module
minimal_attn = load(name='minimal_attn', sources=['main.cpp', 'flash.cu'], extra_cuda_cflags=['-O2'])

···

with torch.autograd.profiler.profile(use_cuda=True) as prof:
minimal_result = minimal_attn.forward(q, k, v)
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))

这样的调用方法

flash.cu

flashattention的核心代码都在这里

定义变量

这一块的代码最好是结合python代码中对变量的定义和对函数的调用来看

1
2
3
4
5
6
7
8
batch_size = 16
n_head = 12
seq_len = 64
head_embd = 64

q = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
k = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
v = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()

根据这部分可以看出,QKV变量是按照(batch_size, n_head, seq_len, head_embd)的格式定义的,具体的值也写在上面了

然后跳过核函数,直接看调用的代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
const int Bc = 32; const int Br = 32;

const int B = Q.size(0); const int nh = Q.size(1);
const int N = Q.size(2); const int d = Q.size(3);

const int Tc = ceil((float) N / Bc); const int Tr = ceil((float) N / Br);
const float softmax_scale = 1.0 / sqrt(d);

// Initialize O, l, m to HBM
auto O = torch::zeros_like(Q);
auto l = torch::zeros({B, nh, N});
auto m = torch::full({B, nh, N}, -INFINITY);
torch::Device device(torch::kCUDA);
l = l.to(device); m = m.to(device);

Bc和Br应该是每次载入多少行、列的子矩阵进入sram,按照b站视频的说法,具体载入多少行,应该是符合192k/4*什么的一个公式的,但在这里省略了计算,直接设置了一个定值

然后是B、nh、N、d四个变量,这里对应前面python代码,可以看出分别对应的是batch_size, n_head, seq_len, head_embd,对应的具体值为16,12,64,64,这表示一次处理16个句子(batch_size),一个句子有64个单词
(seq_len),每个单词的维数是12*64(n_head*head_embd)

  • Bc:列方向的分块大小,每个块处理K和V的Bc个元素。
  • Br:行方向的分块大小,每个块处理Q的Br个元素。
  • Tc:列方向的总块数,等于ceil(N / Bc)。
  • Tr:行方向的总块数,等于ceil(N / Br)。

接下来使用const int Tc = ceil((float) N / Bc); const int Tr = ceil((float) N / Br);来计算了行和列方向的总块数。

然后的O、l、m就需要费点心思去理解了,这里的命名严格遵循了原论文(好评!):

  • l:存储每个位置i的累加分母(exp(s_i - m_i)的和)。
  • m:存储每个位置i的当前最大值,用于数值稳定。
  • O:最终的输出,形状和Q相同。

这些主要是softmax部分需要使用的,l就是softmax的分母,m是为了防止浮点数溢出,softmax时分子分母统一除以m

计算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// Calculate SRAM size needed per block
const int sram_size = (3 * Bc * d * sizeof(float)) + (Bc * Br * sizeof(float));
int max_sram_size;
cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0);
printf("Max shared memory: %d, requested shared memory: %d \\n", max_sram_size, sram_size);

dim3 grid_dim(B, nh); // batch_size x num_heads
dim3 block_dim(Bc); // Bc threads per block

forward_kernel<<<grid_dim, block_dim, sram_size>>>(
Q.data_ptr<float>(), K.data_ptr<float>(), V.data_ptr<float>(),
N, d, Tc, Tr, Bc, Br, softmax_scale,
l.data_ptr<float>(), m.data_ptr<float>(), O.data_ptr<float>()
);
return O;

对于sram中的(3 * Bc * d * sizeof(float)),表示的是QKV的块,bc是分块行数,d是特征维度。Bc * Br * sizeof(float)表示的是Q*V后的矩阵S

然后就进入核函数,每个线程块处理一个头

核函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
void forward_kernel(const float* Q, const float* K, const float* V, const int N, const int d,
const int Tc, const int Tr, const int Bc, const int Br, const float softmax_scale,
float* l, float *m, float* O) {
int tx = threadIdx.x;
int bx = blockIdx.x; int by = blockIdx.y; // batch and head index

// Offset into Q,K,V,O,l,m - different for each batch and head
int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d); // gridDim.y = nh
int lm_offset = (bx * gridDim.y * N) + (by * N); // offset for l and m

// Define SRAM for Q,K,V,S
extern __shared__ float sram[];
int tile_size = Bc * d; // size of Qi, Kj, Vj
float* Qi = sram;
float* Kj = &sram[tile_size];
float* Vj = &sram[tile_size * 2];
float* S = &sram[tile_size * 3];

for (int j = 0; j < Tc; j++) {

// Load Kj, Vj to SRAM
for (int x = 0; x < d; x++) {
Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
}
__syncthreads(); // such that the inner loop can use the correct Kj, Vj

for (int i = 0; i < Tr; i++) {

// Load Qi to SRAM, l and m to registers
for (int x = 0; x < d; x++) {
Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x];
}
float row_m_prev = m[lm_offset + (Br * i) + tx];
float row_l_prev = l[lm_offset + (Br * i) + tx];

// S = QK^T, row_m = rowmax(S)
float row_m = -INFINITY;
for (int y = 0; y < Bc; y++) {
float sum = 0;
for (int x = 0; x < d; x++) {
sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
}
sum *= softmax_scale;
S[(Bc * tx) + y] = sum;

if (sum > row_m)
row_m = sum;
}

// P = exp(S - row_m), row_l = rowsum(P)
float row_l = 0;
for (int y = 0; y < Bc; y++) {
S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - row_m);
row_l += S[(Bc * tx) + y];
}

// Compute new m and l
float row_m_new = max(row_m_prev, row_m);
float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev) + (__expf(row_m - row_m_new) * row_l);

// Write O, l, m to HBM
for (int x = 0; x < d; x++) {
float pv = 0; // Pij * Vj
for (int y = 0; y < Bc; y++) {
pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
}
O[qkv_offset + (tile_size * i) + (tx * d) + x] = (1 / row_l_new) \
* ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (tile_size * i) + (tx * d) + x]) \
+ (__expf(row_m - row_m_new) * pv));
}
m[lm_offset + (Br * i) + tx] = row_m_new;
l[lm_offset + (Br * i) + tx] = row_l_new;
}
__syncthreads(); // otherwise, thread can use the wrong Kj, Vj in inner loop
}
}


1
2
int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d);  // gridDim.y = nh
int lm_offset = (bx * gridDim.y * N) + (by * N); // offset for l and m

核函数的grid_dim是按照batch_size x num_heads排列的,block_dim是Bc个线程。具体到代码中,grid_dim为(16,12),Bc为32,因此grid共有(16,12)个,每个grid中有32个线程。

因此:

  • bx:blockIdx.x,表示当前线程块处理的 批次索引(例如 bx=0 表示第一个批次)。
  • by:blockIdx.y,表示当前线程块处理的 注意力头索引(例如 by=0 表示第一个注意力头)。
  • gridDim.y:CUDA 网格的 y 维度,等于注意力头的数量 nh。
  • N 是序列长度,
  • d 是每个头的特征维度

每个批次包含 nh 个头,每个头包含 N × d 个元素,前 bx 个批次占用的总元素数为:
bx × nh × N × d
(即 bx * gridDim.y * N * d)。当前头在批次内的起始位置为:
by × N × d
(即 by * N * d)。所以该头的数据在全局内存中对应的位置为:qkv_offset。

lm_offset与之类似,每个批次包含 nh 个头,每个头包含 N 个元素。
前 bx 个批次占用的总元素数为:
bx × nh × N
(即 bx * gridDim.y * N)。当前头在批次内的起始位置为:
by × N
(即 by * N)。


1
2
3
4
5
6
7
// Define SRAM for Q,K,V,S
extern __shared__ float sram[];
int tile_size = Bc * d; // size of Qi, Kj, Vj
float* Qi = sram;
float* Kj = &sram[tile_size];
float* Vj = &sram[tile_size * 2];
float* S = &sram[tile_size * 3];

这一部分开内存,没什么好说的


1
2
3
4
for (int x = 0; x < d; x++) {
Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
}

这一部分是把K和V加载到共享内存Kj和Vj中,加载完后进行后续计算

接下来涉及到的是lm的处理:

在注意力机制的 在线 Softmax 计算中,l 和 m 是用于数值稳定的关键变量:

  • m:当前行(或块)的最大值(max),用于防止 exp 函数溢出。

  • l:当前行(或块)的指数和(sum(exp)),用于归一化(Softmax 的分母)。

它们的 迭代更新逻辑 如下(假设分块计算):

初始值:m_prev = -inf, l_prev = 0(初始状态)。

每次处理一个块:
用当前块的局部最大值 m_new 和局部指数和 l_new,结合之前的 m_prev 和 l_prev,更新全局的 m 和 l。

1
2
float row_m_prev = m[lm_offset + (Br * i) + tx];
float row_l_prev = l[lm_offset + (Br * i) + tx];

从全局内存中读取当前行(对应线程 tx)的上一轮最大值 row_m_prev 和指数和 row_l_prev。

1
2
3
4
5
6
7
8
9
10
11
12
float row_m = -INFINITY;
for (int y = 0; y < Bc; y++) {
float sum = 0;
for (int x = 0; x < d; x++) {
sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
}
sum *= softmax_scale;
S[(Bc * tx) + y] = sum;

if (sum > row_m)
row_m = sum;
}

目的:计算当前 Q 块(Qi)和 K 块(Kj)的乘积矩阵 S = QK^T 的局部块,并求当前块的行最大值 row_m。

逻辑:
Qi 是 Q 的当前块(行数为 Br),Kj 是 K 的当前块(行数为 Bc)。
每个线程计算 Q 的一行(tx 对应行)与 K 的所有列(y 遍历)的点积。
row_m 记录当前行在块内的最大值。

然后计算新的row_l(Softmax 的分母),更新全局m和l
float row_m_new = max(row_m_prev, row_m); float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev) + (__expf(row_m - row_m_new) * row_l);这部分是更新的关键,通过将当前块的局部最大值 row_m 和指数和 row_l 与之前的全局值 row_m_prev 和 row_l_prev 合并,得到新的全局值 row_m_new 和 row_l_new。

  • 新最大值:row_m_new = max(old_max, current_max)
  • 新指数和:
    row_l_new = exp(old_max - new_max) * old_sum + exp(current_max - new_max) * current_sum(通过数值稳定公式合并新旧结果)。

最后将各个部分写会hbm 完成计算。


flashattention学习笔记
http://zzsy.me/2025/02/26/flashattention学习笔记/
作者
yuanyuan
发布于
2025年2月26日
更新于
2025年2月28日
许可协议