Skip to content

Commit ad938ae

Browse files
authored
[DT][NFC] Localize CPU specific encoding materialization logic. (#19452)
The revision moves the CPU materialization logic from Dialect/Codegen/Utils/Utils.[h|cpp] to CPUEncodingExternalModels. They were public methods during transition states. After all the CPU layout attributes are implemented, we no longer need to expose them to the public. Additionally, it removes the outdated logic from MaterializeContractionOp pattern. And it removes the `transposeNarrowN` input argument from lowerContractionOpWithEncoding method because all the CPU backends enable the transposeNarrowN feature. Signed-off-by: hanhanW <hanhan0912@gmail.com>
1 parent c618134 commit ad938ae

File tree

4 files changed

+294
-318
lines changed

4 files changed

+294
-318
lines changed

compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp

+8-18
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "iree/compiler/Codegen/Common/EncodingUtils.h"
1212
#include "iree/compiler/Codegen/Common/Passes.h"
13+
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
1314
#include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h"
1415
#include "iree/compiler/Codegen/Utils/Utils.h"
1516
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
@@ -740,25 +741,14 @@ class MaterializeContractionOp
740741
auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
741742
this->getTypeConverter());
742743

743-
if (auto layoutAttr = converter->getLayoutAttr()) {
744-
SmallVector<Type> convertedResTypes;
745-
for (auto init : op.getDpsInits()) {
746-
convertedResTypes.push_back(converter->convertType(init.getType()));
747-
}
748-
Operation *newOp =
749-
layoutAttr.lowerOp(rewriter, op, convertedResTypes, operands);
750-
rewriter.replaceOp(op, newOp->getResults());
751-
return success();
752-
}
753-
754-
FailureOr<Operation *> convertedOp =
755-
IREE::Codegen::lowerContractionOpWithEncoding(
756-
rewriter, op, operands, converter->getTransposeNarrowN(),
757-
converter->getLayoutAttr());
758-
if (failed(convertedOp)) {
759-
return failure();
744+
IREE::Codegen::LayoutAttrInterface layoutAttr = converter->getLayoutAttr();
745+
SmallVector<Type> convertedResTypes;
746+
for (auto init : op.getDpsInits()) {
747+
convertedResTypes.push_back(converter->convertType(init.getType()));
760748
}
761-
rewriter.replaceOp(op.getOperation(), convertedOp.value()->getResult(0));
749+
Operation *newOp =
750+
layoutAttr.lowerOp(rewriter, op, convertedResTypes, operands);
751+
rewriter.replaceOp(op, newOp->getResults());
762752
return success();
763753
}
764754

compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp

-270
Original file line numberDiff line numberDiff line change
@@ -305,274 +305,4 @@ getEncodingInfoForMatmul(Encoding::EncodingAttr encoding, TileMxNxK tileMxNxK) {
305305
return encodingInfo;
306306
}
307307

