Skip to content

Commit 2111358

Browse files
[GPU] Enable tile and fuse matmul by default
Signed-off-by: Nirvedh <nirvedh@gmail.com> Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
1 parent 6fdb2c5 commit 2111358

11 files changed

+88
-59
lines changed

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp

+19-9
Original file line numberDiff line numberDiff line change
@@ -164,13 +164,9 @@ static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget(
164164
problem, intrinsics, seeds, maxSharedMemoryBytes, targetSubgroupSize,
165165
transposedLhs, transposedRhs, /*canUpcastAcc=*/false,
166166
/*mustBeAligned*/ mustBeAligned, doCPromotion);
167-
if (!schedule) {
168-
// Then try again by allowing upcasting accumulator.
169-
schedule = deduceMMASchedule(
170-
problem, intrinsics, seeds, maxSharedMemoryBytes, targetSubgroupSize,
171-
transposedLhs, transposedRhs, /*canUpcastAcc=*/true,
172-
/*mustBeAligned*/ mustBeAligned, doCPromotion);
173-
}
167+
// TODO (nirvedhmeshram) : Add support for upcasting accumulator schedule.
168+
// Currently we dont have this for TileAndFuse path, see
169+
// https://github.com/iree-org/iree/issues/19532
174170
return schedule;
175171
}
176172

