@@ -118,8 +118,8 @@ LogicalResult setDataTiledMultiMmaLoweringConfig(
118
118
// / problem based on the available mma intrinsics.
119
119
static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget (
120
120
IREE::GPU::TargetAttr target, GPUMatmulShapeType problem,
121
- bool transposedLhs, bool transposedRhs, bool mustBeAligned = true ,
122
- bool doCPromotion = false ) {
121
+ bool transposedLhs, bool transposedRhs, bool isIGEMM ,
122
+ bool mustBeAligned = true , bool doCPromotion = false ) {
123
123
const int64_t targetSubgroupSize = target.getPreferredSubgroupSize ();
124
124
SmallVector<GPUMatmulShapeType> intrinsics;
125
125
for (IREE::GPU::MMAAttr mma : target.getWgp ().getMma ()) {
@@ -142,20 +142,22 @@ static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget(
142
142
// See https://github.com/iree-org/iree/issues/16341 for details.
143
143
int64_t mSize = ShapedType::getNumElements (problem.mSizes );
144
144
int64_t nSize = ShapedType::getNumElements (problem.nSizes );
145
+ int64_t cacheLineSizeElements = kCacheLineSizeBits / inBitWidth;
146
+ int64_t bestKElementCountPerSubgroup =
147
+ isIGEMM ? cacheLineSizeElements / 2 : cacheLineSizeElements;
145
148
if (mSize * nSize <= 512 * 512 ) {
146
149
// For matmuls with small M*N size, we want to distribute M*N onto more
147
150
// workgroups to fill the GPU. Use a smaller bestMNTileCountPerSubgroup
148
151
// and a larger bestKTileCountPerSubgroup.
149
152
seeds = {/* bestSubgroupCountPerWorkgroup=*/ 4 ,
150
153
/* bestMNTileCountPerSubgroup=*/ 4 ,
151
- /* bestKTileCountPerSubgroup=*/ 8 ,
152
- /* bestKElementCountPerSubgroup*/ kCacheLineSizeBits * 2 /
153
- inBitWidth};
154
+ /* bestKTileCountPerSubgroup=*/ 8 , bestKElementCountPerSubgroup * 2 };
154
155
} else {
156
+ int64_t bestKElementCountPerSubgroup =
157
+ isIGEMM ? cacheLineSizeElements / 2 : cacheLineSizeElements;
155
158
seeds = {/* bestSubgroupCountPerWorkgroup=*/ 4 ,
156
159
/* bestMNTileCountPerSubgroup=*/ 16 ,
157
- /* bestKTileCountPerSubgroup=*/ 4 ,
158
- /* bestKElementCountPerSubgroup*/ kCacheLineSizeBits / inBitWidth};
160
+ /* bestKTileCountPerSubgroup=*/ 4 , bestKElementCountPerSubgroup};
159
161
}
160
162
161
163
// We target slightly below the full available shared Memory to leave room for
@@ -181,7 +183,8 @@ static FailureOr<std::pair<LoweringConfigAttr, int64_t>>
181
183
getMatmulLoweringConfigAndWorkgroupSize (SmallVector<int64_t > bounds,
182
184
ArrayRef<AffineMap> maps,
183
185
ArrayRef<Value> operands,
184
- IREE::GPU::TargetAttr target) {
186
+ IREE::GPU::TargetAttr target,
187
+ bool isIGEMM) {
185
188
if (target.getWgp ().getMma ().empty ())
186
189
return failure ();
187
190
@@ -249,7 +252,7 @@ getMatmulLoweringConfigAndWorkgroupSize(SmallVector<int64_t> bounds,
249
252
bool mustBeAligned = true ;
250
253
bool doCPromotion = false ;
251
254
std::optional<GPUMMASchedule> schedule = getMmaScheduleFromProblemAndTarget (
252
- target, problem, transposedLhs, transposedRhs);
255
+ target, problem, transposedLhs, transposedRhs, isIGEMM );
253
256
254
257
// TODO (nirvedhmeshram, qedawkins): The performance with this will be bad if
255
258
// the GEMM is accumulating (i.e doesnt have a zero fill dpsInit) as that
@@ -259,9 +262,9 @@ getMatmulLoweringConfigAndWorkgroupSize(SmallVector<int64_t> bounds,
259
262
LDBG (" Attempting to deduce unaligned TileAndFuse MMA schedulee" );
260
263
mustBeAligned = false ;
261
264
doCPromotion = true ;
262
- schedule = getMmaScheduleFromProblemAndTarget (target, problem,
263
- transposedLhs, transposedRhs,
264
- mustBeAligned, doCPromotion);
265
+ schedule = getMmaScheduleFromProblemAndTarget (
266
+ target, problem, transposedLhs, transposedRhs, isIGEMM, mustBeAligned ,
267
+ doCPromotion);
265
268
}
266
269
267
270
if (!schedule) {
@@ -384,7 +387,8 @@ setIGEMMConvolutionLoweringConfig(IREE::GPU::TargetAttr target,
384
387
SmallVector<int64_t > bounds = igemmLoopBounds.value ();
385
388
FailureOr<std::pair<LoweringConfigAttr, int64_t >> configAndWgSize =
386
389
getMatmulLoweringConfigAndWorkgroupSize (
387
- bounds, igemmContractionMaps.value (), igemmOperands.value (), target);
390
+ bounds, igemmContractionMaps.value (), igemmOperands.value (), target,
391
+ /* isIGEMM=*/ true );
388
392
if (failed (configAndWgSize)) {
389
393
return failure ();
390
394
}
@@ -434,7 +438,8 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
434
438
LDBG (" Matmul TileAndFuse Config" );
435
439
436
440
FailureOr<std::pair<LoweringConfigAttr, int64_t >> configAndWgSize =
437
- getMatmulLoweringConfigAndWorkgroupSize (bounds, maps, operands, target);
441
+ getMatmulLoweringConfigAndWorkgroupSize (bounds, maps, operands, target,
442
+ /* isIGEMM=*/ false );
438
443
if (failed (configAndWgSize)) {
439
444
return failure ();
440
445
}
0 commit comments