-
Notifications
You must be signed in to change notification settings - Fork 228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Questions about FSDP2 support and memory usage. #658
Comments
Sorry for the confusion. The FSDP2 is in Are you comparing FSDP1 vs. FSDP2 with the same "wrapping"? I.e. are you calling If you are able to share some code, that would be helpful too! |
Thanks for your clarification! For my own usage, FSDP1 is like
I test FSDP1 on 8 GPUs in one machine, it costs about 55G GPU memory. As for FSDP2:
it costs about 75G~80G GPU memory. Sorry I can not provide more codes about the model arch. But the model arch is quit similar to llama, a DiT-like model together with extra input embeddings and bi-direction attention compared to llama. I think this shouldn't lead to such huge GPU memory gain. |
In short, I think there could be two possibilities:
The FSDP1 frontend API is When you pass in an Each time you call an FSDP frontend API on a module, it creates one "parameter group" to be communicated together. Specifically, all parameters in For FSDP1, you are using a size based wrapping policy. This maybe leads to much smaller "parameter groups". In other words, for FSDP2, since you are only wrapping each transformer block, perhaps there are some other modules in your model that are not being assigned into their own "parameter group" but should for more fair comparison. Intuitively, wrapping more modules means more communication kernels, but generally lower peak memory usage. Since FSDP2 shards parameters on dim-0, if you have some parameters with shape |
Ultimately, the best way for you to debug this (if you want) is to compare memory snapshots. It is somewhat hard for me to say without more information. I can only recommend to make the comparison as apples-to-apples as possible. Making the 'wrapping' the same across FSDP1 and FSDP2 is always doable. For the uneven sharding issue, there is not much we can do right now; perhaps you can fold some dims together in your parameters to make dim-0 larger. |
Gu, thanks!
I would also try to make this reproducible with the open-source Flux(t2i model), maybe. Thanks again! |
If you can get an open-source repro, I would be happy to help take a look! |
What is current support of FSDP2 in main pytorch?
I just see this here https://github.com/pytorch/pytorch/blob/main/torch/distributed/_composable/fully_shard.py#L45
Will FSDP2 be deprecated? Can FSDP1 work with DTensor as well as TP?
I tried FSDP2 in my new project, but I got higher GPU Memory usage compared to FSDP1, what might this cause? The model is a 10B DiT-like model with extra embedding layer compared to LLMs. My main concern is that should I need to take more modules warpped with fully_shard to reduce the memory usage?
Since the transformer block is quite similar to llama, I use the same fully_sahrd warp with your project.
The text was updated successfully, but these errors were encountered: