@@ -1295,9 +1295,11 @@ static LogicalResult setContractConfig(IREE::GPU::TargetAttr target,
1295
1295
CodeGenPipeline pipeline) {
1296
1296
TileSizesListType tileSizes;
1297
1297
unsigned numParallelLoops = op.getNumParallelLoops ();
1298
- SmallVector<int64_t > workgroupTileSizes (numParallelLoops - 2 , 1 );
1299
- workgroupTileSizes.append ({tileX, tileY});
1300
- workgroupTileSizes.append (op.getNumReductionLoops (), tileK);
1298
+ unsigned numReductionLoops = op.getNumReductionLoops ();
1299
+ SmallVector<int64_t > workgroupTileSizes (
1300
+ numParallelLoops + numReductionLoops, 1 );
1301
+ workgroupTileSizes[numParallelLoops - 2 ] = tileX;
1302
+ workgroupTileSizes[numParallelLoops - 1 ] = tileY;
1301
1303
1302
1304
SmallVector<unsigned > partitionedLoops =
1303
1305
cast<PartitionableLoopsInterface>(op.getOperation ())
@@ -1311,11 +1313,63 @@ static LogicalResult setContractConfig(IREE::GPU::TargetAttr target,
1311
1313
}
1312
1314
}
1313
1315
1314
- tileSizes.emplace_back (std::move (workgroupTileSizes)); // Workgroup level.
1315
1316
std::optional<int64_t > subgroupSize = std::nullopt;
1316
1317
if (!subgroupSizes.empty ())
1317
1318
subgroupSize = subgroupSizes.front ();
1318
1319
1320
+ // For the LLVMGPUTileAndFuse pipeline, we need to split tile sizes
1321
+ // for workgroup, thread, and reduction.
1322
+ if (pipeline == CodeGenPipeline::LLVMGPUTileAndFuse) {
1323
+
1324
+ auto context = op.getContext ();
1325
+ Builder b (context);
1326
+ SmallVector<NamedAttribute, 1 > attrs;
1327
+
1328
+ SmallVector<int64_t > threadTileSizes (numParallelLoops + numReductionLoops,
1329
+ 0 );
1330
+ std::fill (threadTileSizes.begin (),
1331
+ threadTileSizes.begin () + numParallelLoops, 1 );
1332
+
1333
+ threadTileSizes[numParallelLoops - 2 ] =
1334
+ (tileX / workgroupSize[0 ]) < 1 ? 1 : (tileX / workgroupSize[0 ]);
1335
+ threadTileSizes[numParallelLoops - 1 ] =
1336
+ (tileY / workgroupSize[1 ]) < 1 ? 1 : (tileY / workgroupSize[1 ]);
1337
+
1338
+ SmallVector<int64_t > reductionTileSizes (
1339
+ numParallelLoops + numReductionLoops, 0 );
1340
+ reductionTileSizes[numParallelLoops + numReductionLoops - 1 ] = tileK;
1341
+
1342
+ attrs.emplace_back (b.getStringAttr (" workgroup" ),
1343
+ b.getI64ArrayAttr (workgroupTileSizes));
1344
+ attrs.emplace_back (b.getStringAttr (" thread" ),
1345
+ b.getI64ArrayAttr (threadTileSizes));
1346
+ attrs.emplace_back (b.getStringAttr (" reduction" ),
1347
+ b.getI64ArrayAttr (reductionTileSizes));
1348
+
1349
+ auto configDict = b.getDictionaryAttr (attrs);
1350
+ auto loweringConfig =
1351
+ IREE::GPU::LoweringConfigAttr::get (context, configDict);
1352
+ SmallVector<NamedAttribute, 1 > pipelineAttrs;
1353
+ auto pipelineOptions = IREE::GPU::GPUPipelineOptionsAttr::get (
1354
+ context, /* prefetchSharedMemory=*/ false ,
1355
+ /* no_reduce_shared_memory_bank_conflicts=*/ true ,
1356
+ /* use_igemm_convolution=*/ false ,
1357
+ /* reorder_workgroups_strategy=*/ std::nullopt);
1358
+ pipelineAttrs.emplace_back (
1359
+ b.getStringAttr (IREE::GPU::GPUPipelineOptionsAttr::getDictKeyName ()),
1360
+ pipelineOptions);
1361
+ auto pipelineConfig = b.getDictionaryAttr (pipelineAttrs);
1362
+
1363
+ return setOpConfigAndEntryPointFnTranslation (
1364
+ entryPoint, op, loweringConfig, pipeline, workgroupSize, subgroupSize,
1365
+ pipelineConfig);
1366
+ }
1367
+
1368
+ // Other pipeline (MatmulTensorCore) expect the reduction tile size to be in
1369
+ // the same list.
1370
+ workgroupTileSizes[numParallelLoops + numReductionLoops - 1 ] = tileK;
1371
+ tileSizes.emplace_back (std::move (workgroupTileSizes));
1372
+
1319
1373
return setOpConfigAndEntryPointFnTranslation (
1320
1374
entryPoint, op, tileSizes, pipeline, workgroupSize, subgroupSize,
1321
1375
getSoftwarePipeliningAttrDict (op->getContext (), softwarePipelineDepth,
@@ -1390,7 +1444,7 @@ static LogicalResult setContractConfig(IREE::GPU::TargetAttr target,
1390
1444
return setMatmulConfig (
1391
1445
sizeN, sizeM, 4 , {sizeM, sizeN, 1 },
1392
1446
target.getWgp ().getSubgroupSizeChoices ().asArrayRef (),
1393
- softwarePipelineDepthSimt, CodeGenPipeline::LLVMGPUMatmulSimt );
1447
+ softwarePipelineDepthSimt, CodeGenPipeline::LLVMGPUTileAndFuse );
1394
1448
}
1395
1449
1396
1450
// SIMT matmul case. Query the best configuration.
@@ -1404,7 +1458,7 @@ static LogicalResult setContractConfig(IREE::GPU::TargetAttr target,
1404
1458
config.tileSize [0 ], config.tileSize [1 ], config.tileSize [2 ],
1405
1459
config.workgroupSize ,
1406
1460
target.getWgp ().getSubgroupSizeChoices ().asArrayRef (),
1407
- softwarePipelineDepthSimt, CodeGenPipeline::LLVMGPUMatmulSimt );
1461
+ softwarePipelineDepthSimt, CodeGenPipeline::LLVMGPUTileAndFuse );
1408
1462
}
1409
1463
}
1410
1464
}
@@ -1429,7 +1483,7 @@ static LogicalResult setContractConfig(IREE::GPU::TargetAttr target,
1429
1483
return setMatmulConfig (tileX, tileY, tileK, workgroupSize,
1430
1484
target.getWgp ().getSubgroupSizeChoices ().asArrayRef (),
1431
1485
softwarePipelineDepthSimt,
1432
- CodeGenPipeline::LLVMGPUMatmulSimt );
1486
+ CodeGenPipeline::LLVMGPUTileAndFuse );
1433
1487
}
1434
1488
1435
1489
// ====---------------------------------------------------------------------===//
0 commit comments