Skip to content

Commit 3cc311a

Browse files
[mlir][Transforms] Dialect Conversion: No target mat. for 1:N replacement (llvm#117513)
During a 1:N replacement (`applySignatureConversion` or `replaceOpWithMultiple`), the dialect conversion driver used to insert two materializations: * Argument materialization: convert N replacement values to 1 SSA value of the original type `S`. * Target materialization: convert original type to legalized type `T`. The target materialization is unnecessary. Subsequent patterns receive the replacement values via their adaptors. These patterns have their own type converter. When they see a replacement value of type `S`, they will automatically insert a target materialization to type `T`. There is no reason to do this already during the 1:N replacement. (The functionality used to be duplicated in `remapValues` and `insertNTo1Materialization`.) Special case: If a subsequent pattern does not have a type converter, it does *not* insert any target materializations. That's because the absence of a type converter indicates that the pattern does not care about type legality. Therefore, it is correct to pass an SSA value of type `S` (or any other type) to the pattern. Note: Most patterns in `TestPatterns.cpp` run without a type converter. To make sure that the tests still behave the same, some of these patterns now have a type converter. This commit is in preparation of adding 1:N support to the conversion value mapping. Before making any further changes to the mapping infrastructure, I'd like to make sure that the code base around it (that uses the mapping) is robust.
1 parent c660b28 commit 3cc311a

File tree

4 files changed

+128
-103
lines changed

4 files changed

+128
-103
lines changed

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

+86-44
Original file line numberDiff line numberDiff line change
@@ -153,70 +153,112 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
153153
type.isVarArg());
154154
});
155155

156+
// Add generic source and target materializations to handle cases where
157+
// non-LLVM types persist after an LLVM conversion.
158+
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
159+
ValueRange inputs, Location loc) {
160+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
161+
.getResult(0);
162+
});
163+
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
164+
ValueRange inputs, Location loc) {
165+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
166+
.getResult(0);
167+
});
168+
156169
// Helper function that checks if the given value range is a bare pointer.
157170
auto isBarePointer = [](ValueRange values) {
158171
return values.size() == 1 &&
159172
isa<LLVM::LLVMPointerType>(values.front().getType());
160173
};
161174

162-
// Argument materializations convert from the new block argument types
163-
// (multiple SSA values that make up a memref descriptor) back to the
164-
// original block argument type. The dialect conversion framework will then
165-
// insert a target materialization from the original block argument type to
166-
// a legal type.
167-
addArgumentMaterialization([&](OpBuilder &builder,
168-
UnrankedMemRefType resultType,
169-
ValueRange inputs, Location loc) {
175+
// TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
176+
// must be passed explicitly.
177+
auto packUnrankedMemRefDesc =
178+
[&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
179+
Location loc, LLVMTypeConverter &converter) -> Value {
170180
// Note: Bare pointers are not supported for unranked memrefs because a
171181
// memref descriptor cannot be built just from a bare pointer.
172-
if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
182+
if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
173183
return Value();
174-
Value desc =
175-
UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
184+
return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
185+
inputs);
186+
};
187+
188+
// MemRef descriptor elements -> UnrankedMemRefType
189+
auto unrakedMemRefMaterialization = [&](OpBuilder &builder,
190+
UnrankedMemRefType resultType,
191+
ValueRange inputs, Location loc) {
176192
// An argument materialization must return a value of type
177193
// `resultType`, so insert a cast from the memref descriptor type
178194
// (!llvm.struct) to the original memref type.
179-
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
180-
.getResult(0);
181-
});
182-
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
183-
ValueRange inputs, Location loc) {
184-
Value desc;
185-
if (isBarePointer(inputs)) {
186-
desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
187-
inputs[0]);
188-
} else if (TypeRange(inputs) ==
189-
getMemRefDescriptorFields(resultType,
190-
/*unpackAggregates=*/true)) {
191-
desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
192-
} else {
193-
// The inputs are neither a bare pointer nor an unpacked memref
194-
// descriptor. This materialization function cannot be used.
195+
Value packed =
196+
packUnrankedMemRefDesc(builder, resultType, inputs, loc, *this);
197+
if (!packed)
195198
return Value();
196-
}
199+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
200+
.getResult(0);
201+
};
202+
203+
// TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
204+
// must be passed explicitly.
205+
auto packRankedMemRefDesc = [&](OpBuilder &builder, MemRefType resultType,
206+
ValueRange inputs, Location loc,
207+
LLVMTypeConverter &converter) -> Value {
208+
assert(resultType && "expected non-null result type");
209+
if (isBarePointer(inputs))
210+
return MemRefDescriptor::fromStaticShape(builder, loc, converter,
211+
resultType, inputs[0]);
212+
if (TypeRange(inputs) ==
213+
converter.getMemRefDescriptorFields(resultType,
214+
/*unpackAggregates=*/true))
215+
return MemRefDescriptor::pack(builder, loc, converter, resultType,
216+
inputs);
217+
// The inputs are neither a bare pointer nor an unpacked memref descriptor.
218+
// This materialization function cannot be used.
219+
return Value();
220+
};
221+
222+
// MemRef descriptor elements -> MemRefType
223+
auto rankedMemRefMaterialization = [&](OpBuilder &builder,
224+
MemRefType resultType,
225+
ValueRange inputs, Location loc) {
197226
// An argument materialization must return a value of type `resultType`,
198227
// so insert a cast from the memref descriptor type (!llvm.struct) to the
199228
// original memref type.
200-
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
201-
.getResult(0);
202-
});
203-
// Add generic source and target materializations to handle cases where
204-
// non-LLVM types persist after an LLVM conversion.
205-
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
206-
ValueRange inputs, Location loc) {
207-
if (inputs.size() != 1)
229+
Value packed =
230+
packRankedMemRefDesc(builder, resultType, inputs, loc, *this);
231+
if (!packed)
208232
return Value();
209-
210-
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
233+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
211234
.getResult(0);
212-
});
235+
};
236+
237+
// Argument materializations convert from the new block argument types
238+
// (multiple SSA values that make up a memref descriptor) back to the
239+
// original block argument type.
240+
addArgumentMaterialization(unrakedMemRefMaterialization);
241+
addArgumentMaterialization(rankedMemRefMaterialization);
242+
addSourceMaterialization(unrakedMemRefMaterialization);
243+
addSourceMaterialization(rankedMemRefMaterialization);
244+
245+
// Bare pointer -> Packed MemRef descriptor
213246
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
214-
ValueRange inputs, Location loc) {
215-
if (inputs.size() != 1)
247+
ValueRange inputs, Location loc,
248+
Type originalType) -> Value {
249+
// The original MemRef type is required to build a MemRef descriptor
250+
// because the sizes/strides of the MemRef cannot be inferred from just the
251+
// bare pointer.
252+
if (!originalType)
216253
return Value();
217-
218-
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
219-
.getResult(0);
254+
if (resultType != convertType(originalType))
255+
return Value();
256+
if (auto memrefType = dyn_cast<MemRefType>(originalType))
257+
return packRankedMemRefDesc(builder, memrefType, inputs, loc, *this);
258+
if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType))
259+
return packUnrankedMemRefDesc(builder, unrankedMemrefType, inputs, loc,
260+
*this);
261+
return Value();
220262
});
221263

