Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

testing torchao autoquant [WIP] #2266

Open
wants to merge 2 commits into
base: gh/HDCharles/6/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 32 additions & 11 deletions torchbenchmark/util/backends/torchdynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace:
)
parser.add_argument(
"--quantization",
choices=["int8dynamic", "int8weightonly", "int4weightonly"],
choices=["int8dynamic", "int8weightonly", "int4weightonly","autoquant","noquant"],
help="Apply quantization to the model before running it",
)
parser.add_argument(
Expand Down Expand Up @@ -184,24 +184,45 @@ def apply_torchdynamo_args(
if args.quantization:
import torchao
from torchao.quantization import (
change_linear_weights_to_int4_woqtensors,
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
quantize, int8_weight_only, int4_weight_only, int8_dynamic_activation_int8_weight
)
from torchao.utils import unwrap_tensor_subclass

torch._dynamo.config.automatic_dynamic_shapes = False
torch._dynamo.config.force_parameter_static_shapes = False
torch._dynamo.config.cache_size_limit = 1000
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True
torch._dynamo.config.cache_size_limit = 10000
assert "cuda" in model.device
module, example_inputs = model.get_module()
if isinstance(example_inputs, tuple([tuple, list])):
example_inputs = tuple([
x.to(torch.bfloat16)
if isinstance(x, torch.Tensor) and x.dtype in [torch.float32, torch.float16]
else x
for x in example_inputs
])
module=module.to(torch.bfloat16)
with torch.no_grad():
if isinstance(example_inputs, dict):
module(**example_inputs)
else:
module(*example_inputs)
if args.quantization == "int8dynamic":
torch._inductor.config.force_fuse_int_mm_with_mul = True
change_linear_weights_to_int8_dqtensors(module)
quantize(module, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
elif args.quantization == "int8weightonly":
torch._inductor.config.use_mixed_mm = True
change_linear_weights_to_int8_woqtensors(module)
quantize(module, int8_weight_only(), set_inductor_config=False)
elif args.quantization == "int4weightonly":
change_linear_weights_to_int4_woqtensors(module)
quantize(module, int4_weight_only(), set_inductor_config=False)
if args.quantization == "autoquant":
torchao.autoquant(module, error_on_unseen=False, mode=["interpolate", .85], set_inductor_config=False)
if isinstance(example_inputs, dict):
module(**example_inputs)
else:
module(*example_inputs)
from torchao.quantization.autoquant import AUTOQUANT_CACHE
assert len(AUTOQUANT_CACHE)>0, f"Err: found no autoquantizable layers in model {type(module)}, stopping autoquantization"
else:
unwrap_tensor_subclass(module)

if args.freeze_prepack_weights:
torch._inductor.config.freezing = True
Expand Down
7 changes: 3 additions & 4 deletions userbenchmark/group_bench/configs/torch_ao.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@ device: cuda
extra_args: --precision bf16 --torchdynamo inductor --inductor-compile-mode max-autotune
metrics:
- latencies
- gpu_peak_mem
test_group:
test_batch_size_default:
subgroup:
- extra_args:
- extra_args: --quantization int8dynamic
- extra_args: --quantization int8weightonly
- extra_args: --quantization int4weightonly
- extra_args: --quantization noquant
- extra_args: --quantization autoquant
23 changes: 14 additions & 9 deletions userbenchmark/group_bench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,6 @@ def models_from_config(config) -> List[str]:
basic_models_list = list_models()
else:
basic_models_list = [config["model"]]
assert isinstance(config["model", list]), "Config model must be a list or string."
basic_models_list = config["model"]
extended_models_list = []
if "extended_models" in config:
from torchbenchmark.util.experiment.instantiator import list_extended_models
Expand Down Expand Up @@ -218,14 +216,21 @@ def run(args: List[str]):
results = {}
try:
for config in group_config.configs:
metrics = get_metrics(group_config.metrics, config)
if "accuracy" in metrics:
metrics_dict = run_config_accuracy(config, metrics, dryrun=args.dryrun)
else:
metrics_dict = run_config(config, metrics, dryrun=args.dryrun)
try:
metrics = get_metrics(group_config.metrics, config)
if "accuracy" in metrics:
metrics_dict = run_config_accuracy(config, metrics, dryrun=args.dryrun)
else:
metrics_dict = run_config(config, metrics, dryrun=args.dryrun)
except KeyboardInterrupt:
raise KeyboardInterrupt
except Exception as e:
print(f"for config, ran into error: {e}")
metrics_dict = {}
config_str = config_to_str(config)
for metric in metrics_dict:
results[f"{config_str}, metric={metric}"] = metrics_dict[metric]
for metric in metrics:
results[f"{config_str}, metric={metric}"] = metrics_dict.get(metric, "err")
print(f">>metric={metric}: {metrics_dict.get(metric, 'err')}")
except KeyboardInterrupt:
print("User keyboard interrupted!")
result = get_output_json(BM_NAME, results)
Expand Down
Loading