Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA fine-tuning weights explosion in FSDP training #421

Open
MinghaoYan opened this issue Jun 24, 2024 · 12 comments
Open

LoRA fine-tuning weights explosion in FSDP training #421

MinghaoYan opened this issue Jun 24, 2024 · 12 comments
Assignees

Comments

@MinghaoYan
Copy link

MinghaoYan commented Jun 24, 2024

Dear authors,

I encountered weights explosion problems during integrating LoRA to torchtitan. I am running with train_configs/llama3_8b.toml configs with run_llama_train.sh on 4 A10 24GB GPUs. PyTorch version is the latest 2.5.0 nightly.

I have made the following changes so far:

  1. In train.py, I added two utility functions get_parameters(), calculate_parameter_change() to compute the difference of LoRA weight during each training step. And a call to mark_only_lora_as_trainable() function which mark all lora matrices to be trainable and freeze all other layers (implemented by the original LoRA authors in model.py). When creating model_cls from model args, I changed device from meta to cpu since this would allow me to load pretrained weights directly (Step 3).

  2. In model.py, I replaced the wq, wk, wv matrices from nn.Linear to custom defined Linear layer implemented by the authors of LoRA to incorporate LoRA adapter training.

  3. In model.py, I replaced the init_weights function in the Transformer module with loading pretrained weights from HF llama-3-8B-Instruct checkpoint. I checked the weights loaded and it seems to be loading the correct weights.

Since the LoRA implementation has been widely tested, I suppose it should be fine. What I also noted was that when I didn't set device to cpu in creating model_cls from model args (default in code would be meta), my weight copying operation in step 3 would essentially copy all 0s to the tensors. However, in this case, LoRA-A's behavior seems to be very similar to after I copied the weights correctly, while LoRA-B, which is 0-initialized, would keep staying at 0. In the current case, both LoRA-A and LoRA-B have exploding weights. Due to the similar behavior of LoRA-A under different initialization, I am suspecting the bug still lies in the FSDP setup somewhere. I have attached my code below and would appreciate any hints on where the bug might come from (Uploading txt files since py is not allowed, they are py files).

Below is sample output showing the progression of LoRA A/B weight changes from a sampled layer.

