Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support causal masking in FlashAttention #1947

Draft
wants to merge 2 commits into
base: perf_attn
Choose a base branch
from

Conversation

jopperm
Copy link
Contributor

@jopperm jopperm commented Aug 20, 2024

PoC extension of the advance path to handle causal masking in FlashAttention-2.

Summary of changes:

  • Support for tt.make_range throughout all passes of the advance path.
  • Extended TritonToTritonGPUWarp and MatchTargetSize to support two dependent attention-for-loops.
  • Extended lowering of tt.broadcast of row vectors: Here, we need to select and splat a single value per thread. To that end, I needed to introduce an op for querying the lane ID in TritonGEN.

The generated code passes result verification against PyTorch (both causal=True and causal=False).

Remaining issues:

  • Enabling the schedule load pass leads to invalid IR (operation uses its own result).
  • No lit tests yet.

@jopperm jopperm marked this pull request as ready for review August 22, 2024 15:02
@jopperm jopperm self-assigned this Aug 22, 2024
Copy link
Contributor

@etiotto etiotto left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed please post a PR against the llvm-target branch for this feature.

@etiotto etiotto marked this pull request as draft August 26, 2024 19:37
whitneywhtsang pushed a commit that referenced this pull request Aug 30, 2024
Add support for distributing `tt.make_range` ops according to the
desired warp size. This is a PoC which assumes that multiple warps are
only needed along one dimension.

See #1947 for
more context / the complete PoC.

---------

Signed-off-by: Julian Oppermann <[email protected]>
whitneywhtsang pushed a commit that referenced this pull request Sep 4, 2024
…#2013)

Detect and handle flash attention with causal masking in
`-convert-triton-to-tritongpu-warp` by supporting `tt.make_range` and
*two* dependent attention-`for`-loops.

See #1947 for more context / the complete PoC.

---------

Signed-off-by: Julian Oppermann <[email protected]>
whitneywhtsang pushed a commit that referenced this pull request Sep 4, 2024
…2043)

Splits `make_range` into SG-sized subranges, and handles row-vector
broadcasts (e.g. `1x64 -> 16x64`) in `MatchTargetSize`.

See #1947 for
more context / the complete PoC.

---------

Signed-off-by: Julian Oppermann <[email protected]>
Dewei-Wang-sh pushed a commit that referenced this pull request Sep 9, 2024
Support canonicalization of dependent `scf.for` loops by re-gluing
individual results after the loop.

See #1947 for
more context / the complete PoC.

---------

Signed-off-by: Julian Oppermann <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[#6 Attention Performance] extend attention support for Causal = True
2 participants