Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move output swizzling pass before fusions #1651

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions mlir/include/mlir/Dialect/Rock/utility/memoryUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#ifndef MLIR_DIALECT_ROCK_UTILITY_MEMORYUTILS_H
#define MLIR_DIALECT_ROCK_UTILITY_MEMORYUTILS_H

#include "mlir/Dialect/Rock/IR/Rock.h"
#include "llvm/ADT/MapVector.h"

namespace mlir {
namespace rock {

struct LDSInfo {
llvm::SmallDenseMap<GpuAllocOp, llvm::SetVector<GpuAllocOp>>
interferenceGraph;
SmallVector<GpuAllocOp> allocs;
SmallVector<GpuDeallocOp> deallocs;
llvm::SmallDenseMap<GpuAllocOp, llvm::SetVector<GpuAllocOp>> deallocBefore;
};

/// Utility function to get workgroup memory size
std::optional<int64_t> getWorkgroupMemorySize(MemRefType type);

/// Utility function to check if there is enough LDS on the target architecture
LogicalResult checkLDSSize(Operation *op, int64_t ldsBytes);

/// This is a greedy graph coloring algorithm.
/// There are some changes to make it work for LDS, the main one:
/// each alloc can be assigned more than one color, this is because
/// in graph coloring all vertex are assumed to be the same size
/// (for example, register allocation).
/// Example: A=GpuAllocOp(1kB), B=GpuAllocOp(1kB), C=GpuAllocOp(2kB)
/// A <--> B C (A and B have an edge, C disjoint)
/// In this case, we can assign colors: A -> {0}, B -> {1}, and C -> {0, 1}.
/// Colors 0 and 1 are 1kB each.
/// Note: If an alloc has more than one color assigned, they have to be
/// consecutive.
std::tuple<llvm::MapVector<int64_t, int64_t>,
SmallVector<std::tuple<GpuAllocOp, int64_t, int64_t, bool>>>
graphColoring(LDSInfo &ldsInfo);

/// Utility function to create an interference graph of GPUAllocs and
/// GPUDeallocs
FailureOr<LDSInfo> createInterferenceGraph(func::FuncOp &func);

/// Utility function to compute allocated LDS after LDS reuse pass.
FailureOr<int64_t> getAllocatedLDSAfterReuse(func::FuncOp &func);

} // namespace rock
} // namespace mlir

#endif // MLIR_DIALECT_ROCK_UTILITY_MEMORYUTILS_H
33 changes: 16 additions & 17 deletions mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ void rock::buildKernelPipeline(OpPassManager &pm,
/* rocmlir-opt --rock-affix-params --rock-conv-to-gemm
* --rock-fold-broadcast --rock-affix-params --rock-gemm-to-gridwise
* --rock-regularize --rock-gridwise-gemm-to-blockwise
* --rock-blockwise-gemm-to-threadwise --rock-output-swizzle
*/
auto &funcPm = pm.nest<func::FuncOp>();
funcPm.addPass(rock::createRockAffixTuningParametersPass(
Expand All @@ -153,31 +154,29 @@ void rock::buildKernelPipeline(OpPassManager &pm,
funcPm.addPass(rock::createRockRegularizePass());
funcPm.addPass(rock::createRockGridwiseGemmToBlockwisePass());
funcPm.addPass(rock::createRockBlockwiseGemmToThreadwisePass());

if (options.enableFusion) {
// align linalg tiling
/* rocmlir-opt --rock-linalg-align --canonicalize
* --convert-linalg-to-affine-loops
*/
funcPm.addPass(rock::createRockLinalgAlignPass());
funcPm.addPass(rock::createRockPipelinePass());
funcPm.addPass(createCanonicalizerPass());
funcPm.addPass(createConvertLinalgToAffineLoopsPass());
funcPm.addPass(rock::createRockVectorizeFusionsPass());
}
funcPm.addPass(rock::createRockReuseLDSPass());
funcPm.addPass(rock::createRockOutputSwizzlePass());
funcPm.addPass(rock::createRockReuseLDSPass());

if (!options.enableApplicability) {
if (options.enableFusion) {
// align linalg tiling
/* rocmlir-opt --rock-linalg-align --rock-pipeline --canonicalize
* --convert-linalg-to-affine-loops --rock-vectorize-fusions
*/
funcPm.addPass(rock::createRockLinalgAlignPass());
funcPm.addPass(rock::createRockPipelinePass());
funcPm.addPass(createCanonicalizerPass());
funcPm.addPass(createConvertLinalgToAffineLoopsPass());
funcPm.addPass(rock::createRockVectorizeFusionsPass());
}
funcPm.addPass(rock::createRockReuseLDSPass());

// rock lowering for reductions
/* rocmlir-opt --rock-lower-reduce
*/
funcPm.addPass(rock::createRockLowerReducePass());

// rock lowering (block to thread)
/* rocmlir-opt --rock-lowering-blockwise-gemm-to-threadwise
* --canonicalize --rock-threadwise-gemm-lowering
// rock lowering (thread down)
/* rocmlir-opt --rock-threadwise-gemm-lowering
* --rock-analyze-memory-use --rock-sugar-to-loops --rock-clean-math
* --math-extend-to-supported-types="source-types=f64,f32,f16
* target-type=f32"
Expand Down
57 changes: 24 additions & 33 deletions mlir/lib/Dialect/Rock/Transforms/OutputSwizzle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include "mlir/Dialect/Rock/IR/Rock.h"
#include "mlir/Dialect/Rock/IR/TransformMapBuilder.h"
#include "mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h"
#include "mlir/Dialect/Rock/utility/AmdArchDb.h"
#include "mlir/Dialect/Rock/utility/loweringUtils.h"
#include "mlir/Dialect/Rock/utility/math.h"
#include "mlir/Dialect/Rock/utility/transformMapUtils.h"
Expand All @@ -32,6 +31,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Rock/Passes.h"
#include "mlir/Dialect/Rock/utility/builderUtils.h"
#include "mlir/Dialect/Rock/utility/memoryUtils.h"
#include "mlir/Dialect/Rock/utility/transformMapUtils.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
Expand Down Expand Up @@ -80,31 +80,6 @@ static bool hasGlobalMemoryAddressSpace(MemRefType type) {
!hasPrivateMemoryAddressSpace(type);
}

static int64_t getLDSTotalSize(func::FuncOp &func) {
int64_t totalSize = 0;
func.walk([&](GpuAllocOp gpuAlloc) {
mlir::MemRefType type = gpuAlloc.getOutput().getType();
auto memSpaceValue =
dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace())
.getValue();
if (memSpaceValue == gpu::GPUDialect::getWorkgroupAddressSpace()) {
totalSize += type.getNumElements() * getByteWidth(type.getElementType());
}
});
return totalSize;
}

static LogicalResult checkLDSSize(Operation *op, int64_t ldsBytes) {
// Check for arch limitations exceeded
FailureOr<StringAttr> maybeArch = getArch(op);
if (succeeded(maybeArch)) {
StringAttr arch = maybeArch.value();
const int64_t ldsSize = rock::lookupArchInfo(arch).maxSharedMemPerWG;
return success(ldsBytes <= ldsSize);
}
return success();
}

static std::optional<std::tuple<int64_t, int64_t, ArrayAttr>>
getIdToLDS(ThreadwiseWriteAllOp &op, OpBuilder &b) {
ArrayAttr srcTransform = op.getExtraViewsAttr();
Expand Down Expand Up @@ -137,6 +112,9 @@ struct ThreadwiseWriteAllRewritePattern
PatternRewriter &b) const override {
Location loc = op.getLoc();

if (!hasGlobalMemoryAddressSpace(op.getDest().getType()))
return b.notifyMatchFailure(op, "isn't writing to global memory");

// Prepare some useful constants.
Value convertedC = op.getSource();
Value matC = op.getDest();
Expand Down Expand Up @@ -206,8 +184,9 @@ struct ThreadwiseWriteAllRewritePattern
return success();
}
size_t extraIdxCount = op.getExtraIndices().size();
VectorizationResult vectorRes =
getMaxVectorization(destView, extraIdxCount);
VectorizationResult vectorRes = getMaxVectorization(
destView, extraIdxCount, /*inputDimLen=*/std::nullopt,
destView.getDefiningOp());
int64_t originalVectorLen = vectorRes.max;

if (vectorLen <= originalVectorLen) {
Expand Down Expand Up @@ -409,12 +388,24 @@ void RockOutputSwizzlePass::runOnOperation() {
if (!func->hasAttr("kernel"))
return;

// Get total LDS memory allocated
int64_t ldsAllocated = getLDSTotalSize(func);
// Get allocated LDS after "reuse LDS" pass
FailureOr<int64_t> maybeAllocatedLDS = getAllocatedLDSAfterReuse(func);
if (failed(maybeAllocatedLDS)) {
LLVM_DEBUG(llvm::dbgs() << "Failed calling getAllocatedLDS\n");
return signalPassFailure();
}
int64_t allocatedLDS = maybeAllocatedLDS.value();

// not enough LDS memory
if (failed(checkLDSSize(func, allocatedLDS))) {
LLVM_DEBUG(llvm::dbgs() << "We require too much LDS memory: "
<< allocatedLDS << " bytes\n");
return signalPassFailure();
}

SmallVector<Operation *, 4> writes;
func.walk([&writes, &rewriter,
ldsAllocated](ThreadwiseWriteAllOp threadwiseWriteAll) {
allocatedLDS](ThreadwiseWriteAllOp threadwiseWriteAll) {
MemRefType destMemRefType =
cast<MemRefType>(threadwiseWriteAll.getDest().getType());

Expand Down Expand Up @@ -442,10 +433,10 @@ void RockOutputSwizzlePass::runOnOperation() {
return;
}
// heuristic: if we need more LDS, skip this pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should check if there's any performance regression due to this. I'm happy to do this if you are busy with other things.

if (ldsRequiredBytes > ldsAllocated) {
if (ldsRequiredBytes > allocatedLDS) {
LLVM_DEBUG(llvm::dbgs()
<< "OutputSwizzle requires more LDS memory, current usage: "
<< ldsAllocated << " bytes, required: " << ldsRequiredBytes
<< allocatedLDS << " bytes, required: " << ldsRequiredBytes
<< " bytes, skipping pass\n");
return;
}
Expand Down
Loading