From 9229df9acd912bcf00e8faf138a33382d94e23b2 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Tue, 1 Oct 2024 17:55:26 -0700 Subject: [PATCH] Float8 dynamic autoquant (#946) --- test/integration/test_integration.py | 54 +++++++++++++++++----------- torchao/_models/llama/generate.py | 4 ++- torchao/kernel/intmm.py | 2 +- torchao/quantization/__init__.py | 1 + torchao/quantization/autoquant.py | 48 +++++++++++++++++++++++-- 5 files changed, 84 insertions(+), 25 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 1303f2279..5f81858ba 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -73,6 +73,7 @@ AQInt8WeightOnlyQuantizedLinearWeight3, AutoQuantizableLinearWeight, AQFloat8WeightOnlyQuantizedLinearWeight, + AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, ) from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx import os @@ -677,27 +678,28 @@ def _test_lin_weight_subclass_impl( ): if not "cuda" in test_device: self.skipTest("test requires cuda") - m, k, n = test_shape - x = torch.randn(m, k, device=test_device, dtype=test_dtype) - lin = torch.nn.Linear(k, n, device=test_device).to(test_dtype) - ref_f = lin(x) - - lin.weight = torch.nn.Parameter( - test_subclass_from_float(lin.weight), requires_grad=False - ) - test = lin(x) - self.assertGreater( - SQNR(ref_f, test), - min_sqnr, - f"{lin.weight.__class__.__name__} failed, no compile, dtype={test_dtype}, (m, k, n)={test_shape}" - ) - lin_comp = torch.compile(lin, mode='max-autotune') - test_comp = lin_comp(x) - self.assertGreater( - SQNR(ref_f, test_comp), - min_sqnr, - f"{lin.weight.__class__.__name__} failed at compile with dtype={test_dtype}, (m, k, n)={test_shape}" - ) + with torch.no_grad(): + m, k, n = test_shape + x = torch.randn(m, k, device=test_device, dtype=test_dtype) + lin = torch.nn.Linear(k, n, device=test_device).to(test_dtype) + ref_f = lin(x) + + lin.weight = torch.nn.Parameter( + test_subclass_from_float(lin.weight), requires_grad=False + ) + test = lin(x) + self.assertGreater( + SQNR(ref_f, test), + min_sqnr, + f"{lin.weight.__class__.__name__} failed, no compile, dtype={test_dtype}, (m, k, n)={test_shape}" + ) + lin_comp = torch.compile(lin, mode='max-autotune') + test_comp = lin_comp(x) + self.assertGreater( + SQNR(ref_f, test_comp), + min_sqnr, + f"{lin.weight.__class__.__name__} failed at compile with dtype={test_dtype}, (m, k, n)={test_shape}" + ) @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "skip because there is some bug in inductor codegen") @@ -753,6 +755,16 @@ def test_aq_float8_weight_only_quant_subclass(self, device, dtype): AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype ) + @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") + @unittest.skipIf(not is_H100, "Need H100 to run") + def test_aq_float8_dynamic_quant_subclass(self, device, dtype): + if dtype != torch.bfloat16: + self.skipTest("Fails for {dtype}") + self._test_lin_weight_subclass_impl( + AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype + ) + @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 5fb905dbf..19e42e7cd 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -246,6 +246,8 @@ def main( if "autoquant" in quantization: if "autoquant-int4" == quantization: model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST) + elif "autoquant-float8" == quantization: + model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST) else: model = autoquant(model, manual=True) @@ -415,7 +417,7 @@ def callback(x): parser.add_argument('-q', '--quantization', type=str, help=( 'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-, int4wo--hqq, autoquant, ' - +'autoquant-int4, uintx--, uintx---hqq, sparse-marlin' + +'autoquant-int4, autoquant-float8, uintx--, uintx---hqq, sparse-marlin' ) ) parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache') diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 81e7b19b1..7d076a6e8 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -53,7 +53,7 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: and j_is_nonzero_multiple_of_8 and k_is_nonzero_multiple_of_8 ) - + if device_cpu or bad_dimensions_for_cublas: # fallback path return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to( diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 05c55b255..9eb312dd6 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -23,6 +23,7 @@ "autoquant", "DEFAULT_AUTOQUANT_CLASS_LIST", "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", + "OTHER_AUTOQUANT_CLASS_LIST", "get_scale", "SmoothFakeDynQuantMixin", "SmoothFakeDynamicallyQuantizedLinear", diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 089add1d8..a5568c4e1 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -17,6 +17,8 @@ ) from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5 from torchao.quantization.utils import quantize_activation_per_token_absmax +from torchao.quantization.observer import PerAxis, PerTensor, PerRow +from torchao.float8.inference import Float8MMConfig import torch.nn.functional as F @@ -25,6 +27,7 @@ "autoquant", "DEFAULT_AUTOQUANT_CLASS_LIST", "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", + "OTHER_AUTOQUANT_CLASS_LIST", ] @@ -221,7 +224,6 @@ def do_autoquant_bench(op, *args, **kwargs): stream.synchronize() torch.cuda.current_stream().wait_stream(stream) torch.cuda.synchronize() - graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): op(*args, **kwargs) @@ -492,6 +494,47 @@ def from_float(cls, weight): block_size = (1, weight.shape[1]) return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=cls.target_dtype, layout_type=Float8LayoutType()) +class AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedTensor): + """ + AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight using per row scaling + """ + activation_granularity: str = PerRow() + @classmethod + def from_float(cls, weight): + + # avoid circular dep + from torchao.dtypes import to_affine_quantized_floatx + from torchao.quantization.quant_api import _input_activation_quant_func_fp8 + # weight settings + def get_weight_block_size(x): + return (1, x.shape[1]) + target_dtype = torch.float8_e4m3fn + + # input settings + def get_per_token_block_size(x): + block_size = list(x.shape) + for i in range(len(block_size)-1): + block_size[i] = 1 + return block_size + + input_target_dtype = torch.float8_e4m3fn + layout_type = Float8LayoutType(mm_config=Float8MMConfig(use_fast_accum=True)) + input_quant_func = lambda x: _input_activation_quant_func_fp8( + x=x, + activation_granularity=cls.activation_granularity, + activation_dtype=input_target_dtype, + ) + block_size = get_weight_block_size(weight) + weight = to_affine_quantized_floatx( + input_float=weight, + block_size=block_size, + target_dtype=target_dtype, + layout_type=layout_type, + scale_dtype=torch.float32, + ) + weight = super(AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func) + return weight + # here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison DEFAULT_AUTOQUANT_CLASS_LIST = [ @@ -511,6 +554,7 @@ def from_float(cls, weight): OTHER_AUTOQUANT_CLASS_LIST = [ AQFloat8WeightOnlyQuantizedLinearWeight, + AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, ] @@ -638,7 +682,7 @@ def autoquant( if set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() - if qtensor_class_list in OTHER_AUTOQUANT_CLASS_LIST: + if qtensor_class_list is OTHER_AUTOQUANT_CLASS_LIST: assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9), "float8 requires CUDA arch >= 8.9" # perform initial swap from linear weights