Skip to content

Commit

Permalink
[SLM]: replace Q with SLM ld/st for FlashAttension (#1656)
Browse files Browse the repository at this point in the history
Signed-off-by: Tiotto, Ettore <[email protected]>
Co-authored-by: Tiotto, Ettore <[email protected]>
Co-authored-by: Whitney Tsang <[email protected]>
  • Loading branch information
3 people authored Aug 14, 2024
1 parent 03327b1 commit dede10f
Show file tree
Hide file tree
Showing 9 changed files with 497 additions and 36 deletions.
1 change: 1 addition & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"NVPTX_ENABLE_DUMP",
"TRITON_INTEL_ADVANCED_PATH",
"TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT",
"TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM",
"TRITON_INTEL_ENABLE_INSTR_SCHED",
"TRITON_INTEL_ENABLE_FAST_PREFETCH",
"TRITONGEN_FORCE_GENISA"
Expand Down
43 changes: 43 additions & 0 deletions test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,46 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
tt.return %0 : tensor<8x16xf32>
}
}

// -----

// COM: Checks tt.load lowering for SLM

#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
#dot0 = #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth=2}>
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {
// CHECK: llvm.func spir_funccc @llvm.genx.GenISA.simdBlockRead(!llvm.ptr<3>) -> vector<64xi16>
// CHECK-LABEL: @slm_load
tt.func public @slm_load(%arg0: !tt.ptr<f16, 3>) {
%c0_i32 = arith.constant 0 : i32
%c0_i64 = arith.constant 0 : i64
%c1_i64 = arith.constant 1 : i64
%c64_i64 = arith.constant 64 : i64
%ptr = tt.make_tensor_ptr %arg0, [%c0_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<16x64xf16, #dot0>, 3>
// CHECK: {{.*}} = llvm.call spir_funccc @llvm.genx.GenISA.simdBlockRead({{.*}}) {function_type = !llvm.func<vector<64xi16> (ptr<3>)>, linkage = #llvm.linkage<external>, sym_name = "llvm.genx.GenISA.simdBlockRead", visibility_ = 0 : i64} : (!llvm.ptr<3>) -> vector<64xi16>
%ld = tt.load %ptr {DotIdx = 0 : i32} : !tt.ptr<tensor<16x64xf16, #dot0>, 3>
tt.return
}
}

// -----

// COM: Checks tt.store lowering for SLM

#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
#dot0 = #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth=2}>
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {
// CHECK: llvm.func spir_funccc @llvm.genx.GenISA.simdBlockWrite(!llvm.ptr<3>, vector<64xi16>)
// CHECK-LABEL: @slm_store
tt.func public @slm_store(%arg0: !tt.ptr<f16, 3>, %arg1: tensor<16x64xf16, #dot0>) {
%c0_i32 = arith.constant 0 : i32
%c0_i64 = arith.constant 0 : i64
%c1_i64 = arith.constant 1 : i64
%c64_i64 = arith.constant 64 : i64
%ptr = tt.make_tensor_ptr %arg0, [%c0_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<16x64xf16, #dot0>, 3>
// CHECK: [[CAST:%.*]] = llvm.bitcast {{.*}} : vector<64xf16> to vector<64xi16>
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.simdBlockWrite({{.*}}, [[CAST]]) {function_type = !llvm.func<void (ptr<3>, vector<64xi16>)>, linkage = #llvm.linkage<external>, sym_name = "llvm.genx.GenISA.simdBlockWrite", visibility_ = 0 : i64} : (!llvm.ptr<3>, vector<64xi16>) -> ()
tt.store %ptr, %arg1 {DotIdx = 0 : i32} : !tt.ptr<tensor<16x64xf16, #dot0>, 3>
tt.return
}
}
22 changes: 22 additions & 0 deletions test/TritonGEN/tritongen-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,25 @@ llvm.func @triton_gen.dpas.bf16_accum(%c: vector<8xbf16>, %a : vector<8xi16>, %b
// CHECK-NEXT: {{%.*}} = llvm.bitcast [[RES]] : vector<8xi16> to vector<8xbf16>
llvm.return
}

// -----

// CHECK: llvm.func spir_funccc @llvm.genx.GenISA.simdBlockRead(!llvm.ptr<3>) -> vector<64xi16>

llvm.func @triton_gen.simdblockread(%ptr: !llvm.ptr<3>) {
// CHECK: llvm.func @triton_gen.simdblockread(%arg0: !llvm.ptr<3>) {
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.simdBlockRead(%arg0) {{.*}} : (!llvm.ptr<3>) -> vector<64xi16>
%ret = triton_gen.simdblockread %ptr : (!llvm.ptr<3>) -> vector<64xi16>
llvm.return
}

// -----

// CHECK: llvm.func spir_funccc @llvm.genx.GenISA.simdBlockWrite(!llvm.ptr<3>, vector<64xi16>)

llvm.func @triton_gen.simdblockwrite(%ptr: !llvm.ptr<3>, %val : vector<64xi16>) {
// CHECK: llvm.func @triton_gen.simdblockwrite(%arg0: !llvm.ptr<3>, %arg1: vector<64xi16>) {
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.simdBlockWrite(%arg0, %arg1) {{.*}} : (!llvm.ptr<3>, vector<64xi16>) -> ()
triton_gen.simdblockwrite %ptr, %val : (!llvm.ptr<3>, vector<64xi16>)
llvm.return
}
14 changes: 14 additions & 0 deletions test/TritonGEN/tritongen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,17 @@ llvm.func @triton_gen.2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
llvm.return
}

llvm.func @triton_gen.simdblockread(%ptr : !llvm.ptr) {
// CHECK: llvm.func @triton_gen.simdblockread(%arg0: !llvm.ptr) {
// CHECK-NEXT: triton_gen.simdblockread %arg0 : (!llvm.ptr) -> vector<64xi16>
triton_gen.simdblockread %ptr : (!llvm.ptr) -> vector<64xi16>
llvm.return
}

llvm.func @triton_gen.simdblockwrite(%ptr : !llvm.ptr, %val : vector<64xi16>) {
// CHECK: llvm.func @triton_gen.simdblockwrite(%arg0: !llvm.ptr, %arg1: vector<64xi16>) {
// CHECK-NEXT: triton_gen.simdblockwrite %arg0, %arg1 : (!llvm.ptr, vector<64xi16>)
triton_gen.simdblockwrite %ptr, %val : (!llvm.ptr, vector<64xi16>)
llvm.return
}
61 changes: 61 additions & 0 deletions test/TritonIntelGPU/slm-match-target-size.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// RUN: env TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM=1 triton-opt %s -tritonintelgpu-match-target-size | FileCheck %s

#warp = #triton_intel_gpu.warp<{sizePerThread = [32, 64], threadsPerWarp = [1, 1], order = [1, 0]}>
#dot0 = #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>
#dot1 = #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>

// COM: Test codegen in match-target-size for SLM path
// CHECK: module attributes {"triton_gpu.num-warps" = 1 : i32, triton_gpu.shared = 4096 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: @matmul_with_fixed_a
tt.func @matmul_with_fixed_a(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: f32, %arg4: !tt.ptr<f32>, %arg5: !tt.ptr<f32>) {
%c1024_i32 = arith.constant 1024 : i32
%c64_i32 = arith.constant 64 : i32
%cst_1 = arith.constant dense<0.000000e+00> : tensor<32x64xf32, #warp>
%c65536_i64 = arith.constant 65536 : i64
%c3145728_i64 = arith.constant 3145728 : i64
%cst_2 = arith.constant 1.44269502 : f32
%c0_i32 = arith.constant 0 : i32
%c1_i64 = arith.constant 1 : i64
%c64_i64 = arith.constant 64 : i64
%c1024_i64 = arith.constant 1024 : i64
%c128_i32 = arith.constant 128 : i32
%0 = tt.get_program_id z : i32
%1 = tt.get_program_id x : i32
%2 = tt.get_program_id y : i32
%3 = arith.extsi %1 : i32 to i64
%4 = arith.muli %3, %c3145728_i64 : i64
%5 = arith.extsi %2 : i32 to i64
%6 = arith.muli %5, %c65536_i64 : i64
%7 = arith.addi %4, %6 : i64
%8 = tt.addptr %arg0, %7 : !tt.ptr<f16>, i64
%9 = arith.muli %0, %c128_i32 : i32
%10 = tt.make_tensor_ptr %8, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%9, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16, #dot0>>
%13 = tt.addptr %arg1, %7 : !tt.ptr<f16>, i64
%14 = tt.make_tensor_ptr %13, [%c64_i64, %c1024_i64], [%c1_i64, %c64_i64], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<64x64xf16, #dot1>>
%15 = tt.addptr %arg5, %7 : !tt.ptr<f32>, i64
%16 = tt.make_tensor_ptr %15, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%9, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf32, #warp>>
%17 = arith.mulf %arg3, %cst_2 : f32
%18 = tt.load %10 : !tt.ptr<tensor<32x64xf16, #dot0>>
// CHECK: [[subA1:%.*]] = tt.load {{.*}} {DotIdx = 0 : i32} : !tt.ptr<tensor<32x32xf16>>
// CHECK: [[subA2:%.*]] = tt.load {{.*}} {DotIdx = 0 : i32} : !tt.ptr<tensor<32x32xf16>>
// CHECK: [[glueA:%.*]] = triton_intel_gpu.glue [[subA1]], [[subA2]] : (tensor<32x32xf16>, tensor<32x32xf16>) -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>
// CHECK: [[extracA1:%.*]] = triton_intel_gpu.extract [[glueA]][0] : tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>> -> tensor<16x64xf16>
// CHECK: tt.store {{.*}}, [[extracA1]] : !tt.ptr<tensor<16x64xf16>, 3>
// CHECK: [[extracA2:%.*]] = triton_intel_gpu.extract [[glueA]][1] : tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>> -> tensor<16x64xf16>
// CHECK: tt.store {{.*}}, [[extracA2]] : !tt.ptr<tensor<16x64xf16>, 3>
%21:3 = scf.for %arg6 = %c0_i32 to %c1024_i32 step %c64_i32 iter_args(%arg8 = %cst_1, %arg10 = %10, %arg11 = %14) -> (tensor<32x64xf32, #warp>, !tt.ptr<tensor<32x64xf16, #dot0>>, !tt.ptr<tensor<64x64xf16, #dot1>>) : i32 {
// CHECK: [[loadA1:%.*]] = tt.load {{.*}} {DotIdx = 0 : i32} : !tt.ptr<tensor<16x64xf16>, 3>
// CHECK: [[loadA2:%.*]] = tt.load {{.*}} {DotIdx = 0 : i32} : !tt.ptr<tensor<16x64xf16>, 3>
// CHECK: [[extractDotA:%.*]] = triton_intel_gpu.extract [[loadA1]][0] : tensor<16x64xf16> -> tensor<8x16xf16>
// CHECK: [[dot1:%.*]] = tt.dot [[extractDotA]], {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<8x16xf16> * tensor<16x16xf16> -> tensor<8x16xf32>
%25 = tt.load %arg11 : !tt.ptr<tensor<64x64xf16, #dot1>>
%26 = tt.dot %18, %25, %cst_1, inputPrecision = tf32 : tensor<32x64xf16, #dot0> * tensor<64x64xf16, #dot1> -> tensor<32x64xf32, #warp>
%27 = tt.advance %arg10, [%c128_i32, %c0_i32] : <tensor<32x64xf16, #dot0>>
%28 = tt.advance %arg11, [%c0_i32, %c64_i32] : <tensor<64x64xf16, #dot1>>
scf.yield %26, %27, %28 : tensor<32x64xf32, #warp>, !tt.ptr<tensor<32x64xf16, #dot0>>, !tt.ptr<tensor<64x64xf16, #dot1>>
}
tt.store %16, %21#0 : !tt.ptr<tensor<32x64xf32, #warp>>
tt.return
}
}
38 changes: 38 additions & 0 deletions third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -507,4 +507,42 @@ def TritonGEN_Matrix2DBlockPrefetchOp : TritonGEN_Op<"2Dblockprefetch">,
let hasVerifier = 1;
}

