Skip to content

Commit

Permalink
Propagate mma layout to atomic_rmw op (#2312)
Browse files Browse the repository at this point in the history
This change help resolve issue
[#1716](#1716).
Propagating mma layout from dot to atomic_rmw op help eliminating
`convert_layout` op from/to large size mma layout, which requires
oversized shared memory.
  • Loading branch information
LiyangLingIntel authored Sep 27, 2024
1 parent 59edf2c commit 659470b
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 1 deletion.
109 changes: 109 additions & 0 deletions test/TritonIntelGPU/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2297,3 +2297,112 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
tt.return %3 : tensor<128x256xf32, #blocked>
}
}


// -----

// COM: Check that dpas layout can be propagated from dot op to atomic_rmw op
// CHECK-NOT: #triton_gpu.blocked<{.*}>
// CHECK: #[[$DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [32], order = [0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [32, 1], order = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 16], order = [1, 0]}>
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}>
#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 32], order = [0, 1]}>
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: tt.func public @propagate_mma_to_atomic_rmw
tt.func public @propagate_mma_to_atomic_rmw(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>, %arg2: !tt.ptr<f32>) attributes {noinline = false} {
%c0_i32 = arith.constant 0 : i32
%c1_i64 = arith.constant 1 : i64
%c32_i32 = arith.constant 32 : i32
%c128_i32 = arith.constant 128 : i32
%c256_i32 = arith.constant 256 : i32
%c4096_i32 = arith.constant 4096 : i32
%c4096_i64 = arith.constant 4096 : i64
%cst = arith.constant dense<4096> : tensor<256xi32, #blocked>
%cst_1 = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #blocked2>
%0 = tt.get_program_id x : i32
%1 = tt.get_program_id y : i32
// CHECK: %[[VAL_0:.*]] = tt.make_tensor_ptr {{.*}} : <tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 2}>>>
// CHECK: %[[VAL_1:.*]] = tt.make_tensor_ptr {{.*}} : <tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
%12 = tt.make_tensor_ptr %arg0, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%0, %1] {order = array<i32: 1, 0>} : <tensor<256x32xbf16, #blocked3>>
%14 = tt.make_tensor_ptr %arg1, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%0, %1] {order = array<i32: 1, 0>} : <tensor<32x256xbf16, #blocked2>>
// CHECK: %[[VAL_2:.*]]:3 = scf.for {{.*}} -> (tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr<tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 2}>>>, !tt.ptr<tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>) : i32 {
%15:3 = scf.for %arg3 = %c0_i32 to %c4096_i32 step %c128_i32 iter_args(%arg4 = %cst_1, %arg5 = %12, %arg6 = %14) -> (tensor<256x256xf32, #blocked2>, !tt.ptr<tensor<256x32xbf16, #blocked3>>, !tt.ptr<tensor<32x256xbf16, #blocked2>>) : i32 {
%47 = tt.load %arg5 : !tt.ptr<tensor<256x32xbf16, #blocked3>>
%48 = tt.load %arg6 : !tt.ptr<tensor<32x256xbf16, #blocked2>>
// CHEKC-NOT: triton_gpu.convert_layout
%49 = triton_gpu.convert_layout %arg4 : tensor<256x256xf32, #blocked2> -> tensor<256x256xf32, #mma>
%50 = triton_gpu.convert_layout %47 : tensor<256x32xbf16, #blocked3> -> tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%51 = triton_gpu.convert_layout %48 : tensor<32x256xbf16, #blocked2> -> tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%52 = tt.dot %50, %51, %49, inputPrecision = tf32 : tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma>
%53 = triton_gpu.convert_layout %52 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked2>
// CHECK: %[[VAL_3:.*]] = tt.advance {{.*}} : <tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 2}>>>
// CHECK: %[[VAL_4:.*]] = tt.advance {{.*}} : <tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
// CHECK: scf.yield {{.*}} : tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr<tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 2}>>>, !tt.ptr<tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
%54 = tt.advance %arg5, [%c0_i32, %c128_i32] : <tensor<256x32xbf16, #blocked3>>
%55 = tt.advance %arg6, [%c128_i32, %c0_i32] : <tensor<32x256xbf16, #blocked2>>
scf.yield %53, %54, %55 : tensor<256x256xf32, #blocked2>, !tt.ptr<tensor<256x32xbf16, #blocked3>>, !tt.ptr<tensor<32x256xbf16, #blocked2>>
}
%16 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked>
%32 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<256x256x!tt.ptr<f32>, #blocked2>
%38 = arith.cmpi slt, %16, %cst : tensor<256xi32, #blocked>
// CHEKC-NOT: triton_gpu.convert_layout
%39 = triton_gpu.convert_layout %38 : tensor<256xi1, #blocked> -> tensor<256xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>
%40 = tt.expand_dims %39 {axis = 0 : i32} : tensor<256xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> -> tensor<1x256xi1, #blocked4>
%41 = triton_gpu.convert_layout %40 : tensor<1x256xi1, #blocked4> -> tensor<1x256xi1, #blocked2>
%42 = tt.broadcast %41 : tensor<1x256xi1, #blocked2> -> tensor<256x256xi1, #blocked2>
// CHECK: %[[VAL_5:.*]] = tt.atomic_rmw fadd, acq_rel, gpu, {{.*}} : (tensor<256x256x!tt.ptr<f32>, #[[$DPAS]]>, tensor<256x256xf32, #[[$DPAS]]>, tensor<256x256xi1, #[[$DPAS]]>) -> tensor<256x256xf32, #[[$DPAS]]>
%46 = tt.atomic_rmw fadd, acq_rel, gpu, %32, %15#0, %42 : (tensor<256x256x!tt.ptr<f32>, #blocked2>, tensor<256x256xf32, #blocked2>, tensor<256x256xi1, #blocked2>) -> tensor<256x256xf32, #blocked2>
tt.return
}
}


