14
14
#include " iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
15
15
#include " mlir/Dialect/Linalg/Transforms/Transforms.h"
16
16
#include " mlir/Dialect/Tensor/IR/Tensor.h"
17
+ #include " mlir/Dialect/Utils/ReshapeOpsUtils.h"
17
18
#include " mlir/Dialect/Utils/StructuredOpsUtils.h"
18
19
#include " mlir/IR/MLIRContext.h"
19
20
#include " mlir/IR/PatternMatch.h"
@@ -246,8 +247,7 @@ getReassociationForExpansion(AffineMap indexingMap,
246
247
return reassociation;
247
248
}
248
249
249
- template <typename OpTy>
250
- static bool isFusableWithReshapeByDimExpansion (OpTy op,
250
+ static bool isFusableWithReshapeByDimExpansion (AttentionOp op,
251
251
OpOperand *fusableOpOperand) {
252
252
// Is fusable only if:
253
253
// - All the indexing maps for operands and results are projected
@@ -256,10 +256,11 @@ static bool isFusableWithReshapeByDimExpansion(OpTy op,
256
256
// - All the loops for the reshaped operand are parallel loops.
257
257
SmallVector<utils::IteratorType> iteratorTypes = op.getLoopIteratorTypes ();
258
258
AffineMap operandMap = op.getMatchingIndexingMap (fusableOpOperand);
259
- return op.hasPureTensorSemantics () &&
260
- llvm::all_of (
261
- op.getIndexingMapsArray (),
262
- [](AffineMap map) { return map.isProjectedPermutation (); }) &&
259
+ return operandMap && op.hasPureTensorSemantics () &&
260
+ llvm::all_of (op.getIndexingMapsArray (),
261
+ [](AffineMap map) {
262
+ return map && map.isProjectedPermutation ();
263
+ }) &&
263
264
operandMap.getNumResults () > 0 ;
264
265
}
265
266
@@ -391,6 +392,197 @@ static std::optional<SmallVector<Value>> fuseAttentionWithReshapeByExpansion(
391
392
return resultVals;
392
393
}
393
394
395
+ namespace {
396
+ class ScatterExpansionInfo {
397
+ public:
398
+ // Helper class similar to `ExpansionInfo` but only for`LinalgExt::ScatterOp`
399
+ // due to its special semantics (i.e. not all dims map to the iteration space)
400
+ LogicalResult compute (LinalgExt::ScatterOp scatterOp,
401
+ OpOperand *fusableOpOperand,
402
+ ArrayRef<ReassociationIndices> reassociationIndices,
403
+ ArrayRef<int64_t > expandedShape,
404
+ ArrayRef<int64_t > collapsedShape,
405
+ PatternRewriter &rewriter);
406
+
407
+ SmallVector<ReassociationIndices> updatesReassoc;
408
+ SmallVector<SmallVector<int64_t >> updatesShapeMap;
409
+ SmallVector<ReassociationIndices> indicesReassoc;
410
+ SmallVector<SmallVector<int64_t >> indicesShapeMap;
411
+ SmallVector<ReassociationIndices> originalReassoc;
412
+ SmallVector<SmallVector<int64_t >> originalShapeMap;
413
+ };
414
+
415
+ } // namespace
416
+
417
+ // Use the innermost indices in `reassoc` to construct a shape map out of
418
+ // `shape`
419
+ static SmallVector<SmallVector<int64_t >>
420
+ computeShapeMapFromReassoc (ArrayRef<ReassociationIndices> reassoc,
421
+ ArrayRef<int64_t > shape) {
422
+ SmallVector<SmallVector<int64_t >> shapeMap;
423
+ for (auto &indices : reassoc) {
424
+ shapeMap.emplace_back (shape.slice (indices.front (), indices.size ()));
425
+ }
426
+ return shapeMap;
427
+ }
428
+
429
+ static SmallVector<ReassociationIndices>
430
+ computeReassocFromShapeMap (ArrayRef<SmallVector<int64_t >> shapeMap) {
431
+ SmallVector<ReassociationIndices> reassoc;
432
+ int64_t dimCount = 0 ;
433
+ for (auto &shape : shapeMap) {
434
+ reassoc.emplace_back (
435
+ llvm::to_vector (llvm::seq<int64_t >(dimCount, dimCount + shape.size ())));
436
+ dimCount += shape.size ();
437
+ }
438
+ return reassoc;
439
+ }
440
+
441
+ LogicalResult ScatterExpansionInfo::compute (
442
+ LinalgExt::ScatterOp scatterOp, OpOperand *fusableOpOperand,
443
+ ArrayRef<ReassociationIndices> reassociationIndices,
444
+ ArrayRef<int64_t > expandedShape, ArrayRef<int64_t > collapsedShape,
445
+ PatternRewriter &rewriter) {
446
+ if (reassociationIndices.empty ())
447
+ return failure ();
448
+ assert (fusableOpOperand->getOwner () == scatterOp);
449
+
450
+ auto updatesShape = scatterOp.getUpdateType ().getShape ();
451
+ auto originalShape = scatterOp.getOriginalType ().getShape ();
452
+ auto rankOfContiguousSlice =
453
+ scatterOp.getOriginalType ().getRank () - scatterOp.getIndexDepth ();
454
+
455
+ // Helper lambda to convert a shape to an identity shape map.
456
+ auto shapeToIdentShapeMap = [](ArrayRef<int64_t > shape) {
457
+ return llvm::map_to_vector (
458
+ shape, [](int64_t dim) { return SmallVector<int64_t >{dim}; });
459
+ };
460
+
461
+ // Set `batchShapeMap` and `sliceShapeMap` based on the specific operand.
462
+ int64_t operandNum = fusableOpOperand->getOperandNumber ();
463
+
464
+ // In the case of `original`, no chenge to the iteration space
465
+ SmallVector<SmallVector<int64_t >> batchShapeMap =
466
+ operandNum == ScatterOp::kOriginalOpNum
467
+ ? shapeToIdentShapeMap (
468
+ originalShape.take_front (scatterOp.getBatchRank ()))
469
+ : computeShapeMapFromReassoc (
470
+ reassociationIndices.take_front (scatterOp.getBatchRank ()),
471
+ expandedShape);
472
+ // In the case of `indices`, no chenge to the iteration space
473
+ SmallVector<SmallVector<int64_t >> sliceShapeMap =
474
+ operandNum == ScatterOp::kIndicesOpNum
475
+ ? shapeToIdentShapeMap (originalShape.take_back (rankOfContiguousSlice))
476
+ : computeShapeMapFromReassoc (
477
+ reassociationIndices.take_back (rankOfContiguousSlice),
478
+ expandedShape);
479
+
480
+ // Early exit if iteration space is unchanged
481
+ if (llvm::all_of (batchShapeMap, [&](auto vec) { return vec.size () == 1 ; }) &&
482
+ llvm::all_of (sliceShapeMap, [&](auto vec) { return vec.size () == 1 ; })) {
483
+ return failure ();
484
+ }
485
+
486
+ updatesShapeMap = llvm::to_vector (llvm::concat<SmallVector<int64_t >>(
487
+ batchShapeMap,
488
+ shapeToIdentShapeMap (updatesShape.slice (scatterOp.getBatchRank (),
489
+ scatterOp.getUpdateSliceRank () -
490
+ rankOfContiguousSlice)),
491
+ sliceShapeMap));
492
+ indicesShapeMap = llvm::to_vector (llvm::concat<SmallVector<int64_t >>(
493
+ batchShapeMap, shapeToIdentShapeMap (scatterOp.getIndexDepth ())));
494
+ originalShapeMap = llvm::to_vector (llvm::concat<SmallVector<int64_t >>(
495
+ shapeToIdentShapeMap (originalShape.drop_back (rankOfContiguousSlice)),
496
+ sliceShapeMap));
497
+
498
+ updatesReassoc = computeReassocFromShapeMap (updatesShapeMap);
499
+ indicesReassoc = computeReassocFromShapeMap (indicesShapeMap);
500
+ originalReassoc = computeReassocFromShapeMap (originalShapeMap);
501
+ return success ();
502
+ }
503
+
504
+ static std::optional<Value>
505
+ fuseScatterWithReshapeByExpansion (ScatterOp scatterOp, Operation *reshapeOp,
506
+ OpOperand *fusableOpOperand,
507
+ PatternRewriter &rewriter) {
508
+ Location loc = scatterOp.getLoc ();
509
+ // Check if reshape is expanding or collapsing.
510
+ auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
511
+ auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
512
+ bool isExpanding = (expandingReshapeOp != nullptr );
513
+ RankedTensorType expandedType = isExpanding
514
+ ? expandingReshapeOp.getResultType ()
515
+ : collapsingReshapeOp.getSrcType ();
516
+ RankedTensorType collapsedType = isExpanding
517
+ ? expandingReshapeOp.getSrcType ()
518
+ : collapsingReshapeOp.getResultType ();
519
+ ScatterExpansionInfo info;
520
+ if (failed (info.compute (
521
+ scatterOp, fusableOpOperand,
522
+ isExpanding ? expandingReshapeOp.getReassociationIndices ()
523
+ : collapsingReshapeOp.getReassociationIndices (),
524
+ expandedType.getShape (), collapsedType.getShape (), rewriter))) {
525
+ return std::nullopt;
526
+ }
527
+
528
+ // Returns `reassociation` with indices modified so that they are a contiguous
529
+ // grouping of indices.
530
+ auto getType = [&](SmallVector<SmallVector<int64_t >> &shapeMap,
531
+ ShapedType type) {
532
+ SmallVector<int64_t > flattenedArray;
533
+ for (auto &shape : shapeMap) {
534
+ flattenedArray.append (shape.begin (), shape.end ());
535
+ }
536
+ return RankedTensorType::get (flattenedArray, type.getElementType ());
537
+ };
538
+
539
+ OpBuilder::InsertionGuard g (rewriter);
540
+ rewriter.setInsertionPoint (scatterOp);
541
+
542
+ auto isIdentityReassoc = [](SmallVector<ReassociationIndices> &indices) {
543
+ for (auto &index : indices) {
544
+ if (index .size () != 1 )
545
+ return false ;
546
+ }
547
+ return true ;
548
+ };
549
+
550
+ Value newUpdates = rewriter.create <tensor::ExpandShapeOp>(
551
+ loc, getType (info.updatesShapeMap , scatterOp.getUpdateType ()),
552
+ scatterOp.getUpdates (), info.updatesReassoc );
553
+ Value newIndices =
554
+ isIdentityReassoc (info.indicesReassoc )
555
+ ? scatterOp.getIndices ()
556
+ : rewriter.create <tensor::ExpandShapeOp>(
557
+ loc, getType (info.indicesShapeMap , scatterOp.getIndicesType ()),
558
+ scatterOp.getIndices (), info.indicesReassoc );
559
+ Value newOriginal =
560
+ isIdentityReassoc (info.originalReassoc )
561
+ ? scatterOp.getOriginal ()
562
+ : rewriter.create <tensor::ExpandShapeOp>(
563
+ loc,
564
+ getType (info.originalShapeMap , scatterOp.getOriginalType ()),
565
+ scatterOp.getOriginal (), info.originalReassoc );
566
+
567
+ auto newScatter = rewriter.create <ScatterOp>(
568
+ loc, newOriginal.getType (), ValueRange{newUpdates, newIndices},
569
+ ValueRange{newOriginal}, scatterOp.getDimensionMap (),
570
+ scatterOp.getUniqueIndices ());
571
+ rewriter.inlineRegionBefore (scatterOp.getRegion (), newScatter.getRegion (),
572
+ newScatter.getRegion ().begin ());
573
+
574
+ // Collapse back to originanl shape.
575
+ auto newCollapse = rewriter.create <tensor::CollapseShapeOp>(
576
+ loc, scatterOp.getOriginalType (), newScatter.getResult (0 ),
577
+ info.originalReassoc );
578
+
579
+ return {newCollapse};
580
+ }
581
+
582
+ // ===----------------------------------------------------------------------===//
583
+ // Fuse By Expansion Patterns
584
+ // ===----------------------------------------------------------------------===//
585
+
394
586
namespace {
395
587
396
588
// Fold attention with its consumer expand_shape op.
@@ -552,6 +744,74 @@ struct FoldScatterNonIterationUnitDims final
552
744
linalg::ControlDropUnitDims options;
553
745
};
554
746
747
+ struct FoldScatterWithProducerReshapeByExpansion final
748
+ : public OpRewritePattern<ScatterOp> {
749
+ FoldScatterWithProducerReshapeByExpansion (
750
+ MLIRContext *context, linalg::ControlFusionFn controlFoldingReshapes,
751
+ PatternBenefit benefit = 1 )
752
+ : OpRewritePattern<ScatterOp>(context, benefit),
753
+ controlFoldingReshapes (std::move(controlFoldingReshapes)) {}
754
+
755
+ LogicalResult matchAndRewrite (ScatterOp scatterOp,
756
+ PatternRewriter &rewriter) const override {
757
+ for (OpOperand &opOperand : scatterOp->getOpOperands ()) {
758
+ tensor::CollapseShapeOp reshapeOp =
759
+ opOperand.get ().getDefiningOp <tensor::CollapseShapeOp>();
760
+ if (!reshapeOp)
761
+ continue ;
762
+ if (!controlFoldingReshapes (&opOperand))
763
+ continue ;
764
+
765
+ std::optional<Value> replacementValue = fuseScatterWithReshapeByExpansion (
766
+ scatterOp, reshapeOp, &opOperand, rewriter);
767
+ if (!replacementValue)
768
+ return failure ();
769
+ rewriter.replaceOp (scatterOp, *replacementValue);
770
+ return success ();
771
+ }
772
+ return failure ();
773
+ }
774
+
775
+ linalg::ControlFusionFn controlFoldingReshapes;
776
+ };
777
+
778
+ struct FoldScatterWithConsumerReshapeByExpansion final
779
+ : public OpRewritePattern<tensor::ExpandShapeOp> {
780
+ FoldScatterWithConsumerReshapeByExpansion (
781
+ MLIRContext *context, linalg::ControlFusionFn controlFoldingReshapes,
782
+ PatternBenefit benefit = 1 )
783
+ : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
784
+ controlFoldingReshapes (std::move(controlFoldingReshapes)) {}
785
+
786
+ LogicalResult matchAndRewrite (tensor::ExpandShapeOp expandOp,
787
+ PatternRewriter &rewriter) const override {
788
+ auto producerResult = dyn_cast<OpResult>(expandOp.getSrc ());
789
+ if (!producerResult) {
790
+ return rewriter.notifyMatchFailure (expandOp,
791
+ " source not produced by an operation" );
792
+ }
793
+
794
+ auto scatterOp = producerResult.getDefiningOp <LinalgExt::ScatterOp>();
795
+ if (!scatterOp) {
796
+ return failure ();
797
+ }
798
+
799
+ if (!controlFoldingReshapes (&expandOp.getSrcMutable ())) {
800
+ return failure ();
801
+ }
802
+
803
+ std::optional<Value> replacementValue = fuseScatterWithReshapeByExpansion (
804
+ scatterOp, expandOp, scatterOp.getTiedOpOperand (producerResult),
805
+ rewriter);
806
+ if (!replacementValue)
807
+ return failure ();
808
+ rewriter.replaceOp (scatterOp, *replacementValue);
809
+ return success ();
810
+ }
811
+
812
+ linalg::ControlFusionFn controlFoldingReshapes;
813
+ };
814
+
555
815
} // namespace
556
816
557
817
// / Return the `reassociation` indices to use to collapse the operand when the
@@ -772,6 +1032,10 @@ void populateFoldReshapeOpsByExpansionPatterns(
772
1032
patterns.getContext (), controlFoldingReshapes);
773
1033
patterns.add <FoldAttentionWithProducerReshapeByExpansion>(
774
1034
patterns.getContext (), controlFoldingReshapes);
1035
+ patterns.add <FoldScatterWithProducerReshapeByExpansion>(
1036
+ patterns.getContext (), controlFoldingReshapes);
1037
+ patterns.add <FoldScatterWithConsumerReshapeByExpansion>(
1038
+ patterns.getContext (), controlFoldingReshapes);
775
1039
}
776
1040
777
1041
SmallVector<unsigned > defaultControlDropUnitDims (Operation *op) {
0 commit comments