def TritonGEN_SIMDBlockReadOp: TritonGEN_Op<"simdblockread">,
Results<(outs FixedVectorOf<[TritonGEN_MatrixElemType]>:$res)>,
Arguments<(ins
Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr
)> {

let summary = "simd block read";

let description = [{
The `triton_gen.simdblockread` operation performs simd block read from
a start address without laneId offset. The parameters are:
$ptr - the base address to read data
}];

let assemblyFormat = [{
operands ` ` attr-dict `:` functional-type(operands, results)
}];
}

def TritonGEN_SIMDBlockWriteOp : TritonGEN_Op<"simdblockwrite">,
Arguments<(ins
Arg<LLVM_AnyPointer, "", [MemWrite]>:$ptr,
FixedVectorOf<[TritonGEN_MatrixElemType]>:$val
)> {

let summary = "simd block write";

let description = [{
The `triton_gen.simdblockwrite` operation performs simd block write to
a start address without laneId offset. The parameters are:
$ptr - the base address to be written
$val - the value vector to write
}];

let assemblyFormat = [{
operands ` ` attr-dict `:` `(` type(operands) `)`
}];
}
#endif // TRITONGEN_OPS
49 changes: 48 additions & 1 deletion third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1369,6 +1369,52 @@ struct TritonMatrix2DBlockPrefetchLowering
}
};

