Skip to content

Commit

Permalink
testing torchao autoquant [WIP]
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 89eedee21ce6e922710f791c95b05d4556d3184b
Pull Request resolved: #2266

updating apis

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
HDCharles committed Jun 28, 2024
1 parent 67ef897 commit 123ce91
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 24 deletions.
43 changes: 32 additions & 11 deletions torchbenchmark/util/backends/torchdynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace:
)
parser.add_argument(
"--quantization",
choices=["int8dynamic", "int8weightonly", "int4weightonly"],
choices=["int8dynamic", "int8weightonly", "int4weightonly","autoquant","noquant"],
help="Apply quantization to the model before running it",
)
parser.add_argument(
Expand Down Expand Up @@ -184,24 +184,45 @@ def apply_torchdynamo_args(
if args.quantization:
import torchao
from torchao.quantization import (
change_linear_weights_to_int4_woqtensors,
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
quantize, int8_weight_only, int4_weight_only, int8_dynamic_activation_int8_weight
)
from torchao.utils import unwrap_tensor_subclass

torch._dynamo.config.automatic_dynamic_shapes = False
torch._dynamo.config.force_parameter_static_shapes = False
torch._dynamo.config.cache_size_limit = 1000
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True
torch._dynamo.config.cache_size_limit = 10000
assert "cuda" in model.device
module, example_inputs = model.get_module()
if isinstance(example_inputs, tuple([tuple, list])):
example_inputs = tuple([
x.to(torch.bfloat16)
if isinstance(x, torch.Tensor) and x.dtype in [torch.float32, torch.float16]
else x
for x in example_inputs
])
module=module.to(torch.bfloat16)
with torch.no_grad():
if isinstance(example_inputs, dict):
module(**example_inputs)
else:
module(*example_inputs)
if args.quantization == "int8dynamic":
torch._inductor.config.force_fuse_int_mm_with_mul = True
change_linear_weights_to_int8_dqtensors(module)
quantize(module, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
elif args.quantization == "int8weightonly":
torch._inductor.config.use_mixed_mm = True
change_linear_weights_to_int8_woqtensors(module)
quantize(module, int8_weight_only(), set_inductor_config=False)
elif args.quantization == "int4weightonly":
change_linear_weights_to_int4_woqtensors(module)
quantize(module, int4_weight_only(), set_inductor_config=False)
if args.quantization == "autoquant":
torchao.autoquant(module, error_on_unseen=False, mode=["interpolate", .85], set_inductor_config=False)
if isinstance(example_inputs, dict):
module(**example_inputs)
else:
module(*example_inputs)
from torchao.quantization.autoquant import AUTOQUANT_CACHE
assert len(AUTOQUANT_CACHE)>0, f"Err: found no autoquantizable layers in model {type(module)}, stopping autoquantization"
else:
unwrap_tensor_subclass(module)

if args.freeze_prepack_weights:
torch._inductor.config.freezing = True
Expand Down
7 changes: 3 additions & 4 deletions userbenchmark/group_bench/configs/torch_ao.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@ device: cuda
extra_args: --precision bf16 --torchdynamo inductor --inductor-compile-mode max-autotune
metrics:
- latencies
- gpu_peak_mem
test_group:
test_batch_size_default:
subgroup:
- extra_args:
- extra_args: --quantization int8dynamic
- extra_args: --quantization int8weightonly
- extra_args: --quantization int4weightonly
- extra_args: --quantization noquant
- extra_args: --quantization autoquant
23 changes: 14 additions & 9 deletions userbenchmark/group_bench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,6 @@ def models_from_config(config) -> List[str]:
basic_models_list = list_models()
else:
basic_models_list = [config["model"]]
assert isinstance(config["model", list]), "Config model must be a list or string."
basic_models_list = config["model"]
extended_models_list = []
if "extended_models" in config:
from torchbenchmark.util.experiment.instantiator import list_extended_models
Expand Down Expand Up @@ -218,14 +216,21 @@ def run(args: List[str]):
results = {}
try:
for config in group_config.configs:
metrics = get_metrics(group_config.metrics, config)
if "accuracy" in metrics:
metrics_dict = run_config_accuracy(config, metrics, dryrun=args.dryrun)
else:
metrics_dict = run_config(config, metrics, dryrun=args.dryrun)
try:
metrics = get_metrics(group_config.metrics, config)
if "accuracy" in metrics:
metrics_dict = run_config_accuracy(config, metrics, dryrun=args.dryrun)
else:
metrics_dict = run_config(config, metrics, dryrun=args.dryrun)
except KeyboardInterrupt:
raise KeyboardInterrupt
except Exception as e:
print(f"for config, ran into error: {e}")
metrics_dict = {}
config_str = config_to_str(config)
for metric in metrics_dict:
results[f"{config_str}, metric={metric}"] = metrics_dict[metric]
for metric in metrics:
results[f"{config_str}, metric={metric}"] = metrics_dict.get(metric, "err")
print(f">>metric={metric}: {metrics_dict.get(metric, 'err')}")
except KeyboardInterrupt:
print("User keyboard interrupted!")
result = get_output_json(BM_NAME, results)
Expand Down

0 comments on commit 123ce91

Please sign in to comment.