@@ -108,13 +108,9 @@ struct GPUSetEncodingOpLoweringConversion
108
108
return success ();
109
109
}
110
110
111
- FailureOr< MaterializeEncodingInfo> maybeEncodingInfo =
111
+ MaterializeEncodingInfo encodingInfo =
112
112
converter->getEncodingInfo (encodingOp.getResultType ());
113
- if (failed (maybeEncodingInfo)) {
114
- return rewriter.notifyMatchFailure (encodingOp,
115
- " unhandled result encoding" );
116
- }
117
- if (!maybeEncodingInfo->swizzle ) {
113
+ if (!encodingInfo.swizzle ) {
118
114
rewriter.replaceOp (encodingOp, packedValue.value ());
119
115
return success ();
120
116
}
@@ -128,18 +124,18 @@ struct GPUSetEncodingOpLoweringConversion
128
124
.getShape ()
129
125
.take_front (origRank));
130
126
expandShapeShape.append (
131
- getExpandedTileShape (maybeEncodingInfo-> swizzle ->expandShape ));
127
+ getExpandedTileShape (encodingInfo. swizzle ->expandShape ));
132
128
RankedTensorType expandShapeType =
133
129
encodingOp.getSourceType ().clone (expandShapeShape);
134
130
135
- SmallVector<ReassociationIndices> reassociation = getReassociationIndices (
136
- origRank, maybeEncodingInfo-> swizzle ->expandShape );
131
+ SmallVector<ReassociationIndices> reassociation =
132
+ getReassociationIndices ( origRank, encodingInfo. swizzle ->expandShape );
137
133
auto expandShapeOp = rewriter.create <tensor::ExpandShapeOp>(
138
134
loc, expandShapeType, packedValue.value (), reassociation);
139
135
140
136
SmallVector<int64_t > transposePerm =
141
137
llvm::to_vector (llvm::seq<int64_t >(0 , origRank));
142
- for (auto perm : maybeEncodingInfo-> swizzle ->permutation ) {
138
+ for (auto perm : encodingInfo. swizzle ->permutation ) {
143
139
transposePerm.push_back (origRank + perm);
144
140
}
145
141
SmallVector<OpFoldResult> transposeResultDims =
@@ -168,9 +164,9 @@ struct GPUUnsetEncodingOpLoweringConversion
168
164
auto converter = static_cast <const MaterializeEncodingTypeConverter *>(
169
165
getTypeConverter ());
170
166
171
- FailureOr< MaterializeEncodingInfo> maybeEncodingInfo =
167
+ MaterializeEncodingInfo encodingInfo =
172
168
converter->getEncodingInfo (unsetEncodingOp.getSource ().getType ());
173
- if (failed (maybeEncodingInfo )) {
169
+ if (IREE::Codegen::isIdentityLayout (encodingInfo )) {
174
170
Type targetType =
175
171
getTypeConverter ()->convertType (unsetEncodingOp.getSourceType ());
176
172
Value result = rewriter.createOrFold <tensor::CastOp>(
@@ -181,35 +177,34 @@ struct GPUUnsetEncodingOpLoweringConversion
181
177
182
178
Location loc = unsetEncodingOp.getLoc ();
183
179
Value unpackSrc = adaptor.getSource ();
184
- if (maybeEncodingInfo-> swizzle ) {
180
+ if (encodingInfo. swizzle ) {
185
181
int targetRank = unsetEncodingOp.getResultType ().getRank ();
186
182
auto srcConvertedType =
187
183
cast<RankedTensorType>(adaptor.getSource ().getType ());
188
184
SmallVector<OpFoldResult> emptyShape =
189
185
tensor::getMixedSizes (rewriter, loc, adaptor.getSource ());
190
186
emptyShape.resize (targetRank);
191
- for (auto i :
192
- getExpandedTileShape (maybeEncodingInfo->swizzle ->expandShape )) {
187
+ for (auto i : getExpandedTileShape (encodingInfo.swizzle ->expandShape )) {
193
188
emptyShape.push_back (rewriter.getIndexAttr (i));
194
189
}
195
190
auto emptyTensor = rewriter.create <tensor::EmptyOp>(
196
191
loc, emptyShape, unsetEncodingOp.getSourceType ().getElementType ());
197
192
198
193
SmallVector<int64_t > transposePerm =
199
194
llvm::to_vector (llvm::seq<int64_t >(0 , targetRank));
200
- for (auto perm : maybeEncodingInfo-> swizzle ->permutation ) {
195
+ for (auto perm : encodingInfo. swizzle ->permutation ) {
201
196
transposePerm.push_back (targetRank + perm);
202
197
}
203
198
auto invertedTransposePerm = invertPermutationVector (transposePerm);
204
199
auto transposeOp = rewriter.create <linalg::TransposeOp>(
205
200
loc, adaptor.getSource (), emptyTensor, invertedTransposePerm);
206
201
207
202
SmallVector<ReassociationIndices> reassociation = getReassociationIndices (
208
- targetRank, maybeEncodingInfo-> swizzle ->expandShape );
203
+ targetRank, encodingInfo. swizzle ->expandShape );
209
204
SmallVector<int64_t > unpackSrcShape (
210
205
srcConvertedType.getShape ().take_front (targetRank));
211
- unpackSrcShape.append (maybeEncodingInfo-> innerTileSizes .begin (),
212
- maybeEncodingInfo-> innerTileSizes .end ());
206
+ unpackSrcShape.append (encodingInfo. innerTileSizes .begin (),
207
+ encodingInfo. innerTileSizes .end ());
213
208
RankedTensorType unpackSrcType =
214
209
unsetEncodingOp.getResultType ().clone (unpackSrcShape);
215
210
unpackSrc = rewriter.create <tensor::CollapseShapeOp>(
0 commit comments