From a5ff8daf158c2073aebe3ada4de664cc3ec51d66 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 4 Oct 2024 15:09:45 -0700 Subject: [PATCH] Allow N-D inputs to triton fp8 row quantize Summary: We previously assumed inputs to fp8 quantize would be 2D, however we now are working with higher dimension workloads that would benefit from FP8. This small diff adds more general shape checking to fp8 quantization. Differential Revision: D63921964 --- fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py | 10 ++++++---- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py | 5 +++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py b/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py index 43a0e7ec3..1c78f065d 100644 --- a/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py +++ b/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py @@ -31,14 +31,13 @@ def setUp(self) -> None: def test_quantize_fp8_row(self) -> None: def _test_quantize_fp8_row( - shape: Tuple[int, int], + shape: Tuple[int, ...], use_triton: bool, device: torch.device, output_device: Optional[torch.device] = None, use_scale_ub: bool = False, ) -> None: - M, K = shape - a = torch.randn(M, K, dtype=torch.bfloat16, device=device) + a = torch.randn(shape, dtype=torch.bfloat16, device=device) scale_ub = ( torch.tensor([1200], dtype=torch.float, device=device) @@ -52,7 +51,8 @@ def _test_quantize_fp8_row( # Undo scaling. a_torch = a_fp8.to(torch.bfloat16) - a_torch *= a_scale[:, None] + broadcast_shape = list(a_torch.shape[:-1]) + [-1] + a_torch *= a_scale.view(broadcast_shape) self.assertTrue( torch.allclose( @@ -61,6 +61,8 @@ def _test_quantize_fp8_row( ) _test_quantize_fp8_row((2, 3), True, torch.device("cuda")) + # Test with batched input. + _test_quantize_fp8_row((4, 2, 3), True, torch.device("cuda")) _test_quantize_fp8_row((2, 3), True, torch.device("cuda"), use_scale_ub=True) _test_quantize_fp8_row((2, 3), False, torch.device("cpu"), torch.device("cuda")) _test_quantize_fp8_row( diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index d90f12e0a..5862a2f4d 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -2128,9 +2128,10 @@ def quantize_fp8_row_meta( """Shape function for torch compile.""" if output_device is None: output_device = a.device - M, K = a.shape + # Flatten to 2D since each row of each potential batch gets a scale. + M = a.view(-1, a.shape[-1]).shape[0] dtype = get_fp8_constants()[0] - fake_out = torch.empty((M, K), device=output_device, dtype=dtype) + fake_out = torch.empty(a.shape, device=output_device, dtype=dtype) fake_scale = torch.empty((M), device=output_device, dtype=torch.float32) return fake_out, fake_scale