Skip to content

Commit f4b5379

Browse files
committed
Adds asserts
1 parent 4a626d6 commit f4b5379

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ struct GenericOpTypePropagation
228228
}
229229
auto inputOperandType =
230230
llvm::cast<RankedTensorType>(genericOp->getOperandTypes()[index]);
231+
assert(inputOperandType.getElementType() == argType &&
232+
"expected same element type");
231233
std::optional<Type> legalizedArgType =
232234
legalizeStorageElementType(inputOperandType);
233235
if (!legalizedArgType) {
@@ -259,6 +261,8 @@ struct GenericOpTypePropagation
259261
modifyYield = true;
260262
OpOperand *yieldOperand =
261263
modifiedOp.getMatchingYieldValue(modifiedOpOperand);
264+
assert(llvm::cast<TensorType>(modifiedOpOperand->get().getType()).getElementType() ==
265+
yieldOperand->get().getType() && "expected same element type");
262266
std::optional<Type> legalizedType =
263267
legalizeStorageElementType(modifiedOpOperand->get().getType());
264268
if (!legalizedType) {
@@ -289,8 +293,11 @@ struct LinalgFillTypePropagation
289293
matchAndRewrite(linalg::FillOp fillOp, OpAdaptor adaptor,
290294
ConversionPatternRewriter &rewriter) const final {
291295
Value value = adaptor.getInputs().front();
296+
TensorType outputType = cast<TensorType>(adaptor.getOutputs()[0].getType());
297+
assert(outputType.getElementType() == value.getType() &&
298+
"expected same element type");
292299
std::optional<Type> legalizedElementType =
293-
legalizeStorageElementType(adaptor.getOutputs()[0].getType());
300+
legalizeStorageElementType(outputType);
294301
if (!legalizedElementType) {
295302
return fillOp.emitOpError("failed to get legalized type for value");
296303
}

0 commit comments

Comments
 (0)