Skip to content

Commit

Permalink
Remove dead values before shape refinement
Browse files Browse the repository at this point in the history
Some programs have a lot of dead values, causing shape refinement to fail to
converge. Indeed, even if there is no shape refinement to do, greedy pattern
rewrites keep iterating as long as the IR changes, which includes removing dead
values.

PiperOrigin-RevId: 619606321
  • Loading branch information
Michael Levesque-Dion authored and TensorFlow MLIR Team committed Mar 27, 2024
1 parent 0272de9 commit 0b73f94
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 0 deletions.
1 change: 1 addition & 0 deletions stablehlo/stablehlo/experimental/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ cc_library(
"transforms/ChloRecomposeOps.cpp",
"transforms/StablehloCanonicalizeDynamism.cpp",
"transforms/StablehloRefineShapes.cpp",
"transforms/StablehloTrivialDce.cpp",
],
hdrs = [
"transforms/Passes.h",
Expand Down
1 change: 1 addition & 0 deletions stablehlo/stablehlo/experimental/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ cc_library(
"transforms/ChloRecomposeOps.cpp",
"transforms/StablehloCanonicalizeDynamism.cpp",
"transforms/StablehloRefineShapes.cpp",
"transforms/StablehloTrivialDce.cpp",
],
hdrs = [
"transforms/Passes.h",
Expand Down
1 change: 1 addition & 0 deletions stablehlo/stablehlo/experimental/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ add_mlir_dialect_library(ExperimentalStablehloPasses
ChloRecomposeOps.cpp
StablehloCanonicalizeDynamism.cpp
StablehloRefineShapes.cpp
StablehloTrivialDce.cpp

DEPENDS
ExperimentalPassesIncGen
Expand Down
16 changes: 16 additions & 0 deletions stablehlo/stablehlo/experimental/transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,19 @@ def ChloRecomposeOpsPass : Pass<"experimental-chlo-recompose-ops", "ModuleOp"> {
}];
let dependentDialects = ["chlo::ChloDialect"];
}

def StablehloTrivialDcePass : Pass<"experimental-stablehlo-trivial-dce", "ModuleOp"> {
let summary = "(Experimental) Performs a single bottom up pass to remove values that are trivially dead.";
let description = [{
An experimental pass to remove dead values prior to running other passes
that may fail to converge otherwise. For example, running shape refinement
on a program that has a lot of dead values can fail because shape refinement
is top down and removing values causes a new iteration to be triggered, and
removing all the dead values with a top down traversal can take a lot of
iterations (10+), which is slow.

Performing a single pass should be fast, and doing it bottom up means that
values that are transitively dead can be removed since leaf values will be
processed first.
}];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/* Copyright 2022 The StableHLO Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstdint>

#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "stablehlo/experimental/transforms/Passes.h"

namespace mlir {
namespace stablehlo {
namespace experimental {

#define GEN_PASS_DEF_STABLEHLOTRIVIALDCEPASS
#include "stablehlo/experimental/transforms/Passes.h.inc"

namespace {

struct StablehloTrivialDcePass
: public impl::StablehloTrivialDcePassBase<StablehloTrivialDcePass> {
using StablehloTrivialDcePassBase::StablehloTrivialDcePassBase;

void runOnOperation() override {
GreedyRewriteConfig config;

// Hardcode defaults for stability.
config.enableRegionSimplification = true;
config.maxNumRewrites = GreedyRewriteConfig::kNoLimit;
config.strictMode = GreedyRewriteStrictness::AnyOp;

// Run a single bottom up pass.
config.useTopDownTraversal = false;
config.maxIterations = 1;

// Running a greedy rewrite will cause trivially dead values to be removed.
// Doing it without patterns ensures that no other changes are made to the
// IR. Doing it bottom-up ensures that values that are transitively dead are
// also removed. Although 1 pass should be enough,
// applyPatternsAndFoldGreedily will want to run at least 1 more iteration
// to confirm convergence, but we don't need to check for convergence, so we
// ignore the return value.
(void)applyPatternsAndFoldGreedily(getOperation(), RewritePatternSet(&getContext()), config);
}
};

} // namespace
} // namespace experimental
} // namespace stablehlo
} // namespace mlir

0 comments on commit 0b73f94

Please sign in to comment.