Skip to content

Commit

Permalink
[MatchTargetSize] Handle tt.make_range and row-vector broadcasts (#…
Browse files Browse the repository at this point in the history
…2043)

Splits `make_range` into SG-sized subranges, and handles row-vector
broadcasts (e.g. `1x64 -> 16x64`) in `MatchTargetSize`.

See #1947 for
more context / the complete PoC.

---------

Signed-off-by: Julian Oppermann <[email protected]>
  • Loading branch information
jopperm authored Sep 4, 2024
1 parent 0d6f060 commit cb0b7e1
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 2 deletions.
57 changes: 57 additions & 0 deletions test/TritonIntelGPU/match-target-size.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,60 @@ tt.func public @attn_fwd(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.pt
tt.return
}
}

// -----

#warp = #triton_intel_gpu.warp<{sizePerThread = [16, 64], threadsPerWarp = [1, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 1 : i32} {
tt.func public @_attn_fwd(%arg0: i32, %arg1: !tt.ptr<i32>) {
// COM: This op primes the map of known layouts
%cst = arith.constant dense<1> : tensor<16x64xi32, #warp>

// CHECK: %[[CST_48:.*]] = arith.constant dense<48> : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #warp}>>
// CHECK: %[[CST_32:.*]] = arith.constant dense<32> : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #warp}>>
// CHECK: %[[CST_16:.*]] = arith.constant dense<16> : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #warp}>>

// CHECK: %[[MR1:.*]] = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #warp}>>

// CHECK: %[[MR2:.*]] = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #warp}>>
// CHECK: %[[MR2_PLUS_16:.*]] = arith.addi %[[MR2]], %[[CST_16]] : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #warp}>>
// CHECK: %[[MR2_PLUS_32:.*]] = arith.addi %[[MR2]], %[[CST_32]] : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #warp}>>
// CHECK: %[[MR2_PLUS_48:.*]] = arith.addi %[[MR2]], %[[CST_48]] : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #warp}>>
// CHECK: %[[GLUE:.*]] = triton_intel_gpu.glue %[[MR2]], %[[MR2_PLUS_16]], %[[MR2_PLUS_32]], %[[MR2_PLUS_48]] : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #warp}>>, tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #warp}>>, tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #warp}>>, tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #warp}>>) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #warp}>>
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #warp}>>

// CHECK: %[[ED1:.*]] = tt.expand_dims %[[MR1]] {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #warp}>> -> tensor<16x1xi32, #warp>
// CHECK: %[[ED2:.*]] = tt.expand_dims %[[GLUE]] {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #warp}>> -> tensor<1x64xi32, #warp>
%2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #warp}>> -> tensor<16x1xi32, #warp>
%3 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #warp}>> -> tensor<1x64xi32, #warp>

// CHECK: %[[BC1:.*]] = triton_intel_gpu.broadcast %[[ED1]] : tensor<16x1xi32, #warp> -> tensor<16x16xi32>
%4 = tt.broadcast %2 : tensor<16x1xi32, #warp> -> tensor<16x64xi32, #warp>

// CHECK: %[[EX0:.*]] = triton_intel_gpu.extract %[[ED2]][0] : tensor<1x64xi32, #warp> -> tensor<1x16xi32>
// CHECK: %[[BC20:.*]] = tt.broadcast %[[EX0]] : tensor<1x16xi32> -> tensor<16x16xi32>
// CHECK: %[[EX1:.*]] = triton_intel_gpu.extract %[[ED2]][1] : tensor<1x64xi32, #warp> -> tensor<1x16xi32>
// CHECK: %[[BC21:.*]] = tt.broadcast %[[EX1]] : tensor<1x16xi32> -> tensor<16x16xi32>
// CHECK: %[[EX2:.*]] = triton_intel_gpu.extract %[[ED2]][2] : tensor<1x64xi32, #warp> -> tensor<1x16xi32>
// CHECK: %[[BC22:.*]] = tt.broadcast %[[EX2]] : tensor<1x16xi32> -> tensor<16x16xi32>
// CHECK: %[[EX3:.*]] = triton_intel_gpu.extract %[[ED2]][3] : tensor<1x64xi32, #warp> -> tensor<1x16xi32>
// CHECK: %[[BC23:.*]] = tt.broadcast %[[EX3]] : tensor<1x16xi32> -> tensor<16x16xi32>
%5 = tt.broadcast %3 : tensor<1x64xi32, #warp> -> tensor<16x64xi32, #warp>

// CHECK: arith.addi %[[BC1]], %[[BC20]] : tensor<16x16xi32>
// CHECK: arith.addi %[[BC1]], %[[BC21]] : tensor<16x16xi32>
// CHECK: arith.addi %[[BC1]], %[[BC22]] : tensor<16x16xi32>
// CHECK: arith.addi %[[BC1]], %[[BC23]] : tensor<16x16xi32>
%6 = arith.addi %4, %5 : tensor<16x64xi32, #warp>

// COM: Prevent DCE
%c0_i32 = arith.constant 0 : i32
%c0_i64 = arith.constant 0 : i64
%c1_i64 = arith.constant 1 : i64
%c64_i64 = arith.constant 64 : i64
%7 = tt.make_tensor_ptr %arg1, [%c0_i64, %c0_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<16x64xi32, #warp>>
tt.store %7, %6 : !tt.ptr<tensor<16x64xi32, #warp>>
tt.return
}
}
80 changes: 78 additions & 2 deletions third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,13 @@ class MatchTargetSizePass
op->getResultTypes().end());
types.append(resultTypes);

