Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

using fsdp2 wrapper Flux(text to image) model , gradient is inconsistent with fsdp1 #734

Open
yanmj0601 opened this issue Dec 13, 2024 · 9 comments
Labels
question Further information is requested

Comments

@yanmj0601
Copy link

yanmj0601 commented Dec 13, 2024

i use register_full_backward_hook print grad when backward like this way:

def print_grad_hook(name):
    def hook(module, grad_input, grad_output):
        print(f"Layer Name: {name},Grad input: {grad_input},Grad output: {grad_output}")
    return hook
for name, layer in model.named_children():
    layer.register_full_backward_hook(print_grad_hook(name))

but i discover last layer's grad is inconsistent between fsdp1 and fsdp2.('Grad output ' is consistent)

fsdp1 grad:
Layer Name: proj_out,Grad input: (tensor([[[-1.4901e-08,  2.2445e-07,  5.4250e-08,   ...,  3.7812e-07,
           4.0606e-07, -3.8184e-07]]], device='cuda:0'),),Grad output: (tensor([[[-2.3991e-06,  2.3693e-06,  1.3947e-05,  ..., 
           4.0233e-07,  8.0466e-07]]], device='cuda:0', dtype=torch.bfloat16),)

fsdp2 grad:
Layer Name: proj_out,Grad input: (tensor([[[-0.0000e+00,  2.3842e-07,  5.9605e-08,  ...,  8.9407e-07,
           4.1723e-07, -3.5763e-07]]], device='cuda:0'),),Grad output: (tensor([[[-2.3991e-06,  2.3693e-06,  1.3947e-05,  ..., 
           4.0233e-07,  8.0466e-07]]], device='cuda:0', dtype=torch.bfloat16),)

Below is my code to wrapper flux model,Currently I'm not using compile and activation checkpointing

for layer_id, transformer_block in model.transformer_blocks.named_children():
        if pp_enabled:
            # For PP, do not reshard after forward to avoid per-microbatch
            # all-gathers, which can be expensive and non-overlapped
            reshard_after_forward = False
        else:
            # As an optimization, do not reshard after forward for the last
            # transformer block since FSDP would prefetch it immediately
            reshard_after_forward = True
        fully_shard(
            transformer_block,
            **fsdp_config,
            reshard_after_forward=reshard_after_forward,
        )
    for layer_id, transformer_block in model.single_transformer_blocks.named_children():
        if pp_enabled:
            # For PP, do not reshard after forward to avoid per-microbatch
            # all-gathers, which can be expensive and non-overlapped
            reshard_after_forward = False
        else:
            # As an optimization, do not reshard after forward for the last
            # transformer block since FSDP would prefetch it immediately
            reshard_after_forward = int(layer_id) < len(model.single_transformer_blocks) - 1
        fully_shard(
            transformer_block,
            **fsdp_config,
            reshard_after_forward=reshard_after_forward,
        )
    fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
@yanmj0601
Copy link
Author

can you give some advice,thanks

@awgu
Copy link
Contributor

awgu commented Dec 13, 2024

Are you wrapping FSDP1 and FSDP2 in the exactly same way, and is there any parameter sharing for that last layer?

@tianyu-l tianyu-l added the question Further information is requested label Dec 13, 2024
@yanmj0601
Copy link
Author

Are you wrapping FSDP1 and FSDP2 in the exactly same way, and is there any parameter sharing for that last layer?

yes,i am wrapping FSDP1 and FSDP2 in same way and not parameter sharing.
In fsdp2, there seems to be a vanishing gradient problem.(fsdp1.grad:left fsdp2:right)
截屏2024-12-16 14 33 54

@awgu
Copy link
Contributor

awgu commented Dec 16, 2024

@yanmj0601 Do you have any repro for this that we can look into?

@yanmj0601
Copy link
Author

@yanmj0601 Do you have any repro for this that we can look into?

