18
18
#include " mlir/IR/BuiltinAttributes.h"
19
19
#include " mlir/IR/BuiltinOps.h"
20
20
#include " mlir/IR/Location.h"
21
+ #include " mlir/IR/Verifier.h"
21
22
22
23
#define DEBUG_TYPE " iree-codegen-link-tuning-specs"
23
24
#define DBGS () (llvm::dbgs() << " [" DEBUG_TYPE " ]: " )
@@ -53,35 +54,14 @@ static SmallVector<NamedSequenceOp> findTuningSpecs(ModuleOp module) {
53
54
});
54
55
}
55
56
56
- // Returns true iff the entrypoint has the following signature:
57
- // ```
58
- // transform.named_sequence @name(%arg0: !transform.any_op) ->
59
- // (!transform.any_op)
60
- // ```
61
- static LogicalResult validateTuningSpec (NamedSequenceOp op) {
62
- ArrayRef<Type> resTypes = op.getFunctionType ().getResults ();
63
- if (resTypes.size () != 1 || !isa<transform::AnyOpType>(resTypes[0 ])) {
64
- return op.emitWarning ()
65
- << " Tuning spec entry point expected to return any_op" ;
66
- }
67
-
68
- ArrayRef<Type> argTypes = op.getArgumentTypes ();
69
- if (argTypes.size () != 1 || !isa<transform::AnyOpType>(argTypes[0 ])) {
70
- return op.emitWarning () << " Tuning spec entry point expected to have a "
71
- " single any_op argument" ;
72
- }
73
-
74
- return success ();
75
- }
76
-
77
57
static bool consumesInputOp (NamedSequenceOp op) {
78
58
if (op.getArgAttr (0 , kArgConsumedAttrName )) {
79
59
return true ;
80
60
}
81
61
return false ;
82
62
}
83
63
84
- static NamedSequenceOp
64
+ static FailureOr< NamedSequenceOp>
85
65
emitLinkedTuningSpec (ModuleOp module, ArrayRef<NamedSequenceOp> specsToLink) {
86
66
OpBuilder builder (module->getContext ());
87
67
builder.setInsertionPointToEnd (module.getBody ());
@@ -144,6 +124,11 @@ emitLinkedTuningSpec(ModuleOp module, ArrayRef<NamedSequenceOp> specsToLink) {
144
124
}
145
125
146
126
builder.create <transform::YieldOp>(loc, operand);
127
+
128
+ if (failed (mlir::verify (module))) {
129
+ return module.emitError (" Linked tuning spec failed to verify" );
130
+ }
131
+
147
132
return newSpec;
148
133
}
149
134
@@ -169,13 +154,6 @@ FailureOr<NamedSequenceOp> linkTuningSpecs(ModuleOp module) {
169
154
llvm::append_range (tuningSpecs, findTuningSpecs (nested));
170
155
}
171
156
172
- for (NamedSequenceOp spec : tuningSpecs) {
173
- LDBG (" Found tuning spec: " << spec.getSymName ());
174
- if (failed (validateTuningSpec (spec))) {
175
- return failure ();
176
- }
177
- }
178
-
179
157
size_t numConsumedSpecs = llvm::count_if (tuningSpecs, consumesInputOp);
180
158
if (numConsumedSpecs > 0 && numConsumedSpecs != tuningSpecs.size ()) {
181
159
LDBG (" Only " << numConsumedSpecs << " tuning specs out of "
0 commit comments