From 659470b340e50376d977b4b241bb6b4d1da0c861 Mon Sep 17 00:00:00 2001 From: Liyang Ling Date: Fri, 27 Sep 2024 10:38:07 +0800 Subject: [PATCH] Propagate mma layout to atomic_rmw op (#2312) This change help resolve issue [#1716](https://github.com/intel/intel-xpu-backend-for-triton/issues/1716). Propagating mma layout from dot to atomic_rmw op help eliminating `convert_layout` op from/to large size mma layout, which requires oversized shared memory. --- test/TritonIntelGPU/combine.mlir | 109 ++++++++++++++++++ .../RemoveLayoutConversions.cpp | 34 +++++- 2 files changed, 142 insertions(+), 1 deletion(-) diff --git a/test/TritonIntelGPU/combine.mlir b/test/TritonIntelGPU/combine.mlir index 35794ced83..5e2b8a3b54 100644 --- a/test/TritonIntelGPU/combine.mlir +++ b/test/TritonIntelGPU/combine.mlir @@ -2297,3 +2297,112 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : tt.return %3 : tensor<128x256xf32, #blocked> } } + + +// ----- + +// COM: Check that dpas layout can be propagated from dot op to atomic_rmw op +// CHECK-NOT: #triton_gpu.blocked<{.*}> +// CHECK: #[[$DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [32], order = [0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [32, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 16], order = [1, 0]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> +#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 32], order = [0, 1]}> +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: tt.func public @propagate_mma_to_atomic_rmw + tt.func public @propagate_mma_to_atomic_rmw(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) attributes {noinline = false} { + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c32_i32 = arith.constant 32 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %c4096_i32 = arith.constant 4096 : i32 + %c4096_i64 = arith.constant 4096 : i64 + %cst = arith.constant dense<4096> : tensor<256xi32, #blocked> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #blocked2> + %0 = tt.get_program_id x : i32 + %1 = tt.get_program_id y : i32 + // CHECK: %[[VAL_0:.*]] = tt.make_tensor_ptr {{.*}} : >> + // CHECK: %[[VAL_1:.*]] = tt.make_tensor_ptr {{.*}} : >> + %12 = tt.make_tensor_ptr %arg0, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%0, %1] {order = array} : > + %14 = tt.make_tensor_ptr %arg1, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%0, %1] {order = array} : > + // CHECK: %[[VAL_2:.*]]:3 = scf.for {{.*}} -> (tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr>>, !tt.ptr>>) : i32 { + %15:3 = scf.for %arg3 = %c0_i32 to %c4096_i32 step %c128_i32 iter_args(%arg4 = %cst_1, %arg5 = %12, %arg6 = %14) -> (tensor<256x256xf32, #blocked2>, !tt.ptr>, !tt.ptr>) : i32 { + %47 = tt.load %arg5 : !tt.ptr> + %48 = tt.load %arg6 : !tt.ptr> + // CHEKC-NOT: triton_gpu.convert_layout + %49 = triton_gpu.convert_layout %arg4 : tensor<256x256xf32, #blocked2> -> tensor<256x256xf32, #mma> + %50 = triton_gpu.convert_layout %47 : tensor<256x32xbf16, #blocked3> -> tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %51 = triton_gpu.convert_layout %48 : tensor<32x256xbf16, #blocked2> -> tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %52 = tt.dot %50, %51, %49, inputPrecision = tf32 : tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + %53 = triton_gpu.convert_layout %52 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked2> + // CHECK: %[[VAL_3:.*]] = tt.advance {{.*}} : >> + // CHECK: %[[VAL_4:.*]] = tt.advance {{.*}} : >> + // CHECK: scf.yield {{.*}} : tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr>>, !tt.ptr>> + %54 = tt.advance %arg5, [%c0_i32, %c128_i32] : > + %55 = tt.advance %arg6, [%c128_i32, %c0_i32] : > + scf.yield %53, %54, %55 : tensor<256x256xf32, #blocked2>, !tt.ptr>, !tt.ptr> + } + %16 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked> + %32 = tt.splat %arg2 : !tt.ptr -> tensor<256x256x!tt.ptr, #blocked2> + %38 = arith.cmpi slt, %16, %cst : tensor<256xi32, #blocked> + // CHEKC-NOT: triton_gpu.convert_layout + %39 = triton_gpu.convert_layout %38 : tensor<256xi1, #blocked> -> tensor<256xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %40 = tt.expand_dims %39 {axis = 0 : i32} : tensor<256xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> -> tensor<1x256xi1, #blocked4> + %41 = triton_gpu.convert_layout %40 : tensor<1x256xi1, #blocked4> -> tensor<1x256xi1, #blocked2> + %42 = tt.broadcast %41 : tensor<1x256xi1, #blocked2> -> tensor<256x256xi1, #blocked2> + // CHECK: %[[VAL_5:.*]] = tt.atomic_rmw fadd, acq_rel, gpu, {{.*}} : (tensor<256x256x!tt.ptr, #[[$DPAS]]>, tensor<256x256xf32, #[[$DPAS]]>, tensor<256x256xi1, #[[$DPAS]]>) -> tensor<256x256xf32, #[[$DPAS]]> + %46 = tt.atomic_rmw fadd, acq_rel, gpu, %32, %15#0, %42 : (tensor<256x256x!tt.ptr, #blocked2>, tensor<256x256xf32, #blocked2>, tensor<256x256xi1, #blocked2>) -> tensor<256x256xf32, #blocked2> + tt.return + } +} + + +// ----- + +// COM: Check that bare atomic_rmw op with blocked layout can still be propagated to dpas layout +// COM: Blocked layout will not backpropagate to overwrite dpas layout +// CHECK-NOT: #triton_gpu.blocked<{.*}> +// CHECK: #[[$DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: tt.func public @bare_atomic_with_blocked_layout + tt.func public @bare_atomic_with_blocked_layout(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %cst_0 = arith.constant dense<3072> : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %c128_i32 = arith.constant 128 : i32 + %c4096_i64 = arith.constant 4096 : i64 + %c4096_i32 = arith.constant 4096 : i32 + %0 = tt.get_program_id x : i32 + %1 = tt.get_program_id y : i32 + %12 = tt.make_tensor_ptr %arg0, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%0, %1] {order = array} : >> + %14 = tt.make_tensor_ptr %arg1, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%0, %1] {order = array} : >> + %15:3 = scf.for %arg3 = %c0_i32 to %c4096_i32 step %c128_i32 iter_args(%arg4 = %cst, %arg5 = %12, %arg6 = %14) -> (tensor<256x256xf32, #mma>, !tt.ptr>>, !tt.ptr>>) : i32 { + %41 = tt.advance %arg5, [%c0_i32, %c128_i32] : >> + %42 = tt.advance %arg6, [%c128_i32, %c0_i32] : >> + %43 = tt.load %arg5 {triton_intel_gpu.block_io = "row_major"} : !tt.ptr>> + %44 = tt.load %arg6 {triton_intel_gpu.block_io = "row_major"} : !tt.ptr>> + %45 = tt.dot %43, %44, %arg4, inputPrecision = tf32 : tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + scf.yield %45, %41, %42 : tensor<256x256xf32, #mma>, !tt.ptr>>, !tt.ptr>> + } + %18 = tt.splat %0 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %28 = tt.splat %arg2 : !tt.ptr -> tensor<256x256x!tt.ptr, #mma> + %30 = arith.cmpi slt, %18, %cst_0 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %31 = tt.expand_dims %30 {axis = 1 : i32} : tensor<256xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<256x1xi1, #mma> + %34 = tt.broadcast %31 : tensor<256x1xi1, #mma> -> tensor<256x256xi1, #mma> + // CHECK-NOT: triton_gpu.convert_layout + %37 = triton_gpu.convert_layout %28 : tensor<256x256x!tt.ptr, #mma> -> tensor<256x256x!tt.ptr, #blocked> + %38 = triton_gpu.convert_layout %15#0 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked> + %39 = triton_gpu.convert_layout %34 : tensor<256x256xi1, #mma> -> tensor<256x256xi1, #blocked> + // CHECK: %[[VAL_0:.*]] = tt.atomic_rmw fadd, acq_rel, gpu, {{.*}} : (tensor<256x256x!tt.ptr, #[[$DPAS]]>, tensor<256x256xf32, #[[$DPAS]]>, tensor<256x256xi1, #[[$DPAS]]>) -> tensor<256x256xf32, #[[$DPAS]]> + %40 = tt.atomic_rmw fadd, acq_rel, gpu, %37, %38, %39 : (tensor<256x256x!tt.ptr, #blocked>, tensor<256x256xf32, #blocked>, tensor<256x256xi1, #blocked>) -> tensor<256x256xf32, #blocked> + // CHECK-NOT: triton_gpu.convert_layout + %41 = triton_gpu.convert_layout %40 : tensor<256x256xf32, #blocked> -> tensor<256x256xf32, #mma> + tt.return + } +} diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index 2a781589c4..e91cfa34c0 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -13,6 +13,7 @@ #include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h" #include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" namespace mlir::triton::gpu::intel { @@ -252,6 +253,19 @@ bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) { } } } + + // HACK: we want to propagate mma layout to the atomic_rmw op, so we do + // not need an extra ConvertLayout Op to convert layout from mma to other + // layouts, which may consume excessive shared local memory. + // TODO: we need to investigate the performance impact of atomic_rmw op + // with mma layout, compared with ConvertLayout Op + atomic_rmw op with + // blocked layout. + if (auto atomicOp = dyn_cast(op)) { + auto tensorType = + dyn_cast(atomicOp.getResult().getType()); + if (tensorType && isa(tensorType.getEncoding())) + return true; + } bool isMMAV3 = isa(encoding) && cast(encoding).getVersionMajor() == 3; @@ -272,6 +286,11 @@ bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) { auto forOp = dyn_cast(yield.getOperation()->getParentOp()); if (!forOp) continue; + for (OpOperand &operand : forOp.getResult(0).getUses()) { + Operation *def = operand.get().getDefiningOp(); + if (def && (seen.insert(operand.get()).second == true)) + queue.push_back(operand.get()); + } for (OpOperand &operand : yield->getOpOperands()) { Operation *def = operand.get().getDefiningOp(); if (def && (forwardSlice.count(def) || operand.get() == currentValue) && @@ -288,8 +307,12 @@ bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) { bool isLayoutAnchor(Operation *op) { if (isa(op)) return ttgi::isExpensiveLoadOrStore(op); - if (isa(op)) + if (isa(op)) return true; + if (isa(op)) + if (auto tensorType = + dyn_cast(op->getResult(0).getType())) + return isa(tensorType.getEncoding()); // Heuristic: Mark permuting reshape as a layout anchor. Its dst can be // anything, so it stops forward-propagation of layouts. We rely on the @@ -402,6 +425,15 @@ SmallVector LayoutPropagation::propagateToUsers(Value value, setEncoding({afterArg, result}, info, changed, user); continue; } + if (auto atomicRMWOp = dyn_cast(user)) { + bool isBlockedOrMma = std::all_of( + info.encodings.begin(), info.encodings.end(), [](Attribute encoding) { + return isa(encoding); + }); + if (isBlockedOrMma) + setEncoding(user->getResults(), info, changed, user); + continue; + } if (user->hasTrait() || user->hasTrait() || isa