Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

70B Fine-tuning GPUs Utilization #2142

Open
fabiogeraci opened this issue Dec 10, 2024 · 4 comments
Open

70B Fine-tuning GPUs Utilization #2142

fabiogeraci opened this issue Dec 10, 2024 · 4 comments
Assignees
Labels
discussion Start a discussion distributed Anything related to distributed env (multi-GPU, multi-node)

Comments

@fabiogeraci
Copy link

          openmpi script, launch cli
mpirun \
    -np $TOTAL_NUM_GPUS \
    -H \$MPI_HOST_STRING \
    -x PATH \
    -bind-to none \
    -map-by slot \
    --mca pml ob1 --mca btl ^openib \
    --display-allocation \
    --display-map \
    python3 src/full_finetune_distributed.py \
    --config config_files/8B_full_distributed.yaml \
    optimizer_in_bwd=False

full_finetune_distributed.py

if int(os.environ.get("NUM_NODES")) > 1:
    from torch.distributed._tensor import init_device_mesh
    mesh_2d = init_device_mesh("cuda",
                               mesh_shape=(int(os.environ.get("NUM_NODES")),
                                           int(os.environ['WORLD_SIZE']) // 2),
                                           mesh_dim_names=("dp", "tp"))
else:
    mesh_2d = None

training.shard_model(
    model=model,
    shard_conditions=fsdp_shard_conditions,
    cpu_offload=fsdp_cpu_offload,
    reshard_after_forward=reshard_after_forward,
    mesh=mesh_2d,
)

_distributed.py

def shard_model(
    model: TransformerDecoder,
    shard_conditions: List[Callable[[str, nn.Module], bool]],
    *,
    cpu_offload: bool,
    reshard_after_forward: bool = True,
    mesh: Optional[DeviceMesh] = None # <-- Add this line
) -> None:
if mesh is not None: # <-- Add this line
        fsdp_kwargs["mesh"] = mesh # <-- Add this line

Originally posted by @fabiogeraci in #2018 (comment)

@fabiogeraci
Copy link
Author

fabiogeraci commented Dec 10, 2024

I am using the configuration above to fine tuning 70B model on 2 nodes with 8 gpus each. the job took 75minutes to compile (is that usual?)

I also noticed that one of the 16 gpus wan not used at all, i hope the video helps i also attached the nccl 70b_nccl.txt
Screencast from 10-12-24 09:59:54.webm

the job was killed because, any suggestions

# LSBATCH: User input
#BSUB -J gpu-test
#BSUB -o /nfs/users/nfs_f/fg12/scripts/logs/gpu-test_o.%J
#BSUB -e /nfs/users/nfs_f/fg12/scripts/logs/gpu-test_e.%J
#BSUB -n 128
#BSUB -q gpu-parallel
#BSUB -gpu "num=8:gmem=80000:mode=shared:block=yes"
#BSUB -M 768G
#BSUB -R "select[mem>768G] rusage[mem=768G] span[ptile=64]"

TERM_MEMLIMIT: job killed after reaching LSF memory usage limit.
Exited with signal termination: 9.

Resource usage summary:

    CPU time :                                   69121.00 sec.
    Max Memory :                                 793870 MB
    Average Memory :                             398370.69 MB
    Total Requested Memory :                     1572864.00 MB
    Delta Memory :                               778994.00 MB
    Max Swap :                                   -
    Max Processes :                              559
    Max Threads :                                5356
    Run time :                                   6266 sec.
    Turnaround time :                            6269 sec.

70b_config.txt

@joecummings joecummings self-assigned this Dec 10, 2024
@joecummings joecummings added discussion Start a discussion distributed Anything related to distributed env (multi-GPU, multi-node) labels Dec 10, 2024
@joecummings
Copy link
Contributor

Thanks for the report! Based on your config and the setup you have, I don't see immediately why this would hit your specified memory limit of 768G. Let me get ahold of a multi-node setup today and test this out.

@joecummings
Copy link
Contributor

Hey @fabiogeraci, just updating you on this. I'm waiting on a request for multi-node server (PyTorch has limited quantity). If I don't hear back today, I'll just rent one out on Lambda Labs or something.

@fabiogeraci
Copy link
Author

Hey @fabiogeraci, just updating you on this. I'm waiting on a request for multi-node server (PyTorch has limited quantity). If I don't hear back today, I'll just rent one out on Lambda Labs or something.

thanks you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion Start a discussion distributed Anything related to distributed env (multi-GPU, multi-node)
Projects
None yet
Development

No branches or pull requests

2 participants