Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support causal masking in FlashAttention #1947

Draft
wants to merge 2 commits into
base: perf_attn
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ endfunction()


# Disable warnings that show up in external code (gtest;pybind11)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-deprecated-declarations -Wno-covered-switch-default -fvisibility=hidden")

include_directories(".")
include_directories(${MLIR_INCLUDE_DIRS})
Expand Down
87 changes: 70 additions & 17 deletions lib/Conversion/TritonToTritonGPU/TritonToTritonGPUWarpPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ class ConvertTritonToTritonGPUWarp
return;
valueAttrMap.clear();
DenseMap<scf::ForOp, LoopDotInfo> loopMap;
SmallVector<Workload, 2> workloads;
for (auto loop : loops) {
auto dots = llvm::to_vector(loop.getOps<tt::DotOp>());
assert(dots.size() <= 2 && "only support 1 or 2 dot in a loop");
Expand All @@ -180,6 +181,7 @@ class ConvertTritonToTritonGPUWarp
loopMap[loop] = loopDotInfo;
// DAG pattern match
auto workLoadKind = matchLoopWorkload(loop, loopDotInfo);
workloads.push_back(workLoadKind);
if (workLoadKind == Workload::None) {
LDBG("\n");
LDBG("***********************************************\n");
Expand Down Expand Up @@ -384,6 +386,51 @@ class ConvertTritonToTritonGPUWarp
}
return WalkResult::advance();
});

if (loops.size() == 2 && workloads.front() == Workload::Attention &&
workloads.back() == Workload::Attention) {
// match attention with causal masking
Attribute blockLayout = loopMap[loops.front()]
.dotInfo0.dot.getResult()
.getType()
.getEncoding();

func.walk<WalkOrder::PreOrder>([&](Operation *op) {
SmallVector<RankedTensorType> typesWithoutEncoding;
for (Type ty : op->getResultTypes()) {
if (auto tty = dyn_cast<RankedTensorType>(ty))
if (!tty.getEncoding())
typesWithoutEncoding.push_back(tty);
}

if (typesWithoutEncoding.empty())
return;

if (auto cst = dyn_cast<arith::ConstantOp>(op)) {
transformArithConstantOp(cst, blockLayout);
return;
}

// Assign:
// - rank-2 operations: block layout
// - rank-1 operations: slice layout
assert(op->getNumResults() == 1);
OpResult res = op->getOpResult(0);
auto tty = cast<RankedTensorType>(res.getType());
if (tty.getRank() == 2)
res.setType(addAttrToType(tty, blockLayout));
// Rank==1 tensors get a slice layout with the axis depending on the
// use.
if (auto expand = dyn_cast<tt::ExpandDimsOp>(op)) {
Attribute sliceLayout = triton::gpu::SliceEncodingAttr::get(
blockLayout.getContext(), expand.getAxis(), blockLayout);
DenseSet<Value> chainedVals;
expandDefChain(expand.getSrc(), chainedVals);
for (auto cv : chainedVals)
cv.setType(addAttrToType(cv.getType(), sliceLayout));
}
});
}
}

/// adding module attributes
Expand Down Expand Up @@ -415,9 +462,12 @@ class ConvertTritonToTritonGPUWarp
result.setType(newType);
} else if (auto expand = dyn_cast<tt::ExpandDimsOp>(op)) {
auto src = expand.getSrc();
auto attr = cast<ttg::SliceEncodingAttr>(src.getType().getEncoding());
Type newType = addAttrToType(result.getType(), attr.getParent());
result.setType(newType);
if (auto attr = dyn_cast_if_present<ttg::SliceEncodingAttr>(
src.getType().getEncoding())) {
Type newType = addAttrToType(result.getType(), attr.getParent());
result.setType(newType);
}
// else: will patch up later
}
}

Expand Down Expand Up @@ -645,7 +695,8 @@ class ConvertTritonToTritonGPUWarp
// if (chainedVals.count(val))
// return;
// chainedVals.insert(val);
for (auto user : val.getUsers()) {
for (auto &use : val.getUses()) {
Operation *user = use.getOwner();
if (user->getDialect() == arithDialect ||
user->getDialect() == mathDialect) {
auto result = user->getResults()[0];
Expand All @@ -657,19 +708,15 @@ class ConvertTritonToTritonGPUWarp
expandDefChain(operand, chainedVals);
}
} else if (auto yield = dyn_cast<scf::YieldOp>(user)) {
unsigned resNum = -1;
unsigned i = 0;
for (auto operand : user->getOperands()) {
if (operand == val) {
resNum = i;
break;
}
i++;
}
auto loop = dyn_cast<scf::ForOp>(yield->getParentOp());
auto res = loop.getResult(resNum);
auto res = loop.getResult(use.getOperandNumber());
chainedVals.insert(res);
expandUseChain(res, chainedVals);
} else if (auto forLoop = dyn_cast<scf::ForOp>(user)) {
auto arg = forLoop.getRegionIterArg(use.getOperandNumber() -
forLoop.getNumControlOperands());
chainedVals.insert(arg);
expandUseChain(arg, chainedVals);
} else if (isa<tt::ExpandDimsOp, tt::SplatOp, tt::StoreOp>(user)) {
;
} else {
Expand All @@ -687,15 +734,21 @@ class ConvertTritonToTritonGPUWarp
assert(loop);
auto loopArg = loop.getInitArgs()[arg.getArgNumber() - 1];
expandDefChain(loopArg, chainedVals);
} else if (auto def = val.getDefiningOp()) {
} else if (auto opRes = dyn_cast<OpResult>(val)) {
Operation *def = opRes.getOwner();
if (def->getDialect() == arithDialect ||
def->getDialect() == mathDialect) {
for (auto operand : def->getOperands()) {
expandDefChain(operand, chainedVals);
expandUseChain(operand, chainedVals);
}
} else if (isa<tt::SplatOp, tt::BroadcastOp, tt::ReduceOp>(def)) {
;
} else if (auto forLoop = dyn_cast<scf::ForOp>(def)) {
Value yieldArg = forLoop.getYieldedValues()[opRes.getResultNumber()];
chainedVals.insert(yieldArg);
expandDefChain(yieldArg, chainedVals);
} else if (isa<tt::SplatOp, tt::BroadcastOp, tt::ReduceOp,
tt::MakeRangeOp>(def)) {
chainedVals.insert(def->getResult(0));
} else {
assert(0 && "add more support");
}
Expand Down
45 changes: 22 additions & 23 deletions python/tutorials/06-fused-attention.forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ def forward(q, k, v, causal, sm_scale):
BLOCK_N = 64 if Lk <= 64 else 32
num_stages = 4 if Lk <= 64 else 3
num_warps = 8 if Lq == 64 else 16
causal = False
stage = 3 if causal else 1
grid = (q.shape[0], q.shape[1],triton.cdiv(q.shape[2], BLOCK_M))
print("Q stride =", q.stride(0), q.stride(1), q.stride(2), q.stride(3))
Expand All @@ -182,28 +181,28 @@ def forward(q, k, v, causal, sm_scale):
)
return o

# torch.manual_seed(0)
# Z = 1
# H = 2
# N_CTX = 1024
# D_HEAD = 64
# causal = False
# dtype=torch.float16
# q = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype)
# k = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype)
# v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype)
# sm_scale = 0.125
# dout = torch.randn_like(q)
# #torch_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
# #torch.save(torch_output, "./torch_output.pt")
torch.manual_seed(0)
Z = 1
H = 2
N_CTX = 1024
D_HEAD = 64
causal = True
dtype=torch.float16
q = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype)
k = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype)
v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype)
sm_scale = 0.125
dout = torch.randn_like(q)
torch_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=causal)
# torch.save(torch_output, "./torch_output.pt")
# torch_output = torch.load("./torch_output.pt")
# triton_output = forward(q, k, v, causal, sm_scale)
triton_output = forward(q, k, v, causal, sm_scale)

# torch_outputf32 = torch_output.to(torch.float32)
# if torch.allclose(triton_output, torch_outputf32, atol=1e-3, rtol=1e-3):
# print("✅ Triton and Torch match")
# else:
# print("❌ Triton and Torch differ")
torch_outputf32 = torch_output.to(torch.float32)
if torch.allclose(triton_output, torch_outputf32, atol=1e-3, rtol=1e-3):
print("✅ Triton and Torch match")
else:
print("❌ Triton and Torch differ")



Expand All @@ -228,15 +227,15 @@ def forward(q, k, v, causal, sm_scale):
args={},
))
def benchmark(Z, H, N_CTX, D_HEAD, provider):
causal = False
causal = True
dtype=torch.float16
q = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype)
k = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype)
v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype)
sm_scale = 0.125
quantiles = [0.5, 0.2, 0.8]
if provider == 'onednn':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale = sm_scale), rep=1000, quantiles=quantiles,
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=causal, scale = sm_scale), rep=1000, quantiles=quantiles,
fast_flush=False)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: forward(q, k, v, causal, sm_scale), rep=1000, quantiles=quantiles,
Expand Down
3 changes: 2 additions & 1 deletion third_party/intel/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def make_ttgir(mod, metadata, opt, device_arch):
intel.passes.ttgpuir.add_match_target_size(pm)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
intel.passes.ttgpuir.add_schedule_load(pm)
# FIXME: re-enable pass
# intel.passes.ttgpuir.add_schedule_load(pm)
passes.common.add_symbol_dce(pm)
pm.run(mod)
return mod
Expand Down
13 changes: 13 additions & 0 deletions third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,19 @@ def TritonGEN_SubgroupIdOp : TritonGEN_Op<"subgroup.id", [Pure]> {
}];
}

def TritonGEN_LaneIdOp : TritonGEN_Op<"lane.id", [Pure]> {
let summary = "Lane Index";
string baseDescription = [{
The `gen.lane.id` operation returns the lane ID which is a number
from 0 to the subgroup size minus one.
}];
let arguments = (ins);
let results = (outs I32:$res);
let assemblyFormat = [{
attr-dict `:` type($res)
}];
}

//===----------------------------------------------------------------------===//
// Synchronization
//===----------------------------------------------------------------------===//
Expand Down
4 changes: 3 additions & 1 deletion third_party/intel/lib/GPUToTritonGEN/GPUToTritonGENPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,9 @@ void mlir::triton::populateGPUToTritonGENConversionPatterns(
GPUIndexIntrinsicOpLowering<mlir::gpu::GridDimOp, TritonGEN::GridDimXOp,
TritonGEN::GridDimYOp, TritonGEN::GridDimZOp>,
SingleDimLaunchConfigLowering<mlir::gpu::SubgroupIdOp,
TritonGEN::SubgroupIdOp>>(converter);
TritonGEN::SubgroupIdOp>,
SingleDimLaunchConfigLowering<mlir::gpu::LaneIdOp, TritonGEN::LaneIdOp>>(
converter);
patterns.add<GPUFuncOpLowering>(
converter,
/*allocaAddrSpace=*/TritonGEN::TritonGENMemorySpace::kFunction,
Expand Down
35 changes: 28 additions & 7 deletions third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,27 @@ struct TritonGENSubgroupIdLowering
}
};

//===----------------------------------------------------------------------===//
// LaneID Op Lowering
//===----------------------------------------------------------------------===//

struct TritonGENLaneIdLowering
: public ConvertOpToLLVMPattern<TritonGEN::LaneIdOp> {
using ConvertOpToLLVMPattern<TritonGEN::LaneIdOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(TritonGEN::LaneIdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto retType = rewriter.getIntegerType(32);

intel::AttributeList attrs;
LLVM::CallOp callOp = createDeviceFunctionCall(
rewriter, "_Z22get_sub_group_local_idv", retType, {}, {}, attrs);
rewriter.replaceOp(op, callOp);
return success();
}
};

//===----------------------------------------------------------------------===//
// Synchronization Ops Lowerings
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1220,13 +1241,13 @@ void mlir::triton::populateTritonGENToLLVMConversionPatterns(
TritonGENBlockDimXLowering, TritonGENBlockDimYLowering,
TritonGENBlockDimZLowering, TritonGENGridDimXLowering,
TritonGENGridDimYLowering, TritonGENGridDimZLowering,
TritonGENSubgroupIdLowering, TritonGENBarrierLowering,
TritonGENSplitBarrierSignalLowering, TritonGENSplitBarrierWaitLowering,
TritonGENNamedBarrierSignalLowering, TritonGENNamedBarrierWaitLowering,
TritonSubGroupShuffleLowering, TritonSubgroupReduceLowering,
TritonMatrixDPASLowering, TritonMatrix2DBlockLoadLowering,
TritonMatrix2DBlockStoreLowering, TritonMatrix2DBlockPrefetchLowering>(
converter);
TritonGENSubgroupIdLowering, TritonGENLaneIdLowering,
TritonGENBarrierLowering, TritonGENSplitBarrierSignalLowering,
TritonGENSplitBarrierWaitLowering, TritonGENNamedBarrierSignalLowering,
TritonGENNamedBarrierWaitLowering, TritonSubGroupShuffleLowering,
TritonSubgroupReduceLowering, TritonMatrixDPASLowering,
TritonMatrix2DBlockLoadLowering, TritonMatrix2DBlockStoreLowering,
TritonMatrix2DBlockPrefetchLowering>(converter);
}

void registerConvertTritonTritonGENToLLVMInterface(DialectRegistry &registry) {
Expand Down
Loading