Skip to content

Releases: DefTruth/CUDA-Learn-Notes

🎉FA2/HGEMM SMEM Swizzle

25 Dec 05:52
bdd361a
Compare
Choose a tag to compare

What's Changed

  • [FA2] split-q + tiling-qk D=512 performance🎉 by @DefTruth in #178
  • [FA2] split-q + tiling-qk D=512 performance🎉 by @DefTruth in #179
  • [FA2] split-q + tiling-qk D=512 performance🎉 by @DefTruth in #180
  • [Doc] Refactor README.md to improve readability✔️ by @DefTruth in #181
  • [Doc] Refactor README.md for better readability✔️ by @DefTruth in #182
  • [FA2] flash-attn-mma 3080/L20/4090 bench✔️ by @DefTruth in #183
  • [FA2] flash-attn-mma 3080/L20/4090 bench✔️ by @DefTruth in #184
  • [FA2] fa2/hgemm manually smem swizzle🎉 by @DefTruth in #185

flash_attn_mma_stages_split_q_tiling_qk_swizzle_kernel

void flash_attn_mma_stages_split_q_tiling_qk_swizzle_kernel<512, 16, 8, 16, 8, 1, 8, 1, 1, 16, 1, 64, 2, 0, 0, 8>(__half *, __half *, __half *, __half *, int, int) (8, 48, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 8.9
    Section: Command line profiler metrics
    ------------------------------------------------------------------ ----------- ------------
    Metric Name                                                        Metric Unit Metric Value
    ------------------------------------------------------------------ ----------- ------------
    sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.avg                        0
    sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.max                        0
    sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.min                        0
    sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.sum                        0
    ------------------------------------------------------------------ ----------- ------------

Full Changelog: v2.6.11...v2.6.12

📚 Split Q + QK Fine-grained Tiling

23 Dec 02:41
d474791
Compare
Choose a tag to compare

What's Changed

  • [FA2] split-q + tiling-qk D=512 performance🎉 by @DefTruth in #177

📚 Split Q + QK Fine-grained Tiling (O(16xd) SRAM vs FA2 O(4xBrxd) SRAM, Headdim -> 1024)

Currently, for small-scale attention (B<=4, H <=48, SeqLen <= 8192) can run faster than offical FA2/SDPA on some Devices. For example, on NVIDIA RTX 3080 Laptop, 📚 Split Q + Fully Shared QKV SMEM can achieve 55 TFLOPS (D=64) that almost ~1.5x 🎉 faster than FA2. Moreover, on NVIDIA L20, 📚 Split Q + QK Fine-grained Tiling can achieve 81 TFLOPS (D=512) that almost ~1.4x 🎉 faster than SDPA(EFFICIENT_ATTENTION). However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~

  • Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop), Faster than FA2~🎉🎉
python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch # NVIDIA RTX 3080 Laptop
-------------------------------------------B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10-------------------------------------------
                  torch(unfused): ['-0.00514603 ', '0.05783081  ', '-0.00026727 '], time:20.999861ms, TFLOPS:6.67 (+0.00%)
            mma(split-kv+stage1): ['-0.00511169 ', '0.05795288  ', '-0.00029612 '], time:5.120730ms, TFLOPS:27.36 (+310.10%)
            mma(split-kv+stage2): ['-0.00511169 ', '0.05795288  ', '-0.00029612 '], time:5.004287ms, TFLOPS:28.00 (+2.33%)
             mma(split-q+stage1): ['-0.00511169 ', '0.05795288  ', '-0.00029612 '], time:3.462291ms, TFLOPS:40.47 (+44.54%)
             mma(split-q+stage2): ['-0.00511169 ', '0.05795288  ', '-0.00029612 '], time:3.658915ms, TFLOPS:38.30
   mma(split-q+share-qkv+stage1): ['-0.00511169 ', '0.05795288  ', '-0.00029612 '], time:2.551699ms, TFLOPS:54.91 (+35.69%)
   mma(split-q+share-qkv+stage2): ['-0.00511169 ', '0.05795288  ', '-0.00029612 '], time:2.532172ms, TFLOPS:55.34 (+0.77%)
    mma(split-q+share-kv+stage1): ['-0.00511169 ', '0.05795288  ', '-0.00029612 '], time:2.776575ms, TFLOPS:50.46
    mma(split-q+share-kv+stage2): ['-0.00511169 ', '0.05795288  ', '-0.00029612 '], time:2.596927ms, TFLOPS:53.96
                         (flash): ['-0.00516129 ', '0.05783081  ', '-0.00027728 '], time:3.776550ms, TFLOPS:37.10
