Skip to content

Commit 205dce6

Browse files
authored
[mlir][linalg] Add a folder for transpose(fill) -> fill (#83623)
This is similar to the existing folder for a linalg.copy. Transposing a filled tensor is the same as filling the destination of the transpose.
1 parent f505a92 commit 205dce6

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

+17-1
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,22 @@ struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
815815
}
816816
};
817817

818+
/// Fold fill with transpose.
819+
struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
820+
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
821+
822+
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
823+
PatternRewriter &rewriter) const override {
824+
if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
825+
rewriter.replaceOpWithNewOp<FillOp>(
826+
transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
827+
transposeOp.getDpsInitOperand(0)->get());
828+
return success();
829+
}
830+
return failure();
831+
}
832+
};
833+
818834
} // namespace
819835

820836
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -823,7 +839,7 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
823839
.add<FoldFillWithCopy, FoldFillWithTensorExtract, FoldFillWithPack,
824840
FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
825841
FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
826-
FoldInsertPadIntoFill>(context);
842+
FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
827843
}
828844

829845
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Linalg/canonicalize.mlir

+14
Original file line numberDiff line numberDiff line change
@@ -993,6 +993,20 @@ func.func @canonicalize_fill_to_copy_dest(%arg0 : tensor<?x?xf32>, %arg1 : tenso
993993

994994
// -----
995995

996+
// CHECK-LABEL: func @canonicalize_fill_to_transpose_input(
997+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
998+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
999+
// CHECK: %[[ZERO:.+]] = arith.constant 0.0
1000+
// CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[ARG1]] : tensor<?x?xf32>)
1001+
func.func @canonicalize_fill_to_transpose_input(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
1002+
%c0 = arith.constant 0.0 : f32
1003+
%fill = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
1004+
%transpose = linalg.transpose ins(%fill : tensor<?x?xf32>) outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0]
1005+
return %transpose : tensor<?x?xf32>
1006+
}
1007+
1008+
// -----
1009+
9961010
// CHECK-LABEL: func @broadcast_same_shape(
9971011
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<2x3xf32>
9981012
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x3xf32>)

0 commit comments

Comments
 (0)