From 926fb9bf8e36a585e4b35199efbdaf9066c224d4 Mon Sep 17 00:00:00 2001 From: achew010 <165894159+achew010@users.noreply.github.com> Date: Fri, 20 Sep 2024 22:44:24 +0800 Subject: [PATCH] feat: Add DataClass Arguments to Activate Padding-Free and MultiPack Plugin and FastKernels (#280) * add arguments to activate ilab plugin Signed-off-by: 1000960000 user Signed-off-by: 1000850000 user * plugin rename Signed-off-by: 1000850000 user * added multipack dataclass Signed-off-by: 1000850000 user * formatted scripts Signed-off-by: 1000850000 user * fix unit tests Signed-off-by: 1000850000 user * additional fmt fixes Signed-off-by: 1000850000 user * fix import Signed-off-by: Yu Chin Fabian Lim Signed-off-by: 1000850000 user * removed unit test enforcing pretokenized datasets with paddingfree Signed-off-by: 1000850000 user * modifications to dataclasses to support fast kernels on full finetuning Signed-off-by: 1000850000 user * minor syntax fixes and fmt Signed-off-by: 1000850000 user * added more checks and unit tests Signed-off-by: 1000850000 user * Addressed changes from code review Signed-off-by: 1000850000 user * formatting fixes to README Signed-off-by: 1000850000 user * removed experimental status for fused lora and fast kernels Signed-off-by: 1000850000 user --------- Signed-off-by: 1000960000 user Signed-off-by: 1000850000 user Signed-off-by: Yu Chin Fabian Lim Co-authored-by: Yu Chin Fabian Lim --- README.md | 17 +- tests/acceleration/spying_utils.py | 2 +- .../test_acceleration_dataclasses.py | 29 ++ .../test_acceleration_framework.py | 308 ++++++++++++++++-- tests/test_sft_trainer.py | 7 +- .../config/acceleration_configs/__init__.py | 1 + .../acceleration_framework_config.py | 48 ++- .../attention_and_distributed_packing.py | 35 ++ .../fused_ops_and_kernels.py | 17 +- tuning/sft_trainer.py | 40 ++- 10 files changed, 455 insertions(+), 49 deletions(-) create mode 100644 tuning/config/acceleration_configs/attention_and_distributed_packing.py diff --git a/README.md b/README.md index 7fd8fd5d..40e78a83 100644 --- a/README.md +++ b/README.md @@ -621,15 +621,26 @@ The list of configurations for various `fms_acceleration` plugins: - [fused_ops_and_kernels](./tuning/config/acceleration_configs/fused_ops_and_kernels.py) (experimental): - `--fused_lora`: fused lora for more efficient LoRA training. - `--fast_kernels`: fast cross-entropy, rope, rms loss kernels. +- [attention_and_distributed_packing](./tuning/config/acceleration_configs/attention_and_distributed_packing.py) (experimental): + - `--padding_free`: technique to process multiple examples in single batch without adding padding tokens that waste compute. + - `--multipack`: technique for *multi-gpu training* to balance out number of tokens processed in each device, to minimize waiting time. Notes: * `quantized_lora_config` requires that it be used along with LoRA tuning technique. See [LoRA tuning section](https://github.com/foundation-model-stack/fms-hf-tuning/tree/main?tab=readme-ov-file#lora-tuning-example) on the LoRA parameters to pass. * When setting `--auto_gptq triton_v2` plus note to also pass `--torch_dtype float16` and `--fp16`, or an exception will be raised. This is because these kernels only support this dtype. - * Currently, the `fused_ops_and_kernels` is to be used used together QLoRA or GPTQ-LORA via the `quantized_lora_config`. In the future it may be made more flexible such that `fast_kernels` can even be used with full-finetuning. * When using `fused_ops_and_kernels` together with `quantized_lora_config`, make sure to appropriately set `--fused_lora auto_gptq True` or `bitsandbytes True`; the `True` sets `fast_lora==True`. - * Currently `fused_ops_and_kernels` only supports activating `fast_loss,fast_rsm_layernorm,fast_rope_embeddings` all to `True`, so pass `--fast_kernels True True True`. - + * `fused_ops_and_kernels` works for full-finetuning, LoRA, QLoRA and GPTQ-LORA, + - pass `--fast_kernels True True True` for full finetuning/LoRA + - pass `--fast_kernels True True True --auto_gptq triton_v2 --fused_lora auto_gptq True` for GPTQ-LoRA + - pass `--fast_kernels True True True --bitsandbytes nf4 --fused_lora bitsandbytes True` for QLoRA + * Notes on Padding Free + - works for both *single* and *multi-gpu*. + - works on both *pretokenized* and *untokenized* datasets + - verified against the version found in HF main, merged in via PR https://github.com/huggingface/transformers/pull/31629. + * Notes on Multipack + - works only for *multi-gpu*. + - currently only includes the version of *multipack* optimized for linear attention implementations like *flash-attn*. Activate `TRANSFORMERS_VERBOSITY=info` to see the huggingface trainer printouts and verify that `AccelerationFramework` is activated! diff --git a/tests/acceleration/spying_utils.py b/tests/acceleration/spying_utils.py index ce1ae1f9..5dc9f7ce 100644 --- a/tests/acceleration/spying_utils.py +++ b/tests/acceleration/spying_utils.py @@ -36,7 +36,7 @@ def augmentation( def get_callbacks_and_ready_for_train(self, *args, **kwargs): spy["get_ready_for_train_calls"] += 1 - return plugin_cls.get_callbacks_and_ready_for_train(self, args, **kwargs) + return plugin_cls.get_callbacks_and_ready_for_train(self, *args, **kwargs) attributes = { "model_loader": model_loader, diff --git a/tests/acceleration/test_acceleration_dataclasses.py b/tests/acceleration/test_acceleration_dataclasses.py index fc031298..13015993 100644 --- a/tests/acceleration/test_acceleration_dataclasses.py +++ b/tests/acceleration/test_acceleration_dataclasses.py @@ -23,6 +23,11 @@ FusedOpsAndKernelsConfig, QuantizedLoraConfig, ) +from tuning.config.acceleration_configs.attention_and_distributed_packing import ( + AttentionAndDistributedPackingConfig, + MultiPack, + PaddingFree, +) from tuning.config.acceleration_configs.fused_ops_and_kernels import ( FastKernelsConfig, FusedLoraConfig, @@ -65,6 +70,24 @@ def test_dataclass_parse_successfully(): assert cfg.auto_gptq is None assert isinstance(cfg.bnb_qlora, BNBQLoraConfig) + # 3. Specifing "--padding_free" will parse a PaddingFree class + parser = transformers.HfArgumentParser( + dataclass_types=AttentionAndDistributedPackingConfig + ) + (cfg,) = parser.parse_args_into_dataclasses( + ["--padding_free", "huggingface"], + ) + assert isinstance(cfg.padding_free, PaddingFree) + + # 4. Specifing "--multipack" will parse a MultiPack class + parser = transformers.HfArgumentParser( + dataclass_types=AttentionAndDistributedPackingConfig + ) + (cfg,) = parser.parse_args_into_dataclasses( + ["--multipack", "16"], + ) + assert isinstance(cfg.multipack, MultiPack) + def test_two_dataclasses_parse_successfully_together(): """Ensure that the two dataclasses can parse arguments successfully @@ -133,3 +156,9 @@ def test_dataclass_will_fail_to_accept_illegal_args(): ValueError, match="quant_type can only be either 'nf4' or 'fp4." ): BNBQLoraConfig(quant_type="fake-quant-type") + + # 3 padding-free plugin only supports huggingface models + with pytest.raises( + ValueError, match="only 'huggingface' method currently supported." + ): + PaddingFree(method="invalid-method") diff --git a/tests/acceleration/test_acceleration_framework.py b/tests/acceleration/test_acceleration_framework.py index c907b8f0..b6acf7eb 100644 --- a/tests/acceleration/test_acceleration_framework.py +++ b/tests/acceleration/test_acceleration_framework.py @@ -17,6 +17,7 @@ from typing import Annotated from unittest.mock import patch import copy +import os import tempfile # Third Party @@ -37,6 +38,11 @@ from tuning.config.acceleration_configs.acceleration_framework_config import ( ConfigAnnotation, ) +from tuning.config.acceleration_configs.attention_and_distributed_packing import ( + AttentionAndDistributedPackingConfig, + MultiPack, + PaddingFree, +) from tuning.config.acceleration_configs.fused_ops_and_kernels import ( FastKernelsConfig, FusedLoraConfig, @@ -47,11 +53,24 @@ ) from tuning.utils.import_utils import is_fms_accelerate_available +# for some reason the CI will raise an import error if we try to import +# these from tests.data +TWITTER_COMPLAINTS_JSON_FORMAT = os.path.join( + os.path.dirname(__file__), "../data/twitter_complaints_json.json" +) +TWITTER_COMPLAINTS_TOKENIZED = os.path.join( + os.path.dirname(__file__), + "../data/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json", +) + # pylint: disable=import-error if is_fms_accelerate_available(): # Third Party - from fms_acceleration.utils.test_utils import build_framework_and_maybe_instantiate + from fms_acceleration.utils.test_utils import ( + build_framework_and_maybe_instantiate, + instantiate_model_patcher, + ) if is_fms_accelerate_available(plugins="peft"): # Third Party @@ -62,7 +81,11 @@ if is_fms_accelerate_available(plugins="foak"): # Third Party - from fms_acceleration_foak import FastQuantizedPeftAccelerationPlugin + from fms_acceleration_foak import FastKernelsAccelerationPlugin + + if is_fms_accelerate_available(plugins="aadp"): + # Third Party + from fms_acceleration_aadp import PaddingFreeAccelerationPlugin # There are more extensive unit tests in the @@ -351,6 +374,8 @@ def test_framework_intialized_properly_peft( train_args.output_dir = tempdir train_args.save_strategy = "no" train_args.fp16 = True + peft_args = copy.deepcopy(PEFT_LORA_ARGS) + peft_args.target_modules = ["q_proj", "k_proj"] installation_path, (MockedPlugin, spy) = mock_and_spy @@ -361,13 +386,14 @@ def test_framework_intialized_properly_peft( [([installation_path], MockedPlugin)], instantiate=False, ): - sft_trainer.train( - model_args, - DATA_ARGS, - train_args, - PEFT_LORA_ARGS, - quantized_lora_config=quantized_lora_config, - ) + with instantiate_model_patcher(): + sft_trainer.train( + model_args, + DATA_ARGS, + train_args, + peft_args, + quantized_lora_config=quantized_lora_config, + ) # spy inside the train to ensure that the acceleration plugin # was called. In the context of the AutoGPTQ plugin @@ -399,6 +425,8 @@ def test_framework_intialized_properly_foak(): train_args.output_dir = tempdir train_args.save_strategy = "no" train_args.fp16 = True + peft_args = copy.deepcopy(PEFT_LORA_ARGS) + peft_args.target_modules = ["q_proj", "k_proj"] # setup default quantized lora args dataclass # - with auth gptq as the quantized method @@ -406,7 +434,7 @@ def test_framework_intialized_properly_foak(): fusedops_kernels_config = FusedOpsAndKernelsConfig( fused_lora=FusedLoraConfig(base_layer="auto_gptq", fused_lora=True), fast_kernels=FastKernelsConfig( - fast_loss=True, fast_rsm_layernorm=True, fast_rope_embeddings=True + fast_loss=True, fast_rms_layernorm=True, fast_rope_embeddings=True ), ) @@ -415,7 +443,7 @@ def test_framework_intialized_properly_foak(): "AutoGPTQMock", AutoGPTQAccelerationPlugin ) MockedPlugin2, spy2 = create_mock_plugin_class_and_spy( - "FastPeftMock", FastQuantizedPeftAccelerationPlugin + "FastPeftMock", FastKernelsAccelerationPlugin ) # 1. mock a plugin class @@ -428,14 +456,15 @@ def test_framework_intialized_properly_foak(): ], instantiate=False, ): - sft_trainer.train( - model_args, - DATA_ARGS, - train_args, - PEFT_LORA_ARGS, - quantized_lora_config=quantized_lora_config, - fusedops_kernels_config=fusedops_kernels_config, - ) + with instantiate_model_patcher(): + sft_trainer.train( + model_args, + DATA_ARGS, + train_args, + peft_args, + quantized_lora_config=quantized_lora_config, + fusedops_kernels_config=fusedops_kernels_config, + ) # spy inside the train to ensure that the AutoGPTQ plugin is called assert spy["model_loader_calls"] == 1 @@ -446,3 +475,244 @@ def test_framework_intialized_properly_foak(): assert spy2["model_loader_calls"] == 0 assert spy2["augmentation_calls"] == 1 assert spy2["get_ready_for_train_calls"] == 1 + + +@pytest.mark.skipif( + not is_fms_accelerate_available(plugins="aadp"), + reason="Only runs if fms-accelerate is installed along with \ + attention_and_distributed_packing plugin", +) +def test_framework_initialize_and_trains_with_aadp(): + """ + Ensure that a properly configured aadp dataclass is + correctly activated in train. + """ + + with tempfile.TemporaryDirectory() as tempdir: + + model_args = copy.deepcopy(MODEL_ARGS) + model_args.model_name_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v0.3" + model_args.use_flash_attn = True + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + train_args.save_strategy = "no" + data_args = copy.deepcopy(DATA_ARGS) + data_args.training_data_path = TWITTER_COMPLAINTS_JSON_FORMAT + data_args.response_template = "\n\n### Label:" + data_args.dataset_text_field = "output" + + # initialize a config + aadp_config = AttentionAndDistributedPackingConfig( + padding_free=PaddingFree(method="huggingface") + ) + + # create mocked plugin class for spying + MockedPlugin1, spy = create_mock_plugin_class_and_spy( + "PaddingFreeMock", PaddingFreeAccelerationPlugin + ) + + # 1. mock a plugin class + # 2. register the mocked plugins + # 3. call sft_trainer.train + with build_framework_and_maybe_instantiate( + [ + (["training.attention.padding_free"], MockedPlugin1), + ], + instantiate=False, + ): + with instantiate_model_patcher(): + sft_trainer.train( + model_args, + data_args, + train_args, + attention_and_distributed_packing_config=aadp_config, + ) + + # spy inside the train to ensure that the ilab plugin is called + assert spy["model_loader_calls"] == 0 + assert spy["augmentation_calls"] == 1 + assert spy["get_ready_for_train_calls"] == 1 + + +@pytest.mark.skipif( + not is_fms_accelerate_available(plugins="aadp"), + reason="Only runs if fms-accelerate is installed along with \ + attention_and_distributed_packing plugin", +) +def test_error_raised_with_paddingfree_and_flash_attn_disabled(): + """Ensure error raised when padding-free is not used with flash attention""" + with pytest.raises( + ValueError, + match="`--padding_free` argument was called without enabling " + "flash attention, ensure `use_flash_attn = True` to use padding-free flash attention", + ): + attention_and_distributed_packing_config = AttentionAndDistributedPackingConfig( + padding_free=PaddingFree(method="huggingface") + ) + model_args = copy.deepcopy(MODEL_ARGS) + model_args.use_flash_attn = False + sft_trainer.train( + model_args, + DATA_ARGS, + TRAIN_ARGS, + attention_and_distributed_packing_config=attention_and_distributed_packing_config, + ) + + +@pytest.mark.skipif( + not is_fms_accelerate_available(plugins="aadp"), + reason="Only runs if fms-accelerate is installed along with \ + attention_and_distributed_packing plugin", +) +def test_error_raised_with_multipack_and_paddingfree_disabled(): + """Ensure error raised when padding-free is not used with multipack""" + with pytest.raises( + ValueError, + match="`--multipack` is currently only supported with `--padding_free`", + ): + attention_and_distributed_packing_config = AttentionAndDistributedPackingConfig( + multipack=MultiPack(num_processes=16), + padding_free=None, + ) + model_args = copy.deepcopy(MODEL_ARGS) + sft_trainer.train( + model_args, + DATA_ARGS, + TRAIN_ARGS, + attention_and_distributed_packing_config=attention_and_distributed_packing_config, + ) + + +@pytest.mark.skipif( + not is_fms_accelerate_available(plugins="aadp"), + reason="Only runs if fms-accelerate is installed along with \ + attention_and_distributed_packing plugin", +) +def test_error_raised_with_packing_and_paddingfree_enabled(): + """Ensure error raised when padding-free is used with packing""" + with pytest.raises( + ValueError, + match="`--padding_free` argument was called with `packing=True`, " + "Trainer should not perform packing when using `--padding_free`", + ): + attention_and_distributed_packing_config = AttentionAndDistributedPackingConfig( + padding_free=PaddingFree(method="huggingface") + ) + model_args = copy.deepcopy(MODEL_ARGS) + model_args.use_flash_attn = True + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.packing = True + sft_trainer.train( + model_args, + DATA_ARGS, + train_args, + attention_and_distributed_packing_config=attention_and_distributed_packing_config, + ) + + +@pytest.mark.skipif( + not is_fms_accelerate_available(plugins="foak"), + reason="Only runs if fms-accelerate is installed along with \ + fused_ops_and_kernels plugin", +) +def test_error_raised_with_fused_lora_enabled_without_quantized_argument(): + """ + Ensure error is thrown when `--fused_lora` is passed without + `--auto_gptq` or `bitsandbytes` + """ + with pytest.raises( + ValueError, + match="`--fused_lora` must be accompanied by a quantized base layer " + "`--auto_gptq` or `--bitsandbytes`.", + ): + with tempfile.TemporaryDirectory() as tempdir: + # instantiate the arguments + model_args = copy.deepcopy(MODEL_ARGS) + model_args.model_name_or_path = "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ" + model_args.torch_dtype = torch.float16 + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + train_args.save_strategy = "no" + train_args.fp16 = True + peft_args = copy.deepcopy(PEFT_LORA_ARGS) + peft_args.target_modules = ["q_proj", "k_proj"] + + # setup FOAK config with fused lora + fusedops_kernels_config = FusedOpsAndKernelsConfig( + fused_lora=FusedLoraConfig(base_layer="auto_gptq", fused_lora=True), + ) + + # pass FOAK config but don't specify quantized base layer to sft_trainer + # expect error in framework instantiation + with build_framework_and_maybe_instantiate( + [ + (["training.fused_ops_and_kernels"], fusedops_kernels_config), + ], + instantiate=False, + ): + with instantiate_model_patcher(): + sft_trainer.train( + model_args, + DATA_ARGS, + train_args, + peft_args, + quantized_lora_config=None, + fusedops_kernels_config=fusedops_kernels_config, + ) + + +@pytest.mark.skipif( + not is_fms_accelerate_available(plugins="foak"), + reason="Only runs if fms-accelerate is installed along with \ + fused_ops_and_kernels plugin", +) +def test_fastkernels_with_full_finetuning_runs_successfully(): + """ + Ensure that a properly configured fastkernels dataclass will train with full FT. + """ + with tempfile.TemporaryDirectory() as tempdir: + + model_args = copy.deepcopy(MODEL_ARGS) + model_args.model_name_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v0.3" + model_args.torch_dtype = torch.bfloat16 + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + train_args.save_strategy = "no" + train_args.bf16 = True + data_args = copy.deepcopy(DATA_ARGS) + data_args.training_data_path = TWITTER_COMPLAINTS_JSON_FORMAT + data_args.response_template = "\n\n### Label:" + data_args.dataset_text_field = "output" + + # initialize a FOAK config + fusedops_kernels_config = FusedOpsAndKernelsConfig( + fast_kernels=FastKernelsConfig( + fast_loss=True, fast_rms_layernorm=True, fast_rope_embeddings=True + ), + ) + + # create mocked plugin class for spying + MockedPlugin1, spy = create_mock_plugin_class_and_spy( + "FastKernelsMock", FastKernelsAccelerationPlugin + ) + + # 1. mock a plugin class + # 2. register the mocked plugins + # 3. call sft_trainer.train + with build_framework_and_maybe_instantiate( + [ + (["training.fused_ops_and_kernels"], MockedPlugin1), + ], + instantiate=False, + ): + with instantiate_model_patcher(): + sft_trainer.train( + model_args, + data_args, + train_args, + fusedops_kernels_config=fusedops_kernels_config, + ) + + # spy inside train to ensure that the aadp plugin is called + assert spy["augmentation_calls"] == 1 + assert spy["get_ready_for_train_calls"] == 1 diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 2d55b7de..b2054700 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -334,6 +334,7 @@ def test_parse_arguments(job_config): _, _, _, + _, ) = sft_trainer.parse_arguments(parser, job_config_copy) assert str(model_args.torch_dtype) == "torch.bfloat16" assert data_args.dataset_text_field == "output" @@ -347,7 +348,7 @@ def test_parse_arguments_defaults(job_config): assert "torch_dtype" not in job_config_defaults assert job_config_defaults["use_flash_attn"] is False assert "save_strategy" not in job_config_defaults - model_args, _, training_args, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( + model_args, _, training_args, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_defaults ) assert str(model_args.torch_dtype) == "torch.bfloat16" @@ -359,14 +360,14 @@ def test_parse_arguments_peft_method(job_config): parser = sft_trainer.get_parser() job_config_pt = copy.deepcopy(job_config) job_config_pt["peft_method"] = "pt" - _, _, _, _, tune_config, _, _, _, _, _ = sft_trainer.parse_arguments( + _, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_pt ) assert isinstance(tune_config, peft_config.PromptTuningConfig) job_config_lora = copy.deepcopy(job_config) job_config_lora["peft_method"] = "lora" - _, _, _, _, tune_config, _, _, _, _, _ = sft_trainer.parse_arguments( + _, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_lora ) assert isinstance(tune_config, peft_config.LoraConfig) diff --git a/tuning/config/acceleration_configs/__init__.py b/tuning/config/acceleration_configs/__init__.py index f971e210..4f20a0af 100644 --- a/tuning/config/acceleration_configs/__init__.py +++ b/tuning/config/acceleration_configs/__init__.py @@ -14,5 +14,6 @@ # Local from .acceleration_framework_config import AccelerationFrameworkConfig +from .attention_and_distributed_packing import AttentionAndDistributedPackingConfig from .fused_ops_and_kernels import FusedOpsAndKernelsConfig from .quantized_lora_config import QuantizedLoraConfig diff --git a/tuning/config/acceleration_configs/acceleration_framework_config.py b/tuning/config/acceleration_configs/acceleration_framework_config.py index ed0f54d1..46fbe6b0 100644 --- a/tuning/config/acceleration_configs/acceleration_framework_config.py +++ b/tuning/config/acceleration_configs/acceleration_framework_config.py @@ -21,6 +21,7 @@ import yaml # Local +from .attention_and_distributed_packing import MultiPack, PaddingFree from .fused_ops_and_kernels import FastKernelsConfig, FusedLoraConfig from .quantized_lora_config import AutoGPTQLoraConfig, BNBQLoraConfig from tuning.utils.import_utils import is_fms_accelerate_available @@ -83,7 +84,7 @@ class AccelerationFrameworkConfig: ConfigAnnotation( path="peft.quantization", key="fused_ops_and_kernels", - experimental=True, + experimental=False, required_packages=["foak"], ), ] = None @@ -91,13 +92,50 @@ class AccelerationFrameworkConfig: fast_kernels: Annotated[ FastKernelsConfig, ConfigAnnotation( - path="peft.quantization", + path="training", key="fused_ops_and_kernels", - experimental=True, + experimental=False, required_packages=["foak"], ), ] = None + padding_free: Annotated[ + PaddingFree, + ConfigAnnotation( + path="training.attention", + experimental=True, + required_packages=["aadp"], + ), + ] = None + + multipack: Annotated[ + MultiPack, + ConfigAnnotation( + path="training.dataloader", + experimental=True, + required_packages=["aadp"], + ), + ] = None + + def _verify_configured_dataclasses(self): + if self.multipack is not None: + # ensure if multipack is set, padding free is also turned on as well + # this also ensures that the attention implementation for multipack + # will be flash attention as sfttrainer will enforce flash attn to be + # set for padding free + if self.padding_free is None: + raise ValueError( + "`--multipack` is currently only supported with `--padding_free`" + ) + + # Check that fused lora must be activated with either auto_gptq or bitsandbytes + if self.fused_lora is not None: + if self.bitsandbytes is None and self.auto_gptq is None: + raise ValueError( + "`--fused_lora` must be accompanied by a quantized base layer" + " `--auto_gptq` or `--bitsandbytes`." + ) + @staticmethod def from_dataclasses(*dataclasses: Type): "Convert one or many FMS config dataclasses to a monolithic AccelerationConfig" @@ -115,6 +153,7 @@ def from_dataclasses(*dataclasses: Type): # first unroll all the dataclases into a single level nested_dataclasses = [] + for dc in dataclasses: if dc is None: continue @@ -159,10 +198,11 @@ def from_dataclasses(*dataclasses: Type): setattr(config, fi.name, dc) del rem_fields[fi.name] # remove the field + # perform some checks on dataclasse + config._verify_configured_dataclasses() return config def get_framework(self): - if is_fms_accelerate_available(): # to be eventually be made to be passed as a dict to Acceleration diff --git a/tuning/config/acceleration_configs/attention_and_distributed_packing.py b/tuning/config/acceleration_configs/attention_and_distributed_packing.py new file mode 100644 index 00000000..e1ed83a5 --- /dev/null +++ b/tuning/config/acceleration_configs/attention_and_distributed_packing.py @@ -0,0 +1,35 @@ +# Standard +from dataclasses import dataclass + +# Local +from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass + + +@parsable_dataclass +@dataclass +class PaddingFree: + + method: str = "huggingface" + + def __post_init__(self): + if self.method != "huggingface": + raise ValueError("only 'huggingface' method currently supported.") + + +@parsable_dataclass +@dataclass +class MultiPack: + + num_processes: int = 16 + + +@dataclass +class AttentionAndDistributedPackingConfig: + + padding_free: PaddingFree = None + + multipack: MultiPack = None + + def __post_init__(self): + # ensure nested dataclasses initialized + ensure_nested_dataclasses_initialized(self) diff --git a/tuning/config/acceleration_configs/fused_ops_and_kernels.py b/tuning/config/acceleration_configs/fused_ops_and_kernels.py index ded51415..599d6b8c 100644 --- a/tuning/config/acceleration_configs/fused_ops_and_kernels.py +++ b/tuning/config/acceleration_configs/fused_ops_and_kernels.py @@ -54,19 +54,11 @@ class FastKernelsConfig(List): fast_loss: bool = False # fast rms norm triton kernels - fast_rsm_layernorm: bool = False + fast_rms_layernorm: bool = False # fast RoPE embedding triton kernels fast_rope_embeddings: bool = False - def __post_init__(self): - - if not self.fast_loss == self.fast_rsm_layernorm == self.fast_rope_embeddings: - raise ValueError( - "fast_loss, fast_rms_layernorm and fast_rope_embedding must be enabled " - "together. This restriction may be relaxed in the future." - ) - @dataclass class FusedOpsAndKernelsConfig: @@ -78,13 +70,6 @@ class FusedOpsAndKernelsConfig: fast_kernels: FastKernelsConfig = None def __post_init__(self): - if (self.fused_lora is not None and self.fast_kernels is None) or ( - self.fused_lora is None and self.fast_kernels is not None - ): - raise ValueError( - "fused lora and fast_kernels must be used together. " - "This restriction may be relaxed in the future." - ) # ensure nested dataclasses initialized ensure_nested_dataclasses_initialized(self) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 4b5d9b4a..beb89462 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -44,6 +44,7 @@ from tuning.config import configs, peft_config from tuning.config.acceleration_configs import ( AccelerationFrameworkConfig, + AttentionAndDistributedPackingConfig, FusedOpsAndKernelsConfig, QuantizedLoraConfig, ) @@ -86,6 +87,9 @@ def train( exp_metadata: Optional[Dict] = None, quantized_lora_config: Optional[QuantizedLoraConfig] = None, fusedops_kernels_config: Optional[FusedOpsAndKernelsConfig] = None, + attention_and_distributed_packing_config: Optional[ + AttentionAndDistributedPackingConfig + ] = None, ): """Call the SFTTrainer @@ -113,6 +117,7 @@ def train( fusedops_kernels_config: tuning.config.acceleration_configs.FusedOpsAndKernelsConfig \ Should be used in combination with quantized_lora_config. Also currently fused_lora and fast_kernels must used together (may change in future). \ + attention_and_distributed_packing_config: Used for padding-free attention and multipack. """ train_args, logger = set_log_level(train_args, "sft_trainer_train") @@ -127,6 +132,24 @@ def train( ): raise ValueError("gradient_accumulation_steps has to be an integer >= 1") + if ( + attention_and_distributed_packing_config is not None + and attention_and_distributed_packing_config.padding_free is not None + ): + if model_args.use_flash_attn is False: + raise ValueError( + "`--padding_free` argument was called without enabling flash attention, " + "ensure `use_flash_attn = True` to use padding-free flash attention" + ) + + if train_args.packing: + # We prevent Trainer from performing packing with padding_free. + # Since the plugin computes attention efficiently without padding. + raise ValueError( + "`--padding_free` argument was called with `packing=True`, " + "Trainer should not perform packing when using `--padding_free`" + ) + task_type = "CAUSAL_LM" additional_metrics = {} @@ -178,7 +201,9 @@ def train( trainer_callbacks.append(cb) framework = AccelerationFrameworkConfig.from_dataclasses( - quantized_lora_config, fusedops_kernels_config + quantized_lora_config, + fusedops_kernels_config, + attention_and_distributed_packing_config, ).get_framework() model_loader = AutoModelForCausalLM.from_pretrained @@ -369,7 +394,7 @@ def train( ) if framework is not None: - accelerator = None if not is_accelerate_available else trainer.accelerator + accelerator = None if not is_accelerate_available() else trainer.accelerator # ready for train may produce additional callbacks for the trainer for x in framework.get_callbacks_and_ready_for_train(model, accelerator): @@ -436,6 +461,7 @@ def get_parser(): AimConfig, QuantizedLoraConfig, FusedOpsAndKernelsConfig, + AttentionAndDistributedPackingConfig, ) ) parser.add_argument( @@ -482,6 +508,8 @@ def parse_arguments(parser, json_config=None): Configuration for quantized LoRA (a form of PEFT). FusedOpsAndKernelsConfig Configuration for fused operations and kernels. + AttentionAndDistributedPackingConfig + Configuration for padding free and packing. dict[str, str] Extra AIM metadata. """ @@ -497,6 +525,7 @@ def parse_arguments(parser, json_config=None): aim_config, quantized_lora_config, fusedops_kernels_config, + attention_and_distributed_packing_config, ) = parser.parse_dict(json_config, allow_extra_keys=True) peft_method = json_config.get("peft_method") exp_metadata = json_config.get("exp_metadata") @@ -512,6 +541,7 @@ def parse_arguments(parser, json_config=None): aim_config, quantized_lora_config, fusedops_kernels_config, + attention_and_distributed_packing_config, additional, _, ) = parser.parse_args_into_dataclasses(return_remaining_strings=True) @@ -536,6 +566,7 @@ def parse_arguments(parser, json_config=None): aim_config, quantized_lora_config, fusedops_kernels_config, + attention_and_distributed_packing_config, exp_metadata, ) @@ -556,6 +587,7 @@ def main(): aim_config, quantized_lora_config, fusedops_kernels_config, + attention_and_distributed_packing_config, exp_metadata, ) = parse_arguments(parser, job_config) @@ -567,7 +599,7 @@ def main(): model_args %s, data_args %s, training_args %s, trainer_controller_args %s, \ tune_config %s, file_logger_config, %s aim_config %s, \ quantized_lora_config %s, fusedops_kernels_config %s, \ - exp_metadata %s", + attention_and_distributed_packing_config %s exp_metadata %s", model_args, data_args, training_args, @@ -577,6 +609,7 @@ def main(): aim_config, quantized_lora_config, fusedops_kernels_config, + attention_and_distributed_packing_config, exp_metadata, ) except Exception as e: # pylint: disable=broad-except @@ -618,6 +651,7 @@ def main(): exp_metadata=metadata, quantized_lora_config=quantized_lora_config, fusedops_kernels_config=fusedops_kernels_config, + attention_and_distributed_packing_config=attention_and_distributed_packing_config, ) except (MemoryError, OutOfMemoryError) as e: logger.error(traceback.format_exc())