----------------------------------------------------------------------------------------------------------------------------------
  • Example: B=1, H=48, N=8192, D=512 (RTX 3080), FA2 not supported, QK Tiling Faster than SDPA~🎉🎉
python3 flash_attn_mma.py --B 1 --H 8 --N 8192 --iters 10 --show-all --sdpa --D 512 # NVIDIA RTX 3080 Laptop, Faster than SDPA
------------------------------------------B=1, H=8, N=8192, D=512, Warmup: 1, Iters: 10-------------------------------------------
   mma(split-q+tiling-qk+stage1): ['-0.00433731 ', '0.02165222  ', '-0.01544189 '], time:48.775554ms, TFLOPS:22.60 (+0.00%)
   mma(split-q+tiling-qk+stage2): ['-0.00433731 ', '0.02165222  ', '-0.01544189 '], time:47.503424ms, TFLOPS:23.20 (+2.68%)
                          (sdpa): ['-0.00438309 ', '0.02174377  ', '-0.01551056 '], time:66.486573ms, TFLOPS:16.58
----------------------------------------------------------------------------------------------------------------------------------
  • Example: B=1, H=48, N=8192, D=512 (NVIDIA L20), FA2 not supported, QK Tiling Faster than SDPA~🎉🎉
python3 flash_attn_mma.py --B 1 --H 48 --D 512 --N 16384 --show-all --check --iters 10 --sdpa
-----------------------------------------B=1, H=48, N=16384, D=512, Warmup: 1, Iters: 10------------------------------------------
   mma(split-q+tiling-qk+stage1): ['0.0079422   ', '-0.02334595 ', '0.00881958  '], time:387.384224ms, TFLOPS:68.28 (+0.00%)
   mma(split-q+tiling-qk+stage2): ['0.0079422   ', '-0.02334595 ', '0.00881958  '], time:325.593209ms, TFLOPS:81.24 (+18.98%)
                          (sdpa): ['0.00790405  ', '-0.02330017 ', '0.00875854  '], time:452.067018ms, TFLOPS:58.51
----------------------------------------------------------------------------------------------------------------------------------
  • 📚 Split Q + Fully Shared QKV SMEM (1/4 SRAM vs FA2)
// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy
// and reduce Q SMEM IO-Access.
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, half* K, half* V, half* O, ...);
  • 📚 Split Q + QK Fine-grained Tiling (O(16xd) SRAM vs FA2 O(4xBrxd) SRAM, Headdim -> 1024)
// Fine-grained tiling at the MMA level for Q and K results in a constant SRAM usage of
// 64 * kMmaAtomK for Q and K. For V, the SRAM complexity is O(kMmaAtomK * d), leading to
// an overall SRAM complexity of O(kMmaAtomK * d). Consequently, this approach allows us to
// extend D (head dimension) up to 1024. Performance is stay tuned for updates ~
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q, half* K, half* V, half* O, ...);

Full Changelog: v2.6.10...v2.6.11

📚FA2: QK Fine-grained Tiling

22 Dec 07:58
697e06f
Compare
Choose a tag to compare

What's Changed

  • [FA2] hotfix flash-attn-mma smem size setting✔️ by @DefTruth in #170
  • [FA2] reorder grid layout, boost 5~10% TFLOPS✔️ by @DefTruth in #171
  • [FA2] optimize block tiling for headdim >= 128✔️ by @DefTruth in #172
  • [FA2] flash-attn-mma tiling-qk for large d⚡️ by @DefTruth in #173
  • [FA2] fix tiling-qk misaligned address✔️ by @DefTruth in #174
  • [README] Refactor README.md✔️ by @DefTruth in #175
  • [README] Refactor README✔️ by @DefTruth in #176

