28
28
#include " mlir/Interfaces/TilingInterface.h"
29
29
#include " mlir/Rewrite/FrozenRewritePatternSet.h"
30
30
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
31
+ #include " llvm/ADT/ScopeExit.h"
31
32
#include " llvm/ADT/TypeSwitch.h"
32
33
#include " llvm/Support/Debug.h"
33
34
#include < optional>
@@ -1467,6 +1468,47 @@ void SliceTrackingListener::notifyOperationReplaced(Operation *op,
1467
1468
ValueRange replacement) {
1468
1469
removeOp (op);
1469
1470
}
1471
+
1472
+ // ===----------------------------------------------------------------------===//
1473
+ // ReplacementListener
1474
+ // ===----------------------------------------------------------------------===//
1475
+
1476
+ // / Listener that tracks updates replacements for values which can be mutated.
1477
+ // / This listener runs on top of the existing listener for the rewriter,
1478
+ // / to make sure external users can still run listeners.
1479
+ class ReplacementListener : public RewriterBase ::ForwardingListener {
1480
+ public:
1481
+ ReplacementListener (DenseMap<Value, Value> &replacements,
1482
+ OpBuilder::Listener *listener)
1483
+ : ForwardingListener(listener), replacements(replacements) {}
1484
+
1485
+ void updateReplacementValues (ValueRange origValues,
1486
+ ValueRange replaceValues) {
1487
+ // This can probably be written better, but just iterates over the map
1488
+ // and the new replacements for now.
1489
+ for (auto &[key, val] : replacements) {
1490
+ for (auto [orig, replace] : llvm::zip_equal (origValues, replaceValues)) {
1491
+ if (val == orig) {
1492
+ val = replace;
1493
+ }
1494
+ }
1495
+ }
1496
+ }
1497
+
1498
+ void notifyOperationReplaced (Operation *op, Operation *newOp) override {
1499
+ ForwardingListener::notifyOperationReplaced (op, newOp);
1500
+ updateReplacementValues (op->getResults (), newOp->getResults ());
1501
+ }
1502
+
1503
+ void notifyOperationReplaced (Operation *op, ValueRange values) override {
1504
+ ForwardingListener::notifyOperationReplaced (op, values);
1505
+ updateReplacementValues (op->getResults (), values);
1506
+ }
1507
+
1508
+ private:
1509
+ DenseMap<Value, Value> &replacements;
1510
+ };
1511
+
1470
1512
} // namespace
1471
1513
1472
1514
// / Implementation of tile consumer and fuse producer greedily.
@@ -1493,26 +1535,27 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
1493
1535
for (auto *tiledOp : tilingResult->tiledOps )
1494
1536
tiledAndFusedOps.insert (tiledOp);
1495
1537
1538
+ DenseMap<Value, Value> replacements;
1539
+ for (auto [origVal, replacement] : llvm::zip_equal (
1540
+ consumer->getResults (), tilingResult->mergeResult .replacements )) {
1541
+ replacements[origVal] = replacement;
1542
+ }
1543
+
1496
1544
// If there are no loops generated, fusion is immaterial.
1497
1545
auto &loops = tilingResult->loops ;
1498
1546
if (loops.empty ()) {
1499
- DenseMap<Value, Value> replacements;
1500
- for (auto [origVal, replacement] : llvm::zip_equal (
1501
- consumer->getResults (), tilingResult->mergeResult .replacements )) {
1502
- replacements[origVal] = replacement;
1503
- }
1504
1547
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1505
1548
replacements};
1506
1549
}
1507
1550
1508
- // To keep track of replacements for now just record the map from the
1509
- // original untiled value to the result number of the for loop. Since the
1510
- // loop gets potentially replaced during fusion, keeping the value directly
1511
- // wont work.
1512
- DenseMap<Value, size_t > origValToResultNumber;
1513
- for ( auto [ index , result] : llvm::enumerate (consumer-> getResults ())) {
1514
- origValToResultNumber[result] = index ;
1515
- }
1551
+ // Since the loop gets potentially replaced during fusion, we need to track
1552
+ // the mutation of replacement values. To do this, we attach a listener to
1553
+ // update the replacements as they happen.
1554
+ OpBuilder::Listener *previousListener = rewriter. getListener ();
1555
+ auto resetListener =
1556
+ llvm::make_scope_exit ([&]() { rewriter. setListener (previousListener); });
1557
+ ReplacementListener replaceListener (replacements, previousListener) ;
1558
+ rewriter. setListener (&replaceListener);
1516
1559
1517
1560
// 2. Typically, the operands of the tiled operation are slices of the
1518
1561
// operands of the untiled operation. These are expressed in IR using
@@ -1581,9 +1624,9 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
1581
1624
worklistCandidates.append (newSlices.value ());
1582
1625
for (auto [index , result] :
1583
1626
llvm::enumerate (fusableProducerOp->getResults ())) {
1584
- origValToResultNumber [result] = loops.front ()->getNumResults () -
1585
- fusableProducerOp ->getNumResults () +
1586
- index ;
1627
+ replacements [result] = loops.front ()->getResult (
1628
+ loops. front () ->getNumResults () -
1629
+ fusableProducerOp-> getNumResults () + index ) ;
1587
1630
}
1588
1631
}
1589
1632
if (Operation *tiledAndFusedOp =
@@ -1597,11 +1640,6 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
1597
1640
}
1598
1641
}
1599
1642
1600
- DenseMap<Value, Value> replacements;
1601
- for (auto [origVal, resultNumber] : origValToResultNumber) {
1602
- replacements[origVal] = loops.front ()->getResult (resultNumber);
1603
- }
1604
-
1605
1643
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1606
1644
replacements};
1607
1645
}
0 commit comments