Skip to content

Commit

Permalink
[FA2] fa2/hgemm manually smem swizzle🎉 (#185)
Browse files Browse the repository at this point in the history
* Update flash_attn_mma.py

* Create makefile

* Create README.md

* Update and rename matrix_trans_swizzle.cu to mat_trans_swizzle.cu

* Update hgemm_mma_swizzle.cu

* Update mat_trans_swizzle.cu

* Update and rename flash_attn_mma_swizzle_qkv.cu to flash_attn_mma_share_kv_swizzle.cu

* Create flash_attn_mma_share_qkv_swizzle.cu

* Create flash_attn_mma_split_q_swizzle.cu

* Create flash_attn_mma_split_kv_swizzle.cu

* Create flash_attn_mma_tiling_qk_swizzle.cu

* Create flash_attn_mma_tiling_qkv_swizzle.cu

* Update flash_attn_mma_share_qkv_swizzle.cu

* Update flash_attn_mma_split_kv_swizzle.cu

* Update flash_attn_mma_split_q_swizzle.cu

* Update flash_attn_mma_tiling_qk_swizzle.cu

* Update flash_attn_mma_tiling_qkv_swizzle.cu

* Update README.md

* Update hgemm_mma_swizzle.cu

* Update makefile

* Update README.md

* Update README.md

* Update mat_trans_swizzle.cu

* Update makefile

* Update hgemm_mma_swizzle.cu

* Update hgemm_mma_swizzle.cu

* Update README.md

* Update hgemm_mma_stage.cu

* Update hgemm_mma.cu

* Update makefile

* Update utils.h

* Create mma_simple_swizzle.cu

* Update makefile

* Update mma_simple_swizzle.cu

* Update hgemm_mma_swizzle.cu

* Update makefile

* Update utils.py

* Update makefile

* Create hgemm_mma_stage_swizzle.cu

* Update hgemm.py

* Update hgemm.cc

* Update mat_trans_swizzle.cu

* Update flash_attn_mma_tiling_qk_swizzle.cu

* Update flash_attn.cc

* Update flash_attn_mma.py

* Update flash_attn_mma.py

* Update flash_attn_mma.py

* Update flash_attn_mma_tiling_qk_swizzle.cu

* Update flash_attn_mma_tiling_qk_swizzle.cu

* Update flash_attn_mma_share_kv_swizzle.cu

* Update README.md

* Update README.md

* Create print_swizzle_layout.py

* Update flash_attn_mma_tiling_qk_swizzle.cu

* Update flash_attn_mma_share_kv_swizzle.cu

* Update README.md

* Update hgemm_mma_stage_swizzle.cu

* Update README.md

* Update README.md

* Update README.md

* Update mma_simple_swizzle.cu

* Create print_swizzle_layout.py

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md
  • Loading branch information
DefTruth authored Dec 25, 2024
1 parent 98984e0 commit bdd361a
Show file tree
Hide file tree
Showing 27 changed files with 3,940 additions and 179 deletions.
64 changes: 34 additions & 30 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's d
|✔️|✔️|✔️|✔️|
|Copy Async|Tile MMA (More Threads)|Tile Warp (More Values)|Multi Stages (2/3/4)|
|✔️|✔️|✔️|✔️|
|Reg Double Buffers|Block Swizzle|Warp Swizzle|SMEM Swizzle (CuTe)|
|Reg Double Buffers|Block Swizzle|Warp Swizzle|SMEM Swizzle (CuTe/MMA)|
|✔️|✔️|✔️|✔️|
|Collective Store (Warp Shfl)|Row Major (NN)|Col Major (TN)| SGEMM FP32/TF32|
|Collective Store (Shfl)|Row Major (NN)|Col Major (TN)| SGEMM FP32/TF32|
|✔️|✔️|✔️|✔️|


Expand All @@ -48,7 +48,7 @@ I have also implemented **FlashAttention-2** using pure MMA PTX instructions, wh
|Tensor Cores|Loop over Seqlen/Headdim |Tile Block (Br, Bc)|MMA (m16n8k16)|
|:---:|:---:|:---:|:---:|
|✔️|✔️|✔️|✔️|
|Pack LDST (128 bits)|SMEM Padding|Copy Async|Tile MMA (More Threads)|
|Pack LDST (128 bits)|SMEM **Swizzle**/Padding |Copy Async|Tile MMA (More Threads)|
|✔️|✔️|✔️|✔️|
|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Shfl)|**Split KV/Q**|
|✔️|✔️|✔️|✔️|
Expand Down Expand Up @@ -160,7 +160,6 @@ The kernels listed here will guide you through a step-by-step progression, rangi