308-
static RankedTensorType dropEncoding(RankedTensorType type) {
309-
return RankedTensorType::get(type.getShape(), type.getElementType());
310-
}
311-
312-
static Operation *dropEncodingAndCloneOp(OpBuilder &builder, Operation *op,
313-
ValueRange convertedInputOperands,
314-
ValueRange convertedOutputOperands) {
315-
SmallVector<Value> operands;
316-
operands.append(convertedInputOperands.begin(), convertedInputOperands.end());
317-
operands.append(convertedOutputOperands.begin(),
318-
convertedOutputOperands.end());
319-
return mlir::clone(builder, op,
320-
{dropEncoding(cast<RankedTensorType>(
321-
convertedOutputOperands[0].getType()))},
322-
operands);
323-
}
324-
325-
static RankedTensorType
326-
getExpandedType(RankedTensorType type, bool isBatched, bool isTransposed,
327-
SmallVectorImpl<ReassociationIndices> &ri) {
328-
if (!isBatched) {
329-
ri.assign({{0, 1}, {2, 3}});
330-
if (!isTransposed) {
331-
return RankedTensorType::get(
332-
{1, type.getDimSize(0), 1, type.getDimSize(1)},
333-
type.getElementType());
334-
}
335-
return RankedTensorType::get({type.getDimSize(0), 1, type.getDimSize(1), 1},
336-
type.getElementType());
337-
}
338-
339-
ri.assign({{0}, {1, 2}, {3, 4}});
340-
if (!isTransposed) {
341-
return RankedTensorType::get(
342-
{type.getDimSize(0), 1, type.getDimSize(1), 1, type.getDimSize(2)},
343-
type.getElementType());
344-
}
345-
return RankedTensorType::get(
346-
{type.getDimSize(0), type.getDimSize(1), 1, type.getDimSize(2), 1},
347-
type.getElementType());
348-
}
349-
350-
/// Given an input Value and a desired output element type, create and return
351-
/// an element-wise linalg::GenericOp that extends the input Value to the
352-
/// output element type.
353-
static Value createElementWiseExtUIOp(OpBuilder &builder, Value input,
354-
Location loc, Type outElemType) {
355-
auto inputType = cast<RankedTensorType>(input.getType());
356-
SmallVector<AffineMap> maps(
357-
2, builder.getMultiDimIdentityMap(inputType.getRank()));
358-
SmallVector<utils::IteratorType> iteratorTypes(inputType.getRank(),
359-
utils::IteratorType::parallel);
360-
auto castedType = inputType.clone(outElemType);
361-
SmallVector<OpFoldResult> inputMixedSizes =
362-
tensor::getMixedSizes(builder, loc, input);
363-
Value init =
364-
builder.create<tensor::EmptyOp>(loc, inputMixedSizes, outElemType);
365-
return builder
366-
.create<linalg::GenericOp>(
367-
loc, castedType, input, init, maps, iteratorTypes,
368-
[&](OpBuilder &b, Location nestedLoc, ValueRange args) {
369-
Value castRes =
370-
b.create<arith::ExtUIOp>(nestedLoc, outElemType, args[0])
371-
->getResult(0);
372-
b.create<linalg::YieldOp>(nestedLoc, castRes);
373-
})
374-
.getResult(0);
375-
}
376-
377-
/// If needed, expand and the input Value, and return the resulting input with
378-
/// the canonical mmt4d input shape. If the input element type is unsigned,
379-
/// create a producer Linalg::GenericOp on the input that unsigned extends the
380-
/// input to the output element type. This extension is required to keep the
381-
/// unsignedness information on the input for ukernels. If `transpose` is true,
382-
/// the `linalgOp`'s indexing maps are transposed.
383-
static Value getMmt4dOperand(Value value, linalg::LinalgOp linalgOp,
384-
bool transpose, OpBuilder &builder,
385-
SmallVectorImpl<ReassociationIndices> &ri,
386-
ArrayRef<Type> elemTypes, int operandIdx) {
387-
assert(linalgOp.getNumDpsInputs() == 2);
388-
assert(linalgOp.getNumDpsInits() == 1);
389-
auto cDims = linalg::inferContractionDims(linalgOp);
390-
Location loc = linalgOp->getLoc();
391-
Value expandedValue = value;
392-
// If vecmat with non-rhs operandIdx or matvec with non-lhs operandIdx, the
393-
// operand is a vector and must be extended
394-
if ((cDims->m.empty() && operandIdx != 1) ||
395-
(cDims->n.empty() && operandIdx != 0)) {
396-
auto type = cast<RankedTensorType>(value.getType());
397-
RankedTensorType newType = getExpandedType(
398-
type, /*isBatched=*/!cDims->batch.empty(),
399-
/*isTransposed=*/operandIdx == 2 && (transpose ^ cDims->n.empty()), ri);
400-
expandedValue =
401-
builder.create<tensor::ExpandShapeOp>(loc, newType, value, ri);
402-
}
403-
if (elemTypes[operandIdx].isUnsignedInteger()) {
404-
return createElementWiseExtUIOp(builder, expandedValue, loc,
405-
elemTypes.back());
406-
}
407-
return expandedValue;
408-
}
409-
410-
TileMxNxK chooseMatmulTile(ArrayRef<TileMxNxK> enumeratedTiles,
411-
IREE::Encoding::MatmulNarrowDim narrowDim,
412-
ArrayRef<int64_t> hostDefinedUpperBound) {
413-
assert((hostDefinedUpperBound.empty() || hostDefinedUpperBound.size() >= 3) &&
414-
"expected hostDefinedUpperBound is empty or has upper bound for {M, "
415-
"N, K}");
416-
// Handle narrow-N by transposing to reduce to narrow-M. Note: the
417-
// enumeratedTiles currently only enumerate narrow-M cases.
418-
if (narrowDim.isN()) {
419-
SmallVector<int64_t> newHostDefinedUpperBound(hostDefinedUpperBound);
420-
std::swap(newHostDefinedUpperBound[0], newHostDefinedUpperBound[1]);
421-
narrowDim.dim = IREE::Encoding::MatmulNarrowDim::Dim::M;
422-
TileMxNxK tile =
423-
chooseMatmulTile(enumeratedTiles, narrowDim, newHostDefinedUpperBound);
424-
std::swap(tile.M, tile.N);
425-
return tile;
426-
}
427-
// Handle kDynamic: currently this is only used with VMVX, where there is only
428-
// one enumerated tile and it has all three M/N/K dimensions dynamic, so for
429-
// now we only support that. Generalize that as needed when more dynamic tile
430-
// sizes are used outside of VMVX, e.g. perhaps some day with Arm SVE. Decide
431-
// how to incorporate the handling of kDynamic in the cost-model evaluation
432-
// below to decide when to prefer a dynamic vs a static tile shape.
433-
for (auto tile : enumeratedTiles) {
434-
if (ShapedType::isDynamic(tile.M) || ShapedType::isDynamic(tile.N) ||
435-
ShapedType::isDynamic(tile.K)) {
436-
assert(enumeratedTiles.size() == 1);
437-
assert(ShapedType::isDynamic(tile.M) && ShapedType::isDynamic(tile.N) &&
438-
ShapedType::isDynamic(tile.K));
439-
return tile;
440-
}
441-
}
442-
// We're going to "rate" the enumerated tiles.
443-
struct RatedTileMxNxK : TileMxNxK {
444-
RatedTileMxNxK() {}
445-
RatedTileMxNxK(TileMxNxK tile) : TileMxNxK(tile) {}
446-
// Penalize tiles that are wider in the M dimension than matmulNarrowM.
447-
int64_t paddingPenalty = 0;
448-
// Favor larger tiles, as long as they still minimize paddingPenalty.
449-
int64_t productMxNxK = 0;
450-
};
451-
SmallVector<RatedTileMxNxK> ratedTiles;
452-
ratedTiles.reserve(enumeratedTiles.size());
453-
int64_t bestPaddingPenalty = INT64_MAX;
454-
int64_t mUB = INT64_MAX;
455-
int64_t nUB = INT64_MAX;
456-
int64_t kUB = INT64_MAX;
457-
if (!hostDefinedUpperBound.empty()) {
458-
mUB = hostDefinedUpperBound[0];
459-
nUB = hostDefinedUpperBound[1];
460-
kUB = hostDefinedUpperBound[2];
461-
}
462-
for (auto tile : enumeratedTiles) {
463-
if (tile.M > mUB || tile.N > nUB || tile.K > kUB) {
464-
LLVM_DEBUG(llvm::dbgs() << "[" << DEBUG_TYPE << "]: tile (";
465-
llvm::interleaveComma(
466-
ArrayRef<int64_t>{tile.M, tile.N, tile.K}, llvm::dbgs());
467-
llvm::dbgs()
468-
<< ") is skipped because it is not valid for upper_bound (";
469-
llvm::interleaveComma(ArrayRef<int64_t>{mUB, nUB, kUB},
470-
llvm::dbgs());
471-
llvm::dbgs() << ")\n");
472-
continue;
473-
}
474-
RatedTileMxNxK ratedTile(tile);
475-
ratedTile.paddingPenalty = 0;
476-
// If we are choosing a tile for a narrow-M case, we want to minimize
477-
// padding along the M dimension.
478-
// The PowerOf2Ceil is so that we are OK with padding up to the next
479-
// power of two, we just try to avoid padding beyond that. For example,
480-
// if matmulNarrowM==7 and we have enumerated tiles with M=8,4,2,1, we
481-
// are OK with the tile that has M==8 even though it requires some padding.
482-
// Otherwise, we would be penalizing the tiles with M==8,4,2 and we would
483-
// end up selecting the vecmat tile (M==1) for that case!
484-
if (narrowDim) {
485-
ratedTile.paddingPenalty =
486-
std::max<int64_t>(tile.M - llvm::PowerOf2Ceil(narrowDim.size), 0);
487-
}
488-
ratedTile.productMxNxK = tile.M * tile.N * tile.K;
489-
ratedTiles.push_back(ratedTile);
490-
LLVM_DEBUG(llvm::dbgs() << "candidate: "; llvm::interleaveComma(
491-
ArrayRef<int64_t>{tile.M, tile.N, tile.K}, llvm::dbgs());
492-
llvm::dbgs() << " penalty:" << ratedTile.paddingPenalty << "\n");
493-
bestPaddingPenalty = std::min(bestPaddingPenalty, ratedTile.paddingPenalty);
494-
}
495-
RatedTileMxNxK bestRatedTile;
496-
for (auto ratedTile : ratedTiles) {
497-
// Choose only among tiles that minimize paddingPenalty. Among those,
498-
// maximize productMxNxK.
499-
if (ratedTile.paddingPenalty == bestPaddingPenalty &&
500-
bestRatedTile.productMxNxK < ratedTile.productMxNxK) {
501-
bestRatedTile = ratedTile;
502-
}
503-
}
504-
// Sanity check. This assert can only fail if there's a programming mistake
505-
// locally here.
506-
assert(bestRatedTile.paddingPenalty == bestPaddingPenalty);
507-
return bestRatedTile;
508-
}
509-
510-
FailureOr<Operation *>
511-
lowerContractionOpWithEncoding(OpBuilder &builder, linalg::LinalgOp linalgOp,
512-
ValueRange operands, bool transposeNarrowN,
513-
LayoutAttrInterface layoutAttr) {
514-
if (!linalgOp.hasPureTensorSemantics()) {
515-
return failure();
516-
}
517-
518-
auto inputs = linalgOp.getDpsInputOperands();
519-
auto outputs = linalgOp.getDpsInits();
520-
521-
auto lhsType = cast<RankedTensorType>(inputs[0]->get().getType());
522-
auto rhsType = cast<RankedTensorType>(inputs[1]->get().getType());
523-
auto resultType = cast<RankedTensorType>(outputs[0].getType());
524-
auto lhsEncoding = IREE::Encoding::getEncodingAttr(lhsType);
525-
auto rhsEncoding = IREE::Encoding::getEncodingAttr(rhsType);
526-
auto resultEncoding = IREE::Encoding::getEncodingAttr(resultType);
527-
if (!lhsEncoding || !rhsEncoding || !resultEncoding) {
528-
return failure();
529-
}
530-
531-
if (lhsEncoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_LHS ||
532-
rhsEncoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_RHS ||
533-
resultEncoding.getOperandIndex().getValue() !=
534-
IREE::Encoding::MATMUL_RESULT) {
535-
return failure();
536-
}
537-
538-
MaterializeEncodingInfo encodingInfo = layoutAttr.getEncodingInfo(
539-
cast<RankedTensorType>(linalgOp->getResultTypes()[0]));
540-
541-
if (isIdentityLayout(encodingInfo)) {
542-
return dropEncodingAndCloneOp(builder, linalgOp,
543-
operands.take_front(inputs.size()),
544-
operands.drop_front(inputs.size()));
545-
}
546-
547-
bool transpose = transposeNarrowN && isNarrowNResult(resultEncoding);
548-
SmallVector<Type> elemTypes = lhsEncoding.getElementTypesArray();
549-
SmallVector<ReassociationIndices> ri;
550-
Value newLhs = getMmt4dOperand(operands[0], linalgOp, transpose, builder, ri,
551-
elemTypes, /*operandIdx=*/0);
552-
Value newRhs = getMmt4dOperand(operands[1], linalgOp, transpose, builder, ri,
553-
elemTypes, /*operandIdx=*/1);
554-
Value newResult = getMmt4dOperand(operands[2], linalgOp, transpose, builder,
555-
ri, elemTypes, /*operandIdx=*/2);
556-
if (transpose) {
557-
std::swap(newLhs, newRhs);
558-
}
559-
Type newResultType = newResult.getType();
560-
auto cDims = IREE::Encoding::getEncodingContractionDims(lhsEncoding);
561-
Operation *result;
562-
if (cDims->batch.empty()) {
563-
result = builder.create<linalg::Mmt4DOp>(linalgOp.getLoc(), newResultType,
564-
ValueRange{newLhs, newRhs},
565-
ValueRange{newResult});
566-
} else {
567-
result = builder.create<linalg::BatchMmt4DOp>(
568-
linalgOp.getLoc(), newResultType, ValueRange{newLhs, newRhs},
569-
ValueRange{newResult});
570-
}
571-
if (!ri.empty()) {
572-
result = builder.create<tensor::CollapseShapeOp>(
573-
linalgOp->getLoc(), operands[2].getType(), result->getResult(0), ri);
574-
}
575-
return result;
576-
}
577-
578308
} // namespace mlir::iree_compiler::IREE::Codegen

compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h

-23
Original file line numberDiff line numberDiff line change
@@ -75,29 +75,6 @@ struct TileMxNxK {
7575
MaterializeEncodingInfo
7676
getEncodingInfoForMatmul(Encoding::EncodingAttr encoding, TileMxNxK tileMxNxK);
7777

78-
//===----------------------------------------------------------------------===//
79-
// Operation Lowering Utilities.
80-
//===----------------------------------------------------------------------===//
81-
82-
// TODO(hanchung): The below methods are exposed to public because they are
83-
// shared between MaterializeEncodingIntoPackUnPack.cpp.cpp and
84-
// CPUEncodingExternalModels.cpp. They will be moved to other places after all
85-
// the CPU backends implement their layout attributes.
86-
87-
/// Returns the best TileMxNxK from `enumeratedTiles` pool. If the
88-
/// `hostDefinedUpperBound` is not empty, the chosen tile sizes can not be
89-
/// greater than the values.
90-
/// TODO(#16933): Remove `hostDefinedUpperBound` once we can propagate such
91-
/// information to host. For now, they are defined by host.
92-
TileMxNxK chooseMatmulTile(ArrayRef<TileMxNxK> enumeratedTiles,
93-
IREE::Encoding::MatmulNarrowDim narrowDim,
94-
ArrayRef<int64_t> hostDefinedUpperBound = {});
95-
96-
FailureOr<Operation *>
97-
lowerContractionOpWithEncoding(OpBuilder &builder, linalg::LinalgOp linalgOp,
98-
ValueRange operands, bool transposeNarrowN,
99-
LayoutAttrInterface layoutAttr);
100-
10178
} // namespace mlir::iree_compiler::IREE::Codegen
10279

10380
#endif // IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_UTILS_H_

0 commit comments

Comments
 (0)