Skip to content

Commit ed9a028

Browse files
authored
GPU Data-tiled multi-mma: subgroup dimensions should be outer (#19521)
This was already the idea, but there was an accidental exception: in the accumulator tensor, if there was both a `unroll_m` dimension and `subgroup_n` dimension, then the `subgroup_n` dimension wasn't on the outside of `unroll_m` as it was meant to be. Noticed this when it required corresponding strides in the ukernel. Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
1 parent 16097c1 commit ed9a028

File tree

3 files changed

+23
-23
lines changed

3 files changed

+23
-23
lines changed

compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx942.mlir

+15-15
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,8 @@ func.func @set_encoding_ACC_unroll8x8x4_MFMA_F32_16x16x4_F32() {
230230
// CHECK-SAME : tensor<2x5x128x128xf32> into tensor<2x5x8x4x4x4x2x16xf32>
231231
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
232232
// CHECK-SAME: ins(%[[EXPAND]] : tensor<2x5x8x4x4x4x2x16xf32>)
233-
// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x2x4x16x4xf32>)
234-
// CHECK-SAME: permutation = [0, 1, 2, 5, 6, 3, 7, 4]
233+
// CHECK-SAME: outs({{.*}} : tensor<2x5x4x8x2x4x16x4xf32>)
234+
// CHECK-SAME: permutation = [0, 1, 5, 2, 6, 3, 7, 4]
235235
// CHECK: flow.dispatch.tensor.store %[[TRANSPOSE]]
236236

237237
// -----
@@ -255,9 +255,9 @@ func.func @unset_encoding_ACC_unroll8x8x4_MFMA_F32_16x16x4_F32() {
255255

256256
// CHECK-LABEL: func.func @unset_encoding_ACC_unroll8x8x4_MFMA_F32_16x16x4_F32() {
257257
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
258-
// CHECK-SAME: ins(%{{.+}} : tensor<2x5x8x4x2x4x16x4xf32>)
258+
// CHECK-SAME: ins(%{{.+}} : tensor<2x5x4x8x2x4x16x4xf32>)
259259
// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x4x4x2x16xf32>)
260-
// CHECK-SAME: permutation = [0, 1, 2, 5, 7, 3, 4, 6]
260+
// CHECK-SAME: permutation = [0, 1, 3, 5, 7, 2, 4, 6]
261261
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[TRANSPOSE]]
262262
// CHECK-SAME: : tensor<2x5x8x4x4x4x2x16xf32> into tensor<2x5x128x128xf32>
263263
// CHECK: %[[UNPACK:.*]] = tensor.unpack %[[COLLAPSE]]
@@ -298,9 +298,9 @@ func.func @unset_encoding_ACC_dynamic_unroll8x8x4_MFMA_F32_16x16x4_F32() {
298298
}
299299
// CHECK-LABEL: func.func @unset_encoding_ACC_dynamic_unroll8x8x4_MFMA_F32_16x16x4_F32
300300
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
301-
// CHECK-SAME: ins(%{{.+}} : tensor<?x?x8x4x2x4x16x4xf32>)
301+
// CHECK-SAME: ins(%{{.+}} : tensor<?x?x4x8x2x4x16x4xf32>)
302302
// CHECK-SAME: outs({{.*}} : tensor<?x?x8x4x4x4x2x16xf32>)
303-
// CHECK-SAME: permutation = [0, 1, 2, 5, 7, 3, 4, 6]
303+
// CHECK-SAME: permutation = [0, 1, 3, 5, 7, 2, 4, 6]
304304
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[TRANSPOSE]]
305305
// CHECK-SAME: : tensor<?x?x8x4x4x4x2x16xf32> into tensor<?x?x128x128xf32>
306306
// CHECK: %[[UNPACK:.*]] = tensor.unpack %[[COLLAPSE]]
@@ -362,7 +362,7 @@ func.func @matmul_lowering_MFMA_F32_16x16x4_F32() {
362362
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
363363
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x8x4x16x4xf32>
364364
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x2x4x16x4xf32>
365-
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x8x4x2x4x16x4xf32>
365+
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x4x8x2x4x16x4xf32>
366366
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
367367
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
368368
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
@@ -422,7 +422,7 @@ func.func @batch_matmul_lowering_MFMA_F32_16x16x4_F32() {
422422
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
423423
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x?x8x4x16x4xf32>
424424
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x?x4x2x4x16x4xf32>
425-
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x?x8x4x2x4x16x4xf32>
425+
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x?x4x8x2x4x16x4xf32>
426426
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
427427
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
428428
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
@@ -528,8 +528,8 @@ func.func @set_encoding_ACC_unroll8x8x2_MFMA_I32_16x16x32_I8() {
528528
// CHECK-SAME : tensor<2x5x128x128xi32> into tensor<2x5x8x4x4x4x2x16xi32>
529529
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
530530
// CHECK-SAME: ins(%[[EXPAND]] : tensor<2x5x8x4x4x4x2x16xi32>)
531-
// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x2x4x16x4xi32>)
532-
// CHECK-SAME: permutation = [0, 1, 2, 5, 6, 3, 7, 4]
531+
// CHECK-SAME: outs({{.*}} : tensor<2x5x4x8x2x4x16x4xi32>)
532+
// CHECK-SAME: permutation = [0, 1, 5, 2, 6, 3, 7, 4]
533533
// CHECK: flow.dispatch.tensor.store %[[TRANSPOSE]]
534534

535535
// -----
@@ -553,9 +553,9 @@ func.func @unset_encoding_ACC_unroll8x8x2_MFMA_I32_16x16x32_I8() {
553553

554554
// CHECK-LABEL: func.func @unset_encoding_ACC_unroll8x8x2_MFMA_I32_16x16x32_I8() {
555555
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
556-
// CHECK-SAME: ins(%{{.+}} : tensor<2x5x8x4x2x4x16x4xi32>)
556+
// CHECK-SAME: ins(%{{.+}} : tensor<2x5x4x8x2x4x16x4xi32>)
557557
// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x4x4x2x16xi32>)
558-
// CHECK-SAME: permutation = [0, 1, 2, 5, 7, 3, 4, 6]
558+
// CHECK-SAME: permutation = [0, 1, 3, 5, 7, 2, 4, 6]
559559
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[TRANSPOSE]]
560560
// CHECK-SAME: : tensor<2x5x8x4x4x4x2x16xi32> into tensor<2x5x128x128xi32>
561561
// CHECK: %[[UNPACK:.*]] = tensor.unpack %[[COLLAPSE]]
@@ -618,7 +618,7 @@ func.func @matmul_lowering_MFMA_I32_16x16x32_I8() {
618618
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
619619
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x8x4x16x2x8xi8>
620620
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x2x4x16x2x8xi8>
621-
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x8x4x2x4x16x4xi32>
621+
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x4x8x2x4x16x4xi32>
622622
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
623623
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
624624
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
@@ -1124,7 +1124,7 @@ func.func @batch_matmul_lowering_MFMA_F32_16x16x32_F8E4M3FNUZ() {
11241124
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
11251125
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x?x8x4x16x2x8xf8E4M3FNUZ>
11261126
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x?x4x2x4x16x2x8xf8E4M3FNUZ>
1127-
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x?x8x4x2x4x16x4xf32>
1127+
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x?x4x8x2x4x16x4xf32>
11281128
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
11291129
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
11301130
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
@@ -1184,7 +1184,7 @@ func.func @batch_matmul_lowering_MFMA_F32_16x16x16_BF16() {
11841184
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
11851185
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x?x8x4x16x2x4xbf16>
11861186
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x?x4x2x4x16x2x4xbf16>
1187-
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x?x8x4x2x4x16x4xf32>
1187+
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x?x4x8x2x4x16x4xf32>
11881188
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
11891189
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
11901190
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,12 @@ TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
183183
if (mma.getUnrollN() > 1) {
184184
expand(swizzle, 1, {Kind::CrossIntrinsic, mma.getUnrollN()});
185185
}
186-
if (mma.getSubgroupsN() > 1) {
187-
expand(swizzle, 1, {Kind::CrossThread, mma.getSubgroupsN()});
188-
}
189186
if (mma.getUnrollM() > 1) {
190187
expand(swizzle, 0, {Kind::CrossIntrinsic, mma.getUnrollM()});
191188
}
189+
if (mma.getSubgroupsN() > 1) {
190+
expand(swizzle, 1, {Kind::CrossThread, mma.getSubgroupsN()});
191+
}
192192
if (mma.getSubgroupsM() > 1) {
193193
expand(swizzle, 0, {Kind::CrossThread, mma.getSubgroupsM()});
194194
}

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir

+5-5
Original file line numberDiff line numberDiff line change
@@ -755,11 +755,11 @@ hal.executable public @main {
755755
// CHECK: gpu.barrier
756756
// CHECK-DAG: %[[A_READ:.+]] = vector.transfer_read %[[A_ALLOC]]{{.*}} vector<8x1x1x4xf32>
757757
// CHECK-DAG: %[[B_READ:.+]] = vector.transfer_read %[[B_ALLOC]]{{.*}} vector<2x1x1x4xf32>
758-
// CHECK-DAG: %[[C_READ:.+]] = vector.transfer_read %[[BINDING_C]]{{.*}} vector<8x1x2x1x1x4xf32>
759-
// CHECK-DAG: %[[C_00_0:.+]] = vector.extract %[[C_READ]][0, 0, 0, 0, 0] : vector<4xf32> from vector<8x1x2x1x1x4xf32>
760-
// CHECK-DAG: %[[C_01_0:.+]] = vector.extract %[[C_READ]][0, 0, 1, 0, 0] : vector<4xf32> from vector<8x1x2x1x1x4xf32>
761-
// CHECK-DAG: %[[C_70_0:.+]] = vector.extract %[[C_READ]][7, 0, 0, 0, 0] : vector<4xf32> from vector<8x1x2x1x1x4xf32>
762-
// CHECK-DAG: %[[C_71_0:.+]] = vector.extract %[[C_READ]][7, 0, 1, 0, 0] : vector<4xf32> from vector<8x1x2x1x1x4xf32>
758+
// CHECK-DAG: %[[C_READ:.+]] = vector.transfer_read %[[BINDING_C]]{{.*}} vector<8x2x1x1x4xf32>
759+
// CHECK-DAG: %[[C_00_0:.+]] = vector.extract %[[C_READ]][0, 0, 0, 0] : vector<4xf32> from vector<8x2x1x1x4xf32>
760+
// CHECK-DAG: %[[C_01_0:.+]] = vector.extract %[[C_READ]][0, 1, 0, 0] : vector<4xf32> from vector<8x2x1x1x4xf32>
761+
// CHECK-DAG: %[[C_70_0:.+]] = vector.extract %[[C_READ]][7, 0, 0, 0] : vector<4xf32> from vector<8x2x1x1x4xf32>
762+
// CHECK-DAG: %[[C_71_0:.+]] = vector.extract %[[C_READ]][7, 1, 0, 0] : vector<4xf32> from vector<8x2x1x1x4xf32>
763763
// CHECK-DAG: %[[A_EXTRACT00:.+]] = vector.extract %[[A_READ]][0, 0, 0, 0] : f32 from vector<8x1x1x4xf32>
764764
// CHECK-DAG: %[[A_EXTRACT01:.+]] = vector.extract %[[A_READ]][0, 0, 0, 1] : f32 from vector<8x1x1x4xf32>
765765
// CHECK-DAG: %[[A_EXTRACT02:.+]] = vector.extract %[[A_READ]][0, 0, 0, 2] : f32 from vector<8x1x1x4xf32>

0 commit comments

Comments
 (0)