// -----

// COM: Check that bare atomic_rmw op with blocked layout can still be propagated to dpas layout
// COM: Blocked layout will not backpropagate to overwrite dpas layout
// CHECK-NOT: #triton_gpu.blocked<{.*}>
// CHECK: #[[$DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}>
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: tt.func public @bare_atomic_with_blocked_layout
tt.func public @bare_atomic_with_blocked_layout(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>, %arg2: !tt.ptr<f32>) attributes {noinline = false} {
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
%cst_0 = arith.constant dense<3072> : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%c1_i64 = arith.constant 1 : i64
%c0_i32 = arith.constant 0 : i32
%c128_i32 = arith.constant 128 : i32
%c4096_i64 = arith.constant 4096 : i64
%c4096_i32 = arith.constant 4096 : i32
%0 = tt.get_program_id x : i32
%1 = tt.get_program_id y : i32
%12 = tt.make_tensor_ptr %arg0, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%0, %1] {order = array<i32: 1, 0>} : <tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>>
%14 = tt.make_tensor_ptr %arg1, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%0, %1] {order = array<i32: 1, 0>} : <tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
%15:3 = scf.for %arg3 = %c0_i32 to %c4096_i32 step %c128_i32 iter_args(%arg4 = %cst, %arg5 = %12, %arg6 = %14) -> (tensor<256x256xf32, #mma>, !tt.ptr<tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>>, !tt.ptr<tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>) : i32 {
%41 = tt.advance %arg5, [%c0_i32, %c128_i32] : <tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>>
%42 = tt.advance %arg6, [%c128_i32, %c0_i32] : <tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
%43 = tt.load %arg5 {triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>>
%44 = tt.load %arg6 {triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
%45 = tt.dot %43, %44, %arg4, inputPrecision = tf32 : tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma>
scf.yield %45, %41, %42 : tensor<256x256xf32, #mma>, !tt.ptr<tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>>, !tt.ptr<tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
}
%18 = tt.splat %0 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%28 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<256x256x!tt.ptr<f32>, #mma>
%30 = arith.cmpi slt, %18, %cst_0 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%31 = tt.expand_dims %30 {axis = 1 : i32} : tensor<256xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<256x1xi1, #mma>
%34 = tt.broadcast %31 : tensor<256x1xi1, #mma> -> tensor<256x256xi1, #mma>
// CHECK-NOT: triton_gpu.convert_layout
%37 = triton_gpu.convert_layout %28 : tensor<256x256x!tt.ptr<f32>, #mma> -> tensor<256x256x!tt.ptr<f32>, #blocked>
%38 = triton_gpu.convert_layout %15#0 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked>
%39 = triton_gpu.convert_layout %34 : tensor<256x256xi1, #mma> -> tensor<256x256xi1, #blocked>
// CHECK: %[[VAL_0:.*]] = tt.atomic_rmw fadd, acq_rel, gpu, {{.*}} : (tensor<256x256x!tt.ptr<f32>, #[[$DPAS]]>, tensor<256x256xf32, #[[$DPAS]]>, tensor<256x256xi1, #[[$DPAS]]>) -> tensor<256x256xf32, #[[$DPAS]]>
%40 = tt.atomic_rmw fadd, acq_rel, gpu, %37, %38, %39 : (tensor<256x256x!tt.ptr<f32>, #blocked>, tensor<256x256xf32, #blocked>, tensor<256x256xi1, #blocked>) -> tensor<256x256xf32, #blocked>
// CHECK-NOT: triton_gpu.convert_layout
%41 = triton_gpu.convert_layout %40 : tensor<256x256xf32, #blocked> -> tensor<256x256xf32, #mma>
tt.return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"

#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"

namespace mlir::triton::gpu::intel {
Expand Down Expand Up @@ -252,6 +253,19 @@ bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
}
}
}

// HACK: we want to propagate mma layout to the atomic_rmw op, so we do
// not need an extra ConvertLayout Op to convert layout from mma to other
// layouts, which may consume excessive shared local memory.
// TODO: we need to investigate the performance impact of atomic_rmw op
// with mma layout, compared with ConvertLayout Op + atomic_rmw op with
// blocked layout.
if (auto atomicOp = dyn_cast<AtomicRMWOp>(op)) {
auto tensorType =
dyn_cast<RankedTensorType>(atomicOp.getResult().getType());
if (tensorType && isa<MmaEncodingTrait>(tensorType.getEncoding()))
return true;
}
bool isMMAV3 =
isa<NvidiaMmaEncodingAttr>(encoding) &&
cast<NvidiaMmaEncodingAttr>(encoding).getVersionMajor() == 3;
Expand All @@ -272,6 +286,11 @@ bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
auto forOp = dyn_cast<scf::ForOp>(yield.getOperation()->getParentOp());
if (!forOp)
continue;
for (OpOperand &operand : forOp.getResult(0).getUses()) {
Operation *def = operand.get().getDefiningOp();
if (def && (seen.insert(operand.get()).second == true))
queue.push_back(operand.get());
}
for (OpOperand &operand : yield->getOpOperands()) {
Operation *def = operand.get().getDefiningOp();
if (def && (forwardSlice.count(def) || operand.get() == currentValue) &&
Expand All @@ -288,8 +307,12 @@ bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
bool isLayoutAnchor(Operation *op) {
if (isa<LoadOp, StoreOp>(op))
return ttgi::isExpensiveLoadOrStore(op);
if (isa<DotOp, AtomicRMWOp, AtomicCASOp>(op))
if (isa<DotOp, AtomicCASOp>(op))
return true;
if (isa<AtomicRMWOp>(op))
if (auto tensorType =
dyn_cast<RankedTensorType>(op->getResult(0).getType()))
return isa<MmaEncodingTrait>(tensorType.getEncoding());

// Heuristic: Mark permuting reshape as a layout anchor. Its dst can be
// anything, so it stops forward-propagation of layouts. We rely on the
Expand Down Expand Up @@ -402,6 +425,15 @@ SmallVector<Value> LayoutPropagation::propagateToUsers(Value value,
setEncoding({afterArg, result}, info, changed, user);
continue;
}
if (auto atomicRMWOp = dyn_cast<AtomicRMWOp>(user)) {
bool isBlockedOrMma = std::all_of(
info.encodings.begin(), info.encodings.end(), [](Attribute encoding) {
return isa<BlockedEncodingAttr, MmaEncodingTrait>(encoding);
});
if (isBlockedOrMma)
setEncoding(user->getResults(), info, changed, user);
continue;
}
if (user->hasTrait<OpTrait::SameOperandsAndResultEncoding>() ||
user->hasTrait<OpTrait::Elementwise>() ||
isa<ReduceOp, ExpandDimsOp, ReshapeOp, TransOp, JoinOp, SplitOp,
Expand Down

0 comments on commit 659470b

Please sign in to comment.