[rank0]:{'layers.29._checkpoint_wrapped_module.attention.wq.lora_A': 146.9622344970703, 'layers.29._checkpoint_wrapped_module.attention.wq.lora_B': 5.200856207920879e-07, 'layers.29._checkpoint_wrapped_module.attention.wk.lora_A': 147.65611267089844, 'layers.29._checkpoint_wrapped_modul
e.attention.wk.lora_B': 6.672774333082998e-08, 'layers.29._checkpoint_wrapped_module.attention.wv.lora_A': 147.62704467773438, 'layers.29._checkpoint_wrapped_module.attention.wv.lora_B': 0.018832538276910782}
[rank0]:{'layers.29._checkpoint_wrapped_module.attention.wq.lora_A': 18811.1484375, 'layers.29._checkpoint_wrapped_module.attention.wq.lora_B': 6.657096673734486e-05, 'layers.29._checkpoint_wrapped_module.attention.wk.lora_A': 18899.97265625, 'layers.29._checkpoint_wrapped_module.attent
ion.wk.lora_B': 4.27057602792047e-06, 'layers.29._checkpoint_wrapped_module.attention.wv.lora_A': 18896.2734375, 'layers.29._checkpoint_wrapped_module.attention.wv.lora_B': 1.2052823305130005}
[rank0]:{'layers.29._checkpoint_wrapped_module.attention.wq.lora_A': 2407827.0, 'layers.29._checkpoint_wrapped_module.attention.wq.lora_B': 0.008521084673702717, 'layers.29._checkpoint_wrapped_module.attention.wk.lora_A': 2419196.5, 'layers.29._checkpoint_wrapped_module.attention.wk.lor
a_B': 0.00027331686578691006, 'layers.29._checkpoint_wrapped_module.attention.wv.lora_A': 2418723.0, 'layers.29._checkpoint_wrapped_module.attention.wv.lora_B': 77.13806915283203}

model.txt
train.txt

@tianyu-l
Copy link
Contributor

@weifengpy any thoughts?

@weifengpy
Copy link
Contributor

weifengpy commented Jun 25, 2024

Hi @MinghaoYan , thanks for filing the issue. I practiced LoRA + FSDP in TorchTune so would love to understand if there are any FSDP bugs

Are you open to take a look at the loss together with me? just to make sure the lora is setup correctly

  1. are you loading from a pretrained checkpoint, say llama2 or Llama3?
  2. after loading the model, without applying any lora adapters, could I know the loss? For example, loss = 2.0. This is just checking if the model is loaded correctly
  3. after applying lora adapters, before 1st optim.step, could I know the loss? I would expect the the loss to be similar to 2.0 from step 2. This is checking if lora adapters are added correctly

If you just need a LoRA + FSDP recipe, or need a reference implmentation, here is one from TorchTune: https://github.com/pytorch/torchtune/blob/main/recipes/configs/dev/llama2/7B_lora_fsdp2.yaml

@MinghaoYan
Copy link
Author

MinghaoYan commented Jun 25, 2024

Thanks for your reply! I think I might not be loading correctly. I disabled all Lora implementation and reverted to the default llama-3-8b setup. Currently I am trying to copy the weights from a HuggingFace Llama-3-8B-Instruct checkpoint to an instantiated Transformer module in torchtitan manually by replacing the Transformer init_weight function with the following:

   def init_weights(self):
        with torch.device(self.freqs_cis.device):
            self.freqs_cis = self._precompute_freqs_cis()
        self._copy_weights()
   def _copy_weights(self, pretrained_model_name="meta-llama/Meta-Llama-3-8B-Instruct"):
        # Copy embedding weights
        pretrained_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name).to('cpu')

        # Copy embedding weights
        self.tok_embeddings.weight.data = pretrained_model.model.embed_tokens.weight.data.clone()

        # print(self.tok_embeddings.weight.data)
        # Copy transformer layer weights
        for i, pretrained_layer in enumerate(pretrained_model.model.layers):

            # Attention weights
            assert self.layers[str(i)].attention.wq.weight.data.shape == pretrained_layer.self_attn.q_proj.weight.data.shape, f"Mismatch in shape for wq at layer {i}"
            assert self.layers[str(i)].attention.wk.weight.data.shape == pretrained_layer.self_attn.k_proj.weight.data.shape, f"Mismatch in shape for wk at layer {i}"
            assert self.layers[str(i)].attention.wv.weight.data.shape == pretrained_layer.self_attn.v_proj.weight.data.shape, f"Mismatch in shape for wv at layer {i}"
            assert self.layers[str(i)].attention.wo.weight.data.shape == pretrained_layer.self_attn.o_proj.weight.data.shape, f"Mismatch in shape for wo at layer {i}"

            self.layers[str(i)].attention.wq.weight.data = pretrained_layer.self_attn.q_proj.weight.data.clone()
            self.layers[str(i)].attention.wk.weight.data = pretrained_layer.self_attn.k_proj.weight.data.clone()
            self.layers[str(i)].attention.wv.weight.data = pretrained_layer.self_attn.v_proj.weight.data.clone()
            self.layers[str(i)].attention.wo.weight.data = pretrained_layer.self_attn.o_proj.weight.data.clone()

            # Feed-forward weights
            assert self.layers[str(i)].feed_forward.w1.weight.data.shape == pretrained_layer.mlp.gate_proj.weight.data.shape, f"Mismatch in shape for w1 at layer {i}"
            assert self.layers[str(i)].feed_forward.w2.weight.data.shape == pretrained_layer.mlp.down_proj.weight.data.shape, f"Mismatch in shape for w2 at layer {i}"
            assert self.layers[str(i)].feed_forward.w3.weight.data.shape == pretrained_layer.mlp.up_proj.weight.data.shape, f"Mismatch in shape for w3 at layer {i}"

            self.layers[str(i)].feed_forward.w1.weight.data = pretrained_layer.mlp.gate_proj.weight.data.clone()
            self.layers[str(i)].feed_forward.w2.weight.data = pretrained_layer.mlp.down_proj.weight.data.clone()
            self.layers[str(i)].feed_forward.w3.weight.data = pretrained_layer.mlp.up_proj.weight.data.clone()

            # LayerNorm weights
            assert self.layers[str(i)].attention_norm.weight.data.shape == pretrained_layer.input_layernorm.weight.data.shape, f"Mismatch in shape for attention_norm weight at layer {i}"
            assert self.layers[str(i)].ffn_norm.weight.data.shape == pretrained_layer.post_attention_layernorm.weight.data.shape, f"Mismatch in shape for ffn_norm weight at layer {i}"

            self.layers[str(i)].attention_norm.weight.data = pretrained_layer.input_layernorm.weight.data.clone()
            self.layers[str(i)].ffn_norm.weight.data = pretrained_layer.post_attention_layernorm.weight.data.clone()

            # Init LoRA weights
            # nn.init.kaiming_uniform_(self.layers[str(i)].attention.wq.lora_A, a=math.sqrt(5))
            # nn.init.zeros_(self.layers[str(i)].attention.wq.lora_B)
            # nn.init.kaiming_uniform_(self.layers[str(i)].attention.wk.lora_A, a=math.sqrt(5))
            # nn.init.zeros_(self.layers[str(i)].attention.wk.lora_B)
            # nn.init.kaiming_uniform_(self.layers[str(i)].attention.wv.lora_A, a=math.sqrt(5))
            # nn.init.zeros_(self.layers[str(i)].attention.wv.lora_B)

        # Copy final layer norm
        assert self.norm.weight.data.shape == pretrained_model.model.norm.weight.data.shape
        self.norm.weight.data = pretrained_model.model.norm.weight.data.clone()

        # Copy lm_head weights
        assert self.output.weight.data.shape == pretrained_model.lm_head.weight.data.shape
        self.output.weight.data = pretrained_model.lm_head.weight.data.clone()

        del pretrained_model

I noticed that since the model weights don't fit on GPU, I loaded the pretrained model on CPU and then copied the weights over. This would cause problem later since some tensors would be on cpu instead of cuda. However, if I just call model.to("cuda") after init_weights(), this would throw an OOM error.

I was doing this since I couldn't find a better way to load from a HuggingFace checkpoint directly into the Transformer module in torchtitan, I would appreciate it if you have a pointer on how to load from a HF checkpoint in torchtitan.

I am hoping to build upon torchtitan since later on in my project, I might need to directly change the training function as well as model architecture, so a lightweight framework would help me down the line.

One more thing I discovered was that init_weights() function was called twice, once at through the Transformer module init function

with torch.device("meta"):
        model = model_cls.from_model_args(model_config)

and once more later on explicitly

    if parallel_dims.pp_enabled:
        pp_schedule = build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn)
    else:
        # If PP is enabled, we can't rely on init_weights, because some layers are missing.
        # In the future, we may make init_weights handle missing layers, but also have to consider RNG seed propagation.
        # allocate sharded model on GPU and initialize weights via DTensor
        model.init_weights()

Not sure if this is the intended behavior, but just want to let you guys know.

@weifengpy
Copy link
Contributor

how to load from a HF checkpoint in torchtitan

Here is my reference implementation to load HF checkpoint into FSDP model https://github.com/pytorch/torchtune/blob/main/recipes/dev/lora_finetune_fsdp2.py#L343.
full_sd below is referring to HF checkpoint.

def load_from_full_model_state_dict(
    model: "FSDPModule",
    full_sd: Dict[str, Any],
    device: torch.device,
):
    """
    Converting full state dict into a sharded state dict
    and loading it into FSDP model
    - 'full' means plain tensor
    - 'sharded' means `DTensor` where reach rank has a shard of the plain tensor
    """
    meta_sharded_sd = model.state_dict()
    sharded_sd = {}
    for param_name, full_tensor in full_sd.items():
        sharded_meta_param = meta_sharded_sd.get(param_name)
        full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device)
        sharded_tensor = distribute_tensor(
            full_tensor,
            sharded_meta_param.device_mesh,
            sharded_meta_param.placements,
        )
        sharded_sd[param_name] = nn.Parameter(sharded_tensor)
    # choose `assign=True` since we cannot call `copy_` on meta tensor
    return model.load_state_dict(sharded_sd, strict=False, assign=True)

