-
Notifications
You must be signed in to change notification settings - Fork 228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
using fsdp2 wrapper Flux(text to image) model , gradient is inconsistent with fsdp1 #734
Comments
can you give some advice,thanks |
Are you wrapping FSDP1 and FSDP2 in the exactly same way, and is there any parameter sharing for that last layer? |
@yanmj0601 Do you have any repro for this that we can look into? |
Thank you, this problem has been solved, mainly due to the conflict between mixed precision of fsdp2 and torch.amp.autocast. But I encountered another problem. During the process of aligning the accuracy, I found that when fsdp1 and fsdp2 called torch.nn.LayerNorm, fsdp1 return value is fp32 precision, but fsdp2 return value is fp16 precision. Where can I configure the return accuracy of layernorm in fsdp2? |
@yanmj0601 Are you sure you are wrapping FSDP1 and FSDP2 in the same way then? FSDP1 does not have any mechanism to make layer norm return fp32. Perhaps, you are wrapping layer norms separately and using a different mixed precision config for it? I need to clarify which of the following you want (possibly more than one):
Depending on where the layer norms are in your in model, I would say that 4 may not make sense. For example, if you have layer norm before a linear and that linear runs in bf16, then you may as well return the layer norm in bf16 since it will need to be cast to bf16 for the linear anyway. If you want 1 and 3, then you can apply |
FSDP1 returns fp32 probably because my model training is in the autocast context, but the FSDP2 in the autocast context will have the above vanishing gradient problem. I haven't found the specific reason for this phenomenon yet. thanks for your reply, I will try this method first |
One more question, do you recommend using FSDP2 with mixed precision and autocast context? |
@yanmj0601 so I am not sure why FSDP1 vs. FSDP2 would have any difference with autocast. See my comment here: #700 (comment) FSDP1 and FSDP2 mainly provide some convenient configs for inserting dtype casts into the training flow, but there is no magic happening. Autocast happens on the operator level, and FSDP1/2 mixed precision happens at the wrapped module level. If the wrapping is the same, then FSDP1 and FSDP2 have the same behavior with respect to mixed precision. You can use autocast with FSDP2 mixed precision if you want the complementary behavior -- namely, if you want the autocast upcast-to-fp32 logic for some operators. |
i use register_full_backward_hook print grad when backward like this way:
but i discover last layer's grad is inconsistent between fsdp1 and fsdp2.('Grad output ' is consistent)
Below is my code to wrapper flux model,Currently I'm not using compile and activation checkpointing
The text was updated successfully, but these errors were encountered: