Skip to content

Commit

Permalink
Allow N-D inputs to triton fp8 row quantize
Browse files Browse the repository at this point in the history
Summary: We previously assumed inputs to fp8 quantize would be 2D, however we now are working with higher dimension workloads that would benefit from FP8. This small diff adds more general shape checking to fp8 quantization.

Differential Revision: D63921964
  • Loading branch information
jwfromm authored and facebook-github-bot committed Oct 4, 2024
1 parent 88ef5f9 commit a5ff8da
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
10 changes: 6 additions & 4 deletions fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,13 @@ def setUp(self) -> None:

def test_quantize_fp8_row(self) -> None:
def _test_quantize_fp8_row(
shape: Tuple[int, int],
shape: Tuple[int, ...],
use_triton: bool,
device: torch.device,
output_device: Optional[torch.device] = None,
use_scale_ub: bool = False,
) -> None:
M, K = shape
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
a = torch.randn(shape, dtype=torch.bfloat16, device=device)

scale_ub = (
torch.tensor([1200], dtype=torch.float, device=device)
Expand All @@ -52,7 +51,8 @@ def _test_quantize_fp8_row(

# Undo scaling.
a_torch = a_fp8.to(torch.bfloat16)
a_torch *= a_scale[:, None]
broadcast_shape = list(a_torch.shape[:-1]) + [-1]
a_torch *= a_scale.view(broadcast_shape)

self.assertTrue(
torch.allclose(
Expand All @@ -61,6 +61,8 @@ def _test_quantize_fp8_row(
)

_test_quantize_fp8_row((2, 3), True, torch.device("cuda"))
# Test with batched input.
_test_quantize_fp8_row((4, 2, 3), True, torch.device("cuda"))
_test_quantize_fp8_row((2, 3), True, torch.device("cuda"), use_scale_ub=True)
_test_quantize_fp8_row((2, 3), False, torch.device("cpu"), torch.device("cuda"))
_test_quantize_fp8_row(
Expand Down
5 changes: 3 additions & 2 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2128,9 +2128,10 @@ def quantize_fp8_row_meta(
"""Shape function for torch compile."""
if output_device is None:
output_device = a.device
M, K = a.shape
# Flatten to 2D since each row of each potential batch gets a scale.
M = a.view(-1, a.shape[-1]).shape[0]
dtype = get_fp8_constants()[0]
fake_out = torch.empty((M, K), device=output_device, dtype=dtype)
fake_out = torch.empty(a.shape, device=output_device, dtype=dtype)
fake_scale = torch.empty((M), device=output_device, dtype=torch.float32)
return fake_out, fake_scale

Expand Down

0 comments on commit a5ff8da

Please sign in to comment.