diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index 9f5f9f3fca97a..d2cddfe00ac78 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -85,6 +85,36 @@ struct SCFTilingOptions { return *this; } + /// Specify how reduction dimensions should be tiled. + /// + /// Tiling can be thought of as splitting a dimension into 2 and materializing + /// the outer dimension as a loop: + /// + /// op[original] -> op[original / x, x] -> loop[original] { op[x] } + /// + /// For parallel dimensions, the split can only happen in one way, with both + /// dimensions being parallel. For reduction dimensions however, there is a + /// choice in how we split the reduction dimension. This enum exposes this + /// choice. + enum class ReductionTilingStrategy { + // [reduction] -> [reduction1, reduction2] + // -> loop[reduction1] { [reduction2] } + FullReduction, + // [reduction] -> [reduction1, parallel2] + // -> loop[reduction1] { [parallel2] }; merge[reduction1] + PartialReductionOuterReduction, + // [reduction] -> [parallel1, reduction2] + // -> loop[parallel1] { [reduction2] }; merge[parallel1] + PartialReductionOuterParallel + }; + ReductionTilingStrategy reductionStrategy = + ReductionTilingStrategy::FullReduction; + SCFTilingOptions & + setReductionTilingStrategy(ReductionTilingStrategy strategy) { + reductionStrategy = strategy; + return *this; + } + /// Specify mapping of loops to devices. This is only respected when the loop /// constructs support such a mapping (like `scf.forall`). Will be ignored /// when using loop constructs that dont support such a mapping (like @@ -102,11 +132,16 @@ struct SCFTilingResult { /// matter except the last op. The replacements are expected to be the results /// of the last op. SmallVector tiledOps; + /// The initial destination values passed to the tiled operations. + SmallVector initialValues; /// The `scf.for` operations that iterate over the tiles. SmallVector loops; - /// Values to use as replacements for the untiled op. Is the same size as the - /// number of results of the untiled op. - SmallVector replacements; + /// The result generated by the loop nest in tiling, may hold partial results, + /// which need to be merged to match the computation of the untiled operation. + /// `mergeResult` contains the operations used to perform this merge from + /// partial results and the values that can be used as replacements of + /// the untiled operation. + MergeResult mergeResult; /// Slices generated after tiling that can be used for fusing with the tiled /// producer. SmallVector generatedSlices; @@ -300,20 +335,6 @@ tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp); FailureOr> lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op); -/// Transformation information returned after reduction tiling. -struct SCFReductionTilingResult { - /// The partial reduction tiled op generated. - SmallVector parallelTiledOps; - /// The final reduction operation merging all the partial reductions. - SmallVector mergeOps; - /// Initial values used for reduction. - SmallVector initialValues; - /// The loop operations that iterate over the tiles. - SmallVector loops; - /// The replacements to use for the results of the tiled operation. - SmallVector replacements; -}; - /// Method to tile a reduction and generate a parallel op within a serial loop. /// Each of the partial reductions are calculated in parallel. Then after the /// loop all the partial reduction are merged into a final reduction. @@ -338,7 +359,7 @@ struct SCFReductionTilingResult { /// %6 = linalg.generic %1 ["parallel", "reduction"] /// : tensor<7x4xf32> -> tensor<7xf32> /// ``` -FailureOr +FailureOr tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef tileSize); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 8839faf4cafb2..66a3947e0f91f 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2224,7 +2224,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter, return emitDefaultDefiniteFailure(target); if (target->getNumResults()) - rewriter.replaceOp(target, maybeTilingResult->replacements); + rewriter.replaceOp(target, maybeTilingResult->mergeResult.replacements); else rewriter.eraseOp(target); @@ -2631,17 +2631,18 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne( transform::ApplyToEachResultList &results, transform::TransformState &state) { rewriter.setInsertionPoint(target); - FailureOr result = scf::tileReductionUsingScf( + FailureOr result = scf::tileReductionUsingScf( rewriter, cast(target.getOperation()), getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()))); if (failed(result)) return emitDefaultSilenceableFailure(target); + rewriter.replaceOp(target, result->mergeResult.replacements); for (Value initValue : result->initialValues) results.push_back(initValue.getDefiningOp()); - for (auto parallelTiledOp : result->parallelTiledOps) + for (auto parallelTiledOp : result->tiledOps) results.push_back(parallelTiledOp); - for (auto mergeOp : result->mergeOps) + for (auto mergeOp : result->mergeResult.mergeOps) results.push_back(mergeOp); results.push_back(result->loops.front()); return DiagnosedSilenceableFailure::success(); @@ -3065,7 +3066,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter, if (failed(maybeTilingResult)) return DiagnosedSilenceableFailure::definiteFailure(); - rewriter.replaceOp(op, maybeTilingResult->replacements); + rewriter.replaceOp(op, maybeTilingResult->mergeResult.replacements); tiled.append(maybeTilingResult->tiledOps); for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops)) @@ -3304,7 +3305,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl( if (failed(maybeTilingResult)) return transformOp.emitDefaultSilenceableFailure(tileableOp); - rewriter.replaceOp(tileableOp, maybeTilingResult->replacements); + rewriter.replaceOp(tileableOp, maybeTilingResult->mergeResult.replacements); tilingResult = *maybeTilingResult; diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 6a4a6b4393380..ef5d4370e7810 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -570,6 +570,144 @@ static LogicalResult generateLoopNest( return rewriter.notifyMatchFailure(loc, "unhandled loop type"); } +static FailureOr> +createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, + ArrayRef tileSizes, + const scf::SCFTilingOptions &options) { + SmallVector initTensors; + Location loc = op->getLoc(); + switch (options.reductionStrategy) { + case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction: + if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, initTensors))) + return failure(); + return initTensors; + case scf::SCFTilingOptions::ReductionTilingStrategy:: + PartialReductionOuterReduction: { + auto redOp = dyn_cast(op.getOperation()); + if (!redOp) { + return rewriter.notifyMatchFailure( + op, "PartialReductionOuterReduction tiling strategy is only supported" + "for operations implementing PartialReductionOpInterface"); + } + // Get reduction dimensions. + // TODO: PartialReductionOpInterface should really query TilingInterface + // itself and find reduction dimensions. + SmallVector reductionDims; + for (auto [idx, iteratorType] : + llvm::enumerate(op.getLoopIteratorTypes())) { + if (iteratorType == utils::IteratorType::reduction) + reductionDims.push_back(idx); + } + return redOp.generateInitialTensorForPartialReduction( + rewriter, loc, tileSizes, reductionDims); + } + default: + return rewriter.notifyMatchFailure(op, + "unhandled reduction tiling strategy"); + } +} + +static FailureOr +getTiledImplementation(RewriterBase &rewriter, TilingInterface op, + ValueRange regionIterArg, ArrayRef offsets, + ArrayRef sizes, + const scf::SCFTilingOptions &options) { + switch (options.reductionStrategy) { + case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction: + return op.getTiledImplementation(rewriter, offsets, sizes); + case scf::SCFTilingOptions::ReductionTilingStrategy:: + PartialReductionOuterReduction: { + auto redOp = dyn_cast(op.getOperation()); + if (!redOp) { + return rewriter.notifyMatchFailure( + op, "PartialReductionOuterReduction tiling strategy is only " + "supported for operations " + "implementing PartialReductionOpInterface"); + } + // Get reduction dimensions. + // TODO: PartialReductionOpInterface should really query TilingInterface + // itself and find reduction dimensions. + SmallVector reductionDims; + for (auto [idx, iteratorType] : + llvm::enumerate(op.getLoopIteratorTypes())) { + if (iteratorType == utils::IteratorType::reduction) + reductionDims.push_back(idx); + } + return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg, + offsets, sizes, reductionDims); + } + default: + return rewriter.notifyMatchFailure(op, + "unhandled reduction tiling strategy"); + } +} + +static LogicalResult +getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, + TilingInterface op, ArrayRef offsets, + ArrayRef sizes, + SmallVector &resultOffset, + SmallVector &resultSize, + const scf::SCFTilingOptions &options) { + + switch (options.reductionStrategy) { + case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction: + return op.getResultTilePosition(rewriter, index, offsets, sizes, + resultOffset, resultSize); + case scf::SCFTilingOptions::ReductionTilingStrategy:: + PartialReductionOuterReduction: { + // TODO: This does not work for non identity accesses to the result tile. + // The proper fix is to add a getPartialResultTilePosition method to + // PartialReductionOpInterface. + resultOffset = + SmallVector(offsets.size(), rewriter.getIndexAttr(0)); + for (size_t i = 0; i < offsets.size(); i++) { + resultSize.push_back( + tensor::getMixedSize(rewriter, op.getLoc(), tiledResult, i)); + } + return success(); + default: + return rewriter.notifyMatchFailure(op, + "unhandled reduction tiling strategy"); + } + } +} + +static FailureOr +mergeTilingResults(RewriterBase &rewriter, TilingInterface op, + ValueRange partialResults, + const scf::SCFTilingOptions &options) { + switch (options.reductionStrategy) { + case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction: + // No need to merge results for reduction tiling strategy. + return MergeResult{{}, partialResults}; + case scf::SCFTilingOptions::ReductionTilingStrategy:: + PartialReductionOuterReduction: { + auto redOp = dyn_cast(op.getOperation()); + if (!redOp) { + return rewriter.notifyMatchFailure( + op, "PartialReductionOuterReduction tiling strategy is only " + "supported for operations " + "implementing PartialReductionOpInterface"); + } + // Get reduction dimensions. + // TODO: PartialReductionOpInterface should really query TilingInterface + // itself and find reduction dimensions. + SmallVector reductionDims; + for (auto [idx, iteratorType] : + llvm::enumerate(op.getLoopIteratorTypes())) { + if (iteratorType == utils::IteratorType::reduction) + reductionDims.push_back(idx); + } + return redOp.mergeReductions(rewriter, op.getLoc(), partialResults, + reductionDims); + } + default: + return rewriter.notifyMatchFailure(op, + "unhandled reduction tiling strategy"); + } +} + /// Append the specified additional `newInitOperands` operands to the /// loops existing `init` operands (or similar), and replace `loopOp` with /// the new loop that has the additional init operands. The loop body of @@ -710,11 +848,11 @@ FailureOr yieldTiledValuesAndReplaceLoop( }); } -/// Method to add new init values to a loop nest. Updates `loops` in-place with -/// new loops that use the `newInitValues`. -/// The outer-loops are updated to yield the new result values of the inner -/// loop. For the innermost loop, the call back `getNewYields` is invoked to get -/// the additional values to yield form the innermost loop. +/// Method to add new init values to a loop nest. Updates `loops` in-place +/// with new loops that use the `newInitValues`. The outer-loops are updated +/// to yield the new result values of the inner loop. For the innermost loop, +/// the call back `getNewYields` is invoked to get the additional values to +/// yield form the innermost loop. static LogicalResult addInitOperandsToLoopNest( RewriterBase &rewriter, MutableArrayRef loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) { @@ -852,9 +990,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, auto clonedOp = cast( cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs)); - // 5b. Early return cloned op if tiling is not happening. We can not return - // the original op because it could lead to - // `rewriter.replaceOp(op, op->getResults())` and users would get crash. + // 5b. Early return cloned op if tiling is not happening. We can not + // return the original op because it could lead to `rewriter.replaceOp(op, + // op->getResults())` and users would get crash. if (llvm::all_of(tileSizes, isZeroIndex)) { tiledResults.append(clonedOp->result_begin(), clonedOp->result_end()); tilingResult = @@ -864,7 +1002,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, } // 5c. Tile the cloned operation. - tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes); + tilingResult = getTiledImplementation(rewriter, clonedOp, regionIterArgs, + offsets, sizes, options); if (failed(tilingResult)) { rewriter.eraseOp(clonedOp); return op.emitOpError("faild to tile operation"); @@ -879,8 +1018,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, llvm::enumerate(tilingResult->tiledValues)) { tiledResults.push_back(tiledValue); SmallVector resultOffset, resultSize; - if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes, - resultOffset, resultSize))) { + if (failed(getResultTilePosition(rewriter, index, tiledValue, op, offsets, + sizes, resultOffset, resultSize, + options))) { for (auto op : tilingResult->tiledOps) { rewriter.eraseOp(op); } @@ -895,158 +1035,65 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, }; // 6. Find the destination tensors to use for the operation. - SmallVector destinationTensors; - if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op, - destinationTensors))) { - return rewriter.notifyMatchFailure(op, - "unable to create destination tensors"); + FailureOr> maybeInits = + createInitialTensorsForTiling(rewriter, op, tileSizes, options); + if (failed(maybeInits)) { + return rewriter.notifyMatchFailure( + op, "unable to create initial tensors for tiling"); } + SmallVector &initTensors = maybeInits.value(); // 7. Generate the tiled loops nest using the callback defined above. SmallVector loops; if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain, - tileSizes, numThreads, destinationTensors, + tileSizes, numThreads, initTensors, innerYieldTiledValuesFn, loops))) return op.emitOpError("failed to generate tiling loops"); assert(succeeded(tilingResult) && "expected tiling result to be computed after loop generation"); - // If loops are empty, the tiled op is used as the replacement for the untiled - // op. + SmallVector partialResults; if (loops.empty()) { - return scf::SCFTilingResult{tilingResult->tiledOps, loops, - tilingResult->tiledValues, - tilingResult->generatedSlices}; + // If loops are empty, the tiled op is used as the replacement for the + // untiled op. + partialResults = tilingResult->tiledValues; + } else { + partialResults = llvm::map_to_vector(loops.front()->getResults(), + [](OpResult r) -> Value { return r; }); } - SmallVector replacements = llvm::map_to_vector( - loops.front()->getResults(), [](OpResult r) -> Value { return r; }); - return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements, + FailureOr mergeResult = + mergeTilingResults(rewriter, op, partialResults, options); + if (failed(mergeResult)) { + return rewriter.notifyMatchFailure( + op, "Failed to merge partial results from tiling"); + } + + return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops, + mergeResult.value(), tilingResult->generatedSlices}; } -FailureOr +FailureOr mlir::scf::tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef tileSizes) { - Location loc = op.getLoc(); - // Ops implementing PartialReductionOpInterface are expected to implement - // TilingInterface. - auto tilingInterfaceOp = cast(op.getOperation()); - SmallVector iterationDomain = tilingInterfaceOp.getIterationDomain(b); - auto tileSizesVector = llvm::to_vector(tileSizes); - if (tileSizesVector.size() < iterationDomain.size()) { - auto zero = b.getIndexAttr(0); - tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(), - zero); - } - SmallVector iterators = - tilingInterfaceOp.getLoopIteratorTypes(); - - SmallVector reductionDims; - for (auto [idx, iteratorType] : - llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) { - if (iteratorType == utils::IteratorType::reduction) - reductionDims.push_back(idx); - } - - // 2. create the inital tensor value. - FailureOr> maybeInitTensors = - op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector, - reductionDims); - if (failed(maybeInitTensors)) { - return b.notifyMatchFailure(op, "Failed to create initial tensors."); - } - SmallVector &initTensors = maybeInitTensors.value(); - - // 3. Define the callback to use for generating the inner most tile loop body. - SmallVector parallelTiledOps; - auto innerYieldTiledValuesFn = - [&](RewriterBase &rewriter, Location loc, ValueRange ivs, - ValueRange regionIterArgs, SmallVector &tiledResult, - SmallVector> &resultOffsets, - SmallVector> &resultSizes) - -> LogicalResult { - SmallVector offsets, sizes; - { - int materializedLoopNum = 0; - for (auto [tileSize, loopRange] : - llvm::zip_equal(tileSizesVector, iterationDomain)) { - if (isConstantIntValue(tileSize, 0)) { - offsets.push_back(loopRange.offset); - sizes.push_back(loopRange.size); - continue; - } - Value iv = ivs[materializedLoopNum++]; - offsets.push_back(iv); - sizes.push_back( - getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize)); - } - } - - // 4a. Clone the operation. - { - auto clonedOp = cast( - cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs)); - - // 4b. Tile the cloned operation. - FailureOr partialTilingResult = - clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets, - sizes, reductionDims); - if (failed(partialTilingResult)) { - return failure(); - } - std::swap(parallelTiledOps, partialTilingResult->tiledOps); - std::swap(tiledResult, partialTilingResult->tiledValues); - - // 4c. Delete the cloned operation. - b.eraseOp(clonedOp); - } - - // 4d. Compute the offsets and sizes needed to insert the result of the - // tiled value back into destination before yielding the destination. - for (auto result : tiledResult) { - SmallVector outOffsets(offsets.size(), b.getIndexAttr(0)); - resultOffsets.emplace_back(std::move(outOffsets)); - - SmallVector outSizes; - for (size_t i = 0; i < offsets.size(); i++) { - outSizes.push_back(tensor::getMixedSize(b, loc, result, i)); - } - resultSizes.emplace_back(std::move(outSizes)); - } - return success(); - }; - - // 5. Generate the tiled implementation using the destination tensors. - SmallVector loops; - scf::SCFTilingOptions options; - options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp); - if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector, - /*numThreads=*/ArrayRef{}, - initTensors, innerYieldTiledValuesFn, loops))) - return b.notifyMatchFailure(op, "failed to tile for parallel reduction"); - - SmallVector replacements = llvm::map_to_vector( - loops.front()->getResults(), [](OpResult r) -> Value { return r; }); - - // 5. Apply the merge reduction to combine all the partial values. - b.setInsertionPointAfter(*loops.begin()); - FailureOr mergeResult = - op.mergeReductions(b, loc, replacements, reductionDims); - if (failed(mergeResult)) { - return failure(); - } - b.replaceOp(op, mergeResult->replacements); - - SCFReductionTilingResult reductionTilingResult; - std::swap(reductionTilingResult.parallelTiledOps, parallelTiledOps); - std::swap(reductionTilingResult.mergeOps, mergeResult->mergeOps); - std::swap(reductionTilingResult.initialValues, initTensors); - std::swap(reductionTilingResult.loops, loops); - std::swap(reductionTilingResult.replacements, mergeResult->replacements); - - return reductionTilingResult; + SCFTilingOptions options; + options.setLoopType(SCFTilingOptions::LoopType::ForOp); + options.setReductionTilingStrategy(SCFTilingOptions::ReductionTilingStrategy:: + PartialReductionOuterReduction); + options.setTileSizes(tileSizes); + + TilingInterface tilingInterfaceOp = + dyn_cast(op.getOperation()); + if (!tilingInterfaceOp) { + return b.notifyMatchFailure( + op, + "Operation implementing PartialReductionOpInterface should implement " + "TilingInterface"); + } + + return tileUsingSCF(b, tilingInterfaceOp, options); } //===----------------------------------------------------------------------===// @@ -1055,9 +1102,10 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b, /// Return the untiled producer whose slice is used in a tiled consumer. The /// method traverses the tile loop nest (`loops`) if needed, and returns the -/// `iter_args` of the outer most that is encountered. Traversing the iter_args -/// indicates that this is a destination operand of the consumer. If there was -/// no loop traversal needed, the second value of the returned tuple is empty. +/// `iter_args` of the outer most that is encountered. Traversing the +/// iter_args indicates that this is a destination operand of the consumer. If +/// there was no loop traversal needed, the second value of the returned tuple +/// is empty. static std::tuple> getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef loops) { @@ -1115,8 +1163,8 @@ mlir::scf::tileAndFuseProducerOfSlice( Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs( rewriter, fusableProducerOp, clonedOpDestinationTensors); // 2d. Update the source of the candidateSlice to be the cloned producer. - // Easier to just clone the slice with different source since replacements - // and DCE of cloned ops becomes easier + // Easier to just clone the slice with different source since + // replacements and DCE of cloned ops becomes easier SmallVector candidateSliceOpOperands = llvm::to_vector(candidateSliceOp->getOperands()); candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber); @@ -1250,13 +1298,13 @@ FailureOr> mlir::scf::yieldReplacementForFusedProducer( failed(tilableOp.getIterationDomainTileFromResultTile( rewriter, sliceResultNumber, sliceOffset, sliceSizes, iterDomainOffset, iterDomainSizes))) { - // In theory, it is unnecessary to raise an error here. Actually although - // it fails to reconstruct the result tensor, it should not broke current - // fusion anyway. The reason why we must return failure currently is that - // the callback function `newYieldValuesFn` will be called after new init - // operand(s) has already been appended. It will take more refactoring to - // make sure the init operands are added consistently in the future. For - // more details, please refer to: + // In theory, it is unnecessary to raise an error here. Actually + // although it fails to reconstruct the result tensor, it should not + // broke current fusion anyway. The reason why we must return failure + // currently is that the callback function `newYieldValuesFn` will be + // called after new init operand(s) has already been appended. It will + // take more refactoring to make sure the init operands are added + // consistently in the future. For more details, please refer to: // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814 return failure(); } @@ -1282,7 +1330,8 @@ FailureOr> mlir::scf::yieldReplacementForFusedProducer( } } - // d. create `extract_slice` for `iter_args` for DPS operation if necessary + // d. create `extract_slice` for `iter_args` for DPS operation if + // necessary if (auto tiledDestStyleOp = dyn_cast(tiledOwner)) { rewriter.setInsertionPoint(tiledDestStyleOp); @@ -1334,9 +1383,10 @@ class SliceTrackingListener : public RewriterBase::Listener { std::optional patterns); SliceTrackingListener() = default; - /// Adds the given list of operations to the worklist, and if present, applies - /// the list of `patterns` to the newly added operations. This only processes - /// the given operations and any newly inserted ones by the pattern set. + /// Adds the given list of operations to the worklist, and if present, + /// applies the list of `patterns` to the newly added operations. This only + /// processes the given operations and any newly inserted ones by the + /// pattern set. LogicalResult insertAndApplyPatterns(ArrayRef newOps); /// Add to the new operation worklist if it is an extract_slice. @@ -1357,7 +1407,8 @@ class SliceTrackingListener : public RewriterBase::Listener { std::deque worklist; private: - /// Optional pattern set to apply when adding new operations to the worklist. + /// Optional pattern set to apply when adding new operations to the + /// worklist. std::optional patterns = std::nullopt; }; @@ -1390,8 +1441,9 @@ void SliceTrackingListener::notifyOperationInserted( worklist.push_back(slice); } -// Scan the worklist for the given op and remove it if present. The expectation -// is for the worklist to be small and for removal to be relatively rare. +// Scan the worklist for the given op and remove it if present. The +// expectation is for the worklist to be small and for removal to be +// relatively rare. void SliceTrackingListener::removeOp(Operation *op) { if (!isa(op)) return; @@ -1445,17 +1497,18 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( auto &loops = tilingResult->loops; if (loops.empty()) { DenseMap replacements; - for (auto [origVal, replacement] : - llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) { + for (auto [origVal, replacement] : llvm::zip_equal( + consumer->getResults(), tilingResult->mergeResult.replacements)) { replacements[origVal] = replacement; } return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops, replacements}; } - // To keep track of replacements for now just record the map from the original - // untiled value to the result number of the for loop. Since the loop gets - // potentially replaced during fusion, keeping the value directly wont work. + // To keep track of replacements for now just record the map from the + // original untiled value to the result number of the for loop. Since the + // loop gets potentially replaced during fusion, keeping the value directly + // wont work. DenseMap origValToResultNumber; for (auto [index, result] : llvm::enumerate(consumer->getResults())) { origValToResultNumber[result] = index; @@ -1463,11 +1516,11 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( // 2. Typically, the operands of the tiled operation are slices of the // operands of the untiled operation. These are expressed in IR using - // `tensor.extract_slice` operations with source being the operands of the - // untiled operation. Create a worklist of these `tensor.extract_slice` - // operations. If the producers of the source of the `tensor.extract_slice` - // can be tiled such that the tiled value is generated in-place, that - // effectively tiles + fuses the operations. + // `tensor.extract_slice` operations with source being the operands of + // the untiled operation. Create a worklist of these + // `tensor.extract_slice` operations. If the producers of the source of + // the `tensor.extract_slice` can be tiled such that the tiled value is + // generated in-place, that effectively tiles + fuses the operations. struct WorklistItem { tensor::ExtractSliceOp candidateSlice; SCFTileAndFuseOptions::ControlFnResult controlFnResult; @@ -1511,9 +1564,10 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( SmallVector worklistCandidates = fusedResult->generatedSlices; if (worklistItem.controlFnResult.yieldProducerReplacement) { - // Reconstruct and yield all opResult of fusableProducerOp by default. The - // caller can specific which one to yield by designating optional argument - // named `yieldResultNumber` of `yieldReplacementForFusedProducer`. + // Reconstruct and yield all opResult of fusableProducerOp by default. + // The caller can specific which one to yield by designating optional + // argument named `yieldResultNumber` of + // `yieldReplacementForFusedProducer`. Operation *fusableProducerOp = fusedResult->origProducer.getOwner(); FailureOr> newSlices = yieldReplacementForFusedProducer(rewriter, @@ -1582,8 +1636,8 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) { return success(); } -/// An utility to get the first user of the given loopOp. If any of user stay in -/// different block of loopOp, return failure. +/// An utility to get the first user of the given loopOp. If any of user stay +/// in different block of loopOp, return failure. static FailureOr getFirstUserOfLoop(Operation *loopOp) { if (!isa(loopOp)) return failure(); @@ -1616,11 +1670,11 @@ static FailureOr getFirstUserOfLoop(Operation *loopOp) { return firstUserOfLoop; } -/// This utility currently checks whether the first userOp of loop is NOT before -/// the last defineOp of consumer operand. Because that we need to move the -/// whole loop structure right before the `firstUserOfLoop`. This utility thus -/// helps ensuring that no invalid IR is formed, i.e. no backward slice of -/// consumerOp is dominated by the `firstUserOfLoop`. Saying that: +/// This utility currently checks whether the first userOp of loop is NOT +/// before the last defineOp of consumer operand. Because that we need to move +/// the whole loop structure right before the `firstUserOfLoop`. This utility +/// thus helps ensuring that no invalid IR is formed, i.e. no backward slice +/// of consumerOp is dominated by the `firstUserOfLoop`. Saying that: /// /// ``` /// %0 = scf.for() { @@ -1634,9 +1688,9 @@ static FailureOr getFirstUserOfLoop(Operation *loopOp) { /// %3 = consumerOp(%2) /// ``` /// -/// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it would -/// be invalid to move the `loopOp` right before the `firstUserOfLoop`, a.k.a. -/// use-def chain violation: +/// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it +/// would be invalid to move the `loopOp` right before the `firstUserOfLoop`, +/// a.k.a. use-def chain violation: /// /// ``` /// %0:2 = scf.for() { @@ -1650,10 +1704,10 @@ static FailureOr getFirstUserOfLoop(Operation *loopOp) { /// /// @param loopOp: loop operation /// @param consumerOp: consumer operation -/// @param reorderOperations: the flag controls whether to reorder the backward -/// slice w.r.t. the defineOp of `consumerOp` operands. -/// @return: computed backward slice of consumerOp, but excluding those already -/// dominates `firstUserOfLoop`. +/// @param reorderOperations: the flag controls whether to reorder the +/// backward slice w.r.t. the defineOp of `consumerOp` operands. +/// @return: computed backward slice of consumerOp, but excluding those +/// already dominates `firstUserOfLoop`. static FailureOr> checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, bool reorderOperations) { @@ -1713,8 +1767,8 @@ static FailureOr getConsumerFromLoopUses(RewriterBase &rewriter, if (!isa(consumerOp) || !isa(consumerOp)) { // TODO: We have to init result of consumer before scf.for, use - // DestinationStyleOpInterface to get result shape from init for now. Add - // support for other op such as op has InferTypeOpInterface. + // DestinationStyleOpInterface to get result shape from init for now. + // Add support for other op such as op has InferTypeOpInterface. continue; } // Step 2. Check if user stay in the same block. @@ -1729,7 +1783,8 @@ static FailureOr getConsumerFromLoopUses(RewriterBase &rewriter, checkAssumptionForLoop(loopOp, consumerOp, true); if (failed(slice)) continue; - // Step 5. If backward sice is not empty, move them before firstUserOfLoop. + // Step 5. If backward sice is not empty, move them before + // firstUserOfLoop. if (!slice->empty()) { mlir::topologicalSort(*slice); FailureOr firstUserOfLoop = getFirstUserOfLoop(loopOp); @@ -1743,8 +1798,8 @@ static FailureOr getConsumerFromLoopUses(RewriterBase &rewriter, return failure(); } -/// Find the perfectly nested loops outside of given loop(included) sorted from -/// outer to inner. +/// Find the perfectly nested loops outside of given loop(included) sorted +/// from outer to inner. /// /// E.g. /// @@ -1997,10 +2052,11 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, } // 10. Try to get iter domain position from input position. Use - // clonedConsumerOp instead of tiledConsumerOp, because the iteration domain - // may require index computation based on the result size. The sizes and - // offsets should be the same either way, but using tiledConsumerOp could - // lead to some chained unnecessary extra index computation. + // clonedConsumerOp instead of tiledConsumerOp, because the iteration + // domain may require index computation based on the result size. The + // sizes and offsets should be the same either way, but using + // tiledConsumerOp could lead to some chained unnecessary extra index + // computation. SmallVector iterDomainOffsets, iterDomainSizes; if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile( rewriter, operandNumber, offsets, sizes, iterDomainOffsets, @@ -2067,7 +2123,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, "unable to add new inits to nest loop"); } - // 15. Replace the result of scf loop and consumer op with new loop's results. + // 15. Replace the result of scf loop and consumer op with new loop's + // results. for (auto &&[oldResult, newResult] : llvm::zip( consumerOp->getResults(), diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp index 5e903e378daf8..7380b766935ff 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp @@ -250,7 +250,8 @@ applyTileToAll(RewriterBase &rewriter, Operation *transformOp, return failure(); // Perform the replacement of tiled and fused values. - rewriter.replaceOp(tilingInterfaceOp, tiledResults->replacements); + rewriter.replaceOp(tilingInterfaceOp, + tiledResults->mergeResult.replacements); // Report back the relevant handles to the transform op. tiledOps.push_back(tiledResults->tiledOps.front());