@@ -815,6 +815,22 @@ struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
815
815
}
816
816
};
817
817
818
+ // / Fold fill with transpose.
819
+ struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
820
+ using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
821
+
822
+ LogicalResult matchAndRewrite (linalg::TransposeOp transposeOp,
823
+ PatternRewriter &rewriter) const override {
824
+ if (auto fillOp = transposeOp.getInput ().getDefiningOp <FillOp>()) {
825
+ rewriter.replaceOpWithNewOp <FillOp>(
826
+ transposeOp, transposeOp.getResultTypes (), fillOp.getInputs (),
827
+ transposeOp.getDpsInitOperand (0 )->get ());
828
+ return success ();
829
+ }
830
+ return failure ();
831
+ }
832
+ };
833
+
818
834
} // namespace
819
835
820
836
void FillOp::getCanonicalizationPatterns (RewritePatternSet &results,
@@ -823,7 +839,7 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
823
839
.add <FoldFillWithCopy, FoldFillWithTensorExtract, FoldFillWithPack,
824
840
FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
825
841
FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
826
- FoldInsertPadIntoFill>(context);
842
+ FoldInsertPadIntoFill, FoldFillWithTranspose >(context);
827
843
}
828
844
829
845
// ===----------------------------------------------------------------------===//
0 commit comments