Skip to content

Commit 91a09c7

Browse files
give seperate heuristics to IGEMM
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
1 parent d999ed1 commit 91a09c7

File tree

2 files changed

+25
-20
lines changed

2 files changed

+25
-20
lines changed

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp

+19-14
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ LogicalResult setDataTiledMultiMmaLoweringConfig(
118118
/// problem based on the available mma intrinsics.
119119
static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget(
120120
IREE::GPU::TargetAttr target, GPUMatmulShapeType problem,
121-
bool transposedLhs, bool transposedRhs, bool mustBeAligned = true,
122-
bool doCPromotion = false) {
121+
bool transposedLhs, bool transposedRhs, bool isIGEMM,
122+
bool mustBeAligned = true, bool doCPromotion = false) {
123123
const int64_t targetSubgroupSize = target.getPreferredSubgroupSize();
124124
SmallVector<GPUMatmulShapeType> intrinsics;
125125
for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
@@ -142,20 +142,22 @@ static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget(
142142
// See https://github.com/iree-org/iree/issues/16341 for details.
143143
int64_t mSize = ShapedType::getNumElements(problem.mSizes);
144144
int64_t nSize = ShapedType::getNumElements(problem.nSizes);
145+
int64_t cacheLineSizeElements = kCacheLineSizeBits / inBitWidth;
146+
int64_t bestKElementCountPerSubgroup =
147+
isIGEMM ? cacheLineSizeElements / 2 : cacheLineSizeElements;
145148
if (mSize * nSize <= 512 * 512) {
146149
// For matmuls with small M*N size, we want to distribute M*N onto more
147150
// workgroups to fill the GPU. Use a smaller bestMNTileCountPerSubgroup
148151
// and a larger bestKTileCountPerSubgroup.
149152
seeds = {/*bestSubgroupCountPerWorkgroup=*/4,
150153
/*bestMNTileCountPerSubgroup=*/4,
151-
/*bestKTileCountPerSubgroup=*/8,
152-
/*bestKElementCountPerSubgroup*/ kCacheLineSizeBits * 2 /
153-
inBitWidth};
154+
/*bestKTileCountPerSubgroup=*/8, bestKElementCountPerSubgroup * 2};
154155
} else {
156+
int64_t bestKElementCountPerSubgroup =
157+
isIGEMM ? cacheLineSizeElements / 2 : cacheLineSizeElements;
155158
seeds = {/*bestSubgroupCountPerWorkgroup=*/4,
156159
/*bestMNTileCountPerSubgroup=*/16,
157-
/*bestKTileCountPerSubgroup=*/4,
158-
/*bestKElementCountPerSubgroup*/ kCacheLineSizeBits / inBitWidth};
160+
/*bestKTileCountPerSubgroup=*/4, bestKElementCountPerSubgroup};
159161
}
160162

161163
// We target slightly below the full available shared Memory to leave room for
@@ -181,7 +183,8 @@ static FailureOr<std::pair<LoweringConfigAttr, int64_t>>
181183
getMatmulLoweringConfigAndWorkgroupSize(SmallVector<int64_t> bounds,
182184
ArrayRef<AffineMap> maps,
183185
ArrayRef<Value> operands,
184-
IREE::GPU::TargetAttr target) {
186+
IREE::GPU::TargetAttr target,
187+
bool isIGEMM) {
185188
if (target.getWgp().getMma().empty())
186189
return failure();
187190

@@ -249,7 +252,7 @@ getMatmulLoweringConfigAndWorkgroupSize(SmallVector<int64_t> bounds,
249252
bool mustBeAligned = true;
250253
bool doCPromotion = false;
251254
std::optional<GPUMMASchedule> schedule = getMmaScheduleFromProblemAndTarget(
252-
target, problem, transposedLhs, transposedRhs);
255+
target, problem, transposedLhs, transposedRhs, isIGEMM);
253256

254257
// TODO (nirvedhmeshram, qedawkins): The performance with this will be bad if
255258
// the GEMM is accumulating (i.e doesnt have a zero fill dpsInit) as that
@@ -259,9 +262,9 @@ getMatmulLoweringConfigAndWorkgroupSize(SmallVector<int64_t> bounds,
259262
LDBG("Attempting to deduce unaligned TileAndFuse MMA schedulee");
260263
mustBeAligned = false;
261264
doCPromotion = true;
262-
schedule = getMmaScheduleFromProblemAndTarget(target, problem,
263-
transposedLhs, transposedRhs,
264-
mustBeAligned, doCPromotion);
265+
schedule = getMmaScheduleFromProblemAndTarget(
266+
target, problem, transposedLhs, transposedRhs, isIGEMM, mustBeAligned,
267+
doCPromotion);
265268
}
266269

267270
if (!schedule) {
@@ -384,7 +387,8 @@ setIGEMMConvolutionLoweringConfig(IREE::GPU::TargetAttr target,
384387
SmallVector<int64_t> bounds = igemmLoopBounds.value();
385388
FailureOr<std::pair<LoweringConfigAttr, int64_t>> configAndWgSize =
386389
getMatmulLoweringConfigAndWorkgroupSize(
387-
bounds, igemmContractionMaps.value(), igemmOperands.value(), target);
390+
bounds, igemmContractionMaps.value(), igemmOperands.value(), target,
391+
/*isIGEMM=*/true);
388392
if (failed(configAndWgSize)) {
389393
return failure();
390394
}
@@ -434,7 +438,8 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
434438
LDBG("Matmul TileAndFuse Config");
435439

436440
FailureOr<std::pair<LoweringConfigAttr, int64_t>> configAndWgSize =
437-
getMatmulLoweringConfigAndWorkgroupSize(bounds, maps, operands, target);
441+
getMatmulLoweringConfigAndWorkgroupSize(bounds, maps, operands, target,
442+
/*isIGEMM=*/false);
438443
if (failed(configAndWgSize)) {
439444
return failure();
440445
}

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir

+6-6
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func.func @nhwc_conv_mfma() {
2424
// CHECK: linalg.conv_2d_nhwc_hwcf {{.*}}lowering_config = #iree_gpu.lowering_config
2525
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>
2626
// CHECK-SAME: promote_operands = [0, 1]
27-
// CHECK-SAME: reduction = [0, 0, 0, 0, 16]
27+
// CHECK-SAME: reduction = [0, 0, 0, 0, 8]
2828
// CHECK-SAME: subgroup = [1, 2, 2, 1, 0]
2929
// CHECK-SAME: workgroup = [1, 2, 32, 64, 0]
3030

@@ -53,7 +53,7 @@ func.func @nchw_conv_mfma() {
5353
// CHECK: linalg.conv_2d_nchw_fchw {{.*}}lowering_config = #iree_gpu.lowering_config
5454
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>
5555
// CHECK-SAME: promote_operands = [0, 1]
56-
// CHECK-SAME: reduction = [0, 0, 0, 0, 16]
56+
// CHECK-SAME: reduction = [0, 0, 0, 0, 8]
5757
// CHECK-SAME: subgroup = [1, 2, 2, 1, 0]
5858
// CHECK-SAME: workgroup = [1, 64, 2, 32, 0]
5959

@@ -81,9 +81,9 @@ func.func @nhwc_conv_unaligned_mfma() {
8181

8282
// CHECK: linalg.conv_2d_nhwc_hwcf {{.*}}lowering_config = #iree_gpu.lowering_config
8383
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>
84-
// CHECK-SAME: padding = [2, 1, 32, 64, 64]
84+
// CHECK-SAME: padding = [2, 1, 32, 64, 32]
8585
// CHECK-SAME: promote_operands = [0, 1, 2]
86-
// CHECK-SAME: reduction = [0, 0, 0, 0, 16]
86+
// CHECK-SAME: reduction = [0, 0, 0, 0, 8]
8787
// CHECK-SAME: subgroup = [2, 1, 2, 1, 0]
8888
// CHECK-SAME: workgroup = [2, 1, 32, 64, 0]
8989

@@ -111,8 +111,8 @@ func.func @nchw_conv_unaligned_mfma() {
111111

112112
// CHECK: linalg.conv_2d_nchw_fchw {{.*}}lowering_config = #iree_gpu.lowering_config
113113
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>
114-
// CHECK-SAME: padding = [1, 64, 2, 32, 64]
114+
// CHECK-SAME: padding = [1, 64, 2, 32, 32]
115115
// CHECK-SAME: promote_operands = [0, 1, 2]
116-
// CHECK-SAME: reduction = [0, 0, 0, 0, 16]
116+
// CHECK-SAME: reduction = [0, 0, 0, 0, 8]
117117
// CHECK-SAME: subgroup = [1, 2, 2, 1, 0]
118118
// CHECK-SAME: workgroup = [1, 64, 2, 32, 0]

0 commit comments

Comments
 (0)