Skip to content

Commit 43644d3

Browse files
committed
Elide chained transfers at Flow level
Signed-off-by: Alex Vasile <48962821+Alex-Vasile@users.noreply.github.com>
1 parent c846333 commit 43644d3

File tree

2 files changed

+54
-0
lines changed

2 files changed

+54
-0
lines changed

compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp

+24
Original file line numberDiff line numberDiff line change
@@ -1192,11 +1192,35 @@ struct ElideRedundantTransfer : public OpRewritePattern<TensorTransferOp> {
11921192
}
11931193
};
11941194

1195+
// Attempts to identify trivial case of chained transfer ops (A -> B -> C) and rewrite it as (A -> C)
1196+
// Writes it as A -> B and A -> C relying on dead code elimination to remove the unused A -> B transfer.
1197+
struct ElideChainedTransfer : public OpRewritePattern<TensorTransferOp> {
1198+
using OpRewritePattern::OpRewritePattern;
1199+
LogicalResult matchAndRewrite(TensorTransferOp currTransferOp,
1200+
PatternRewriter &rewriter) const override {
1201+
auto baseValue =
1202+
IREE::Util::TiedOpInterface::findTiedBaseValue(currTransferOp.getOperand());
1203+
if (auto prevTransferOp = dyn_cast_if_present<IREE::Flow::TensorTransferOp>(
1204+
baseValue.getDefiningOp())) {
1205+
rewriter.replaceOpWithNewOp<TensorTransferOp>(
1206+
currTransferOp,
1207+
currTransferOp->getResultTypes(),
1208+
prevTransferOp.getOperand(),
1209+
currTransferOp.getOperandDims(),
1210+
currTransferOp.getTarget());
1211+
return success();
1212+
}
1213+
return failure();
1214+
}
1215+
};
1216+
1217+
11951218
} // namespace
11961219

11971220
void TensorTransferOp::getCanonicalizationPatterns(RewritePatternSet &results,
11981221
MLIRContext *context) {
11991222
results.insert<ElideRedundantTransfer>(context);
1223+
results.insert<ElideChainedTransfer>(context);
12001224
}
12011225

12021226
//===----------------------------------------------------------------------===//

compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir

+30
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,36 @@ util.func public @ElideRedundantTransfer(%operand: tensor<4x?xf32>, %dim: index)
424424

425425
// -----
426426

427+
// CHECK-LABEL: @ElideChainedTransferTwoTransfers
428+
// CHECK-SAME: (%[[OPERAND:.+]]: tensor<1xf16>)
429+
util.func public @ElideChainedTransferTwoTransfers(%operand: tensor<1xf16>) -> tensor<1xf16> {
430+
// CHECK-NOT: flow.tensor.transfer
431+
%redundant = flow.tensor.transfer %operand : tensor<1xf16> to "target1"
432+
// CHECK: %[[RESULT:.+]] = flow.tensor.transfer %[[OPERAND]]
433+
%result = flow.tensor.transfer %redundant : tensor<1xf16> to "target2"
434+
// CHECK-NEXT: util.return %[[RESULT]]
435+
util.return %result : tensor<1xf16>
436+
}
437+
438+
// -----
439+
440+
// CHECK-LABEL: @ElideChainedTransferFourTransfers
441+
// CHECK-SAME: (%[[OPERAND:.+]]: tensor<1xf16>)
442+
util.func public @ElideChainedTransferFourTransfers(%operand: tensor<1xf16>) -> tensor<1xf16> {
443+
// CHECK-NOT: flow.tensor.transfer
444+
%redundant = flow.tensor.transfer %operand : tensor<1xf16> to "target1"
445+
// CHECK-NOT: flow.tensor.transfer
446+
%redundant2 = flow.tensor.transfer %redundant : tensor<1xf16> to "target2"
447+
// CHECK-NOT: flow.tensor.transfer
448+
%redundant3 = flow.tensor.transfer %redundant2 : tensor<1xf16> to "target3"
449+
// CHECK: %[[RESULT:.+]] = flow.tensor.transfer %[[OPERAND]]
450+
%result = flow.tensor.transfer %redundant3 : tensor<1xf16> to "target4"
451+
// CHECK-NEXT: util.return %[[RESULT]]
452+
util.return %result : tensor<1xf16>
453+
}
454+
455+
// -----
456+
427457
// CHECK-LABEL: @sliceConst0D
428458
util.func public @sliceConst0D() -> tensor<i32> {
429459
%0 = arith.constant dense<0> : tensor<i32>

0 commit comments

Comments
 (0)