222264
// Integer memory spaces map to themselves.

mlir/lib/Transforms/Utils/DialectConversion.cpp

+14-32
Original file line numberDiff line numberDiff line change
@@ -849,8 +849,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
849849
/// function will be deleted when full 1:N support has been added.
850850
///
851851
/// This function inserts an argument materialization back to the original
852-
/// type, followed by a target materialization to the legalized type (if
853-
/// applicable).
852+
/// type.
854853
void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
855854
ValueRange replacements, Value originalValue,
856855
const TypeConverter *converter);
@@ -1376,9 +1375,13 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13761375
// used as a replacement.
13771376
auto replArgs =
13781377
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1379-
insertNTo1Materialization(
1380-
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1381-
/*replacements=*/replArgs, /*outputValue=*/origArg, converter);
1378+
if (replArgs.size() == 1) {
1379+
mapping.map(origArg, replArgs.front());
1380+
} else {
1381+
insertNTo1Materialization(
1382+
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1383+
/*replacements=*/replArgs, /*outputValue=*/origArg, converter);
1384+
}
13821385
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
13831386
}
13841387

@@ -1437,36 +1440,12 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
14371440
// Insert argument materialization back to the original type.
14381441
Type originalType = originalValue.getType();
14391442
UnrealizedConversionCastOp argCastOp;
1440-
Value argMat = buildUnresolvedMaterialization(
1443+
buildUnresolvedMaterialization(
14411444
MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue,
1442-
/*inputs=*/replacements, originalType, /*originalType=*/Type(), converter,
1443-
&argCastOp);
1445+
/*inputs=*/replacements, originalType,
1446+
/*originalType=*/Type(), converter, &argCastOp);
14441447
if (argCastOp)
14451448
nTo1TempMaterializations.insert(argCastOp);
1446-
1447-
// Insert target materialization to the legalized type.
1448-
Type legalOutputType;
1449-
if (converter) {
1450-
legalOutputType = converter->convertType(originalType);
1451-
} else if (replacements.size() == 1) {
1452-
// When there is no type converter, assume that the replacement value
1453-
// types are legal. This is reasonable to assume because they were
1454-
// specified by the user.
1455-
// FIXME: This won't work for 1->N conversions because multiple output
1456-
// types are not supported in parts of the dialect conversion. In such a
1457-
// case, we currently use the original value type.
1458-
legalOutputType = replacements[0].getType();
1459-
}
1460-
if (legalOutputType && legalOutputType != originalType) {
1461-
UnrealizedConversionCastOp targetCastOp;
1462-
buildUnresolvedMaterialization(
1463-
MaterializationKind::Target, computeInsertPoint(argMat), loc,
1464-
/*valueToMap=*/argMat, /*inputs=*/argMat,
1465-
/*outputType=*/legalOutputType, /*originalType=*/originalType,
1466-
converter, &targetCastOp);
1467-
if (targetCastOp)
1468-
nTo1TempMaterializations.insert(targetCastOp);
1469-
}
14701449
}
14711450

