diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 22390d649..436367cb5 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -486,7 +486,7 @@ python run_generation.py \ ### Loading 4 Bit Checkpoints from Hugging Face -You can load pre-quantized 4bit models with the argument `--load_quantized_model`. +You can load pre-quantized 4bit models with the argument `--load_quantized_model_with_inc`. Currently, uint4 checkpoints and single device are supported. More information on enabling 4 bit inference in SynapseAI is available here: https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_UINT4.html. @@ -508,7 +508,35 @@ python run_lm_eval.py \ --attn_softmax_bf16 \ --bucket_size=128 \ --bucket_internal \ ---load_quantized_model +--load_quantized_model_with_inc +``` + +### Loading 4 Bit Checkpoints from Neural Compressor (INC) + +You can load a pre-quantized 4-bit checkpoint with the argument `--quantized_inc_model_path`, supplied with the original model with the argument `--model_name_or_path`. +Currently, only uint4 checkpoints and single-device configurations are supported. +**Note:** In this process, you can load a checkpoint that has been quantized using INC. +More information on enabling 4-bit inference in SynapseAI is available here: +https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_INT4.html. + +Below is an example of loading a llama7b model with a 4bit checkpoint quantized in INC. +Please note that the model checkpoint name is denoted as ``. +Additionally, the following environment variables are used for performance optimizations and are planned to be removed in future versions: +`SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false ENABLE_EXPERIMENTAL_FLAGS=1` +```bash +SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false ENABLE_EXPERIMENTAL_FLAGS=1 \ +python run_lm_eval.py \ +-o acc_load_uint4_model.txt \ +--model_name_or_path meta-llama/Llama-2-7b-hf \ +--use_hpu_graphs \ +--use_kv_cache \ +--trim_logits \ +--batch_size 1 \ +--bf16 \ +--attn_softmax_bf16 \ +--bucket_size=128 \ +--bucket_internal \ +--quantized_inc_model_path \ ``` ### Using Habana Flash Attention @@ -539,6 +567,37 @@ python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \ For more details see [documentation](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html#using-fused-sdpa). +### Running with UINT4 weight quantization using AutoGPTQ + + +Llama2-7b in UINT4 weight only quantization is enabled using [AutoGPTQ Fork](https://github.com/HabanaAI/AutoGPTQ), which provides quantization capabilities in PyTorch. +Currently, the support is for UINT4 inference of pre-quantized models only. + +You can run a *UINT4 weight quantized* model using AutoGPTQ by setting the following environment variables: +`SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false ENABLE_EXPERIMENTAL_FLAGS=true` before running the command, +and by adding the argument `--load_quantized_model_with_autogptq`. + +***Note:*** +Setting the above environment variables improves performance. These variables will be removed in future releases. + + +Here is an example to run a quantized model : +```bash +SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false \ +ENABLE_EXPERIMENTAL_FLAGS=true python run_generation.py \ +--attn_softmax_bf16 \ +--model_name_or_path \ +--use_hpu_graphs \ +--limit_hpu_graphs \ +--use_kv_cache \ +--bucket_size 128 \ +--bucket_internal \ +--trim_logits \ +--max_new_tokens 128 \ +--batch_size 1 \ +--bf16 \ +--load_quantized_model_with_autogptq +``` ## Language Model Evaluation Harness diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 0e29fb7e5..fd877f746 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -293,21 +293,11 @@ def setup_parser(parser): type=str, help="Path to serialize const params. Const params will be held on disk memory instead of being allocated on host memory.", ) - parser.add_argument( - "--disk_offload", - action="store_true", - help="Whether to enable device map auto. In case no space left on cpu, weights will be offloaded to disk.", - ) parser.add_argument( "--trust_remote_code", action="store_true", help="Whether to trust the execution of code from datasets/models defined on the Hub. This option should only be set to `True` for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine.", ) - parser.add_argument( - "--load_quantized_model", - action="store_true", - help="Whether to load model from hugging face checkpoint.", - ) parser.add_argument( "--parallel_strategy", type=str, @@ -321,6 +311,35 @@ def setup_parser(parser): help="Whether to enable inputs_embeds or not.", ) + parser.add_argument( + "--run_partial_dataset", + action="store_true", + help="Run the inference with dataset for specified --n_iterations(default:5)", + ) + + quant_parser_group = parser.add_mutually_exclusive_group() + quant_parser_group.add_argument( + "--load_quantized_model_with_autogptq", + action="store_true", + help="Load an AutoGPTQ quantized checkpoint using AutoGPTQ.", + ) + quant_parser_group.add_argument( + "--disk_offload", + action="store_true", + help="Whether to enable device map auto. In case no space left on cpu, weights will be offloaded to disk.", + ) + quant_parser_group.add_argument( + "--load_quantized_model_with_inc", + action="store_true", + help="Load a Huggingface quantized checkpoint using INC.", + ) + quant_parser_group.add_argument( + "--quantized_inc_model_path", + type=str, + default=None, + help="Path to neural-compressor quantized model, if set, the checkpoint will be loaded.", + ) + args = parser.parse_args() if args.torch_compile: @@ -333,6 +352,9 @@ def setup_parser(parser): args.flash_attention_fast_softmax = True args.quant_config = os.getenv("QUANT_CONFIG", "") + if args.quant_config and args.load_quantized_model_with_autogptq: + raise RuntimeError("Setting both quant_config and load_quantized_model_with_autogptq is unsupported. ") + if args.quant_config == "" and args.disk_offload: logger.warning( "`--disk_offload` was tested only with fp8, it may not work with full precision. If error raises try to remove the --disk_offload flag." diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index c2ae975bc..9be27b5dc 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -237,10 +237,34 @@ def setup_model(args, model_dtype, model_kwargs, logger): torch_dtype=model_dtype, **model_kwargs, ) - elif args.load_quantized_model: + elif args.load_quantized_model_with_autogptq: + from transformers import GPTQConfig + + quantization_config = GPTQConfig(bits=4, use_exllama=False) + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, torch_dtype=model_dtype, quantization_config=quantization_config, **model_kwargs + ) + elif args.load_quantized_model_with_inc: from neural_compressor.torch.quantization import load model = load(model_name_or_path=args.model_name_or_path, format="huggingface", device="hpu", **model_kwargs) + elif args.quantized_inc_model_path: + org_model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + **model_kwargs, + ) + + from neural_compressor.torch.quantization import load + + model = load( + model_name_or_path=args.quantized_inc_model_path, + format="default", + device="hpu", + original_model=org_model, + **model_kwargs, + ) + # TODO: [SW-195965] Remove once load supports other types + model = model.to(model_dtype) else: if args.assistant_model is not None: assistant_model = AutoModelForCausalLM.from_pretrained( @@ -614,8 +638,7 @@ def initialize_model(args, logger): "token": args.token, "trust_remote_code": args.trust_remote_code, } - - if args.load_quantized_model: + if args.load_quantized_model_with_inc or args.quantized_inc_model_path: model_kwargs["torch_dtype"] = torch.bfloat16 if args.trust_remote_code: diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index 242f1e6a6..a266841e1 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -65,6 +65,9 @@ ("mistralai/Mixtral-8x7B-v0.1", 2, 48, True, 2048, 2048, 1147.50), ("microsoft/phi-2", 1, 1, True, 128, 128, 254.08932787178165), ], + "load_quantized_model_with_autogptq": [ + ("TheBloke/Llama-2-7b-Chat-GPTQ", 1, 10, False, 128, 2048, 456.7), + ], "deepspeed": [ ("bigscience/bloomz", 8, 1, 36.77314954096159), ("meta-llama/Llama-2-70b-hf", 8, 1, 64.10514998902435), @@ -108,6 +111,7 @@ ("state-spaces/mamba-130m-hf", 224, False, 794.542), ], "fp8": [], + "load_quantized_model_with_autogptq": [], "deepspeed": [ ("bigscience/bloomz-7b1", 8, 1, 31.994268212011505), ], @@ -130,6 +134,7 @@ def _test_text_generation( world_size: int = 8, torch_compile: bool = False, fp8: bool = False, + load_quantized_model_with_autogptq: bool = False, max_input_tokens: int = 0, max_output_tokens: int = 100, parallel_strategy: str = None, @@ -241,6 +246,8 @@ def _test_text_generation( f"--max_input_tokens {max_input_tokens}", "--limit_hpu_graphs", ] + if load_quantized_model_with_autogptq: + command += ["--load_quantized_model_with_autogptq"] if parallel_strategy is not None: command += [ f"--parallel_strategy={parallel_strategy}", @@ -334,6 +341,36 @@ def test_text_generation_fp8( ) +@pytest.mark.parametrize( + "model_name, world_size, batch_size, reuse_cache, input_len, output_len, baseline", + MODELS_TO_TEST["load_quantized_model_with_autogptq"], +) +def test_text_generation_gptq( + model_name: str, + baseline: float, + world_size: int, + batch_size: int, + reuse_cache: bool, + input_len: int, + output_len: int, + token: str, +): + deepspeed = True if world_size > 1 else False + _test_text_generation( + model_name, + baseline, + token, + deepspeed=deepspeed, + world_size=world_size, + fp8=False, + load_quantized_model_with_autogptq=True, + batch_size=batch_size, + reuse_cache=reuse_cache, + max_input_tokens=input_len, + max_output_tokens=output_len, + ) + + @pytest.mark.parametrize("model_name, world_size, batch_size, baseline", MODELS_TO_TEST["deepspeed"]) def test_text_generation_deepspeed(model_name: str, baseline: float, world_size: int, batch_size: int, token: str): _test_text_generation(model_name, baseline, token, deepspeed=True, world_size=world_size, batch_size=batch_size)