Skip to content

Commit cc4fbd8

Browse files
GleasonKGoogle-ML-Automation
authored andcommitted
[MHLO] Preserve discardable attrs when canonicalizing while op
PiperOrigin-RevId: 737640159
1 parent 02f01e6 commit cc4fbd8

File tree

3 files changed

+83
-1
lines changed

3 files changed

+83
-1
lines changed

third_party/stablehlo/temporary.patch

+57
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,50 @@ diff --ruN a/stablehlo/stablehlo/integrations/python/tests/chlo.py b/stablehlo/s
350350
+ assert attr.rhs_contracting_dimensions == [2]
351351
+ assert attr.lhs_ragged_dimensions == [1]
352352
+ assert attr.rhs_group_dimensions == [0]
353+
diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
354+
--- stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
355+
+++ stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
356+
@@ -1751,7 +1751,7 @@
357+
// -----
358+
359+
////////
360+
-// WhileOp DCE
361+
+// WhileOp
362+
363+
// CHECK-LABEL: while_op_with_outfeed_no_dce
364+
func.func @while_op_with_outfeed_no_dce(%arg0: tensor<i64>) -> tensor<i64> {
365+
@@ -1780,6 +1780,31 @@
366+
stablehlo.return %iterArg : tensor<i64>
367+
}
368+
return %arg0 : tensor<i64>
369+
+}
370+
+
371+
+// Constant capture
372+
+// CHECK-LABEL: while_op_constant_capture
373+
+func.func @while_op_constant_capture(%arg0: tensor<10xf32>) -> (tensor<10xf32>) {
374+
+ %c = stablehlo.constant dense<1> : tensor<i32>
375+
+ %c_0 = stablehlo.constant dense<10> : tensor<i32>
376+
+ %c_1 = stablehlo.constant dense<0> : tensor<i32>
377+
+ %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
378+
+ %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<10xf32>
379+
+ // CHECK: stablehlo.while(%iterArg = %c_1, %iterArg_2 = %0) : tensor<i32>, tensor<10xf32> attributes {mhlo.frontend_attributes = {test_attr = "true"}}
380+
+ %1:3 = stablehlo.while(%iterArg = %arg0, %iterArg_2 = %c_1, %iterArg_3 = %0) : tensor<10xf32>, tensor<i32>, tensor<10xf32> attributes {mhlo.frontend_attributes = {test_attr = "true"}}
381+
+ cond {
382+
+ %2 = stablehlo.compare LT, %iterArg_2, %c_0, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
383+
+ stablehlo.return %2 : tensor<i1>
384+
+ } do {
385+
+ %2 = stablehlo.dynamic_slice %iterArg, %iterArg_2, sizes = [1] : (tensor<10xf32>, tensor<i32>) -> tensor<1xf32>
386+
+ %3 = stablehlo.reshape %2 : (tensor<1xf32>) -> tensor<f32>
387+
+ %4 = stablehlo.sine %3 : tensor<f32>
388+
+ %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor<f32>) -> tensor<1xf32>
389+
+ %6 = stablehlo.dynamic_update_slice %iterArg_3, %5, %iterArg_2 : (tensor<10xf32>, tensor<1xf32>, tensor<i32>) -> tensor<10xf32>
390+
+ %7 = stablehlo.add %iterArg_2, %c : tensor<i32>
391+
+ stablehlo.return %iterArg, %7, %6 : tensor<10xf32>, tensor<i32>, tensor<10xf32>
392+
+ }
393+
+ return %1#2 : tensor<10xf32>
394+
}
395+
396+
// -----
353397
diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stablehlo/transforms/VhloToVersion.cpp
354398
--- stablehlo/stablehlo/transforms/VhloToVersion.cpp
355399
+++ stablehlo/stablehlo/transforms/VhloToVersion.cpp
@@ -446,4 +490,17 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
446490
auto elementType = resultType.getElementType();
447491

448492
if (!elementType.isInteger())
493+
diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp
494+
--- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp
495+
+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp
496+
@@ -1369,7 +1369,8 @@
497+
bodyReturnOp->eraseOperand(idx);
498+
499+
WhileOp newWhileOp = rewriter.create<WhileOp>(
500+
- whileOp.getLoc(), bodyReturnOp->getOperandTypes(), newOperands);
501+
+ whileOp.getLoc(), bodyReturnOp->getOperandTypes(), newOperands,
502+
+ whileOp->getAttrs());
503+
newWhileOp.getBodyRegion(0).takeBody(whileOp.getBodyRegion(0));
504+
newWhileOp.getBodyRegion(1).takeBody(whileOp.getBodyRegion(1));
505+
for (auto results : llvm::zip(resultsToReplace, newWhileOp->getResults()))
449506

xla/mlir_hlo/mhlo/IR/hlo_ops.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -6428,7 +6428,8 @@ static LogicalResult whileCanonicalization(WhileOp whileOp,
64286428
bodyReturnOp->eraseOperand(idx);
64296429

64306430
WhileOp newWhileOp = rewriter.create<WhileOp>(
6431-
whileOp.getLoc(), bodyReturnOp->getOperandTypes(), newOperands);
6431+
whileOp.getLoc(), bodyReturnOp->getOperandTypes(), newOperands,
6432+
whileOp->getAttrs());
64326433
newWhileOp.getBodyRegion(0).takeBody(whileOp.getBodyRegion(0));
64336434
newWhileOp.getBodyRegion(1).takeBody(whileOp.getBodyRegion(1));
64346435
for (auto results : llvm::zip(resultsToReplace, newWhileOp->getResults()))

xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir

+24
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,30 @@ func.func @dce_while_without_side_effect(%arg0: tensor<i64>) -> tensor<i64> {
929929
func.return %arg0 : tensor<i64>
930930
}
931931

932+
// CHECK-LABEL: while_op_dce_no_side_effect
933+
func.func @while_op_dce_no_side_effect(%arg0: tensor<10xf32>) -> tensor<10xf32> {
934+
%0 = mhlo.constant dense<1> : tensor<i32>
935+
%1 = mhlo.constant dense<10> : tensor<i32>
936+
%2 = mhlo.constant dense<0> : tensor<i32>
937+
%3 = mhlo.constant dense<0.000000e+00> : tensor<f32>
938+
%4 = "mhlo.broadcast_in_dim"(%3) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor<f32>) -> tensor<10xf32>
939+
// CHECK: mhlo.while(%iterArg = %2, %iterArg_0 = %3) : tensor<i32>, tensor<10xf32> attributes {mhlo.frontend_attributes = {test_attr = "true"}}
940+
%5:3 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %2, %iterArg_1 = %4) : tensor<10xf32>, tensor<i32>, tensor<10xf32> attributes {mhlo.frontend_attributes = {test_attr = "true"}}
941+
cond {
942+
%6 = mhlo.compare LT, %iterArg_0, %1, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
943+
mhlo.return %6 : tensor<i1>
944+
} do {
945+
%6 = "mhlo.dynamic_slice"(%iterArg, %iterArg_0) <{slice_sizes = dense<1> : tensor<1xi64>}> : (tensor<10xf32>, tensor<i32>) -> tensor<1xf32>
946+
%7 = mhlo.reshape %6 : (tensor<1xf32>) -> tensor<f32>
947+
%8 = mhlo.sine %7 : tensor<f32>
948+
%9 = "mhlo.broadcast_in_dim"(%8) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor<f32>) -> tensor<1xf32>
949+
%10 = mhlo.dynamic_update_slice %iterArg_1, %9, %iterArg_0 : (tensor<10xf32>, tensor<1xf32>, tensor<i32>) -> tensor<10xf32>
950+
%11 = mhlo.add %iterArg_0, %0 : tensor<i32>
951+
mhlo.return %iterArg, %11, %10 : tensor<10xf32>, tensor<i32>, tensor<10xf32>
952+
}
953+
return %5#2 : tensor<10xf32>
954+
}
955+
932956
////////
933957
// Tensor/Shape canonicalize
934958

0 commit comments

Comments
 (0)