Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vote on new features in Discussions #694

Open
tianyu-l opened this issue Nov 23, 2024 · 16 comments
Open

Vote on new features in Discussions #694

tianyu-l opened this issue Nov 23, 2024 · 16 comments

Comments

@tianyu-l
Copy link
Contributor

Hi torchtitanists,

Thank you for your interests in torchtitan!

We created #693 for the community to add feature requests and vote on them. We'll try to prioritize on the most requested features. Please share what you'd like to see next!

@zigzagcaiseason
Copy link

zigzagcaiseason commented Nov 28, 2024

@tianyu-l

Hi developers,

Firstly, thanks for the great work that can demonstrate the power of PyTorch newly released features!

I just have one confusion about the usage of FSDP2 fully_shard.
Does FSDP2 support mixed precision within one warpping module, such like torch.float32 and torch.bfloat16 within a FSDPParamGroup?

To put it more clear, in most use cases of training LLM such like Lllama2, the precision of RMSNorm is usually torch.float32, but other components within the DecoderLayer is usually torch.bfloat16. When we want to train the model with the help of FSDP2, we have to wrap RMSNorm seperately since it has a seperate dtype, which will introduce additional all-gather and reduce-scatter.

From the profiling results, we found this approach (warpping RMSNorm seperately) will lead to bad computation-communication overlapping, especially in the backward pass.

Apart from that, there are also some other use cases: dtype of MoE gating layers is required to be torch.float32, but other components in the DecoderLayer is torch.bfloat16. We can also found that seperately warpping MoE.GateLayer would cause bad overlapping of computation-communication.

So, does mixed precision within a FSDPParamGroup supported? or could this be a new feature in the future?

Thanks!

@mayank31398
Copy link

@zigzagcai RMSNorm only has activations in fp32, the weights are still bf16.
Also, FSDP2 is quite flexible in having different dtypes for different tensors I believe

@tianyu-l
Copy link
Contributor Author

tianyu-l commented Dec 2, 2024

cc: @awgu

@aniltrkkn
Copy link

it should be simple but

Gradient Accumulation

it is very useful for sfting big models.

@samsja
Copy link
Contributor

samsja commented Dec 2, 2024

it should be simple but

Gradient Accumulation

it is very useful for sfting big models.

gradient accumulation is not that worth it with fully shared since you need to all gather the weight at each forward anyway.

Tho yeah could makes sense to still have it

@awgu
Copy link
Contributor

awgu commented Dec 2, 2024

@samsja you can avoid the all-gather/reduce-scatter per microbatch with FSDP2 with hopefully intuitive APIs:

for microbatch_idx, microbatch in enumerate(batch):
    is_last_microbatch = microbatch_idx == num_microbatches - 1
    model.set_requires_gradient_sync(is_last_microbatch)  # only reduce-scatter on last microbatch
    model.set_reshard_after_backward(is_last_microbatch)  # only all-gather on 1st microbatch
    # Run forward/backward
optim.step()
optim.zero_grad()

This will use extra memory since unsharded parameters and gradients are held through forward and backward (roughly equivalent to ZeRO-1).

@samsja
Copy link
Contributor

samsja commented Dec 2, 2024

@samsja you can avoid the all-gather/reduce-scatter per microbatch with FSDP2 with hopefully intuitive APIs:

for microbatch_idx, microbatch in enumerate(batch):
    is_last_microbatch = microbatch_idx == num_microbatches - 1
    model.set_requires_gradient_sync(is_last_microbatch)  # only reduce-scatter on last microbatch
    model.set_reshard_after_backward(is_last_microbatch)  # only all-gather on 1st microbatch
    # Run forward/backward
optim.step()
optim.zero_grad()

This will use extra memory since unsharded parameters and gradients are held through forward and backward (roughly equivalent to ZeRO-1).

hmm nice I did not know that set_reshard_after_backward was a thing.

Yeah so grad acc makes sense with zero 1 but not zero 2

@awgu
Copy link
Contributor

awgu commented Dec 3, 2024

@zigzagcai sorry for the delay -- I was out last week.

Does FSDP2 support mixed precision within one warpping module, such like torch.float32 and torch.bfloat16 within a FSDPParamGroup?

This is not well-supported (at least not simply). Part of this is an API design question trading off with performance. E.g., how would you specify which parameters in the parameter group are using fp32 vs. using bf16? (Let me know if you have ideas here.)

