Releases: DefTruth/CUDA-Learn-Notes
🎉FA2/HGEMM SMEM Swizzle
What's Changed
- [FA2] split-q + tiling-qk D=512 performance🎉 by @DefTruth in #178
- [FA2] split-q + tiling-qk D=512 performance🎉 by @DefTruth in #179
- [FA2] split-q + tiling-qk D=512 performance🎉 by @DefTruth in #180
- [Doc] Refactor README.md to improve readability✔️ by @DefTruth in #181
- [Doc] Refactor README.md for better readability✔️ by @DefTruth in #182
- [FA2] flash-attn-mma 3080/L20/4090 bench✔️ by @DefTruth in #183
- [FA2] flash-attn-mma 3080/L20/4090 bench✔️ by @DefTruth in #184
- [FA2] fa2/hgemm manually smem swizzle🎉 by @DefTruth in #185
flash_attn_mma_stages_split_q_tiling_qk_swizzle_kernel
void flash_attn_mma_stages_split_q_tiling_qk_swizzle_kernel<512, 16, 8, 16, 8, 1, 8, 1, 1, 16, 1, 64, 2, 0, 0, 8>(__half *, __half *, __half *, __half *, int, int) (8, 48, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 8.9
Section: Command line profiler metrics
------------------------------------------------------------------ ----------- ------------
Metric Name Metric Unit Metric Value
------------------------------------------------------------------ ----------- ------------
sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.avg 0
sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.max 0
sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.min 0
sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.sum 0
------------------------------------------------------------------ ----------- ------------
Full Changelog: v2.6.11...v2.6.12
📚 Split Q + QK Fine-grained Tiling
What's Changed
📚 Split Q + QK Fine-grained Tiling (O(16xd) SRAM vs FA2 O(4xBrxd) SRAM, Headdim -> 1024
)
Currently, for small-scale attention (B<=4, H <=48, SeqLen <= 8192)
can run faster than offical FA2/SDPA on some Devices. For example, on NVIDIA RTX 3080 Laptop, 📚 Split Q + Fully Shared QKV SMEM can achieve 55 TFLOPS (D=64) that almost ~1.5x 🎉 faster than FA2. Moreover, on NVIDIA L20, 📚 Split Q + QK Fine-grained Tiling can achieve 81 TFLOPS (D=512) that almost ~1.4x 🎉 faster than SDPA(EFFICIENT_ATTENTION). However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~
- Example: B=1, H=8, N=8192,
D=64
(NVIDIA RTX 3080 Laptop), Faster than FA2~🎉🎉
python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch # NVIDIA RTX 3080 Laptop
-------------------------------------------B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10-------------------------------------------
torch(unfused): ['-0.00514603 ', '0.05783081 ', '-0.00026727 '], time:20.999861ms, TFLOPS:6.67 (+0.00%)
mma(split-kv+stage1): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:5.120730ms, TFLOPS:27.36 (+310.10%)
mma(split-kv+stage2): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:5.004287ms, TFLOPS:28.00 (+2.33%)
mma(split-q+stage1): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:3.462291ms, TFLOPS:40.47 (+44.54%)
mma(split-q+stage2): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:3.658915ms, TFLOPS:38.30
mma(split-q+share-qkv+stage1): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:2.551699ms, TFLOPS:54.91 (+35.69%)
mma(split-q+share-qkv+stage2): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:2.532172ms, TFLOPS:55.34 (+0.77%)
mma(split-q+share-kv+stage1): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:2.776575ms, TFLOPS:50.46
mma(split-q+share-kv+stage2): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:2.596927ms, TFLOPS:53.96
(flash): ['-0.00516129 ', '0.05783081 ', '-0.00027728 '], time:3.776550ms, TFLOPS:37.10
----------------------------------------------------------------------------------------------------------------------------------
- Example: B=1, H=48, N=8192,
D=512
(RTX 3080), FA2 not supported,QK Tiling
Faster than SDPA~🎉🎉
python3 flash_attn_mma.py --B 1 --H 8 --N 8192 --iters 10 --show-all --sdpa --D 512 # NVIDIA RTX 3080 Laptop, Faster than SDPA
------------------------------------------B=1, H=8, N=8192, D=512, Warmup: 1, Iters: 10-------------------------------------------
mma(split-q+tiling-qk+stage1): ['-0.00433731 ', '0.02165222 ', '-0.01544189 '], time:48.775554ms, TFLOPS:22.60 (+0.00%)
mma(split-q+tiling-qk+stage2): ['-0.00433731 ', '0.02165222 ', '-0.01544189 '], time:47.503424ms, TFLOPS:23.20 (+2.68%)
(sdpa): ['-0.00438309 ', '0.02174377 ', '-0.01551056 '], time:66.486573ms, TFLOPS:16.58
----------------------------------------------------------------------------------------------------------------------------------
- Example: B=1, H=48, N=8192,
D=512
(NVIDIA L20), FA2 not supported,QK Tiling
Faster than SDPA~🎉🎉
python3 flash_attn_mma.py --B 1 --H 48 --D 512 --N 16384 --show-all --check --iters 10 --sdpa
-----------------------------------------B=1, H=48, N=16384, D=512, Warmup: 1, Iters: 10------------------------------------------
mma(split-q+tiling-qk+stage1): ['0.0079422 ', '-0.02334595 ', '0.00881958 '], time:387.384224ms, TFLOPS:68.28 (+0.00%)
mma(split-q+tiling-qk+stage2): ['0.0079422 ', '-0.02334595 ', '0.00881958 '], time:325.593209ms, TFLOPS:81.24 (+18.98%)
(sdpa): ['0.00790405 ', '-0.02330017 ', '0.00875854 '], time:452.067018ms, TFLOPS:58.51
----------------------------------------------------------------------------------------------------------------------------------
- 📚 Split Q + Fully Shared QKV SMEM (1/4 SRAM vs FA2)
// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy
// and reduce Q SMEM IO-Access.
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, half* K, half* V, half* O, ...);
- 📚 Split Q + QK Fine-grained Tiling (O(16xd) SRAM vs FA2 O(4xBrxd) SRAM,
Headdim -> 1024
)
// Fine-grained tiling at the MMA level for Q and K results in a constant SRAM usage of
// 64 * kMmaAtomK for Q and K. For V, the SRAM complexity is O(kMmaAtomK * d), leading to
// an overall SRAM complexity of O(kMmaAtomK * d). Consequently, this approach allows us to
// extend D (head dimension) up to 1024. Performance is stay tuned for updates ~
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q, half* K, half* V, half* O, ...);
Full Changelog: v2.6.10...v2.6.11
📚FA2: QK Fine-grained Tiling
What's Changed
- [FA2] hotfix flash-attn-mma smem size setting✔️ by @DefTruth in #170
- [FA2] reorder grid layout, boost 5~10% TFLOPS✔️ by @DefTruth in #171
- [FA2] optimize block tiling for headdim >= 128✔️ by @DefTruth in #172
- [FA2] flash-attn-mma tiling-qk for large d⚡️ by @DefTruth in #173
- [FA2] fix tiling-qk misaligned address✔️ by @DefTruth in #174
- [README] Refactor README.md✔️ by @DefTruth in #175
- [README] Refactor README✔️ by @DefTruth in #176
📚 Split Q + QK Fine-grained Tiling (O(16xd) SRAM vs FA2 O(4xBrxd) SRAM, Headdim -> 1024
)
// Fine-grained tiling at the MMA level for Q and K results in a constant SRAM usage of
// 64 * kMmaAtomK for Q and K. For V, the SRAM complexity is O(kMmaAtomK * d), leading to
// an overall SRAM complexity of O(kMmaAtomK * d). Consequently, this approach allows us to
// extend D (head dimension) up to 1024. Performance is stay tuned for updates ~
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q, half* K, half* V, half* O, ...);
Full Changelog: v2.6.9...v2.6.10
FA2 Fully Shared QKV SMEM🎉
What's Changed
- [FA2] Update flash-attn-mma shared-kv/qkv🎉 by @DefTruth in #163
- [FA2] Update flash-attn-mma shared-kv/qkv🎉 by @DefTruth in #164
- [FA2] Update flash-attn-mma shared-qkv🎉 by @DefTruth in #165
- [FA2] Update flash-attn-mma shared-kv🎉 by @DefTruth in #166
- [FA2] Update flash-attn-mma split-kv/q🎉 by @DefTruth in #167
- [FA2] Update flash-attn-mma shared-qkv🎉 by @DefTruth in #168
- [FA2] flash-attn-mma get rid of transpose-k✔️ by @DefTruth in #169
Full Changelog: v2.6.8...v2.6.9
FA2 Fully Shared QKV SMEM🎉
FA2 Fully Shared QKV SMEM🎉
What's Changed
- [FA2] Release flash-attn-mma split-kv/q🎉 by @DefTruth in #161
- [FA2] Release flash-attn-mma shared-kv/qkv🎉 by @DefTruth in #162
I have also implemented FlashAttention-2 using pure MMA PTX instructions, which supports features such as Multi-Stages, Tile MMA, Tile Warp, Fully Shared QKV SMEM, Prefetch Q s2r, Collective Store, etc. Currently, for small-scale attention (B<=4, H <=48, SeqLen <= 8192)
can run faster than offical FA2 on some Devices, for example, NVIDIA RTX 3080 Laptop.
- Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop)
python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 # NVIDIA RTX 3080 Laptop
------------------------------------------------------------------------------------------------------------------------
B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 1617, Warmup: 1, Iters: 10
------------------------------------------------------------------------------------------------------------------------
B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10
mma(split-kv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:5.586338ms, TFLOPS:25.08
mma(split-kv+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:5.326223ms, TFLOPS:26.31
mma(split-q+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:3.834152ms, TFLOPS:36.54
mma(split-q+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:4.328346ms, TFLOPS:32.37
mma(split-q+share-kv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.636528ms, TFLOPS:53.15
mma(split-q+share-qkv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.594471ms, TFLOPS:54.01
mma(split-q+share-qkv+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.574611ms, TFLOPS:54.42
(flash): ['0.01963806 ', '0.0145874 ', '-0.02593994 '], time:3.764462ms, TFLOPS:37.22
-----------------------------------------------------------------------------------------------------------------------
However, for large-scale attention computations, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ Please refer to flash-attention-mma⚡️⚡️ for more details.
Tensor Cores | Loop over Seqlen/Headdim | Tile Block (Br, Bc) | MMA (m16n8k16) |
---|---|---|---|
✔️ | ✔️ | ✔️ | ✔️ |
Pack LDST (128 bits) | SMEM Padding | Copy Async | Tile MMA (More Threads) |
✔️ | ✔️ | ✔️ | ✔️ |
Tile Warp (More Values) | Multi Stages (1/2) | Collective Store (Shfl) | Split KV/Q |
✔️ | ✔️ | ✔️ | ✔️ |
Shared KV SMEM | Fully Shared QKV SMEM | Prefetch Q s2r | SMEM/Block Swizzle |
✔️ | ✔️ | ✔️ | ? |
The Split KV
and Split Q
implementations have been carried out in flash-attention-mma⚡️⚡️ for performance comparison. The Split KV
method, which involves splitting all QKV across MMA (Warps), is slower than Split Q
policy, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).
- 📚 Split KV (Basic, FlashAttention-1)
// Split QKV across MMA(Warps) using naive matmul MMA&Warp tiling policy.
// case: The layout of 8 MMA(2x4) [after] kWarpTileSeqLenQxkWarpTileSeqLenK(2x2) -> 32x2,32x2=64x64:
// | [64,64] | warp_KV 0 | warp_KV 1 | warp_KV 2 | warp_KV 3 |
// | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --|
// | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --|
// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|
// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|
__global__ void
flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D]
half* K, // [B, H, D, N] K^T transposed
half* V, // [B, H, N, D]
half* O, // [B, H, N, D]
int QKV_seqlen);
- 📚 Split Q (Faster, FlashAttention-2)
// Split Q across MMA(Warps) and keep access KV for all MMA(Warps),
// in order to reduce the comm between warps via smem and warp shuffle.
// case: MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps
// | 64x64 | warp_KV 0 |
// | warp_QP 0 | MMA 0 ... MMA 0 (x8) |
// | warp_QP 1 | MMA 1 ... MMA 1 (x8) |
// | warp_QP 2 | MMA 2 ... MMA 2 (x8) |
// | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
__global__ void
flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
half* K, // [B, H, D, N] K^T transposed
half* V, // [B, H, N, D]
half* O, // [B, H, N, D]
int QKV_seqlen);
- 📚 Split Q + Shared KV SMEM (Faster+)
// K, V shared the same shared memory, improve block occupancy.
__global__ void
flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
half* K,
half* V,
half* O,
int QKV_seqlen);
- 📚 Split Q + Fully Shared QKV SMEM (Faster++)
// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy.
__global__ void
flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
half* K,
half* V,
half* O,
int QKV_seqlen);
Full Changelog: v2.6.7...v2.6.8
🎉FA2 MMA Split KV/Q
What's Changed
- [FlashAttention] Update flash-attention-mma 0.0.1 🎉 by @DefTruth in #159
- [FA2] Release flash-attn-mma split-kv/q🎉 by @DefTruth in #160
Full Changelog: v2.6.6...v2.6.7
🎉flash-attention-mma 0.0.1
What's Changed
- [HGEMM] CuTe HGEMM debug Makefile target by @DefTruth in #154
- [Softmax] Update Online Softmax bindings by @DefTruth in #155
- [FlashAttention] Refactor toy-flash-attn codes part-1 by @DefTruth in #156
- [Bug]Fix typo by @wjj19950828 in #157
- [FlashAttention] Release flash-atttention-mma 0.0.1 🎉 by @DefTruth in #158
New Contributors
- @wjj19950828 made their first contribution in #157
Full Changelog: v2.6.5...v2.6.6
⚡️⚡️toy-hgemm library
toy-hgemm library
What's Changed
Full Changelog: v2.6.3...v2.6.4
toy-hgemm library
What's Changed
Full Changelog: v2.6.2...v2.6.3