-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support FSDP in PyTorch #796
Comments
From meeting minutes from Michael Shi: Challenge is ensuring that JAX and PyTorch are equivalent. PyTorch should be doable by changing the DDP wrapper to the FSDP wrapper. |
...For the sake of self-referencing notes and approaches. For the pytorch case, there are two ways for doing this:
For the Jax case, which is the one I am less familiar with:
|
Hi,
Edited to add: Also, I turned off torch.compile for this workload. I think that is also due to the pytorch version. |
Thanks for the update! Regarding the torch.compile, that seems a little more problematic. When you have time could you paste a traceback of the issue w torch compile (maybe with https://gist.github.com/) of in the GH issue thread. If the fix requires updating PyTorch, we should probably bump the priority on that. |
Hi, I am not sure what to make of this error, yet. Also, there is this related blog post: https://dev-discuss.pytorch.org/t/torchdynamo-update-11-making-fsdp-and-dynamo-work-together/1037 |
Hi, |
It is useful to shard optimizer state across devices (to save significant memory). This reflects current practice. We want to support it.
The text was updated successfully, but these errors were encountered: