Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama3.1 models do not allow configuring max_seq_len #2202

Open
akashc1 opened this issue Dec 23, 2024 · 4 comments
Open

Llama3.1 models do not allow configuring max_seq_len #2202

akashc1 opened this issue Dec 23, 2024 · 4 comments
Labels
bug Something isn't working triaged This issue has been assigned an owner and appropriate label

Comments

@akashc1
Copy link
Contributor

akashc1 commented Dec 23, 2024

Llama 3.1 model builders hardcode the max context length, even though the component builders allow specifying it:

And since the QLoRA versions also use these, it affects that too. This prevents anyone from specifying the model's max_seq_len from a config or CLI. E.g. this config will throw an error:

output_dir: /tmp/torchtune/llama3_1_8B/lora # /tmp may be deleted by your system. Change it to your preference.
max_seq_len: 8192

# Tokenizer
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: /models/meta-llama/Llama-3.1-8B-Instruct/original/tokenizer.model
  max_seq_len: ${max_seq_len}

# Model Arguments
model:
  _component_: torchtune.models.llama3_1.lora_llama3_1_8b
  lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
  apply_lora_to_mlp: True
  apply_lora_to_output: False
  lora_rank: 8  # higher increases accuracy and memory
  lora_alpha: 16  # usually alpha=2*rank
  lora_dropout: 0.0
  max_seq_len: ${max_seq_len}
[rank4]: Traceback (most recent call last):
[rank4]:   File "/torchtune/recipes/lora_finetune_distributed.py", line 938, in <module>
[rank4]:     sys.exit(recipe_main())
[rank4]:   File "/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank4]:     sys.exit(recipe_main(conf))
[rank4]:   File "/torchtune/recipes/lora_finetune_distributed.py", line 932, in recipe_main
[rank4]:     recipe.setup(cfg=cfg)
[rank4]:   File "/torchtune/recipes/lora_finetune_distributed.py", line 272, in setup
[rank4]:     self._model = self._setup_model(
[rank4]:   File "/torchtune/recipes/lora_finetune_distributed.py", line 453, in _setup_model
[rank4]:     model = config.instantiate(cfg_model)
[rank4]:   File "/torchtune/torchtune/config/_instantiate.py", line 112, in instantiate
[rank4]:     return _instantiate_node(OmegaConf.to_object(config), *args)
[rank4]:   File "/torchtune/torchtune/config/_instantiate.py", line 33, in _instantiate_node
[rank4]:     return _create_component(_component_, args, kwargs)
[rank4]:   File "/torchtune/torchtune/config/_instantiate.py", line 22, in _create_component
[rank4]:     return _component_(*args, **kwargs)
[rank4]: TypeError: lora_llama3_1_8b() got an unexpected keyword argument 'max_seq_len'

In my workload, and I'm sure for others as well, I need to specify the context length differently.

@felipemello1
Copy link
Contributor

hey @akashc1, I see. We should fix it, but to be clear, you don't have to redefine max_seq_len for the model. This is used only for the positional embedding, you can leave it as 131k, unless you are trying to go beyond that.

For your specific case, where you just want to limit your sequence length in the batch, you should change it only in the tokenizer.

output_dir: /tmp/torchtune/llama3_1_8B/lora # /tmp may be deleted by your system. Change it to your preference.

# Tokenizer
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: /models/meta-llama/Llama-3.1-8B-Instruct/original/tokenizer.model
  max_seq_len: 8192

# Model Arguments
model:
  _component_: torchtune.models.llama3_1.lora_llama3_1_8b
  lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
  apply_lora_to_mlp: True
  apply_lora_to_output: False
  lora_rank: 8  # higher increases accuracy and memory
  lora_alpha: 16  # usually alpha=2*rank
  lora_dropout: 0.0
  # removed max_seq_len

@felipemello1 felipemello1 added triaged This issue has been assigned an owner and appropriate label bug Something isn't working labels Dec 24, 2024
@akashc1
Copy link
Contributor Author

akashc1 commented Dec 24, 2024

@felipemello1 yes I understand that, however the transformer implementation does throw an error if it gets a seq_len longer than it was expecting from init. I've run into this when e.g. training Llama 3.1 405b LoRA with tokenizer.max_seq_len = 16384; if you don't specify in the model config as well it by default expects max_seq_len = 8192. Especially pertinent with data packing.

This is an example config that I've been using to ensure the tokenizer & model both produce/expect the same thing

max_seq_len: 16384

# Tokenizer
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  ...
  max_seq_len: ${max_seq_len}

# Model Arguments
model:
  _component_: torchtune.models.llama3_1.lora_llama3_1_405b
  ...
  max_seq_len: ${max_seq_len}

# Dataset and Sampler
dataset:
  _component_: torchtune.datasets.instruct_dataset
  ...
  packed: True

Let me know if this explains what I'm running into, or if I can clarify further :)

@felipemello1
Copy link
Contributor

Oh, i see! Ok, there is a deeper problem that we set max_seq_len for 405B as 8k, which is wrong. The model was trained for >131k. Thanks for raising it.

I left a comment in your PR (https://github.com/pytorch/torchtune/pull/2203/files): Lets just update the docstring and I will approve it. By the way, should you try llama 3.3 70B? It has better or equivalent performance to 405B.

@felipemello1
Copy link
Contributor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working triaged This issue has been assigned an owner and appropriate label
Projects
None yet
Development

No branches or pull requests

2 participants