diff --git a/test/Conversion/intel/triton_to_tritongpu_warp.mlir b/test/Conversion/intel/triton_to_tritongpu_warp.mlir index 7f443ca1aa..95492bbfed 100644 --- a/test/Conversion/intel/triton_to_tritongpu_warp.mlir +++ b/test/Conversion/intel/triton_to_tritongpu_warp.mlir @@ -96,3 +96,91 @@ module { tt.return } } + +// ----- +// CHECK1: [[BLOCKED:#.*]] = #triton_gpu.blocked<{sizePerThread = [16, 64], threadsPerWarp = [1, 1], warpsPerCTA = [8, 1], order = [1, 0]}> +// CHECK1: "triton_gpu.num-warps" = 8 +module { + tt.func public @_attn_fwd(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: f32, %arg4: !tt.ptr, %arg5: !tt.ptr) { + // CHECK1: tt.load {{.*}} : !tt.ptr>> + // CHECK1: tt.splat {{.*}} : f32 -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = [[BLOCKED]]}>> + // CHECK1: tt.splat {{.*}} : f32 -> tensor<128x64xf32, [[BLOCKED]]> + // CHECK1: tt.load {{.*}} : !tt.ptr>> + // CHECK1: tt.dot {{.*}} -> tensor<128x64xf32, #blocked> + // CHECK1: tt.load {{.*}} : !tt.ptr>> + // CHECK1: tt.dot {{.*}} -> tensor<128x64xf32, [[BLOCKED]]> + %cst = arith.constant dense<1.000000e+00> : tensor<128xf32> + %cst_0 = arith.constant dense<0xFF800000> : tensor<128xf32> + %c1024_i32 = arith.constant 1024 : i32 + %c64_i32 = arith.constant 64 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> + %c65536_i64 = arith.constant 65536 : i64 + %c131072_i64 = arith.constant 131072 : i64 + %cst_2 = arith.constant 1.44269502 : f32 + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c128_i32 = arith.constant 128 : i32 + %0 = tt.get_program_id z : i32 + %1 = tt.get_program_id x : i32 + %2 = tt.get_program_id y : i32 + %3 = arith.extsi %1 : i32 to i64 + %4 = arith.muli %3, %c131072_i64 : i64 + %5 = arith.extsi %2 : i32 to i64 + %6 = arith.muli %5, %c65536_i64 : i64 + %7 = arith.addi %4, %6 : i64 + %8 = tt.addptr %arg0, %7 : !tt.ptr, i64 + %9 = arith.muli %0, %c128_i32 : i32 + %10 = tt.make_tensor_ptr %8, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%9, %c0_i32] {order = array} : > + %11 = tt.addptr %arg2, %7 : !tt.ptr, i64 + %12 = tt.make_tensor_ptr %11, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + %13 = tt.addptr %arg1, %7 : !tt.ptr, i64 + %14 = tt.make_tensor_ptr %13, [%c64_i64, %c1024_i64], [%c1_i64, %c64_i64], [%c0_i32, %c0_i32] {order = array} : > + %15 = tt.addptr %arg5, %7 : !tt.ptr, i64 + %16 = tt.make_tensor_ptr %15, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%9, %c0_i32] {order = array} : > + %17 = arith.mulf %arg3, %cst_2 : f32 + %18 = tt.load %10 : !tt.ptr> + %19 = tt.splat %17 : f32 -> tensor<128xf32> + %20 = tt.splat %17 : f32 -> tensor<128x64xf32> + %21:5 = scf.for %arg6 = %c0_i32 to %c1024_i32 step %c64_i32 iter_args(%arg7 = %cst, %arg8 = %cst_1, %arg9 = %cst_0, %arg10 = %12, %arg11 = %14) -> (tensor<128xf32>, tensor<128x64xf32>, tensor<128xf32>, !tt.ptr>, !tt.ptr>) : i32 { + %25 = tt.load %arg11 : !tt.ptr> + %26 = tt.dot %18, %25, %cst_1, inputPrecision = tf32 : tensor<128x64xf16> * tensor<64x64xf16> -> tensor<128x64xf32> + %27 = "tt.reduce"(%26) <{axis = 1 : i32}> ({ + ^bb0(%arg12: f32, %arg13: f32): + %48 = arith.maxnumf %arg12, %arg13 : f32 + tt.reduce.return %48 : f32 + }) : (tensor<128x64xf32>) -> tensor<128xf32> + %28 = arith.mulf %27, %19 : tensor<128xf32> + %29 = arith.maxnumf %arg9, %28 : tensor<128xf32> + %30 = arith.mulf %26, %20 : tensor<128x64xf32> + %31 = tt.expand_dims %29 {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32> + %32 = tt.broadcast %31 : tensor<128x1xf32> -> tensor<128x64xf32> + %33 = arith.subf %30, %32 : tensor<128x64xf32> + %34 = math.exp2 %33 : tensor<128x64xf32> + %35 = "tt.reduce"(%34) <{axis = 1 : i32}> ({ + ^bb0(%arg12: f32, %arg13: f32): + %48 = arith.addf %arg12, %arg13 : f32 + tt.reduce.return %48 : f32 + }) : (tensor<128x64xf32>) -> tensor<128xf32> + %36 = arith.subf %arg9, %29 : tensor<128xf32> + %37 = math.exp2 %36 : tensor<128xf32> + %38 = arith.mulf %arg7, %37 : tensor<128xf32> + %39 = arith.addf %38, %35 : tensor<128xf32> + %40 = tt.expand_dims %37 {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32> + %41 = tt.broadcast %40 : tensor<128x1xf32> -> tensor<128x64xf32> + %42 = arith.mulf %arg8, %41 : tensor<128x64xf32> + %43 = tt.load %arg10 : !tt.ptr> + %44 = arith.truncf %34 : tensor<128x64xf32> to tensor<128x64xf16> + %45 = tt.dot %44, %43, %42, inputPrecision = tf32 : tensor<128x64xf16> * tensor<64x64xf16> -> tensor<128x64xf32> + %46 = tt.advance %arg10, [%c64_i32, %c0_i32] : > + %47 = tt.advance %arg11, [%c0_i32, %c64_i32] : > + scf.yield %39, %45, %29, %46, %47 : tensor<128xf32>, tensor<128x64xf32>, tensor<128xf32>, !tt.ptr>, !tt.ptr> + } + %22 = tt.expand_dims %21#0 {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32> + %23 = tt.broadcast %22 : tensor<128x1xf32> -> tensor<128x64xf32> + %24 = arith.divf %21#1, %23 : tensor<128x64xf32> + tt.store %16, %24 : !tt.ptr> + tt.return + } +} diff --git a/third_party/intel/lib/TritonToTritonGPUWarp/TritonToTritonGPUWarpPass.cpp b/third_party/intel/lib/TritonToTritonGPUWarp/TritonToTritonGPUWarpPass.cpp index 2fcce6ddf6..7a295baae8 100644 --- a/third_party/intel/lib/TritonToTritonGPUWarp/TritonToTritonGPUWarpPass.cpp +++ b/third_party/intel/lib/TritonToTritonGPUWarp/TritonToTritonGPUWarpPass.cpp @@ -31,6 +31,7 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" #include "llvm/ADT/APSInt.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/Support/Debug.h" #include @@ -136,6 +137,19 @@ class ConvertTritonToTritonGPUWarp ConvertTritonToTritonGPUWarp() = default; ConvertTritonToTritonGPUWarp(unsigned numWarps) { this->numWarps = numWarps; } +private: + DenseMap valueAttrMap; + Dialect *arithDialect = nullptr; + Dialect *mathDialect = nullptr; + +public: + LogicalResult initialize(MLIRContext *context) override { + arithDialect = context->getLoadedDialect("arith"); + mathDialect = context->getLoadedDialect("math"); + valueAttrMap.clear(); + return success(); + } + void runOnOperation() override { MLIRContext *ctx = &getContext(); ModuleOp mod = getOperation(); @@ -166,7 +180,7 @@ class ConvertTritonToTritonGPUWarp // only handle gemm/flashattention for now if (!hasDot || loops.size() == 0) return; - DenseMap valueAttrMap; + valueAttrMap.clear(); DenseMap loopMap; for (auto loop : loops) { auto dots = llvm::to_vector(loop.getOps()); @@ -216,39 +230,154 @@ class ConvertTritonToTritonGPUWarp LDBG("\n"); return; case Workload::ElementWise: - case Workload::Reduction: - case Workload::Attention: + case Workload::Reduction: { + break; + } + case Workload::Attention: { + DotInfo &info0 = loopDotInfo.dotInfo0; + DotInfo &info1 = loopDotInfo.dotInfo1; + DotOp dot0 = info0.dot; + auto aType = cast(dot0.getA().getType()); + auto bType = cast(dot0.getB().getType()); + unsigned Br = aType.getShape()[0]; + unsigned d = bType.getShape()[0]; + unsigned Bc = bType.getShape()[1]; + assert(Br % numWarps == 0 && "rows should be multiple of numWarps"); + assert(Bc % numWarps == 0 && + "columns should be multiple of numWarps"); + SmallVector warpsPerCTA{numWarps, 1}; + SmallVector sizePerWarpQ{Br / numWarps, d}; + SmallVector sizePerWarpK{d, Bc}; + SmallVector sizePerWarpQK{Br / numWarps, Bc}; + SmallVector sizePerWarpV{Bc, d}; + SmallVector sizePerWarpO{Br / numWarps, d}; + auto ctaLayout = ttg::CTALayoutAttr::get(ctx, {1, 1}, {1, 1}, {1, 0}); + auto oLayout = ttg::BlockedEncodingAttr::get( + ctx, sizePerWarpO, {1, 1}, warpsPerCTA, {1, 0}, ctaLayout); + auto vLayout = ttg::DotOperandEncodingAttr::get( + ctx, 1, oLayout, aType.getElementType()); + auto qkLayout1 = ttg::DotOperandEncodingAttr::get( + ctx, 0, oLayout, aType.getElementType()); + OpBuilder b(info1.dot); + auto dot1A = info1.dot.getA(); + auto cvtType = addAttrToType(dot1A.getType(), qkLayout1); + // add convert layout op for dot1.A + auto cvt = b.create(info1.dot.getLoc(), cvtType, + dot1A); + dot1A.replaceAllUsesExcept(cvt, cvt); + auto qkLayout0 = ttg::BlockedEncodingAttr::get( + ctx, sizePerWarpQK, {1, 1}, warpsPerCTA, {1, 0}, ctaLayout); + auto qLayout = ttg::DotOperandEncodingAttr::get( + ctx, 0, qkLayout0, aType.getElementType()); + auto kLayout = ttg::DotOperandEncodingAttr::get( + ctx, 1, qkLayout0, aType.getElementType()); + + // record value's attr + for (auto val : info0.chainOpsA) + valueAttrMap[val] = qLayout; + for (auto val : info0.chainOpsB) + valueAttrMap[val] = kLayout; + for (auto val : info0.chainOpsC) + valueAttrMap[val] = qkLayout0; + if (info0.advanceA) + valueAttrMap[info0.advanceA] = qLayout; + if (info0.advanceB) + valueAttrMap[info0.advanceB] = kLayout; + + assert(info1.chainOpsA.empty()); + for (auto val : info1.chainOpsB) + valueAttrMap[val] = vLayout; + for (auto val : info1.chainOpsC) { + if (valueAttrMap.count(val) == 0) { + valueAttrMap[val] = oLayout; + } else if (valueAttrMap[val] == oLayout) { + continue; + } else { + auto op = val.getDefiningOp(); + // clone value if it has more than 1 layout used + if (auto cst = dyn_cast(op)) { + OpBuilder b(cst); + auto newOp = b.clone(*op); + auto result = newOp->getResults()[0]; + valueAttrMap[result] = oLayout; + val.replaceUsesWithIf(result, [&](OpOperand &use) { + Operation *user = use.getOwner(); + auto val = user->getResults()[0]; + if (std::find(info1.chainOpsC.begin(), info1.chainOpsC.end(), + val) != info1.chainOpsC.end()) + return true; + return false; + }); + } else { + assert(0 && "add more support"); + } + } + } + assert(!info1.advanceA); + if (info1.advanceB) + valueAttrMap[info1.advanceB] = vLayout; break; } + } loop->setAttr(AttrWorkloadName, IntegerAttr::get(i32Ty, int64_t(workLoadKind))); } - /// adding tensor layout attr to related ops - func.walk([&](Operation *op) -> WalkResult { - auto hasTensorType = [&](Type type) { - if (isa(type)) - return true; - else if (auto ptrType = dyn_cast(type)) - if (isa(ptrType.getPointeeType())) - return true; - return false; - }; + auto opHasTensorType = [&](Operation *op) { auto oprndHasTensorType = - llvm::any_of(op->getOperandTypes(), hasTensorType); + llvm::any_of(op->getOperandTypes(), isTensorOrTensorPointerType); auto resultHasTensorType = - llvm::any_of(op->getResultTypes(), hasTensorType); - if (!oprndHasTensorType && !resultHasTensorType) + llvm::any_of(op->getResultTypes(), isTensorOrTensorPointerType); + return oprndHasTensorType || resultHasTensorType; + }; + + /// get other value's layout attr by def/use chain + func.walk([&](Operation *op) { + if (auto loop = dyn_cast(op)) + return; + else if (!opHasTensorType(op)) + return; + else if (auto reduce = dyn_cast(op)) { + assert(reduce.getSrcs().size() == 1); + auto axis = reduce.getAxis(); + auto src = reduce.getSrcs()[0]; + assert(valueAttrMap.count(src) != 0 && + "reduce source attr should be already figured out"); + auto sliceAttr = + ttg::SliceEncodingAttr::get(ctx, axis, valueAttrMap[src]); + auto result = reduce.getResults()[0]; + DenseSet chainedVals; + chainedVals.insert(result); + expandUseChain(result, chainedVals); + for (auto val : chainedVals) { + valueAttrMap[val] = sliceAttr; + } + } else if (op->getDialect() == arithDialect || + op->getDialect() == mathDialect) { + // FIXME: this is really ad-hoc to amend + if (auto mul = dyn_cast(op)) { + auto rhs = mul.getRhs(); + valueAttrMap[rhs] = valueAttrMap[op->getResults()[0]]; + } + } + }); + + /// adding tensor layout attr to related ops + func.walk([&](Operation *op) -> WalkResult { + if (!opHasTensorType(op)) return WalkResult::advance(); - auto numResults = op->getResults().size(); + unsigned numResults = op->getResults().size(); if (auto cst = dyn_cast(op)) { transformArithConstantOp(cst, valueAttrMap[cst]); } else if (auto loop = dyn_cast(op)) { transformScfForOp(loop); } else if (auto store = dyn_cast(op)) { transformStoreOp(store); + } else if (auto convert = dyn_cast(op)) { + ; + // arith, math, tt::ExpandDimsOp, tt::SplatOp } else if (numResults != 0) { assert(numResults == 1 && "only support 1 result"); transformGenericOp(op, valueAttrMap); @@ -267,8 +396,6 @@ class ConvertTritonToTritonGPUWarp } void transformGenericOp(Operation *op, DenseMap &map) { - Dialect *arithDialect = op->getContext()->getLoadedDialect("arith"); - Dialect *mathDialect = op->getContext()->getLoadedDialect("math"); auto result = op->getResults()[0]; // if already got if (map.count(result) != 0) { @@ -277,7 +404,7 @@ class ConvertTritonToTritonGPUWarp } // get the attr by propagating else if (op->getDialect() == arithDialect || - op->getDialect() == mathDialect) { + op->getDialect() == mathDialect || isa(op)) { Attribute attr; for (auto operand : op->getOperands()) { if (auto type = dyn_cast(operand.getType())) @@ -286,6 +413,11 @@ class ConvertTritonToTritonGPUWarp } auto newType = addAttrToType(result.getType(), attr); result.setType(newType); + } else if (auto expand = dyn_cast(op)) { + auto src = expand.getSrc(); + auto attr = cast(src.getType().getEncoding()); + Type newType = addAttrToType(result.getType(), attr.getParent()); + result.setType(newType); } } @@ -387,6 +519,33 @@ class ConvertTritonToTritonGPUWarp offsetsB[1] == 0) return Workload::Gemm; } + // match attention qkv pattern + // %q + // scf.for idx + // %k = tt.load %ptrK + // %s = tt.dot %q, %k + // %ss = arit/math %s + // %v = tt.load %ptrV + // %o = tt.dot %ss, %v + // tt.advance %ptrK, [stepK, 0] + // tt.advance %ptrV, [0, stepV] + else if (loopDotInfo.dotInfo0.dot && loopDotInfo.dotInfo1.dot) { + if (!loopDotInfo.connectDotA || loopDotInfo.connectDotB) + return Workload::None; + auto &info0 = loopDotInfo.dotInfo0; + auto &info1 = loopDotInfo.dotInfo1; + if (!info0.chainOpsA.empty() && // Q is loop invariant + info0.chainOpsA[0].getDefiningOp()->isBeforeInBlock(loop) && + info0.advanceB && info1.advanceB) { + SmallVector rawOffsetsK = info0.advanceB.getOffsets(); + SmallVector rawOffsetsV = info1.advanceB.getOffsets(); + auto offsetsK = *getConstantIntValues(rawOffsetsK); + auto offsetsV = *getConstantIntValues(rawOffsetsV); + if (offsetsK.size() == 2 && offsetsV.size() == 2 && offsetsK[0] == 0 && + offsetsV[1] == 0 && offsetsK[1] == offsetsV[0]) + return Workload::Attention; + } + } return Workload::None; } @@ -404,8 +563,6 @@ class ConvertTritonToTritonGPUWarp void expandDefChain(scf::ForOp loop, Value val, SmallVector &ops, tt::LoadOp &load, tt::AdvanceOp &advance) { - Dialect *arithDialect = val.getContext()->getLoadedDialect("arith"); - Dialect *mathDialect = val.getContext()->getLoadedDialect("math"); ops.push_back(val); // // val is loop invariant // if (loop.isDefinedOutsideOfLoop(val)) @@ -443,8 +600,6 @@ class ConvertTritonToTritonGPUWarp void expandDotCChain(scf::ForOp loop, tt::DotOp dot, SmallVector &ops, LoopDotInfo &loopDotInfo) { - Dialect *arithDialect = dot.getContext()->getLoadedDialect("arith"); - Dialect *mathDialect = dot.getContext()->getLoadedDialect("math"); SmallVector defList; tt::LoadOp nullLoad; tt::AdvanceOp nullAdv; @@ -462,6 +617,84 @@ class ConvertTritonToTritonGPUWarp else if (op->getDialect() == arithDialect || op->getDialect() == mathDialect) ops.push_back(op->getResults()[0]); + else if (isa(op)) + ; + else if (auto dot1 = dyn_cast(op)) { + auto &info1 = loopDotInfo.dotInfo1; + info1.dot = dot1; + auto dotA = dot1.getA(); + auto dotB = dot1.getB(); + auto dotC = dot1.getC(); + if (std::find(ops.begin(), ops.end(), dotA) == ops.end()) + expandDefChain(loop, dotA, info1.chainOpsA, info1.loadA, + info1.advanceA); + else + loopDotInfo.connectDotA = true; + if (std::find(ops.begin(), ops.end(), dotB) == ops.end()) + expandDefChain(loop, dotB, info1.chainOpsB, info1.loadB, + info1.advanceB); + else + loopDotInfo.connectDotB = true; + if (std::find(ops.begin(), ops.end(), dotC) == ops.end()) + expandDotCChain(loop, dot1, info1.chainOpsC, loopDotInfo); + else + loopDotInfo.connectDotC = true; + } + } + } + + void expandUseChain(Value val, DenseSet &chainedVals) { + for (auto &use : val.getUses()) { + Operation *op = use.getOwner(); + // arith/math ops + if (op->getDialect() == arithDialect || op->getDialect() == mathDialect) { + Value result = op->getResults()[0]; + if (chainedVals.count(result) == 0) { + chainedVals.insert(result); + expandUseChain(result, chainedVals); + } + for (auto operand : op->getOperands()) { + expandDefChain(operand, chainedVals); + } + // yield + } else if (auto yield = dyn_cast(op)) { + auto loop = cast(yield->getParentOp()); + Value res = loop.getResult(use.getOperandNumber()); + chainedVals.insert(res); + expandUseChain(res, chainedVals); + // expanddims, splat, store + } else if (isa(op)) { + continue; + // other ops + } else { + assert(0 && "add more support"); + } + } + } + + void expandDefChain(Value val, DenseSet &chainedVals) { + if (chainedVals.count(val)) + return; + chainedVals.insert(val); + if (auto arg = dyn_cast(val)) { + auto loop = dyn_cast(arg.getOwner()->getParentOp()); + assert(loop); + auto loopArg = loop.getInitArgs()[arg.getArgNumber() - 1]; + expandDefChain(loopArg, chainedVals); + } else if (auto def = val.getDefiningOp()) { + if (def->getDialect() == arithDialect || + def->getDialect() == mathDialect) { + for (auto operand : def->getOperands()) { + expandDefChain(operand, chainedVals); + expandUseChain(operand, chainedVals); + } + } else if (isa(def)) { + ; + } else { + assert(0 && "add more support"); + } + } else { + assert(0 && "add more support"); } } };