Skip to content

Commit b91faf6

Browse files
committed
adding test cases and transfer_read/write
1 parent fbbf149 commit b91faf6

File tree

2 files changed

+177
-32
lines changed

2 files changed

+177
-32
lines changed

compiler/src/iree/compiler/Codegen/Common/DecomposeMemRefs.cpp

+114-19
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,21 @@
1010
// upstream GPU/DecomposeMemRefs.cpp file. It adds a new option to not decompose
1111
// it into 0-rank memrefs but instead single-ranked memrefs.
1212
//
13+
// Question to answer at this point:
14+
// 1. should we disallow memrefs with non-identity layout? also cases where
15+
// offset != 0 and stride != 1? if so we should update test cases.
16+
17+
// TODO:
18+
// 1. update memref.subview.
19+
// 2. vector dialects? masked{load|store}, transfer_{read|write}, etc?
1320
//===----------------------------------------------------------------------===//
1421

1522
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1623
#include "mlir/Dialect/Arith/IR/Arith.h"
1724
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1825
#include "mlir/Dialect/Utils/IndexingUtils.h"
1926
#include "mlir/Dialect/Utils/StaticValueUtils.h"
27+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2028
#include "mlir/IR/AffineExpr.h"
2129
#include "mlir/IR/Attributes.h"
2230
#include "mlir/IR/Builders.h"
@@ -120,24 +128,16 @@ getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
120128
return {newExtractStridedMetadata.getBaseBuffer(), finalOffset, strides};
121129
}
122130

123-
static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source,
124-
ValueRange offsets) {
125-
SmallVector<OpFoldResult> offsetsTemp = getAsOpFoldResult(offsets);
126-
auto &&[base, offset, ignore] =
127-
getFlatOffsetAndStrides(rewriter, loc, source, offsetsTemp);
128-
MemRefType retType = inferCastResultType(base, offset);
129-
return rewriter.create<memref::ReinterpretCastOp>(loc, retType, base, offset,
130-
std::nullopt, std::nullopt);
131-
}
132-
131+
/// Returns a collapsed memref and the linearized index to access the element
132+
/// at the specified indices.
133133
static std::pair<Value, OpFoldResult> getCollapsedMemref(OpBuilder &rewriter,
134134
Location loc,
135135
Value source,
136-
ValueRange offsets) {
136+
ValueRange indices) {
137137
MemRefType memrefType = cast<MemRefType>(source.getType());
138-
auto &&[base, offset, ignore] = getFlatOffsetAndStrides(
139-
rewriter, loc, source, getAsOpFoldResult(offsets));
140-
// expand contiguous shape
138+
auto &&[base, index, _] = getFlatOffsetAndStrides(
139+
rewriter, loc, source, getAsOpFoldResult(indices));
140+
// We do not support non-contiguous memrefs.
141141
int64_t collapsedShape = 1;
142142
for (auto dim : memrefType.getShape()) {
143143
collapsedShape *= dim;
@@ -146,12 +146,12 @@ static std::pair<Value, OpFoldResult> getCollapsedMemref(OpBuilder &rewriter,
146146
MemRefType::get({collapsedShape}, memrefType.getElementType(), nullptr,
147147
memrefType.getMemorySpace());
148148

149-
// TODO: implement offset.
149+
// (lialan) TODO: should we keep `offset` in the result memref?
150150
return std::make_pair(rewriter.create<memref::ReinterpretCastOp>(
151151
loc, retType, source, /* offset = */ 0,
152152
/*shapes = */ ArrayRef<int64_t>{collapsedShape},
153153
/* strides = */ ArrayRef<int64_t>{1}),
154-
offset);
154+
index);
155155
}
156156

157157
static Value getValueFromOpFoldResult(PatternRewriter &rewriter, Location loc,
@@ -170,7 +170,6 @@ static bool needFlatten(Value val) {
170170

171171
static bool checkLayout(Value val) {
172172
auto type = cast<MemRefType>(val.getType());
173-
// TODO: is this correct?
174173
return type.getLayout().isIdentity() ||
175174
isa<StridedLayoutAttr>(type.getLayout());
176175
}
@@ -270,6 +269,99 @@ struct FlattenVectorStore : public OpRewritePattern<vector::StoreOp> {
270269
}
271270
};
272271

272+
struct FlattenVectorMaskedLoad : public OpRewritePattern<vector::MaskedLoadOp> {
273+
using OpRewritePattern::OpRewritePattern;
274+
275+
LogicalResult matchAndRewrite(vector::MaskedLoadOp op,
276+
PatternRewriter &rewriter) const override {
277+
Value memref = op.getBase();
278+
if (!needFlatten(memref))
279+
return rewriter.notifyMatchFailure(op, "nothing to do");
280+
281+
if (!checkLayout(memref))
282+
return rewriter.notifyMatchFailure(op, "unsupported layout");
283+
284+
Location loc = op.getLoc();
285+
auto &&[flatMemref, offset] =
286+
getCollapsedMemref(rewriter, loc, memref, op.getIndices());
287+
Value offsetVal = getValueFromOpFoldResult(rewriter, loc, offset);
288+
rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
289+
op, op.getType(), flatMemref, ValueRange{offsetVal}, op.getMask(),
290+
op.getPassThru());
291+
return success();
292+
}
293+
};
294+
295+
struct FlattenVectorMaskedStore : public OpRewritePattern<vector::MaskedStoreOp> {
296+
using OpRewritePattern::OpRewritePattern;
297+
298+
LogicalResult matchAndRewrite(vector::MaskedStoreOp op,
299+
PatternRewriter &rewriter) const override {
300+
Value memref = op.getBase();
301+
if (!needFlatten(memref))
302+
return rewriter.notifyMatchFailure(op, "nothing to do");
303+
304+
if (!checkLayout(memref))
305+
return rewriter.notifyMatchFailure(op, "unsupported layout");
306+
307+
Location loc = op.getLoc();
308+
auto &&[flatMemref, offset] =
309+
getCollapsedMemref(rewriter, loc, memref, op.getIndices());
310+
Value offsetVal = getValueFromOpFoldResult(rewriter, loc, offset);
311+
rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
312+
op, flatMemref, ValueRange{offsetVal}, op.getMask(),
313+
op.getValueToStore());
314+
return success();
315+
}
316+
};
317+
struct FlattenVectorTransferRead : public OpRewritePattern<vector::TransferReadOp> {
318+
using OpRewritePattern::OpRewritePattern;
319+
320+
LogicalResult matchAndRewrite(vector::TransferReadOp op,
321+
PatternRewriter &rewriter) const override {
322+
Value memref = op.getSource();
323+
if (!needFlatten(memref))
324+
return rewriter.notifyMatchFailure(op, "nothing to do");
325+
326+
if (!checkLayout(memref))
327+
return rewriter.notifyMatchFailure(op, "unsupported layout");
328+
329+
Location loc = op.getLoc();
330+
331+
auto &&[flatMemref, offset] =
332+
getCollapsedMemref(rewriter, loc, memref, op.getIndices());
333+
334+
Value offsetVal = getValueFromOpFoldResult(rewriter, loc, offset);
335+
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
336+
op, op.getType(), flatMemref, ValueRange{offsetVal}, op.getPadding());
337+
return success();
338+
}
339+
};
340+
341+
struct FlattenVectorTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
342+
using OpRewritePattern::OpRewritePattern;
343+
344+
LogicalResult matchAndRewrite(vector::TransferWriteOp op,
345+
PatternRewriter &rewriter) const override {
346+
Value memref = op.getSource();
347+
if (!needFlatten(memref))
348+
return rewriter.notifyMatchFailure(op, "nothing to do");
349+
350+
if (!checkLayout(memref))
351+
return rewriter.notifyMatchFailure(op, "unsupported layout");
352+
353+
Location loc = op.getLoc();
354+
auto &&[flatMemref, offset] =
355+
getCollapsedMemref(rewriter, loc, memref, op.getIndices());
356+
357+
Value offsetVal = getValueFromOpFoldResult(rewriter, loc, offset);
358+
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(op, op.getVector(),
359+
flatMemref,
360+
ValueRange{offsetVal});
361+
return success();
362+
}
363+
};
364+
273365
struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
274366
using OpRewritePattern::OpRewritePattern;
275367

@@ -339,8 +431,11 @@ struct DecomposeMemrefsPass
339431

340432
namespace mlir::iree_compiler {
341433
void populateDecomposeMemrefsPatterns(RewritePatternSet &patterns) {
342-
patterns.insert<FlattenMemrefLoad, FlattenVectorLoad, FlattenMemrefStore,
343-
FlattenVectorStore, FlattenSubview>(patterns.getContext());
434+
patterns.insert<FlattenMemrefLoad, FlattenMemrefStore, FlattenSubview,
435+
FlattenVectorMaskedLoad, FlattenVectorMaskedStore,
436+
FlattenVectorLoad, FlattenVectorStore,
437+
FlattenVectorTransferRead, FlattenVectorTransferWrite>(
438+
patterns.getContext());
344439
}
345440

346441
std::unique_ptr<Pass> createDecomposeMemrefsPass() {

compiler/src/iree/compiler/Codegen/Common/test/decompose_memref.mlir

+63-13
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-decompose-memrefs))" %s | FileCheck
22

