diff --git a/torchbenchmark/util/backends/torchdynamo.py b/torchbenchmark/util/backends/torchdynamo.py index 8fa8127d00..b26a204f4a 100644 --- a/torchbenchmark/util/backends/torchdynamo.py +++ b/torchbenchmark/util/backends/torchdynamo.py @@ -86,7 +86,7 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace: ) parser.add_argument( "--quantization", - choices=["int8dynamic", "int8weightonly", "int4weightonly"], + choices=["int8dynamic", "int8weightonly", "int4weightonly","autoquant","noquant"], help="Apply quantization to the model before running it", ) parser.add_argument( @@ -184,24 +184,45 @@ def apply_torchdynamo_args( if args.quantization: import torchao from torchao.quantization import ( - change_linear_weights_to_int4_woqtensors, - change_linear_weights_to_int8_dqtensors, - change_linear_weights_to_int8_woqtensors, + quantize, int8_weight_only, int4_weight_only, int8_dynamic_activation_int8_weight ) + from torchao.utils import unwrap_tensor_subclass torch._dynamo.config.automatic_dynamic_shapes = False - torch._dynamo.config.force_parameter_static_shapes = False - torch._dynamo.config.cache_size_limit = 1000 + torch._inductor.config.force_fuse_int_mm_with_mul = True + torch._inductor.config.use_mixed_mm = True + torch._dynamo.config.cache_size_limit = 10000 assert "cuda" in model.device module, example_inputs = model.get_module() + if isinstance(example_inputs, tuple([tuple, list])): + example_inputs = tuple([ + x.to(torch.bfloat16) + if isinstance(x, torch.Tensor) and x.dtype in [torch.float32, torch.float16] + else x + for x in example_inputs + ]) + module=module.to(torch.bfloat16) + with torch.no_grad(): + if isinstance(example_inputs, dict): + module(**example_inputs) + else: + module(*example_inputs) if args.quantization == "int8dynamic": - torch._inductor.config.force_fuse_int_mm_with_mul = True - change_linear_weights_to_int8_dqtensors(module) + quantize(module, int8_dynamic_activation_int8_weight(), set_inductor_config=False) elif args.quantization == "int8weightonly": - torch._inductor.config.use_mixed_mm = True - change_linear_weights_to_int8_woqtensors(module) + quantize(module, int8_weight_only(), set_inductor_config=False) elif args.quantization == "int4weightonly": - change_linear_weights_to_int4_woqtensors(module) + quantize(module, int4_weight_only(), set_inductor_config=False) + if args.quantization == "autoquant": + torchao.autoquant(module, error_on_unseen=False, mode=["interpolate", .85], set_inductor_config=False) + if isinstance(example_inputs, dict): + module(**example_inputs) + else: + module(*example_inputs) + from torchao.quantization.autoquant import AUTOQUANT_CACHE + assert len(AUTOQUANT_CACHE)>0, f"Err: found no autoquantizable layers in model {type(module)}, stopping autoquantization" + else: + unwrap_tensor_subclass(module) if args.freeze_prepack_weights: torch._inductor.config.freezing = True diff --git a/userbenchmark/group_bench/configs/torch_ao.yaml b/userbenchmark/group_bench/configs/torch_ao.yaml index 762668ea3f..78c206d54c 100644 --- a/userbenchmark/group_bench/configs/torch_ao.yaml +++ b/userbenchmark/group_bench/configs/torch_ao.yaml @@ -7,10 +7,9 @@ device: cuda extra_args: --precision bf16 --torchdynamo inductor --inductor-compile-mode max-autotune metrics: - latencies + - gpu_peak_mem test_group: test_batch_size_default: subgroup: - - extra_args: - - extra_args: --quantization int8dynamic - - extra_args: --quantization int8weightonly - - extra_args: --quantization int4weightonly + - extra_args: --quantization noquant + - extra_args: --quantization autoquant diff --git a/userbenchmark/group_bench/run.py b/userbenchmark/group_bench/run.py index 16a7f490f3..fcdc324efd 100644 --- a/userbenchmark/group_bench/run.py +++ b/userbenchmark/group_bench/run.py @@ -158,8 +158,6 @@ def models_from_config(config) -> List[str]: basic_models_list = list_models() else: basic_models_list = [config["model"]] - assert isinstance(config["model", list]), "Config model must be a list or string." - basic_models_list = config["model"] extended_models_list = [] if "extended_models" in config: from torchbenchmark.util.experiment.instantiator import list_extended_models @@ -218,14 +216,21 @@ def run(args: List[str]): results = {} try: for config in group_config.configs: - metrics = get_metrics(group_config.metrics, config) - if "accuracy" in metrics: - metrics_dict = run_config_accuracy(config, metrics, dryrun=args.dryrun) - else: - metrics_dict = run_config(config, metrics, dryrun=args.dryrun) + try: + metrics = get_metrics(group_config.metrics, config) + if "accuracy" in metrics: + metrics_dict = run_config_accuracy(config, metrics, dryrun=args.dryrun) + else: + metrics_dict = run_config(config, metrics, dryrun=args.dryrun) + except KeyboardInterrupt: + raise KeyboardInterrupt + except Exception as e: + print(f"for config, ran into error: {e}") + metrics_dict = {} config_str = config_to_str(config) - for metric in metrics_dict: - results[f"{config_str}, metric={metric}"] = metrics_dict[metric] + for metric in metrics: + results[f"{config_str}, metric={metric}"] = metrics_dict.get(metric, "err") + print(f">>metric={metric}: {metrics_dict.get(metric, 'err')}") except KeyboardInterrupt: print("User keyboard interrupted!") result = get_output_json(BM_NAME, results)