-
Notifications
You must be signed in to change notification settings - Fork 170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LoRA fine-tuning weights explosion in FSDP training #421
Comments
@weifengpy any thoughts? |
Hi @MinghaoYan , thanks for filing the issue. I practiced LoRA + FSDP in TorchTune so would love to understand if there are any FSDP bugs Are you open to take a look at the loss together with me? just to make sure the lora is setup correctly
If you just need a LoRA + FSDP recipe, or need a reference implmentation, here is one from TorchTune: https://github.com/pytorch/torchtune/blob/main/recipes/configs/dev/llama2/7B_lora_fsdp2.yaml |
Thanks for your reply! I think I might not be loading correctly. I disabled all Lora implementation and reverted to the default llama-3-8b setup. Currently I am trying to copy the weights from a HuggingFace Llama-3-8B-Instruct checkpoint to an instantiated Transformer module in torchtitan manually by replacing the Transformer init_weight function with the following:
I noticed that since the model weights don't fit on GPU, I loaded the pretrained model on CPU and then copied the weights over. This would cause problem later since some tensors would be on cpu instead of cuda. However, if I just call model.to("cuda") after init_weights(), this would throw an OOM error. I was doing this since I couldn't find a better way to load from a HuggingFace checkpoint directly into the Transformer module in torchtitan, I would appreciate it if you have a pointer on how to load from a HF checkpoint in torchtitan. I am hoping to build upon torchtitan since later on in my project, I might need to directly change the training function as well as model architecture, so a lightweight framework would help me down the line. One more thing I discovered was that init_weights() function was called twice, once at through the Transformer module init function
and once more later on explicitly
Not sure if this is the intended behavior, but just want to let you guys know. |
Here is my reference implementation to load HF checkpoint into FSDP model https://github.com/pytorch/torchtune/blob/main/recipes/dev/lora_finetune_fsdp2.py#L343.
Good question! We intentionally call init_weights() twice. 1st time is inside meta init, where we init parameters/tensors on meta device. 2nd time is on device='cuda', where we actually allocate tensor storage with real values For LoRA/finetuning case, meta init is preferred since we are loading from checkpionts eventually. So we just init params on meta device, and copy tensors from checkpoints into model ( |
Thank you very much for the pointer! After some more investigation, it does seem like the first step loss is too high (without any LoRA or any training) after loading weights from HF checkpoints directly. I get a loss of 11.79 at step 1, with train_configs/llama3_8b.toml configs and run_llama_train.sh on 4 A10 24GB GPUs. A minimal reproducible example would be instead of
I loaded weights from the checkpoints instead:
I made one change to your reference implementation, where I mapped HF param names to the names defined in torchtitan Transformers module:
I would really appreciate it if you can spot anything wrong in my code or reproduce the loss that I had. |
good catch by remapping parameter names
I would call
If you still hits error, I can try reproduce. Feel free to open a PR with your local changes
|
Thank you for your reply! I moved load_from_full_model_state_dict to after model.to_empty(...), if I keep torch device as cpu, the behavior is the same. If I change device to cuda, it would incur this new problem where it complains that tensor is not leaf (this would correspond to the full_tensor variable in the load_from_full_model_state_dict function).
I seem to be having some permission issues creating branches and PRs, I will look into it. Thanks again! |
thanks for the patience. let me know if you cannot resolve it after investigation. I might draft some example code to load from HF checkpoint. This is a common ask and we can improve |
Thank you! I have created a PR here: #427 |
I wonder if you are aware of the model definition mismatch between Llama and HF's Transformer (#335). Basically a permutation of some weights is needed to make the conversion work. |
I was not aware of this, thank you! |
Dear authors,
I encountered weights explosion problems during integrating LoRA to torchtitan. I am running with train_configs/llama3_8b.toml configs with run_llama_train.sh on 4 A10 24GB GPUs. PyTorch version is the latest 2.5.0 nightly.
I have made the following changes so far:
In train.py, I added two utility functions get_parameters(), calculate_parameter_change() to compute the difference of LoRA weight during each training step. And a call to mark_only_lora_as_trainable() function which mark all lora matrices to be trainable and freeze all other layers (implemented by the original LoRA authors in model.py). When creating model_cls from model args, I changed device from meta to cpu since this would allow me to load pretrained weights directly (Step 3).
In model.py, I replaced the wq, wk, wv matrices from nn.Linear to custom defined Linear layer implemented by the authors of LoRA to incorporate LoRA adapter training.
In model.py, I replaced the init_weights function in the Transformer module with loading pretrained weights from HF llama-3-8B-Instruct checkpoint. I checked the weights loaded and it seems to be loading the correct weights.
Since the LoRA implementation has been widely tested, I suppose it should be fine. What I also noted was that when I didn't set device to cpu in creating model_cls from model args (default in code would be meta), my weight copying operation in step 3 would essentially copy all 0s to the tensors. However, in this case, LoRA-A's behavior seems to be very similar to after I copied the weights correctly, while LoRA-B, which is 0-initialized, would keep staying at 0. In the current case, both LoRA-A and LoRA-B have exploding weights. Due to the similar behavior of LoRA-A under different initialization, I am suspecting the bug still lies in the FSDP setup somewhere. I have attached my code below and would appreciate any hints on where the bug might come from (Uploading txt files since py is not allowed, they are py files).
Below is sample output showing the progression of LoRA A/B weight changes from a sampled layer.
[rank0]:{'layers.29._checkpoint_wrapped_module.attention.wq.lora_A': 146.9622344970703, 'layers.29._checkpoint_wrapped_module.attention.wq.lora_B': 5.200856207920879e-07, 'layers.29._checkpoint_wrapped_module.attention.wk.lora_A': 147.65611267089844, 'layers.29._checkpoint_wrapped_modul
e.attention.wk.lora_B': 6.672774333082998e-08, 'layers.29._checkpoint_wrapped_module.attention.wv.lora_A': 147.62704467773438, 'layers.29._checkpoint_wrapped_module.attention.wv.lora_B': 0.018832538276910782}
[rank0]:{'layers.29._checkpoint_wrapped_module.attention.wq.lora_A': 18811.1484375, 'layers.29._checkpoint_wrapped_module.attention.wq.lora_B': 6.657096673734486e-05, 'layers.29._checkpoint_wrapped_module.attention.wk.lora_A': 18899.97265625, 'layers.29._checkpoint_wrapped_module.attent
ion.wk.lora_B': 4.27057602792047e-06, 'layers.29._checkpoint_wrapped_module.attention.wv.lora_A': 18896.2734375, 'layers.29._checkpoint_wrapped_module.attention.wv.lora_B': 1.2052823305130005}
[rank0]:{'layers.29._checkpoint_wrapped_module.attention.wq.lora_A': 2407827.0, 'layers.29._checkpoint_wrapped_module.attention.wq.lora_B': 0.008521084673702717, 'layers.29._checkpoint_wrapped_module.attention.wk.lora_A': 2419196.5, 'layers.29._checkpoint_wrapped_module.attention.wk.lor
a_B': 0.00027331686578691006, 'layers.29._checkpoint_wrapped_module.attention.wv.lora_A': 2418723.0, 'layers.29._checkpoint_wrapped_module.attention.wv.lora_B': 77.13806915283203}
model.txt
train.txt
The text was updated successfully, but these errors were encountered: