Skip to content

Commit de7e313

Browse files
committed
Make our trivial loop detector understand delinearize
1 parent c0b4d05 commit de7e313

File tree

3 files changed

+132
-8
lines changed

3 files changed

+132
-8
lines changed

compiler/src/iree/compiler/Codegen/Common/test/remove_trivial_loops.mlir

+70
Original file line numberDiff line numberDiff line change
@@ -240,3 +240,73 @@ hal.executable private @simple_mul {
240240
// CHECK-LABEL: func.func @simple_mul
241241
// CHECK: scf.for
242242
// CHECK: scf.for
243+
244+
// -----
245+
246+
#pipeline_layout = #hal.pipeline.layout<constants = 1, bindings = [
247+
#hal.pipeline.binding<storage_buffer>,
248+
#hal.pipeline.binding<storage_buffer>
249+
]>
250+
#translation_info = #iree_codegen.translation_info<pipeline = None workgroup_size = [64, 1, 1]>
251+
// CHECK-LABEL: func.func @dispatch_0()
252+
hal.executable private @dispatch_0 {
253+
hal.executable.variant @cuda target(#hal.executable.target<"cuda", "cuda-nvptx-fb">) {
254+
hal.executable.export @dispatch_0 layout(#pipeline_layout) {
255+
^bb0(%arg0: !hal.device) :
256+
%c1 = arith.constant 1 : index
257+
hal.return %c1, %c1, %c1 : index, index, index
258+
}
259+
builtin.module {
260+
func.func @dispatch_0() attributes {translation_info = #translation_info} {
261+
%c256 = arith.constant 256 : index
262+
%tidx = gpu.thread_id x
263+
%idsX:2 = affine.delinearize_index %tidx into (2, 32) : index, index
264+
// CHECK-NOT: scf.for
265+
// CHECK: gpu.barrier
266+
%0 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%idsX#0]
267+
scf.for %arg4 = %0 to %c256 step %c256 {
268+
gpu.barrier
269+
}
270+
// CHECK-NOT: scf.for
271+
// CHECK: gpu.barrier
272+
%1 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%idsX#1]
273+
scf.for %arg4 = %1 to %c256 step %c256 {
274+
gpu.barrier
275+
}
276+
return
277+
}
278+
}
279+
}
280+
}
281+
282+
// -----
283+
284+
#pipeline_layout = #hal.pipeline.layout<constants = 1, bindings = [
285+
#hal.pipeline.binding<storage_buffer>,
286+
#hal.pipeline.binding<storage_buffer>
287+
]>
288+
#translation_info = #iree_codegen.translation_info<pipeline = None workgroup_size = [64, 1, 1]>
289+
// CHECK-LABEL: func.func @dispatch_0()
290+
hal.executable private @dispatch_0 {
291+
hal.executable.variant @cuda target(#hal.executable.target<"cuda", "cuda-nvptx-fb">) {
292+
hal.executable.export @dispatch_0 layout(#pipeline_layout) {
293+
^bb0(%arg0: !hal.device) :
294+
%c1 = arith.constant 1 : index
295+
hal.return %c1, %c1, %c1 : index, index, index
296+
}
297+
builtin.module {
298+
func.func @dispatch_0() attributes {translation_info = #translation_info} {
299+
%c256 = arith.constant 256 : index
300+
%0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : i32
301+
%1 = arith.index_cast %0 : i32 to index
302+
%2 = util.assume.int %1[<umin=0, umax=255>] : index
303+
// CHECK-NOT: scf.for
304+
// CHECK: gpu.barrier
305+
scf.for %arg4 = %2 to %c256 step %c256 {
306+
gpu.barrier
307+
}
308+
return
309+
}
310+
}
311+
}
312+
}

compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir

