-
Notifications
You must be signed in to change notification settings - Fork 13k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Tensor] Use output_shape for DimOp->ExpandShapeOp folding #118203
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tensor Author: Kunwar Grover (Groverkss) ChangesWe already have the output shape available in the operation, so there is no need to do any arithmetic to figure it out. This PR makes the tensor.dim folding directly use the available output shape. Full diff: https://github.com/llvm/llvm-project/pull/118203.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 616d4a7d0a0ab5..a6ae728b20fa47 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1971,32 +1971,12 @@ struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
if (!dim.has_value())
return failure();
- // Skip static dims. These are folded to constant ops.
- RankedTensorType resultType = expandShapeOp.getResultType();
- if (!resultType.isDynamicDim(*dim))
- return failure();
-
- // Find reassociation group that contains this result dimension.
- int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
-
- // `dim` is the only dynamic dimension in `group`. (Otherwise, the
- // ExpandShapeOp would be ambiguous.)
- int64_t product = 1;
- ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
- for (int64_t d : grp) {
- if (d != dim) {
- assert(!resultType.isDynamicDim(d) && "expected static dim");
- product *= resultType.getDimSize(d);
- }
- }
-
- // result dim size = src dim size / (product(other dims in reassoc group))
- Value srcDimSz =
- rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
- AffineExpr expr;
- bindSymbols(dimOp.getContext(), expr);
- rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
- dimOp, expr.floorDiv(product), srcDimSz);
+ SmallVector<OpFoldResult> outputShape =
+ getMixedValues(expandShapeOp.getStaticOutputShape(),
+ expandShapeOp.getOutputShape(), rewriter);
+ OpFoldResult outputDim = outputShape[dim.value()];
+ rewriter.replaceOp(dimOp, getValueOrCreateConstantIndexOp(
+ rewriter, dimOp.getLoc(), outputDim));
return success();
}
};
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 0b54c207dea84e..3a0f8e0e073acd 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2278,13 +2278,9 @@ func.func @empty_tensor_canonicalize(%i : index) {
// -----
-// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 floordiv 40)>
// CHECK-LABEL: func @dim_of_expand_shape(
-// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
-// CHECK: %[[c1:.*]] = arith.constant 1 : index
-// CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c1]] : tensor<?x?xf32>
-// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim]]]
-// CHECK: return %[[apply]]
+// CHECK-SAME: %{{.*}}: tensor<?x?xf32>, %{{.*}}: index, %[[ARG2:.+]]: index
+// CHECK: return %[[ARG2]]
func.func @dim_of_expand_shape(%t: tensor<?x?xf32>, %sz0: index, %sz1: index) -> index {
%c2 = arith.constant 2 : index
%0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]] output_shape [%sz0, 1, %sz1, 5, 1, 8]
|
expandShapeOp.getOutputShape(), rewriter); | ||
OpFoldResult outputDim = outputShape[dim.value()]; | ||
rewriter.replaceOp(dimOp, getValueOrCreateConstantIndexOp( | ||
rewriter, dimOp.getLoc(), outputDim)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just roll this into the folder for DimOp? The folder already handles the static case and has similar handling for extract_slice
. This would have the added benefit of improving createOrFold<DimOp>
which is used quite frequently.
I had this change already upstream (#113501) but didn't land it cause it caused some strange errors in IREE (iree-org/iree#18907) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have a few follow ups. Please wait for some comments
Landed as #113501 |
We already have the output shape available in the operation, so there is no need to do any arithmetic to figure it out. This PR makes the tensor.dim folding directly use the available output shape.