import torch from torch.nn import functional as F from torch.utils.cpp_extension import load
# Load the CUDA kernel as a python module minimal_attn = load(name='minimal_attn', sources=['main.cpp', 'flash.cu'], extra_cuda_cflags=['-O2'])
# Use small model params, otherwise slower than manual attention. See caveats in README. batch_size = 16 n_head = 12 seq_len = 64 head_embd = 64
q = torch.randn(batch_size, n_head, seq_len, head_embd).cuda() k = torch.randn(batch_size, n_head, seq_len, head_embd).cuda() v = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
print('=== profiling manual attention ===')
# Our minimal flash attention aims to be faster than this by avoiding HBM read/writes of N^2 matrices. defmanual_attn(q, k, v): att = (q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1)))) att = F.softmax(att, dim=-1) y = att @ v return y
with torch.autograd.profiler.profile(use_cuda=True) as prof: manual_result = manual_attn(q, k, v) print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))
constint B = Q.size(0); constint nh = Q.size(1); constint N = Q.size(2); constint d = Q.size(3);
constint Tc = ceil((float) N / Bc); constint Tr = ceil((float) N / Br); constfloat softmax_scale = 1.0 / sqrt(d);
// Initialize O, l, m to HBM auto O = torch::zeros_like(Q); auto l = torch::zeros({B, nh, N}); auto m = torch::full({B, nh, N}, -INFINITY); torch::Device device(torch::kCUDA); l = l.to(device); m = m.to(device);
voidforward_kernel(constfloat* Q, constfloat* K, constfloat* V, constint N, constint d, constint Tc, constint Tr, constint Bc, constint Br, constfloat softmax_scale, float* l, float *m, float* O){ int tx = threadIdx.x; int bx = blockIdx.x; int by = blockIdx.y; // batch and head index
// Offset into Q,K,V,O,l,m - different for each batch and head int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d); // gridDim.y = nh int lm_offset = (bx * gridDim.y * N) + (by * N); // offset for l and m
// Define SRAM for Q,K,V,S extern __shared__ float sram[]; int tile_size = Bc * d; // size of Qi, Kj, Vj float* Qi = sram; float* Kj = &sram[tile_size]; float* Vj = &sram[tile_size * 2]; float* S = &sram[tile_size * 3];
for (int j = 0; j < Tc; j++) {
// Load Kj, Vj to SRAM for (int x = 0; x < d; x++) { Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x]; Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x]; } __syncthreads(); // such that the inner loop can use the correct Kj, Vj
for (int i = 0; i < Tr; i++) {
// Load Qi to SRAM, l and m to registers for (int x = 0; x < d; x++) { Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x]; } float row_m_prev = m[lm_offset + (Br * i) + tx]; float row_l_prev = l[lm_offset + (Br * i) + tx];
// S = QK^T, row_m = rowmax(S) float row_m = -INFINITY; for (int y = 0; y < Bc; y++) { float sum = 0; for (int x = 0; x < d; x++) { sum += Qi[(tx * d) + x] * Kj[(y * d) + x]; } sum *= softmax_scale; S[(Bc * tx) + y] = sum;
每个批次包含 nh 个头,每个头包含 N × d 个元素,前 bx 个批次占用的总元素数为: bx × nh × N × d (即 bx * gridDim.y * N * d)。当前头在批次内的起始位置为: by × N × d (即 by * N * d)。所以该头的数据在全局内存中对应的位置为:qkv_offset。
lm_offset与之类似,每个批次包含 nh 个头,每个头包含 N 个元素。 前 bx 个批次占用的总元素数为: bx × nh × N (即 bx * gridDim.y * N)。当前头在批次内的起始位置为: by × N (即 by * N)。
1 2 3 4 5 6 7
// Define SRAM for Q,K,V,S extern __shared__ float sram[]; int tile_size = Bc * d; // size of Qi, Kj, Vj float* Qi = sram; float* Kj = &sram[tile_size]; float* Vj = &sram[tile_size * 2]; float* S = &sram[tile_size * 3];