diff --git a/test/TritonIntelGPU/blockptr_load.mlir b/test/TritonIntelGPU/blockptr_load.mlir index c445193b1..2a93b6bac 100644 --- a/test/TritonIntelGPU/blockptr_load.mlir +++ b/test/TritonIntelGPU/blockptr_load.mlir @@ -261,6 +261,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war #dot_b = #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}> module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { // CHECK-LABEL: llvm.func spir_kernelcc @non_contiguous_load_dot_layout + // COM: Check mask is not generated when boundary_check is not set. + // CHECK-NOT: llvm.icmp "slt" tt.func public @non_contiguous_load_dot_layout(%arg0: !tt.ptr, %col_stride: i64) { %c64_i64 = arith.constant 64 : i64 %c1_i64 = arith.constant 1 : i64 diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index cbb7d1c7b..1c5cda12c 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -230,7 +230,7 @@ struct LoadStoreConversionBase { }; SmallVector ptrElems(numElems); - SmallVector maskElems(numElems); + SmallVector maskElems; for (unsigned i = 0; i < numElems; ++i) { auto index = indices[i]; SmallVector indicesInTensor(rank); @@ -251,15 +251,17 @@ struct LoadStoreConversionBase { ptrElems[i] = gep(ptr_ty(rewriter.getContext(), 1 /*global*/), valueElemTy, blockPtr[blockBase], offset); - // Get the LLVM values for mask - maskElems[i] = linearize( - indicesInTensor, - {blockPtr.begin() + blockShape, blockPtr.begin() + blockStride}, - int_val(1, 1), - [&](const Value &index, const Value &shape, const Value &mask) { - // mask = mask && (index < shape) - return and_(icmp_slt(index, trunc(i32_ty, shape)), mask); - }); + if (boundaryCheck.size() > 0) { + // Get the LLVM values for mask + maskElems.push_back(linearize( + indicesInTensor, + {blockPtr.begin() + blockShape, blockPtr.begin() + blockStride}, + int_val(1, 1), + [&](const Value &index, const Value &shape, const Value &mask) { + // mask = mask && (index < shape) + return and_(icmp_slt(index, trunc(i32_ty, shape)), mask); + })); + } } // Get the LLVM values for `other`