Skip to content

Commit 345b1da

Browse files
Revert "[LLVMGPU] Deprecate the matmul simt pipeline (#19335)" (#19508)
This reverts commit 6ff00a8. The above commit causes Llama3.1 8B fp16 model to generate NaN logits for prefill/decode. Issue: #19506 Signed-off-by: archana-ramalingam <archana.ramalingam@amd.com>
1 parent 8ae1b54 commit 345b1da

17 files changed

+259
-108
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_reorder_workgroups_static.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
]>
2626
hal.executable private @main_dispatch_0 {
2727
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
28-
hal.executable.export public @main_dispatch_0_matmul_transpose_b_32000x32000x4096_f16 ordinal(0) layout(#pipeline_layout) attributes {subgroup_size = 64 : index, translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse, {pipeline_depth = 0 : i64, store_stage = 1 : i64}>, workgroup_size = [64 : index, 16 : index, 1 : index]} {
28+
hal.executable.export public @main_dispatch_0_matmul_transpose_b_32000x32000x4096_f16 ordinal(0) layout(#pipeline_layout) attributes {subgroup_size = 64 : index, translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUMatmulSimt, {pipeline_depth = 0 : i64, store_stage = 1 : i64}>, workgroup_size = [64 : index, 16 : index, 1 : index]} {
2929
^bb0(%arg0: !hal.device):
3030
%c250 = arith.constant 250 : index
3131
%c500 = arith.constant 500 : index

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td

+12-10
Original file line numberDiff line numberDiff line change
@@ -40,24 +40,26 @@ def LLVMGPU_SimpleDistribute
4040
: I32EnumAttrCase<"LLVMGPUDistribute", 102>;
4141
def LLVMGPU_Vectorize
4242
: I32EnumAttrCase<"LLVMGPUVectorize", 103>;
43+
def LLVMGPU_MatmulSimt
44+
: I32EnumAttrCase<"LLVMGPUMatmulSimt", 104>;
4345
def LLVMGPU_MatmulTensorCore
44-
: I32EnumAttrCase<"LLVMGPUMatmulTensorCore", 104>;
46+
: I32EnumAttrCase<"LLVMGPUMatmulTensorCore", 105>;
4547
def LLVMGPU_TransposeSharedMem
46-
: I32EnumAttrCase<"LLVMGPUTransposeSharedMem", 105>;
48+
: I32EnumAttrCase<"LLVMGPUTransposeSharedMem", 106>;
4749
def LLVMGPU_WarpReduction
48-
: I32EnumAttrCase<"LLVMGPUWarpReduction", 106>;
50+
: I32EnumAttrCase<"LLVMGPUWarpReduction", 107>;
4951
def LLVMGPU_PackUnPack
50-
: I32EnumAttrCase<"LLVMGPUPackUnPack", 107>;
52+
: I32EnumAttrCase<"LLVMGPUPackUnPack", 108>;
5153
def LLVMGPU_MatmulTensorCoreMmaSync
52-
: I32EnumAttrCase<"LLVMGPUMatmulTensorCoreMmaSync", 108>;
54+
: I32EnumAttrCase<"LLVMGPUMatmulTensorCoreMmaSync", 109>;
5355
def LLVMGPU_VectorDistribute
54-
: I32EnumAttrCase<"LLVMGPUVectorDistribute", 109>;
56+
: I32EnumAttrCase<"LLVMGPUVectorDistribute", 110>;
5557
def LLVMGPU_PadAndVectorDistribute
56-
: I32EnumAttrCase<"LLVMGPUPadAndVectorDistribute", 110>;
58+
: I32EnumAttrCase<"LLVMGPUPadAndVectorDistribute", 111>;
5759
def LLVMGPU_WinogradVectorize
58-
: I32EnumAttrCase<"LLVMGPUWinogradVectorize", 111>;
60+
: I32EnumAttrCase<"LLVMGPUWinogradVectorize", 112>;
5961
def LLVMGPU_TileAndFuse
60-
: I32EnumAttrCase<"LLVMGPUTileAndFuse", 112>;
62+
: I32EnumAttrCase<"LLVMGPUTileAndFuse", 113>;
6163

6264
def SPIRV_BaseLowering
6365
: I32EnumAttrCase<"SPIRVBaseLowering", 200>;
@@ -96,7 +98,7 @@ def DispatchLoweringPassPipelineEnum : I32EnumAttr<
9698

9799
// LLVMGPU CodeGen pipelines
98100
LLVMGPU_Default, LLVMGPU_BaseLowering, LLVMGPU_SimpleDistribute,
99-
LLVMGPU_Vectorize, LLVMGPU_MatmulTensorCore,
101+
LLVMGPU_Vectorize, LLVMGPU_MatmulSimt, LLVMGPU_MatmulTensorCore,
100102
LLVMGPU_TransposeSharedMem, LLVMGPU_WarpReduction, LLVMGPU_PackUnPack,
101103
LLVMGPU_MatmulTensorCoreMmaSync, LLVMGPU_VectorDistribute,
102104
LLVMGPU_PadAndVectorDistribute, LLVMGPU_WinogradVectorize,

compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp

+7-63
Original file line numberDiff line numberDiff line change
@@ -1295,11 +1295,9 @@ static LogicalResult setContractConfig(IREE::GPU::TargetAttr target,
12951295
CodeGenPipeline pipeline) {
12961296
TileSizesListType tileSizes;
12971297
unsigned numParallelLoops = op.getNumParallelLoops();
1298-
unsigned numReductionLoops = op.getNumReductionLoops();
1299-
SmallVector<int64_t> workgroupTileSizes(
1300-
numParallelLoops + numReductionLoops, 1);
1301-
workgroupTileSizes[numParallelLoops - 2] = tileX;
1302-
workgroupTileSizes[numParallelLoops - 1] = tileY;
1298+
SmallVector<int64_t> workgroupTileSizes(numParallelLoops - 2, 1);
1299+
workgroupTileSizes.append({tileX, tileY});
1300+
workgroupTileSizes.append(op.getNumReductionLoops(), tileK);
13031301

13041302
SmallVector<unsigned> partitionedLoops =
13051303
cast<PartitionableLoopsInterface>(op.getOperation())
@@ -1313,65 +1311,11 @@ static LogicalResult setContractConfig(IREE::GPU::TargetAttr target,
13131311
}
13141312
}
13151313

1314+
tileSizes.emplace_back(std::move(workgroupTileSizes)); // Workgroup level.
13161315
std::optional<int64_t> subgroupSize = std::nullopt;
13171316
if (!subgroupSizes.empty())
13181317
subgroupSize = subgroupSizes.front();
13191318

1320-
// For the LLVMGPUTileAndFuse pipeline, we need to split tile sizes
1321-
// for workgroup, thread, and reduction.
1322-
if (pipeline == CodeGenPipeline::LLVMGPUTileAndFuse) {
1323-
1324-
auto context = op.getContext();
1325-
Builder b(context);
1326-
SmallVector<NamedAttribute, 1> attrs;
1327-
1328-
SmallVector<int64_t> threadTileSizes(numParallelLoops + numReductionLoops,
1329-
0);
1330-
std::fill(threadTileSizes.begin(),
1331-
threadTileSizes.begin() + numParallelLoops, 1);
1332-
1333-
threadTileSizes[numParallelLoops - 2] =
1334-
(tileX / workgroupSize[0]) < 1 ? 1 : (tileX / workgroupSize[0]);
1335-
threadTileSizes[numParallelLoops - 1] =
1336-
(tileY / workgroupSize[1]) < 1 ? 1 : (tileY / workgroupSize[1]);
1337-
1338-
SmallVector<int64_t> reductionTileSizes(
1339-
numParallelLoops + numReductionLoops, 0);
1340-
reductionTileSizes[numParallelLoops + numReductionLoops - 1] = tileK;
1341-
1342-
attrs.emplace_back(b.getStringAttr("workgroup"),
1343-
b.getI64ArrayAttr(workgroupTileSizes));
1344-
attrs.emplace_back(b.getStringAttr("thread"),
1345-
b.getI64ArrayAttr(threadTileSizes));
1346-
attrs.emplace_back(b.getStringAttr("reduction"),
1347-
b.getI64ArrayAttr(reductionTileSizes));
1348-
1349-
// Promote operands to use shared memory for LHS and RHS.
1350-
IREE::GPU::setPromotedOperandList(context, attrs, {0, 1});
1351-
auto configDict = b.getDictionaryAttr(attrs);
1352-
auto loweringConfig =
1353-
IREE::GPU::LoweringConfigAttr::get(context, configDict);
1354-
SmallVector<NamedAttribute, 1> pipelineAttrs;
1355-
auto pipelineOptions = IREE::GPU::GPUPipelineOptionsAttr::get(
1356-
context, /*prefetchSharedMemory=*/false,
1357-
/*no_reduce_shared_memory_bank_conflicts=*/true,
1358-
/*use_igemm_convolution=*/false,
1359-
/*reorder_workgroups_strategy=*/std::nullopt);
1360-
pipelineAttrs.emplace_back(
1361-
b.getStringAttr(IREE::GPU::GPUPipelineOptionsAttr::getDictKeyName()),
1362-
pipelineOptions);
1363-
auto pipelineConfig = b.getDictionaryAttr(pipelineAttrs);
1364-
1365-
return setOpConfigAndEntryPointFnTranslation(
1366-
entryPoint, op, loweringConfig, pipeline, workgroupSize, subgroupSize,
1367-
pipelineConfig);
1368-
}
1369-
1370-
// Other pipeline (MatmulTensorCore) expect the reduction tile size to be in
1371-
// the same list.
1372-
workgroupTileSizes[numParallelLoops + numReductionLoops - 1] = tileK;
1373-
tileSizes.emplace_back(std::move(workgroupTileSizes));
1374-
13751319
return setOpConfigAndEntryPointFnTranslation(
13761320
entryPoint, op, tileSizes, pipeline, workgroupSize, subgroupSize,
13771321
getSoftwarePipeliningAttrDict(op->getContext(), softwarePipelineDepth,
@@ -1446,7 +1390,7 @@ static LogicalResult setContractConfig(IREE::GPU::TargetAttr target,
14461390
return setMatmulConfig(
14471391
sizeN, sizeM, 4, {sizeM, sizeN, 1},
14481392
target.getWgp().getSubgroupSizeChoices().asArrayRef(),
1449-
softwarePipelineDepthSimt, CodeGenPipeline::LLVMGPUTileAndFuse);
1393+
softwarePipelineDepthSimt, CodeGenPipeline::LLVMGPUMatmulSimt);
14501394
}
14511395

14521396
// SIMT matmul case. Query the best configuration.
@@ -1460,7 +1404,7 @@ static LogicalResult setContractConfig(IREE::GPU::TargetAttr target,
14601404
config.tileSize[0], config.tileSize[1], config.tileSize[2],
14611405
config.workgroupSize,
14621406
target.getWgp().getSubgroupSizeChoices().asArrayRef(),
1463-
softwarePipelineDepthSimt, CodeGenPipeline::LLVMGPUTileAndFuse);
1407+
softwarePipelineDepthSimt, CodeGenPipeline::LLVMGPUMatmulSimt);
14641408
}
14651409
}
14661410
}
@@ -1485,7 +1429,7 @@ static LogicalResult setContractConfig(IREE::GPU::TargetAttr target,
14851429
return setMatmulConfig(tileX, tileY, tileK, workgroupSize,
14861430
target.getWgp().getSubgroupSizeChoices().asArrayRef(),
14871431
softwarePipelineDepthSimt,
1488-
CodeGenPipeline::LLVMGPUTileAndFuse);
1432+
CodeGenPipeline::LLVMGPUMatmulSimt);
14891433
}
14901434

