You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
currently the API of context parallel have five problems.
only support apply CP to whole model. if we have some cross attn in prep part of model with unsupported shape, it's impossible to apply CP since _context_parallel always override all SDPA and need to wrap whole backward.
no shard/unshard with gradient support. when I try to apply CP to transformer blocks only and remain other SDPA replicate, the context_parallel_unshard in pytorch has no_grad decorator.
weight gradients inside CP region is divided by size of CP mesh because we reduce them in DP+CP, this may work for optimizer with norm support, but make unit test harder to write, we have to scale them back to get same gradients as model without CP.
The length of the sequence must be divisible by the number of CP (CP * 2 for robin).
replicate input of CP region may contain wrong gradient because its gradient may be Partial, we have to check every replicate input and use to_local(grad_placements=[Partial()]).
To resolve problem 1 above, I remove context_parallel context to disable SDPA override, only enable _enable_cp_dispatcher context, then we can enable CP SDPA iff all inputs are converted to DTensor. problem 2 is easy to resolve, just write some auto grad functions.
here is my questions:
is there a better way to support CP region?
do you have any plan to support CP region officially and resolve issues above?
The text was updated successfully, but these errors were encountered:
We do see some requests to apply CP to only some region. This will require some communication before SDPA as we will have to borrow some ranks from the DP sharding dimensions. Current CP doesn't support this for sure. We will discuss how and what is the best way to support this feature.
Problem
currently the API of context parallel have five problems.
_context_parallel
always override all SDPA and need to wrap whole backward.context_parallel_unshard
in pytorch hasno_grad
decorator.Partial
, we have to check every replicate input and useto_local(grad_placements=[Partial()])
.To resolve problem 1 above, I remove
context_parallel
context to disable SDPA override, only enable_enable_cp_dispatcher
context, then we can enable CP SDPA iff all inputs are converted to DTensor. problem 2 is easy to resolve, just write some auto grad functions.here is my questions:
CP region
?CP region
officially and resolve issues above?The text was updated successfully, but these errors were encountered: