-
Notifications
You must be signed in to change notification settings - Fork 228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
meta device issue with float8 delayed scale #654
Comments
cc @vkuzo |
without debugging, my guess would be something like:
I can take a look next week, unless someone gets to it faster |
if running the repo for the 1st time, torchtitan/output/checkpoint folder will be empty. the model won't load checkponits but the error is still there. We do meta init and call init_weights to move model from meta to cuda. buffers for delayed scaling might need some treatment
thanks! |
I see, then this line is relevant: https://github.com/pytorch/ao/blob/e85c1a318b06bbdb3b8c7f92f3257999864446b0/torchao/float8/float8_linear.py#L648 We'll have to think if we can figure out to do this automatically without introducing one more API. If not, we'll have to design such as API. |
I see. it sounds plausible |
Summary: Context: pytorch/torchtitan#654 If the user has delayed scaling and FSDP float8 all-gather on, there is a subtle bug that can happen if the user calls `model.to_empty(device="cuda")`: 1. to_empty recreates the buffers for tracking weight amax and scale 2. (1) leaves the buffers pointed to by Float8Linear.weight._amax_buffer, etc orphaned, because they don't participate in `to_empty` I couldn't think of an easy and clean way to auto-fix this since we can't expect `torch.nn.Module` to know that our logic has multiple references to the same buffer, so exposing a private API for now until we can think of something better. With the current fix, the user can then call `_maybe_fixup_delayed_scaling_buffers` manually to relink the buffers to the correct new versions. Test Plan: CI Reviewers: Subscribers: Tasks: Tags:
I really don't love this solution, but we could do something like this: pytorch/ao#1292. Thoughts? |
Summary: Context: pytorch/torchtitan#654 If the user has delayed scaling and FSDP float8 all-gather on, there is a subtle bug that can happen if the user calls `model.to_empty(device="cuda")`: 1. to_empty recreates the buffers for tracking weight amax and scale 2. (1) leaves the buffers pointed to by Float8Linear.weight._amax_buffer, etc orphaned, because they don't participate in `to_empty` I couldn't think of an easy and clean way to auto-fix this since we can't expect `torch.nn.Module` to know that our logic has multiple references to the same buffer, so exposing a private API for now until we can think of something better. With the current fix, the user can then call `_maybe_fixup_delayed_scaling_buffers` manually to relink the buffers to the correct new versions. Test Plan: CI Reviewers: Subscribers: Tasks: Tags:
thanks for the fix! |
opening as the fix isn't landed yet :) |
repro:
The text was updated successfully, but these errors were encountered: