Skip to content

Commit 6e3631d

Browse files
authored
[mlir][scf] Track replacements using a listener in TileAndFuse (llvm#120999)
This PR makes TileAndFuse explicitly track replacements using a listener instead of assuming that the results always come from the outer most tiling loop. scf::tileUsingInterface can introduce merge operations whose results are the actual replacements to use, instead of the outer most loop results.
1 parent 852feea commit 6e3631d

File tree

1 file changed

+59
-21
lines changed

1 file changed

+59
-21
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

+59-21
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "mlir/Interfaces/TilingInterface.h"
2929
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
3030
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31+
#include "llvm/ADT/ScopeExit.h"
3132
#include "llvm/ADT/TypeSwitch.h"
3233
#include "llvm/Support/Debug.h"
3334
#include <optional>
@@ -1467,6 +1468,47 @@ void SliceTrackingListener::notifyOperationReplaced(Operation *op,
14671468
ValueRange replacement) {
14681469
removeOp(op);
14691470
}
1471+
1472+
//===----------------------------------------------------------------------===//
1473+
// ReplacementListener
1474+
//===----------------------------------------------------------------------===//
1475+
1476+
/// Listener that tracks updates replacements for values which can be mutated.
1477+
/// This listener runs on top of the existing listener for the rewriter,
1478+
/// to make sure external users can still run listeners.
1479+
class ReplacementListener : public RewriterBase::ForwardingListener {
1480+
public:
1481+
ReplacementListener(DenseMap<Value, Value> &replacements,
1482+
OpBuilder::Listener *listener)
1483+
: ForwardingListener(listener), replacements(replacements) {}
1484+
1485+
void updateReplacementValues(ValueRange origValues,
1486+
ValueRange replaceValues) {
1487+
// This can probably be written better, but just iterates over the map
1488+
// and the new replacements for now.
1489+
for (auto &[key, val] : replacements) {
1490+
for (auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) {
1491+
if (val == orig) {
1492+
val = replace;
1493+
}
1494+
}
1495+
}
1496+
}
1497+
1498+
void notifyOperationReplaced(Operation *op, Operation *newOp) override {
1499+
ForwardingListener::notifyOperationReplaced(op, newOp);
1500+
updateReplacementValues(op->getResults(), newOp->getResults());
1501+
}
1502+
1503+
void notifyOperationReplaced(Operation *op, ValueRange values) override {
1504+
ForwardingListener::notifyOperationReplaced(op, values);
1505+
updateReplacementValues(op->getResults(), values);
1506+
}
1507+
1508+
private:
1509+
DenseMap<Value, Value> &replacements;
1510+
};
1511+
14701512
} // namespace
14711513

14721514
/// Implementation of tile consumer and fuse producer greedily.
@@ -1493,26 +1535,27 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
14931535
for (auto *tiledOp : tilingResult->tiledOps)
14941536
tiledAndFusedOps.insert(tiledOp);
14951537

1538+
DenseMap<Value, Value> replacements;
1539+
for (auto [origVal, replacement] : llvm::zip_equal(
1540+
consumer->getResults(), tilingResult->mergeResult.replacements)) {
1541+
replacements[origVal] = replacement;
1542+
}
1543+
14961544
// If there are no loops generated, fusion is immaterial.
14971545
auto &loops = tilingResult->loops;
14981546
if (loops.empty()) {
1499-
DenseMap<Value, Value> replacements;
1500-
for (auto [origVal, replacement] : llvm::zip_equal(
1501-
consumer->getResults(), tilingResult->mergeResult.replacements)) {
1502-
replacements[origVal] = replacement;
1503-
}
15041547
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
15051548
replacements};
15061549
}
15071550

1508-
// To keep track of replacements for now just record the map from the
1509-
// original untiled value to the result number of the for loop. Since the
1510-
// loop gets potentially replaced during fusion, keeping the value directly
1511-
// wont work.
1512-
DenseMap<Value, size_t> origValToResultNumber;
1513-
for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
1514-
origValToResultNumber[result] = index;
1515-
}
1551+
// Since the loop gets potentially replaced during fusion, we need to track
1552+
// the mutation of replacement values. To do this, we attach a listener to
1553+
// update the replacements as they happen.
1554+
OpBuilder::Listener *previousListener = rewriter.getListener();
1555+
auto resetListener =
1556+
llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
1557+
ReplacementListener replaceListener(replacements, previousListener);
1558+
rewriter.setListener(&replaceListener);
15161559

15171560
// 2. Typically, the operands of the tiled operation are slices of the
15181561
// operands of the untiled operation. These are expressed in IR using
@@ -1581,9 +1624,9 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
15811624
worklistCandidates.append(newSlices.value());
15821625
for (auto [index, result] :
15831626
llvm::enumerate(fusableProducerOp->getResults())) {
1584-
origValToResultNumber[result] = loops.front()->getNumResults() -
1585-
fusableProducerOp->getNumResults() +
1586-
index;
1627+
replacements[result] = loops.front()->getResult(
1628+
loops.front()->getNumResults() -
1629+
fusableProducerOp->getNumResults() + index);
15871630
}
15881631
}
15891632
if (Operation *tiledAndFusedOp =
@@ -1597,11 +1640,6 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
15971640
}
15981641
}
15991642

1600-
DenseMap<Value, Value> replacements;
1601-
for (auto [origVal, resultNumber] : origValToResultNumber) {
1602-
replacements[origVal] = loops.front()->getResult(resultNumber);
1603-
}
1604-
16051643
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
16061644
replacements};
16071645
}

0 commit comments

Comments
 (0)