14911435
//====---------------------------------------------------------------------===//

compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ void LLVMGPULowerExecutableTargetPass::runOnOperation() {
114114
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUWinogradVectorize:
115115
addGPUWinogradVectorizePassPipeline(pipeline);
116116
break;
117+
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUMatmulSimt:
118+
addGPUMatmulSimtPassPipeline(pipeline, pipelineOptions);
119+
break;
117120
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUMatmulTensorCore: {
118121
FailureOr<int64_t> maybeDepth =
119122
getSoftwarePipelineDepth(translationInfo.getConfiguration());

compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp

+66
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,72 @@ void addGPUWinogradVectorizePassPipeline(OpPassManager &funcPassManager) {
526526
funcPassManager.addPass(createOptimizeTensorInsertExtractSlicesPass());
527527
}
528528

529+
//===---------------------------------------------------------------------===//
530+
// MatmulSIMT
531+
//===---------------------------------------------------------------------===//
532+
533+
void addGPUMatmulSimtPassPipeline(OpPassManager &funcPassManager,
534+
const GPUPipelineOptions &options) {
535+
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);
536+
537+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
538+
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
539+
funcPassManager.addPass(createCSEPass());
540+
541+
funcPassManager.addPass(createGPUTensorTileToSerialLoopsPass());
542+
funcPassManager.addPass(createGPUTensorAlloc());
543+
funcPassManager.addPass(createGPUTensorTilePass());
544+
545+
// Linalg -> vector
546+
addGPUVectorizationPasses(funcPassManager);
547+
548+
// tensor to memref
549+
addBufferizePasses(funcPassManager);
550+
551+
// distribute foreach threads
552+
funcPassManager.addPass(createGPUDistributePass());
553+
554+
funcPassManager.addPass(createMemrefCopyToLinalgPass());
555+
funcPassManager.addPass(createGPUDistributeSharedMemoryCopyPass());
556+
funcPassManager.addPass(createCanonicalizerPass());
557+
funcPassManager.addPass(createCSEPass());
558+
559+
if (options.enableReduceSharedMemoryBankConflicts) {
560+
funcPassManager.addPass(createGPUReduceBankConflictsPass());
561+
}
562+
563+
ReorderWorkgroupsStrategy reorderStrategy =
564+
getReorderWorkgroupsStrategy(options.reorderStrategy);
565+
funcPassManager.addPass(
566+
createReorderWorkgroups(reorderStrategy, canReorderWorkgroups));
567+
568+
funcPassManager.addPass(createCanonicalizerPass());
569+
funcPassManager.addPass(createCSEPass());
570+
571+
funcPassManager.addPass(memref::createFoldMemRefAliasOpsPass());
572+
funcPassManager.addPass(createCSEPass());
573+
funcPassManager.addPass(createCanonicalizerPass());
574+
funcPassManager.addPass(createCSEPass());
575+
576+
// Even though we vectorize before bufferization we are not able to hoist
577+
// accumulator load/store out of the K loop until distribution. This is
578+
// because we materialize the fill and the matmul in two different scf.forall
579+
// regions, when they should be in the same scf.forall. Newer pipelines
580+
// like TileAndFuse don't have this problem, because they coalesce these
581+
// scf.forall regions into a single scf.forall.
582+
//
583+
// Therefore we still rely on buffer level transformations for transfer ops
584+
// hoisting and store to load forwarding. This relies on shacky alias
585+
// analysis and we need to move this to tensor level once we have better
586+
// abstractions.
587+
funcPassManager.addPass(createOptimizeVectorTransferPass());
588+
589+
// Hoist loop invariant code to avoid pipelining it.
590+
funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
591+
// Pipeline memory operations.
592+
funcPassManager.addPass(createGPUPipeliningPass());
593+
}
594+
529595
//===---------------------------------------------------------------------===//
530596
// Matmul Tensor Core
531597
//===---------------------------------------------------------------------===//

compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h

+4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ using IREE::GPU::GPUPipelineOptions;
2828
// LLVMGPU Backend Pass Pipelines
2929
//----------------------------------------------------------------------------//
3030

31+
/// Lowering using SIMT CUDA core operations.
32+
void addGPUMatmulSimtPassPipeline(OpPassManager &funcPassManager,
33+
const GPUPipelineOptions &options);
34+
3135
/// Lowering using mma.sync Tensor Core operations.
3236
void addGPUMatmulTensorCoreMmaSyncPassPipeline(
3337
OpPassManager &funcPassManager, const GPUPipelineOptions &options,

compiler/src/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp

+10-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ getInstructionShape(Operation *op, CodeGenPipeline pipeline,
3838
Type inputElementType,
3939
SmallVector<int64_t> &instructionShape) {
4040
switch (pipeline) {
41+
case CodeGenPipeline::LLVMGPUMatmulSimt:
42+
// SIMT Pipeline / CUDA Cores
43+
instructionShape = {1, 1, 1};
44+
break;
4145
case CodeGenPipeline::LLVMGPUMatmulTensorCore:
4246
// Tensor Core Pipeline / WMMA API
4347
if (inputElementType.isF16() || inputElementType.isBF16()) {
@@ -77,7 +81,8 @@ verifyGPUMatmulPipeline(Operation *op,
7781
ArrayRef<int64_t> workgroupSize) {
7882
// This verifier only applies to matmul.
7983
CodeGenPipeline pipeline = translationInfo.getDispatchLoweringPassPipeline();
80-
if (pipeline != CodeGenPipeline::LLVMGPUMatmulTensorCore &&
84+
if (pipeline != CodeGenPipeline::LLVMGPUMatmulSimt &&
85+
pipeline != CodeGenPipeline::LLVMGPUMatmulTensorCore &&
8186
pipeline != CodeGenPipeline::LLVMGPUMatmulTensorCoreMmaSync) {
8287
return success();
8388
}
@@ -175,6 +180,10 @@ verifyGPUMatmulPipeline(Operation *op,
175180
<< pipelineName;
176181
}
177182

183+
// Return success for SIMT/CUDA cores.
184+
if (pipeline == CodeGenPipeline::LLVMGPUMatmulSimt)
185+
return success();
186+
178187
//
179188
// Additional verification Tensor Core pipelines.
180189
//

compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir

+3-2
Original file line numberDiff line numberDiff line change
@@ -267,11 +267,12 @@ func.func @not_vmt() {
267267
return
268268
}
269269

270-
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [32, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = false, no_reduce_shared_memory_bank_conflicts = true, use_igemm_convolution = false>}>
270+
// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 128, 8]{{\]}}>
271+
// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUMatmulSimt workgroup_size = [32, 1, 1] subgroup_size = 64, {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
271272
// CHECK: func.func @not_vmt()
272273
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
273274
// CHECK: linalg.generic
274-
// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], reduction = [0, 0, 8], thread = [1, 128, 0], workgroup = [1, 128, 1]}>
275+
// CHECK-SAME: lowering_config = #[[$CONFIG]]
275276

276277
// -----
277278

compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_root_op_attribute.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ func.func @matmul(%lhs: tensor<4x4xf32>, %rhs: tensor<4x4xf32>) -> tensor<4x4xf3
99
return %result : tensor<4x4xf32>
1010
}
1111

12-
// CHECK: %2 = linalg.matmul {lowering_config = #{{.*}}, root_op} ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%1 : tensor<4x4xf32>) -> tensor<4x4xf32>
12+
// CHECK: %2 = linalg.matmul {lowering_config = #config, root_op} ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%1 : tensor<4x4xf32>) -> tensor<4x4xf32>

compiler/src/iree/compiler/Codegen/LLVMGPU/test/distribute_to_thread.mlir

+4-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#map = affine_map<()[s0] -> (s0 * 2)>
1010
#map1 = affine_map<()[s0] -> (s0 * 256)>
1111
#map2 = affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>
12-
#translation = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
12+
#translation = #iree_codegen.translation_info<pipeline = LLVMGPUMatmulSimt workgroup_size = [64, 1, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
1313
func.func @dot_dispatch_0() attributes {translation_info = #translation} {
1414
%cst = arith.constant 0.000000e+00 : f32
1515
%c0 = arith.constant 0 : index
@@ -79,7 +79,7 @@ func.func @dot_dispatch_0() attributes {translation_info = #translation} {
7979
#map2 = affine_map<(d0, d1, d2)[s0] -> (d0 * 32768 + s0 + d1 * 1024 + d2)>
8080
#map3 = affine_map<(d0, d1, d2)[s0] -> (d0 * 65536 + s0 + d1 * 64 + d2)>
8181
#map4 = affine_map<(d0, d1, d2)[s0] -> (d0 * 2048 + s0 + d1 * 64 + d2)>
82-
#translation = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [8, 8, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
82+
#translation = #iree_codegen.translation_info<pipeline = LLVMGPUMatmulSimt workgroup_size = [8, 8, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
8383
func.func @batch_matmul_func() attributes {translation_info = #translation} {
8484
%c0 = arith.constant 0 : index
8585
%cst = arith.constant 0.000000e+00 : f32
@@ -148,7 +148,7 @@ func.func @batch_matmul_func() attributes {translation_info = #translation} {
148148
#map = affine_map<()[s0] -> (s0 * 2)>
149149
#map1 = affine_map<()[s0] -> (s0 * 32)>
150150
#map2 = affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>
151-
#translation = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 8, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
151+
#translation = #iree_codegen.translation_info<pipeline = LLVMGPUMatmulSimt workgroup_size = [64, 8, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
152152
func.func @dot_dispatch_0() attributes {translation_info = #translation} {
153153
%cst = arith.constant 0.000000e+00 : f32
154154
%c0 = arith.constant 0 : index
@@ -312,7 +312,7 @@ module {
312312
#hal.pipeline.binding<storage_buffer>
313313
]>
314314
#config = #iree_codegen.lowering_config<tile_sizes = [[0, 1, 2, 256, 4]]>
315-
#translation = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 8, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
315+
#translation = #iree_codegen.translation_info<pipeline = LLVMGPUMatmulSimt workgroup_size = [64, 8, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
316316
#map = affine_map<()[s0] -> (s0 * 2)>
317317
#map1 = affine_map<()[s0] -> (s0 * 256)>
318318
#map2 = affine_map<(d0)[s0] -> (-d0 + s0, 2)>

0 commit comments

Comments
 (0)