Skip to content

Commit

Permalink
Fix triton fp8 handling of non-contiguous inputs
Browse files Browse the repository at this point in the history
Summary:
This diff fixes an issue where our triton fp8 quantize functions didnt properly handle non-contiguous inputs. Specifically, they write to the output tensor using the same strides as the input, when the output is always allocated as contiguous. This resulted in the output being unintentionally transposed in some cases.

The result of this issue was that non-contiguous inputs would run fine but produce silently transposed outputs. It was noted in github here: pytorch#2713

Adding explicit output strides to the kernel resolves the issue.

I also found a small issue with D59248142 where scaling wouldnt be applied when the number of elements was smaller than the blocksize. This caused fp8_gemm_test to fail. I resolved it by extending the check for when to scale.

Reviewed By: jianyuh

Differential Revision: D60535956
  • Loading branch information
jwfromm authored and facebook-github-bot committed Jul 31, 2024
1 parent 336a854 commit 70043c7
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 8 deletions.
16 changes: 14 additions & 2 deletions fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,13 @@ def _test_matmul_fp8_row(
device: torch.device,
fp8_fast_accum: bool,
use_bias: bool = False,
transpose_input: bool = False,
) -> None:
M, N, K = shape
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
# Make a non-contiguous tensor and check that we still get proper results.
if transpose_input:
a = a.t()
b = torch.randn(N, K, dtype=torch.bfloat16, device=device)
bias = (
torch.randn(N, dtype=torch.float32, device=device) if use_bias else None
Expand All @@ -126,6 +130,9 @@ def _test_matmul_fp8_row(
)

_test_matmul_fp8_row((3, 4, 5), torch.device("cuda"), True)
_test_matmul_fp8_row(
(5, 4, 5), torch.device("cuda"), True, transpose_input=True
)
_test_matmul_fp8_row((3, 4, 5), torch.device("cuda"), True, True)
_test_matmul_fp8_row((3, 4, 5), torch.device("cuda"), False)
_test_matmul_fp8_row((3, 4, 5), torch.device("cuda"), False, True)
Expand Down Expand Up @@ -171,11 +178,15 @@ def _test_matmul_fp8_block(
shape: Tuple[int, int, int],
block_shape: Tuple[int, int, int],
fp8_fast_accum: bool,
transpose_input: bool = False,
device: str = "cuda",
) -> None:
M, N, K = shape
BLOCK_M, BLOCK_N, BLOCK_K = block_shape
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
# Make a non-contiguous tensor and check that we still get proper results.
if transpose_input:
a = a.t()
b = torch.randn(N, K, dtype=torch.bfloat16, device=device)

# Quantize inputs.
Expand Down Expand Up @@ -205,8 +216,9 @@ def _test_matmul_fp8_block(
)

_test_matmul_fp8_block((3, 4, 5), (256, 256, 256), True)
_test_matmul_fp8_block((5, 4, 5), (256, 256, 256), True, transpose_input=True)
_test_matmul_fp8_block((1024, 2048, 4096), (256, 512, 1024), True)
_test_matmul_fp8_block((1024, 2048, 4096), (256, 512, 1024), False)
_test_matmul_fp8_block((3, 4, 5), (256, 256, 256), False)
_test_matmul_fp8_block((3, 4, 5), (256, 256, 256), True, "cpu")
_test_matmul_fp8_block((1024, 2048, 4096), (256, 512, 1024), True, "cpu")
_test_matmul_fp8_block((3, 4, 5), (256, 256, 256), True, device="cpu")
_test_matmul_fp8_block((1024, 2048, 4096), (256, 512, 1024), True, device="cpu")
31 changes: 25 additions & 6 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,12 +1211,12 @@ def _kernel_matmul_fp8_block_fastacc(

for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):

k_remaining = K - k * (BLOCK_K * SPLIT_K)

if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)

a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
if AB_DTYPE:
Expand All @@ -1235,7 +1235,7 @@ def _kernel_matmul_fp8_block_fastacc(
# And have s_k+1 be 1.
# Scale_i = pid_i * BLOCK_I / scale_block_i
pid_k = k * SPLIT_K + pid_z
if (pid_k + 1) % k_multiple == 0:
if ((pid_k + 1) % k_multiple == 0) or (k_remaining < BLOCK_K * SPLIT_K):
# Note: Due to split_k access "pid_k" = k * SPLIT_K + pid_z
# Access a_scale[pid_m, k * SPLIT_K + pid_z]
# and b_scale[k * SPLIT_K + pid_z, pid_n]
Expand Down Expand Up @@ -1676,6 +1676,8 @@ def _kernel_quantize_fp8_row(
N,
stride_am,
stride_an,
stride_om,
stride_on,
TL_FP8_DTYPE: tl.constexpr,
MAX_FP8: tl.constexpr,
EPS: tl.constexpr,
Expand All @@ -1701,6 +1703,8 @@ def _kernel_quantize_fp8_row(
N (int): Number of columns.
stride_am (int): Stride of m dimension of A.
stride_an (int): Stride of n dimension of A.
stride_om (int): Stride of m dimension of output.
stride_on (int): Stride of n dimension of output.
TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
MAX_FP8 (float): Maxmimum expressible value for FP8.
EPS (float): Epsilon value for numerical stability.
Expand Down Expand Up @@ -1742,7 +1746,7 @@ def _kernel_quantize_fp8_row(
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8)
a_fp8.to(TL_FP8_DTYPE)
tl.store(
A_fp8 + pid * stride_am + n_offset * stride_an, a_fp8, mask=n_offset < N
A_fp8 + pid * stride_om + n_offset * stride_on, a_fp8, mask=n_offset < N
)
n_offset += BLOCK_SIZE

Expand Down Expand Up @@ -1779,6 +1783,8 @@ def triton_quantize_fp8_row(
a.shape[1],
a.stride(0),
a.stride(1),
a_fp8.stride(0),
a_fp8.stride(1),
TL_FP8_DTYPE=tl_dtype,
MAX_FP8=max_fp8,
EPS=eps,
Expand Down Expand Up @@ -1859,6 +1865,8 @@ def _kernel_scale_fp8_row(
N,
stride_am,
stride_an,
stride_om,
stride_on,
BLOCK_SIZE: tl.constexpr,
) -> None:
"""
Expand All @@ -1873,6 +1881,8 @@ def _kernel_scale_fp8_row(
N (int): Number of columns.
stride_am (int): Stride of m dimension of A.
stride_an (int): Stride of n dimension of A.
stride_om (int): Stride of m dimension of output.
stride_on (int): Stride of n dimension of output.
BLOCK_SIZE (int): Block size for data loads.
"""
pid = tl.program_id(0)
Expand All @@ -1886,7 +1896,7 @@ def _kernel_scale_fp8_row(
col_scale = tl.load(w_scale + n_offset)
scaled_a = a * row_scale * col_scale
tl.store(
scaled_out + pid * stride_am + n_offset * stride_an,
scaled_out + pid * stride_om + n_offset * stride_on,
scaled_a,
mask=n_offset < N,
)
Expand Down Expand Up @@ -1925,6 +1935,8 @@ def scale_fp8_row(
a.shape[1],
a.stride(0),
a.stride(1),
scaled_out.stride(0),
scaled_out.stride(1),
)

return scaled_out
Expand All @@ -1940,6 +1952,8 @@ def _kernel_quantize_fp8_block(
K,
stride_am,
stride_ak,
stride_om,
stride_ok,
stride_a_scale_m,
stride_a_scale_k,
TL_FP8_DTYPE: tl.constexpr,
Expand Down Expand Up @@ -1967,6 +1981,8 @@ def _kernel_quantize_fp8_block(
K (int): Number of columns.
stride_am (int): Stride of m dimension of A.
stride_ak (int): Stride of k dimension of A.
stride_om (int): Stride of m dimension of output.
stride_ok (int): Stride of k dimension of output.
stride_a_scale_m (int): Stride of m dimension of A_scale.
stride_a_scale_k (int): Stride of k dimension of A_scale.
TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
Expand All @@ -1983,6 +1999,7 @@ def _kernel_quantize_fp8_block(
rm = block_m * BLOCK_M + tl.arange(0, BLOCK_M)
rk = block_k * BLOCK_K + tl.arange(0, BLOCK_K)
a_offset = rm[:, None] * stride_am + rk[None, :] * stride_ak
out_offset = rm[:, None] * stride_om + rk[None, :] * stride_ok
a_mask = (rm < M)[:, None] & (rk < K)[None, :]
a_block = tl.load(A + a_offset, mask=a_mask, other=0.0)

Expand All @@ -2004,7 +2021,7 @@ def _kernel_quantize_fp8_block(
# handles it, but it's nice to have anyway.
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8)
a_fp8.to(TL_FP8_DTYPE)
tl.store(A_fp8 + a_offset, a_fp8, mask=a_mask)
tl.store(A_fp8 + out_offset, a_fp8, mask=a_mask)


def triton_quantize_fp8_block(
Expand Down Expand Up @@ -2050,6 +2067,8 @@ def triton_quantize_fp8_block(
K,
x.stride(0),
x.stride(1),
x_fp8.stride(0),
x_fp8.stride(1),
x_scale.stride(0),
x_scale.stride(1),
# pyre-ignore[6]: Incompatible parameter type [6]
Expand Down

0 comments on commit 70043c7

Please sign in to comment.