To put it more clear, in most use cases of training LLM such like Lllama2, the precision of RMSNorm is usually torch.float32, but other components within the DecoderLayer is usually torch.bfloat16. When we want to train the model with the help of FSDP2, we have to wrap RMSNorm seperately since it has a seperate dtype, which will introduce additional all-gather and reduce-scatter.

From the profiling results, we found this approach (warpping RMSNorm seperately) will lead to bad computation-communication overlapping, especially in the backward pass.

The reason is what FSDP2's default prefetching algorithm is to only allow effectively one in-flight all-gather at a time in backward. This will lead to poor overlapping like you saw when the all-gather sizes are flipping between small and large since we cannot overlap the transformer block all-gather with just the RMSNorm backward for example.

FSDP2 exposes some manual APIs to configure the prefetching that can help here. I will need to find some time to put together an example. Let me see if I can do it tomorrow or later this week. Mainly, you can use set_modules_to_forward_prefetch and set_modules_to_backward_prefetch to overwrite the default prefetching schedule.

Apart from that, there are also some other use cases: dtype of MoE gating layers is required to be torch.float32, but other components in the DecoderLayer is torch.bfloat16. We can also found that seperately warpping MoE.GateLayer would cause bad overlapping of computation-communication.

Do you mean that the MoE router weight must be in fp32? I do want to clarify the use case somewhat (though the feature request is still valid). For the cases I have seen, using bf16 RMSNorm weight and bf16 router weight are sufficient. The computation kernel can upcast intermediates as needed, but that does not mean the weight itself needs to be in fp32.

@zyushun
Copy link

zyushun commented Dec 3, 2024

it should be simple but

Gradient Accumulation

it is very useful for sfting big models.

@aniltrkkn Thanks for the great suggestion! I also agree that gradient accumulation is quite important. I myself implemented one version. Hope it would help a bit.

Line 517: https://github.com/zyushun/Adam-mini/blob/main/examples/llama/train.py

@zyushun
Copy link

zyushun commented Dec 3, 2024

@tianyu-l Thanks for organizing the great discussion! I have one request but I am not sure if we have it now: is there a demo code that transform the saved checkpoint into the format by Huggingface Transformers? That would be quite useful for downstream evaluation or further SFT, RL.

@samsja
Copy link
Contributor

samsja commented Dec 3, 2024

@tianyu-l Thanks for organizing the great discussion! I have one request but I am not sure if we have it now: is there a demo code that transform the saved checkpoint into the format by Huggingface Transformers? That would be quite useful for downstream evaluation or further SFT, RL.

There is a conversion script that should be compatible with torchtitan here: https://github.com/PrimeIntellect-ai/prime/blob/main/scripts/export_dcp.py

@tianyu-l
Copy link
Contributor Author

tianyu-l commented Dec 3, 2024

@zyushun I agree such a script is desirable but missing.
Alternatively you may try following the "DCP -> torch.save -> HF" flow noted in #420 (comment)

@awgu
Copy link
Contributor

awgu commented Dec 4, 2024

@zigzagcai do you have a repro of your setup? On 8xH100s Llama3-8B, I see some exposed collectives (e.g. norm reduce-scatter), but it should not be detrimental to training throughput:

diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py
index 4d4c60b..95b1871 100644
--- a/torchtitan/parallelisms/parallelize_llama.py
+++ b/torchtitan/parallelisms/parallelize_llama.py
@@ -339,10 +339,16 @@ def apply_fsdp(
     Apply data parallelism to the model. FSDP2 is used here.
     """
     mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
+    fp32_policy = MixedPrecisionPolicy(output_dtype=torch.bfloat16)
     fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
+    fp32_config = {"mesh": dp_mesh, "mp_policy": fp32_policy}
     if cpu_offload:
         fsdp_config["offload_policy"] = CPUOffloadPolicy()
 
+    for module_name, module in model.named_modules():
+        if "norm" in module_name:
+            fully_shard(module, **fp32_config)
+
     for layer_id, transformer_block in model.layers.items():
         if pp_enabled:
             # For PP, do not reshard after forward to avoid per-microbatch

@nighting0le01
Copy link

add tests/support for low bit optimizers and Flash attention-3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants