From 9ab2b9b5408417d51a0209a30afa45d3c8377c33 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Tue, 10 Dec 2024 05:07:54 -0800 Subject: [PATCH 1/4] [LinalgExt] Scatter fusion by expansion Signed-off-by: Ian Wood --- .../LinalgExt/Transforms/ReshapeFusion.cpp | 276 +++++++++++++++++- .../test/attention_fuse_by_expansion.mlir | 201 +++++++++++++ 2 files changed, 471 insertions(+), 6 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp index 901288bd4788..f0f725dc2d52 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp @@ -14,6 +14,7 @@ #include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" @@ -246,8 +247,7 @@ getReassociationForExpansion(AffineMap indexingMap, return reassociation; } -template -static bool isFusableWithReshapeByDimExpansion(OpTy op, +static bool isFusableWithReshapeByDimExpansion(AttentionOp op, OpOperand *fusableOpOperand) { // Is fusable only if: // - All the indexing maps for operands and results are projected @@ -256,10 +256,11 @@ static bool isFusableWithReshapeByDimExpansion(OpTy op, // - All the loops for the reshaped operand are parallel loops. SmallVector iteratorTypes = op.getLoopIteratorTypes(); AffineMap operandMap = op.getMatchingIndexingMap(fusableOpOperand); - return op.hasPureTensorSemantics() && - llvm::all_of( - op.getIndexingMapsArray(), - [](AffineMap map) { return map.isProjectedPermutation(); }) && + return operandMap && op.hasPureTensorSemantics() && + llvm::all_of(op.getIndexingMapsArray(), + [](AffineMap map) { + return map && map.isProjectedPermutation(); + }) && operandMap.getNumResults() > 0; } @@ -391,6 +392,197 @@ static std::optional> fuseAttentionWithReshapeByExpansion( return resultVals; } +namespace { +class ScatterExpansionInfo { +public: + // Helper class similar to `ExpansionInfo` but only for`LinalgExt::ScatterOp` + // due to its special semantics (i.e. not all dims map to the iteration space) + LogicalResult compute(LinalgExt::ScatterOp scatterOp, + OpOperand *fusableOpOperand, + ArrayRef reassociationIndices, + ArrayRef expandedShape, + ArrayRef collapsedShape, + PatternRewriter &rewriter); + + SmallVector updatesReassoc; + SmallVector> updatesShapeMap; + SmallVector indicesReassoc; + SmallVector> indicesShapeMap; + SmallVector originalReassoc; + SmallVector> originalShapeMap; +}; + +} // namespace + +// Use the innermost indices in `reassoc` to construct a shape map out of +// `shape` +static SmallVector> +computeShapeMapFromReassoc(ArrayRef reassoc, + ArrayRef shape) { + SmallVector> shapeMap; + for (auto &indices : reassoc) { + shapeMap.emplace_back(shape.slice(indices.front(), indices.size())); + } + return shapeMap; +} + +static SmallVector +computeReassocFromShapeMap(ArrayRef> shapeMap) { + SmallVector reassoc; + int64_t dimCount = 0; + for (auto &shape : shapeMap) { + reassoc.emplace_back( + llvm::to_vector(llvm::seq(dimCount, dimCount + shape.size()))); + dimCount += shape.size(); + } + return reassoc; +} + +LogicalResult ScatterExpansionInfo::compute( + LinalgExt::ScatterOp scatterOp, OpOperand *fusableOpOperand, + ArrayRef reassociationIndices, + ArrayRef expandedShape, ArrayRef collapsedShape, + PatternRewriter &rewriter) { + if (reassociationIndices.empty()) + return failure(); + assert(fusableOpOperand->getOwner() == scatterOp); + + auto updatesShape = scatterOp.getUpdateType().getShape(); + auto originalShape = scatterOp.getOriginalType().getShape(); + auto rankOfContiguousSlice = + scatterOp.getOriginalType().getRank() - scatterOp.getIndexDepth(); + + // Helper lambda to convert a shape to an identity shape map. + auto shapeToIdentShapeMap = [](ArrayRef shape) { + return llvm::map_to_vector( + shape, [](int64_t dim) { return SmallVector{dim}; }); + }; + + // Set `batchShapeMap` and `sliceShapeMap` based on the specific operand. + int64_t operandNum = fusableOpOperand->getOperandNumber(); + + // In the case of `original`, no chenge to the iteration space + SmallVector> batchShapeMap = + operandNum == ScatterOp::kOriginalOpNum + ? shapeToIdentShapeMap( + originalShape.take_front(scatterOp.getBatchRank())) + : computeShapeMapFromReassoc( + reassociationIndices.take_front(scatterOp.getBatchRank()), + expandedShape); + // In the case of `indices`, no chenge to the iteration space + SmallVector> sliceShapeMap = + operandNum == ScatterOp::kIndicesOpNum + ? shapeToIdentShapeMap(originalShape.take_back(rankOfContiguousSlice)) + : computeShapeMapFromReassoc( + reassociationIndices.take_back(rankOfContiguousSlice), + expandedShape); + + // Early exit if iteration space is unchanged + if (llvm::all_of(batchShapeMap, [&](auto vec) { return vec.size() == 1; }) && + llvm::all_of(sliceShapeMap, [&](auto vec) { return vec.size() == 1; })) { + return failure(); + } + + updatesShapeMap = llvm::to_vector(llvm::concat>( + batchShapeMap, + shapeToIdentShapeMap(updatesShape.slice(scatterOp.getBatchRank(), + scatterOp.getUpdateSliceRank() - + rankOfContiguousSlice)), + sliceShapeMap)); + indicesShapeMap = llvm::to_vector(llvm::concat>( + batchShapeMap, shapeToIdentShapeMap(scatterOp.getIndexDepth()))); + originalShapeMap = llvm::to_vector(llvm::concat>( + shapeToIdentShapeMap(originalShape.drop_back(rankOfContiguousSlice)), + sliceShapeMap)); + + updatesReassoc = computeReassocFromShapeMap(updatesShapeMap); + indicesReassoc = computeReassocFromShapeMap(indicesShapeMap); + originalReassoc = computeReassocFromShapeMap(originalShapeMap); + return success(); +} + +static std::optional +fuseScatterWithReshapeByExpansion(ScatterOp scatterOp, Operation *reshapeOp, + OpOperand *fusableOpOperand, + PatternRewriter &rewriter) { + Location loc = scatterOp.getLoc(); + // Check if reshape is expanding or collapsing. + auto expandingReshapeOp = dyn_cast(*reshapeOp); + auto collapsingReshapeOp = dyn_cast(*reshapeOp); + bool isExpanding = (expandingReshapeOp != nullptr); + RankedTensorType expandedType = isExpanding + ? expandingReshapeOp.getResultType() + : collapsingReshapeOp.getSrcType(); + RankedTensorType collapsedType = isExpanding + ? expandingReshapeOp.getSrcType() + : collapsingReshapeOp.getResultType(); + ScatterExpansionInfo info; + if (failed(info.compute( + scatterOp, fusableOpOperand, + isExpanding ? expandingReshapeOp.getReassociationIndices() + : collapsingReshapeOp.getReassociationIndices(), + expandedType.getShape(), collapsedType.getShape(), rewriter))) { + return std::nullopt; + } + + // Returns `reassociation` with indices modified so that they are a contiguous + // grouping of indices. + auto getType = [&](SmallVector> &shapeMap, + ShapedType type) { + SmallVector flattenedArray; + for (auto &shape : shapeMap) { + flattenedArray.append(shape.begin(), shape.end()); + } + return RankedTensorType::get(flattenedArray, type.getElementType()); + }; + + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(scatterOp); + + auto isIdentityReassoc = [](SmallVector &indices) { + for (auto &index : indices) { + if (index.size() != 1) + return false; + } + return true; + }; + + Value newUpdates = rewriter.create( + loc, getType(info.updatesShapeMap, scatterOp.getUpdateType()), + scatterOp.getUpdates(), info.updatesReassoc); + Value newIndices = + isIdentityReassoc(info.indicesReassoc) + ? scatterOp.getIndices() + : rewriter.create( + loc, getType(info.indicesShapeMap, scatterOp.getIndicesType()), + scatterOp.getIndices(), info.indicesReassoc); + Value newOriginal = + isIdentityReassoc(info.originalReassoc) + ? scatterOp.getOriginal() + : rewriter.create( + loc, + getType(info.originalShapeMap, scatterOp.getOriginalType()), + scatterOp.getOriginal(), info.originalReassoc); + + auto newScatter = rewriter.create( + loc, newOriginal.getType(), ValueRange{newUpdates, newIndices}, + ValueRange{newOriginal}, scatterOp.getDimensionMap(), + scatterOp.getUniqueIndices()); + rewriter.inlineRegionBefore(scatterOp.getRegion(), newScatter.getRegion(), + newScatter.getRegion().begin()); + + // Collapse back to originanl shape. + auto newCollapse = rewriter.create( + loc, scatterOp.getOriginalType(), newScatter.getResult(0), + info.originalReassoc); + + return {newCollapse}; +} + +//===----------------------------------------------------------------------===// +// Fuse By Expansion Patterns +//===----------------------------------------------------------------------===// + namespace { // Fold attention with its consumer expand_shape op. @@ -553,6 +745,74 @@ struct FoldScatterNonIterationUnitDims final linalg::ControlDropUnitDims options; }; +struct FoldScatterWithProducerReshapeByExpansion final + : public OpRewritePattern { + FoldScatterWithProducerReshapeByExpansion( + MLIRContext *context, linalg::ControlFusionFn controlFoldingReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFoldingReshapes(std::move(controlFoldingReshapes)) {} + + LogicalResult matchAndRewrite(ScatterOp scatterOp, + PatternRewriter &rewriter) const override { + for (OpOperand &opOperand : scatterOp->getOpOperands()) { + tensor::CollapseShapeOp reshapeOp = + opOperand.get().getDefiningOp(); + if (!reshapeOp) + continue; + if (!controlFoldingReshapes(&opOperand)) + continue; + + std::optional replacementValue = fuseScatterWithReshapeByExpansion( + scatterOp, reshapeOp, &opOperand, rewriter); + if (!replacementValue) + return failure(); + rewriter.replaceOp(scatterOp, *replacementValue); + return success(); + } + return failure(); + } + + linalg::ControlFusionFn controlFoldingReshapes; +}; + +struct FoldScatterWithConsumerReshapeByExpansion final + : public OpRewritePattern { + FoldScatterWithConsumerReshapeByExpansion( + MLIRContext *context, linalg::ControlFusionFn controlFoldingReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFoldingReshapes(std::move(controlFoldingReshapes)) {} + + LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp, + PatternRewriter &rewriter) const override { + auto producerResult = dyn_cast(expandOp.getSrc()); + if (!producerResult) { + return rewriter.notifyMatchFailure(expandOp, + "source not produced by an operation"); + } + + auto scatterOp = producerResult.getDefiningOp(); + if (!scatterOp) { + return failure(); + } + + if (!controlFoldingReshapes(&expandOp.getSrcMutable())) { + return failure(); + } + + std::optional replacementValue = fuseScatterWithReshapeByExpansion( + scatterOp, expandOp, scatterOp.getTiedOpOperand(producerResult), + rewriter); + if (!replacementValue) + return failure(); + rewriter.replaceOp(scatterOp, *replacementValue); + return success(); + } + + linalg::ControlFusionFn controlFoldingReshapes; +}; + } // namespace /// Return the `reassociation` indices to use to collapse the operand when the @@ -773,6 +1033,10 @@ void populateFoldReshapeOpsByExpansionPatterns( patterns.getContext(), controlFoldingReshapes); patterns.add( patterns.getContext(), controlFoldingReshapes); + patterns.add( + patterns.getContext(), controlFoldingReshapes); + patterns.add( + patterns.getContext(), controlFoldingReshapes); } SmallVector defaultControlDropUnitDims(Operation *op) { diff --git a/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir b/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir index 25e15537a7f4..b3fdd9038b49 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir @@ -481,3 +481,204 @@ util.func public @dont_sink_through_k2(%0 : tensor<128x64x128x1x1xf16>, %1 : ten // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> // CHECK-SAME: ins(%[[ARG2]], %[[ARG1]], %[[COLLAPSED]], %[[ARG3]] : // CHECK: util.return %[[ATTENTION]] + + +// ----- + +util.func @scatter_collapse_updates(%arg0: tensor<4x?x2x16x4x128xf16>, %arg1: tensor, %arg2: tensor) -> tensor { + %collapsed = tensor.collapse_shape %arg0[[0, 1], [2], [3], [4], [5]] : tensor<4x?x2x16x4x128xf16> into tensor + %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%collapsed, %arg1 : tensor, tensor) outs(%arg2 : tensor) { + ^bb0(%arg7: f16, %arg8: f16): + iree_linalg_ext.yield %arg7 : f16 + } -> tensor + util.return %1 : tensor +} + +// CHECK-LABEL: util.func public @scatter_collapse_updates +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: +// CHECK: %[[INDICES:.+]] = tensor.expand_shape +// CHECK-SAME: tensor into tensor<4x?x1xi32> +// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: ins(%[[ARG0]], %[[INDICES]] +// CHECK-SAME: outs(%[[ARG2]] +// CHECK: util.return %[[SCATTER]] + +// ----- + +util.func @scatter_collapse_updates_partial(%arg0: tensor<4x?x2x2x16x4x128xf16>, %arg1: tensor, %arg2: tensor<10x16x4x128xf16>) -> tensor<10x16x4x128xf16> { + %collapsed = tensor.collapse_shape %arg0[[0, 1], [2, 3], [4], [5], [6]] : tensor<4x?x2x2x16x4x128xf16> into tensor + %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%collapsed, %arg1 : tensor, tensor) outs(%arg2 : tensor<10x16x4x128xf16>) { + ^bb0(%arg7: f16, %arg8: f16): + iree_linalg_ext.yield %arg7 : f16 + } -> tensor<10x16x4x128xf16> + util.return %1 : tensor<10x16x4x128xf16> +} + +// CHECK-LABEL: util.func public @scatter_collapse_updates_partial +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: +// CHECK-DAG: %[[INDICES:.+]] = tensor.expand_shape %[[ARG1]] {{.*}} tensor into tensor<4x?x1xi32> +// CHECK-DAG: %[[UPDATES:.+]] = tensor.collapse_shape %[[ARG0]] {{.*}} tensor<4x?x2x2x16x4x128xf16> into tensor<4x?x4x16x4x128xf16> +// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: ins(%[[UPDATES]], %[[INDICES]] +// CHECK-SAME: outs(%[[ARG2]] +// CHECK: util.return %[[SCATTER]] + +// ----- + +util.func @scatter_collapse_original(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3 : index) -> tensor { + %collapsed = tensor.collapse_shape %arg2 [[0], [1, 2], [3, 4], [5, 6]] : tensor into tensor + %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor, tensor) outs(%collapsed : tensor) { + ^bb0(%arg6: f16, %arg7: f16): + iree_linalg_ext.yield %arg6 : f16 + } -> tensor + util.return %1 : tensor +} + +// CHECK-LABEL: util.func public @scatter_collapse_original +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: +// CHECK: %[[UPDATES:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK-SAME: tensor into tensor +// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: ins(%[[UPDATES]], %[[ARG1]] +// CHECK-SAME: outs(%[[ARG2]] +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[SCATTER]] +// CHECK: util.return %[[COLLAPSE]] + +// ----- + +util.func @scatter_original_noop(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3 : index) -> tensor { + %collapsed = tensor.collapse_shape %arg2 [[0, 1, 2], [3], [4], [5]] : tensor into tensor + %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor, tensor) outs(%collapsed : tensor) { + ^bb0(%arg6: f16, %arg7: f16): + iree_linalg_ext.yield %arg6 : f16 + } -> tensor + util.return %1 : tensor +} + +// CHECK-LABEL: util.func public @scatter_original_noop +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG2]] +// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] +// CHECK-SAME: outs(%[[COLLAPSE]] +// CHECK: util.return %[[SCATTER]] + + +// ----- + +util.func @scatter_collapse_original_partial(%arg0: tensor, %arg1: tensor, %arg2: tensor<5x?x2x16x4x2x64x2xf16>, %arg3 : index) -> tensor { + %collapsed = tensor.collapse_shape %arg2 [[0, 1], [2, 3], [4, 5], [6, 7]] : tensor<5x?x2x16x4x2x64x2xf16> into tensor + %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor, tensor) outs(%collapsed : tensor) { + ^bb0(%arg6: f16, %arg7: f16): + iree_linalg_ext.yield %arg6 : f16 + } -> tensor + util.return %1 : tensor +} + +// CHECK-LABEL: util.func public @scatter_collapse_original_partial +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: +// CHECK-DAG: %[[UPDATES:.+]] = tensor.expand_shape %[[ARG0]] {{.*}} tensor into tensor +// TODO(IanWood1): fix this so the collapse folds with the expand +// CHECK-DAG: %[[ORIGINAL:.+]] = tensor.expand_shape {{.*}} tensor into tensor +// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: ins(%[[UPDATES]], %[[ARG1]] +// CHECK-SAME: outs(%[[ORIGINAL]] +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[SCATTER]] +// CHECK: util.return %[[COLLAPSE]] + +// ----- + +util.func @scatter_collapse_indices(%arg0: tensor, %arg1: tensor<4x?x1xi32>, %arg2: tensor) -> tensor { + %collapsed = tensor.collapse_shape %arg1[[0, 1], [2]] : tensor<4x?x1xi32> into tensor + %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %collapsed: tensor, tensor) outs(%arg2 : tensor) { + ^bb0(%arg7: f16, %arg8: f16): + iree_linalg_ext.yield %arg7 : f16 + } -> tensor + util.return %1 : tensor +} + +// CHECK-LABEL: util.func public @scatter_collapse_indices +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: +// CHECK-DAG: %[[UPDATES:.+]] = tensor.expand_shape %[[ARG0]] {{.*}} tensor into tensor<4x?x2x16x4x128xf16> +// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: ins(%[[UPDATES]], %[[ARG1]] +// CHECK-SAME: outs(%[[ARG2]] +// CHECK: util.return %[[SCATTER]] + +// ----- + +util.func @scatter_collapse_indices_partial(%arg0: tensor, %arg1: tensor<4x?x1x1xi32>, %arg2: tensor) -> tensor { + %collapsed = tensor.collapse_shape %arg1[[0, 1], [2, 3]] : tensor<4x?x1x1xi32> into tensor + %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %collapsed: tensor, tensor) outs(%arg2 : tensor) { + ^bb0(%arg7: f16, %arg8: f16): + iree_linalg_ext.yield %arg7 : f16 + } -> tensor + util.return %1 : tensor +} + +// CHECK-LABEL: util.func public @scatter_collapse_indices_partial +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: +// CHECK-DAG: %[[UPDATES:.+]] = tensor.expand_shape {{.*}} tensor into tensor<4x?x2x16x4x128xf16> +// CHECK-DAG: %[[ORIGINAL:.+]] = tensor.collapse_shape {{.*}} tensor<4x?x1x1xi32> into tensor<4x?x1xi32> +// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: ins(%[[UPDATES]], %[[ORIGINAL]] +// CHECK-SAME: outs(%[[ARG2]] +// CHECK: util.return %[[SCATTER]] + +// ----- + +util.func public @scatter_collapse(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %c0 = arith.constant 0 : index + %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { + ^bb0(%arg3: f16, %arg4: f16): + iree_linalg_ext.yield %arg3 : f16 + } -> tensor + %dim = tensor.dim %arg0, %c0 : tensor + %expanded = tensor.expand_shape %0 [[0], [1], [2], [3], [4, 5]] output_shape [%dim, 2, 16, 4, 4, 32] : tensor into tensor + util.return %expanded : tensor +} +// CHECK-LABEL: util.func public @scatter_collapse +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: +// CHECK-DAG: %[[UPDATES:.+]] = tensor.expand_shape %[[ARG0]] {{.*}} tensor into tensor +// CHECK-DAG: %[[ORIGINAL:.+]] = tensor.expand_shape %[[ARG2]] {{.*}} tensor into tensor +// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: ins(%[[UPDATES]], %[[ARG1]] +// CHECK-SAME: outs(%[[ORIGINAL]] +// CHECK: util.return %[[SCATTER]] + +// ----- + +util.func public @scatter_collapse_noop(%arg0: tensor<10xf16>, %arg1: tensor<10x1xi32>, %arg2: tensor<128xf16>) -> tensor<4x4x4x2xf16> { + %c0 = arith.constant 0 : index + %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor<10xf16>, tensor<10x1xi32>) outs(%arg2 : tensor<128xf16>) { + ^bb0(%arg3: f16, %arg4: f16): + iree_linalg_ext.yield %arg3 : f16 + } -> tensor<128xf16> + %expanded = tensor.expand_shape %0 [[0, 1, 2, 3]] output_shape[4, 4, 4, 2] : tensor<128xf16> into tensor<4x4x4x2xf16> + util.return %expanded : tensor<4x4x4x2xf16> +} +// CHECK-LABEL: util.func public @scatter_collapse_noop +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: +// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] +// CHECK-SAME: outs(%[[ARG2]] +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[SCATTER]] +// CHECK: util.return %[[EXPANDED]] From 8b870e62a8ac6bcb81099599c8b65d85aca7fde7 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Wed, 8 Jan 2025 01:26:56 -0800 Subject: [PATCH 2/4] Refactored ExpansionInfo to also support Scatter Signed-off-by: Ian Wood --- .../LinalgExt/Transforms/ReshapeFusion.cpp | 430 +++++++++--------- 1 file changed, 225 insertions(+), 205 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp index f0f725dc2d52..7224aa58c135 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp @@ -12,6 +12,8 @@ #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" @@ -21,8 +23,47 @@ namespace mlir::iree_compiler::IREE::LinalgExt { +// Use the innermost indices in `reassoc` to construct a shape map out of +// `shape` +static SmallVector> +computeShapeMapFromReassoc(ArrayRef reassoc, + ArrayRef shape) { + SmallVector> shapeMap; + for (auto &indices : reassoc) { + shapeMap.emplace_back(shape.slice(indices.front(), indices.size())); + } + return shapeMap; +} + +static SmallVector +computeReassocFromShapeMap(ArrayRef> shapeMap) { + SmallVector reassoc; + int64_t dimCount = 0; + for (auto &shape : shapeMap) { + reassoc.emplace_back( + llvm::to_vector(llvm::seq(dimCount, dimCount + shape.size()))); + dimCount += shape.size(); + } + return reassoc; +} + namespace { +/// Helper class that supports fusing reshapes with operands when not all of the +/// shape dims map to the iteration space. +struct ReshapeOperandInfo { +public: + static constexpr int64_t kNoMapping = -1; + + // Original shape of this operand. + ArrayRef originalShape; + + // Similar to the results of the operand's `AffineMap` except `kNoMapping` if + // that dim doesn't map to the iteration space. For example, the indexed + // dimensions in a LinalgExt::ScatterOp. + SmallVector operandToIterationSpace; +}; + /// Information needed to expand an operation to fold the reshape with /// it. class ExpansionInfo { @@ -31,32 +72,50 @@ class ExpansionInfo { // of the expanded op given the `indexingMap` of the fused operand/result of // the op, the `reassocationMaps` of the reshape op and the shape of // the expanded op. - template - LogicalResult compute(OpTy op, OpOperand *fusableOpOperand, - ArrayRef reassociationMaps, - ArrayRef expandedShape, - ArrayRef collapsedShape, - PatternRewriter &rewriter); - unsigned getOrigOpNumDims() const { return reassociation.size(); } - unsigned getExpandedOpNumDims() const { return expandedOpNumDims; } - ReassociationIndicesRef getExpandedDims(unsigned i) const { - return reassociation[i]; + LogicalResult compute(SmallVector infos, + SmallVector loopRanges, + OpOperand *fusableOpOperand, + ArrayRef operandReassoc, + ArrayRef expandedShape); + + SmallVector> getShapeMap(OpOperand *operand) const { + auto info = reshapeInfos[operand->getOperandNumber()]; + SmallVector> shapeMap; + for (auto [operandIdx, loopIdx] : + llvm::enumerate(info.operandToIterationSpace)) { + if (loopIdx == ReshapeOperandInfo::kNoMapping) { + shapeMap.push_back( + SmallVector{info.originalShape[operandIdx]}); + } else { + shapeMap.push_back(loopShapeMap[loopIdx]); + } + } + return shapeMap; + } + + unsigned getOrigNumLoops() const { return loopReassoc.size(); } + unsigned getExpandedNumLoops() const { return expandedOpNumDims; } + ReassociationIndicesRef getExpandedLoops(unsigned i) const { + return loopReassoc[i]; + } + ArrayRef getExpandedShapeOfLoop(unsigned i) const { + return loopShapeMap[i]; } - ArrayRef getExpandedShapeOfDim(unsigned i) const { - return expandedShapeMap[i]; + + SmallVector getReassoc(OpOperand *operand) const { + return computeReassocFromShapeMap(getShapeMap(operand)); } - ArrayRef getOriginalShape() const { return originalLoopExtent; } private: - /// Reassociation from the dimensions in the original operation to the - /// dimension of the expanded operation. - SmallVector reassociation; + /// Extent of the iteration space in the original operation. + SmallVector loopRanges; + SmallVector loopReassoc; /// Mapping from extent of loops in the original operation, to the extent of /// loops in the expanded operation. - SmallVector> expandedShapeMap; - /// Extent of the loop in the original operation. - SmallVector originalLoopExtent; + SmallVector> loopShapeMap; unsigned expandedOpNumDims; + /// Info about the reassociation and original shape for each operand. + SmallVector reshapeInfos; }; class CollapsingInfo { @@ -110,50 +169,46 @@ class CollapsingInfo { } // namespace -template -LogicalResult ExpansionInfo::compute(OpTy op, OpOperand *fusableOpOperand, - ArrayRef reassociationMaps, - ArrayRef expandedShape, - ArrayRef collapsedShape, - PatternRewriter &rewriter) { - if (reassociationMaps.empty()) - return failure(); - AffineMap fusedIndexMap = op.getMatchingIndexingMap(fusableOpOperand); - FailureOr> originalLoopRange = op.getStaticLoopRanges(); - if (failed(originalLoopRange)) { +LogicalResult ExpansionInfo::compute( + SmallVector infos, SmallVector loopRanges, + OpOperand *fusableOpOperand, ArrayRef operandReassoc, + ArrayRef expandedShape) { + if (operandReassoc.empty()) return failure(); + + int64_t operandNum = fusableOpOperand->getOperandNumber(); + ReshapeOperandInfo &fusionOperandInfo = infos[operandNum]; + this->loopShapeMap.clear(); + this->loopShapeMap.resize(loopRanges.size()); + for (auto [operandIdx, loopIdx] : + llvm::enumerate(fusionOperandInfo.operandToIterationSpace)) { + if (loopIdx == ReshapeOperandInfo::kNoMapping) { + continue; + } + + // Compute the shape map at element `loopIdx` + ReassociationIndicesRef indices = operandReassoc[operandIdx]; + for (auto [dimIdx, shapeIdx] : llvm::enumerate(indices)) { + this->loopShapeMap[loopIdx].push_back(expandedShape[shapeIdx]); + } } - originalLoopExtent.assign(originalLoopRange->begin(), - originalLoopRange->end()); - - reassociation.clear(); - expandedShapeMap.clear(); - // Compute the number of dimension in the expanded op that correspond to each - // dimension of the original op. - SmallVector numExpandedDims(fusedIndexMap.getNumDims(), 1); - expandedShapeMap.resize(fusedIndexMap.getNumDims()); - for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) { - unsigned pos = cast(resultExpr.value()).getPosition(); - AffineMap foldedDims = reassociationMaps[resultExpr.index()]; - numExpandedDims[pos] = foldedDims.getNumResults(); - ArrayRef shape = - expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]); - expandedShapeMap[pos].assign(shape.begin(), shape.end()); + + // Fill in the remaining elements with `loopRanges` + this->expandedOpNumDims = 0; + for (const auto &[loopIdx, shapeMap] : llvm::enumerate(this->loopShapeMap)) { + if (shapeMap.empty()) { + this->loopShapeMap[loopIdx] = SmallVector{loopRanges[loopIdx]}; + } + this->expandedOpNumDims += shapeMap.size(); } - // The remaining dimensions remain the same. - for (unsigned i : llvm::seq(0, fusedIndexMap.getNumDims())) - if (expandedShapeMap[i].empty()) - expandedShapeMap[i] = {originalLoopExtent[i]}; - - // Compute reassociation map from the original op to the expanded op. - unsigned sum = 0; - reassociation.reserve(fusedIndexMap.getNumDims()); - for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) { - auto seq = llvm::seq(sum, sum + numFoldedDim.value()); - reassociation.emplace_back(seq.begin(), seq.end()); - sum += numFoldedDim.value(); + + if (llvm::all_of(this->loopShapeMap, + [&](auto vec) { return vec.size() == 1; })) { + return failure(); } - expandedOpNumDims = sum; + this->loopReassoc = computeReassocFromShapeMap(this->loopShapeMap); + this->reshapeInfos = std::move(infos); + this->loopRanges = std::move(loopRanges); return success(); } @@ -202,6 +257,77 @@ CollapsingInfo::initialize(unsigned origNumLoops, return success(); } +static SmallVector +getAttentionReshapeInfo(LinalgExt::AttentionOp attentionOp) { + return llvm::map_to_vector( + attentionOp->getOpOperands(), [&](OpOperand &opOperand) { + ReshapeOperandInfo operandInfo; + auto operandType = dyn_cast(opOperand.get().getType()); + if (!operandType) { + assert( + attentionOp.getMatchingIndexingMap(&opOperand).getNumResults() == + 0 && + "expected non-shaped type to have no results in indexing map"); + return operandInfo; + } + + operandInfo.originalShape = operandType.getShape(); + for (auto result : + attentionOp.getMatchingIndexingMap(&opOperand).getResults()) { + operandInfo.operandToIterationSpace.push_back( + cast(result).getPosition()); + } + return operandInfo; + }); +} + +static SmallVector +getScatterReshapeInfo(LinalgExt::ScatterOp scatterOp) { + SmallVector infos; + auto rankOfContiguousSlice = + scatterOp.getOriginalType().getRank() - scatterOp.getIndexDepth(); + auto updateRank = scatterOp.getUpdateType().getRank(); + + // Operand #0 Updates + { + ReshapeOperandInfo updateInfo; + updateInfo.originalShape = scatterOp.getUpdateType().getShape(); + llvm::append_range(updateInfo.operandToIterationSpace, + llvm::seq(0, scatterOp.getBatchRank())); + updateInfo.operandToIterationSpace.append( + updateRank - (rankOfContiguousSlice + scatterOp.getBatchRank()), + ReshapeOperandInfo::kNoMapping); + llvm::append_range( + updateInfo.operandToIterationSpace, + llvm::seq(updateRank - rankOfContiguousSlice, updateRank)); + infos.push_back(std::move(updateInfo)); + } + + // Operand#1 Indices + { + ReshapeOperandInfo indicesInfo; + indicesInfo.originalShape = scatterOp.getIndicesType().getShape(); + llvm::append_range(indicesInfo.operandToIterationSpace, + llvm::seq(0, scatterOp.getBatchRank())); + indicesInfo.operandToIterationSpace.push_back( + ReshapeOperandInfo::kNoMapping); + infos.push_back(std::move(indicesInfo)); + } + + // Operand #2 Original + { + ReshapeOperandInfo originalInfo; + originalInfo.originalShape = scatterOp.getOriginalType().getShape(); + originalInfo.operandToIterationSpace.append(scatterOp.getIndexDepth(), + ReshapeOperandInfo::kNoMapping); + llvm::append_range( + originalInfo.operandToIterationSpace, + llvm::seq(updateRank - rankOfContiguousSlice, updateRank)); + infos.push_back(std::move(originalInfo)); + } + return infos; +}; + static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo) { @@ -209,12 +335,12 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, for (AffineExpr expr : indexingMap.getResults()) { unsigned pos = cast(expr).getPosition(); auto expandedExprs = llvm::to_vector_of( - llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) { + llvm::map_range(expansionInfo.getExpandedLoops(pos), [&](int64_t v) { return builder.getAffineDimExpr(static_cast(v)); })); newExprs.append(expandedExprs.begin(), expandedExprs.end()); } - return AffineMap::get(expansionInfo.getExpandedOpNumDims(), + return AffineMap::get(expansionInfo.getExpandedNumLoops(), indexingMap.getNumSymbols(), newExprs, builder.getContext()); } @@ -225,7 +351,7 @@ static RankedTensorType getExpandedType(RankedTensorType originalType, SmallVector expandedShape; for (AffineExpr expr : indexingMap.getResults()) { unsigned dim = cast(expr).getPosition(); - auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim); + auto dimExpansion = expansionInfo.getExpandedShapeOfLoop(dim); expandedShape.append(dimExpansion.begin(), dimExpansion.end()); } return RankedTensorType::get(expandedShape, originalType.getElementType()); @@ -238,7 +364,7 @@ getReassociationForExpansion(AffineMap indexingMap, unsigned numReshapeDims = 0; for (AffineExpr expr : indexingMap.getResults()) { unsigned dim = cast(expr).getPosition(); - auto numExpandedDims = expansionInfo.getExpandedDims(dim).size(); + auto numExpandedDims = expansionInfo.getExpandedLoops(dim).size(); SmallVector indices = llvm::to_vector<2>( llvm::seq(numReshapeDims, numReshapeDims + numExpandedDims)); reassociation.emplace_back(std::move(indices)); @@ -278,16 +404,13 @@ static std::optional> fuseAttentionWithReshapeByExpansion( RankedTensorType expandedType = isExpanding ? expandingReshapeOp.getResultType() : collapsingReshapeOp.getSrcType(); - RankedTensorType collapsedType = isExpanding - ? expandingReshapeOp.getSrcType() - : collapsingReshapeOp.getResultType(); - ExpansionInfo expansionInfo; if (failed(expansionInfo.compute( - attentionOp, fusableOpOperand, - isExpanding ? expandingReshapeOp.getReassociationMaps() - : collapsingReshapeOp.getReassociationMaps(), - expandedType.getShape(), collapsedType.getShape(), rewriter))) + getAttentionReshapeInfo(attentionOp), + attentionOp.getStaticLoopRanges().value(), fusableOpOperand, + isExpanding ? expandingReshapeOp.getReassociationIndices() + : collapsingReshapeOp.getReassociationIndices(), + expandedType.getShape()))) return std::nullopt; auto expandedOpIndexingMaps = llvm::to_vector_of( llvm::map_range(attentionOp.getIndexingMapsArray(), [&](AffineMap m) { @@ -392,115 +515,6 @@ static std::optional> fuseAttentionWithReshapeByExpansion( return resultVals; } -namespace { -class ScatterExpansionInfo { -public: - // Helper class similar to `ExpansionInfo` but only for`LinalgExt::ScatterOp` - // due to its special semantics (i.e. not all dims map to the iteration space) - LogicalResult compute(LinalgExt::ScatterOp scatterOp, - OpOperand *fusableOpOperand, - ArrayRef reassociationIndices, - ArrayRef expandedShape, - ArrayRef collapsedShape, - PatternRewriter &rewriter); - - SmallVector updatesReassoc; - SmallVector> updatesShapeMap; - SmallVector indicesReassoc; - SmallVector> indicesShapeMap; - SmallVector originalReassoc; - SmallVector> originalShapeMap; -}; - -} // namespace - -// Use the innermost indices in `reassoc` to construct a shape map out of -// `shape` -static SmallVector> -computeShapeMapFromReassoc(ArrayRef reassoc, - ArrayRef shape) { - SmallVector> shapeMap; - for (auto &indices : reassoc) { - shapeMap.emplace_back(shape.slice(indices.front(), indices.size())); - } - return shapeMap; -} - -static SmallVector -computeReassocFromShapeMap(ArrayRef> shapeMap) { - SmallVector reassoc; - int64_t dimCount = 0; - for (auto &shape : shapeMap) { - reassoc.emplace_back( - llvm::to_vector(llvm::seq(dimCount, dimCount + shape.size()))); - dimCount += shape.size(); - } - return reassoc; -} - -LogicalResult ScatterExpansionInfo::compute( - LinalgExt::ScatterOp scatterOp, OpOperand *fusableOpOperand, - ArrayRef reassociationIndices, - ArrayRef expandedShape, ArrayRef collapsedShape, - PatternRewriter &rewriter) { - if (reassociationIndices.empty()) - return failure(); - assert(fusableOpOperand->getOwner() == scatterOp); - - auto updatesShape = scatterOp.getUpdateType().getShape(); - auto originalShape = scatterOp.getOriginalType().getShape(); - auto rankOfContiguousSlice = - scatterOp.getOriginalType().getRank() - scatterOp.getIndexDepth(); - - // Helper lambda to convert a shape to an identity shape map. - auto shapeToIdentShapeMap = [](ArrayRef shape) { - return llvm::map_to_vector( - shape, [](int64_t dim) { return SmallVector{dim}; }); - }; - - // Set `batchShapeMap` and `sliceShapeMap` based on the specific operand. - int64_t operandNum = fusableOpOperand->getOperandNumber(); - - // In the case of `original`, no chenge to the iteration space - SmallVector> batchShapeMap = - operandNum == ScatterOp::kOriginalOpNum - ? shapeToIdentShapeMap( - originalShape.take_front(scatterOp.getBatchRank())) - : computeShapeMapFromReassoc( - reassociationIndices.take_front(scatterOp.getBatchRank()), - expandedShape); - // In the case of `indices`, no chenge to the iteration space - SmallVector> sliceShapeMap = - operandNum == ScatterOp::kIndicesOpNum - ? shapeToIdentShapeMap(originalShape.take_back(rankOfContiguousSlice)) - : computeShapeMapFromReassoc( - reassociationIndices.take_back(rankOfContiguousSlice), - expandedShape); - - // Early exit if iteration space is unchanged - if (llvm::all_of(batchShapeMap, [&](auto vec) { return vec.size() == 1; }) && - llvm::all_of(sliceShapeMap, [&](auto vec) { return vec.size() == 1; })) { - return failure(); - } - - updatesShapeMap = llvm::to_vector(llvm::concat>( - batchShapeMap, - shapeToIdentShapeMap(updatesShape.slice(scatterOp.getBatchRank(), - scatterOp.getUpdateSliceRank() - - rankOfContiguousSlice)), - sliceShapeMap)); - indicesShapeMap = llvm::to_vector(llvm::concat>( - batchShapeMap, shapeToIdentShapeMap(scatterOp.getIndexDepth()))); - originalShapeMap = llvm::to_vector(llvm::concat>( - shapeToIdentShapeMap(originalShape.drop_back(rankOfContiguousSlice)), - sliceShapeMap)); - - updatesReassoc = computeReassocFromShapeMap(updatesShapeMap); - indicesReassoc = computeReassocFromShapeMap(indicesShapeMap); - originalReassoc = computeReassocFromShapeMap(originalShapeMap); - return success(); -} - static std::optional fuseScatterWithReshapeByExpansion(ScatterOp scatterOp, Operation *reshapeOp, OpOperand *fusableOpOperand, @@ -513,21 +527,19 @@ fuseScatterWithReshapeByExpansion(ScatterOp scatterOp, Operation *reshapeOp, RankedTensorType expandedType = isExpanding ? expandingReshapeOp.getResultType() : collapsingReshapeOp.getSrcType(); - RankedTensorType collapsedType = isExpanding - ? expandingReshapeOp.getSrcType() - : collapsingReshapeOp.getResultType(); - ScatterExpansionInfo info; + ExpansionInfo info; if (failed(info.compute( - scatterOp, fusableOpOperand, + getScatterReshapeInfo(scatterOp), + scatterOp.getStaticLoopRanges().value(), fusableOpOperand, isExpanding ? expandingReshapeOp.getReassociationIndices() : collapsingReshapeOp.getReassociationIndices(), - expandedType.getShape(), collapsedType.getShape(), rewriter))) { + expandedType.getShape()))) { return std::nullopt; } // Returns `reassociation` with indices modified so that they are a contiguous // grouping of indices. - auto getType = [&](SmallVector> &shapeMap, + auto getType = [&](const SmallVector> &shapeMap, ShapedType type) { SmallVector flattenedArray; for (auto &shape : shapeMap) { @@ -539,30 +551,38 @@ fuseScatterWithReshapeByExpansion(ScatterOp scatterOp, Operation *reshapeOp, OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(scatterOp); - auto isIdentityReassoc = [](SmallVector &indices) { - for (auto &index : indices) { - if (index.size() != 1) - return false; - } - return true; - }; + auto isIdentityReassoc = + [](const SmallVector &indices) { + for (auto &index : indices) { + if (index.size() != 1) + return false; + } + return true; + }; + + OpOperand *update = scatterOp.getDpsInputOperand(0); + OpOperand *indices = scatterOp.getDpsInputOperand(1); + OpOperand *original = scatterOp.getDpsInitOperand(0); + RankedTensorType newIndicesType = + getType(info.getShapeMap(indices), scatterOp.getIndicesType()); + RankedTensorType newOriginalType = + getType(info.getShapeMap(original), scatterOp.getOriginalType()); + SmallVector indicesReassoc = info.getReassoc(indices); + SmallVector originalReassoc = info.getReassoc(original); Value newUpdates = rewriter.create( - loc, getType(info.updatesShapeMap, scatterOp.getUpdateType()), - scatterOp.getUpdates(), info.updatesReassoc); + loc, getType(info.getShapeMap(update), scatterOp.getUpdateType()), + scatterOp.getUpdates(), info.getReassoc(update)); Value newIndices = - isIdentityReassoc(info.indicesReassoc) + isIdentityReassoc(indicesReassoc) ? scatterOp.getIndices() : rewriter.create( - loc, getType(info.indicesShapeMap, scatterOp.getIndicesType()), - scatterOp.getIndices(), info.indicesReassoc); + loc, newIndicesType, scatterOp.getIndices(), indicesReassoc); Value newOriginal = - isIdentityReassoc(info.originalReassoc) + isIdentityReassoc(originalReassoc) ? scatterOp.getOriginal() : rewriter.create( - loc, - getType(info.originalShapeMap, scatterOp.getOriginalType()), - scatterOp.getOriginal(), info.originalReassoc); + loc, newOriginalType, scatterOp.getOriginal(), originalReassoc); auto newScatter = rewriter.create( loc, newOriginal.getType(), ValueRange{newUpdates, newIndices}, @@ -571,10 +591,10 @@ fuseScatterWithReshapeByExpansion(ScatterOp scatterOp, Operation *reshapeOp, rewriter.inlineRegionBefore(scatterOp.getRegion(), newScatter.getRegion(), newScatter.getRegion().begin()); - // Collapse back to originanl shape. + // Collapse back to original shape. auto newCollapse = rewriter.create( loc, scatterOp.getOriginalType(), newScatter.getResult(0), - info.originalReassoc); + info.getReassoc(original)); return {newCollapse}; } From 6b8773c87bc3cae4eb1604d61936ac547d61dbc2 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Wed, 8 Jan 2025 03:02:04 -0800 Subject: [PATCH 3/4] Further refactoring and cleanup Signed-off-by: Ian Wood --- .../LinalgExt/Transforms/ReshapeFusion.cpp | 94 +++++++------------ 1 file changed, 36 insertions(+), 58 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp index 7224aa58c135..0a04fa10fcbb 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp @@ -12,8 +12,6 @@ #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/Sequence.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" @@ -23,17 +21,14 @@ namespace mlir::iree_compiler::IREE::LinalgExt { -// Use the innermost indices in `reassoc` to construct a shape map out of -// `shape` -static SmallVector> -computeShapeMapFromReassoc(ArrayRef reassoc, - ArrayRef shape) { - SmallVector> shapeMap; - for (auto &indices : reassoc) { - shapeMap.emplace_back(shape.slice(indices.front(), indices.size())); +static bool +isIdentityReassoc(const SmallVector &indices) { + for (auto &index : indices) { + if (index.size() != 1) + return false; } - return shapeMap; -} + return true; +}; static SmallVector computeReassocFromShapeMap(ArrayRef> shapeMap) { @@ -52,7 +47,6 @@ namespace { /// Helper class that supports fusing reshapes with operands when not all of the /// shape dims map to the iteration space. struct ReshapeOperandInfo { -public: static constexpr int64_t kNoMapping = -1; // Original shape of this operand. @@ -78,6 +72,25 @@ class ExpansionInfo { ArrayRef operandReassoc, ArrayRef expandedShape); + Value getOrCreateExpanded(Location loc, OpOperand *operand, + RewriterBase &rewriter) { + auto shapeMap = this->getShapeMap(operand); + auto reassoc = computeReassocFromShapeMap(shapeMap); + if (isIdentityReassoc(reassoc)) { + return operand->get(); + } + SmallVector flattenedArray; + for (auto &shape : shapeMap) { + flattenedArray.append(shape.begin(), shape.end()); + } + auto newType = RankedTensorType::get( + flattenedArray, + cast(operand->get().getType()).getElementType()); + return rewriter.create(loc, newType, operand->get(), + reassoc); + }; + + /// Get the shape map for the operand. SmallVector> getShapeMap(OpOperand *operand) const { auto info = reshapeInfos[operand->getOperandNumber()]; SmallVector> shapeMap; @@ -102,10 +115,6 @@ class ExpansionInfo { return loopShapeMap[i]; } - SmallVector getReassoc(OpOperand *operand) const { - return computeReassocFromShapeMap(getShapeMap(operand)); - } - private: /// Extent of the iteration space in the original operation. SmallVector loopRanges; @@ -537,52 +546,15 @@ fuseScatterWithReshapeByExpansion(ScatterOp scatterOp, Operation *reshapeOp, return std::nullopt; } - // Returns `reassociation` with indices modified so that they are a contiguous - // grouping of indices. - auto getType = [&](const SmallVector> &shapeMap, - ShapedType type) { - SmallVector flattenedArray; - for (auto &shape : shapeMap) { - flattenedArray.append(shape.begin(), shape.end()); - } - return RankedTensorType::get(flattenedArray, type.getElementType()); - }; - OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(scatterOp); - auto isIdentityReassoc = - [](const SmallVector &indices) { - for (auto &index : indices) { - if (index.size() != 1) - return false; - } - return true; - }; - OpOperand *update = scatterOp.getDpsInputOperand(0); OpOperand *indices = scatterOp.getDpsInputOperand(1); OpOperand *original = scatterOp.getDpsInitOperand(0); - RankedTensorType newIndicesType = - getType(info.getShapeMap(indices), scatterOp.getIndicesType()); - RankedTensorType newOriginalType = - getType(info.getShapeMap(original), scatterOp.getOriginalType()); - SmallVector indicesReassoc = info.getReassoc(indices); - SmallVector originalReassoc = info.getReassoc(original); - - Value newUpdates = rewriter.create( - loc, getType(info.getShapeMap(update), scatterOp.getUpdateType()), - scatterOp.getUpdates(), info.getReassoc(update)); - Value newIndices = - isIdentityReassoc(indicesReassoc) - ? scatterOp.getIndices() - : rewriter.create( - loc, newIndicesType, scatterOp.getIndices(), indicesReassoc); - Value newOriginal = - isIdentityReassoc(originalReassoc) - ? scatterOp.getOriginal() - : rewriter.create( - loc, newOriginalType, scatterOp.getOriginal(), originalReassoc); + Value newUpdates = info.getOrCreateExpanded(loc, update, rewriter); + Value newIndices = info.getOrCreateExpanded(loc, indices, rewriter); + Value newOriginal = info.getOrCreateExpanded(loc, original, rewriter); auto newScatter = rewriter.create( loc, newOriginal.getType(), ValueRange{newUpdates, newIndices}, @@ -592,9 +564,15 @@ fuseScatterWithReshapeByExpansion(ScatterOp scatterOp, Operation *reshapeOp, newScatter.getRegion().begin()); // Collapse back to original shape. + auto originalShapeMap = info.getShapeMap(original); + SmallVector originalReassoc = + computeReassocFromShapeMap(originalShapeMap); + if (isIdentityReassoc(originalReassoc)) { + return {newScatter.getResult(0)}; + } auto newCollapse = rewriter.create( loc, scatterOp.getOriginalType(), newScatter.getResult(0), - info.getReassoc(original)); + originalReassoc); return {newCollapse}; } From b1611ae0d39c79060e8951c4dbe49af94445f695 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Wed, 8 Jan 2025 05:32:35 -0800 Subject: [PATCH 4/4] Move ExpandShape logic to getOrCreateExpanded Also, clean up some duplicated logic Signed-off-by: Ian Wood --- .../LinalgExt/Transforms/ReshapeFusion.cpp | 119 ++++++------------ 1 file changed, 36 insertions(+), 83 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp index 0a04fa10fcbb..6ad28e5b6d4d 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp @@ -72,8 +72,8 @@ class ExpansionInfo { ArrayRef operandReassoc, ArrayRef expandedShape); - Value getOrCreateExpanded(Location loc, OpOperand *operand, - RewriterBase &rewriter) { + std::optional getOrCreateExpanded(Location loc, OpOperand *operand, + RewriterBase &rewriter) { auto shapeMap = this->getShapeMap(operand); auto reassoc = computeReassocFromShapeMap(shapeMap); if (isIdentityReassoc(reassoc)) { @@ -83,9 +83,17 @@ class ExpansionInfo { for (auto &shape : shapeMap) { flattenedArray.append(shape.begin(), shape.end()); } - auto newType = RankedTensorType::get( - flattenedArray, - cast(operand->get().getType()).getElementType()); + auto oldType = cast(operand->get().getType()); + auto newType = + RankedTensorType::get(flattenedArray, oldType.getElementType()); + if (failed(reshapeLikeShapesAreCompatible( + [&](const Twine &msg) { + return rewriter.notifyMatchFailure(loc, msg); + }, + oldType.getShape(), newType.getShape(), reassoc, + /*isExpandingReshape=*/true))) { + return {}; + } return rewriter.create(loc, newType, operand->get(), reassoc); }; @@ -106,6 +114,11 @@ class ExpansionInfo { return shapeMap; } + SmallVector getReassoc(OpOperand *operand) const { + auto shapeMap = this->getShapeMap(operand); + return computeReassocFromShapeMap(shapeMap); + } + unsigned getOrigNumLoops() const { return loopReassoc.size(); } unsigned getExpandedNumLoops() const { return expandedOpNumDims; } ReassociationIndicesRef getExpandedLoops(unsigned i) const { @@ -354,34 +367,6 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, builder.getContext()); } -static RankedTensorType getExpandedType(RankedTensorType originalType, - AffineMap indexingMap, - const ExpansionInfo &expansionInfo) { - SmallVector expandedShape; - for (AffineExpr expr : indexingMap.getResults()) { - unsigned dim = cast(expr).getPosition(); - auto dimExpansion = expansionInfo.getExpandedShapeOfLoop(dim); - expandedShape.append(dimExpansion.begin(), dimExpansion.end()); - } - return RankedTensorType::get(expandedShape, originalType.getElementType()); -} - -static SmallVector -getReassociationForExpansion(AffineMap indexingMap, - const ExpansionInfo &expansionInfo) { - SmallVector reassociation; - unsigned numReshapeDims = 0; - for (AffineExpr expr : indexingMap.getResults()) { - unsigned dim = cast(expr).getPosition(); - auto numExpandedDims = expansionInfo.getExpandedLoops(dim).size(); - SmallVector indices = llvm::to_vector<2>( - llvm::seq(numReshapeDims, numReshapeDims + numExpandedDims)); - reassociation.emplace_back(std::move(indices)); - numReshapeDims += numExpandedDims; - } - return reassociation; -} - static bool isFusableWithReshapeByDimExpansion(AttentionOp op, OpOperand *fusableOpOperand) { // Is fusable only if: @@ -438,54 +423,23 @@ static std::optional> fuseAttentionWithReshapeByExpansion( : collapsingReshapeOp.getSrc()); continue; } - if (auto opOperandType = - dyn_cast(opOperand->get().getType())) { - AffineMap indexingMap = attentionOp.getMatchingIndexingMap(opOperand); - RankedTensorType expandedOperandType = - getExpandedType(opOperandType, indexingMap, expansionInfo); - if (expandedOperandType != opOperand->get().getType()) { - // Reshape the operand to get the right type. - SmallVector reassociation = - getReassociationForExpansion(indexingMap, expansionInfo); - if (failed(reshapeLikeShapesAreCompatible( - [&](const Twine &msg) { - return rewriter.notifyMatchFailure(attentionOp, msg); - }, - opOperandType.getShape(), expandedOperandType.getShape(), - reassociation, - /*isExpandingReshape=*/true))) - return std::nullopt; - expandedOpOperands.push_back(rewriter.create( - loc, expandedOperandType, opOperand->get(), reassociation)); - continue; - } - } - expandedOpOperands.push_back(opOperand->get()); + // Reshape the operand to get the right type. + std::optional expanded = + expansionInfo.getOrCreateExpanded(loc, opOperand, rewriter); + if (!expanded) + return std::nullopt; + expandedOpOperands.push_back(*expanded); + continue; } Value output; OpOperand &outOperand = attentionOp.getOutputMutable(); - AffineMap indexingMap = attentionOp.getMatchingIndexingMap(&outOperand); - auto opOperandType = cast(outOperand.get().getType()); - RankedTensorType expandedOutputType = - getExpandedType(opOperandType, indexingMap, expansionInfo); - if (expandedOutputType != outOperand.get().getType()) { - SmallVector reassociation = - getReassociationForExpansion(indexingMap, expansionInfo); - if (failed(reshapeLikeShapesAreCompatible( - [&](const Twine &msg) { - return rewriter.notifyMatchFailure(attentionOp, msg); - }, - opOperandType.getShape(), expandedOutputType.getShape(), - reassociation, - /*isExpandingReshape=*/true))) - return std::nullopt; - output = rewriter.create( - loc, expandedOutputType, outOperand.get(), reassociation); - } else { - output = outOperand.get(); - } + std::optional maybeOutput = + expansionInfo.getOrCreateExpanded(loc, &outOperand, rewriter); + if (!maybeOutput) + return std::nullopt; + output = *maybeOutput; Value maskOperand; if (expandedOpOperands.size() > 4) { @@ -510,9 +464,7 @@ static std::optional> fuseAttentionWithReshapeByExpansion( int64_t resultNumber = opResult.getResultNumber(); if (resultTypes[resultNumber] != opResult.getType()) { SmallVector reassociation = - getReassociationForExpansion( - attentionOp.getIndexingMapsForResults()[resultNumber], - expansionInfo); + expansionInfo.getReassoc(attentionOp.getTiedOpOperand(opResult)); resultVals.push_back(rewriter.create( attentionOp.getLoc(), opResult.getType(), fusedOp->getResult(resultNumber), reassociation)); @@ -552,9 +504,9 @@ fuseScatterWithReshapeByExpansion(ScatterOp scatterOp, Operation *reshapeOp, OpOperand *update = scatterOp.getDpsInputOperand(0); OpOperand *indices = scatterOp.getDpsInputOperand(1); OpOperand *original = scatterOp.getDpsInitOperand(0); - Value newUpdates = info.getOrCreateExpanded(loc, update, rewriter); - Value newIndices = info.getOrCreateExpanded(loc, indices, rewriter); - Value newOriginal = info.getOrCreateExpanded(loc, original, rewriter); + auto newUpdates = info.getOrCreateExpanded(loc, update, rewriter).value(); + auto newIndices = info.getOrCreateExpanded(loc, indices, rewriter).value(); + auto newOriginal = info.getOrCreateExpanded(loc, original, rewriter).value(); auto newScatter = rewriter.create( loc, newOriginal.getType(), ValueRange{newUpdates, newIndices}, @@ -563,10 +515,11 @@ fuseScatterWithReshapeByExpansion(ScatterOp scatterOp, Operation *reshapeOp, rewriter.inlineRegionBefore(scatterOp.getRegion(), newScatter.getRegion(), newScatter.getRegion().begin()); - // Collapse back to original shape. auto originalShapeMap = info.getShapeMap(original); SmallVector originalReassoc = computeReassocFromShapeMap(originalShapeMap); + + // Collapse back to original shape. if (isIdentityReassoc(originalReassoc)) { return {newScatter.getResult(0)}; }