14721451
Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
@@ -2864,6 +2843,9 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
28642843

28652844
LogicalResult TypeConverter::convertType(Type t,
28662845
SmallVectorImpl<Type> &results) const {
2846+
assert(this && "expected non-null type converter");
2847+
assert(t && "expected non-null type");
2848+
28672849
{
28682850
std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
28692851
std::defer_lock);

mlir/test/Transforms/test-legalizer.mlir

+4-4
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,10 @@ func.func @no_remap_nested() {
124124
// CHECK-NEXT: "foo.region"
125125
// expected-remark@+1 {{op 'foo.region' is not legalizable}}
126126
"foo.region"() ({
127-
// CHECK-NEXT: ^bb0(%{{.*}}: i64, %{{.*}}: i16, %{{.*}}: i64):
128-
^bb0(%i0: i64, %unused: i16, %i1: i64):
129-
// CHECK-NEXT: "test.valid"{{.*}} : (i64, i64)
130-
"test.invalid"(%i0, %i1) : (i64, i64) -> ()
127+
// CHECK-NEXT: ^bb0(%{{.*}}: f64, %{{.*}}: i16, %{{.*}}: f64):
128+
^bb0(%i0: f64, %unused: i16, %i1: f64):
129+
// CHECK-NEXT: "test.valid"{{.*}} : (f64, f64)
130+
"test.invalid"(%i0, %i1) : (f64, f64) -> ()
131131
}) : () -> ()
132132
// expected-remark@+1 {{op 'func.return' is not legalizable}}
133133
return

mlir/test/lib/Dialect/Test/TestPatterns.cpp

+24-23
Original file line numberDiff line numberDiff line change
@@ -985,8 +985,8 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
985985
};
986986
/// This pattern simply updates the operands of the given operation.
987987
struct TestPassthroughInvalidOp : public ConversionPattern {
988-
TestPassthroughInvalidOp(MLIRContext *ctx)
989-
: ConversionPattern("test.invalid", 1, ctx) {}
988+
TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
989+
: ConversionPattern(converter, "test.invalid", 1, ctx) {}
990990
LogicalResult
991991
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
992992
ConversionPatternRewriter &rewriter) const final {
@@ -1307,19 +1307,19 @@ struct TestLegalizePatternDriver
13071307
TestTypeConverter converter;
13081308
mlir::RewritePatternSet patterns(&getContext());
13091309
populateWithGenerated(patterns);
1310-
patterns.add<
1311-
TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
1312-
TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
1313-
TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp,
1314-
TestSplitReturnType, TestChangeProducerTypeI32ToF32,
1315-
TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
1316-
TestUpdateConsumerType, TestNonRootReplacement,
1317-
TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
1318-
TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1319-
TestUndoPropertiesModification, TestEraseOp,
1320-
TestRepetitive1ToNConsumer>(&getContext());
1321-
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
1322-
&getContext(), converter);
1310+
patterns
1311+
.add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
1312+
TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
1313+
TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
1314+
TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
1315+
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
1316+
TestNonRootReplacement, TestBoundedRecursiveRewrite,
1317+
TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
1318+
TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1319+
TestUndoPropertiesModification, TestEraseOp,
1320+
TestRepetitive1ToNConsumer>(&getContext());
1321+
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
1322+
TestPassthroughInvalidOp>(&getContext(), converter);
13231323
patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
13241324
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
13251325
converter);
@@ -1755,8 +1755,9 @@ struct TestTypeConversionAnotherProducer
17551755
};
17561756

17571757
struct TestReplaceWithLegalOp : public ConversionPattern {
1758-
TestReplaceWithLegalOp(MLIRContext *ctx)
1759-
: ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {}
1758+
TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx)
1759+
: ConversionPattern(converter, "test.replace_with_legal_op",
1760+
/*benefit=*/1, ctx) {}
17601761
LogicalResult
17611762
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
17621763
ConversionPatternRewriter &rewriter) const final {
@@ -1878,12 +1879,12 @@ struct TestTypeConversionDriver
18781879

18791880
// Initialize the set of rewrite patterns.
18801881
RewritePatternSet patterns(&getContext());
1881-
patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
1882-
TestSignatureConversionUndo,
1883-
TestTestSignatureConversionNoConverter>(converter,
1884-
&getContext());
1885-
patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>(
1886-
&getContext());
1882+
patterns
1883+
.add<TestTypeConsumerForward, TestTypeConversionProducer,
1884+
TestSignatureConversionUndo,
1885+
TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>(
1886+
converter, &getContext());
1887+
patterns.add<TestTypeConversionAnotherProducer>(&getContext());
18871888
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
18881889
converter);
18891890

0 commit comments

Comments
 (0)