From 835e8a04b122e42a1e1113221863355040482e5c Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Wed, 25 Sep 2024 20:27:29 +0000 Subject: [PATCH] Support column major tensor pointer in llvm lowering --- .../rewrite-tensor-pointer.mlir | 56 +++++++++++++++++++ .../RewriteTensorPointer.cpp | 17 +++++- 2 files changed, 70 insertions(+), 3 deletions(-) diff --git a/test/TritonIntelGPU/rewrite-tensor-pointer.mlir b/test/TritonIntelGPU/rewrite-tensor-pointer.mlir index fd35510906..1ee5d5f876 100644 --- a/test/TritonIntelGPU/rewrite-tensor-pointer.mlir +++ b/test/TritonIntelGPU/rewrite-tensor-pointer.mlir @@ -274,3 +274,59 @@ module attributes {"triton_intel_gpu.support_sg_2d_block"} { tt.return } } + +// ----- + +// COM: Case 4: +// COM: Check that a matrix multiplication of two tensor pointers with block_io attributes is not rewritten +// CHECK: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_bf16_conversion, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { + tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + // CHECK: @matmul_kernel_with_block_pointers + %c4_i32 = arith.constant 4 : i32 + %c256_i32 = arith.constant 256 : i32 + %c1024_i64 = arith.constant 1024 : i64 + %c5120_i64 = arith.constant 5120 : i64 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %c4096_i64 = arith.constant 4096 : i64 + %c32_i32 = arith.constant 32 : i32 + %c64_i32 = arith.constant 64 : i32 + %c5120_i32 = arith.constant 5120 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #dpas> + %0 = tt.get_program_id x : i32 + %1 = arith.divsi %0, %c64_i32 : i32 + %2 = arith.muli %1, %c4_i32 : i32 + %3 = arith.subi %c4_i32, %2 : i32 + %4 = arith.minsi %3, %c4_i32 : i32 + %5 = arith.remsi %0, %4 : i32 + %6 = arith.addi %2, %5 : i32 + %7 = arith.remsi %0, %c64_i32 : i32 + %8 = arith.divsi %7, %4 : i32 + %9 = arith.muli %6, %c256_i32 : i32 + // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : >> + %10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%9, %c0_i32] {order = array} : >> + %11 = arith.muli %8, %c256_i32 : i32 + // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : >> + %12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c1_i64, %c5120_i64], [%c0_i32, %11] {order = array} : >> + %13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #dpas>, !tt.ptr>>, !tt.ptr>>) : i32 { + // CHECK: tt.load {{.*}} {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr>> + // CHECK: tt.load {{.*}} {boundaryCheck = array, triton_intel_gpu.block_io = "column_major"} : !tt.ptr>> + %16 = tt.load %arg5 {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr>> + %17 = tt.load %arg6 {boundaryCheck = array, triton_intel_gpu.block_io = "column_major"} : !tt.ptr>> + // CHECK: tt.dot {{.*}}, {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[DPAS]]> + // CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : >> + // CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : >> + %18 = tt.dot %16, %17, %arg4, inputPrecision = tf32 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>> -> tensor<256x256xf32, #dpas> + %19 = tt.advance %arg5, [%c0_i32, %c32_i32] : >> + %20 = tt.advance %arg6, [%c32_i32, %c0_i32] : >> + scf.yield %18, %19, %20 : tensor<256x256xf32, #dpas>, !tt.ptr>>, !tt.ptr>> + } + %14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array} : > + %15 = arith.truncf %13#0 : tensor<256x256xf32, #dpas> to tensor<256x256xf16, #dpas> + // CHECK: tt.store {{.*}}, {{.*}}, {{.*}} : !tt.ptr + tt.store %14, %15 {boundaryCheck = array} : !tt.ptr> + tt.return + } +} diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp index ecfa0f4659..a60e0fa26d 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp @@ -46,15 +46,26 @@ bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByStoreOp) { TypedValue base = op.getBase(); Operation::operand_range shape = op.getShape(); + unsigned rank = shape.size(); Operation::operand_range strides = op.getStrides(); Operation::operand_range offsets = op.getOffsets(); ArrayRef order = op.getOrder(); ArrayRef tensorShape = tensorType.getShape(); - // TODO: support column-major tensor + int fastChangeDim = -1; + for (size_t i = 0; i < strides.size(); i++) { + if (mlir::triton::gpu::intel::isConstant(strides[i], 1)) { + fastChangeDim = i; + break; + } + } + if (fastChangeDim < 0) { + return false; + } + // HW 2D block read instruction has restriction on pitch divisibility if (strides.size() == 2) { - auto pitch = strides[0]; + auto pitch = strides[(fastChangeDim == rank - 1) ? rank - 2 : rank - 1]; // Across Intel platforms, the strictest pitch restriction is to be a // multiple of OWord(128 bits). if (!ttgi::isDivisible(pitch, 128 / tensorType.getElementTypeBitWidth())) @@ -62,7 +73,7 @@ bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByStoreOp) { } // HW 2D block read instruction only supports contiguous accessing. - auto fastChangeStride = strides[1]; + auto fastChangeStride = strides[fastChangeDim]; if (auto stride = fastChangeStride.getDefiningOp()) { if (auto strideInt = dyn_cast(stride.getValue())) return strideInt.getInt() != 1;