@@ -44,7 +44,12 @@ constexpr const ::llvm::StringLiteral
44
44
45
45
LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts (
46
46
RewriterBase &rewriter, const AnalysisState &state) {
47
+ OpBuilder::InsertionGuard g (rewriter);
47
48
Operation *op = getOperation ();
49
+ SmallVector<OpOperand *> outOfPlaceOpOperands;
50
+ SmallVector<OpResult> outOfPlaceOpResults;
51
+
52
+ // Find all out-of-place OpOperands.
48
53
for (OpOperand &opOperand : op->getOpOperands ()) {
49
54
Type operandType = opOperand.get ().getType ();
50
55
if (!operandType.isa <TensorType>())
@@ -53,17 +58,52 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
53
58
continue ;
54
59
if (operandType.isa <UnrankedTensorType>())
55
60
return op->emitError (" copies of unranked tensors are not supported" );
56
- auto tensorType = operandType.dyn_cast <RankedTensorType>();
57
- if (!tensorType)
58
- continue ;
61
+
59
62
SmallVector<OpResult> aliasingOpResults =
60
63
state.getAliasingOpResult (opOperand);
64
+ if (aliasingOpResults.size () == 1 &&
65
+ !state.bufferizesToMemoryWrite (opOperand) &&
66
+ state.getAliasingOpOperand (aliasingOpResults.front ()).size () == 1 ) {
67
+ // The op itself does not write but may create exactly one alias. Instead
68
+ // of copying the OpOperand, copy the OpResult. The OpResult can sometimes
69
+ // be smaller than the OpOperand (e.g., in the case of an extract_slice,
70
+ // where the result is usually a smaller part of the source).
71
+ outOfPlaceOpResults.push_back (aliasingOpResults.front ());
72
+ } else {
73
+ // In all other cases, make a copy of the OpOperand.
74
+ outOfPlaceOpOperands.push_back (&opOperand);
75
+ }
76
+ }
77
+
78
+ // Insert copies of OpOperands.
79
+ rewriter.setInsertionPoint (op);
80
+ for (OpOperand *opOperand : outOfPlaceOpOperands) {
81
+ auto tensorType = opOperand->get ().getType ().cast <RankedTensorType>();
82
+ SmallVector<OpResult> aliasingOpResults =
83
+ state.getAliasingOpResult (*opOperand);
61
84
bool escape = llvm::any_of (
62
85
aliasingOpResults, [&](Value v) { return state.isTensorYielded (v); });
63
86
Value copy = rewriter.create <AllocTensorOp>(
64
- op->getLoc (), tensorType, ValueRange (), opOperand.get (), escape);
65
- rewriter.updateRootInPlace (op, [&]() { opOperand.set (copy); });
87
+ op->getLoc (), tensorType, ValueRange (), opOperand->get (), escape);
88
+ rewriter.updateRootInPlace (op, [&]() { opOperand->set (copy); });
89
+ }
90
+
91
+ // Insert copies of OpResults.
92
+ rewriter.setInsertionPointAfter (op);
93
+ for (OpResult opResult : outOfPlaceOpResults) {
94
+ auto tensorType = opResult.getType ().cast <RankedTensorType>();
95
+ bool escape = state.isTensorYielded (opResult);
96
+ Value copy = rewriter.create <AllocTensorOp>(op->getLoc (), tensorType,
97
+ ValueRange (), opResult, escape);
98
+ SmallVector<OpOperand *> uses = llvm::to_vector (llvm::map_range (
99
+ opResult.getUses (), [](OpOperand &use) { return &use; }));
100
+ for (OpOperand *use : uses) {
101
+ // Do not update the alloc_tensor op that we just created.
102
+ if (use->getOwner () != copy.getDefiningOp ())
103
+ rewriter.updateRootInPlace (use->getOwner (), [&]() { use->set (copy); });
104
+ }
66
105
}
106
+
67
107
return success ();
68
108
}
69
109
0 commit comments