Skip to content

Commit

Permalink
Distribute tt.make_range to warps (#2026)
Browse files Browse the repository at this point in the history
Add support for distributing `tt.make_range` ops according to the
desired warp size. This is a PoC which assumes that multiple warps are
only needed along one dimension.

See #1947 for
more context / the complete PoC.

---------

Signed-off-by: Julian Oppermann <[email protected]>
  • Loading branch information
jopperm authored Aug 30, 2024
1 parent 41f1ee6 commit 42b073e
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 0 deletions.
57 changes: 57 additions & 0 deletions test/TritonIntelGPU/distribute-to-warps.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,60 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
tt.return
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [16, 64], threadsPerWarp = [1, 1], warpsPerCTA = [8, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 1 : i32} {
tt.func public @_attn_fwd(%arg6 : i32, %dummyptr : !tt.ptr<f32>) {
// CHECK: %[[SG_ID:.*]] = gpu.subgroup_id : index
// CHECK: %[[SG_ID1:.*]] = arith.index_cast %[[SG_ID]] : index to i32
// CHECK: %[[MINUS_INF:.*]] = arith.constant dense<-1.000000e+06> : tensor<16x64xf32, #warp>
%cst_1 = arith.constant dense<-1.000000e+06> : tensor<128x64xf32, #blocked>
// CHECK: %[[ZERO:.*]] = arith.constant dense<0.000000e+00> : tensor<16x64xf32, #warp>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #blocked>
// CHECK: %[[CNST_128:.*]] = arith.constant 128 : i32
%c128_i32 = arith.constant 128 : i32
// CHECK: %[[PROG_ID:.*]] = tt.get_program_id z : i32
%0 = tt.get_program_id z : i32
// CHECK: %[[PROD_ID_SCALED:.*]] = arith.muli %[[PROG_ID]], %[[CNST_128]] : i32
%9 = arith.muli %0, %c128_i32 : i32
// CHECK: %[[RANGE1:.*]] = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
// CHECK: %[[CNST_16:.*]] = arith.constant 16 : i32
// CHECK: %[[SG_ID_SCALED:.*]] = arith.muli %[[SG_ID1]], %[[CNST_16]] : i32
// CHECK: %[[OFFSET1:.*]] = tt.splat %[[SG_ID_SCALED]] : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
// CHECK: %[[RANGE1_OFFSET:.*]] = arith.addi %[[RANGE1]], %[[OFFSET1]] : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%25 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
// CHECK: %[[OFFSET2:.*]] = tt.splat %[[PROD_ID_SCALED]] : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%26 = tt.splat %9 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
// CHECK: %[[RANGE1_OFFSET2:.*]] = arith.addi %[[OFFSET2]], %[[RANGE1_OFFSET]] : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%27 = arith.addi %26, %25 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
// CHECK: %[[RANGE2:.*]] = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #warp}>>
%28 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
// CHECK: %[[EXP_DIM1:.*]] = tt.expand_dims %[[RANGE1_OFFSET2]] {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #warp}>> -> tensor<16x1xi32, #warp>
%39 = tt.expand_dims %27 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked>
// CHECK: %[[EXP_DIM2:.*]] = tt.expand_dims %[[RANGE2]] {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #warp}>> -> tensor<1x64xi32, #warp>
%40 = tt.expand_dims %28 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
// CHECK: %[[BC1:.*]] = tt.broadcast %[[EXP_DIM1]] : tensor<16x1xi32, #warp> -> tensor<16x64xi32, #warp>
%41 = tt.broadcast %39 : tensor<128x1xi32, #blocked> -> tensor<128x64xi32, #blocked>
// CHECK: %[[OFFSET3:.*]] = tt.splat %arg0 : i32 -> tensor<1x64xi32, #warp>
%49 = tt.splat %arg6 : i32 -> tensor<1x64xi32, #blocked>
// CHECK: %[[RANGE2_OFFSET:.*]] = arith.addi %[[OFFSET3]], %[[EXP_DIM2]] : tensor<1x64xi32, #warp>
%50 = arith.addi %49, %40 : tensor<1x64xi32, #blocked>
// CHECK: %[[BC2:.*]] = tt.broadcast %[[RANGE2_OFFSET]] : tensor<1x64xi32, #warp> -> tensor<16x64xi32, #warp>
%51 = tt.broadcast %50 : tensor<1x64xi32, #blocked> -> tensor<128x64xi32, #blocked>
// CHECK: %[[MASK:.*]] = arith.cmpi sge, %[[BC1]], %[[BC2]] : tensor<16x64xi32, #warp>
%52 = arith.cmpi sge, %41, %51 : tensor<128x64xi32, #blocked>
// CHECK: %[[SELECT:.*]] = arith.select %[[MASK]], %[[ZERO]], %[[MINUS_INF]] : tensor<16x64xi1, #warp>, tensor<16x64xf32, #warp>
%54 = arith.select %52, %cst_2, %cst_1 : tensor<128x64xi1, #blocked>, tensor<128x64xf32, #blocked>

// COM: store result to prevent DCE; this is not part of flashattn kernel
%c0_i32 = arith.constant 0 : i32
%c0_i64 = arith.constant 0 : i64
%c1_i64 = arith.constant 0 : i64
%c64_i64 = arith.constant 64 : i64
%dummyblock = tt.make_tensor_ptr %dummyptr, [%c0_i64, %c0_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<128x64xf32, #blocked>>
tt.store %dummyblock, %54 : !tt.ptr<tensor<128x64xf32, #blocked>>
tt.return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,50 @@ void distributeGenericOp(Operation *op) {
op->erase();
}

void distributeMakeRangeOp(tt::MakeRangeOp op, Value warpId) {
assert(op.getStart() == 0 && "Expected zero-based range");

auto loc = op.getLoc();
OpBuilder b(op);

auto tensorTy = op.getType();
auto sliceLayout = dyn_cast<ttg::SliceEncodingAttr>(tensorTy.getEncoding());
assert(sliceLayout && "Expected slice layout");

auto parentWarpsPerCTA = ttg::getWarpsPerCTA(sliceLayout.getParent());
assert(parentWarpsPerCTA.size() == 2 && "Only slice of 2D layout supported");
assert(parentWarpsPerCTA.back() == 1 &&
"Warp distribution on second dimensions unsupported");

auto convTy = convertType(tensorTy);
unsigned numElems = convTy.getNumElements();
auto subRange = b.create<tt::MakeRangeOp>(loc, convTy, 0, numElems);

// If the number of elements stays the same, we don't need any offset
// computation. Use the newly constructed op because the type's encoding was
// changed by the conversion.
if (tensorTy.getNumElements() == numElems) {
op->replaceAllUsesWith(subRange);
op->erase();
return;
}

// Else: We need to take the warp ID into account. `tt.make_range` only
// supports constant boundaries, so we have to construct the offset
// calculation manually.
//
// FIXME: This assumes that warpsPerCTA[dim 1] is 1; I believe for the generic
// case, we would need to determine a dimension-specific offset similar to
// `tt.make_tensor_ptr`' distribution pattern.
auto elemTy = convTy.getElementType();
auto numElemsConst = b.create<arith::ConstantIntOp>(loc, numElems, elemTy);
auto rangeOffset = b.create<arith::MulIOp>(loc, warpId, numElemsConst);
auto splat = b.create<tt::SplatOp>(loc, convTy, rangeOffset);
auto newRange = b.create<arith::AddIOp>(loc, subRange, splat);
op->replaceAllUsesWith(newRange);
op->erase();
}

void distributeArithConstantOp(arith::ConstantOp op) {
auto type = dyn_cast<RankedTensorType>(op.getType());
if (!type)
Expand Down Expand Up @@ -381,6 +425,8 @@ class TritonIntelGPUDistributeToWarpsPass
distributeArithConstantOp(cstOp);
else if (auto convertOp = dyn_cast<ttg::ConvertLayoutOp>(op))
distributeConvertLayoutOp(convertOp, warpId, typeMap[convertOp]);
else if (auto makeRange = dyn_cast<tt::MakeRangeOp>(op))
distributeMakeRangeOp(makeRange, warpId);
else if (isa<tt::LoadOp, tt::DotOp, tt::AdvanceOp, tt::ReduceOp,
tt::SplatOp, tt::BroadcastOp, tt::ExpandDimsOp>(op) ||
op->getDialect() == arithDialect ||
Expand Down

0 comments on commit 42b073e

Please sign in to comment.