struct TritonSIMDBlockReadLowering
: public ConvertOpToLLVMPattern<TritonGEN::SIMDBlockReadOp> {
using ConvertOpToLLVMPattern<
TritonGEN::SIMDBlockReadOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(TritonGEN::SIMDBlockReadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LLVM::LLVMPointerType ptrTy = op.getPtr().getType();
VectorType vecTy = op.getRes().getType();

// TODO: Remove GenISA lowering after PoC productization is completed.
const StringLiteral funcName = "llvm.genx.GenISA.simdBlockRead";
intel::AttributeList attrs;
LLVM::CallOp call = createDeviceFunctionCall(rewriter, funcName, vecTy,
{ptrTy}, {op.getPtr()}, attrs);

rewriter.replaceOp(op, call.getResult());
return success();
}
};

struct TritonSIMDBlockWriteLowering
: public ConvertOpToLLVMPattern<TritonGEN::SIMDBlockWriteOp> {
using ConvertOpToLLVMPattern<
TritonGEN::SIMDBlockWriteOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(TritonGEN::SIMDBlockWriteOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *ctx = rewriter.getContext();
LLVM::LLVMPointerType ptrTy = op.getPtr().getType();
VectorType vecTy = op.getVal().getType();

// TODO: Remove GenISA lowering after PoC productization is completed.
const StringLiteral funcName = "llvm.genx.GenISA.simdBlockWrite";
intel::AttributeList attrs;
LLVM::CallOp call = createDeviceFunctionCall(
rewriter, funcName, void_ty(ctx), {ptrTy, vecTy},
{op.getPtr(), op.getVal()}, attrs);

rewriter.replaceOp(op, call);
return success();
}
};

} // namespace

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1439,7 +1485,8 @@ void mlir::triton::populateTritonGENToLLVMConversionPatterns(
TritonSubGroupReduceLowering, TritonSubGroupScanLowering,
TritonSubGroupShuffleLowering, TritonMatrixDPASLowering,
TritonMatrix2DBlockLoadLowering, TritonMatrix2DBlockStoreLowering,
TritonMatrix2DBlockPrefetchLowering>(converter);
TritonMatrix2DBlockPrefetchLowering, TritonSIMDBlockReadLowering,
TritonSIMDBlockWriteLowering>(converter);
}

