Skip to content

Commit d3821b5

Browse files
committed
Revert "[mlir][scf] Track replacements using a listener in TileAndFuse (llvm#120999)"
This reverts commit 6e3631d.
1 parent 6efaa90 commit d3821b5

File tree

1 file changed

+21
-59
lines changed

1 file changed

+21
-59
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

+21-59
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
#include "mlir/Interfaces/TilingInterface.h"
2929
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
3030
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31-
#include "llvm/ADT/ScopeExit.h"
3231
#include "llvm/ADT/TypeSwitch.h"
3332
#include "llvm/Support/Debug.h"
3433
#include <optional>
@@ -1468,47 +1467,6 @@ void SliceTrackingListener::notifyOperationReplaced(Operation *op,
14681467
ValueRange replacement) {
14691468
removeOp(op);
14701469
}
1471-
1472-
//===----------------------------------------------------------------------===//
1473-
// ReplacementListener
1474-
//===----------------------------------------------------------------------===//
1475-
1476-
/// Listener that tracks updates replacements for values which can be mutated.
1477-
/// This listener runs on top of the existing listener for the rewriter,
1478-
/// to make sure external users can still run listeners.
1479-
class ReplacementListener : public RewriterBase::ForwardingListener {
1480-
public:
1481-
ReplacementListener(DenseMap<Value, Value> &replacements,
1482-
OpBuilder::Listener *listener)
1483-
: ForwardingListener(listener), replacements(replacements) {}
1484-
1485-
void updateReplacementValues(ValueRange origValues,
1486-
ValueRange replaceValues) {
1487-
// This can probably be written better, but just iterates over the map
1488-
// and the new replacements for now.
1489-
for (auto &[key, val] : replacements) {
1490-
for (auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) {
1491-
if (val == orig) {
1492-
val = replace;
1493-
}
1494-
}
1495-
}
1496-
}
1497-
1498-
void notifyOperationReplaced(Operation *op, Operation *newOp) override {
1499-
ForwardingListener::notifyOperationReplaced(op, newOp);
1500-
updateReplacementValues(op->getResults(), newOp->getResults());
1501-
}
1502-
1503-
void notifyOperationReplaced(Operation *op, ValueRange values) override {
1504-
ForwardingListener::notifyOperationReplaced(op, values);
1505-
updateReplacementValues(op->getResults(), values);
1506-
}
1507-
1508-
private:
1509-
DenseMap<Value, Value> &replacements;
1510-
};
1511-
15121470
} // namespace
15131471

15141472
/// Implementation of tile consumer and fuse producer greedily.
@@ -1535,27 +1493,26 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
15351493
for (auto *tiledOp : tilingResult->tiledOps)
15361494
tiledAndFusedOps.insert(tiledOp);
15371495

1538-
DenseMap<Value, Value> replacements;
1539-
for (auto [origVal, replacement] : llvm::zip_equal(
1540-
consumer->getResults(), tilingResult->mergeResult.replacements)) {
1541-
replacements[origVal] = replacement;
1542-
}
1543-
15441496
// If there are no loops generated, fusion is immaterial.
15451497
auto &loops = tilingResult->loops;
15461498
if (loops.empty()) {
1499+
DenseMap<Value, Value> replacements;
1500+
for (auto [origVal, replacement] : llvm::zip_equal(
1501+
consumer->getResults(), tilingResult->mergeResult.replacements)) {
1502+
replacements[origVal] = replacement;
1503+
}
15471504
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
15481505
replacements};
15491506
}
15501507

1551-
// Since the loop gets potentially replaced during fusion, we need to track
1552-
// the mutation of replacement values. To do this, we attach a listener to
1553-
// update the replacements as they happen.
1554-
OpBuilder::Listener *previousListener = rewriter.getListener();
1555-
auto resetListener =
1556-
llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
1557-
ReplacementListener replaceListener(replacements, previousListener);
1558-
rewriter.setListener(&replaceListener);
1508+
// To keep track of replacements for now just record the map from the
1509+
// original untiled value to the result number of the for loop. Since the
1510+
// loop gets potentially replaced during fusion, keeping the value directly
1511+
// wont work.
1512+
DenseMap<Value, size_t> origValToResultNumber;
1513+
for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
1514+
origValToResultNumber[result] = index;
1515+
}
15591516

15601517
// 2. Typically, the operands of the tiled operation are slices of the
15611518
// operands of the untiled operation. These are expressed in IR using
@@ -1624,9 +1581,9 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
16241581
worklistCandidates.append(newSlices.value());
16251582
for (auto [index, result] :
16261583
llvm::enumerate(fusableProducerOp->getResults())) {
1627-
replacements[result] = loops.front()->getResult(
1628-
loops.front()->getNumResults() -
1629-
fusableProducerOp->getNumResults() + index);
1584+
origValToResultNumber[result] = loops.front()->getNumResults() -
1585+
fusableProducerOp->getNumResults() +
1586+
index;
16301587
}
16311588
}
16321589
if (Operation *tiledAndFusedOp =
@@ -1640,6 +1597,11 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
16401597
}
16411598
}
16421599

1600+
DenseMap<Value, Value> replacements;
1601+
for (auto [origVal, resultNumber] : origValToResultNumber) {
1602+
replacements[origVal] = loops.front()->getResult(resultNumber);
1603+
}
1604+
16431605
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
16441606
replacements};
16451607
}

0 commit comments

Comments
 (0)