Skip to content

Commit

Permalink
Float8 dynamic autoquant (#946)
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva authored Oct 2, 2024
1 parent 83d5b63 commit 9229df9
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 25 deletions.
54 changes: 33 additions & 21 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
AQInt8WeightOnlyQuantizedLinearWeight3,
AutoQuantizableLinearWeight,
AQFloat8WeightOnlyQuantizedLinearWeight,
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight,
)
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
import os
Expand Down Expand Up @@ -677,27 +678,28 @@ def _test_lin_weight_subclass_impl(
):
if not "cuda" in test_device:
self.skipTest("test requires cuda")
m, k, n = test_shape
x = torch.randn(m, k, device=test_device, dtype=test_dtype)
lin = torch.nn.Linear(k, n, device=test_device).to(test_dtype)
ref_f = lin(x)

lin.weight = torch.nn.Parameter(
test_subclass_from_float(lin.weight), requires_grad=False
)
test = lin(x)
self.assertGreater(
SQNR(ref_f, test),
min_sqnr,
f"{lin.weight.__class__.__name__} failed, no compile, dtype={test_dtype}, (m, k, n)={test_shape}"
)
lin_comp = torch.compile(lin, mode='max-autotune')
test_comp = lin_comp(x)
self.assertGreater(
SQNR(ref_f, test_comp),
min_sqnr,
f"{lin.weight.__class__.__name__} failed at compile with dtype={test_dtype}, (m, k, n)={test_shape}"
)
with torch.no_grad():
m, k, n = test_shape
x = torch.randn(m, k, device=test_device, dtype=test_dtype)
lin = torch.nn.Linear(k, n, device=test_device).to(test_dtype)
ref_f = lin(x)

lin.weight = torch.nn.Parameter(
test_subclass_from_float(lin.weight), requires_grad=False
)
test = lin(x)
self.assertGreater(
SQNR(ref_f, test),
min_sqnr,
f"{lin.weight.__class__.__name__} failed, no compile, dtype={test_dtype}, (m, k, n)={test_shape}"
)
lin_comp = torch.compile(lin, mode='max-autotune')
test_comp = lin_comp(x)
self.assertGreater(
SQNR(ref_f, test_comp),
min_sqnr,
f"{lin.weight.__class__.__name__} failed at compile with dtype={test_dtype}, (m, k, n)={test_shape}"
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "skip because there is some bug in inductor codegen")
Expand Down Expand Up @@ -753,6 +755,16 @@ def test_aq_float8_weight_only_quant_subclass(self, device, dtype):
AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
@unittest.skipIf(not is_H100, "Need H100 to run")
def test_aq_float8_dynamic_quant_subclass(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest("Fails for {dtype}")
self._test_lin_weight_subclass_impl(
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
Expand Down
4 changes: 3 additions & 1 deletion torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@ def main(
if "autoquant" in quantization:
if "autoquant-int4" == quantization:
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
elif "autoquant-float8" == quantization:
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST)
else:
model = autoquant(model, manual=True)

Expand Down Expand Up @@ -415,7 +417,7 @@ def callback(x):
parser.add_argument('-q', '--quantization', type=str,
help=(
'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
+'autoquant-int4, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin'
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin'
)
)
parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')
Expand Down
2 changes: 1 addition & 1 deletion torchao/kernel/intmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
and j_is_nonzero_multiple_of_8
and k_is_nonzero_multiple_of_8
)

if device_cpu or bad_dimensions_for_cublas:
# fallback path
return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to(
Expand Down
1 change: 1 addition & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"autoquant",
"DEFAULT_AUTOQUANT_CLASS_LIST",
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
"OTHER_AUTOQUANT_CLASS_LIST",
"get_scale",
"SmoothFakeDynQuantMixin",
"SmoothFakeDynamicallyQuantizedLinear",
Expand Down
48 changes: 46 additions & 2 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5
from torchao.quantization.utils import quantize_activation_per_token_absmax
from torchao.quantization.observer import PerAxis, PerTensor, PerRow
from torchao.float8.inference import Float8MMConfig

import torch.nn.functional as F

Expand All @@ -25,6 +27,7 @@
"autoquant",
"DEFAULT_AUTOQUANT_CLASS_LIST",
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
"OTHER_AUTOQUANT_CLASS_LIST",
]


Expand Down Expand Up @@ -221,7 +224,6 @@ def do_autoquant_bench(op, *args, **kwargs):
stream.synchronize()
torch.cuda.current_stream().wait_stream(stream)
torch.cuda.synchronize()

graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
op(*args, **kwargs)
Expand Down Expand Up @@ -492,6 +494,47 @@ def from_float(cls, weight):
block_size = (1, weight.shape[1])
return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=cls.target_dtype, layout_type=Float8LayoutType())

class AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedTensor):
"""
AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight using per row scaling
"""
activation_granularity: str = PerRow()
@classmethod
def from_float(cls, weight):

# avoid circular dep
from torchao.dtypes import to_affine_quantized_floatx
from torchao.quantization.quant_api import _input_activation_quant_func_fp8
# weight settings
def get_weight_block_size(x):
return (1, x.shape[1])
target_dtype = torch.float8_e4m3fn

# input settings
def get_per_token_block_size(x):
block_size = list(x.shape)
for i in range(len(block_size)-1):
block_size[i] = 1
return block_size

input_target_dtype = torch.float8_e4m3fn
layout_type = Float8LayoutType(mm_config=Float8MMConfig(use_fast_accum=True))
input_quant_func = lambda x: _input_activation_quant_func_fp8(
x=x,
activation_granularity=cls.activation_granularity,
activation_dtype=input_target_dtype,
)
block_size = get_weight_block_size(weight)
weight = to_affine_quantized_floatx(
input_float=weight,
block_size=block_size,
target_dtype=target_dtype,
layout_type=layout_type,
scale_dtype=torch.float32,
)
weight = super(AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func)
return weight


# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison
DEFAULT_AUTOQUANT_CLASS_LIST = [
Expand All @@ -511,6 +554,7 @@ def from_float(cls, weight):

OTHER_AUTOQUANT_CLASS_LIST = [
AQFloat8WeightOnlyQuantizedLinearWeight,
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight,
]


Expand Down Expand Up @@ -638,7 +682,7 @@ def autoquant(
if set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()

if qtensor_class_list in OTHER_AUTOQUANT_CLASS_LIST:
if qtensor_class_list is OTHER_AUTOQUANT_CLASS_LIST:
assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9), "float8 requires CUDA arch >= 8.9"

# perform initial swap from linear weights
Expand Down

0 comments on commit 9229df9

Please sign in to comment.