Thank you, this problem has been solved, mainly due to the conflict between mixed precision of fsdp2 and torch.amp.autocast. But I encountered another problem. During the process of aligning the accuracy, I found that when fsdp1 and fsdp2 called torch.nn.LayerNorm, fsdp1 return value is fp32 precision, but fsdp2 return value is fp16 precision. Where can I configure the return accuracy of layernorm in fsdp2?

@awgu
Copy link
Contributor

awgu commented Dec 17, 2024

@yanmj0601 Are you sure you are wrapping FSDP1 and FSDP2 in the same way then? FSDP1 does not have any mechanism to make layer norm return fp32. Perhaps, you are wrapping layer norms separately and using a different mixed precision config for it?

I need to clarify which of the following you want (possibly more than one):

  1. Cast layer norm input from bf16 to fp32
  2. Cast layer norm weight from bf16 to fp32
  3. All-gather layer norm weight directly in fp32 (avoiding any cast from bf16 to fp32)
  4. Cast layer norm output from bf16 to fp32

Depending on where the layer norms are in your in model, I would say that 4 may not make sense. For example, if you have layer norm before a linear and that linear runs in bf16, then you may as well return the layer norm in bf16 since it will need to be cast to bf16 for the linear anyway.

If you want 1 and 3, then you can apply fully_shard to the layer norm modules separately first, e.g. like here (except this example is for batch norm)::
https://github.com/pytorch/pytorch/blob/ea0f60ecfabe0501485015841c4176d5a09c8247/test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py#L543-L549
FSDP2's MixedPrecisionPolicy supports an optional output_dtype that you can use to cast your layer norm output back to bf16 if needed.

@yanmj0601
Copy link
Author

@yanmj0601 Are you sure you are wrapping FSDP1 and FSDP2 in the same way then? FSDP1 does not have any mechanism to make layer norm return fp32. Perhaps, you are wrapping layer norms separately and using a different mixed precision config for it?

I need to clarify which of the following you want (possibly more than one):

  1. Cast layer norm input from bf16 to fp32
  2. Cast layer norm weight from bf16 to fp32
  3. All-gather layer norm weight directly in fp32 (avoiding any cast from bf16 to fp32)
  4. Cast layer norm output from bf16 to fp32

Depending on where the layer norms are in your in model, I would say that 4 may not make sense. For example, if you have layer norm before a linear and that linear runs in bf16, then you may as well return the layer norm in bf16 since it will need to be cast to bf16 for the linear anyway.

If you want 1 and 3, then you can apply fully_shard to the layer norm modules separately first, e.g. like here (except this example is for batch norm):: pytorch/pytorch@ea0f60e/test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py#L543-L549 FSDP2's MixedPrecisionPolicy supports an optional output_dtype that you can use to cast your layer norm output back to bf16 if needed.

FSDP1 returns fp32 probably because my model training is in the autocast context, but the FSDP2 in the autocast context will have the above vanishing gradient problem. I haven't found the specific reason for this phenomenon yet. thanks for your reply, I will try this method first

@yanmj0601
Copy link
Author

One more question, do you recommend using FSDP2 with mixed precision and autocast context?

@awgu
Copy link
Contributor

awgu commented Dec 18, 2024

@yanmj0601 so I am not sure why FSDP1 vs. FSDP2 would have any difference with autocast. See my comment here: #700 (comment)

FSDP1 and FSDP2 mainly provide some convenient configs for inserting dtype casts into the training flow, but there is no magic happening. Autocast happens on the operator level, and FSDP1/2 mixed precision happens at the wrapped module level. If the wrapping is the same, then FSDP1 and FSDP2 have the same behavior with respect to mixed precision.

You can use autocast with FSDP2 mixed precision if you want the complementary behavior -- namely, if you want the autocast upcast-to-fp32 logic for some operators.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants
@awgu @yanmj0601 @tianyu-l and others