Skip to content

Commit

Permalink
Refactor Reuse LDS to be able to keep the same heuristic and also rev…
Browse files Browse the repository at this point in the history
…ert the enableApplicability change
  • Loading branch information
dhernandez0 authored and krzysz00 committed Oct 4, 2024
1 parent 4dde937 commit 3a49093
Show file tree
Hide file tree
Showing 6 changed files with 398 additions and 305 deletions.
49 changes: 49 additions & 0 deletions mlir/include/mlir/Dialect/Rock/utility/memoryUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#ifndef MLIR_DIALECT_ROCK_UTILITY_MEMORYUTILS_H
#define MLIR_DIALECT_ROCK_UTILITY_MEMORYUTILS_H

#include "mlir/Dialect/Rock/IR/Rock.h"
#include "llvm/ADT/MapVector.h"

namespace mlir {
namespace rock {

struct LDSInfo {
llvm::SmallDenseMap<GpuAllocOp, llvm::SetVector<GpuAllocOp>>
interferenceGraph;
SmallVector<GpuAllocOp> allocs;
SmallVector<GpuDeallocOp> deallocs;
llvm::SmallDenseMap<GpuAllocOp, llvm::SetVector<GpuAllocOp>> deallocBefore;
};

/// Utility function to get workgroup memory size
std::optional<int64_t> getWorkgroupMemorySize(MemRefType type);

/// Utility function to check if there is enough LDS on the target architecture
LogicalResult checkLDSSize(Operation *op, int64_t ldsBytes);

/// This is a greedy graph coloring algorithm.
/// There are some changes to make it work for LDS, the main one:
/// each alloc can be assigned more than one color, this is because
/// in graph coloring all vertex are assumed to be the same size
/// (for example, register allocation).
/// Example: A=GpuAllocOp(1kB), B=GpuAllocOp(1kB), C=GpuAllocOp(2kB)
/// A <--> B C (A and B have an edge, C disjoint)
/// In this case, we can assign colors: A -> {0}, B -> {1}, and C -> {0, 1}.
/// Colors 0 and 1 are 1kB each.
/// Note: If an alloc has more than one color assigned, they have to be
/// consecutive.
std::tuple<llvm::MapVector<int64_t, int64_t>,
SmallVector<std::tuple<GpuAllocOp, int64_t, int64_t, bool>>>
graphColoring(LDSInfo &ldsInfo);

/// Utility function to create an interference graph of GPUAllocs and
/// GPUDeallocs
FailureOr<LDSInfo> createInterferenceGraph(func::FuncOp &func);

/// Utility function to compute allocated LDS after LDS reuse pass.
FailureOr<int64_t> getAllocatedLDSAfterReuse(func::FuncOp &func);

} // namespace rock
} // namespace mlir

#endif // MLIR_DIALECT_ROCK_UTILITY_MEMORYUTILS_H
26 changes: 13 additions & 13 deletions mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,20 +156,20 @@ void rock::buildKernelPipeline(OpPassManager &pm,
funcPm.addPass(rock::createRockBlockwiseGemmToThreadwisePass());
funcPm.addPass(rock::createRockOutputSwizzlePass());

if (options.enableFusion) {
// align linalg tiling
/* rocmlir-opt --rock-linalg-align --rock-pipeline --canonicalize
* --convert-linalg-to-affine-loops --rock-vectorize-fusions
*/
funcPm.addPass(rock::createRockLinalgAlignPass());
funcPm.addPass(rock::createRockPipelinePass());
funcPm.addPass(createCanonicalizerPass());
funcPm.addPass(createConvertLinalgToAffineLoopsPass());
funcPm.addPass(rock::createRockVectorizeFusionsPass());
}
funcPm.addPass(rock::createRockReuseLDSPass());

if (!options.enableApplicability) {
if (options.enableFusion) {
// align linalg tiling
/* rocmlir-opt --rock-linalg-align --rock-pipeline --canonicalize
* --convert-linalg-to-affine-loops --rock-vectorize-fusions
*/
funcPm.addPass(rock::createRockLinalgAlignPass());
funcPm.addPass(rock::createRockPipelinePass());
funcPm.addPass(createCanonicalizerPass());
funcPm.addPass(createConvertLinalgToAffineLoopsPass());
funcPm.addPass(rock::createRockVectorizeFusionsPass());
}
funcPm.addPass(rock::createRockReuseLDSPass());

// rock lowering for reductions
/* rocmlir-opt --rock-lower-reduce
*/
Expand Down
44 changes: 29 additions & 15 deletions mlir/lib/Dialect/Rock/Transforms/OutputSwizzle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include "mlir/Dialect/Rock/IR/Rock.h"
#include "mlir/Dialect/Rock/IR/TransformMapBuilder.h"
#include "mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h"
#include "mlir/Dialect/Rock/utility/AmdArchDb.h"
#include "mlir/Dialect/Rock/utility/loweringUtils.h"
#include "mlir/Dialect/Rock/utility/math.h"
#include "mlir/Dialect/Rock/utility/transformMapUtils.h"
Expand All @@ -32,6 +31,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Rock/Passes.h"
#include "mlir/Dialect/Rock/utility/builderUtils.h"
#include "mlir/Dialect/Rock/utility/memoryUtils.h"
#include "mlir/Dialect/Rock/utility/transformMapUtils.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
Expand Down Expand Up @@ -80,17 +80,6 @@ static bool hasGlobalMemoryAddressSpace(MemRefType type) {
!hasPrivateMemoryAddressSpace(type);
}

static LogicalResult checkLDSSize(Operation *op, int64_t ldsBytes) {
// Check for arch limitations exceeded
FailureOr<StringAttr> maybeArch = getArch(op);
if (succeeded(maybeArch)) {
StringAttr arch = maybeArch.value();
const int64_t ldsSize = rock::lookupArchInfo(arch).maxSharedMemPerWG;
return success(ldsBytes <= ldsSize);
}
return success();
}

static std::optional<std::tuple<int64_t, int64_t, ArrayAttr>>
getIdToLDS(ThreadwiseWriteAllOp &op, OpBuilder &b) {
ArrayAttr srcTransform = op.getExtraViewsAttr();
Expand Down Expand Up @@ -195,8 +184,9 @@ struct ThreadwiseWriteAllRewritePattern
return success();
}
size_t extraIdxCount = op.getExtraIndices().size();
VectorizationResult vectorRes =
getMaxVectorization(destView, extraIdxCount, /*inputDimLen=*/std::nullopt, destView.getDefiningOp());
VectorizationResult vectorRes = getMaxVectorization(
destView, extraIdxCount, /*inputDimLen=*/std::nullopt,
destView.getDefiningOp());
int64_t originalVectorLen = vectorRes.max;

if (vectorLen <= originalVectorLen) {
Expand Down Expand Up @@ -398,8 +388,24 @@ void RockOutputSwizzlePass::runOnOperation() {
if (!func->hasAttr("kernel"))
return;

// Get allocated LDS after "reuse LDS" pass
FailureOr<int64_t> maybeAllocatedLDS = getAllocatedLDSAfterReuse(func);
if (failed(maybeAllocatedLDS)) {
LLVM_DEBUG(llvm::dbgs() << "Failed calling getAllocatedLDS\n");
return signalPassFailure();
}
int64_t allocatedLDS = maybeAllocatedLDS.value();

// not enough LDS memory
if (failed(checkLDSSize(func, allocatedLDS))) {
LLVM_DEBUG(llvm::dbgs() << "We require too much LDS memory: "
<< allocatedLDS << " bytes\n");
return signalPassFailure();
}

SmallVector<Operation *, 4> writes;
func.walk([&writes, &rewriter](ThreadwiseWriteAllOp threadwiseWriteAll) {
func.walk([&writes, &rewriter,
allocatedLDS](ThreadwiseWriteAllOp threadwiseWriteAll) {
MemRefType destMemRefType =
cast<MemRefType>(threadwiseWriteAll.getDest().getType());

Expand All @@ -426,6 +432,14 @@ void RockOutputSwizzlePass::runOnOperation() {
<< ldsRequiredBytes << " bytes, skipping pass\n");
return;
}
// heuristic: if we need more LDS, skip this pass
if (ldsRequiredBytes > allocatedLDS) {
LLVM_DEBUG(llvm::dbgs()
<< "OutputSwizzle requires more LDS memory, current usage: "
<< allocatedLDS << " bytes, required: " << ldsRequiredBytes
<< " bytes, skipping pass\n");
return;
}
writes.push_back(threadwiseWriteAll);
}
});
Expand Down
Loading

0 comments on commit 3a49093

Please sign in to comment.