Skip to content

Commit

Permalink
Support column major tensor pointer in llvm lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbaden committed Sep 25, 2024
1 parent db2a0e9 commit ae75f93
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 2 deletions.
56 changes: 56 additions & 0 deletions test/TritonIntelGPU/rewrite-tensor-pointer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,59 @@ module attributes {"triton_intel_gpu.support_sg_2d_block"} {
tt.return
}
}

// -----

// COM: Case 4:
// COM: Check that a matrix multiplication of two tensor pointers with block_io attributes is not rewritten
// CHECK: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_bf16_conversion, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
// CHECK: @matmul_kernel_with_block_pointers
%c4_i32 = arith.constant 4 : i32
%c256_i32 = arith.constant 256 : i32
%c1024_i64 = arith.constant 1024 : i64
%c5120_i64 = arith.constant 5120 : i64
%c1_i64 = arith.constant 1 : i64
%c0_i32 = arith.constant 0 : i32
%c4096_i64 = arith.constant 4096 : i64
%c32_i32 = arith.constant 32 : i32
%c64_i32 = arith.constant 64 : i32
%c5120_i32 = arith.constant 5120 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #dpas>
%0 = tt.get_program_id x : i32
%1 = arith.divsi %0, %c64_i32 : i32
%2 = arith.muli %1, %c4_i32 : i32
%3 = arith.subi %c4_i32, %2 : i32
%4 = arith.minsi %3, %c4_i32 : i32
%5 = arith.remsi %0, %4 : i32
%6 = arith.addi %2, %5 : i32
%7 = arith.remsi %0, %c64_i32 : i32
%8 = arith.divsi %7, %4 : i32
%9 = arith.muli %6, %c256_i32 : i32
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
%10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%9, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>>>
%11 = arith.muli %8, %c256_i32 : i32
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 0, 1>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
%12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c1_i64, %c5120_i64], [%c0_i32, %11] {order = array<i32: 0, 1>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
%13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>>>, !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>) : i32 {
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "column_major"} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
%16 = tt.load %arg5 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>>>
%17 = tt.load %arg6 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "column_major"} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
// CHECK: tt.dot {{.*}}, {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[DPAS]]>
// CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
%18 = tt.dot %16, %17, %arg4, inputPrecision = tf32 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>> -> tensor<256x256xf32, #dpas>
%19 = tt.advance %arg5, [%c0_i32, %c32_i32] : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>>>
%20 = tt.advance %arg6, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
scf.yield %18, %19, %20 : tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>>>, !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
}
%14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array<i32: 1, 0>} : <tensor<256x256xf16, #dpas>>
%15 = arith.truncf %13#0 : tensor<256x256xf32, #dpas> to tensor<256x256xf16, #dpas>
// CHECK: tt.store {{.*}}, {{.*}}, {{.*}} : !tt.ptr<tensor<256x256xf16, #[[DPAS]]>
tt.store %14, %15 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf16, #dpas>>
tt.return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,35 @@ bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByStoreOp) {

TypedValue<triton::PointerType> base = op.getBase();
Operation::operand_range shape = op.getShape();
unsigned rank = shape.size();
Operation::operand_range strides = op.getStrides();
Operation::operand_range offsets = op.getOffsets();
ArrayRef<int32_t> order = op.getOrder();
ArrayRef<int64_t> tensorShape = tensorType.getShape();

int fastChangeDim = -1;
for (size_t i = 0; i < strides.size(); i++) {
if (mlir::triton::gpu::intel::isConstant(strides[i], 1)) {
fastChangeDim = i;
break;
}
}
if (fastChangeDim < 0) {
return false;
}

// TODO: support column-major tensor
// HW 2D block read instruction has restriction on pitch divisibility
if (strides.size() == 2) {
auto pitch = strides[0];
auto pitch = strides[(fastChangeDim == rank - 1) ? rank - 2 : rank - 1];
// Across Intel platforms, the strictest pitch restriction is to be a
// multiple of OWord(128 bits).
if (!ttgi::isDivisible(pitch, 128 / tensorType.getElementTypeBitWidth()))
return true;
}

// HW 2D block read instruction only supports contiguous accessing.
auto fastChangeStride = strides[1];
auto fastChangeStride = strides[fastChangeDim];
if (auto stride = fastChangeStride.getDefiningOp<arith::ConstantOp>()) {
if (auto strideInt = dyn_cast<IntegerAttr>(stride.getValue()))
return strideInt.getInt() != 1;
Expand Down

0 comments on commit ae75f93

Please sign in to comment.