Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[question]can't disable CP for specific (unsupported) SDPA op #757

Open
FindDefinition opened this issue Dec 20, 2024 · 1 comment
Open
Assignees
Labels
context_parallel enhancement New feature or request

Comments

@FindDefinition
Copy link

Problem

currently the API of context parallel have five problems.

  1. only support apply CP to whole model. if we have some cross attn in prep part of model with unsupported shape, it's impossible to apply CP since _context_parallel always override all SDPA and need to wrap whole backward.
  2. no shard/unshard with gradient support. when I try to apply CP to transformer blocks only and remain other SDPA replicate, the context_parallel_unshard in pytorch has no_grad decorator.
  3. weight gradients inside CP region is divided by size of CP mesh because we reduce them in DP+CP, this may work for optimizer with norm support, but make unit test harder to write, we have to scale them back to get same gradients as model without CP.
  4. The length of the sequence must be divisible by the number of CP (CP * 2 for robin).
  5. replicate input of CP region may contain wrong gradient because its gradient may be Partial, we have to check every replicate input and use to_local(grad_placements=[Partial()]).

To resolve problem 1 above, I remove context_parallel context to disable SDPA override, only enable _enable_cp_dispatcher context, then we can enable CP SDPA iff all inputs are converted to DTensor. problem 2 is easy to resolve, just write some auto grad functions.

here is my questions:

  1. is there a better way to support CP region?
  2. do you have any plan to support CP region officially and resolve issues above?
@fegin
Copy link
Contributor

fegin commented Dec 20, 2024

We do see some requests to apply CP to only some region. This will require some communication before SDPA as we will have to borrow some ranks from the DP sharding dimensions. Current CP doesn't support this for sure. We will discuss how and what is the best way to support this feature.

@fegin fegin added the enhancement New feature or request label Dec 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
context_parallel enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants