@@ -350,6 +350,50 @@ diff --ruN a/stablehlo/stablehlo/integrations/python/tests/chlo.py b/stablehlo/s
350
350
+ assert attr.rhs_contracting_dimensions == [2]
351
351
+ assert attr.lhs_ragged_dimensions == [1]
352
352
+ assert attr.rhs_group_dimensions == [0]
353
+ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
354
+ --- stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
355
+ +++ stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
356
+ @@ -1751,7 +1751,7 @@
357
+ // -----
358
+
359
+ ////////
360
+ - // WhileOp DCE
361
+ + // WhileOp
362
+
363
+ // CHECK-LABEL: while_op_with_outfeed_no_dce
364
+ func.func @while_op_with_outfeed_no_dce(%arg0: tensor<i64>) -> tensor<i64> {
365
+ @@ -1780,6 +1780,31 @@
366
+ stablehlo.return %iterArg : tensor<i64>
367
+ }
368
+ return %arg0 : tensor<i64>
369
+ + }
370
+ +
371
+ + // Constant capture
372
+ + // CHECK-LABEL: while_op_constant_capture
373
+ + func.func @while_op_constant_capture(%arg0: tensor<10xf32>) -> (tensor<10xf32>) {
374
+ + %c = stablehlo.constant dense<1> : tensor<i32>
375
+ + %c_0 = stablehlo.constant dense<10> : tensor<i32>
376
+ + %c_1 = stablehlo.constant dense<0> : tensor<i32>
377
+ + %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
378
+ + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<10xf32>
379
+ + // CHECK: stablehlo.while(%iterArg = %c_1, %iterArg_2 = %0) : tensor<i32>, tensor<10xf32> attributes {mhlo.frontend_attributes = {test_attr = "true"}}
380
+ + %1:3 = stablehlo.while(%iterArg = %arg0, %iterArg_2 = %c_1, %iterArg_3 = %0) : tensor<10xf32>, tensor<i32>, tensor<10xf32> attributes {mhlo.frontend_attributes = {test_attr = "true"}}
381
+ + cond {
382
+ + %2 = stablehlo.compare LT, %iterArg_2, %c_0, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
383
+ + stablehlo.return %2 : tensor<i1>
384
+ + } do {
385
+ + %2 = stablehlo.dynamic_slice %iterArg, %iterArg_2, sizes = [1] : (tensor<10xf32>, tensor<i32>) -> tensor<1xf32>
386
+ + %3 = stablehlo.reshape %2 : (tensor<1xf32>) -> tensor<f32>
387
+ + %4 = stablehlo.sine %3 : tensor<f32>
388
+ + %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor<f32>) -> tensor<1xf32>
389
+ + %6 = stablehlo.dynamic_update_slice %iterArg_3, %5, %iterArg_2 : (tensor<10xf32>, tensor<1xf32>, tensor<i32>) -> tensor<10xf32>
390
+ + %7 = stablehlo.add %iterArg_2, %c : tensor<i32>
391
+ + stablehlo.return %iterArg, %7, %6 : tensor<10xf32>, tensor<i32>, tensor<10xf32>
392
+ + }
393
+ + return %1#2 : tensor<10xf32>
394
+ }
395
+
396
+ // -----
353
397
diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stablehlo/transforms/VhloToVersion.cpp
354
398
--- stablehlo/stablehlo/transforms/VhloToVersion.cpp
355
399
+++ stablehlo/stablehlo/transforms/VhloToVersion.cpp
@@ -446,4 +490,17 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveFold
446
490
auto elementType = resultType.getElementType();
447
491
448
492
if (!elementType.isInteger())
493
+ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp
494
+ --- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp
495
+ +++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp
496
+ @@ -1369,7 +1369,8 @@
497
+ bodyReturnOp->eraseOperand(idx);
498
+
499
+ WhileOp newWhileOp = rewriter.create<WhileOp>(
500
+ - whileOp.getLoc(), bodyReturnOp->getOperandTypes(), newOperands);
501
+ + whileOp.getLoc(), bodyReturnOp->getOperandTypes(), newOperands,
502
+ + whileOp->getAttrs());
503
+ newWhileOp.getBodyRegion(0).takeBody(whileOp.getBodyRegion(0));
504
+ newWhileOp.getBodyRegion(1).takeBody(whileOp.getBodyRegion(1));
505
+ for (auto results : llvm::zip(resultsToReplace, newWhileOp->getResults()))
449
506
0 commit comments