|📖 CUDA Kernel| 📖 Elem DType| 📖 Acc DType| 📖 Docs | 📖 Level |
|:---|:---|:---|:---|:---|
| ✔️ [nsys/ncu(timeline/ptx/sass)](./kernels/nvidia-nsight/)|/|/|[link](./kernels/nvidia-nsight/)|⭐️|
| ✔️ [elementwise_f32](./kernels/elementwise/elementwise.cu)|f32|/|[link](./kernels/elementwise/)|⭐️|
| ✔️ [elementwise_f32x4](./kernels/elementwise/elementwise.cu)|f32|/|[link](./kernels/elementwise/)|⭐️|
| ✔️ [elementwise_f16](./kernels/elementwise/elementwise.cu)|f16|/|[link](./kernels/elementwise/)|⭐️|
Expand Down Expand Up @@ -205,27 +204,27 @@ The kernels listed here will guide you through a step-by-step progression, rangi
| ✔️ [mat_trans_f32_diagonal2d](./kernels/mat-transpose/mat_transpose.cu)|f32|/|[link](./kernels/mat-transpose/)|⭐️⭐️|
| ✔️ [mat_trans_f32x4_col2row{2d}](./kernels/mat-transpose/mat_transpose.cu)|f32|/|[link](./kernels/mat-transpose/)|⭐️⭐️|
| ✔️ [mat_trans_f32x4_row2col{2d}](./kernels/mat-transpose/mat_transpose.cu)|f32|/|[link](./kernels/mat-transpose/)|⭐️⭐️|
| ✔️ [warp_reduce_[all]](./kernels/reduce/block_all_reduce.cu)|all|all|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_f32_f32](./kernels/reduce/block_all_reduce.cu)|f32|f32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_f32x4_f32](./kernels/reduce/block_all_reduce.cu)|f32|f32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_f16_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_f16_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_f16x2_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_f16x2_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_f16x8_pack_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_f16x8_pack_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_bf16_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_bf16_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_bf16x2_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_bf16x2_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_bf16x8_pack_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_bf16x8_pack_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_fp8_e4m3_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e4m3|f16|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_fp8_e5m2_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e5m2|f16|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_fp8_e4m3x16_pack_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e4m3|f16|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_fp8_e5m2x16_pack_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e5m2|f16|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_i8_i32](./kernels/reduce/block_all_reduce.cu)|i8|i32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_i8x16_pack_i32](./kernels/reduce/block_all_reduce.cu)|i8|i32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [warp_reduce_{all}](./kernels/reduce/block_all_reduce.cu)|all|all|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_f32_f32](./kernels/reduce/block_all_reduce.cu)|f32|f32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_f32x4_f32](./kernels/reduce/block_all_reduce.cu)|f32|f32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_f16_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_f16_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_f16x2_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_f16x2_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_f16x8_pack_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_f16x8_pack_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_bf16_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_bf16_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_bf16x2_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_bf16x2_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_bf16x8_pack_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_bf16x8_pack_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_fp8_e4m3_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e4m3|f16|[link](./kernels/reduce/)|⭐️⭐️⭐️|
| ✔️ [block_all_reduce_fp8_e5m2_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e5m2|f16|[link](./kernels/reduce/)|⭐️⭐️⭐️|
| ✔️ [block_all_reduce_fp8_e4m3x16_pack_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e4m3|f16|[link](./kernels/reduce/)|⭐️⭐️⭐️|
| ✔️ [block_all_reduce_fp8_e5m2x16_pack_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e5m2|f16|[link](./kernels/reduce/)|⭐️⭐️⭐️|
| ✔️ [block_all_reduce_i8_i32](./kernels/reduce/block_all_reduce.cu)|i8|i32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_i8x16_pack_i32](./kernels/reduce/block_all_reduce.cu)|i8|i32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [dot_product_f32](./kernels/dot-product/dot_product.cu)|f32|f32|[link](./kernels/dot-product/)|⭐️⭐️|
| ✔️ [dot_product_f32x4](./kernels/dot-product/dot_product.cu)|f32|f32|[link](./kernels/dot-product/)|⭐️⭐️|
| ✔️ [dot_product_f16_f32](./kernels/dot-product/dot_product.cu)|f16|f32|[link](./kernels/dot-product/)|⭐️⭐️|
Expand Down Expand Up @@ -262,7 +261,8 @@ The kernels listed here will guide you through a step-by-step progression, rangi
| ✔️ [rms_norm_f16x8_pack_f32](./kernels/rms-norm/rms_norm.cu)|f16|f32|[link](./kernels/rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16_f32](./kernels/rms-norm/rms_norm.cu)|f16|f32|[link](./kernels/rms-norm/)|⭐️⭐️|
| ✔️ [nms_f32](./kernels/nms/nms.cu)|f32|/|[link](./kernels/nms)|⭐️⭐️|
| ✔️ [notes v1(deprecated)](./kernels/notes-v1.cu)|f32|f32|/|⭐️|
| ✔️ [notes v1(deprecated)](./kernels/notes-v1.cu)|f32|f32|/|⭐️⭐️|
| ✔️ [How to profile with nsys/ncu(timeline/ptx/sass)](./kernels/nvidia-nsight/)|/|/|[link](./kernels/nvidia-nsight/)|⭐️⭐️|

### 📚 Hard ⭐⭐⭐️ ([©️back👆🏻](#cuda-kernel))

Expand All @@ -284,7 +284,7 @@ The kernels listed here will guide you through a step-by-step progression, rangi
| ✔️ [sgemm_t_8x8_sliced_k16...dbuf](./kernels/sgemm/sgemm_async.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_t_8x8_sliced_k16...async](./kernels/sgemm/sgemm_async.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_wmma_m16n16k8...stages*](./kernels/sgemm/sgemm_wmma_tf32_stage.cu)|tf32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_wmma_m16n16k8...swizzle*](./kernels/sgemm/sgemm_wmma_tf32_stage.cu)|tf32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_wmma_m16n16k8...swizzle{+block}*](./kernels/sgemm/sgemm_wmma_tf32_stage.cu)|tf32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_naive_f16](./kernels/hgemm/naive/hgemm.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️|
| ✔️ [hgemm_sliced_k_f16](./kernels/hgemm/naive/hgemm.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_t_8x8_sliced_k_f16x4](./kernels/hgemm/hgemm.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
Expand All @@ -299,12 +299,13 @@ The kernels listed here will guide you through a step-by-step progression, rangi
| ✔️ [hgemm_wmma_m16n16k16...dbuf*](./kernels/hgemm/wmma/hgemm_wmma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_wmma_m32n8k16....dbuf*](./kernels/hgemm/wmma/hgemm_wmma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_wmma_m16n16k16...stages*](./kernels/hgemm/wmma/hgemm_wmma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_wmma_m16n16k16...swizzle*](./kernels/hgemm/wmma/hgemm_wmma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_wmma_m16n16k16...swizzle{+block}*](./kernels/hgemm/wmma/hgemm_wmma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_mma_m16n8k16...naive*](./kernels/hgemm/mma/hgemm_mma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_mma_m16n8k16...mma2x4*](./kernels/hgemm/mma/hgemm_mma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_mma_m16n8k16...stages*](./kernels/hgemm/mma/hgemm_mma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_mma_m16n8k16...swizzle*](./kernels/hgemm/mma/hgemm_mma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_mma_stages{swizzle}...cute*](./kernels/hgemm/cutlass/hgemm_mma_stage_tn_cute.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_mma_m16n8k16...swizzle{+block}*](./kernels/hgemm/mma/hgemm_mma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_mma_m16n8k16...swizzle{+smem}*](./kernels/hgemm/mma/hgemm_mma_stage_swizzle.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_mma_stages_swizzle{+smem}...cute*](./kernels/hgemm/cutlass/hgemm_mma_stage_tn_cute.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_mma_cublas*](./kernels/hgemm/cublas/hgemm_cublas.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️|

### 📚 Hard+ ⭐️⭐️⭐️⭐️ & Hard++ ⭐️⭐️⭐️⭐️⭐️ ([©️back👆🏻](#cuda-kernel))
Expand All @@ -318,11 +319,14 @@ The kernels listed here will guide you through a step-by-step progression, rangi
| ✔️ [flash_attn_mma_stages...shared_kv*](./kernels/flash-attn/mma/flash_attn_mma_share_kv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
| ✔️ [flash_attn_mma_stages...shared_qkv*](./kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
| ✔️ [flash_attn_mma_stages...tiling_qk*](./kernels/flash-attn/mma/flash_attn_mma_tiling_qk.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
| ✔️ [flash_attn_mma...tiling_qk_swizzle{+smem}*](./kernels/flash-attn/mma/flash_attn_mma_tiling_qk_swizzle.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
| ? [flash_attn_mma_stages_split_kv{f32}*](./kernels/flash-attn/mma/flash_attn_mma_split_kv_acc_f32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
| ? [flash_attn_mma_stages_split_q{f32}*](./kernels/flash-attn/mma/flash_attn_mma_split_q_acc_f32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
| ? [flash_attn_mma_stages...shared_kv{f32}*](./kernels/flash-attn/mma/flash_attn_mma_share_kv_acc_f32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
| ? [flash_attn_mma_stages...shared_qkv{f32}*](./kernels/flash-attn/mma/flash_attn_mma_share_qkv_acc_f32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
| ? [flash_attn_mma_stages...tiling_qk{f32}*](./kernels/flash-attn/mma/flash_attn_mma_tiling_qk_acc_f32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
| ✔️ [How to implement MMA smem swizzle*](./kernels/swizzle/mma_simple_swizzle.cu)|f16|f16|[link](./kernels/swizzle)|⭐️⭐️⭐️⭐️|

## 📖 博客目录

<div id="my-blogs-part-1"></div>
Expand Down
2 changes: 1 addition & 1 deletion kernels/flash-attn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
|Tensor Cores|Loop over Seqlen/HeadDim |Tile Block (Br, Bc)|MMA (m16n8k16)|
|:---:|:---:|:---:|:---:|
|✔️|✔️|✔️|✔️|
|Pack LDST (pack 128 bits)|SMEM Padding|Copy Async (cp.async.cg/ca)|Tile MMA (More Threads)
|Pack LDST (pack 128 bits)|SMEM **Swizzle**/Padding |Copy Async (cp.async.cg/ca)|Tile MMA (More Threads)
|✔️|✔️|✔️|✔️|
|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Warp Shfl & Reg Reuse)|**Split KV/Q**|
|✔️|✔️|✔️|✔️|
Expand Down
15 changes: 11 additions & 4 deletions kernels/flash-attn/flash_attn_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def get_args():
'./mma/flash_attn_mma_share_kv.cu',
'./mma/flash_attn_mma_share_qkv.cu',
'./mma/flash_attn_mma_tiling_qk.cu',
'./mma/flash_attn_mma_tiling_qk_swizzle.cu',
'./pybind/flash_attn.cc'
],
extra_cuda_cflags=[
Expand Down Expand Up @@ -218,11 +219,11 @@ def run_benchmark(perf_func: callable,
else:
improve = 0
MAX_TFLOPS = TFLOPS
print(f"{out_info:>32}: {out_val}, time:{mean_time:<.6f}ms, "
print(f"{out_info:>38}: {out_val}, time:{mean_time:<.6f}ms, "
f"TFLOPS:{TFLOPS:<6.2f}(+{improve:.2f}%)")
else:
if not only_show_improved or "flash" in tag:
print(f"{out_info:>32}: {out_val}, time:{mean_time:<.6f}ms, "
if not only_show_improved or "flash" in tag or "sdpa" in tag:
print(f"{out_info:>38}: {out_val}, time:{mean_time:<.6f}ms, "
f"TFLOPS:{TFLOPS:<6.2f}")

if show_matrix: print(out)
Expand Down Expand Up @@ -296,7 +297,7 @@ def check_all_close(out_flash_or_sdpa: torch.Tensor, out_mma: torch.Tensor,
diff = torch.abs(out_flash_or_sdpa.float() - out_mma.float())
all_close = str(torch.allclose(out_flash_or_sdpa.float(), out_mma.float(), atol=1e-2))
pretty_print_line(
f"{true_tag} vs {tag:<18}, all close: {all_close:<6}, "
f"{true_tag} vs {tag:<22}, all close: {all_close:<6}, "
f"max diff: {diff.max().item():.6f}, min diff: {diff.min().item():.6f}, "
f"mean diff: {diff.mean().item():.6f}"
)
Expand Down Expand Up @@ -340,6 +341,8 @@ def check_all_close(out_flash_or_sdpa: torch.Tensor, out_mma: torch.Tensor,
out_mma_share_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, k, v, "mma(split-q+share-kv+stage2)", o, stages=2)
out_mma_tiling_qk1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_tiling_qk, q, k, v, "mma(split-q+tiling-qk+stage1)", o, stages=1)
out_mma_tiling_qk2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_tiling_qk, q, k, v, "mma(split-q+tiling-qk+stage2)", o, stages=2)
out_mma_tiling_qk_sw1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_tiling_qk_swizzle, q, k, v, "mma(split-q+tiling-qk+swizzle+stage1)", o, stages=1)
out_mma_tiling_qk_sw2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_tiling_qk_swizzle, q, k, v, "mma(split-q+tiling-qk+swizzle+stage2)", o, stages=2)
if D <= 256:
out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)")
if args.run_torch_sdpa:
Expand All @@ -360,9 +363,13 @@ def check_all_close(out_flash_or_sdpa: torch.Tensor, out_mma: torch.Tensor,
check_all_close(out_flash, out_mma_share_qkv2, "out_mma_share_qkv2", args.check_all)
check_all_close(out_flash, out_mma_tiling_qk1, "out_mma_tiling_qk1", args.check_all)
check_all_close(out_flash, out_mma_tiling_qk2, "out_mma_tiling_qk2", args.check_all)
check_all_close(out_flash, out_mma_tiling_qk_sw1, "out_mma_tiling_qk_sw1", args.check_all)
check_all_close(out_flash, out_mma_tiling_qk_sw2, "out_mma_tiling_qk_sw2", args.check_all)
pretty_print_line()
elif args.run_torch_sdpa:
pretty_print_line()
check_all_close(out_sdpa, out_mma_tiling_qk1, "out_mma_tiling_qk1", args.check_all, False)
check_all_close(out_sdpa, out_mma_tiling_qk2, "out_mma_tiling_qk2", args.check_all, False)
check_all_close(out_sdpa, out_mma_tiling_qk_sw1, "out_mma_tiling_qk_sw1", args.check_all, False)
check_all_close(out_sdpa, out_mma_tiling_qk_sw2, "out_mma_tiling_qk_sw2", args.check_all, False)
pretty_print_line()
Loading

0 comments on commit bdd361a

Please sign in to comment.