Skip to content

Commit 87b4677

Browse files
[mlir][bufferize] Improve resolveConflicts for ExtractSliceOp
It is sometimes better to make a copy of the OpResult instead of making a copy of the OpOperand. E.g., when bufferizing tensor.extract_slice. This implementation will eventually make parts of extract_slice's `bufferize` implementation obsolete (and simplify it). It will only need to handle in-place OpOperands. Differential Revision: https://reviews.llvm.org/D126819
1 parent 72a049d commit 87b4677

File tree

2 files changed

+63
-5
lines changed

2 files changed

+63
-5
lines changed

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

+45-5
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,12 @@ constexpr const ::llvm::StringLiteral
4444

4545
LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
4646
RewriterBase &rewriter, const AnalysisState &state) {
47+
OpBuilder::InsertionGuard g(rewriter);
4748
Operation *op = getOperation();
49+
SmallVector<OpOperand *> outOfPlaceOpOperands;
50+
SmallVector<OpResult> outOfPlaceOpResults;
51+
52+
// Find all out-of-place OpOperands.
4853
for (OpOperand &opOperand : op->getOpOperands()) {
4954
Type operandType = opOperand.get().getType();
5055
if (!operandType.isa<TensorType>())
@@ -53,17 +58,52 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
5358
continue;
5459
if (operandType.isa<UnrankedTensorType>())
5560
return op->emitError("copies of unranked tensors are not supported");
56-
auto tensorType = operandType.dyn_cast<RankedTensorType>();
57-
if (!tensorType)
58-
continue;
61+
5962
SmallVector<OpResult> aliasingOpResults =
6063
state.getAliasingOpResult(opOperand);
64+
if (aliasingOpResults.size() == 1 &&
65+
!state.bufferizesToMemoryWrite(opOperand) &&
66+
state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) {
67+
// The op itself does not write but may create exactly one alias. Instead
68+
// of copying the OpOperand, copy the OpResult. The OpResult can sometimes
69+
// be smaller than the OpOperand (e.g., in the case of an extract_slice,
70+
// where the result is usually a smaller part of the source).
71+
outOfPlaceOpResults.push_back(aliasingOpResults.front());
72+
} else {
73+
// In all other cases, make a copy of the OpOperand.
74+
outOfPlaceOpOperands.push_back(&opOperand);
75+
}
76+
}
77+
78+
// Insert copies of OpOperands.
79+
rewriter.setInsertionPoint(op);
80+
for (OpOperand *opOperand : outOfPlaceOpOperands) {
81+
auto tensorType = opOperand->get().getType().cast<RankedTensorType>();
82+
SmallVector<OpResult> aliasingOpResults =
83+
state.getAliasingOpResult(*opOperand);
6184
bool escape = llvm::any_of(
6285
aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); });
6386
Value copy = rewriter.create<AllocTensorOp>(
64-
op->getLoc(), tensorType, ValueRange(), opOperand.get(), escape);
65-
rewriter.updateRootInPlace(op, [&]() { opOperand.set(copy); });
87+
op->getLoc(), tensorType, ValueRange(), opOperand->get(), escape);
88+
rewriter.updateRootInPlace(op, [&]() { opOperand->set(copy); });
89+
}
90+
91+
// Insert copies of OpResults.
92+
rewriter.setInsertionPointAfter(op);
93+
for (OpResult opResult : outOfPlaceOpResults) {
94+
auto tensorType = opResult.getType().cast<RankedTensorType>();
95+
bool escape = state.isTensorYielded(opResult);
96+
Value copy = rewriter.create<AllocTensorOp>(op->getLoc(), tensorType,
97+
ValueRange(), opResult, escape);
98+
SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range(
99+
opResult.getUses(), [](OpOperand &use) { return &use; }));
100+
for (OpOperand *use : uses) {
101+
// Do not update the alloc_tensor op that we just created.
102+
if (use->getOwner() != copy.getDefiningOp())
103+
rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(copy); });
104+
}
66105
}
106+
67107
return success();
68108
}
69109

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// RUN: mlir-opt %s -tensor-copy-insertion -split-input-file | FileCheck %s
2+
// RUN: mlir-opt %s -tensor-copy-insertion="bufferize-function-boundaries allow-return-allocs" -split-input-file | FileCheck %s --check-prefix=CHECK-FUNC
3+
4+
// CHECK-LABEL: func @extract_slice(
5+
// CHECK-SAME: %[[t:.*]]: tensor<?xf32>
6+
// CHECK-FUNC-LABEL: func @extract_slice(
7+
func.func @extract_slice(%t: tensor<?xf32>, %idx: index, %f: f32)
8+
-> (tensor<5xf32>, tensor<?xf32>)
9+
{
10+
// CHECK: %[[extract_slice:.*]] = tensor.extract_slice %[[t]][10] [5] [1]
11+
%0 = tensor.extract_slice %t[10][5][1] : tensor<?xf32> to tensor<5xf32>
12+
// CHECK: %[[alloc:.*]] = bufferization.alloc_tensor() copy(%[[extract_slice]]) {escape = false} : tensor<5xf32>
13+
// CHECK-FUNC: bufferization.alloc_tensor() copy(%{{.*}}) {escape = true} : tensor<5xf32>
14+
// CHECK: %[[insert:.*]] = tensor.insert %{{.*}} into %[[alloc]]
15+
%1 = tensor.insert %f into %0[%idx] : tensor<5xf32>
16+
// CHECK: return %[[insert]], %[[t]]
17+
return %1, %t : tensor<5xf32>, tensor<?xf32>
18+
}

0 commit comments

Comments
 (0)