📚 Split Q + QK Fine-grained Tiling (O(16xd) SRAM vs FA2 O(4xBrxd) SRAM, Headdim -> 1024)

// Fine-grained tiling at the MMA level for Q and K results in a constant SRAM usage of
// 64 * kMmaAtomK for Q and K. For V, the SRAM complexity is O(kMmaAtomK * d), leading to
// an overall SRAM complexity of O(kMmaAtomK * d). Consequently, this approach allows us to
// extend D (head dimension) up to 1024. Performance is stay tuned for updates ~
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q, half* K, half* V, half* O, ...);

Full Changelog: v2.6.9...v2.6.10

FA2 Fully Shared QKV SMEM🎉

19 Dec 08:10
4687e1d
Compare
Choose a tag to compare

What's Changed

  • [FA2] Update flash-attn-mma shared-kv/qkv🎉 by @DefTruth in #163
  • [FA2] Update flash-attn-mma shared-kv/qkv🎉 by @DefTruth in #164
  • [FA2] Update flash-attn-mma shared-qkv🎉 by @DefTruth in #165
  • [FA2] Update flash-attn-mma shared-kv🎉 by @DefTruth in #166
  • [FA2] Update flash-attn-mma split-kv/q🎉 by @DefTruth in #167
  • [FA2] Update flash-attn-mma shared-qkv🎉 by @DefTruth in #168
  • [FA2] flash-attn-mma get rid of transpose-k✔️ by @DefTruth in #169

Full Changelog: v2.6.8...v2.6.9

FA2 Fully Shared QKV SMEM🎉

17 Dec 03:44
0c6785f
Compare
Choose a tag to compare

FA2 Fully Shared QKV SMEM🎉

What's Changed

  • [FA2] Release flash-attn-mma split-kv/q🎉 by @DefTruth in #161
  • [FA2] Release flash-attn-mma shared-kv/qkv🎉 by @DefTruth in #162

I have also implemented FlashAttention-2 using pure MMA PTX instructions, which supports features such as Multi-Stages, Tile MMA, Tile Warp, Fully Shared QKV SMEM, Prefetch Q s2r, Collective Store, etc. Currently, for small-scale attention (B<=4, H <=48, SeqLen <= 8192) can run faster than offical FA2 on some Devices, for example, NVIDIA RTX 3080 Laptop.

flash-attn-mma

  • Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop)
python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 # NVIDIA RTX 3080 Laptop
------------------------------------------------------------------------------------------------------------------------
                    B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 1617, Warmup: 1, Iters: 10
------------------------------------------------------------------------------------------------------------------------
                              B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10
          mma(split-kv+stage1): ['0.01960754  ', '0.01452637  ', '-0.02592468 '], time:5.586338ms, TFLOPS:25.08
          mma(split-kv+stage2): ['0.01960754  ', '0.01452637  ', '-0.02592468 '], time:5.326223ms, TFLOPS:26.31
           mma(split-q+stage1): ['0.01960754  ', '0.01452637  ', '-0.02592468 '], time:3.834152ms, TFLOPS:36.54
           mma(split-q+stage2): ['0.01960754  ', '0.01452637  ', '-0.02592468 '], time:4.328346ms, TFLOPS:32.37
  mma(split-q+share-kv+stage1): ['0.01960754  ', '0.01452637  ', '-0.02592468 '], time:2.636528ms, TFLOPS:53.15
 mma(split-q+share-qkv+stage1): ['0.01960754  ', '0.01452637  ', '-0.02592468 '], time:2.594471ms, TFLOPS:54.01
 mma(split-q+share-qkv+stage2): ['0.01960754  ', '0.01452637  ', '-0.02592468 '], time:2.574611ms, TFLOPS:54.42
                       (flash): ['0.01963806  ', '0.0145874   ', '-0.02593994 '], time:3.764462ms, TFLOPS:37.22
-----------------------------------------------------------------------------------------------------------------------

