Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TT-TO-TTGWARP] Detect and handle flash attention with causal masking #2013

Merged
merged 2 commits into from
Sep 4, 2024

Conversation

jopperm
Copy link
Contributor

@jopperm jopperm commented Aug 27, 2024

Detect and handle flash attention with causal masking in -convert-triton-to-tritongpu-warp by supporting tt.make_range and two dependent attention-for-loops.

See #1947 for more context / the complete PoC.

@jopperm jopperm self-assigned this Aug 27, 2024
@jopperm jopperm changed the title Detect and handle flash attention with causal masking [TT-TO-TTGWARP] Detect and handle flash attention with causal masking Aug 27, 2024
test/Conversion/intel/triton_to_tritongpu_warp.mlir Outdated Show resolved Hide resolved
@@ -383,6 +385,51 @@ class ConvertTritonToTritonGPUWarp
}
return WalkResult::advance();
});

if (loops.size() == 2 && workloads.front() == Workload::Attention &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the loop size has to be 2 ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flashAttention with causal mask has 2 loops.

return;

if (auto cst = dyn_cast<arith::ConstantOp>(op)) {
transformArithConstantOp(cst, blockLayout);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here just assume the type w/o encoding has blockLayout, right?
I have a local change that aims to cover the full propagation, making most of the case go to L405 early return.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this IR walk just patches previously unhandled ops, but is completely specific to causal flash attention. I added a FIXME above to make that clearer. So great if we could drop this workaround soon.

Copy link
Contributor

@Dewei-Wang-sh Dewei-Wang-sh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall LGTM

Signed-off-by: Julian Oppermann <[email protected]>
@jopperm jopperm requested a review from etiotto September 2, 2024 11:24
@whitneywhtsang whitneywhtsang merged commit 0d6f060 into llvm-target Sep 4, 2024
4 checks passed
@whitneywhtsang whitneywhtsang deleted the jopperm/tt-to-ttgwarp-causal branch September 4, 2024 19:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[#6 Attention Performance] extend attention support for Causal = True
5 participants