Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TT-TO-TTGWARP] Detect and handle flash attention with causal masking #2013

Merged
merged 2 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 158 additions & 4 deletions test/Conversion/intel/triton_to_tritongpu_warp.mlir
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
// RUN: triton-opt %s -split-input-file --convert-triton-to-tritongpu-warp="num-warps=32" | FileCheck %s --check-prefix=CHECK
// RUN: triton-opt %s -split-input-file --convert-triton-to-tritongpu-warp="num-warps=8" | FileCheck %s --check-prefix=CHECK1
// RUN: triton-opt %s -split-input-file --convert-triton-to-tritongpu-warp="num-warps=32" | FileCheck %s --check-prefix=CHECK --implicit-check-not="tensor<{{[0-9]+x[0-9]+x[if][0-9]+}}>"
// RUN: triton-opt %s -split-input-file --convert-triton-to-tritongpu-warp="num-warps=8" | FileCheck %s --check-prefix=CHECK1 --implicit-check-not="tensor<{{[0-9]+x[0-9]+x[if][0-9]+}}>"

// COM: The implicit-check-not ensures that an encoding attribute is added to all 2D tensors.
Dewei-Wang-sh marked this conversation as resolved.
Show resolved Hide resolved
// COM: Ideally, we should also check 1D tensors, but there are some tensor-typed attributes (`tt.divisibility_arg1`) that are not changed by the pass,
// COM: which are impractical to filter out.

