-
Notifications
You must be signed in to change notification settings - Fork 227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace WeightOnlyInt8Linear with TorchAO int8_weight_only quantization #1328
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,7 @@ | |
|
||
# from functools import reduce | ||
# from math import gcd | ||
from typing import Dict, Optional, Callable, Any, List | ||
from typing import Any, Callable, Dict, List, Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
@@ -37,6 +37,7 @@ | |
from torchao.quantization.quant_api import ( | ||
int4_weight_only, | ||
Int4WeightOnlyQuantizer, | ||
int8_weight_only, | ||
Int8DynActInt4WeightQuantizer, | ||
quantize_, | ||
) | ||
|
@@ -45,8 +46,8 @@ | |
find_multiple, | ||
get_device_str, | ||
get_precision, | ||
set_precision, | ||
name_to_dtype, | ||
set_precision, | ||
state_dict_device, | ||
use_et_backend, | ||
) | ||
|
@@ -60,28 +61,36 @@ | |
|
||
import inspect | ||
|
||
|
||
def get_named_parameters(func: Callable) -> List[str]: | ||
# Get the signature of the function | ||
signature = inspect.signature(func) | ||
|
||
# Extract the parameters from the signature | ||
parameters = signature.parameters | ||
|
||
# Filter and return named parameters | ||
named_params = [ | ||
name for name, param in parameters.items() | ||
if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) | ||
name | ||
for name, param in parameters.items() | ||
if param.kind | ||
in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) | ||
] | ||
return named_params | ||
|
||
def validate_args(named_params: List[str], q_kwargs: Dict[str, Any], quantizer: Optional[str] = None) -> Dict[str, Any]: | ||
|
||
def validate_args( | ||
named_params: List[str], q_kwargs: Dict[str, Any], quantizer: Optional[str] = None | ||
) -> Dict[str, Any]: | ||
for key in q_kwargs.keys(): | ||
if key not in named_params: | ||
print(f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring.") | ||
print( | ||
f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring." | ||
) | ||
del q_kwargs[key] | ||
return q_kwargs | ||
|
||
|
||
######################################################################### | ||
### torchchat quantization API ### | ||
|
||
|
@@ -110,21 +119,30 @@ def quantize_model( | |
if quantizer not in quantizer_class_dict: | ||
raise RuntimeError(f"unknown quantizer {quantizer} specified") | ||
else: | ||
ao_quant = True | ||
# Use tensor subclass API for int4 weight only. | ||
if device == "cuda" and quantizer == "linear:int4": | ||
quantize_(model, int4_weight_only(q_kwargs["groupsize"])) | ||
elif quantizer == "linear:int8": | ||
print("quantizer is linear int8") | ||
quantize_(model, int8_weight_only()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not integrate it into a QuantHandler class dispatched thru the handler dict at a single call site rather than build a chain of if statements? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @mikekgfb, we will refactor this part in the future after all quant APIs are moved to torchao I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. torchAO already has a class-based API that is used for other quantizers? Why do these differently, and then later refactor them? Or why not do them all a consistent way now, and if you refactor later, do that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, quantizer API is deprecated in favor of |
||
else: | ||
ao_quant = False | ||
if ao_quant: | ||
if not support_tensor_subclass: | ||
unwrap_tensor_subclass(model) | ||
continue | ||
|
||
if quantizer in ["linear:a8wxdq", "embedding:wx"]: | ||
# These quantizers require float32 input weights. Note that after quantization, | ||
# the weights will no longer be float32, but lowbit integers | ||
if get_precision() != torch.float32: | ||
print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.") | ||
print( | ||
f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32." | ||
) | ||
set_precision(torch.float32) | ||
# We set global precision from quantize options if it is specified at cli.py:485 | ||
|
||
# We set global precision from quantize options if it is specified at cli.py:485 | ||
# so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat | ||
precision = get_precision() | ||
|
||
|
@@ -141,14 +159,19 @@ def quantize_model( | |
model = quant_handler.quantize(model) | ||
|
||
|
||
|
||
######################################################################### | ||
### QuantHandler API definition ### | ||
### (unify with torchao in future) ### | ||
|
||
|
||
class QuantHandler: | ||
def __init__(self, model: Optional[nn.Module] = None, device="cpu", precision=None, tokenizer=None): | ||
def __init__( | ||
self, | ||
model: Optional[nn.Module] = None, | ||
device="cpu", | ||
precision=None, | ||
tokenizer=None, | ||
): | ||
self.model_ = model | ||
self.device = device | ||
self.tokenizer = tokenizer | ||
|
@@ -176,7 +199,15 @@ def quantize(self, model: nn.Module) -> nn.Module: | |
|
||
|
||
class PrecisionHandler(QuantHandler): | ||
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, dtype): | ||
def __init__( | ||
self, | ||
model: Optional[nn.Module] = None, | ||
device="cpu", | ||
precision=None, | ||
tokenizer=None, | ||
*, | ||
dtype, | ||
): | ||
self.model_ = model | ||
self.device = device | ||
self.tokenizer = tokenizer | ||
|
@@ -205,7 +236,15 @@ def quantized_model(self) -> nn.Module: | |
|
||
|
||
class ExecutorHandler(QuantHandler): | ||
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, accelerator): | ||
def __init__( | ||
self, | ||
model: Optional[nn.Module] = None, | ||
device="cpu", | ||
precision=None, | ||
tokenizer=None, | ||
*, | ||
accelerator, | ||
): | ||
self.model_ = model | ||
|
||
if isinstance(accelerator, str): | ||
|
@@ -529,147 +568,6 @@ def linear_int8_et(input, weight, scales): | |
) | ||
|
||
|
||
class WeightOnlyInt8Linear(nn.Module): | ||
__constants__ = ["in_features", "out_features"] | ||
in_features: int | ||
out_features: int | ||
weight: torch.Tensor | ||
scales: torch.Tensor | ||
|
||
def __init__( | ||
self, | ||
in_features, | ||
out_features, | ||
bias=None, | ||
device=None, | ||
dtype=None, | ||
*, | ||
weight: Optional[torch.Tensor] = None, | ||
scales: Optional[torch.Tensor] = None, | ||
groupsize: Optional[int] = None, | ||
): | ||
super().__init__() | ||
if dtype is None: | ||
dtype = torch.get_default_dtype() | ||
|
||
if device is None: | ||
device = "cpu" | ||
|
||
assert not bias, "Bias is not supported by LinearInt8" | ||
self.in_features = in_features | ||
self.out_features = out_features | ||
|
||
assert (weight is None) == bool( | ||
scales is None | ||
), "must specify both weights and scales, or neither" | ||
if weight is None: | ||
weight = torch.empty( | ||
(out_features, in_features), | ||
dtype=torch.int8, | ||
device=device, | ||
) | ||
if groupsize is None or (groupsize == 0): | ||
scales = torch.empty(out_features, dtype=dtype, device=device) | ||
else: | ||
n_groups = (in_features + groupsize - 1) // groupsize | ||
scales = torch.empty(out_features, n_groups, dtype=dtype, device=device) | ||
|
||
self.register_buffer("weight", weight.to(device)) | ||
self.register_buffer("scales", scales.to(device)) | ||
|
||
if use_et_backend(): | ||
self.forward = self.et_forward | ||
else: | ||
self.forward = self.aoti_forward | ||
|
||
def aoti_forward(self, input: torch.Tensor) -> torch.Tensor: | ||
return linear_int8_aoti(input, self.weight, self.scales) | ||
|
||
def et_forward(self, input: torch.Tensor) -> torch.Tensor: | ||
return linear_int8_et(input, self.weight, self.scales) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Int 8 seems like it special cased for ET, reminder to check that as well |
||
|
||
|
||
class WeightOnlyInt8QuantHandler(QuantHandler): | ||
def __init__( | ||
self, | ||
model: Optional[nn.Module] = None, | ||
device = None, | ||
precision=None, | ||
tokenizer=None, | ||
*, | ||
node_type: str = "*", | ||
bitwidth: Optional[int] = None, | ||
groupsize: Optional[int] = None, | ||
): | ||
self.model_ = model | ||
self.device = device | ||
self.groupsize = groupsize | ||
self.node_type = node_type | ||
if bitwidth is None: | ||
self.bitwidth = 8 | ||
else: | ||
self.bitwidth = bitwidth | ||
|
||
@torch.no_grad() | ||
def quantize(self, module): | ||
# cur_state_dict = state_dict_device(self.model_.state_dict()) | ||
# dict_device = "cpu" # self.device | ||
|
||
if self.bitwidth == 4: | ||
range_min = -8 | ||
range_max = 7 | ||
elif self.bitwidth == 8: | ||
range_min = -128 | ||
range_max = 127 | ||
else: | ||
raise ValueError(f"Unsupported bitwidth {self.bitwidth}") | ||
|
||
for name, child in module.named_children(): | ||
# print(f"name: {name}") | ||
if isinstance(child, nn.Linear): | ||
if ( | ||
(self.node_type == "*") | ||
or (self.node_type == "output" and name == "output") | ||
or (self.node_type == "!output" and name != "output") | ||
): | ||
# print(f"{name, child}") | ||
input_weight = child.weight.float() | ||
# print(f"{name, child}") | ||
# print(f"in_features: {child.in_features}") | ||
# print(f"out_features: {child.out_features}") | ||
|
||
# print(f"expanded weight shape {input_weight.shape}") | ||
weight, scales, _ = dynamically_quantize_per_channel( | ||
input_weight, | ||
range_min, | ||
range_max, | ||
torch.int8, | ||
self.groupsize, | ||
scales_dtype=child.weight.dtype, | ||
) | ||
|
||
setattr( | ||
module, | ||
name, | ||
WeightOnlyInt8Linear( | ||
in_features=child.in_features, | ||
out_features=child.out_features, | ||
device=self.device, | ||
# update variables from quantization | ||
weight=weight, | ||
scales=scales, | ||
groupsize=self.groupsize, | ||
), | ||
) | ||
else: | ||
self.quantize(child) | ||
|
||
return module | ||
|
||
def quantized_model(self) -> nn.Module: | ||
return self.quantize(self.model_) | ||
|
||
|
||
######################################################################### | ||
##### embedding table quantization ###### | ||
### (unify with torchao in future) ### | ||
|
@@ -886,10 +784,10 @@ def quantized_model(self) -> nn.Module: | |
# class references | ||
quantizer_class_dict = { | ||
"embedding": EmbeddingOnlyQuantHandler, | ||
"linear:int8": WeightOnlyInt8QuantHandler, | ||
"precision": PrecisionHandler, | ||
"executor": ExecutorHandler, | ||
"linear:int4": Int4WeightOnlyQuantizer, | ||
"linear:int8": int8_weight_only, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can probably use None for now, and remove this later There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We check for int8_weight_only and finished check before it looks at the table I think @vmpuri can you check? |
||
"linear:a8w4dq": Int8DynActInt4WeightQuantizer, | ||
} | ||
|
||
|
@@ -932,11 +830,16 @@ def quantized_model(self) -> nn.Module: | |
print("Slow fallback kernels will be used.") | ||
|
||
except Exception as e: | ||
|
||
class ErrorHandler(QuantHandler): | ||
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None): | ||
def __init__( | ||
self, model: Optional[nn.Module] = None, device="cpu", precision=None | ||
): | ||
global torchao_experimental_load_error | ||
raise Exception(f"Note: Failed to load torchao experimental quantizer with error: {torchao_experimental_load_error}") | ||
|
||
raise Exception( | ||
f"Note: Failed to load torchao experimental quantizer with error: {torchao_experimental_load_error}" | ||
) | ||
|
||
torchao_experimental_load_error = e | ||
quantizer_class_dict["linear:a8wxdq"] = ErrorHandler | ||
quantizer_class_dict["embedding:wx"] = ErrorHandler |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.