10
10
// upstream GPU/DecomposeMemRefs.cpp file. It adds a new option to not decompose
11
11
// it into 0-rank memrefs but instead single-ranked memrefs.
12
12
//
13
+ // Question to answer at this point:
14
+ // 1. should we disallow memrefs with non-identity layout? also cases where
15
+ // offset != 0 and stride != 1? if so we should update test cases.
16
+
17
+ // TODO:
18
+ // 1. update memref.subview.
19
+ // 2. vector dialects? masked{load|store}, transfer_{read|write}, etc?
13
20
// ===----------------------------------------------------------------------===//
14
21
15
22
#include " mlir/Dialect/Affine/IR/AffineOps.h"
16
23
#include " mlir/Dialect/Arith/IR/Arith.h"
17
24
#include " mlir/Dialect/MemRef/IR/MemRef.h"
18
25
#include " mlir/Dialect/Utils/IndexingUtils.h"
19
26
#include " mlir/Dialect/Utils/StaticValueUtils.h"
27
+ #include " mlir/Dialect/Vector/IR/VectorOps.h"
20
28
#include " mlir/IR/AffineExpr.h"
21
29
#include " mlir/IR/Attributes.h"
22
30
#include " mlir/IR/Builders.h"
@@ -120,24 +128,16 @@ getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
120
128
return {newExtractStridedMetadata.getBaseBuffer (), finalOffset, strides};
121
129
}
122
130
123
- static Value getFlatMemref (OpBuilder &rewriter, Location loc, Value source,
124
- ValueRange offsets) {
125
- SmallVector<OpFoldResult> offsetsTemp = getAsOpFoldResult (offsets);
126
- auto &&[base, offset, ignore] =
127
- getFlatOffsetAndStrides (rewriter, loc, source, offsetsTemp);
128
- MemRefType retType = inferCastResultType (base, offset);
129
- return rewriter.create <memref::ReinterpretCastOp>(loc, retType, base, offset,
130
- std::nullopt, std::nullopt);
131
- }
132
-
131
+ // / Returns a collapsed memref and the linearized index to access the element
132
+ // / at the specified indices.
133
133
static std::pair<Value, OpFoldResult> getCollapsedMemref (OpBuilder &rewriter,
134
134
Location loc,
135
135
Value source,
136
- ValueRange offsets ) {
136
+ ValueRange indices ) {
137
137
MemRefType memrefType = cast<MemRefType>(source.getType ());
138
- auto &&[base, offset, ignore ] = getFlatOffsetAndStrides (
139
- rewriter, loc, source, getAsOpFoldResult (offsets ));
140
- // expand contiguous shape
138
+ auto &&[base, index , _ ] = getFlatOffsetAndStrides (
139
+ rewriter, loc, source, getAsOpFoldResult (indices ));
140
+ // We do not support non- contiguous memrefs.
141
141
int64_t collapsedShape = 1 ;
142
142
for (auto dim : memrefType.getShape ()) {
143
143
collapsedShape *= dim;
@@ -146,12 +146,12 @@ static std::pair<Value, OpFoldResult> getCollapsedMemref(OpBuilder &rewriter,
146
146
MemRefType::get ({collapsedShape}, memrefType.getElementType (), nullptr ,
147
147
memrefType.getMemorySpace ());
148
148
149
- // TODO: implement offset.
149
+ // (lialan) TODO: should we keep ` offset` in the result memref?
150
150
return std::make_pair (rewriter.create <memref::ReinterpretCastOp>(
151
151
loc, retType, source, /* offset = */ 0 ,
152
152
/* shapes = */ ArrayRef<int64_t >{collapsedShape},
153
153
/* strides = */ ArrayRef<int64_t >{1 }),
154
- offset );
154
+ index );
155
155
}
156
156
157
157
static Value getValueFromOpFoldResult (PatternRewriter &rewriter, Location loc,
@@ -170,7 +170,6 @@ static bool needFlatten(Value val) {
170
170
171
171
static bool checkLayout (Value val) {
172
172
auto type = cast<MemRefType>(val.getType ());
173
- // TODO: is this correct?
174
173
return type.getLayout ().isIdentity () ||
175
174
isa<StridedLayoutAttr>(type.getLayout ());
176
175
}
@@ -270,6 +269,99 @@ struct FlattenVectorStore : public OpRewritePattern<vector::StoreOp> {
270
269
}
271
270
};
272
271
272
+ struct FlattenVectorMaskedLoad : public OpRewritePattern <vector::MaskedLoadOp> {
273
+ using OpRewritePattern::OpRewritePattern;
274
+
275
+ LogicalResult matchAndRewrite (vector::MaskedLoadOp op,
276
+ PatternRewriter &rewriter) const override {
277
+ Value memref = op.getBase ();
278
+ if (!needFlatten (memref))
279
+ return rewriter.notifyMatchFailure (op, " nothing to do" );
280
+
281
+ if (!checkLayout (memref))
282
+ return rewriter.notifyMatchFailure (op, " unsupported layout" );
283
+
284
+ Location loc = op.getLoc ();
285
+ auto &&[flatMemref, offset] =
286
+ getCollapsedMemref (rewriter, loc, memref, op.getIndices ());
287
+ Value offsetVal = getValueFromOpFoldResult (rewriter, loc, offset);
288
+ rewriter.replaceOpWithNewOp <vector::MaskedLoadOp>(
289
+ op, op.getType (), flatMemref, ValueRange{offsetVal}, op.getMask (),
290
+ op.getPassThru ());
291
+ return success ();
292
+ }
293
+ };
294
+
295
+ struct FlattenVectorMaskedStore : public OpRewritePattern <vector::MaskedStoreOp> {
296
+ using OpRewritePattern::OpRewritePattern;
297
+
298
+ LogicalResult matchAndRewrite (vector::MaskedStoreOp op,
299
+ PatternRewriter &rewriter) const override {
300
+ Value memref = op.getBase ();
301
+ if (!needFlatten (memref))
302
+ return rewriter.notifyMatchFailure (op, " nothing to do" );
303
+
304
+ if (!checkLayout (memref))
305
+ return rewriter.notifyMatchFailure (op, " unsupported layout" );
306
+
307
+ Location loc = op.getLoc ();
308
+ auto &&[flatMemref, offset] =
309
+ getCollapsedMemref (rewriter, loc, memref, op.getIndices ());
310
+ Value offsetVal = getValueFromOpFoldResult (rewriter, loc, offset);
311
+ rewriter.replaceOpWithNewOp <vector::MaskedStoreOp>(
312
+ op, flatMemref, ValueRange{offsetVal}, op.getMask (),
313
+ op.getValueToStore ());
314
+ return success ();
315
+ }
316
+ };
317
+ struct FlattenVectorTransferRead : public OpRewritePattern <vector::TransferReadOp> {
318
+ using OpRewritePattern::OpRewritePattern;
319
+
320
+ LogicalResult matchAndRewrite (vector::TransferReadOp op,
321
+ PatternRewriter &rewriter) const override {
322
+ Value memref = op.getSource ();
323
+ if (!needFlatten (memref))
324
+ return rewriter.notifyMatchFailure (op, " nothing to do" );
325
+
326
+ if (!checkLayout (memref))
327
+ return rewriter.notifyMatchFailure (op, " unsupported layout" );
328
+
329
+ Location loc = op.getLoc ();
330
+
331
+ auto &&[flatMemref, offset] =
332
+ getCollapsedMemref (rewriter, loc, memref, op.getIndices ());
333
+
334
+ Value offsetVal = getValueFromOpFoldResult (rewriter, loc, offset);
335
+ rewriter.replaceOpWithNewOp <vector::TransferReadOp>(
336
+ op, op.getType (), flatMemref, ValueRange{offsetVal}, op.getPadding ());
337
+ return success ();
338
+ }
339
+ };
340
+
341
+ struct FlattenVectorTransferWrite : public OpRewritePattern <vector::TransferWriteOp> {
342
+ using OpRewritePattern::OpRewritePattern;
343
+
344
+ LogicalResult matchAndRewrite (vector::TransferWriteOp op,
345
+ PatternRewriter &rewriter) const override {
346
+ Value memref = op.getSource ();
347
+ if (!needFlatten (memref))
348
+ return rewriter.notifyMatchFailure (op, " nothing to do" );
349
+
350
+ if (!checkLayout (memref))
351
+ return rewriter.notifyMatchFailure (op, " unsupported layout" );
352
+
353
+ Location loc = op.getLoc ();
354
+ auto &&[flatMemref, offset] =
355
+ getCollapsedMemref (rewriter, loc, memref, op.getIndices ());
356
+
357
+ Value offsetVal = getValueFromOpFoldResult (rewriter, loc, offset);
358
+ rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(op, op.getVector (),
359
+ flatMemref,
360
+ ValueRange{offsetVal});
361
+ return success ();
362
+ }
363
+ };
364
+
273
365
struct FlattenSubview : public OpRewritePattern <memref::SubViewOp> {
274
366
using OpRewritePattern::OpRewritePattern;
275
367
@@ -339,8 +431,11 @@ struct DecomposeMemrefsPass
339
431
340
432
namespace mlir ::iree_compiler {
341
433
void populateDecomposeMemrefsPatterns (RewritePatternSet &patterns) {
342
- patterns.insert <FlattenMemrefLoad, FlattenVectorLoad, FlattenMemrefStore,
343
- FlattenVectorStore, FlattenSubview>(patterns.getContext ());
434
+ patterns.insert <FlattenMemrefLoad, FlattenMemrefStore, FlattenSubview,
435
+ FlattenVectorMaskedLoad, FlattenVectorMaskedStore,
436
+ FlattenVectorLoad, FlattenVectorStore,
437
+ FlattenVectorTransferRead, FlattenVectorTransferWrite>(
438
+ patterns.getContext ());
344
439
}
345
440
346
441
std::unique_ptr<Pass> createDecomposeMemrefsPass () {
0 commit comments