void registerConvertTritonTritonGENToLLVMInterface(DialectRegistry &registry) {
Expand Down
78 changes: 74 additions & 4 deletions third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace {

VectorType getVectorType(RankedTensorType tensorType, Type elemType) {
// Determine a vector type of the given `elemType` that covers 1/16 of
// `tensorType`, i.e. the amout of data a single subgroup lane will work on.
// `tensorType`, i.e. the amount of data a single subgroup lane will work on.
size_t tensorSize =
tensorType.getNumElements() * tensorType.getElementTypeBitWidth();
size_t num = (tensorSize / 16) / elemType.getIntOrFloatBitWidth();
Expand Down Expand Up @@ -150,7 +150,11 @@ class LoadStorePrefetchOpConversion
vBlks = ceil(blockWidth, blockWidthUnit);
blockWidth = blockWidthUnit;
}
assert((vBlks == 1 || vBlks == 2) && "only support 1 or 2 blocks");
bool isLocalSpace = (ptrType.getAddressSpace() ==
TritonGEN::TritonGENMemorySpace::kWorkgroup);

assert(isLocalSpace ||
(vBlks == 1 || vBlks == 2) && "only support 1 or 2 blocks");

Value ptr = op.getPtr();
if (auto cast =
Expand All @@ -167,6 +171,9 @@ class LoadStorePrefetchOpConversion

OpBuilder::InsertPoint insertPoint = rewriter.saveInsertionPoint();
rewriter.setInsertionPointAfter(ptrOp);
if (isLocalSpace)
return rewriteLocalSpace(op, base, insertPoint, adaptor, rewriter);

Location loc = op.getLoc();
bool transpose = ptrOp.getOrder()[0] == 0;
Value bytes =
Expand Down Expand Up @@ -223,7 +230,7 @@ class LoadStorePrefetchOpConversion
auto newOp = rewriter.create<TritonGEN::Matrix2DBlockPrefetchOp>(
loc, base, surfaceW, surfaceH, surfaceP, offsetX, offsetY, dataSize,
blockWidth, blockHeight, vBlks, TritonGEN::LoadCacheControl::L1C_L3C);
VERIFY_OPERATION(newOp);
VERIFY_OPERATION(newOp)

rewriter.eraseOp(op);
} else {
Expand All @@ -234,13 +241,76 @@ class LoadStorePrefetchOpConversion
loc, base, surfaceW, surfaceH, surfaceP, offsetX, offsetY, dataSize,
blockWidth, blockHeight, vBlks,
bitcast(adaptor.getValue(), vectorType));
VERIFY_OPERATION(newOp);
VERIFY_OPERATION(newOp)

rewriter.eraseOp(op);
}

return success();
}

