28
28
#include " mlir/Interfaces/TilingInterface.h"
29
29
#include " mlir/Rewrite/FrozenRewritePatternSet.h"
30
30
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
31
- #include " llvm/ADT/ScopeExit.h"
32
31
#include " llvm/ADT/TypeSwitch.h"
33
32
#include " llvm/Support/Debug.h"
34
33
#include < optional>
@@ -1468,47 +1467,6 @@ void SliceTrackingListener::notifyOperationReplaced(Operation *op,
1468
1467
ValueRange replacement) {
1469
1468
removeOp (op);
1470
1469
}
1471
-
1472
- // ===----------------------------------------------------------------------===//
1473
- // ReplacementListener
1474
- // ===----------------------------------------------------------------------===//
1475
-
1476
- // / Listener that tracks updates replacements for values which can be mutated.
1477
- // / This listener runs on top of the existing listener for the rewriter,
1478
- // / to make sure external users can still run listeners.
1479
- class ReplacementListener : public RewriterBase ::ForwardingListener {
1480
- public:
1481
- ReplacementListener (DenseMap<Value, Value> &replacements,
1482
- OpBuilder::Listener *listener)
1483
- : ForwardingListener(listener), replacements(replacements) {}
1484
-
1485
- void updateReplacementValues (ValueRange origValues,
1486
- ValueRange replaceValues) {
1487
- // This can probably be written better, but just iterates over the map
1488
- // and the new replacements for now.
1489
- for (auto &[key, val] : replacements) {
1490
- for (auto [orig, replace] : llvm::zip_equal (origValues, replaceValues)) {
1491
- if (val == orig) {
1492
- val = replace;
1493
- }
1494
- }
1495
- }
1496
- }
1497
-
1498
- void notifyOperationReplaced (Operation *op, Operation *newOp) override {
1499
- ForwardingListener::notifyOperationReplaced (op, newOp);
1500
- updateReplacementValues (op->getResults (), newOp->getResults ());
1501
- }
1502
-
1503
- void notifyOperationReplaced (Operation *op, ValueRange values) override {
1504
- ForwardingListener::notifyOperationReplaced (op, values);
1505
- updateReplacementValues (op->getResults (), values);
1506
- }
1507
-
1508
- private:
1509
- DenseMap<Value, Value> &replacements;
1510
- };
1511
-
1512
1470
} // namespace
1513
1471
1514
1472
// / Implementation of tile consumer and fuse producer greedily.
@@ -1535,27 +1493,26 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
1535
1493
for (auto *tiledOp : tilingResult->tiledOps )
1536
1494
tiledAndFusedOps.insert (tiledOp);
1537
1495
1538
- DenseMap<Value, Value> replacements;
1539
- for (auto [origVal, replacement] : llvm::zip_equal (
1540
- consumer->getResults (), tilingResult->mergeResult .replacements )) {
1541
- replacements[origVal] = replacement;
1542
- }
1543
-
1544
1496
// If there are no loops generated, fusion is immaterial.
1545
1497
auto &loops = tilingResult->loops ;
1546
1498
if (loops.empty ()) {
1499
+ DenseMap<Value, Value> replacements;
1500
+ for (auto [origVal, replacement] : llvm::zip_equal (
1501
+ consumer->getResults (), tilingResult->mergeResult .replacements )) {
1502
+ replacements[origVal] = replacement;
1503
+ }
1547
1504
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1548
1505
replacements};
1549
1506
}
1550
1507
1551
- // Since the loop gets potentially replaced during fusion, we need to track
1552
- // the mutation of replacement values. To do this, we attach a listener to
1553
- // update the replacements as they happen.
1554
- OpBuilder::Listener *previousListener = rewriter. getListener ();
1555
- auto resetListener =
1556
- llvm::make_scope_exit ([&]() { rewriter. setListener (previousListener); });
1557
- ReplacementListener replaceListener (replacements, previousListener) ;
1558
- rewriter. setListener (&replaceListener);
1508
+ // To keep track of replacements for now just record the map from the
1509
+ // original untiled value to the result number of the for loop. Since the
1510
+ // loop gets potentially replaced during fusion, keeping the value directly
1511
+ // wont work.
1512
+ DenseMap<Value, size_t > origValToResultNumber;
1513
+ for ( auto [ index , result] : llvm::enumerate (consumer-> getResults ())) {
1514
+ origValToResultNumber[result] = index ;
1515
+ }
1559
1516
1560
1517
// 2. Typically, the operands of the tiled operation are slices of the
1561
1518
// operands of the untiled operation. These are expressed in IR using
@@ -1624,9 +1581,9 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
1624
1581
worklistCandidates.append (newSlices.value ());
1625
1582
for (auto [index , result] :
1626
1583
llvm::enumerate (fusableProducerOp->getResults ())) {
1627
- replacements [result] = loops.front ()->getResult (
1628
- loops. front () ->getNumResults () -
1629
- fusableProducerOp-> getNumResults () + index ) ;
1584
+ origValToResultNumber [result] = loops.front ()->getNumResults () -
1585
+ fusableProducerOp ->getNumResults () +
1586
+ index ;
1630
1587
}
1631
1588
}
1632
1589
if (Operation *tiledAndFusedOp =
@@ -1640,6 +1597,11 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
1640
1597
}
1641
1598
}
1642
1599
1600
+ DenseMap<Value, Value> replacements;
1601
+ for (auto [origVal, resultNumber] : origValToResultNumber) {
1602
+ replacements[origVal] = loops.front ()->getResult (resultNumber);
1603
+ }
1604
+
1643
1605
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1644
1606
replacements};
1645
1607
}
0 commit comments