Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model init with HuggingFace model #743

Open
neeldani opened this issue Dec 16, 2024 · 4 comments
Open

Model init with HuggingFace model #743

neeldani opened this issue Dec 16, 2024 · 4 comments
Assignees

Comments

@neeldani
Copy link

neeldani commented Dec 16, 2024

I am writing a simple script to run FSDP2 (fully_shard) on the pythia-1b model available on HuggingFace. I am currently running the model on 1 node with 2 devices. I was following the meta-device initialisation from the FSDP2 docs. However, I think there is something wrong with my implementation since the peak memory usage with FSDP is same as without FSDP (~ 1GB). Further, I get an OOM on my device when I try with pythia-2.8b model. Following is a snippet on how I am initialising the model on a meta device using HuggingFace APIs:

model_name = "EleutherAI/pythia-14m"
    
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
config = AutoConfig.from_pretrained(model_name)
    with init_empty_weights():
        model = AutoModelForCausalLM.from_config(config)

    for module in model.modules():
        if isinstance(module, GPTNeoXLayer):
            fully_shard(module)
    
    model = fully_shard(model, reshard_after_forward=True)

    model = load_checkpoint_and_dispatch(
        model, path_to_safe_tensors
    )

This is not very straightforward since the shards expect DTensors when the weights are being loaded via load_checkpoint_and_dispatch. I am looking for some suggestions on what would be a good way to make FSDP2 work with HuggingFace models. I dont think accelerate supports FSDP2 yet.

@awgu
Copy link
Contributor

awgu commented Dec 16, 2024

cc: @weifengpy @mori360

@neeldani
Copy link
Author

👋 Gentle bump on this - mainly to see if there is some workaround for the above issue 👀

@mori360
Copy link
Contributor

mori360 commented Dec 20, 2024

However, I think there is something wrong with my implementation since the peak memory usage with FSDP is same as without FSDP (~ 1GB).

It depends on where you have the peak memory. If it's on fully_shard, then the full_state_dict would shard to a local_state_dict, causing a greater memory. (full_state_dict + local_state_dict > full_state_dict)

I get an OOM on my device when I try with pythia-2.8b model

Could you give more details on the safe_tensors as I could repro the huge memory cost.
Also, could you give a device flow so that I could follow up when you switch you devices to gpu.

@neeldani
Copy link
Author

neeldani commented Dec 23, 2024

It depends on where you have the peak memory. If it's on fully_shard, then the full_state_dict would shard to a local_state_dict, causing a greater memory. (full_state_dict + local_state_dict > full_state_dict)

I see. Ideally I am looking for an approach which allows me to load the sharded models on each GPU without loading the full_state_dict

Could you give more details on the safe_tensors as I could repro the huge memory cost.

I downloaded the model.safetensors for the pythia-1b model from here. These weights are not sharded

Also, could you give a device flow so that I could follow up when you switch you devices to gpu.

I am trying to mimic TorchTitan's implementation but with a HuggingFace model

  1. Load the empty model on the meta device
  2. Apply fsdp, move sharded weights to the respective GPUs and materialise the weights
  3. re-initialise the sharded weights on each GPU

This is a simple repro of my implementation which can be run using:

torchrun --nnodes=1 --nproc_per_node=2 reproduce.py

import os

import torch
from torch.distributed import init_process_group, destroy_process_group
from torch.distributed._composable.fsdp import fully_shard
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXLayer
from accelerate import init_empty_weights, load_checkpoint_and_dispatch

def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int:
    num_params = sum(p.numel() for p in model.parameters())
    if exclude_embedding:
        num_params -= model.tok_embeddings.weight.numel()
    return num_params

def setup(local_rank, world_size):
    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)
    init_process_group("nccl", rank=local_rank, world_size=world_size)

def load():
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    setup(local_rank, world_size)

    model_name = "EleutherAI/pythia-2.8b"
    config = AutoConfig.from_pretrained(model_name)
    
    with init_empty_weights():
        model = AutoModelForCausalLM.from_config(config)
    
    if local_rank == 0:
        print("Load models with empty weights")
        print("Device: ", model.device)
        print("Params: ", get_num_params(model))
        print("Peak mem: ", torch.cuda.max_memory_allocated() / (1024 ** 3))

    for module in model.modules():
        if isinstance(module, GPTNeoXLayer):
            fully_shard(module)
    
    model = fully_shard(model, reshard_after_forward=True)

    if local_rank == 0:
        print("Applied FSDP to the model")
        print("Device: ", model.device)
        print("Peak mem: ", torch.cuda.max_memory_allocated() / (1024 ** 3))
        print("# of params: ", get_num_params(model))

    model.to_empty(device='cuda')

    if local_rank == 0:
        print("Materialized the sharded tensors")
        print("Device: ", model.device)
        print("Peak mem: ", torch.cuda.max_memory_allocated() / (1024 ** 3))
        print("# of params: ", get_num_params(model))

    model = load_checkpoint_and_dispatch(model, "model.safetensors", device_map="auto", no_split_module_classes="GPTNeoXLayer")

if __name__ == "__main__":
    load()

The flow is very similar to that of TorchTitan's except that TorchTitan makes an explicit call to re-initialise the weights after materialising them. Since I wish to load weights from a pretrained HF model, its a bit challenging. The above code throws an error where I call load_checkpoint_and_dispatch since the model expects DTensors as inputs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants