Skip to content

Commit

Permalink
Add configs from ROCm/triton tune_gemm (#3226)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3226

X-link: facebookresearch/FBGEMM#324

https://github.com/ROCm/triton/blob/main_perf/python/perf-kernels/tools/tune_gemm/tune_gemm.py configs

Adapted from D63415810

Reviewed By: jianyuh

Differential Revision: D63929109

fbshipit-source-id: fbe36bf5b08dcd673a1587a831b7bce416d01e86
  • Loading branch information
karthik-man authored and facebook-github-bot committed Oct 5, 2024
1 parent 7a4472a commit 42dca08
Showing 1 changed file with 109 additions and 38 deletions.
147 changes: 109 additions & 38 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2460,43 +2460,106 @@ def need_split_k(SIZE_M, SIZE_N, SIZE_K):
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024


# Configs adapted from https://github.com/ROCm/triton/blob/main_perf/python/perf-kernels/tools/tune_gemm/tune_gemm.py
def prune_configs(configs, named_args, **kwargs):
pruned_configs = []
M = named_args["M"]
N = named_args["N"]
K = named_args["K"]
elemBytes_a = named_args["A"].element_size()
elemBytes_b = named_args["B"].element_size()

if M < 32 or N < 32:
mfma = 16
else:
mfma = 32

SIZE_M = named_args["A"].shape[0]
SIZE_N = named_args["B"].shape[1]
SIZE_K = named_args["C"].shape[1]
# TODO (zhanglx): figure out the boundary between large and small gemms
large_gemm = False
if M >= 2048 and N >= 2048:
large_gemm = True

pruned_configs = []
for config in configs:
kw = config.kwargs
BLOCK_SIZE_M, BLOCK_SIZE_N, _ = (
kw["BLOCK_M"],
kw["BLOCK_N"],
kw["BLOCK_K"],
)
SPLIT_K = kw["SPLIT_K"]
if SIZE_M <= 32 and BLOCK_SIZE_M != 32:
BLOCK_SIZE_M = config.kwargs.get("BLOCK_M")
BLOCK_SIZE_N = config.kwargs.get("BLOCK_N")
BLOCK_SIZE_K = config.kwargs.get("BLOCK_K")
num_warps = config.num_warps
matrix_instr_nonkdim = config.kwargs.get("matrix_instr_nonkdim")
if matrix_instr_nonkdim > mfma:
continue
if mfma == 4 and BLOCK_SIZE_K < 64:
continue
# some layouts could not work properly in case
# number elemens per thread is less 1
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
continue
SPLIT_K = config.kwargs.get("SPLIT_K")
GROUP_M = config.kwargs.get("GROUP_M")
if BLOCK_SIZE_M < matrix_instr_nonkdim or BLOCK_SIZE_N < matrix_instr_nonkdim:
continue
if M <= matrix_instr_nonkdim and BLOCK_SIZE_M != matrix_instr_nonkdim:
continue
if N <= matrix_instr_nonkdim and BLOCK_SIZE_N != matrix_instr_nonkdim:
continue
if SIZE_N <= 32 and BLOCK_SIZE_N != 32:
# Skip BLOCK_SIZE that is too large compare to M/N
# unless BLOCK_SIZE is already small enough
if BLOCK_SIZE_M > M * 2 and BLOCK_SIZE_M != 16:
continue
if BLOCK_SIZE_N > N * 2 and BLOCK_SIZE_N != 16:
continue
# skip large split_k when not necessary
if SPLIT_K != 1 and not need_split_k(SIZE_M, SIZE_N, SIZE_K):
if SPLIT_K != 1 and not need_split_k(M, N, K):
continue
# skip split_k that leads to EVEN_K = false
leap = SPLIT_K * BLOCK_SIZE_K
modv = K % leap
if modv != 0:
continue
# skip large GROUP_M
if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
continue
# out of shared memory resource
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
LDS = (
BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
+ BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
)
if LDS > 65536:
continue
# Skip small block sizes and num_warps for large gemm
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
if large_gemm:
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
continue
if BLOCK_SIZE_K < 64:
continue
if num_warps < 4:
continue

pruned_configs.append(config)
logging.info(f"pruned_configs: config len{len(pruned_configs)}")

print(f"{len(configs)=} {len(pruned_configs)=}")
if len(pruned_configs) == 0:
print(f"No configs left after pruning! {M=} {N=} {K=}")
pruned_configs = configs[:10]
return pruned_configs


def get_full_non_persistent_tuning_space(use_split_k):
if torch.version.hip is None:
logger.warning("Using HIP configs on CUDA device, this may be slow.")
def get_full_non_persistent_tuning_space():
configs = []
block_mn_range = [32, 64, 128, 256]
block_k_range = [32, 64, 128]

block_mn_range = [16, 32, 64, 128, 256]
block_k_range = [16, 32, 64, 128, 256]
split_k_range = [1]
num_warps_range = [1, 2, 4, 8, 16]
group_m_range = [1, 4, 8]
num_warps_range = [1, 2, 4, 8]
group_m_range = [1, 2, 4, 8, 16, 32]
# For now we see better perf with num_stages=0 for all gemm configs we care
# But keep this explicit so that we do not forget we may need to set it to
# other values in the future
num_stage_range = [0]
waves_per_eu_range = [0]
matrix_instr_nonkdim_range = [16, 32]
kpack_range = [1, 2]

for block_m in block_mn_range:
for block_n in block_mn_range:
Expand All @@ -2505,28 +2568,36 @@ def get_full_non_persistent_tuning_space(use_split_k):
for group_m in group_m_range:
for split_k in split_k_range:
for num_stages in num_stage_range:
configs.append(
triton.Config(
{
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"BLOCK_K": block_k,
"GROUP_M": group_m,
"SPLIT_K": split_k,
},
num_stages=num_stages,
num_warps=num_warps,
)
)

for waves_per_eu in waves_per_eu_range:
for (
matrix_instr_nonkdim
) in matrix_instr_nonkdim_range:
for kpack in kpack_range:
configs.append(
triton.Config(
{
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"BLOCK_K": block_k,
"GROUP_M": group_m,
"SPLIT_K": split_k,
"waves_per_eu": waves_per_eu,
"matrix_instr_nonkdim": matrix_instr_nonkdim,
"kpack": kpack,
},
num_warps=num_warps,
num_stages=num_stages,
)
)
logger.info(f"all configs #: {len(configs)}")
return configs


MATMUL_CONFIGS: List[Config] = get_full_non_persistent_tuning_space(True)
MATMUL_CONFIGS_NON_PERSISTENT: List[Config] = get_full_non_persistent_tuning_space()


@triton.autotune(
configs=MATMUL_CONFIGS,
configs=MATMUL_CONFIGS_NON_PERSISTENT,
key=["M", "N", "K"],
prune_configs_by={
"early_config_prune": prune_configs,
Expand Down

0 comments on commit 42dca08

Please sign in to comment.