Skip to content

Commit bd955e2

Browse files
committed
Narrow for loops too
1 parent 1068d83 commit bd955e2

File tree

3 files changed

+111
-11
lines changed

3 files changed

+111
-11
lines changed

compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp

+19-6
Original file line numberDiff line numberDiff line change
@@ -506,17 +506,30 @@ void moveLoopInvariantCodeFromGuaranteedLoops(Operation *target) {
506506
loopLike.getLoopLowerBounds();
507507
std::optional<SmallVector<OpFoldResult>> maybeUpperBounds =
508508
loopLike.getLoopUpperBounds();
509-
if (!maybeLowerBounds || !maybeUpperBounds) {
509+
std::optional<SmallVector<Value>> maybeIvs =
510+
loopLike.getLoopInductionVars();
511+
if (!maybeLowerBounds || !maybeUpperBounds || !maybeIvs) {
510512
return;
511513
}
512514

513515
// If any lower + upper bound pair cannot be definitely verified as lb < ub
514516
// then the loop may have a zero trip count.
515-
for (auto [lb, ub] :
516-
llvm::zip_equal(*maybeLowerBounds, *maybeUpperBounds)) {
517-
if (!ValueBoundsConstraintSet::compare(lb, ValueBoundsConstraintSet::LT,
518-
ub)) {
519-
return;
517+
for (auto [lb, ub, iv] :
518+
llvm::zip_equal(*maybeLowerBounds, *maybeUpperBounds, *maybeIvs)) {
519+
if (iv.getType().isIndex()) {
520+
if (!ValueBoundsConstraintSet::compare(lb, ValueBoundsConstraintSet::LT,
521+
ub)) {
522+
return;
523+
}
524+
} else {
525+
// Weaker test for non-`index` operands to some loops
526+
// like scf.for, since the value bounds interface requires index types.
527+
auto maybeLb = getConstantIntValue(lb);
528+
auto maybeUb = getConstantIntValue(ub);
529+
if (!maybeLb || !maybeUb)
530+
return;
531+
if (*maybeLb >= *maybeUb)
532+
return;
520533
}
521534
}
522535

compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp

+63-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/Affine/Utils.h"
1919
#include "mlir/Dialect/Arith/IR/Arith.h"
2020
#include "mlir/Dialect/Arith/Transforms/Passes.h"
21+
#include "mlir/Dialect/SCF/IR/SCF.h"
2122
#include "mlir/IR/Matchers.h"
2223
#include "mlir/IR/PatternMatch.h"
2324
#include "mlir/Pass/Pass.h"
@@ -254,6 +255,66 @@ struct RemoveIndexCastForAssumeOfI32
254255
DataFlowSolver &solver;
255256
};
256257

258+
//===----------------------------------------------------------------------===//
259+
// scf.for induction variable range narrowing
260+
// If the induction variable of an scf.for can be represented as an I32,
261+
// make that change to save on registers etc.
262+
//===----------------------------------------------------------------------===//
263+
struct NarrowSCFForIvToI32 : public OpRewritePattern<scf::ForOp> {
264+
NarrowSCFForIvToI32(MLIRContext *context, DataFlowSolver &solver)
265+
: OpRewritePattern(context), solver(solver) {}
266+
267+
LogicalResult matchAndRewrite(scf::ForOp forOp,
268+
PatternRewriter &rewriter) const override {
269+
Location loc = forOp.getLoc();
270+
Value iv = forOp.getInductionVar();
271+
Type srcType = iv.getType();
272+
if (!srcType.isIndex() && !srcType.isInteger(64))
273+
return rewriter.notifyMatchFailure(forOp, "IV isn't an index or i64");
274+
if (!staticallyLegalToConvertToUnsigned(solver, iv))
275+
return rewriter.notifyMatchFailure(forOp, "IV isn't non-negative");
276+
if (!staticallyLegalToConvertToUnsigned(solver, forOp.getStep()))
277+
return rewriter.notifyMatchFailure(forOp, "Step isn't non-negative");
278+
auto *ivState = solver.lookupState<IntegerValueRangeLattice>(iv);
279+
if (ivState->getValue().getValue().smax().getActiveBits() > 31)
280+
return rewriter.notifyMatchFailure(forOp, "IV won't fit in signed int32");
281+
282+
Type i32 = rewriter.getI32Type();
283+
auto doCastDown = [&](Value v) -> Value {
284+
if (srcType.isIndex())
285+
return rewriter.create<arith::IndexCastUIOp>(loc, i32, v);
286+
else
287+
return rewriter.create<arith::TruncIOp>(loc, i32, v);
288+
};
289+
Value newLb = doCastDown(forOp.getLowerBound());
290+
Value newUb = doCastDown(forOp.getUpperBound());
291+
Value newStep = doCastDown(forOp.getStep());
292+
{
293+
PatternRewriter::InsertionGuard g(rewriter);
294+
rewriter.setInsertionPointToStart(&forOp.getRegion().front());
295+
Value castBackOp;
296+
if (srcType.isIndex())
297+
castBackOp =
298+
rewriter.create<arith::IndexCastUIOp>(iv.getLoc(), srcType, iv);
299+
else
300+
castBackOp = rewriter.create<arith::ExtUIOp>(iv.getLoc(), srcType, iv);
301+
(void)solver.getOrCreateState<IntegerValueRangeLattice>(castBackOp)
302+
->join(*ivState);
303+
rewriter.replaceAllUsesExcept(iv, castBackOp, castBackOp.getDefiningOp());
304+
}
305+
solver.eraseState(iv);
306+
rewriter.modifyOpInPlace(forOp, [&]() {
307+
iv.setType(i32);
308+
forOp.getLowerBoundMutable().assign(newLb);
309+
forOp.getUpperBoundMutable().assign(newUb);
310+
forOp.getStepMutable().assign(newStep);
311+
});
312+
return success();
313+
}
314+
315+
DataFlowSolver &solver;
316+
};
317+
257318
//===----------------------------------------------------------------------===//
258319
// Divisibility
259320
//===----------------------------------------------------------------------===//
@@ -396,7 +457,8 @@ class OptimizeIntArithmeticPass
396457

397458
if (narrowToI32) {
398459
arith::populateIntRangeNarrowingPatterns(patterns, solver, {32});
399-
patterns.add<RemoveIndexCastForAssumeOfI32>(ctx, solver);
460+
patterns.add<NarrowSCFForIvToI32, RemoveIndexCastForAssumeOfI32>(ctx,
461+
solver);
400462
}
401463

402464
// Populate canonicalization patterns.

compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic_narrowing.mlir

+29-4
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
// CHECK-DAG: %[[TID_I32:.+]] = arith.index_castui %[[THREAD_ID_X]] : index to i32
1414
// CHECK: %[[V0:.+]] = arith.divui %[[TID_I32]], %[[C16]] : i32
1515
// CHECK-NEXT: %[[V1:.+]] = arith.remui %[[TID_I32]], %[[C16]] : i32
16-
// CHECK-NEXT: %[[V2:.+]] = arith.muli %[[V0]], %[[V32]] : i32
17-
// CHECK-NEXT; %[[V3:.+]] = arith.addi %[[V2]], %[[V1]] : i32
16+
// CHECK-NEXT: %[[V2:.+]] = arith.muli %[[V0]], %[[C32]] : i32
17+
// CHECK-NEXT: %[[V3:.+]] = arith.addi %[[V2]], %[[V1]] : i32
1818
// CHECK-NEXT: %[[RET:.+]] = arith.index_castui %[[V3]] : i32 to index
1919
// CHECK: return %[[RET]]
2020
util.func @narrow_tid_computations() -> index {
@@ -32,12 +32,37 @@ util.func @narrow_tid_computations() -> index {
3232

3333
// CHECK-LABEL: @narrow_assumes
3434
// CHECK-SAME: (%[[ARG0:.+]]: i32)
35-
// CHECK-NEXT: %[[ASSUME:.+]] = util.assume.int %[[ARG0]][<umin = 16, umax = 122, udiv = 16>] : i32
35+
// CHECK-NEXT: %[[ASSUME:.+]] = util.assume.int %[[ARG0]]<umin = 16, umax = 122, udiv = 16> : i32
3636
// CHECK-NEXT: %[[AS_INDEX:.+]] = arith.index_castui %[[ASSUME]] : i32 to index
3737
// CHECK-NEXT: util.return %[[ASSUME]], %[[AS_INDEX]]
3838
util.func @narrow_assumes(%arg0: i32) -> (i32, index) {
3939
%0 = arith.index_castui %arg0 : i32 to index
40-
%1 = util.assume.int %0[<umin = 16, umax = 122, udiv = 16>] : index
40+
%1 = util.assume.int %0<umin = 16, umax = 122, udiv = 16> : index
4141
%2 = arith.index_castui %1 : index to i32
4242
util.return %2, %1 : i32, index
4343
}
44+
45+
// -----
46+
47+
// CHECK-LABEL: @narrow_scf_for
48+
// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : i32
49+
// CHECK-DAG: %[[C96:.+]] = arith.constant 96 : i32
50+
// CHECK-DAG: %[[C512:.+]] = arith.constant 512 : i32
51+
// CHECK-DAG: %[[TID:.+]] = gpu.thread_id x upper_bound 64
52+
// CHECK-DAG: %[[TID_I32:.+]] = arith.index_castui %[[TID]] : index to i32
53+
// CHECK: scf.for %[[ARG1:.+]] = %[[TID_I32]] to %[[C96]] step %[[C64]]
54+
// CHECK-NEXT: %[[V0:.+]] = arith.addi %[[ARG1]], %[[C512]]
55+
// CHECK-NEXT: %[[V0_IDX:.+]] = arith.index_castui %[[V0]] : i32 to index
56+
// CHECK-NEXT: memref.store {{.*}}[%[[V0_IDX]]]
57+
util.func @narrow_scf_for(%arg0: memref<?xf32>) {
58+
%c0_f32 = arith.constant 0.0 : f32
59+
%c64 = arith.constant 64 : index
60+
%c96 = arith.constant 96 : index
61+
%c512 = arith.constant 512 : index
62+
%tid = gpu.thread_id x upper_bound 64
63+
scf.for %arg1 = %tid to %c96 step %c64 {
64+
%0 = arith.addi %arg1, %c512 : index
65+
memref.store %c0_f32, %arg0[%0] : memref<?xf32>
66+
}
67+
util.return
68+
}

0 commit comments

Comments
 (0)