private:
LogicalResult rewriteLocalSpace(OpType op, Value base,
OpBuilder::InsertPoint insertPoint,
typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto ptrType = cast<PointerType>(op.getPtr().getType());
assert(ptrType.getAddressSpace() ==
TritonGEN::TritonGENMemorySpace::kWorkgroup &&
"expecting local space");

MLIRContext *ctx = rewriter.getContext();
Location loc = op.getLoc();
Value llPtr = adaptor.getPtr();
if (auto cast =
dyn_cast<mlir::UnrealizedConversionCastOp>(llPtr.getDefiningOp()))
llPtr = cast.getInputs()[0];

// sg_size(16) x i64 = 64 x i16
VectorType v64i16Ty = VectorType::get(64, i16_ty);
LLVM::LLVMPointerType ptrToSharedMemTy =
ptr_ty(ctx, ptrType.getAddressSpace());
Value offsetX = extract_element(llPtr, i32_val(0));
Value offsetY = extract_element(llPtr, i32_val(1));

Value blkId = add(mul(udiv(offsetY, i32_val(8)), i32_val(4)),
udiv(offsetX, i32_val(16)));
Value index = mul(blkId, i32_val(128));
base = gep(ptrToSharedMemTy, i16_ty, base, index);

if constexpr (std::is_same_v<OpType, LoadOp>) {
VectorType v64f16Ty = VectorType::get(64, f16_ty);

rewriter.restoreInsertionPoint(insertPoint);

TritonGEN::SIMDBlockReadOp simdRead =
rewriter.create<TritonGEN::SIMDBlockReadOp>(loc, v64i16Ty, base);
rewriter.replaceOp(op, simdRead.getRes());

return success();
}

if constexpr (std::is_same_v<OpType, StoreOp>) {
rewriter.restoreInsertionPoint(insertPoint);
Value val = adaptor.getValue();
if (auto shuffleOp =
dyn_cast_or_null<LLVM::ShuffleVectorOp>(val.getDefiningOp()))
val = shuffleOp.getRes();
if (isa<LLVM::LLVMStructType>(val.getType())) {
SmallVector<Value> unpackedVal = unpackLLElements(loc, val, rewriter);
val = packLLVector(loc, unpackedVal, rewriter);
}
val = bitcast(val, v64i16Ty);

TritonGEN::SIMDBlockWriteOp simdWrite =
rewriter.create<TritonGEN::SIMDBlockWriteOp>(loc, base, val);

rewriter.eraseOp(op);
return success();
}

return failure();
}
};

/// TritonGen DpasOp Desc: XeHP SDV: dot product accumulate systolic
Expand Down
Loading

0 comments on commit dede10f

Please sign in to comment.