+6-6
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,10 @@ func.func @matmul_256x1024x128_div_add() attributes {translation_info = #transla
109109
// CHECK: gpu.barrier
110110
// CHECK: scf.for %[[IV_Y:.+]] = %[[OFFSET_Y]] to %[[C32]] step %[[C32]]
111111
// CHECK: %[[LHS_VIEW:.+]] = memref.subview %[[LHS_ALLOC]][%[[IV_Y]], 0]
112-
// CHECK: scf.for %[[IV_X:.+]] = %[[OFFSET_X]] to %[[C32]] step %[[C32]]
113-
// CHECK: %[[RHS_VIEW:.+]] = memref.subview %[[RHS_ALLOC]][0, %[[IV_X]]]
114112
// CHECK-DAG: %[[READ0:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C0]]]
115113
// CHECK-DAG: %[[READ1:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C16]]]
114+
// CHECK: scf.for %[[IV_X:.+]] = %[[OFFSET_X]] to %[[C32]] step %[[C32]]
115+
// CHECK: %[[RHS_VIEW:.+]] = memref.subview %[[RHS_ALLOC]][0, %[[IV_X]]]
116116
// CHECK-DAG: %[[READ2:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C0]], %[[C0]]]
117117
// CHECK-DAG: %[[READ3:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C16]], %[[C0]]]
118118
// CHECK-DAG: %[[READ4:.+]] = vector.transfer_read %{{.+}}[%[[C0]], %[[C0]]]
@@ -246,10 +246,10 @@ func.func @matmul_256x1024x128_div_add() attributes {translation_info = #transla
246246
// CHECK: scf.for %[[IV_Z:.+]] = %[[ID_Z]] to %[[C1]] step %[[C1]]
247247
// CHECK: scf.for %[[IV_Y:.+]] = %[[OFFSET_Y]] to %[[C32]] step %[[C32]]
248248
// CHECK: %[[LHS_VIEW:.+]] = memref.subview %[[LHS_ALLOC]][%[[IV_Z]], %[[IV_Y]], 0] [1, 16, 32]
249-
// CHECK: scf.for %[[IV_X:.+]] = %[[OFFSET_X]] to %[[C32]] step %[[C32]] {
250-
// CHECK: %[[RHS_VIEW:.+]] = memref.subview %[[RHS_ALLOC]][%[[IV_Z]], 0, %[[IV_X]]] [1, 32, 16]
251249
// CHECK-DAG: %[[READ0:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C0]], %[[C0]]]
252250
// CHECK-DAG: %[[READ1:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C0]], %[[C16]]]
251+
// CHECK: scf.for %[[IV_X:.+]] = %[[OFFSET_X]] to %[[C32]] step %[[C32]] {
252+
// CHECK: %[[RHS_VIEW:.+]] = memref.subview %[[RHS_ALLOC]][%[[IV_Z]], 0, %[[IV_X]]] [1, 32, 16]
253253
// CHECK-DAG: %[[READ2:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C0]], %[[C0]], %[[C0]]]
254254
// CHECK-DAG: %[[READ3:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C0]], %[[C16]], %[[C0]]]
255255
// CHECK-DAG: %[[READ4:.+]] = vector.transfer_read %{{.+}}[%[[C0]], %[[C0]], %[[C0]]]
@@ -369,10 +369,10 @@ func.func @matmul_256x1024x128_mixed_signedness_int8() {
369369
// CHECK: gpu.barrier
370370
// CHECK: scf.for %[[IV_Y:.+]] = %[[OFFSET_Y]] to %[[C32]] step %[[C32]]
371371
// CHECK: %[[LHS_VIEW:.+]] = memref.subview %[[LHS_ALLOC]][%[[IV_Y]], 0]
372-
// CHECK: scf.for %[[IV_X:.+]] = %[[OFFSET_X]] to %[[C32]] step %[[C32]]
373-
// CHECK: %[[RHS_VIEW:.+]] = memref.subview %[[RHS_ALLOC]][0, %[[IV_X]]]
374372
// CHECK-DAG: %[[READ0:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C0]]]
375373
// CHECK-DAG: %[[READ1:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C16]]]
374+
// CHECK: scf.for %[[IV_X:.+]] = %[[OFFSET_X]] to %[[C32]] step %[[C32]]
375+
// CHECK: %[[RHS_VIEW:.+]] = memref.subview %[[RHS_ALLOC]][0, %[[IV_X]]]
376376
// CHECK-DAG: %[[READ2:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C0]], %[[C0]]]
377377
// CHECK-DAG: %[[READ3:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C16]], %[[C0]]]
378378
// CHECK-DAG: %[[READ4:.+]] = vector.transfer_read %{{.+}}[%[[C0]], %[[C0]]]

compiler/src/iree/compiler/Codegen/Transforms/RemoveSingleIterationLoop.cpp

+56-2
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,67 @@
1717
#include "mlir/Dialect/Affine/Utils.h"
1818
#include "mlir/IR/BuiltinOps.h"
1919
#include "mlir/IR/PatternMatch.h"
20+
#include "mlir/IR/Value.h"
2021

2122
#define DEBUG_TYPE "iree-codegen-remove-single-iteration"
2223

2324
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
2425

2526
namespace mlir::iree_compiler {
2627

28+
/// Traverse affine.delinearize_index and affine.linearize_index and util
29+
/// assumption ops to get bounds. In the long run, this should either be added
30+
/// as a composition utility to affine and/or as calls to
31+
/// IntRangeInferenceInterface.
32+
static std::optional<std::pair<AffineExpr, AffineExpr>>
33+
getMinMaxExprWrapper(Value dim, SmallVectorImpl<Value> &dims,
34+
SmallVectorImpl<Value> &syms,
35+
GetMinMaxExprFn getMinMaxExpr) {
36+
if (auto delinOp = dim.getDefiningOp<affine::AffineDelinearizeIndexOp>()) {
37+
if (!delinOp.getDynamicBasis().empty()) {
38+
LLVM_DEBUG(
39+
DBGS()
40+
<< "not handling delinearize with dynamic dimensions for now\n");
41+
return std::nullopt;
42+
}
43+
Value linearIdx = delinOp.getLinearIndex();
44+
ArrayRef<int64_t> basis = delinOp.getStaticBasis();
45+
unsigned resultNum = cast<OpResult>(dim).getResultNumber();
46+
auto linearMinMax =
47+
getMinMaxExprWrapper(linearIdx, dims, syms, getMinMaxExpr);
48+
if (resultNum == 0 && !delinOp.hasOuterBound()) {
49+
if (!linearMinMax.has_value())
50+
return std::nullopt;
51+
auto [min, max] = *linearMinMax;
52+
int64_t divisor = ShapedType::getNumElements(basis);
53+
return std::make_pair(min.floorDiv(divisor), max.floorDiv(divisor));
54+
}
55+
unsigned basisArg = resultNum - (delinOp.hasOuterBound() ? 0 : 1);
56+
int64_t modulus = basis[basisArg];
57+
int64_t divisor = ShapedType::getNumElements(basis.drop_front(basisArg));
58+
if (linearMinMax.has_value()) {
59+
auto [min, max] = *linearMinMax;
60+
return std::make_pair(min.floorDiv(divisor) % modulus,
61+
max.floorDiv(divisor) % modulus);
62+
}
63+
if (resultNum > 0)
64+
return std::make_pair(
65+
getAffineConstantExpr(0, dim.getContext()),
66+
getAffineConstantExpr(modulus - 1, dim.getContext()));
67+
return std::nullopt;
68+
}
69+
70+
if (auto assumeOp = dim.getDefiningOp<IREE::Util::AssumeIntOp>()) {
71+
auto [min, max] =
72+
assumeOp.getUnionedUnsignedRange(cast<OpResult>(dim).getResultNumber());
73+
if (!min || !max)
74+
return std::nullopt;
75+
return std::make_pair(getAffineConstantExpr(*min, dim.getContext()),
76+
getAffineConstantExpr(*max, dim.getContext()));
77+
}
78+
return getMinMaxExpr(dim, dims, syms);
79+
}
80+
2781
/// Compose map with apply affine ops and try to simplify it.
2882
static void combineAndSimplifyMap(AffineMap &map, SmallVectorImpl<Value> &dims,
2983
SmallVectorImpl<Value> &symbols) {
@@ -52,7 +106,7 @@ static AffineMap substituteMin(AffineMap map, SmallVectorImpl<Value> &dims,
52106
substituted = false;
53107
for (unsigned dimIdx = 0; dimIdx < dims.size(); ++dimIdx) {
54108
Value dim = dims[dimIdx];
55-
auto minMax = getMinMaxExpr(dim, dims, symbols);
109+
auto minMax = getMinMaxExprWrapper(dim, dims, symbols, getMinMaxExpr);
56110
if (!minMax)
57111
continue;
58112
AffineExpr dimExpr = getAffineDimExpr(dimIdx, expr.getContext());
@@ -70,7 +124,7 @@ static AffineMap substituteMin(AffineMap map, SmallVectorImpl<Value> &dims,
70124
// Substitute symbols
71125
for (unsigned symIdx = 0; symIdx < symbols.size(); ++symIdx) {
72126
Value sym = symbols[symIdx];
73-
auto minMax = getMinMaxExpr(sym, dims, symbols);
127+
auto minMax = getMinMaxExprWrapper(sym, dims, symbols, getMinMaxExpr);
74128
if (!minMax)
75129
continue;
76130
AffineExpr symExpr = getAffineSymbolExpr(symIdx, expr.getContext());

0 commit comments

Comments
 (0)