@@ -228,6 +228,8 @@ struct GenericOpTypePropagation
228
228
}
229
229
auto inputOperandType =
230
230
llvm::cast<RankedTensorType>(genericOp->getOperandTypes ()[index ]);
231
+ assert (inputOperandType.getElementType () == argType &&
232
+ " expected same element type" );
231
233
std::optional<Type> legalizedArgType =
232
234
legalizeStorageElementType (inputOperandType);
233
235
if (!legalizedArgType) {
@@ -259,6 +261,8 @@ struct GenericOpTypePropagation
259
261
modifyYield = true ;
260
262
OpOperand *yieldOperand =
261
263
modifiedOp.getMatchingYieldValue (modifiedOpOperand);
264
+ assert (llvm::cast<TensorType>(modifiedOpOperand->get ().getType ()).getElementType () ==
265
+ yieldOperand->get ().getType () && " expected same element type" );
262
266
std::optional<Type> legalizedType =
263
267
legalizeStorageElementType (modifiedOpOperand->get ().getType ());
264
268
if (!legalizedType) {
@@ -289,8 +293,11 @@ struct LinalgFillTypePropagation
289
293
matchAndRewrite (linalg::FillOp fillOp, OpAdaptor adaptor,
290
294
ConversionPatternRewriter &rewriter) const final {
291
295
Value value = adaptor.getInputs ().front ();
296
+ TensorType outputType = cast<TensorType>(adaptor.getOutputs ()[0 ].getType ());
297
+ assert (outputType.getElementType () == value.getType () &&
298
+ " expected same element type" );
292
299
std::optional<Type> legalizedElementType =
293
- legalizeStorageElementType (adaptor. getOutputs ()[ 0 ]. getType () );
300
+ legalizeStorageElementType (outputType );
294
301
if (!legalizedElementType) {
295
302
return fillOp.emitOpError (" failed to get legalized type for value" );
296
303
}
0 commit comments