@@ -392,9 +388,16 @@ setIGEMMConvolutionLoweringConfig(IREE::GPU::TargetAttr target,
392388
std::array<int64_t, 3> workgroupSize = {configAndWgSize->second, 1, 1};
393389
LoweringConfigAttr loweringConfig = configAndWgSize->first;
394390

391+
bool usePrefetchSharedMemory = true;
392+
// Prefetching has issues when doing c promotion, see
393+
// https://github.com/iree-org/iree/issues/19612.
394+
if (llvm::any_of(getPromotedOperandList(loweringConfig).value(),
395+
[](int64_t promote) { return promote == 2; })) {
396+
usePrefetchSharedMemory = false;
397+
}
395398
SmallVector<NamedAttribute, 1> pipelineAttrs;
396399
auto pipelineOptions = IREE::GPU::GPUPipelineOptionsAttr::get(
397-
linalgOp->getContext(), /*prefetchSharedMemory=*/true,
400+
linalgOp->getContext(), /*prefetchSharedMemory=*/usePrefetchSharedMemory,
398401
/*no_reduce_shared_memory_bank_conflicts=*/false,
399402
/*use_igemm_convolution=*/true,
400403
/*reorder_workgroups_strategy=*/std::nullopt);
@@ -435,9 +438,16 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
435438
std::array<int64_t, 3> workgroupSize = {configAndWgSize->second, 1, 1};
436439
LoweringConfigAttr loweringConfig = configAndWgSize->first;
437440

441+
bool usePrefetchSharedMemory = true;
442+
// Prefetching has issues when doing c promotion, see
443+
// https://github.com/iree-org/iree/issues/19612.
444+
if (llvm::any_of(getPromotedOperandList(loweringConfig).value(),
445+
[](int64_t promote) { return promote == 2; })) {
446+
usePrefetchSharedMemory = false;
447+
}
438448
SmallVector<NamedAttribute, 1> pipelineAttrs;
439449
auto pipelineOptions = IREE::GPU::GPUPipelineOptionsAttr::get(
440-
linalgOp->getContext(), /*prefetchSharedMemory=*/true,
450+
linalgOp->getContext(), /*prefetchSharedMemory=*/usePrefetchSharedMemory,
441451
/*no_reduce_shared_memory_bank_conflicts=*/false,
442452
/*use_igemm_convolution=*/false,
443453
/*reorder_workgroups_strategy=*/std::nullopt);

compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@
4848
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
4949
namespace mlir::iree_compiler {
5050

51-
llvm::cl::opt<bool> clGPUTestTileAndFuseMatmul(
52-
"iree-codegen-llvmgpu-test-tile-and-fuse-matmul",
51+
llvm::cl::opt<bool> clGPUEnableTileAndFuseMatmul(
52+
"iree-codegen-llvmgpu-enable-tile-and-fuse-matmul",
5353
llvm::cl::desc("test the the tile and fuse pipeline for matmul"),
54-
llvm::cl::init(false));
54+
llvm::cl::init(true));
5555

5656
llvm::cl::opt<bool> clGPUTestTileAndFuseVectorize(
5757
"iree-codegen-llvmgpu-test-tile-and-fuse-vectorize",
@@ -2352,7 +2352,7 @@ static LogicalResult setRootConfig(IREE::GPU::TargetAttr target,
23522352
LDBG("Tile and fuse data tiled multi_mma config");
23532353
return success();
23542354
}
2355-
if (clGPUTestTileAndFuseMatmul) {
2355+
if (clGPUEnableTileAndFuseMatmul) {
23562356
if (succeeded(IREE::GPU::setMatmulLoweringConfig(target, entryPointFn,
23572357
computeOp))) {
23582358
LDBG("Tile and fuse matmul config");

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ func.func @nhwc_conv_unaligned_mfma() {
7676

7777
// CHECK-LABEL: func.func @nhwc_conv_unaligned_mfma
7878
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
79-
// CHECK-SAME: #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false
79+
// CHECK-SAME: #iree_gpu.pipeline_options<prefetch_shared_memory = false, no_reduce_shared_memory_bank_conflicts = false
8080
// CHECK-SAME: use_igemm_convolution = true
8181

8282
// CHECK: linalg.conv_2d_nhwc_hwcf {{.*}}lowering_config = #iree_gpu.lowering_config
@@ -106,7 +106,7 @@ func.func @nchw_conv_unaligned_mfma() {
106106

107107
// CHECK-LABEL: func.func @nchw_conv_unaligned_mfma
108108
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
109-
// CHECK-SAME: #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false
109+
// CHECK-SAME: #iree_gpu.pipeline_options<prefetch_shared_memory = false, no_reduce_shared_memory_bank_conflicts = false
110110
// CHECK-SAME: use_igemm_convolution = true
111111

112112
// CHECK: linalg.conv_2d_nchw_fchw {{.*}}lowering_config = #iree_gpu.lowering_config

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir

+42-36
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// RUN: iree-opt --mlir-print-local-scope --split-input-file --iree-gpu-test-target=gfx942 \
2-
// RUN: --iree-codegen-llvmgpu-test-tile-and-fuse-matmul=true --iree-codegen-llvmgpu-test-tile-and-fuse-vectorize=true \
2+
// RUN: --iree-codegen-llvmgpu-test-tile-and-fuse-vectorize=true \
33
// RUN: --iree-codegen-llvmgpu-use-igemm=false \
44
// RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s
55

@@ -10,21 +10,23 @@
1010
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>
1111
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
1212
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
13-
func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor<10x64x2048xf16>) -> tensor<2x10x64x64xf16> {
13+
func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor<10x64x2048xf16>) -> tensor<2x10x64x64xf32> {
1414
%c0 = arith.constant 0 : index
15-
%cst = arith.constant 0.000000e+00 : f16
16-
%5 = tensor.empty() : tensor<2x10x64x64xf16>
17-
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2x10x64x64xf16>) -> tensor<2x10x64x64xf16>
15+
%cst = arith.constant 0.000000e+00 : f32
16+
%5 = tensor.empty() : tensor<2x10x64x64xf32>
17+
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2x10x64x64xf32>) -> tensor<2x10x64x64xf32>
1818
%7 = linalg.generic {
1919
indexing_maps = [#map, #map1, #map2],
2020
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]}
21-
ins(%lhs, %rhs : tensor<2x64x2048xf16>, tensor<10x64x2048xf16>) outs(%6 : tensor<2x10x64x64xf16>) {
22-
^bb0(%in: f16, %in_0: f16, %out: f16):
23-
%8 = arith.mulf %in, %in_0 : f16
24-
%9 = arith.addf %8, %out : f16
25-
linalg.yield %9 : f16
26-
} -> tensor<2x10x64x64xf16>
27-
return %7 : tensor<2x10x64x64xf16>
21+
ins(%lhs, %rhs : tensor<2x64x2048xf16>, tensor<10x64x2048xf16>) outs(%6 : tensor<2x10x64x64xf32>) {
22+
^bb0(%in: f16, %in_0: f16, %out: f32):
23+
%8 = arith.extf %in : f16 to f32
24+
%9 = arith.extf %in_0 : f16 to f32
25+
%10 = arith.mulf %8, %9 : f32
26+
%11 = arith.addf %10, %out : f32
27+
linalg.yield %11 : f32
28+
} -> tensor<2x10x64x64xf32>
29+
return %7 : tensor<2x10x64x64xf32>
2830
}
2931

3032
// CHECK-LABEL: func.func @expanded_matmul_transpose_b
@@ -46,21 +48,23 @@ func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor
4648
#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d5)>
4749
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d3, d4, d5)>
4850
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
49-
func.func @multi_dim_mma_schedule(%lhs: tensor<10x32x128x16xf16>, %rhs: tensor<4x32x128x16xf16>) -> tensor<10x4x32x32xf16> {
51+
func.func @multi_dim_mma_schedule(%lhs: tensor<10x32x128x16xf16>, %rhs: tensor<4x32x128x16xf16>) -> tensor<10x4x32x32xf32> {
5052
%c0 = arith.constant 0 : index
51-
%cst = arith.constant 0.000000e+00 : f16
52-
%5 = tensor.empty() : tensor<10x4x32x32xf16>
53-
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<10x4x32x32xf16>) -> tensor<10x4x32x32xf16>
53+
%cst = arith.constant 0.000000e+00 : f32
54+
%5 = tensor.empty() : tensor<10x4x32x32xf32>
55+
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<10x4x32x32xf32>) -> tensor<10x4x32x32xf32>
5456
%7 = linalg.generic {
5557
indexing_maps = [#map, #map1, #map2],
5658
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
57-
ins(%lhs, %rhs : tensor<10x32x128x16xf16>, tensor<4x32x128x16xf16>) outs(%6 : tensor<10x4x32x32xf16>) {
58-
^bb0(%in: f16, %in_0: f16, %out: f16):
59-
%8 = arith.mulf %in, %in_0 : f16
60-
%9 = arith.addf %8, %out : f16
61-
linalg.yield %9 : f16
62-
} -> tensor<10x4x32x32xf16>
63-
return %7 : tensor<10x4x32x32xf16>
59+
ins(%lhs, %rhs : tensor<10x32x128x16xf16>, tensor<4x32x128x16xf16>) outs(%6 : tensor<10x4x32x32xf32>) {
60+
^bb0(%in: f16, %in_0: f16, %out: f32):
61+
%8 = arith.extf %in : f16 to f32
62+
%9 = arith.extf %in_0 : f16 to f32
63+
%10 = arith.mulf %8, %9 : f32
64+
%11 = arith.addf %10, %out : f32
65+
linalg.yield %11 : f32
66+
} -> tensor<10x4x32x32xf32>
67+
return %7 : tensor<10x4x32x32xf32>
6468
}
6569

6670
// CHECK-LABEL: func.func @multi_dim_mma_schedule
@@ -79,23 +83,25 @@ func.func @multi_dim_mma_schedule(%lhs: tensor<10x32x128x16xf16>, %rhs: tensor<4
7983
#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d5, d6)>
8084
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d4, d5, d6)>
8185
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
82-
func.func @dynamic_multi_dim_mma_schedule(%lhs: tensor<?x6x16x?x16xf16>, %rhs: tensor<?x32x?x16xf16>) -> tensor<?x6x?x16x32xf16> {
86+
func.func @dynamic_multi_dim_mma_schedule(%lhs: tensor<?x6x16x?x16xf16>, %rhs: tensor<?x32x?x16xf16>) -> tensor<?x6x?x16x32xf32> {
8387
%c0 = arith.constant 0 : index
84-
%cst = arith.constant 0.000000e+00 : f16
88+
%cst = arith.constant 0.000000e+00 : f32
8589
%d0 = tensor.dim %lhs, %c0 : tensor<?x6x16x?x16xf16>
8690
%d2 = tensor.dim %rhs, %c0 : tensor<?x32x?x16xf16>
87-
%5 = tensor.empty(%d0, %d2) : tensor<?x6x?x16x32xf16>
88-
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<?x6x?x16x32xf16>) -> tensor<?x6x?x16x32xf16>
91+
%5 = tensor.empty(%d0, %d2) : tensor<?x6x?x16x32xf32>
92+
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<?x6x?x16x32xf32>) -> tensor<?x6x?x16x32xf32>
8993
%7 = linalg.generic {
9094
indexing_maps = [#map, #map1, #map2],
9195
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
92-
ins(%lhs, %rhs : tensor<?x6x16x?x16xf16>, tensor<?x32x?x16xf16>) outs(%6 : tensor<?x6x?x16x32xf16>) {
93-
^bb0(%in: f16, %in_0: f16, %out: f16):
94-
%8 = arith.mulf %in, %in_0 : f16
95-
%9 = arith.addf %8, %out : f16
96-
linalg.yield %9 : f16
97-
} -> tensor<?x6x?x16x32xf16>
98-
return %7 : tensor<?x6x?x16x32xf16>
96+
ins(%lhs, %rhs : tensor<?x6x16x?x16xf16>, tensor<?x32x?x16xf16>) outs(%6 : tensor<?x6x?x16x32xf32>) {
97+
^bb0(%in: f16, %in_0: f16, %out: f32):
98+
%8 = arith.extf %in : f16 to f32
99+
%9 = arith.extf %in_0 : f16 to f32
100+
%10 = arith.mulf %8, %9 : f32
101+
%11 = arith.addf %10, %out : f32
102+
linalg.yield %11 : f32
103+
} -> tensor<?x6x?x16x32xf32>
104+
return %7 : tensor<?x6x?x16x32xf32>
99105
}
100106

101107
// CHECK-LABEL: func.func @dynamic_multi_dim_mma_schedule
@@ -271,7 +277,7 @@ func.func @unaligned_to_intrinsic_batched_matmul(%lhs : tensor<12x577x577xf32>,
271277

272278
// CHECK-LABEL: func.func @unaligned_to_intrinsic_batched_matmul
273279
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64
274-
// CHECK-SAME: {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>}
280+
// CHECK-SAME: {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = false, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>}
275281
// CHECK: linalg.batch_matmul {{.*}}lowering_config = #iree_gpu.lowering_config
276282
// CHECK-SAME: padding = [1, 16, 16, 4]
277283
// CHECK-SAME: promote_operands = [0, 1, 2]
@@ -300,7 +306,7 @@ func.func @unaligned_to_intrinsic_batched_matmul_tiling_check(%lhs : tensor<12x5
300306

301307
// CHECK-LABEL: func.func @unaligned_to_intrinsic_batched_matmul_tiling_check
302308
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
303-
// CHECK-SAME: {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>}
309+
// CHECK-SAME: {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = false, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>}
304310
// CHECK: linalg.batch_matmul {{.*}}lowering_config = #iree_gpu.lowering_config
305311
// CHECK-SAME: padding = [1, 16, 512, 4]
306312
// CHECK-SAME: promote_operands = [0, 1, 2]

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx1100.mlir

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx1100 --iree-codegen-llvmgpu-use-vector-distribution \
2+
// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false \
23
// RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s --check-prefix=WMMA
34

45
// TODO: This test is still using the legacy LLVMGPU kernel config. This needs

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx942.mlir

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 --iree-codegen-llvmgpu-use-vector-distribution \
22
// RUN: --iree-codegen-llvmgpu-use-unaligned-gemm-vector-distribution --iree-codegen-llvmgpu-use-igemm=false \
3+
// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false \
34
// RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s
45

56
// TODO: This test is still using the legacy LLVMGPU kernel config. This needs

compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,14 @@ func.func @custom_op(%arg0 : tensor<384x512xf32>, %arg1 : tensor<512x128xf32>,
3333
return %1 : tensor<384x128xf32>
3434
}
3535
// CHECK: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 0]]>
36-
// CHECK: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64,
36+
// CHECK: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64,
3737
// CHECK: func @custom_op
3838
// CHECK-SAME: translation_info = #[[TRANSLATION]]
3939
// CHECK: iree_linalg_ext.custom_op
4040
// CHECK-SAME: lowering_config = #[[CONFIG]]
4141
// CHECK: ^bb
4242
// CHECK: linalg.matmul
43-
// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>, promote_operands = [0, 1], reduction = [0, 0, 32], subgroup_m_count = 2 : i64, subgroup_n_count = 2 : i64, workgroup = [64, 64, 0]}>
43+
// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>, promote_operands = [0, 1], reduction = [0, 0, 8], subgroup = [2, 2, 0], workgroup = [64, 64, 0]}>
4444
// CHECK: iree_linalg_ext.yield
4545

4646
// -----

compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-codegen-llvmgpu-configuration-pipeline)" \
2-
// RUN: --iree-gpu-test-target=sm_60 %s | FileCheck %s
2+
// RUN: --iree-gpu-test-target=sm_60 --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false %s | FileCheck %s
33
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-codegen-llvmgpu-configuration-pipeline)" \
4-
// RUN: --iree-gpu-test-target=sm_80 %s | FileCheck %s --check-prefix=SM80
4+
// RUN: --iree-gpu-test-target=sm_80 --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false %s | FileCheck %s --check-prefix=SM80
55

66
// Transform dialect attributes are tested separately.
77

compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))' --iree-gpu-test-target=sm_80 -split-input-file %s -o - | FileCheck %s
1+
// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant( \
2+
// RUN: builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))' \
3+
// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false --iree-gpu-test-target=sm_80 -split-input-file %s -o - | FileCheck %s
24

35
// This test checks that the lowering of nvvm includes the extraction
46
// and optimization of address computations.

compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_mma_sync_pipeline_test.mlir

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_80 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" -iree-codegen-llvmgpu-use-mma-sync %s | FileCheck %s
1+
// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_80 \
2+
// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant( \
3+
// RUN: builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" \
4+
// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false -iree-codegen-llvmgpu-use-mma-sync %s | FileCheck %s
25

36
// Verify that a simple element wise op gets lowered succefully all the way to
47
// nvvm/llvm dialect via mma.sync path.

compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1-
// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_60 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" -iree-codegen-llvmgpu-use-wmma %s | FileCheck %s
2-
// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_80 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" -iree-codegen-llvmgpu-use-wmma %s | FileCheck %s --check-prefix=SM80
1+
// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_60 \
2+
// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant( \
3+
// RUN: builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" \
4+
// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false -iree-codegen-llvmgpu-use-wmma %s | FileCheck %s
5+
// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_80 \
6+
// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant( \
7+
// RUN: builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" \
8+
// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false -iree-codegen-llvmgpu-use-wmma %s | FileCheck %s --check-prefix=SM80
39

410
// Verify that a simple element wise op gets lowered succefully all the way to
511
// nvvm/llvm dialect.

0 commit comments

Comments
 (0)