Skip to content

Commit

Permalink
[convertBlockPtrToTensorOfPtr] Do not generate mask when `boundary_ch…
Browse files Browse the repository at this point in the history
…eck` is not set (#2366)

`RewriteTensorPointer` also doesn't generate mask when `boundary_check`
is not set:
https://github.com/intel/intel-xpu-backend-for-triton/blob/main/third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp#L206

Signed-off-by: Whitney Tsang <[email protected]>
  • Loading branch information
whitneywhtsang authored Sep 27, 2024
1 parent db07b9e commit 59edf2c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
2 changes: 2 additions & 0 deletions test/TritonIntelGPU/blockptr_load.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
#dot_b = #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>
module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: llvm.func spir_kernelcc @non_contiguous_load_dot_layout
// COM: Check mask is not generated when boundary_check is not set.
// CHECK-NOT: llvm.icmp "slt"
tt.func public @non_contiguous_load_dot_layout(%arg0: !tt.ptr<f16>, %col_stride: i64) {
%c64_i64 = arith.constant 64 : i64
%c1_i64 = arith.constant 1 : i64
Expand Down
22 changes: 12 additions & 10 deletions third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ struct LoadStoreConversionBase {
};

SmallVector<Value> ptrElems(numElems);
SmallVector<Value> maskElems(numElems);
SmallVector<Value> maskElems;
for (unsigned i = 0; i < numElems; ++i) {
auto index = indices[i];
SmallVector<Value> indicesInTensor(rank);
Expand All @@ -251,15 +251,17 @@ struct LoadStoreConversionBase {
ptrElems[i] = gep(ptr_ty(rewriter.getContext(), 1 /*global*/),
valueElemTy, blockPtr[blockBase], offset);

// Get the LLVM values for mask
maskElems[i] = linearize(
indicesInTensor,
{blockPtr.begin() + blockShape, blockPtr.begin() + blockStride},
int_val(1, 1),
[&](const Value &index, const Value &shape, const Value &mask) {
// mask = mask && (index < shape)
return and_(icmp_slt(index, trunc(i32_ty, shape)), mask);
});
if (boundaryCheck.size() > 0) {
// Get the LLVM values for mask
maskElems.push_back(linearize(
indicesInTensor,
{blockPtr.begin() + blockShape, blockPtr.begin() + blockStride},
int_val(1, 1),
[&](const Value &index, const Value &shape, const Value &mask) {
// mask = mask && (index < shape)
return and_(icmp_slt(index, trunc(i32_ty, shape)), mask);
}));
}
}

// Get the LLVM values for `other`
Expand Down

0 comments on commit 59edf2c

Please sign in to comment.