One more thing I discovered was that init_weights() function was called twice, once at through the Transformer module init function

Good question! We intentionally call init_weights() twice. 1st time is inside meta init, where we init parameters/tensors on meta device. 2nd time is on device='cuda', where we actually allocate tensor storage with real values

For LoRA/finetuning case, meta init is preferred since we are loading from checkpionts eventually. So we just init params on meta device, and copy tensors from checkpoints into model (load_from_full_model_state_dict). More about meta init: https://pytorch.org/tutorials/prototype/skip_param_init.html

@weifengpy weifengpy self-assigned this Jun 25, 2024
@MinghaoYan
Copy link
Author

MinghaoYan commented Jun 25, 2024

Thank you very much for the pointer! After some more investigation, it does seem like the first step loss is too high (without any LoRA or any training) after loading weights from HF checkpoints directly. I get a loss of 11.79 at step 1, with train_configs/llama3_8b.toml configs and run_llama_train.sh on 4 A10 24GB GPUs. A minimal reproducible example would be instead of

    if parallel_dims.pp_enabled:
        pp_schedule = build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn)
    else:
        # If PP is enabled, we can't rely on init_weights, because some layers are missing.
        # In the future, we may make init_weights handle missing layers, but also have to consider RNG seed propagation.
        # allocate sharded model on GPU and initialize weights via DTensor
        model.init_weights()

