Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Tensor] Use output_shape for DimOp->ExpandShapeOp folding #118203

Closed
wants to merge 1 commit into from

Conversation

Groverkss
Copy link
Member

We already have the output shape available in the operation, so there is no need to do any arithmetic to figure it out. This PR makes the tensor.dim folding directly use the available output shape.

@llvmbot
Copy link
Member

llvmbot commented Dec 1, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: Kunwar Grover (Groverkss)

Changes

We already have the output shape available in the operation, so there is no need to do any arithmetic to figure it out. This PR makes the tensor.dim folding directly use the available output shape.


Full diff: https://github.com/llvm/llvm-project/pull/118203.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+6-26)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+2-6)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 616d4a7d0a0ab5..a6ae728b20fa47 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1971,32 +1971,12 @@ struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
     if (!dim.has_value())
       return failure();
 
-    // Skip static dims. These are folded to constant ops.
-    RankedTensorType resultType = expandShapeOp.getResultType();
-    if (!resultType.isDynamicDim(*dim))
-      return failure();
-
-    // Find reassociation group that contains this result dimension.
-    int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
-
-    // `dim` is the only dynamic dimension in `group`. (Otherwise, the
-    // ExpandShapeOp would be ambiguous.)
-    int64_t product = 1;
-    ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
-    for (int64_t d : grp) {
-      if (d != dim) {
-        assert(!resultType.isDynamicDim(d) && "expected static dim");
-        product *= resultType.getDimSize(d);
-      }
-    }
-
-    // result dim size = src dim size / (product(other dims in reassoc group))
-    Value srcDimSz =
-        rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
-    AffineExpr expr;
-    bindSymbols(dimOp.getContext(), expr);
-    rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
-        dimOp, expr.floorDiv(product), srcDimSz);
+    SmallVector<OpFoldResult> outputShape =
+        getMixedValues(expandShapeOp.getStaticOutputShape(),
+                       expandShapeOp.getOutputShape(), rewriter);
+    OpFoldResult outputDim = outputShape[dim.value()];
+    rewriter.replaceOp(dimOp, getValueOrCreateConstantIndexOp(
+                                  rewriter, dimOp.getLoc(), outputDim));
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 0b54c207dea84e..3a0f8e0e073acd 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2278,13 +2278,9 @@ func.func @empty_tensor_canonicalize(%i : index) {
 
 // -----
 
-//       CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 floordiv 40)>
 // CHECK-LABEL: func @dim_of_expand_shape(
-//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>
-//       CHECK:   %[[c1:.*]] = arith.constant 1 : index
-//       CHECK:   %[[dim:.*]] = tensor.dim %[[t]], %[[c1]] : tensor<?x?xf32>
-//       CHECK:   %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim]]]
-//       CHECK:   return %[[apply]]
+//  CHECK-SAME:     %{{.*}}: tensor<?x?xf32>, %{{.*}}: index, %[[ARG2:.+]]: index
+//       CHECK:   return %[[ARG2]]
 func.func @dim_of_expand_shape(%t: tensor<?x?xf32>, %sz0: index, %sz1: index) -> index {
   %c2 = arith.constant 2 : index
   %0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]] output_shape [%sz0, 1, %sz1, 5, 1, 8]

expandShapeOp.getOutputShape(), rewriter);
OpFoldResult outputDim = outputShape[dim.value()];
rewriter.replaceOp(dimOp, getValueOrCreateConstantIndexOp(
rewriter, dimOp.getLoc(), outputDim));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just roll this into the folder for DimOp? The folder already handles the static case and has similar handling for extract_slice. This would have the added benefit of improving createOrFold<DimOp> which is used quite frequently.

@MaheshRavishankar
Copy link
Contributor

MaheshRavishankar commented Dec 1, 2024

I had this change already upstream (#113501) but didn't land it cause it caused some strange errors in IREE (iree-org/iree#18907)

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have a few follow ups. Please wait for some comments

@Groverkss
Copy link
Member Author

Landed as #113501

@Groverkss Groverkss closed this Feb 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants