@@ -305,274 +305,4 @@ getEncodingInfoForMatmul(Encoding::EncodingAttr encoding, TileMxNxK tileMxNxK) {
305
305
return encodingInfo;
306
306
}
307
307
308
- static RankedTensorType dropEncoding (RankedTensorType type) {
309
- return RankedTensorType::get (type.getShape (), type.getElementType ());
310
- }
311
-
312
- static Operation *dropEncodingAndCloneOp (OpBuilder &builder, Operation *op,
313
- ValueRange convertedInputOperands,
314
- ValueRange convertedOutputOperands) {
315
- SmallVector<Value> operands;
316
- operands.append (convertedInputOperands.begin (), convertedInputOperands.end ());
317
- operands.append (convertedOutputOperands.begin (),
318
- convertedOutputOperands.end ());
319
- return mlir::clone (builder, op,
320
- {dropEncoding (cast<RankedTensorType>(
321
- convertedOutputOperands[0 ].getType ()))},
322
- operands);
323
- }
324
-
325
- static RankedTensorType
326
- getExpandedType (RankedTensorType type, bool isBatched, bool isTransposed,
327
- SmallVectorImpl<ReassociationIndices> &ri) {
328
- if (!isBatched) {
329
- ri.assign ({{0 , 1 }, {2 , 3 }});
330
- if (!isTransposed) {
331
- return RankedTensorType::get (
332
- {1 , type.getDimSize (0 ), 1 , type.getDimSize (1 )},
333
- type.getElementType ());
334
- }
335
- return RankedTensorType::get ({type.getDimSize (0 ), 1 , type.getDimSize (1 ), 1 },
336
- type.getElementType ());
337
- }
338
-
339
- ri.assign ({{0 }, {1 , 2 }, {3 , 4 }});
340
- if (!isTransposed) {
341
- return RankedTensorType::get (
342
- {type.getDimSize (0 ), 1 , type.getDimSize (1 ), 1 , type.getDimSize (2 )},
343
- type.getElementType ());
344
- }
345
- return RankedTensorType::get (
346
- {type.getDimSize (0 ), type.getDimSize (1 ), 1 , type.getDimSize (2 ), 1 },
347
- type.getElementType ());
348
- }
349
-
350
- // / Given an input Value and a desired output element type, create and return
351
- // / an element-wise linalg::GenericOp that extends the input Value to the
352
- // / output element type.
353
- static Value createElementWiseExtUIOp (OpBuilder &builder, Value input,
354
- Location loc, Type outElemType) {
355
- auto inputType = cast<RankedTensorType>(input.getType ());
356
- SmallVector<AffineMap> maps (
357
- 2 , builder.getMultiDimIdentityMap (inputType.getRank ()));
358
- SmallVector<utils::IteratorType> iteratorTypes (inputType.getRank (),
359
- utils::IteratorType::parallel);
360
- auto castedType = inputType.clone (outElemType);
361
- SmallVector<OpFoldResult> inputMixedSizes =
362
- tensor::getMixedSizes (builder, loc, input);
363
- Value init =
364
- builder.create <tensor::EmptyOp>(loc, inputMixedSizes, outElemType);
365
- return builder
366
- .create <linalg::GenericOp>(
367
- loc, castedType, input, init, maps, iteratorTypes,
368
- [&](OpBuilder &b, Location nestedLoc, ValueRange args) {
369
- Value castRes =
370
- b.create <arith::ExtUIOp>(nestedLoc, outElemType, args[0 ])
371
- ->getResult (0 );
372
- b.create <linalg::YieldOp>(nestedLoc, castRes);
373
- })
374
- .getResult (0 );
375
- }
376
-
377
- // / If needed, expand and the input Value, and return the resulting input with
378
- // / the canonical mmt4d input shape. If the input element type is unsigned,
379
- // / create a producer Linalg::GenericOp on the input that unsigned extends the
380
- // / input to the output element type. This extension is required to keep the
381
- // / unsignedness information on the input for ukernels. If `transpose` is true,
382
- // / the `linalgOp`'s indexing maps are transposed.
383
- static Value getMmt4dOperand (Value value, linalg::LinalgOp linalgOp,
384
- bool transpose, OpBuilder &builder,
385
- SmallVectorImpl<ReassociationIndices> &ri,
386
- ArrayRef<Type> elemTypes, int operandIdx) {
387
- assert (linalgOp.getNumDpsInputs () == 2 );
388
- assert (linalgOp.getNumDpsInits () == 1 );
389
- auto cDims = linalg::inferContractionDims (linalgOp);
390
- Location loc = linalgOp->getLoc ();
391
- Value expandedValue = value;
392
- // If vecmat with non-rhs operandIdx or matvec with non-lhs operandIdx, the
393
- // operand is a vector and must be extended
394
- if ((cDims->m .empty () && operandIdx != 1 ) ||
395
- (cDims->n .empty () && operandIdx != 0 )) {
396
- auto type = cast<RankedTensorType>(value.getType ());
397
- RankedTensorType newType = getExpandedType (
398
- type, /* isBatched=*/ !cDims->batch .empty (),
399
- /* isTransposed=*/ operandIdx == 2 && (transpose ^ cDims->n .empty ()), ri);
400
- expandedValue =
401
- builder.create <tensor::ExpandShapeOp>(loc, newType, value, ri);
402
- }
403
- if (elemTypes[operandIdx].isUnsignedInteger ()) {
404
- return createElementWiseExtUIOp (builder, expandedValue, loc,
405
- elemTypes.back ());
406
- }
407
- return expandedValue;
408
- }
409
-
410
- TileMxNxK chooseMatmulTile (ArrayRef<TileMxNxK> enumeratedTiles,
411
- IREE::Encoding::MatmulNarrowDim narrowDim,
412
- ArrayRef<int64_t > hostDefinedUpperBound) {
413
- assert ((hostDefinedUpperBound.empty () || hostDefinedUpperBound.size () >= 3 ) &&
414
- " expected hostDefinedUpperBound is empty or has upper bound for {M, "
415
- " N, K}" );
416
- // Handle narrow-N by transposing to reduce to narrow-M. Note: the
417
- // enumeratedTiles currently only enumerate narrow-M cases.
418
- if (narrowDim.isN ()) {
419
- SmallVector<int64_t > newHostDefinedUpperBound (hostDefinedUpperBound);
420
- std::swap (newHostDefinedUpperBound[0 ], newHostDefinedUpperBound[1 ]);
421
- narrowDim.dim = IREE::Encoding::MatmulNarrowDim::Dim::M;
422
- TileMxNxK tile =
423
- chooseMatmulTile (enumeratedTiles, narrowDim, newHostDefinedUpperBound);
424
- std::swap (tile.M , tile.N );
425
- return tile;
426
- }
427
- // Handle kDynamic: currently this is only used with VMVX, where there is only
428
- // one enumerated tile and it has all three M/N/K dimensions dynamic, so for
429
- // now we only support that. Generalize that as needed when more dynamic tile
430
- // sizes are used outside of VMVX, e.g. perhaps some day with Arm SVE. Decide
431
- // how to incorporate the handling of kDynamic in the cost-model evaluation
432
- // below to decide when to prefer a dynamic vs a static tile shape.
433
- for (auto tile : enumeratedTiles) {
434
- if (ShapedType::isDynamic (tile.M ) || ShapedType::isDynamic (tile.N ) ||
435
- ShapedType::isDynamic (tile.K )) {
436
- assert (enumeratedTiles.size () == 1 );
437
- assert (ShapedType::isDynamic (tile.M ) && ShapedType::isDynamic (tile.N ) &&
438
- ShapedType::isDynamic (tile.K ));
439
- return tile;
440
- }
441
- }
442
- // We're going to "rate" the enumerated tiles.
443
- struct RatedTileMxNxK : TileMxNxK {
444
- RatedTileMxNxK () {}
445
- RatedTileMxNxK (TileMxNxK tile) : TileMxNxK(tile) {}
446
- // Penalize tiles that are wider in the M dimension than matmulNarrowM.
447
- int64_t paddingPenalty = 0 ;
448
- // Favor larger tiles, as long as they still minimize paddingPenalty.
449
- int64_t productMxNxK = 0 ;
450
- };
451
- SmallVector<RatedTileMxNxK> ratedTiles;
452
- ratedTiles.reserve (enumeratedTiles.size ());
453
- int64_t bestPaddingPenalty = INT64_MAX;
454
- int64_t mUB = INT64_MAX;
455
- int64_t nUB = INT64_MAX;
456
- int64_t kUB = INT64_MAX;
457
- if (!hostDefinedUpperBound.empty ()) {
458
- mUB = hostDefinedUpperBound[0 ];
459
- nUB = hostDefinedUpperBound[1 ];
460
- kUB = hostDefinedUpperBound[2 ];
461
- }
462
- for (auto tile : enumeratedTiles) {
463
- if (tile.M > mUB || tile.N > nUB || tile.K > kUB ) {
464
- LLVM_DEBUG (llvm::dbgs () << " [" << DEBUG_TYPE << " ]: tile (" ;
465
- llvm::interleaveComma (
466
- ArrayRef<int64_t >{tile.M , tile.N , tile.K }, llvm::dbgs ());
467
- llvm::dbgs ()
468
- << " ) is skipped because it is not valid for upper_bound (" ;
469
- llvm::interleaveComma (ArrayRef<int64_t >{mUB , nUB, kUB },
470
- llvm::dbgs ());
471
- llvm::dbgs () << " )\n " );
472
- continue ;
473
- }
474
- RatedTileMxNxK ratedTile (tile);
475
- ratedTile.paddingPenalty = 0 ;
476
- // If we are choosing a tile for a narrow-M case, we want to minimize
477
- // padding along the M dimension.
478
- // The PowerOf2Ceil is so that we are OK with padding up to the next
479
- // power of two, we just try to avoid padding beyond that. For example,
480
- // if matmulNarrowM==7 and we have enumerated tiles with M=8,4,2,1, we
481
- // are OK with the tile that has M==8 even though it requires some padding.
482
- // Otherwise, we would be penalizing the tiles with M==8,4,2 and we would
483
- // end up selecting the vecmat tile (M==1) for that case!
484
- if (narrowDim) {
485
- ratedTile.paddingPenalty =
486
- std::max<int64_t >(tile.M - llvm::PowerOf2Ceil (narrowDim.size ), 0 );
487
- }
488
- ratedTile.productMxNxK = tile.M * tile.N * tile.K ;
489
- ratedTiles.push_back (ratedTile);
490
- LLVM_DEBUG (llvm::dbgs () << " candidate: " ; llvm::interleaveComma (
491
- ArrayRef<int64_t >{tile.M , tile.N , tile.K }, llvm::dbgs ());
492
- llvm::dbgs () << " penalty:" << ratedTile.paddingPenalty << " \n " );
493
- bestPaddingPenalty = std::min (bestPaddingPenalty, ratedTile.paddingPenalty );
494
- }
495
- RatedTileMxNxK bestRatedTile;
496
- for (auto ratedTile : ratedTiles) {
497
- // Choose only among tiles that minimize paddingPenalty. Among those,
498
- // maximize productMxNxK.
499
- if (ratedTile.paddingPenalty == bestPaddingPenalty &&
500
- bestRatedTile.productMxNxK < ratedTile.productMxNxK ) {
501
- bestRatedTile = ratedTile;
502
- }
503
- }
504
- // Sanity check. This assert can only fail if there's a programming mistake
505
- // locally here.
506
- assert (bestRatedTile.paddingPenalty == bestPaddingPenalty);
507
- return bestRatedTile;
508
- }
509
-
510
- FailureOr<Operation *>
511
- lowerContractionOpWithEncoding (OpBuilder &builder, linalg::LinalgOp linalgOp,
512
- ValueRange operands, bool transposeNarrowN,
513
- LayoutAttrInterface layoutAttr) {
514
- if (!linalgOp.hasPureTensorSemantics ()) {
515
- return failure ();
516
- }
517
-
518
- auto inputs = linalgOp.getDpsInputOperands ();
519
- auto outputs = linalgOp.getDpsInits ();
520
-
521
- auto lhsType = cast<RankedTensorType>(inputs[0 ]->get ().getType ());
522
- auto rhsType = cast<RankedTensorType>(inputs[1 ]->get ().getType ());
523
- auto resultType = cast<RankedTensorType>(outputs[0 ].getType ());
524
- auto lhsEncoding = IREE::Encoding::getEncodingAttr (lhsType);
525
- auto rhsEncoding = IREE::Encoding::getEncodingAttr (rhsType);
526
- auto resultEncoding = IREE::Encoding::getEncodingAttr (resultType);
527
- if (!lhsEncoding || !rhsEncoding || !resultEncoding) {
528
- return failure ();
529
- }
530
-
531
- if (lhsEncoding.getOperandIndex ().getValue () != IREE::Encoding::MATMUL_LHS ||
532
- rhsEncoding.getOperandIndex ().getValue () != IREE::Encoding::MATMUL_RHS ||
533
- resultEncoding.getOperandIndex ().getValue () !=
534
- IREE::Encoding::MATMUL_RESULT) {
535
- return failure ();
536
- }
537
-
538
- MaterializeEncodingInfo encodingInfo = layoutAttr.getEncodingInfo (
539
- cast<RankedTensorType>(linalgOp->getResultTypes ()[0 ]));
540
-
541
- if (isIdentityLayout (encodingInfo)) {
542
- return dropEncodingAndCloneOp (builder, linalgOp,
543
- operands.take_front (inputs.size ()),
544
- operands.drop_front (inputs.size ()));
545
- }
546
-
547
- bool transpose = transposeNarrowN && isNarrowNResult (resultEncoding);
548
- SmallVector<Type> elemTypes = lhsEncoding.getElementTypesArray ();
549
- SmallVector<ReassociationIndices> ri;
550
- Value newLhs = getMmt4dOperand (operands[0 ], linalgOp, transpose, builder, ri,
551
- elemTypes, /* operandIdx=*/ 0 );
552
- Value newRhs = getMmt4dOperand (operands[1 ], linalgOp, transpose, builder, ri,
553
- elemTypes, /* operandIdx=*/ 1 );
554
- Value newResult = getMmt4dOperand (operands[2 ], linalgOp, transpose, builder,
555
- ri, elemTypes, /* operandIdx=*/ 2 );
556
- if (transpose) {
557
- std::swap (newLhs, newRhs);
558
- }
559
- Type newResultType = newResult.getType ();
560
- auto cDims = IREE::Encoding::getEncodingContractionDims (lhsEncoding);
561
- Operation *result;
562
- if (cDims->batch .empty ()) {
563
- result = builder.create <linalg::Mmt4DOp>(linalgOp.getLoc (), newResultType,
564
- ValueRange{newLhs, newRhs},
565
- ValueRange{newResult});
566
- } else {
567
- result = builder.create <linalg::BatchMmt4DOp>(
568
- linalgOp.getLoc (), newResultType, ValueRange{newLhs, newRhs},
569
- ValueRange{newResult});
570
- }
571
- if (!ri.empty ()) {
572
- result = builder.create <tensor::CollapseShapeOp>(
573
- linalgOp->getLoc (), operands[2 ].getType (), result->getResult (0 ), ri);
574
- }
575
- return result;
576
- }
577
-
578
308
} // namespace mlir::iree_compiler::IREE::Codegen
0 commit comments