I loaded weights from the checkpoints instead:

    pretrained_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
    load_from_full_model_state_dict(model, pretrained_model, "cpu")
    del pretrained_model

    init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda"
    model.to_empty(device=init_device)

    if parallel_dims.pp_enabled:
        pp_schedule = build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn)
    else:
        # If PP is enabled, we can't rely on init_weights, because some layers are missing.
        # In the future, we may make init_weights handle missing layers, but also have to consider RNG seed propagation.
        # allocate sharded model on GPU and initialize weights via DTensor
        # model.init_weights()
        pass

I made one change to your reference implementation, where I mapped HF param names to the names defined in torchtitan Transformers module:

def load_from_full_model_state_dict(
        model: "FSDPModule",
        full_sd: Dict[str, Any],
        device: torch.device,
    ):
        """
        Converting full state dict into a sharded state dict
        and loading it into FSDP model
        - 'full' means plain tensor
        - 'sharded' means `DTensor` where reach rank has a shard of the plain tensor
        """
        print(model)

        param_mapping = {
            'model.embed_tokens.weight': 'tok_embeddings.weight'
        }

        for i, _ in enumerate(full_sd.model.layers):
            param_mapping.update({
                f'model.layers.{i}.self_attn.q_proj.weight': f'layers.{i}.attention.wq.weight',
                f'model.layers.{i}.self_attn.k_proj.weight': f'layers.{i}.attention.wk.weight',
                f'model.layers.{i}.self_attn.v_proj.weight': f'layers.{i}.attention.wv.weight',
                f'model.layers.{i}.self_attn.o_proj.weight': f'layers.{i}.attention.wo.weight',
                f'model.layers.{i}.mlp.gate_proj.weight': f'layers.{i}.feed_forward.w1.weight',
                f'model.layers.{i}.mlp.down_proj.weight': f'layers.{i}.feed_forward.w2.weight',
                f'model.layers.{i}.mlp.up_proj.weight': f'layers.{i}.feed_forward.w3.weight',
                f'model.layers.{i}.input_layernorm.weight': f'layers.{i}.attention_norm.weight',
                f'model.layers.{i}.post_attention_layernorm.weight': f'layers.{i}.ffn_norm.weight'
            })

        param_mapping.update({
            'model.norm.weight': 'norm.weight',
            'lm_head.weight': 'output.weight'
        })

        meta_sharded_sd = model.state_dict()
        sharded_sd = {}
        for param_name, full_tensor in full_sd.named_parameters():
            
            sharded_meta_param = meta_sharded_sd.get(param_mapping[param_name])
            # print(sharded_meta_param)
            full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device)
            # print(param_name, full_tensor, sharded_meta_param)

            sharded_tensor = distribute_tensor(
                full_tensor,
                sharded_meta_param.device_mesh,
                sharded_meta_param.placements,
            )
            sharded_sd[param_name] = nn.Parameter(sharded_tensor)
        # choose `assign=True` since we cannot call `copy_` on meta tensor
        return model.load_state_dict(sharded_sd, strict=False, assign=True)

I would really appreciate it if you can spot anything wrong in my code or reproduce the loss that I had.

@weifengpy
Copy link
Contributor

weifengpy commented Jun 25, 2024

I made one change to your reference implementation, where I mapped HF param names to the names defined in torchtitan Transformers module

good catch by remapping parameter names

if you can spot anything wrong in my code

I would call load_from_full_model_state_dict after model.to_empty(...), and use torch.device("cuda")

  • model.to_empty moves FSDP model from meta to 'cuda', but parameters values are unassigned
  • load_from_full_model_state_dict(..., torch.device("cuda"): assign parameters values
  • I do not expect GPU OOM since load_from_full_model_state_dict casts full tensor into 1/N tensor. But let me know if you hit it

If you still hits error, I can try reproduce. Feel free to open a PR with your local changes

    init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda"
    model.to_empty(device=init_device)

    if parallel_dims.pp_enabled:
        pp_schedule = build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn)
    else:
        # If PP is enabled, we can't rely on init_weights, because some layers are missing.
        # In the future, we may make init_weights handle missing layers, but also have to consider RNG seed propagation.
        # allocate sharded model on GPU and initialize weights via DTensor
        # model.init_weights()
        pass

    pretrained_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
    load_from_full_model_state_dict(model, pretrained_model, torch.device("cuda"))
    del pretrained_model

@MinghaoYan
Copy link
Author

Thank you for your reply!

I moved load_from_full_model_state_dict to after model.to_empty(...), if I keep torch device as cpu, the behavior is the same. If I change device to cuda, it would incur this new problem where it complains that tensor is not leaf (this would correspond to the full_tensor variable in the load_from_full_model_state_dict function).

    File "/home/ubuntu/.conda/envs/torchtitan/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
      return f(*args, **kwargs)
    File "/home/ubuntu/torchtitan/train.py", line 252, in main
      load_from_full_model_state_dict(model, pretrained_model, torch.device("cuda"))
    File "/home/ubuntu/torchtitan/train.py", line 540, in load_from_full_model_state_dict
      sharded_tensor = distribute_tensor(
    File "/home/ubuntu/.conda/envs/torchtitan/lib/python3.10/site-packages/torch/distributed/_tensor/api.py", line 598, in distribute_tensor
      raise RuntimeError(
  RuntimeError: `distribute_tensor` should be used to distribute leaf tensors! but found non-leaf tensor!

I seem to be having some permission issues creating branches and PRs, I will look into it. Thanks again!

@weifengpy
Copy link
Contributor

thanks for the patience. let me know if you cannot resolve it after investigation. I might draft some example code to load from HF checkpoint. This is a common ask and we can improve

@MinghaoYan
Copy link
Author

Thank you! I have created a PR here: #427

@weifengpy
Copy link
Contributor

Thank you! I have created a PR here: #427

thanks. I will give it a try. in the meanwhile, we have a script to convert HF to DCP format. Are you interested in giving it a try? #305

@tianyu-l
Copy link
Contributor

tianyu-l commented Jul 4, 2024

@MinghaoYan

Currently I am trying to copy the weights from a HuggingFace Llama-3-8B-Instruct checkpoint to an instantiated Transformer module in torchtitan manually by replacing the Transformer init_weight function with the following

I wonder if you are aware of the model definition mismatch between Llama and HF's Transformer (#335). Basically a permutation of some weights is needed to make the conversion work.

@MinghaoYan
Copy link
Author

I was not aware of this, thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants