diff --git a/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py b/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py index 43a0e7ec3..1c78f065d 100644 --- a/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py +++ b/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py @@ -31,14 +31,13 @@ def setUp(self) -> None: def test_quantize_fp8_row(self) -> None: def _test_quantize_fp8_row( - shape: Tuple[int, int], + shape: Tuple[int, ...], use_triton: bool, device: torch.device, output_device: Optional[torch.device] = None, use_scale_ub: bool = False, ) -> None: - M, K = shape - a = torch.randn(M, K, dtype=torch.bfloat16, device=device) + a = torch.randn(shape, dtype=torch.bfloat16, device=device) scale_ub = ( torch.tensor([1200], dtype=torch.float, device=device) @@ -52,7 +51,8 @@ def _test_quantize_fp8_row( # Undo scaling. a_torch = a_fp8.to(torch.bfloat16) - a_torch *= a_scale[:, None] + broadcast_shape = list(a_torch.shape[:-1]) + [-1] + a_torch *= a_scale.view(broadcast_shape) self.assertTrue( torch.allclose( @@ -61,6 +61,8 @@ def _test_quantize_fp8_row( ) _test_quantize_fp8_row((2, 3), True, torch.device("cuda")) + # Test with batched input. + _test_quantize_fp8_row((4, 2, 3), True, torch.device("cuda")) _test_quantize_fp8_row((2, 3), True, torch.device("cuda"), use_scale_ub=True) _test_quantize_fp8_row((2, 3), False, torch.device("cpu"), torch.device("cuda")) _test_quantize_fp8_row( diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index d90f12e0a..5862a2f4d 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -2128,9 +2128,10 @@ def quantize_fp8_row_meta( """Shape function for torch compile.""" if output_device is None: output_device = a.device - M, K = a.shape + # Flatten to 2D since each row of each potential batch gets a scale. + M = a.view(-1, a.shape[-1]).shape[0] dtype = get_fp8_constants()[0] - fake_out = torch.empty((M, K), device=output_device, dtype=dtype) + fake_out = torch.empty(a.shape, device=output_device, dtype=dtype) fake_scale = torch.empty((M), device=output_device, dtype=torch.float32) return fake_out, fake_scale