|
18 | 18 | #include "mlir/Dialect/Affine/Utils.h"
|
19 | 19 | #include "mlir/Dialect/Arith/IR/Arith.h"
|
20 | 20 | #include "mlir/Dialect/Arith/Transforms/Passes.h"
|
| 21 | +#include "mlir/Dialect/SCF/IR/SCF.h" |
21 | 22 | #include "mlir/IR/Matchers.h"
|
22 | 23 | #include "mlir/IR/PatternMatch.h"
|
23 | 24 | #include "mlir/Pass/Pass.h"
|
@@ -254,6 +255,66 @@ struct RemoveIndexCastForAssumeOfI32
|
254 | 255 | DataFlowSolver &solver;
|
255 | 256 | };
|
256 | 257 |
|
| 258 | +//===----------------------------------------------------------------------===// |
| 259 | +// scf.for induction variable range narrowing |
| 260 | +// If the induction variable of an scf.for can be represented as an I32, |
| 261 | +// make that change to save on registers etc. |
| 262 | +//===----------------------------------------------------------------------===// |
| 263 | +struct NarrowSCFForIvToI32 : public OpRewritePattern<scf::ForOp> { |
| 264 | + NarrowSCFForIvToI32(MLIRContext *context, DataFlowSolver &solver) |
| 265 | + : OpRewritePattern(context), solver(solver) {} |
| 266 | + |
| 267 | + LogicalResult matchAndRewrite(scf::ForOp forOp, |
| 268 | + PatternRewriter &rewriter) const override { |
| 269 | + Location loc = forOp.getLoc(); |
| 270 | + Value iv = forOp.getInductionVar(); |
| 271 | + Type srcType = iv.getType(); |
| 272 | + if (!srcType.isIndex() && !srcType.isInteger(64)) |
| 273 | + return rewriter.notifyMatchFailure(forOp, "IV isn't an index or i64"); |
| 274 | + if (!staticallyLegalToConvertToUnsigned(solver, iv)) |
| 275 | + return rewriter.notifyMatchFailure(forOp, "IV isn't non-negative"); |
| 276 | + if (!staticallyLegalToConvertToUnsigned(solver, forOp.getStep())) |
| 277 | + return rewriter.notifyMatchFailure(forOp, "Step isn't non-negative"); |
| 278 | + auto *ivState = solver.lookupState<IntegerValueRangeLattice>(iv); |
| 279 | + if (ivState->getValue().getValue().smax().getActiveBits() > 31) |
| 280 | + return rewriter.notifyMatchFailure(forOp, "IV won't fit in signed int32"); |
| 281 | + |
| 282 | + Type i32 = rewriter.getI32Type(); |
| 283 | + auto doCastDown = [&](Value v) -> Value { |
| 284 | + if (srcType.isIndex()) |
| 285 | + return rewriter.create<arith::IndexCastUIOp>(loc, i32, v); |
| 286 | + else |
| 287 | + return rewriter.create<arith::TruncIOp>(loc, i32, v); |
| 288 | + }; |
| 289 | + Value newLb = doCastDown(forOp.getLowerBound()); |
| 290 | + Value newUb = doCastDown(forOp.getUpperBound()); |
| 291 | + Value newStep = doCastDown(forOp.getStep()); |
| 292 | + { |
| 293 | + PatternRewriter::InsertionGuard g(rewriter); |
| 294 | + rewriter.setInsertionPointToStart(&forOp.getRegion().front()); |
| 295 | + Value castBackOp; |
| 296 | + if (srcType.isIndex()) |
| 297 | + castBackOp = |
| 298 | + rewriter.create<arith::IndexCastUIOp>(iv.getLoc(), srcType, iv); |
| 299 | + else |
| 300 | + castBackOp = rewriter.create<arith::ExtUIOp>(iv.getLoc(), srcType, iv); |
| 301 | + (void)solver.getOrCreateState<IntegerValueRangeLattice>(castBackOp) |
| 302 | + ->join(*ivState); |
| 303 | + rewriter.replaceAllUsesExcept(iv, castBackOp, castBackOp.getDefiningOp()); |
| 304 | + } |
| 305 | + solver.eraseState(iv); |
| 306 | + rewriter.modifyOpInPlace(forOp, [&]() { |
| 307 | + iv.setType(i32); |
| 308 | + forOp.getLowerBoundMutable().assign(newLb); |
| 309 | + forOp.getUpperBoundMutable().assign(newUb); |
| 310 | + forOp.getStepMutable().assign(newStep); |
| 311 | + }); |
| 312 | + return success(); |
| 313 | + } |
| 314 | + |
| 315 | + DataFlowSolver &solver; |
| 316 | +}; |
| 317 | + |
257 | 318 | //===----------------------------------------------------------------------===//
|
258 | 319 | // Divisibility
|
259 | 320 | //===----------------------------------------------------------------------===//
|
@@ -396,7 +457,8 @@ class OptimizeIntArithmeticPass
|
396 | 457 |
|
397 | 458 | if (narrowToI32) {
|
398 | 459 | arith::populateIntRangeNarrowingPatterns(patterns, solver, {32});
|
399 |
| - patterns.add<RemoveIndexCastForAssumeOfI32>(ctx, solver); |
| 460 | + patterns.add<NarrowSCFForIvToI32, RemoveIndexCastForAssumeOfI32>(ctx, |
| 461 | + solver); |
400 | 462 | }
|
401 | 463 |
|
402 | 464 | // Populate canonicalization patterns.
|
|
0 commit comments