-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TT-TO-TTGWARP] Detect and handle flash attention with causal masking #2013
Conversation
Signed-off-by: Julian Oppermann <[email protected]>
@@ -383,6 +385,51 @@ class ConvertTritonToTritonGPUWarp | |||
} | |||
return WalkResult::advance(); | |||
}); | |||
|
|||
if (loops.size() == 2 && workloads.front() == Workload::Attention && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why the loop size has to be 2 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
flashAttention with causal mask has 2 loops.
third_party/intel/lib/TritonToTritonGPUWarp/TritonToTritonGPUWarpPass.cpp
Outdated
Show resolved
Hide resolved
third_party/intel/lib/TritonToTritonGPUWarp/TritonToTritonGPUWarpPass.cpp
Outdated
Show resolved
Hide resolved
return; | ||
|
||
if (auto cst = dyn_cast<arith::ConstantOp>(op)) { | ||
transformArithConstantOp(cst, blockLayout); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here just assume the type w/o encoding has blockLayout, right?
I have a local change that aims to cover the full propagation, making most of the case go to L405 early return.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this IR walk just patches previously unhandled ops, but is completely specific to causal flash attention. I added a FIXME
above to make that clearer. So great if we could drop this workaround soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall LGTM
Signed-off-by: Julian Oppermann <[email protected]>
Detect and handle flash attention with causal masking in
-convert-triton-to-tritongpu-warp
by supportingtt.make_range
and two dependent attention-for
-loops.See #1947 for more context / the complete PoC.