if (auto mr = dyn_cast<tt::MakeRangeOp>(op)) {
// FIXME: MakeRange's rank-1 tensors are not candidates in the original
// sense of the advance path, but we need to split them here anyways.
transformMakeRangeOp(mr);
return WalkResult::advance();
}

if (llvm::none_of(types, [this](Type type) { return isCandidate(type); }))
return WalkResult::advance();
if (isa<scf::ForOp, scf::YieldOp>(op))
Expand Down Expand Up @@ -375,6 +382,7 @@ class MatchTargetSizePass
void transformDotOp(tt::DotOp dot);
void transformReduceOp(tt::ReduceOp op);
void transformBroadcastOp(tt::BroadcastOp op);
void transformMakeRangeOp(tt::MakeRangeOp op);

/// Generic transformation.
void transformGenericOp(Operation *op);
Expand Down Expand Up @@ -742,6 +750,10 @@ MatchTargetSizePass::getSubTypeAndShape(Type type, bool isTransposed,
if (useSLM) {
subSize[0] = 16;
subSize[1] = 64;
} else {
// Never exceed the shape of the original type.
subSize[0] = std::min(subSize[0], shape[0]);
subSize[1] = std::min(subSize[1], shape[1]);
}

auto subType = RankedTensorType::get(
Expand Down Expand Up @@ -1003,14 +1015,17 @@ void MatchTargetSizePass::transformBroadcastOp(tt::BroadcastOp op) {
RankedTensorType srcType = op.getSrc().getType();
unsigned srcDim0 = srcType.getShape()[0];
unsigned dstDim0 = tType.getShape()[0];
unsigned resDim0 = resType.getShape()[0];
unsigned srcDim1 = srcType.getShape()[1];
unsigned dstDim1 = tType.getShape()[1];
unsigned resDim1 = resType.getShape()[1];
Operation *glue;
if (srcDim0 == dstDim0) {
Value newOp = b.create<ttgi::BroadcastOp>(loc, tType, op.getSrc());
unsigned num = resType.getShape()[1] / tType.getShape()[1];
SmallVector<Value> ops(num, newOp);
glue = b.create<ttgi::GlueOp>(loc, resType, ops);
} else {
assert(srcDim0 == 2 * dstDim0 && "add more support");
} else if (srcDim0 == 2 * dstDim0) {
auto newTy = RankedTensorType::get({srcDim0, tType.getShape()[1]},
tType.getElementType());
auto newOp = b.create<ttgi::BroadcastOp>(loc, newTy, op.getSrc());
Expand All @@ -1019,6 +1034,30 @@ void MatchTargetSizePass::transformBroadcastOp(tt::BroadcastOp op) {
SmallVector<Value> ops{extract0, extract1, extract0, extract1,
extract0, extract1, extract0, extract1};
glue = b.create<ttgi::GlueOp>(loc, resType, ops);
} else if (srcDim0 == 1 && srcDim1 == resDim1) {
// Handle row-vector broadcasts, e.g. 1x64 --> 16x64.
auto subRowVecTy =
RankedTensorType::get({1, tType.getShape()[1]}, tType.getElementType());

// How many extracts do we need to cover the width of the input tensor?
unsigned nExtracts = srcDim1 / dstDim1;
SmallVector<Value> subBroadcasts;
for (int i = 0; i < nExtracts; ++i) {
auto ext = b.create<ttgi::ExtractOp>(loc, subRowVecTy, op.getSrc(), i);
auto sbc = b.create<tt::BroadcastOp>(loc, tType, ext);
subBroadcasts.push_back(sbc);
}

// How often do we need to repeat a sub broadcast to cover the height of the
// result tensor?
unsigned nRepeats = resDim0 / dstDim0;
SmallVector<Value> ops;
for (int i = 0; i < nRepeats * nExtracts; ++i)
ops.push_back(subBroadcasts[i / nRepeats]);

glue = b.create<ttgi::GlueOp>(loc, resType, ops);
} else {
llvm::report_fatal_error("Unhandled broadcast");
}

op->replaceAllUsesWith(glue->getResults());
Expand All @@ -1027,6 +1066,43 @@ void MatchTargetSizePass::transformBroadcastOp(tt::BroadcastOp op) {
return;
}

void MatchTargetSizePass::transformMakeRangeOp(tt::MakeRangeOp op) {
constexpr unsigned subgroupSize = 16;
unsigned start = op.getStart();
unsigned end = op.getEnd();
assert(start == 0 && end % subgroupSize == 0 && "Unsupported range");

if (end == subgroupSize)
// nothing to do
return;

// Transform the range like this: (SG = subgroup size = 16)
// makeRange(0, N) = glue(
// splat(0 * SG) + makeRange(0, SG),
// splat(1 * SG) + makeRange(0, SG),
// ...
// splat((N/SG-1) * SG) + makeRange(0, SG)
// )

OpBuilder b(op);
Location loc = op.getLoc();
RankedTensorType origTy = op.getType();
Type elemTy = origTy.getElementType();
auto subRangeTy =
RankedTensorType::get({subgroupSize}, elemTy, origTy.getEncoding());
auto subRange = b.create<tt::MakeRangeOp>(loc, subRangeTy, 0, subgroupSize);
SmallVector<Value> subRanges;
for (int i = 0; i < end / subgroupSize; ++i) {
Value offset =
b.create<arith::ConstantIntOp>(loc, i * subgroupSize, elemTy);
Value offsetTensor = b.create<tt::SplatOp>(loc, subRangeTy, offset);
subRanges.push_back(b.create<arith::AddIOp>(loc, subRange, offsetTensor));
}
auto glue = b.create<ttgi::GlueOp>(loc, op.getType(), subRanges);
op.replaceAllUsesWith(glue->getResults()[0]);
op->erase();
}

void MatchTargetSizePass::transformGenericOp(Operation *op) {
unsigned numResults = op->getResults().size();
unsigned dotIdx = 2;
Expand Down

0 comments on commit cb0b7e1

Please sign in to comment.