diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp index 901288bd4788..6ad28e5b6d4d 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp @@ -14,14 +14,50 @@ #include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" namespace mlir::iree_compiler::IREE::LinalgExt { +static bool +isIdentityReassoc(const SmallVector &indices) { + for (auto &index : indices) { + if (index.size() != 1) + return false; + } + return true; +}; + +static SmallVector +computeReassocFromShapeMap(ArrayRef> shapeMap) { + SmallVector reassoc; + int64_t dimCount = 0; + for (auto &shape : shapeMap) { + reassoc.emplace_back( + llvm::to_vector(llvm::seq(dimCount, dimCount + shape.size()))); + dimCount += shape.size(); + } + return reassoc; +} + namespace { +/// Helper class that supports fusing reshapes with operands when not all of the +/// shape dims map to the iteration space. +struct ReshapeOperandInfo { + static constexpr int64_t kNoMapping = -1; + + // Original shape of this operand. + ArrayRef originalShape; + + // Similar to the results of the operand's `AffineMap` except `kNoMapping` if + // that dim doesn't map to the iteration space. For example, the indexed + // dimensions in a LinalgExt::ScatterOp. + SmallVector operandToIterationSpace; +}; + /// Information needed to expand an operation to fold the reshape with /// it. class ExpansionInfo { @@ -30,32 +66,78 @@ class ExpansionInfo { // of the expanded op given the `indexingMap` of the fused operand/result of // the op, the `reassocationMaps` of the reshape op and the shape of // the expanded op. - template - LogicalResult compute(OpTy op, OpOperand *fusableOpOperand, - ArrayRef reassociationMaps, - ArrayRef expandedShape, - ArrayRef collapsedShape, - PatternRewriter &rewriter); - unsigned getOrigOpNumDims() const { return reassociation.size(); } - unsigned getExpandedOpNumDims() const { return expandedOpNumDims; } - ReassociationIndicesRef getExpandedDims(unsigned i) const { - return reassociation[i]; - } - ArrayRef getExpandedShapeOfDim(unsigned i) const { - return expandedShapeMap[i]; - } - ArrayRef getOriginalShape() const { return originalLoopExtent; } + LogicalResult compute(SmallVector infos, + SmallVector loopRanges, + OpOperand *fusableOpOperand, + ArrayRef operandReassoc, + ArrayRef expandedShape); + + std::optional getOrCreateExpanded(Location loc, OpOperand *operand, + RewriterBase &rewriter) { + auto shapeMap = this->getShapeMap(operand); + auto reassoc = computeReassocFromShapeMap(shapeMap); + if (isIdentityReassoc(reassoc)) { + return operand->get(); + } + SmallVector flattenedArray; + for (auto &shape : shapeMap) { + flattenedArray.append(shape.begin(), shape.end()); + } + auto oldType = cast(operand->get().getType()); + auto newType = + RankedTensorType::get(flattenedArray, oldType.getElementType()); + if (failed(reshapeLikeShapesAreCompatible( + [&](const Twine &msg) { + return rewriter.notifyMatchFailure(loc, msg); + }, + oldType.getShape(), newType.getShape(), reassoc, + /*isExpandingReshape=*/true))) { + return {}; + } + return rewriter.create(loc, newType, operand->get(), + reassoc); + }; + + /// Get the shape map for the operand. + SmallVector> getShapeMap(OpOperand *operand) const { + auto info = reshapeInfos[operand->getOperandNumber()]; + SmallVector> shapeMap; + for (auto [operandIdx, loopIdx] : + llvm::enumerate(info.operandToIterationSpace)) { + if (loopIdx == ReshapeOperandInfo::kNoMapping) { + shapeMap.push_back( + SmallVector{info.originalShape[operandIdx]}); + } else { + shapeMap.push_back(loopShapeMap[loopIdx]); + } + } + return shapeMap; + } + + SmallVector getReassoc(OpOperand *operand) const { + auto shapeMap = this->getShapeMap(operand); + return computeReassocFromShapeMap(shapeMap); + } + + unsigned getOrigNumLoops() const { return loopReassoc.size(); } + unsigned getExpandedNumLoops() const { return expandedOpNumDims; } + ReassociationIndicesRef getExpandedLoops(unsigned i) const { + return loopReassoc[i]; + } + ArrayRef getExpandedShapeOfLoop(unsigned i) const { + return loopShapeMap[i]; + } private: - /// Reassociation from the dimensions in the original operation to the - /// dimension of the expanded operation. - SmallVector reassociation; + /// Extent of the iteration space in the original operation. + SmallVector loopRanges; + SmallVector loopReassoc; /// Mapping from extent of loops in the original operation, to the extent of /// loops in the expanded operation. - SmallVector> expandedShapeMap; - /// Extent of the loop in the original operation. - SmallVector originalLoopExtent; + SmallVector> loopShapeMap; unsigned expandedOpNumDims; + /// Info about the reassociation and original shape for each operand. + SmallVector reshapeInfos; }; class CollapsingInfo { @@ -109,50 +191,46 @@ class CollapsingInfo { } // namespace -template -LogicalResult ExpansionInfo::compute(OpTy op, OpOperand *fusableOpOperand, - ArrayRef reassociationMaps, - ArrayRef expandedShape, - ArrayRef collapsedShape, - PatternRewriter &rewriter) { - if (reassociationMaps.empty()) +LogicalResult ExpansionInfo::compute( + SmallVector infos, SmallVector loopRanges, + OpOperand *fusableOpOperand, ArrayRef operandReassoc, + ArrayRef expandedShape) { + if (operandReassoc.empty()) return failure(); - AffineMap fusedIndexMap = op.getMatchingIndexingMap(fusableOpOperand); - FailureOr> originalLoopRange = op.getStaticLoopRanges(); - if (failed(originalLoopRange)) { + + int64_t operandNum = fusableOpOperand->getOperandNumber(); + ReshapeOperandInfo &fusionOperandInfo = infos[operandNum]; + this->loopShapeMap.clear(); + this->loopShapeMap.resize(loopRanges.size()); + for (auto [operandIdx, loopIdx] : + llvm::enumerate(fusionOperandInfo.operandToIterationSpace)) { + if (loopIdx == ReshapeOperandInfo::kNoMapping) { + continue; + } + + // Compute the shape map at element `loopIdx` + ReassociationIndicesRef indices = operandReassoc[operandIdx]; + for (auto [dimIdx, shapeIdx] : llvm::enumerate(indices)) { + this->loopShapeMap[loopIdx].push_back(expandedShape[shapeIdx]); + } + } + + // Fill in the remaining elements with `loopRanges` + this->expandedOpNumDims = 0; + for (const auto &[loopIdx, shapeMap] : llvm::enumerate(this->loopShapeMap)) { + if (shapeMap.empty()) { + this->loopShapeMap[loopIdx] = SmallVector{loopRanges[loopIdx]}; + } + this->expandedOpNumDims += shapeMap.size(); + } + + if (llvm::all_of(this->loopShapeMap, + [&](auto vec) { return vec.size() == 1; })) { return failure(); } - originalLoopExtent.assign(originalLoopRange->begin(), - originalLoopRange->end()); - - reassociation.clear(); - expandedShapeMap.clear(); - // Compute the number of dimension in the expanded op that correspond to each - // dimension of the original op. - SmallVector numExpandedDims(fusedIndexMap.getNumDims(), 1); - expandedShapeMap.resize(fusedIndexMap.getNumDims()); - for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) { - unsigned pos = cast(resultExpr.value()).getPosition(); - AffineMap foldedDims = reassociationMaps[resultExpr.index()]; - numExpandedDims[pos] = foldedDims.getNumResults(); - ArrayRef shape = - expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]); - expandedShapeMap[pos].assign(shape.begin(), shape.end()); - } - // The remaining dimensions remain the same. - for (unsigned i : llvm::seq(0, fusedIndexMap.getNumDims())) - if (expandedShapeMap[i].empty()) - expandedShapeMap[i] = {originalLoopExtent[i]}; - - // Compute reassociation map from the original op to the expanded op. - unsigned sum = 0; - reassociation.reserve(fusedIndexMap.getNumDims()); - for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) { - auto seq = llvm::seq(sum, sum + numFoldedDim.value()); - reassociation.emplace_back(seq.begin(), seq.end()); - sum += numFoldedDim.value(); - } - expandedOpNumDims = sum; + this->loopReassoc = computeReassocFromShapeMap(this->loopShapeMap); + this->reshapeInfos = std::move(infos); + this->loopRanges = std::move(loopRanges); return success(); } @@ -201,6 +279,77 @@ CollapsingInfo::initialize(unsigned origNumLoops, return success(); } +static SmallVector +getAttentionReshapeInfo(LinalgExt::AttentionOp attentionOp) { + return llvm::map_to_vector( + attentionOp->getOpOperands(), [&](OpOperand &opOperand) { + ReshapeOperandInfo operandInfo; + auto operandType = dyn_cast(opOperand.get().getType()); + if (!operandType) { + assert( + attentionOp.getMatchingIndexingMap(&opOperand).getNumResults() == + 0 && + "expected non-shaped type to have no results in indexing map"); + return operandInfo; + } + + operandInfo.originalShape = operandType.getShape(); + for (auto result : + attentionOp.getMatchingIndexingMap(&opOperand).getResults()) { + operandInfo.operandToIterationSpace.push_back( + cast(result).getPosition()); + } + return operandInfo; + }); +} + +static SmallVector +getScatterReshapeInfo(LinalgExt::ScatterOp scatterOp) { + SmallVector infos; + auto rankOfContiguousSlice = + scatterOp.getOriginalType().getRank() - scatterOp.getIndexDepth(); + auto updateRank = scatterOp.getUpdateType().getRank(); + + // Operand #0 Updates + { + ReshapeOperandInfo updateInfo; + updateInfo.originalShape = scatterOp.getUpdateType().getShape(); + llvm::append_range(updateInfo.operandToIterationSpace, + llvm::seq(0, scatterOp.getBatchRank())); + updateInfo.operandToIterationSpace.append( + updateRank - (rankOfContiguousSlice + scatterOp.getBatchRank()), + ReshapeOperandInfo::kNoMapping); + llvm::append_range( + updateInfo.operandToIterationSpace, + llvm::seq(updateRank - rankOfContiguousSlice, updateRank)); + infos.push_back(std::move(updateInfo)); + } + + // Operand#1 Indices + { + ReshapeOperandInfo indicesInfo; + indicesInfo.originalShape = scatterOp.getIndicesType().getShape(); + llvm::append_range(indicesInfo.operandToIterationSpace, + llvm::seq(0, scatterOp.getBatchRank())); + indicesInfo.operandToIterationSpace.push_back( + ReshapeOperandInfo::kNoMapping); + infos.push_back(std::move(indicesInfo)); + } + + // Operand #2 Original + { + ReshapeOperandInfo originalInfo; + originalInfo.originalShape = scatterOp.getOriginalType().getShape(); + originalInfo.operandToIterationSpace.append(scatterOp.getIndexDepth(), + ReshapeOperandInfo::kNoMapping); + llvm::append_range( + originalInfo.operandToIterationSpace, + llvm::seq(updateRank - rankOfContiguousSlice, updateRank)); + infos.push_back(std::move(originalInfo)); + } + return infos; +}; + static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo) { @@ -208,46 +357,17 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, for (AffineExpr expr : indexingMap.getResults()) { unsigned pos = cast(expr).getPosition(); auto expandedExprs = llvm::to_vector_of( - llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) { + llvm::map_range(expansionInfo.getExpandedLoops(pos), [&](int64_t v) { return builder.getAffineDimExpr(static_cast(v)); })); newExprs.append(expandedExprs.begin(), expandedExprs.end()); } - return AffineMap::get(expansionInfo.getExpandedOpNumDims(), + return AffineMap::get(expansionInfo.getExpandedNumLoops(), indexingMap.getNumSymbols(), newExprs, builder.getContext()); } -static RankedTensorType getExpandedType(RankedTensorType originalType, - AffineMap indexingMap, - const ExpansionInfo &expansionInfo) { - SmallVector expandedShape; - for (AffineExpr expr : indexingMap.getResults()) { - unsigned dim = cast(expr).getPosition(); - auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim); - expandedShape.append(dimExpansion.begin(), dimExpansion.end()); - } - return RankedTensorType::get(expandedShape, originalType.getElementType()); -} - -static SmallVector -getReassociationForExpansion(AffineMap indexingMap, - const ExpansionInfo &expansionInfo) { - SmallVector reassociation; - unsigned numReshapeDims = 0; - for (AffineExpr expr : indexingMap.getResults()) { - unsigned dim = cast(expr).getPosition(); - auto numExpandedDims = expansionInfo.getExpandedDims(dim).size(); - SmallVector indices = llvm::to_vector<2>( - llvm::seq(numReshapeDims, numReshapeDims + numExpandedDims)); - reassociation.emplace_back(std::move(indices)); - numReshapeDims += numExpandedDims; - } - return reassociation; -} - -template -static bool isFusableWithReshapeByDimExpansion(OpTy op, +static bool isFusableWithReshapeByDimExpansion(AttentionOp op, OpOperand *fusableOpOperand) { // Is fusable only if: // - All the indexing maps for operands and results are projected @@ -256,10 +376,11 @@ static bool isFusableWithReshapeByDimExpansion(OpTy op, // - All the loops for the reshaped operand are parallel loops. SmallVector iteratorTypes = op.getLoopIteratorTypes(); AffineMap operandMap = op.getMatchingIndexingMap(fusableOpOperand); - return op.hasPureTensorSemantics() && - llvm::all_of( - op.getIndexingMapsArray(), - [](AffineMap map) { return map.isProjectedPermutation(); }) && + return operandMap && op.hasPureTensorSemantics() && + llvm::all_of(op.getIndexingMapsArray(), + [](AffineMap map) { + return map && map.isProjectedPermutation(); + }) && operandMap.getNumResults() > 0; } @@ -277,16 +398,13 @@ static std::optional> fuseAttentionWithReshapeByExpansion( RankedTensorType expandedType = isExpanding ? expandingReshapeOp.getResultType() : collapsingReshapeOp.getSrcType(); - RankedTensorType collapsedType = isExpanding - ? expandingReshapeOp.getSrcType() - : collapsingReshapeOp.getResultType(); - ExpansionInfo expansionInfo; if (failed(expansionInfo.compute( - attentionOp, fusableOpOperand, - isExpanding ? expandingReshapeOp.getReassociationMaps() - : collapsingReshapeOp.getReassociationMaps(), - expandedType.getShape(), collapsedType.getShape(), rewriter))) + getAttentionReshapeInfo(attentionOp), + attentionOp.getStaticLoopRanges().value(), fusableOpOperand, + isExpanding ? expandingReshapeOp.getReassociationIndices() + : collapsingReshapeOp.getReassociationIndices(), + expandedType.getShape()))) return std::nullopt; auto expandedOpIndexingMaps = llvm::to_vector_of( llvm::map_range(attentionOp.getIndexingMapsArray(), [&](AffineMap m) { @@ -305,54 +423,23 @@ static std::optional> fuseAttentionWithReshapeByExpansion( : collapsingReshapeOp.getSrc()); continue; } - if (auto opOperandType = - dyn_cast(opOperand->get().getType())) { - AffineMap indexingMap = attentionOp.getMatchingIndexingMap(opOperand); - RankedTensorType expandedOperandType = - getExpandedType(opOperandType, indexingMap, expansionInfo); - if (expandedOperandType != opOperand->get().getType()) { - // Reshape the operand to get the right type. - SmallVector reassociation = - getReassociationForExpansion(indexingMap, expansionInfo); - if (failed(reshapeLikeShapesAreCompatible( - [&](const Twine &msg) { - return rewriter.notifyMatchFailure(attentionOp, msg); - }, - opOperandType.getShape(), expandedOperandType.getShape(), - reassociation, - /*isExpandingReshape=*/true))) - return std::nullopt; - expandedOpOperands.push_back(rewriter.create( - loc, expandedOperandType, opOperand->get(), reassociation)); - continue; - } - } - expandedOpOperands.push_back(opOperand->get()); + // Reshape the operand to get the right type. + std::optional expanded = + expansionInfo.getOrCreateExpanded(loc, opOperand, rewriter); + if (!expanded) + return std::nullopt; + expandedOpOperands.push_back(*expanded); + continue; } Value output; OpOperand &outOperand = attentionOp.getOutputMutable(); - AffineMap indexingMap = attentionOp.getMatchingIndexingMap(&outOperand); - auto opOperandType = cast(outOperand.get().getType()); - RankedTensorType expandedOutputType = - getExpandedType(opOperandType, indexingMap, expansionInfo); - if (expandedOutputType != outOperand.get().getType()) { - SmallVector reassociation = - getReassociationForExpansion(indexingMap, expansionInfo); - if (failed(reshapeLikeShapesAreCompatible( - [&](const Twine &msg) { - return rewriter.notifyMatchFailure(attentionOp, msg); - }, - opOperandType.getShape(), expandedOutputType.getShape(), - reassociation, - /*isExpandingReshape=*/true))) - return std::nullopt; - output = rewriter.create( - loc, expandedOutputType, outOperand.get(), reassociation); - } else { - output = outOperand.get(); - } + std::optional maybeOutput = + expansionInfo.getOrCreateExpanded(loc, &outOperand, rewriter); + if (!maybeOutput) + return std::nullopt; + output = *maybeOutput; Value maskOperand; if (expandedOpOperands.size() > 4) { @@ -377,9 +464,7 @@ static std::optional> fuseAttentionWithReshapeByExpansion( int64_t resultNumber = opResult.getResultNumber(); if (resultTypes[resultNumber] != opResult.getType()) { SmallVector reassociation = - getReassociationForExpansion( - attentionOp.getIndexingMapsForResults()[resultNumber], - expansionInfo); + expansionInfo.getReassoc(attentionOp.getTiedOpOperand(opResult)); resultVals.push_back(rewriter.create( attentionOp.getLoc(), opResult.getType(), fusedOp->getResult(resultNumber), reassociation)); @@ -391,6 +476,64 @@ static std::optional> fuseAttentionWithReshapeByExpansion( return resultVals; } +static std::optional +fuseScatterWithReshapeByExpansion(ScatterOp scatterOp, Operation *reshapeOp, + OpOperand *fusableOpOperand, + PatternRewriter &rewriter) { + Location loc = scatterOp.getLoc(); + // Check if reshape is expanding or collapsing. + auto expandingReshapeOp = dyn_cast(*reshapeOp); + auto collapsingReshapeOp = dyn_cast(*reshapeOp); + bool isExpanding = (expandingReshapeOp != nullptr); + RankedTensorType expandedType = isExpanding + ? expandingReshapeOp.getResultType() + : collapsingReshapeOp.getSrcType(); + ExpansionInfo info; + if (failed(info.compute( + getScatterReshapeInfo(scatterOp), + scatterOp.getStaticLoopRanges().value(), fusableOpOperand, + isExpanding ? expandingReshapeOp.getReassociationIndices() + : collapsingReshapeOp.getReassociationIndices(), + expandedType.getShape()))) { + return std::nullopt; + } + + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(scatterOp); + + OpOperand *update = scatterOp.getDpsInputOperand(0); + OpOperand *indices = scatterOp.getDpsInputOperand(1); + OpOperand *original = scatterOp.getDpsInitOperand(0); + auto newUpdates = info.getOrCreateExpanded(loc, update, rewriter).value(); + auto newIndices = info.getOrCreateExpanded(loc, indices, rewriter).value(); + auto newOriginal = info.getOrCreateExpanded(loc, original, rewriter).value(); + + auto newScatter = rewriter.create( + loc, newOriginal.getType(), ValueRange{newUpdates, newIndices}, + ValueRange{newOriginal}, scatterOp.getDimensionMap(), + scatterOp.getUniqueIndices()); + rewriter.inlineRegionBefore(scatterOp.getRegion(), newScatter.getRegion(), + newScatter.getRegion().begin()); + + auto originalShapeMap = info.getShapeMap(original); + SmallVector originalReassoc = + computeReassocFromShapeMap(originalShapeMap); + + // Collapse back to original shape. + if (isIdentityReassoc(originalReassoc)) { + return {newScatter.getResult(0)}; + } + auto newCollapse = rewriter.create( + loc, scatterOp.getOriginalType(), newScatter.getResult(0), + originalReassoc); + + return {newCollapse}; +} + +//===----------------------------------------------------------------------===// +// Fuse By Expansion Patterns +//===----------------------------------------------------------------------===// + namespace { // Fold attention with its consumer expand_shape op. @@ -553,6 +696,74 @@ struct FoldScatterNonIterationUnitDims final linalg::ControlDropUnitDims options; }; +struct FoldScatterWithProducerReshapeByExpansion final + : public OpRewritePattern { + FoldScatterWithProducerReshapeByExpansion( + MLIRContext *context, linalg::ControlFusionFn controlFoldingReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFoldingReshapes(std::move(controlFoldingReshapes)) {} + + LogicalResult matchAndRewrite(ScatterOp scatterOp, + PatternRewriter &rewriter) const override { + for (OpOperand &opOperand : scatterOp->getOpOperands()) { + tensor::CollapseShapeOp reshapeOp = + opOperand.get().getDefiningOp(); + if (!reshapeOp) + continue; + if (!controlFoldingReshapes(&opOperand)) + continue; + + std::optional replacementValue = fuseScatterWithReshapeByExpansion( + scatterOp, reshapeOp, &opOperand, rewriter); + if (!replacementValue) + return failure(); + rewriter.replaceOp(scatterOp, *replacementValue); + return success(); + } + return failure(); + } + + linalg::ControlFusionFn controlFoldingReshapes; +}; + +struct FoldScatterWithConsumerReshapeByExpansion final + : public OpRewritePattern { + FoldScatterWithConsumerReshapeByExpansion( + MLIRContext *context, linalg::ControlFusionFn controlFoldingReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFoldingReshapes(std::move(controlFoldingReshapes)) {} + + LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp, + PatternRewriter &rewriter) const override { + auto producerResult = dyn_cast(expandOp.getSrc()); + if (!producerResult) { + return rewriter.notifyMatchFailure(expandOp, + "source not produced by an operation"); + } + + auto scatterOp = producerResult.getDefiningOp(); + if (!scatterOp) { + return failure(); + } + + if (!controlFoldingReshapes(&expandOp.getSrcMutable())) { + return failure(); + } + + std::optional replacementValue = fuseScatterWithReshapeByExpansion( + scatterOp, expandOp, scatterOp.getTiedOpOperand(producerResult), + rewriter); + if (!replacementValue) + return failure(); + rewriter.replaceOp(scatterOp, *replacementValue); + return success(); + } + + linalg::ControlFusionFn controlFoldingReshapes; +}; + } // namespace /// Return the `reassociation` indices to use to collapse the operand when the @@ -773,6 +984,10 @@ void populateFoldReshapeOpsByExpansionPatterns( patterns.getContext(), controlFoldingReshapes); patterns.add( patterns.getContext(), controlFoldingReshapes); + patterns.add( + patterns.getContext(), controlFoldingReshapes); + patterns.add( + patterns.getContext(), controlFoldingReshapes); } SmallVector defaultControlDropUnitDims(Operation *op) { diff --git a/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir b/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir index 25e15537a7f4..b3fdd9038b49 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir @@ -481,3 +481,204 @@ util.func public @dont_sink_through_k2(%0 : tensor<128x64x128x1x1xf16>, %1 : ten // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> // CHECK-SAME: ins(%[[ARG2]], %[[ARG1]], %[[COLLAPSED]], %[[ARG3]] : // CHECK: util.return %[[ATTENTION]] + + +// ----- + +util.func @scatter_collapse_updates(%arg0: tensor<4x?x2x16x4x128xf16>, %arg1: tensor, %arg2: tensor) -> tensor { + %collapsed = tensor.collapse_shape %arg0[[0, 1], [2], [3], [4], [5]] : tensor<4x?x2x16x4x128xf16> into tensor + %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%collapsed, %arg1 : tensor, tensor) outs(%arg2 : tensor) { + ^bb0(%arg7: f16, %arg8: f16): + iree_linalg_ext.yield %arg7 : f16 + } -> tensor + util.return %1 : tensor +} + +// CHECK-LABEL: util.func public @scatter_collapse_updates +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: +// CHECK: %[[INDICES:.+]] = tensor.expand_shape +// CHECK-SAME: tensor into tensor<4x?x1xi32> +// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: ins(%[[ARG0]], %[[INDICES]] +// CHECK-SAME: outs(%[[ARG2]] +// CHECK: util.return %[[SCATTER]] + +// ----- + +util.func @scatter_collapse_updates_partial(%arg0: tensor<4x?x2x2x16x4x128xf16>, %arg1: tensor, %arg2: tensor<10x16x4x128xf16>) -> tensor<10x16x4x128xf16> { + %collapsed = tensor.collapse_shape %arg0[[0, 1], [2, 3], [4], [5], [6]] : tensor<4x?x2x2x16x4x128xf16> into tensor + %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%collapsed, %arg1 : tensor, tensor) outs(%arg2 : tensor<10x16x4x128xf16>) { + ^bb0(%arg7: f16, %arg8: f16): + iree_linalg_ext.yield %arg7 : f16 + } -> tensor<10x16x4x128xf16> + util.return %1 : tensor<10x16x4x128xf16> +} + +// CHECK-LABEL: util.func public @scatter_collapse_updates_partial +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: +// CHECK-DAG: %[[INDICES:.+]] = tensor.expand_shape %[[ARG1]] {{.*}} tensor into tensor<4x?x1xi32> +// CHECK-DAG: %[[UPDATES:.+]] = tensor.collapse_shape %[[ARG0]] {{.*}} tensor<4x?x2x2x16x4x128xf16> into tensor<4x?x4x16x4x128xf16> +// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: ins(%[[UPDATES]], %[[INDICES]] +// CHECK-SAME: outs(%[[ARG2]] +// CHECK: util.return %[[SCATTER]] + +// ----- + +util.func @scatter_collapse_original(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3 : index) -> tensor { + %collapsed = tensor.collapse_shape %arg2 [[0], [1, 2], [3, 4], [5, 6]] : tensor into tensor + %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor, tensor) outs(%collapsed : tensor) { + ^bb0(%arg6: f16, %arg7: f16): + iree_linalg_ext.yield %arg6 : f16 + } -> tensor + util.return %1 : tensor +} + +// CHECK-LABEL: util.func public @scatter_collapse_original +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: +// CHECK: %[[UPDATES:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK-SAME: tensor into tensor +// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: ins(%[[UPDATES]], %[[ARG1]] +// CHECK-SAME: outs(%[[ARG2]] +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[SCATTER]] +// CHECK: util.return %[[COLLAPSE]] + +// ----- + +util.func @scatter_original_noop(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3 : index) -> tensor { + %collapsed = tensor.collapse_shape %arg2 [[0, 1, 2], [3], [4], [5]] : tensor into tensor + %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor, tensor) outs(%collapsed : tensor) { + ^bb0(%arg6: f16, %arg7: f16): + iree_linalg_ext.yield %arg6 : f16 + } -> tensor + util.return %1 : tensor +} + +// CHECK-LABEL: util.func public @scatter_original_noop +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG2]] +// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] +// CHECK-SAME: outs(%[[COLLAPSE]] +// CHECK: util.return %[[SCATTER]] + + +// ----- + +util.func @scatter_collapse_original_partial(%arg0: tensor, %arg1: tensor, %arg2: tensor<5x?x2x16x4x2x64x2xf16>, %arg3 : index) -> tensor { + %collapsed = tensor.collapse_shape %arg2 [[0, 1], [2, 3], [4, 5], [6, 7]] : tensor<5x?x2x16x4x2x64x2xf16> into tensor + %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor, tensor) outs(%collapsed : tensor) { + ^bb0(%arg6: f16, %arg7: f16): + iree_linalg_ext.yield %arg6 : f16 + } -> tensor + util.return %1 : tensor +} + +// CHECK-LABEL: util.func public @scatter_collapse_original_partial +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: +// CHECK-DAG: %[[UPDATES:.+]] = tensor.expand_shape %[[ARG0]] {{.*}} tensor into tensor +// TODO(IanWood1): fix this so the collapse folds with the expand +// CHECK-DAG: %[[ORIGINAL:.+]] = tensor.expand_shape {{.*}} tensor into tensor +// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: ins(%[[UPDATES]], %[[ARG1]] +// CHECK-SAME: outs(%[[ORIGINAL]] +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[SCATTER]] +// CHECK: util.return %[[COLLAPSE]] + +// ----- + +util.func @scatter_collapse_indices(%arg0: tensor, %arg1: tensor<4x?x1xi32>, %arg2: tensor) -> tensor { + %collapsed = tensor.collapse_shape %arg1[[0, 1], [2]] : tensor<4x?x1xi32> into tensor + %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %collapsed: tensor, tensor) outs(%arg2 : tensor) { + ^bb0(%arg7: f16, %arg8: f16): + iree_linalg_ext.yield %arg7 : f16 + } -> tensor + util.return %1 : tensor +} + +// CHECK-LABEL: util.func public @scatter_collapse_indices +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: +// CHECK-DAG: %[[UPDATES:.+]] = tensor.expand_shape %[[ARG0]] {{.*}} tensor into tensor<4x?x2x16x4x128xf16> +// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: ins(%[[UPDATES]], %[[ARG1]] +// CHECK-SAME: outs(%[[ARG2]] +// CHECK: util.return %[[SCATTER]] + +// ----- + +util.func @scatter_collapse_indices_partial(%arg0: tensor, %arg1: tensor<4x?x1x1xi32>, %arg2: tensor) -> tensor { + %collapsed = tensor.collapse_shape %arg1[[0, 1], [2, 3]] : tensor<4x?x1x1xi32> into tensor + %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %collapsed: tensor, tensor) outs(%arg2 : tensor) { + ^bb0(%arg7: f16, %arg8: f16): + iree_linalg_ext.yield %arg7 : f16 + } -> tensor + util.return %1 : tensor +} + +// CHECK-LABEL: util.func public @scatter_collapse_indices_partial +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: +// CHECK-DAG: %[[UPDATES:.+]] = tensor.expand_shape {{.*}} tensor into tensor<4x?x2x16x4x128xf16> +// CHECK-DAG: %[[ORIGINAL:.+]] = tensor.collapse_shape {{.*}} tensor<4x?x1x1xi32> into tensor<4x?x1xi32> +// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: ins(%[[UPDATES]], %[[ORIGINAL]] +// CHECK-SAME: outs(%[[ARG2]] +// CHECK: util.return %[[SCATTER]] + +// ----- + +util.func public @scatter_collapse(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %c0 = arith.constant 0 : index + %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { + ^bb0(%arg3: f16, %arg4: f16): + iree_linalg_ext.yield %arg3 : f16 + } -> tensor + %dim = tensor.dim %arg0, %c0 : tensor + %expanded = tensor.expand_shape %0 [[0], [1], [2], [3], [4, 5]] output_shape [%dim, 2, 16, 4, 4, 32] : tensor into tensor + util.return %expanded : tensor +} +// CHECK-LABEL: util.func public @scatter_collapse +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: +// CHECK-DAG: %[[UPDATES:.+]] = tensor.expand_shape %[[ARG0]] {{.*}} tensor into tensor +// CHECK-DAG: %[[ORIGINAL:.+]] = tensor.expand_shape %[[ARG2]] {{.*}} tensor into tensor +// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: ins(%[[UPDATES]], %[[ARG1]] +// CHECK-SAME: outs(%[[ORIGINAL]] +// CHECK: util.return %[[SCATTER]] + +// ----- + +util.func public @scatter_collapse_noop(%arg0: tensor<10xf16>, %arg1: tensor<10x1xi32>, %arg2: tensor<128xf16>) -> tensor<4x4x4x2xf16> { + %c0 = arith.constant 0 : index + %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor<10xf16>, tensor<10x1xi32>) outs(%arg2 : tensor<128xf16>) { + ^bb0(%arg3: f16, %arg4: f16): + iree_linalg_ext.yield %arg3 : f16 + } -> tensor<128xf16> + %expanded = tensor.expand_shape %0 [[0, 1, 2, 3]] output_shape[4, 4, 4, 2] : tensor<128xf16> into tensor<4x4x4x2xf16> + util.return %expanded : tensor<4x4x4x2xf16> +} +// CHECK-LABEL: util.func public @scatter_collapse_noop +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: +// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] +// CHECK-SAME: outs(%[[ARG2]] +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[SCATTER]] +// CHECK: util.return %[[EXPANDED]]