However, for large-scale attention computations, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ Please refer to flash-attention-mma⚡️⚡️ for more details.

Tensor Cores Loop over Seqlen/Headdim Tile Block (Br, Bc) MMA (m16n8k16)
✔️ ✔️ ✔️ ✔️
Pack LDST (128 bits) SMEM Padding Copy Async Tile MMA (More Threads)
✔️ ✔️ ✔️ ✔️
Tile Warp (More Values) Multi Stages (1/2) Collective Store (Shfl) Split KV/Q
✔️ ✔️ ✔️ ✔️
Shared KV SMEM Fully Shared QKV SMEM Prefetch Q s2r SMEM/Block Swizzle
✔️ ✔️ ✔️ ?

The Split KV and Split Q implementations have been carried out in flash-attention-mma⚡️⚡️ for performance comparison. The Split KV method, which involves splitting all QKV across MMA (Warps), is slower than Split Q policy, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).

  • 📚 Split KV (Basic, FlashAttention-1)
// Split QKV across MMA(Warps) using naive matmul MMA&Warp tiling policy.
// case: The layout of 8 MMA(2x4)  [after] kWarpTileSeqLenQxkWarpTileSeqLenK(2x2) -> 32x2,32x2=64x64: 
// |  [64,64]  |    warp_KV 0    |    warp_KV 1    |    warp_KV 2    |    warp_KV 3    |
// | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --|
// | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --|
// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|
// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|
__global__ void 
flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D]
                                      half* K, // [B, H, D, N] K^T transposed 
                                      half* V, // [B, H, N, D] 
                                      half* O, // [B, H, N, D] 
                                      int QKV_seqlen);
  • 📚 Split Q (Faster, FlashAttention-2)
// Split Q across MMA(Warps) and keep access KV for all MMA(Warps),
// in order to reduce the comm between warps via smem and warp shuffle.
// case: MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps
// |   64x64   |      warp_KV 0       |
// | warp_QP 0 | MMA 0 ... MMA 0 (x8) |
// | warp_QP 1 | MMA 1 ... MMA 1 (x8) |
// | warp_QP 2 | MMA 2 ... MMA 2 (x8) |
// | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
__global__ void
flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
                                     half* K, // [B, H, D, N] K^T transposed 
                                     half* V, // [B, H, N, D] 
                                     half* O, // [B, H, N, D] 
                                     int QKV_seqlen);
  • 📚 Split Q + Shared KV SMEM (Faster+)
// K, V shared the same shared memory, improve block occupancy.
__global__ void 
flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, 
                                               half* K, 
                                               half* V, 
                                               half* O, 
                                               int QKV_seqlen);
  • 📚 Split Q + Fully Shared QKV SMEM (Faster++)
// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy.
__global__ void 
flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, 
                                                half* K, 
                                                half* V, 
                                                half* O, 
                                                int QKV_seqlen);

Full Changelog: v2.6.7...v2.6.8

🎉FA2 MMA Split KV/Q

15 Dec 03:31
5afd8c1
Compare
Choose a tag to compare

What's Changed

  • [FlashAttention] Update flash-attention-mma 0.0.1 🎉 by @DefTruth in #159
  • [FA2] Release flash-attn-mma split-kv/q🎉 by @DefTruth in #160

Full Changelog: v2.6.6...v2.6.7

🎉flash-attention-mma 0.0.1

12 Dec 03:29
b1b923a
Compare
Choose a tag to compare

What's Changed

New Contributors

Full Changelog: v2.6.5...v2.6.6

⚡️⚡️toy-hgemm library

28 Nov 02:07
37f1554
Compare
Choose a tag to compare

What's Changed

Full Changelog: v2.6.4...v2.6.5

toy-hgemm library

22 Nov 11:42
56e2fe9
Compare
Choose a tag to compare

What's Changed

Full Changelog: v2.6.3...v2.6.4

toy-hgemm library

22 Nov 06:51
6ea2eb9
Compare
Choose a tag to compare

What's Changed

Full Changelog: v2.6.2...v2.6.3