17
17
#include " mlir/Dialect/Affine/Utils.h"
18
18
#include " mlir/IR/BuiltinOps.h"
19
19
#include " mlir/IR/PatternMatch.h"
20
+ #include " mlir/IR/Value.h"
20
21
21
22
#define DEBUG_TYPE " iree-codegen-remove-single-iteration"
22
23
23
24
#define DBGS () (llvm::dbgs() << " [" DEBUG_TYPE << " ]: " )
24
25
25
26
namespace mlir ::iree_compiler {
26
27
28
+ // / Traverse affine.delinearize_index and affine.linearize_index and util
29
+ // / assumption ops to get bounds. In the long run, this should either be added
30
+ // / as a composition utility to affine and/or as calls to
31
+ // / IntRangeInferenceInterface.
32
+ static std::optional<std::pair<AffineExpr, AffineExpr>>
33
+ getMinMaxExprWrapper (Value dim, SmallVectorImpl<Value> &dims,
34
+ SmallVectorImpl<Value> &syms,
35
+ GetMinMaxExprFn getMinMaxExpr) {
36
+ if (auto delinOp = dim.getDefiningOp <affine::AffineDelinearizeIndexOp>()) {
37
+ if (!delinOp.getDynamicBasis ().empty ()) {
38
+ LLVM_DEBUG (
39
+ DBGS ()
40
+ << " not handling delinearize with dynamic dimensions for now\n " );
41
+ return std::nullopt;
42
+ }
43
+ Value linearIdx = delinOp.getLinearIndex ();
44
+ ArrayRef<int64_t > basis = delinOp.getStaticBasis ();
45
+ unsigned resultNum = cast<OpResult>(dim).getResultNumber ();
46
+ auto linearMinMax =
47
+ getMinMaxExprWrapper (linearIdx, dims, syms, getMinMaxExpr);
48
+ if (resultNum == 0 && !delinOp.hasOuterBound ()) {
49
+ if (!linearMinMax.has_value ())
50
+ return std::nullopt;
51
+ auto [min, max] = *linearMinMax;
52
+ int64_t divisor = ShapedType::getNumElements (basis);
53
+ return std::make_pair (min.floorDiv (divisor), max.floorDiv (divisor));
54
+ }
55
+ unsigned basisArg = resultNum - (delinOp.hasOuterBound () ? 0 : 1 );
56
+ int64_t modulus = basis[basisArg];
57
+ int64_t divisor = ShapedType::getNumElements (basis.drop_front (basisArg));
58
+ if (linearMinMax.has_value ()) {
59
+ auto [min, max] = *linearMinMax;
60
+ return std::make_pair (min.floorDiv (divisor) % modulus,
61
+ max.floorDiv (divisor) % modulus);
62
+ }
63
+ if (resultNum > 0 )
64
+ return std::make_pair (
65
+ getAffineConstantExpr (0 , dim.getContext ()),
66
+ getAffineConstantExpr (modulus - 1 , dim.getContext ()));
67
+ return std::nullopt;
68
+ }
69
+
70
+ if (auto assumeOp = dim.getDefiningOp <IREE::Util::AssumeIntOp>()) {
71
+ auto [min, max] =
72
+ assumeOp.getUnionedUnsignedRange (cast<OpResult>(dim).getResultNumber ());
73
+ if (!min || !max)
74
+ return std::nullopt;
75
+ return std::make_pair (getAffineConstantExpr (*min, dim.getContext ()),
76
+ getAffineConstantExpr (*max, dim.getContext ()));
77
+ }
78
+ return getMinMaxExpr (dim, dims, syms);
79
+ }
80
+
27
81
// / Compose map with apply affine ops and try to simplify it.
28
82
static void combineAndSimplifyMap (AffineMap &map, SmallVectorImpl<Value> &dims,
29
83
SmallVectorImpl<Value> &symbols) {
@@ -52,7 +106,7 @@ static AffineMap substituteMin(AffineMap map, SmallVectorImpl<Value> &dims,
52
106
substituted = false ;
53
107
for (unsigned dimIdx = 0 ; dimIdx < dims.size (); ++dimIdx) {
54
108
Value dim = dims[dimIdx];
55
- auto minMax = getMinMaxExpr (dim, dims, symbols);
109
+ auto minMax = getMinMaxExprWrapper (dim, dims, symbols, getMinMaxExpr );
56
110
if (!minMax)
57
111
continue ;
58
112
AffineExpr dimExpr = getAffineDimExpr (dimIdx, expr.getContext ());
@@ -70,7 +124,7 @@ static AffineMap substituteMin(AffineMap map, SmallVectorImpl<Value> &dims,
70
124
// Substitute symbols
71
125
for (unsigned symIdx = 0 ; symIdx < symbols.size (); ++symIdx) {
72
126
Value sym = symbols[symIdx];
73
- auto minMax = getMinMaxExpr (sym, dims, symbols);
127
+ auto minMax = getMinMaxExprWrapper (sym, dims, symbols, getMinMaxExpr );
74
128
if (!minMax)
75
129
continue ;
76
130
AffineExpr symExpr = getAffineSymbolExpr (symIdx, expr.getContext ());
0 commit comments