Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions about FSDP2 support and memory usage. #658

Open
tangjiasheng opened this issue Oct 29, 2024 · 6 comments
Open

Questions about FSDP2 support and memory usage. #658

tangjiasheng opened this issue Oct 29, 2024 · 6 comments
Labels
question Further information is requested

Comments

@tangjiasheng
Copy link

What is current support of FSDP2 in main pytorch?
I just see this here https://github.com/pytorch/pytorch/blob/main/torch/distributed/_composable/fully_shard.py#L45

"torch.distributed._composable.fully_shard will be removed after PyTorch 2.5."

Will FSDP2 be deprecated? Can FSDP1 work with DTensor as well as TP?

I tried FSDP2 in my new project, but I got higher GPU Memory usage compared to FSDP1, what might this cause? The model is a 10B DiT-like model with extra embedding layer compared to LLMs. My main concern is that should I need to take more modules warpped with fully_shard to reduce the memory usage?

Since the transformer block is quite similar to llama, I use the same fully_sahrd warp with your project.

@awgu
Copy link
Contributor

awgu commented Oct 29, 2024

Sorry for the confusion. The fully_shard in torch.distributed._composable.fully_shard was an experimental API that we were working on that simply calls into the same code as FSDP1 (FullyShardedDataParallel).

FSDP2 is in torch.distributed._composable.fsdp.fully_shard (note the extra .fsdp).

Are you comparing FSDP1 vs. FSDP2 with the same "wrapping"? I.e. are you calling fully_shard on the same submodules as FullyShardedDataParallel (possibly through the auto_wrap_policy)?

If you are able to share some code, that would be helpful too!

@tianyu-l tianyu-l added the question Further information is requested label Oct 29, 2024
@tangjiasheng
Copy link
Author

tangjiasheng commented Oct 30, 2024

Sorry for the confusion. The fully_shard in torch.distributed._composable.fully_shard was an experimental API that we were working on that simply calls into the same code as FSDP1 (FullyShardedDataParallel).

FSDP2 is in torch.distributed._composable.fsdp.fully_shard (note the extra .fsdp).

Are you comparing FSDP1 vs. FSDP2 with the same "wrapping"? I.e. are you calling fully_shard on the same submodules as FullyShardedDataParallel (possibly through the auto_wrap_policy)?

If you are able to share some code, that would be helpful too!

Thanks for your clarification!
For FSDP2, can you share more information on future plans? And will FullyShardedDataParallel be worked with DTensor and TP?

For my own usage, FSDP1 is like

        fpSixteen = MixedPrecision(
            param_dtype=dtype,
            # Gradient communication precision.
            reduce_dtype=torch.float,
            # Buffer precision.
            buffer_dtype=dtype,
        )

        my_size_based_auto_wrap_policy = functools.partial(
            size_based_auto_wrap_policy, min_num_params=5e7) 
        
        model = FSDP(model, mixed_precision=fpSixteen, auto_wrap_policy=my_size_based_auto_wrap_policy, device_id=torch.cuda.current_device(), sharding_strategy=ShardingStrategy.FULL_SHARD, use_orig_params=True)

I test FSDP1 on 8 GPUs in one machine, it costs about 55G GPU memory.

As for FSDP2:

        mp_policy = MixedPrecisionPolicy(param_dtype=dtype, reduce_dtype=torch.float)
        fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}

        for layer_id, st_blocks in enumerate(model.transformer_blocks):
            reshard_after_forward = int(layer_id) < len(model.transformer_blocks) - 1
            fully_shard(
                st_blocks,
                **fsdp_config,
                reshard_after_forward=reshard_after_forward,
            )
        fully_shard(model, **fsdp_config, reshard_after_forward=True)

it costs about 75G~80G GPU memory.

Sorry I can not provide more codes about the model arch. But the model arch is quit similar to llama, a DiT-like model together with extra input embeddings and bi-direction attention compared to llama. I think this shouldn't lead to such huge GPU memory gain.

@awgu
Copy link
Contributor

awgu commented Oct 30, 2024

In short, I think there could be two possibilities:

  1. Different "FSDP wrapping"
  2. Uneven sharding with FSDP2

The FSDP1 frontend API is FullyShardedDataParallel, and the FSDP2 front-end API is fully_shard. (The former is an nn.Module wrapper, while the latter is not.)

When you pass in an auto_wrap_policy to FSDP1, the policy is simply syntactic sugar that helps call FullyShardedDataParallel on submodules following the policy's predicate and reassigns the new wrapper FullyShardedDataParallel module back to its parent.

Each time you call an FSDP frontend API on a module, it creates one "parameter group" to be communicated together. Specifically, all parameters in module.parameters() except those already assigned to a nested FSDP module will be assigned to this new "parameter group".

For FSDP1, you are using a size based wrapping policy. This maybe leads to much smaller "parameter groups". In other words, for FSDP2, since you are only wrapping each transformer block, perhaps there are some other modules in your model that are not being assigned into their own "parameter group" but should for more fair comparison. Intuitively, wrapping more modules means more communication kernels, but generally lower peak memory usage.


Since FSDP2 shards parameters on dim-0, if you have some parameters with shape [dim0, dim1, ...], where dim0 is less than your FSDP world size (seems 8 for you) while the product of the other dims is large, then there will be significant padding. In adversarial cases, this could lead to significant extra memory usage.

@awgu
Copy link
Contributor

awgu commented Oct 30, 2024

Ultimately, the best way for you to debug this (if you want) is to compare memory snapshots. It is somewhat hard for me to say without more information. I can only recommend to make the comparison as apples-to-apples as possible. Making the 'wrapping' the same across FSDP1 and FSDP2 is always doable. For the uneven sharding issue, there is not much we can do right now; perhaps you can fold some dims together in your parameters to make dim-0 larger.

@tangjiasheng
Copy link
Author

Gu, thanks!
As you say, "for FSDP2, since you are only wrapping each transformer block, perhaps there are some other modules in your model that are not being assigned into their own "parameter group"". I've wrapped the other embedding layers into FSDP2's fully_shard, but it seems that there is no help to the GPU memory usage.
Next I will try to block the difference of auto_wrap_policy , I will post my results tomorrow.
I'm not sure would this policy be fair for this comparison?

my_auto_wrap_policy = functools.partial(
                        lambda_auto_wrap_policy,
                        lambda_fn=lambda m: m in model.blocks,
                    )

I would also try to make this reproducible with the open-source Flux(t2i model), maybe.
For the uneven sharding issue, it might not be the reason for my case.

Thanks again!

@awgu
Copy link
Contributor

awgu commented Oct 30, 2024

If you can get an open-source repro, I would be happy to help take a look!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants