Skip to content

Commit

Permalink
Integrate StableHLO at openxla/stablehlo@b6406a43
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 629847908
  • Loading branch information
sdasgup3 authored and copybara-github committed May 1, 2024
1 parent 2a88402 commit 69f3233
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 19 deletions.
15 changes: 2 additions & 13 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt
--- stablehlo/CMakeLists.txt
+++ stablehlo/CMakeLists.txt
@@ -13,153 +13,20 @@
@@ -13,154 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
Expand Down Expand Up @@ -43,6 +43,7 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt
#-------------------------------------------------------------------------------
-option(STABLEHLO_BUILD_EMBEDDED "Build StableHLO as part of another project" OFF)
-option(STABLEHLO_ENABLE_BINDINGS_PYTHON "Enables StableHLO Python bindings" OFF)
-option(STABLEHLO_ENABLE_PYTHON_TF_TESTS "Enables StableHLO to SavedModel tests requiring TF" OFF)
-option(STABLEHLO_ENABLE_STRICT_BUILD "Build StableHLO with strict warnings and warnings as errors" OFF)
-option(STABLEHLO_ENABLE_SANITIZER "Enable a sanitizer [OFF, address]" OFF)
-option(STABLEHLO_ENABLE_SPLIT_DWARF "Enable split DWARF if the platform supports it" OFF)
Expand Down Expand Up @@ -2563,16 +2564,4 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c
+} // namespace experimental
+} // namespace stablehlo
+} // namespace mlir
diff --ruN a/stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cpp b/stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cpp
--- stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cpp
+++ stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cpp
@@ -735,7 +735,7 @@
addToWorkList(retOp.getOperand(resultNo));
while (!workList.empty()) {
auto definition = workList.pop_back_val();
- if (auto blockArg = definition.dyn_cast<BlockArgument>()) {
+ if (auto blockArg = dyn_cast<BlockArgument>(definition)) {
// using one argument implies using the whole argument pair
const auto pairNo = blockArg.getArgNumber() % numOperandPairs;
usedArgs.set(pairNo);

4 changes: 2 additions & 2 deletions third_party/stablehlo/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
# LINT.IfChange
STABLEHLO_COMMIT = "5217297204acb9e5a21e40fa825aa0769fb3c33f"
STABLEHLO_SHA256 = "9090e7f31420ef2bf5dc17a385f6e828f0072a79fa59d065c01c9a90e46ee730"
STABLEHLO_COMMIT = "b6406a43b48b7803f3efdbc235b1fbb5da782449"
STABLEHLO_SHA256 = "d2ecc5fe29f4a2fd17723f1f403109d24cafe8479f1e1a95d27d447c84a01976"
# LINT.ThenChange(Google-internal path)

tf_http_archive(
Expand Down
8 changes: 4 additions & 4 deletions xla/mlir_hlo/tests/Dialect/mhlo/verifier_conv_op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,8 @@ func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,

// -----

func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,
%arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
func.func @invalid_conv_dimensions(%arg0: tensor<3x8x8x207xf32>,
%arg1: tensor<3x3x207x16xf32>) -> tensor<3x8x8x16xf32> {
// expected-error@+1 {{expects output feature dimension size (16) to be a multiple of batch_group_count. Got batch_group_count = 3.}}
%0 = mhlo.convolution(%arg0, %arg1)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
Expand All @@ -342,8 +342,8 @@ func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,
feature_group_count = 1 : i64,
precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]
} :
(tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
func.return %0 : tensor<1x8x8x16xf32>
(tensor<3x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<3x8x8x16xf32>
func.return %0 : tensor<3x8x8x16xf32>
}

// -----
Expand Down

0 comments on commit 69f3233

Please sign in to comment.