33
// TODO: support vector dialect.
4-
// TODO: support stores.
5-
// TODO: test subviews.
64

75
// -----
86

@@ -17,59 +15,59 @@ func.func @load_scalar_from_memref(%input: memref<4x8xf32>) -> f32 {
1715
// -----
1816

1917
func.func @load_scalar_from_memref_static_dim(%input: memref<4x8xf32, strided<[8, 12], offset: 100>>) -> f32 {
20-
%c0 = arith.constant 0 : index
2118
%c1 = arith.constant 1 : index
22-
%value = memref.load %input[%c0, %c1] : memref<4x8xf32, strided<[8, 12], offset: 100>>
19+
%c2 = arith.constant 2 : index
20+
%value = memref.load %input[%c1, %c2] : memref<4x8xf32, strided<[8, 12], offset: 100>>
2321
return %value : f32
2422
}
2523
// CHECK-LABEL: func @load_scalar_from_memref_static_dim
2624

2725
// -----
2826

2927
func.func @load_scalar_from_memref_static_dim_2(%input: memref<4x8xf32, strided<[8, 12], offset: 100>>, %row: index, %col: index) -> f32 {
30-
%value = memref.load %input[%row, %col] : memref<4x8xf32, strided<[8, 12], offset: 100>>
28+
%value = memref.load %input[%col, %row] : memref<4x8xf32, strided<[8, 12], offset: 100>>
3129
return %value : f32
3230
}
3331
// CHECK-LABEL: func @load_scalar_from_memref_static_dim_2
3432

3533
// -----
3634

37-
func.func @load_scalar_from_memref_dynamic_dim(%input: memref<4x8xf32, strided<[?, ?], offset: ?>>) -> f32 {
35+
func.func @load_scalar_from_memref_dynamic_dim(%input: memref<4x8xf32, strided<[?, ?], offset: ?>>, %row : index, %col : index) -> f32 {
3836
%c0 = arith.constant 0 : index
3937
%c1 = arith.constant 1 : index
40-
%value = memref.load %input[%c0, %c1] : memref<4x8xf32, strided<[?, ?], offset: ?>>
38+
%value = memref.load %input[%c1, %c0] : memref<4x8xf32, strided<[?, ?], offset: ?>>
4139
return %value : f32
4240
}
4341
// CHECK-LABEL: func @load_scalar_from_memref_dynamic_dim
4442

4543
// -----
4644

4745
func.func @load_scalar_from_memref_dynamic_dim_2(%input: memref<4x8xf32, strided<[?, ?], offset: ?>>, %row: index, %col: index) -> f32 {
48-
%value = memref.load %input[%row, %col] : memref<4x8xf32, strided<[?, ?], offset: ?>>
46+
%value = memref.load %input[%col, %row] : memref<4x8xf32, strided<[?, ?], offset: ?>>
4947
return %value : f32
5048
}
5149
// CHECK-LABEL: func @load_scalar_from_memref_dynamic_dim_2
5250

5351
// -----
5452

5553
func.func @load_scalar_from_memref_subview(%input: memref<4x8xf32>, %row: index, %col: index) -> memref<1x1xf32, strided<[8, 1], offset: ?>> {
56-
%subview = memref.subview %input[%row, %col] [1, 1] [1, 1] : memref<4x8xf32> to memref<1x1xf32, strided<[8, 1], offset: ?>>
54+
%subview = memref.subview %input[%col, %row] [1, 1] [1, 1] : memref<4x8xf32> to memref<1x1xf32, strided<[8, 1], offset: ?>>
5755
return %subview : memref<1x1xf32, strided<[8, 1], offset: ?>>
5856
}
5957
// CHECK-LABEL: func @load_scalar_from_memref_subview
6058

6159
// -----
6260

6361
func.func @store_scalar_from_memref_static_dim_2(%input: memref<4x8xf32, strided<[8, 12], offset: 100>>, %row: index, %col: index, %value: f32) {
64-
memref.store %value, %input[%row, %col] : memref<4x8xf32, strided<[8, 12], offset: 100>>
62+
memref.store %value, %input[%col, %row] : memref<4x8xf32, strided<[8, 12], offset: 100>>
6563
return
6664
}
6765
// CHECK-LABEL: func @store_scalar_from_memref_static_dim_2
6866

6967
// -----
7068

7169
func.func @store_scalar_from_memref_dynamic_dim_2(%input: memref<4x8xf32, strided<[?, ?], offset: ?>>, %row: index, %col: index, %value: f32) {
72-
memref.store %value, %input[%row, %col] : memref<4x8xf32, strided<[?, ?], offset: ?>>
70+
memref.store %value, %input[%col, %row] : memref<4x8xf32, strided<[?, ?], offset: ?>>
7371
return
7472
}
7573
// CHECK-LABEL: func @store_scalar_from_memref_dynamic_dim_2
@@ -97,7 +95,7 @@ func.func @load_vector_from_memref_odd(%input: memref<3x7xi2>) -> vector<3xi2> {
9795
// -----
9896

9997
func.func @load_vector_from_memref_dynamic(%input: memref<3x7xi2>, %row: index, %col: index) -> vector<3xi2> {
100-
%value = vector.load %input[%row, %col] : memref<3x7xi2>, vector<3xi2>
98+
%value = vector.load %input[%col, %row] : memref<3x7xi2>, vector<3xi2>
10199
return %value : vector<3xi2>
102100
}
103101
// CHECK-LABEL: func @load_vector_from_memref_dynamic
@@ -115,7 +113,59 @@ func.func @store_vector_to_memref_odd(%input: memref<3x7xi2>, %value: vector<3xi
115113
// -----
116114

117115
func.func @store_vector_to_memref_dynamic(%input: memref<3x7xi2>, %value: vector<3xi2>, %row: index, %col: index) {
118-
vector.store %value, %input[%row, %col] : memref<3x7xi2>, vector<3xi2>
116+
vector.store %value, %input[%col, %row] : memref<3x7xi2>, vector<3xi2>
119117
return
120118
}
121119
// CHECK-LABEL: func @store_vector_to_memref_dynamic
120+
121+
// -----
122+
123+
func.func @mask_store_vector_to_memref_odd(%input: memref<3x7xi2>, %value: vector<3xi2>, %mask: vector<3xi1>) {
124+
%c1 = arith.constant 1 : index
125+
%c3 = arith.constant 3 : index
126+
vector.maskedstore %input[%c1, %c3], %mask, %value : memref<3x7xi2>, vector<3xi1>, vector<3xi2>
127+
return
128+
}
129+
// CHECK-LABEL: func @mask_store_vector_to_memref_odd
130+
131+
// -----
132+
133+
func.func @mask_store_vector_to_memref_dynamic(%input: memref<3x7xi2>, %value: vector<3xi2>, %row: index, %col: index, %mask: vector<3xi1>) {
134+
vector.maskedstore %input[%col, %row], %mask, %value : memref<3x7xi2>, vector<3xi1>, vector<3xi2>
135+
return
136+
}
137+
// CHECK-LABEL: func @mask_store_vector_to_memref_dynamic
138+
139+
// -----
140+
func.func @mask_load_vector_from_memref_odd(%input: memref<3x7xi2>, %mask: vector<3xi1>, %passthru: vector<3xi2>) -> vector<3xi2> {
141+
%c1 = arith.constant 1 : index
142+
%c3 = arith.constant 3 : index
143+
%result = vector.maskedload %input[%c1, %c3], %mask, %passthru : memref<3x7xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2>
144+
return %result : vector<3xi2>
145+
}
146+
// CHECK-LABEL: func @mask_load_vector_from_memref_odd
147+
148+
// -----
149+
150+
func.func @mask_load_vector_from_memref_dynamic(%input: memref<3x7xi2>, %row: index, %col: index, %mask: vector<3xi1>, %passthru: vector<3xi2>) -> vector<3xi2> {
151+
%result = vector.maskedload %input[%col, %row], %mask, %passthru : memref<3x7xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2>
152+
return %result : vector<3xi2>
153+
}
154+
// CHECK-LABEL: func @mask_load_vector_from_memref_dynamic
155+
156+
// -----
157+
158+
func.func @transfer_read_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, %row: index, %col: index) -> vector<8xi2> {
159+
%c0 = arith.constant 0 : i2
160+
%0 = vector.transfer_read %input[%col, %row], %c0 : memref<4x8xi2>, vector<8xi2>
161+
return %0 : vector<8xi2>
162+
}
163+
// CHECK-LABEL: func @transfer_read_memref
164+
165+
// -----
166+
167+
func.func @transfer_write_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, %row: index, %col: index) {
168+
vector.transfer_write %value, %input[%col, %row] : vector<8xi2>, memref<4x8xi2>
169+
return
170+
}
171+
// CHECK-LABEL: func @transfer_write_memref

0 commit comments

Comments
 (0)