// CHECK: #triton_gpu.blocked<{sizePerThread = [32, 64], threadsPerWarp = [1, 1], warpsPerCTA = [8, 4], order = [1, 0]}>
// CHECK: "triton_gpu.num-warps" = 32
module {
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16, 1>, %arg1: !tt.ptr<f16, 1>, %arg2: !tt.ptr<f32, 1>, %arg3: i32, %arg4: i32, %arg5: i32) {
// CHECK: tt.load
// CHECK-SAME: tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}
// CHECK: tt.load
Expand Down Expand Up @@ -53,7 +57,7 @@ module {
// CHECK1: #triton_gpu.blocked<{sizePerThread = [8, 32], threadsPerWarp = [1, 1], warpsPerCTA = [1, 8], order = [1, 0]}>
// CHECK1: "triton_gpu.num-warps" = 8
module {
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16, 1>, %arg1: !tt.ptr<f16, 1>, %arg2: !tt.ptr<f32, 1>, %arg3: i32, %arg4: i32, %arg5: i32) {
// CHECK1: tt.load
// CHECK1-SAME: tensor<8x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}
// CHECK1: tt.load
Expand Down Expand Up @@ -184,3 +188,153 @@ module {
tt.return
}
}

// -----

// COM: FlashAttention with causal masking
// COM: - two loops with workload=attention (=4) are detected
// COM: - encodings are propagated to the operations that correspond to the causal mask computation

// CHECK1: #triton_gpu.blocked<{sizePerThread = [16, 64], threadsPerWarp = [1, 1], warpsPerCTA = [8, 1], order = [1, 0]}>
// CHECK1: "triton_gpu.num-warps" = 8
module {
tt.func public @_attn_fwd(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: f32, %arg4: !tt.ptr<f32>, %arg5: !tt.ptr<f32>) {
%cst = arith.constant dense<1.000000e+00> : tensor<128xf32>
%cst_0 = arith.constant dense<0xFF800000> : tensor<128xf32>
%c1_i32 = arith.constant 1 : i32
%cst_1 = arith.constant dense<-1.000000e+06> : tensor<128x64xf32>
%c64_i32 = arith.constant 64 : i32
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32>
%c65536_i64 = arith.constant 65536 : i64
%c131072_i64 = arith.constant 131072 : i64
%cst_3 = arith.constant 1.44269502 : f32
%c0_i32 = arith.constant 0 : i32
%c1_i64 = arith.constant 1 : i64
%c64_i64 = arith.constant 64 : i64
%c1024_i64 = arith.constant 1024 : i64
%c128_i32 = arith.constant 128 : i32
%0 = tt.get_program_id z : i32
%1 = tt.get_program_id x : i32
%2 = tt.get_program_id y : i32
%3 = arith.extsi %1 : i32 to i64
%4 = arith.muli %3, %c131072_i64 : i64
%5 = arith.extsi %2 : i32 to i64
%6 = arith.muli %5, %c65536_i64 : i64
%7 = arith.addi %4, %6 : i64
%8 = tt.addptr %arg0, %7 : !tt.ptr<f16>, i64
%9 = arith.muli %0, %c128_i32 : i32
%10 = tt.make_tensor_ptr %8, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%9, %c0_i32] {order = array<i32: 1, 0>} : <tensor<128x64xf16>>
%11 = tt.addptr %arg2, %7 : !tt.ptr<f16>, i64
%12 = tt.make_tensor_ptr %11, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf16>>
%13 = tt.addptr %arg1, %7 : !tt.ptr<f16>, i64
%14 = tt.make_tensor_ptr %13, [%c64_i64, %c1024_i64], [%c1_i64, %c64_i64], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<64x64xf16>>
%15 = tt.addptr %arg5, %7 : !tt.ptr<f32>, i64
%16 = tt.make_tensor_ptr %15, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%9, %c0_i32] {order = array<i32: 1, 0>} : <tensor<128x64xf32>>
// CHECK: tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
%18 = tt.splat %9 : i32 -> tensor<128xi32>
%19 = arith.addi %18, %17 : tensor<128xi32>
// CHECK: tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%20 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%21 = arith.mulf %arg3, %cst_3 : f32
%22 = tt.load %10 : !tt.ptr<tensor<128x64xf16>>
%23 = tt.splat %21 : f32 -> tensor<128xf32>
%24 = tt.splat %21 : f32 -> tensor<128x64xf32>
%25:5 = scf.for %arg6 = %c0_i32 to %9 step %c64_i32 iter_args(%arg7 = %cst, %arg8 = %cst_2, %arg9 = %cst_0, %arg10 = %12, %arg11 = %14) -> (tensor<128xf32>, tensor<128x64xf32>, tensor<128xf32>, !tt.ptr<tensor<64x64xf16>>, !tt.ptr<tensor<64x64xf16>>) : i32 {
%39 = tt.load %arg11 : !tt.ptr<tensor<64x64xf16>>
%40 = tt.dot %22, %39, %cst_2, inputPrecision = tf32 : tensor<128x64xf16> * tensor<64x64xf16> -> tensor<128x64xf32>
%41 = "tt.reduce"(%40) <{axis = 1 : i32}> ({
^bb0(%arg12: f32, %arg13: f32):
%62 = arith.maxnumf %arg12, %arg13 : f32
tt.reduce.return %62 : f32
}) : (tensor<128x64xf32>) -> tensor<128xf32>
%42 = arith.mulf %41, %23 : tensor<128xf32>
%43 = arith.maxnumf %arg9, %42 : tensor<128xf32>
%44 = arith.mulf %40, %24 : tensor<128x64xf32>
%45 = tt.expand_dims %43 {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32>
%46 = tt.broadcast %45 : tensor<128x1xf32> -> tensor<128x64xf32>
%47 = arith.subf %44, %46 : tensor<128x64xf32>
%48 = math.exp2 %47 : tensor<128x64xf32>
%49 = "tt.reduce"(%48) <{axis = 1 : i32}> ({
^bb0(%arg12: f32, %arg13: f32):
%62 = arith.addf %arg12, %arg13 : f32
tt.reduce.return %62 : f32
}) : (tensor<128x64xf32>) -> tensor<128xf32>
%50 = arith.subf %arg9, %43 : tensor<128xf32>
%51 = math.exp2 %50 : tensor<128xf32>
%52 = arith.mulf %arg7, %51 : tensor<128xf32>
%53 = arith.addf %52, %49 : tensor<128xf32>
%54 = tt.expand_dims %51 {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32>
%55 = tt.broadcast %54 : tensor<128x1xf32> -> tensor<128x64xf32>
%56 = arith.mulf %arg8, %55 : tensor<128x64xf32>
%57 = tt.load %arg10 : !tt.ptr<tensor<64x64xf16>>
%58 = arith.truncf %48 : tensor<128x64xf32> to tensor<128x64xf16>
%59 = tt.dot %58, %57, %56, inputPrecision = tf32 : tensor<128x64xf16> * tensor<64x64xf16> -> tensor<128x64xf32>
%60 = tt.advance %arg10, [%c64_i32, %c0_i32] : <tensor<64x64xf16>>
%61 = tt.advance %arg11, [%c0_i32, %c64_i32] : <tensor<64x64xf16>>
scf.yield %53, %59, %43, %60, %61 : tensor<128xf32>, tensor<128x64xf32>, tensor<128xf32>, !tt.ptr<tensor<64x64xf16>>, !tt.ptr<tensor<64x64xf16>>
// CHECK1: workload = 4
Dewei-Wang-sh marked this conversation as resolved.
Show resolved Hide resolved
}
gpu.barrier
%26 = arith.muli %0, %c128_i32 : i32
%27 = arith.addi %0, %c1_i32 : i32
%28 = arith.muli %27, %c128_i32 : i32
%29 = tt.advance %14, [%c0_i32, %26] : <tensor<64x64xf16>>
%30 = tt.advance %12, [%26, %c0_i32] : <tensor<64x64xf16>>
// CHECK1: [[EXP_DIM1:%.*]] = tt.expand_dims {{%.*}} {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked>
// CHECK1: [[EXP_DIM2:%.*]] = tt.expand_dims {{%.*}} {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
%31 = tt.expand_dims %19 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32>
%32 = tt.expand_dims %20 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
// CHECK1: [[BC1:%.*]] = tt.broadcast [[EXP_DIM1]] : tensor<128x1xi32, #blocked> -> tensor<128x64xi32, #blocked>
%33 = tt.broadcast %31 : tensor<128x1xi32> -> tensor<128x64xi32>
%34 = tt.splat %21 : f32 -> tensor<128x64xf32>
%35:5 = scf.for %arg6 = %26 to %28 step %c64_i32 iter_args(%arg7 = %25#0, %arg8 = %25#1, %arg9 = %25#2, %arg10 = %30, %arg11 = %29) -> (tensor<128xf32>, tensor<128x64xf32>, tensor<128xf32>, !tt.ptr<tensor<64x64xf16>>, !tt.ptr<tensor<64x64xf16>>) : i32 {
%39 = tt.load %arg11 : !tt.ptr<tensor<64x64xf16>>
%40 = tt.dot %22, %39, %cst_2, inputPrecision = tf32 : tensor<128x64xf16> * tensor<64x64xf16> -> tensor<128x64xf32>
%41 = tt.splat %arg6 : i32 -> tensor<1x64xi32>
// CHECK1: [[OFFSET:%.*]] = arith.addi {{%.*}}, [[EXP_DIM2]] : tensor<1x64xi32, #blocked>
%42 = arith.addi %41, %32 : tensor<1x64xi32>
// CHECK1: [[BC2:%.*]] = tt.broadcast [[OFFSET]] : tensor<1x64xi32, #blocked> -> tensor<128x64xi32, #blocked>
%43 = tt.broadcast %42 : tensor<1x64xi32> -> tensor<128x64xi32>
// CHECK1: arith.cmpi sge, [[BC1]], [[BC2]] : tensor<128x64xi32, #blocked>
%44 = arith.cmpi sge, %33, %43 : tensor<128x64xi32>
%45 = arith.mulf %40, %34 : tensor<128x64xf32>
%46 = arith.select %44, %cst_2, %cst_1 : tensor<128x64xi1>, tensor<128x64xf32>
%47 = arith.addf %45, %46 : tensor<128x64xf32>
%48 = "tt.reduce"(%47) <{axis = 1 : i32}> ({
^bb0(%arg12: f32, %arg13: f32):
%67 = arith.maxnumf %arg12, %arg13 : f32
tt.reduce.return %67 : f32
}) : (tensor<128x64xf32>) -> tensor<128xf32>
%49 = arith.maxnumf %arg9, %48 : tensor<128xf32>
%50 = tt.expand_dims %49 {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32>
%51 = tt.broadcast %50 : tensor<128x1xf32> -> tensor<128x64xf32>
%52 = arith.subf %47, %51 : tensor<128x64xf32>
%53 = math.exp2 %52 : tensor<128x64xf32>
%54 = "tt.reduce"(%53) <{axis = 1 : i32}> ({
^bb0(%arg12: f32, %arg13: f32):
%67 = arith.addf %arg12, %arg13 : f32
tt.reduce.return %67 : f32
}) : (tensor<128x64xf32>) -> tensor<128xf32>
%55 = arith.subf %arg9, %49 : tensor<128xf32>
%56 = math.exp2 %55 : tensor<128xf32>
%57 = arith.mulf %arg7, %56 : tensor<128xf32>
%58 = arith.addf %57, %54 : tensor<128xf32>
%59 = tt.expand_dims %56 {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32>
%60 = tt.broadcast %59 : tensor<128x1xf32> -> tensor<128x64xf32>
%61 = arith.mulf %arg8, %60 : tensor<128x64xf32>
%62 = tt.load %arg10 : !tt.ptr<tensor<64x64xf16>>
%63 = arith.truncf %53 : tensor<128x64xf32> to tensor<128x64xf16>
%64 = tt.dot %63, %62, %61, inputPrecision = tf32 : tensor<128x64xf16> * tensor<64x64xf16> -> tensor<128x64xf32>
%65 = tt.advance %arg10, [%c64_i32, %c0_i32] : <tensor<64x64xf16>>
%66 = tt.advance %arg11, [%c0_i32, %c64_i32] : <tensor<64x64xf16>>
scf.yield %58, %64, %49, %65, %66 : tensor<128xf32>, tensor<128x64xf32>, tensor<128xf32>, !tt.ptr<tensor<64x64xf16>>, !tt.ptr<tensor<64x64xf16>>
// CHECK1: workload = 4
}
%36 = tt.expand_dims %35#0 {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32>
%37 = tt.broadcast %36 : tensor<128x1xf32> -> tensor<128x64xf32>
%38 = arith.divf %35#1, %37 : tensor<128x64xf32>
tt.store %16, %38 : !tt.ptr<tensor<128x64xf32>>
tt.return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ class ConvertTritonToTritonGPUWarp
return;
valueAttrMap.clear();
DenseMap<scf::ForOp, LoopDotInfo> loopMap;
SmallVector<Workload, 2> workloads;
for (auto loop : loops) {
auto dots = llvm::to_vector(loop.getOps<tt::DotOp>());
assert(dots.size() <= 2 && "only support 1 or 2 dot in a loop");
Expand All @@ -191,6 +192,7 @@ class ConvertTritonToTritonGPUWarp
loopMap[loop] = loopDotInfo;
// DAG pattern match
Workload workLoadKind = matchLoopWorkload(loop, loopDotInfo);
workloads.push_back(workLoadKind);

/// get tensor layout attr according to workload pattern
switch (workLoadKind) {
Expand Down Expand Up @@ -383,6 +385,55 @@ class ConvertTritonToTritonGPUWarp
}
return WalkResult::advance();
});

if (loops.size() == 2 && workloads.front() == Workload::Attention &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the loop size has to be 2 ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flashAttention with causal mask has 2 loops.

workloads.back() == Workload::Attention) {
// match attention with causal masking
// FIXME: This is a workaround to attach layouts to tensor ops that have
// not been handled before. This should instead be covered by a
// more generic layout propagation approach.
Attribute blockLayout = loopMap[loops.front()]
.dotInfo0.dot.getResult()
.getType()
.getEncoding();

func.walk<WalkOrder::PreOrder>([&](Operation *op) {
SmallVector<RankedTensorType> typesWithoutEncoding;
for (Type ty : op->getResultTypes()) {
if (auto tty = dyn_cast<RankedTensorType>(ty))
if (!tty.getEncoding())
typesWithoutEncoding.push_back(tty);
}

if (typesWithoutEncoding.empty())
return;

if (auto cst = dyn_cast<arith::ConstantOp>(op)) {
transformArithConstantOp(cst, blockLayout);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here just assume the type w/o encoding has blockLayout, right?
I have a local change that aims to cover the full propagation, making most of the case go to L405 early return.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this IR walk just patches previously unhandled ops, but is completely specific to causal flash attention. I added a FIXME above to make that clearer. So great if we could drop this workaround soon.

return;
}

// Assign:
// - rank-2 operations: block layout
// - rank-1 operations: slice layout
assert(op->getNumResults() == 1 &&
"Unexpected tensor operation with multiple results");
OpResult res = op->getOpResult(0);
auto tty = cast<RankedTensorType>(res.getType());
if (tty.getRank() == 2)
res.setType(addAttrToType(tty, blockLayout));
// Rank==1 tensors get a slice layout with the axis depending on the
// use.
if (auto expand = dyn_cast<tt::ExpandDimsOp>(op)) {
Attribute sliceLayout = triton::gpu::SliceEncodingAttr::get(
blockLayout.getContext(), expand.getAxis(), blockLayout);
DenseSet<Value> chainedVals;
expandDefChain(expand.getSrc(), chainedVals);
for (auto cv : chainedVals)
cv.setType(addAttrToType(cv.getType(), sliceLayout));
}
});
}
}

/// adding module attributes
Expand Down Expand Up @@ -414,9 +465,14 @@ class ConvertTritonToTritonGPUWarp
result.setType(newType);
} else if (auto expand = dyn_cast<tt::ExpandDimsOp>(op)) {
auto src = expand.getSrc();
auto attr = cast<ttg::SliceEncodingAttr>(src.getType().getEncoding());
Type newType = addAttrToType(result.getType(), attr.getParent());
result.setType(newType);
if (auto attr = dyn_cast_if_present<ttg::SliceEncodingAttr>(
src.getType().getEncoding())) {
Type newType = addAttrToType(result.getType(), attr.getParent());
result.setType(newType);
}
// else: will patch the encoding later in the causal-attention-specific
// layout propagation.
// FIXME: Remove this workaround.
}
}

Expand Down Expand Up @@ -666,6 +722,11 @@ class ConvertTritonToTritonGPUWarp
Value res = loop.getResult(use.getOperandNumber());
chainedVals.insert(res);
expandUseChain(res, chainedVals);
} else if (auto forLoop = dyn_cast<scf::ForOp>(op)) {
auto arg = forLoop.getRegionIterArg(use.getOperandNumber() -
forLoop.getNumControlOperands());
chainedVals.insert(arg);
expandUseChain(arg, chainedVals);
// expanddims, splat, store
} else if (isa<tt::ExpandDimsOp, tt::SplatOp, tt::StoreOp>(op)) {
continue;
Expand All @@ -685,15 +746,21 @@ class ConvertTritonToTritonGPUWarp
assert(loop);
auto loopArg = loop.getInitArgs()[arg.getArgNumber() - 1];
expandDefChain(loopArg, chainedVals);
} else if (auto def = val.getDefiningOp()) {
} else if (auto opRes = dyn_cast<OpResult>(val)) {
Operation *def = opRes.getOwner();
if (def->getDialect() == arithDialect ||
def->getDialect() == mathDialect) {
for (auto operand : def->getOperands()) {
expandDefChain(operand, chainedVals);
expandUseChain(operand, chainedVals);
}
} else if (isa<tt::SplatOp, tt::BroadcastOp, tt::ReduceOp>(def)) {
;
} else if (auto forLoop = dyn_cast<scf::ForOp>(def)) {
Value yieldArg = forLoop.getYieldedValues()[opRes.getResultNumber()];
chainedVals.insert(yieldArg);
expandDefChain(yieldArg, chainedVals);
} else if (isa<tt::SplatOp, tt::BroadcastOp, tt::ReduceOp,
tt::MakeRangeOp>(def)) {
chainedVals.insert(def->getResult(0));
} else {
assert(0 && "add more support");
}
Expand Down