Skip to content

Commit 7cd1fad

Browse files
committed
[LinalgExt] Scatter fusion by expansion
Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
1 parent 5daf9c8 commit 7cd1fad

File tree

2 files changed

+471
-6
lines changed

2 files changed

+471
-6
lines changed

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp

+270-6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
1515
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1616
#include "mlir/Dialect/Tensor/IR/Tensor.h"
17+
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
1718
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
1819
#include "mlir/IR/MLIRContext.h"
1920
#include "mlir/IR/PatternMatch.h"
@@ -246,8 +247,7 @@ getReassociationForExpansion(AffineMap indexingMap,
246247
return reassociation;
247248
}
248249

249-
template <typename OpTy>
250-
static bool isFusableWithReshapeByDimExpansion(OpTy op,
250+
static bool isFusableWithReshapeByDimExpansion(AttentionOp op,
251251
OpOperand *fusableOpOperand) {
252252
// Is fusable only if:
253253
// - All the indexing maps for operands and results are projected
@@ -256,10 +256,11 @@ static bool isFusableWithReshapeByDimExpansion(OpTy op,
256256
// - All the loops for the reshaped operand are parallel loops.
257257
SmallVector<utils::IteratorType> iteratorTypes = op.getLoopIteratorTypes();
258258
AffineMap operandMap = op.getMatchingIndexingMap(fusableOpOperand);
259-
return op.hasPureTensorSemantics() &&
260-
llvm::all_of(
261-
op.getIndexingMapsArray(),
262-
[](AffineMap map) { return map.isProjectedPermutation(); }) &&
259+
return operandMap && op.hasPureTensorSemantics() &&
260+
llvm::all_of(op.getIndexingMapsArray(),
261+
[](AffineMap map) {
262+
return map && map.isProjectedPermutation();
263+
}) &&
263264
operandMap.getNumResults() > 0;
264265
}
265266

@@ -391,6 +392,197 @@ static std::optional<SmallVector<Value>> fuseAttentionWithReshapeByExpansion(
391392
return resultVals;
392393
}
393394

395+
namespace {
396+
class ScatterExpansionInfo {
397+
public:
398+
// Helper class similar to `ExpansionInfo` but only for`LinalgExt::ScatterOp`
399+
// due to its special semantics (i.e. not all dims map to the iteration space)
400+
LogicalResult compute(LinalgExt::ScatterOp scatterOp,
401+
OpOperand *fusableOpOperand,
402+
ArrayRef<ReassociationIndices> reassociationIndices,
403+
ArrayRef<int64_t> expandedShape,
404+
ArrayRef<int64_t> collapsedShape,
405+
PatternRewriter &rewriter);
406+
407+
SmallVector<ReassociationIndices> updatesReassoc;
408+
SmallVector<SmallVector<int64_t>> updatesShapeMap;
409+
SmallVector<ReassociationIndices> indicesReassoc;
410+
SmallVector<SmallVector<int64_t>> indicesShapeMap;
411+
SmallVector<ReassociationIndices> originalReassoc;
412+
SmallVector<SmallVector<int64_t>> originalShapeMap;
413+
};
414+
415+
} // namespace
416+
417+
// Use the innermost indices in `reassoc` to construct a shape map out of
418+
// `shape`
419+
static SmallVector<SmallVector<int64_t>>
420+
computeShapeMapFromReassoc(ArrayRef<ReassociationIndices> reassoc,
421+
ArrayRef<int64_t> shape) {
422+
SmallVector<SmallVector<int64_t>> shapeMap;
423+
for (auto &indices : reassoc) {
424+
shapeMap.emplace_back(shape.slice(indices.front(), indices.size()));
425+
}
426+
return shapeMap;
427+
}
428+
429+
static SmallVector<ReassociationIndices>
430+
computeReassocFromShapeMap(ArrayRef<SmallVector<int64_t>> shapeMap) {
431+
SmallVector<ReassociationIndices> reassoc;
432+
int64_t dimCount = 0;
433+
for (auto &shape : shapeMap) {
434+
reassoc.emplace_back(
435+
llvm::to_vector(llvm::seq<int64_t>(dimCount, dimCount + shape.size())));
436+
dimCount += shape.size();
437+
}
438+
return reassoc;
439+
}
440+
441+
LogicalResult ScatterExpansionInfo::compute(
442+
LinalgExt::ScatterOp scatterOp, OpOperand *fusableOpOperand,
443+
ArrayRef<ReassociationIndices> reassociationIndices,
444+
ArrayRef<int64_t> expandedShape, ArrayRef<int64_t> collapsedShape,
445+
PatternRewriter &rewriter) {
446+
if (reassociationIndices.empty())
447+
return failure();
448+
assert(fusableOpOperand->getOwner() == scatterOp);
449+
450+
auto updatesShape = scatterOp.getUpdateType().getShape();
451+
auto originalShape = scatterOp.getOriginalType().getShape();
452+
auto rankOfContiguousSlice =
453+
scatterOp.getOriginalType().getRank() - scatterOp.getIndexDepth();
454+
455+
// Helper lambda to convert a shape to an identity shape map.
456+
auto shapeToIdentShapeMap = [](ArrayRef<int64_t> shape) {
457+
return llvm::map_to_vector(
458+
shape, [](int64_t dim) { return SmallVector<int64_t>{dim}; });
459+
};
460+
461+
// Set `batchShapeMap` and `sliceShapeMap` based on the specific operand.
462+
int64_t operandNum = fusableOpOperand->getOperandNumber();
463+
464+
// In the case of `original`, no chenge to the iteration space
465+
SmallVector<SmallVector<int64_t>> batchShapeMap =
466+
operandNum == ScatterOp::kOriginalOpNum
467+
? shapeToIdentShapeMap(
468+
originalShape.take_front(scatterOp.getBatchRank()))
469+
: computeShapeMapFromReassoc(
470+
reassociationIndices.take_front(scatterOp.getBatchRank()),
471+
expandedShape);
472+
// In the case of `indices`, no chenge to the iteration space
473+
SmallVector<SmallVector<int64_t>> sliceShapeMap =
474+
operandNum == ScatterOp::kIndicesOpNum
475+
? shapeToIdentShapeMap(originalShape.take_back(rankOfContiguousSlice))
476+
: computeShapeMapFromReassoc(
477+
reassociationIndices.take_back(rankOfContiguousSlice),
478+
expandedShape);
479+
480+
// Early exit if iteration space is unchanged
481+
if (llvm::all_of(batchShapeMap, [&](auto vec) { return vec.size() == 1; }) &&
482+
llvm::all_of(sliceShapeMap, [&](auto vec) { return vec.size() == 1; })) {
483+
return failure();
484+
}
485+
486+
updatesShapeMap = llvm::to_vector(llvm::concat<SmallVector<int64_t>>(
487+
batchShapeMap,
488+
shapeToIdentShapeMap(updatesShape.slice(scatterOp.getBatchRank(),
489+
scatterOp.getUpdateSliceRank() -
490+
rankOfContiguousSlice)),
491+
sliceShapeMap));
492+
indicesShapeMap = llvm::to_vector(llvm::concat<SmallVector<int64_t>>(
493+
batchShapeMap, shapeToIdentShapeMap(scatterOp.getIndexDepth())));
494+
originalShapeMap = llvm::to_vector(llvm::concat<SmallVector<int64_t>>(
495+
shapeToIdentShapeMap(originalShape.drop_back(rankOfContiguousSlice)),
496+
sliceShapeMap));
497+
498+
updatesReassoc = computeReassocFromShapeMap(updatesShapeMap);
499+
indicesReassoc = computeReassocFromShapeMap(indicesShapeMap);
500+
originalReassoc = computeReassocFromShapeMap(originalShapeMap);
501+
return success();
502+
}
503+
504+
static std::optional<Value>
505+
fuseScatterWithReshapeByExpansion(ScatterOp scatterOp, Operation *reshapeOp,
506+
OpOperand *fusableOpOperand,
507+
PatternRewriter &rewriter) {
508+
Location loc = scatterOp.getLoc();
509+
// Check if reshape is expanding or collapsing.
510+
auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
511+
auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
512+
bool isExpanding = (expandingReshapeOp != nullptr);
513+
RankedTensorType expandedType = isExpanding
514+
? expandingReshapeOp.getResultType()
515+
: collapsingReshapeOp.getSrcType();
516+
RankedTensorType collapsedType = isExpanding
517+
? expandingReshapeOp.getSrcType()
518+
: collapsingReshapeOp.getResultType();
519+
ScatterExpansionInfo info;
520+
if (failed(info.compute(
521+
scatterOp, fusableOpOperand,
522+
isExpanding ? expandingReshapeOp.getReassociationIndices()
523+
: collapsingReshapeOp.getReassociationIndices(),
524+
expandedType.getShape(), collapsedType.getShape(), rewriter))) {
525+
return std::nullopt;
526+
}
527+
528+
// Returns `reassociation` with indices modified so that they are a contiguous
529+
// grouping of indices.
530+
auto getType = [&](SmallVector<SmallVector<int64_t>> &shapeMap,
531+
ShapedType type) {
532+
SmallVector<int64_t> flattenedArray;
533+
for (auto &shape : shapeMap) {
534+
flattenedArray.append(shape.begin(), shape.end());
535+
}
536+
return RankedTensorType::get(flattenedArray, type.getElementType());
537+
};
538+
539+
OpBuilder::InsertionGuard g(rewriter);
540+
rewriter.setInsertionPoint(scatterOp);
541+
542+
auto isIdentityReassoc = [](SmallVector<ReassociationIndices> &indices) {
543+
for (auto &index : indices) {
544+
if (index.size() != 1)
545+
return false;
546+
}
547+
return true;
548+
};
549+
550+
Value newUpdates = rewriter.create<tensor::ExpandShapeOp>(
551+
loc, getType(info.updatesShapeMap, scatterOp.getUpdateType()),
552+
scatterOp.getUpdates(), info.updatesReassoc);
553+
Value newIndices =
554+
isIdentityReassoc(info.indicesReassoc)
555+
? scatterOp.getIndices()
556+
: rewriter.create<tensor::ExpandShapeOp>(
557+
loc, getType(info.indicesShapeMap, scatterOp.getIndicesType()),
558+
scatterOp.getIndices(), info.indicesReassoc);
559+
Value newOriginal =
560+
isIdentityReassoc(info.originalReassoc)
561+
? scatterOp.getOriginal()
562+
: rewriter.create<tensor::ExpandShapeOp>(
563+
loc,
564+
getType(info.originalShapeMap, scatterOp.getOriginalType()),
565+
scatterOp.getOriginal(), info.originalReassoc);
566+
567+
auto newScatter = rewriter.create<ScatterOp>(
568+
loc, newOriginal.getType(), ValueRange{newUpdates, newIndices},
569+
ValueRange{newOriginal}, scatterOp.getDimensionMap(),
570+
scatterOp.getUniqueIndices());
571+
rewriter.inlineRegionBefore(scatterOp.getRegion(), newScatter.getRegion(),
572+
newScatter.getRegion().begin());
573+
574+
// Collapse back to originanl shape.
575+
auto newCollapse = rewriter.create<tensor::CollapseShapeOp>(
576+
loc, scatterOp.getOriginalType(), newScatter.getResult(0),
577+
info.originalReassoc);
578+
579+
return {newCollapse};
580+
}
581+
582+
//===----------------------------------------------------------------------===//
583+
// Fuse By Expansion Patterns
584+
//===----------------------------------------------------------------------===//
585+
394586
namespace {
395587

396588
// Fold attention with its consumer expand_shape op.
@@ -552,6 +744,74 @@ struct FoldScatterNonIterationUnitDims final
552744
linalg::ControlDropUnitDims options;
553745
};
554746

747+
struct FoldScatterWithProducerReshapeByExpansion final
748+
: public OpRewritePattern<ScatterOp> {
749+
FoldScatterWithProducerReshapeByExpansion(
750+
MLIRContext *context, linalg::ControlFusionFn controlFoldingReshapes,
751+
PatternBenefit benefit = 1)
752+
: OpRewritePattern<ScatterOp>(context, benefit),
753+
controlFoldingReshapes(std::move(controlFoldingReshapes)) {}
754+
755+
LogicalResult matchAndRewrite(ScatterOp scatterOp,
756+
PatternRewriter &rewriter) const override {
757+
for (OpOperand &opOperand : scatterOp->getOpOperands()) {
758+
tensor::CollapseShapeOp reshapeOp =
759+
opOperand.get().getDefiningOp<tensor::CollapseShapeOp>();
760+
if (!reshapeOp)
761+
continue;
762+
if (!controlFoldingReshapes(&opOperand))
763+
continue;
764+
765+
std::optional<Value> replacementValue = fuseScatterWithReshapeByExpansion(
766+
scatterOp, reshapeOp, &opOperand, rewriter);
767+
if (!replacementValue)
768+
return failure();
769+
rewriter.replaceOp(scatterOp, *replacementValue);
770+
return success();
771+
}
772+
return failure();
773+
}
774+
775+
linalg::ControlFusionFn controlFoldingReshapes;
776+
};
777+
778+
struct FoldScatterWithConsumerReshapeByExpansion final
779+
: public OpRewritePattern<tensor::ExpandShapeOp> {
780+
FoldScatterWithConsumerReshapeByExpansion(
781+
MLIRContext *context, linalg::ControlFusionFn controlFoldingReshapes,
782+
PatternBenefit benefit = 1)
783+
: OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
784+
controlFoldingReshapes(std::move(controlFoldingReshapes)) {}
785+
786+
LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
787+
PatternRewriter &rewriter) const override {
788+
auto producerResult = dyn_cast<OpResult>(expandOp.getSrc());
789+
if (!producerResult) {
790+
return rewriter.notifyMatchFailure(expandOp,
791+
"source not produced by an operation");
792+
}
793+
794+
auto scatterOp = producerResult.getDefiningOp<LinalgExt::ScatterOp>();
795+
if (!scatterOp) {
796+
return failure();
797+
}
798+
799+
if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
800+
return failure();
801+
}
802+
803+
std::optional<Value> replacementValue = fuseScatterWithReshapeByExpansion(
804+
scatterOp, expandOp, scatterOp.getTiedOpOperand(producerResult),
805+
rewriter);
806+
if (!replacementValue)
807+
return failure();
808+
rewriter.replaceOp(scatterOp, *replacementValue);
809+
return success();
810+
}
811+
812+
linalg::ControlFusionFn controlFoldingReshapes;
813+
};
814+
555815
} // namespace
556816

557817
/// Return the `reassociation` indices to use to collapse the operand when the
@@ -772,6 +1032,10 @@ void populateFoldReshapeOpsByExpansionPatterns(
7721032
patterns.getContext(), controlFoldingReshapes);
7731033
patterns.add<FoldAttentionWithProducerReshapeByExpansion>(
7741034
patterns.getContext(), controlFoldingReshapes);
1035+
patterns.add<FoldScatterWithProducerReshapeByExpansion>(
1036+
patterns.getContext(), controlFoldingReshapes);
1037+
patterns.add<FoldScatterWithConsumerReshapeByExpansion>(
1038+
patterns.getContext(), controlFoldingReshapes);
7751039
}
7761040

7771041
SmallVector<unsigned> defaultControlDropUnitDims(Operation *op) {

0 commit comments

Comments
 (0)