@@ -52,9 +52,11 @@ void IREECodegenDialect::initialize() {
52
52
addTypes<IREE::Codegen::NullPointerType>();
53
53
}
54
54
55
- static LogicalResult validateTuningEntrypoint (
56
- ModuleOp moduleOp, transform::NamedSequenceOp &kernelConfigOp,
57
- int &numTuningEntryPoints, const std::string &kRequiredByDefaultAttr ) {
55
+ static LogicalResult
56
+ validateTuningEntrypoint (ModuleOp moduleOp,
57
+ transform::NamedSequenceOp &kernelConfigOp,
58
+ StringRef requiredByDefaultAttrMessage) {
59
+ int numTuningEntryPoints = 0 ;
58
60
for (Operation &op : moduleOp.getBody ()->getOperations ()) {
59
61
if (auto namedSeqOp = dyn_cast<transform::NamedSequenceOp>(&op)) {
60
62
if (namedSeqOp.getName () == kKernelConfigSpecName ) {
@@ -70,7 +72,7 @@ static LogicalResult validateTuningEntrypoint(
70
72
if (!kernelConfigOp) {
71
73
return moduleOp->emitError ()
72
74
<< " Missing named sequence '" << kKernelConfigSpecName << " '"
73
- << kRequiredByDefaultAttr ;
75
+ << requiredByDefaultAttrMessage ;
74
76
}
75
77
76
78
// Verify that the kernelConfigOp has the attribute
@@ -79,73 +81,63 @@ static LogicalResult validateTuningEntrypoint(
79
81
return kernelConfigOp.emitError ()
80
82
<< " Missing attribute '" << kTuningSpecEntrypointAttrName
81
83
<< " ' in named sequence '" << kKernelConfigSpecName << " '"
82
- << kRequiredByDefaultAttr ;
84
+ << requiredByDefaultAttrMessage ;
83
85
}
84
86
85
87
if (numTuningEntryPoints != 1 ) {
86
88
return moduleOp.emitError ()
87
89
<< " Expected one named sequence with '"
88
90
<< kTuningSpecEntrypointAttrName << " ', but found "
89
- << numTuningEntryPoints << kRequiredByDefaultAttr ;
91
+ << numTuningEntryPoints << requiredByDefaultAttrMessage ;
90
92
}
91
93
92
94
return success ();
93
95
}
94
96
95
97
LogicalResult
96
98
validateKernelConfigContents (transform::NamedSequenceOp kernelConfigOp,
97
- const std::string &kRequiredByDefaultAttr ) {
98
- transform::ForeachMatchOp foreachMatchOp;
99
- bool hasYieldOp = false ;
100
- for (Block &block : kernelConfigOp.getBlocks ()) {
101
- for (Operation &op : block) {
102
- if (auto foreachOp = dyn_cast<transform::ForeachMatchOp>(op)) {
103
- if (foreachMatchOp) {
104
- return kernelConfigOp.emitError ()
105
- << " '" << kKernelConfigSpecName
106
- << " ' must contain exactly one 'ForeachMatchOp'"
107
- << kRequiredByDefaultAttr ;
108
- }
109
- foreachMatchOp = foreachOp;
110
- } else if (isa<transform::YieldOp>(op)) {
111
- if (hasYieldOp) {
112
- return kernelConfigOp.emitError ()
113
- << " '" << kKernelConfigSpecName
114
- << " ' must contain exactly one 'transform::YieldOp'"
115
- << kRequiredByDefaultAttr ;
116
- }
117
- hasYieldOp = true ;
118
- } else {
119
- return kernelConfigOp.emitError ()
120
- << " Unexpected op '" << op.getName () << " ' in '"
121
- << kKernelConfigSpecName << " '" << kRequiredByDefaultAttr ;
122
- }
123
- }
99
+ StringRef requiredByDefaultAttrMessage) {
100
+ // Ensure there is exactly one block.
101
+ if (kernelConfigOp.getBlocks ().size () != 1 ) {
102
+ return kernelConfigOp.emitError ()
103
+ << " '" << kKernelConfigSpecName << " ' must contain exactly one block"
104
+ << requiredByDefaultAttrMessage;
124
105
}
125
106
126
- // Ensure both required operations are present.
107
+ Block &block = kernelConfigOp.getBlocks ().front ();
108
+ // Ensure there are exactly two operations.
109
+ if (block.getOperations ().size () != 2 ) {
110
+ return kernelConfigOp.emitError () << " '" << kKernelConfigSpecName
111
+ << " ' must contain exactly two operations"
112
+ << requiredByDefaultAttrMessage;
113
+ }
114
+
115
+ auto opIt = block.begin ();
116
+ auto foreachMatchOp = dyn_cast<transform::ForeachMatchOp>(&*opIt);
127
117
if (!foreachMatchOp) {
128
- return kernelConfigOp.emitError ()
129
- << " Missing 'ForeachMatchOp' in ' " << kKernelConfigSpecName << " '"
130
- << kRequiredByDefaultAttr ;
118
+ return kernelConfigOp.emitError () << " ' " << kKernelConfigSpecName
119
+ << " ' must start with 'ForeachMatchOp '"
120
+ << requiredByDefaultAttrMessage ;
131
121
}
132
122
133
- if (!hasYieldOp) {
134
- return kernelConfigOp.emitError ()
135
- << " Missing 'transform::YieldOp' in '" << kKernelConfigSpecName
136
- << " '" << kRequiredByDefaultAttr ;
123
+ ++opIt;
124
+ auto yieldOp = dyn_cast<transform::YieldOp>(&*opIt);
125
+ if (!yieldOp) {
126
+ return kernelConfigOp.emitError () << " '" << kKernelConfigSpecName
127
+ << " ' must end with 'transform::YieldOp'"
128
+ << requiredByDefaultAttrMessage;
137
129
}
138
130
139
131
if (foreachMatchOp.getRestrictRootAttr ()) {
140
132
return foreachMatchOp.emitError ()
141
133
<< " 'ForeachMatchOp' must not have 'restrict_root' attribute"
142
- << kRequiredByDefaultAttr ;
134
+ << requiredByDefaultAttrMessage ;
143
135
}
144
136
145
137
if (foreachMatchOp.getFlattenResultsAttr ()) {
146
138
return foreachMatchOp.emitError ()
147
139
<< " 'ForeachMatchOp' must not have 'flatten_results' attribute"
148
- << kRequiredByDefaultAttr ;
140
+ << requiredByDefaultAttrMessage ;
149
141
}
150
142
151
143
Type anyOpType = transform::AnyOpType::get (kernelConfigOp.getContext ());
@@ -154,15 +146,15 @@ validateKernelConfigContents(transform::NamedSequenceOp kernelConfigOp,
154
146
if (argTypes.size () != 1 || argTypes.front () != anyOpType) {
155
147
return foreachMatchOp.emitError ()
156
148
<< " 'ForeachMatchOp' must take exactly one 'any_op' argument"
157
- << kRequiredByDefaultAttr ;
149
+ << requiredByDefaultAttrMessage ;
158
150
}
159
151
160
152
SmallVector<Type> resultTypes (foreachMatchOp.getResultTypes ());
161
153
// Ensure the operation has exactly one result of type any_op.
162
154
if (resultTypes.size () != 1 || resultTypes.front () != anyOpType) {
163
155
return foreachMatchOp.emitError ()
164
156
<< " 'ForeachMatchOp' must return exactly one 'any_op' result"
165
- << kRequiredByDefaultAttr ;
157
+ << requiredByDefaultAttrMessage ;
166
158
}
167
159
168
160
return success ();
@@ -199,20 +191,18 @@ IREECodegenDialect::verifyOperationAttribute(Operation *op,
199
191
// of type `transform::AnyOpType`.
200
192
201
193
if (symbol == kTuningSpecDefaultEntrypointAttrName ) {
202
- const std::string kRequiredByDefaultAttr =
194
+ const std::string requiredByDefaultAttrMessage =
203
195
" (required by '" + std::string (kTuningSpecDefaultEntrypointAttrName ) +
204
- " '). " ;
196
+ " ')" ;
205
197
if (auto moduleOp = dyn_cast<ModuleOp>(op)) {
206
198
transform::NamedSequenceOp kernelConfigOp;
207
- int numTuningEntryPoints = 0 ;
208
199
if (failed (validateTuningEntrypoint (moduleOp, kernelConfigOp,
209
- numTuningEntryPoints,
210
- kRequiredByDefaultAttr ))) {
200
+ requiredByDefaultAttrMessage))) {
211
201
return failure ();
212
202
}
213
203
214
204
if (failed (validateKernelConfigContents (kernelConfigOp,
215
- kRequiredByDefaultAttr ))) {
205
+ requiredByDefaultAttrMessage ))) {
216
206
return failure ();
217
207
}
218
208
}
@@ -221,8 +211,8 @@ IREECodegenDialect::verifyOperationAttribute(Operation *op,
221
211
if (symbol != kTuningSpecEntrypointAttrName )
222
212
return success ();
223
213
224
- const std::string kRequiredByEntrypoint =
225
- " (required by '" + std::string (kTuningSpecEntrypointAttrName ) + " '). " ;
214
+ const std::string requiredByEntrypointMessage =
215
+ " (required by '" + std::string (kTuningSpecEntrypointAttrName ) + " ')" ;
226
216
if (!isa<UnitAttr>(attr)) {
227
217
return op->emitError (" '" ) << symbol << " ' attribute must be a UnitAttr" ;
228
218
}
@@ -231,13 +221,13 @@ IREECodegenDialect::verifyOperationAttribute(Operation *op,
231
221
ArrayRef<Type> resTypes = namedSeqOp.getFunctionType ().getResults ();
232
222
if (resTypes.size () != 1 || !isa<transform::AnyOpType>(resTypes[0 ])) {
233
223
return namedSeqOp.emitError ()
234
- << " Must return one 'any_op'" << kRequiredByEntrypoint ;
224
+ << " Must return one 'any_op'" << requiredByEntrypointMessage ;
235
225
}
236
226
237
227
ArrayRef<Type> argTypes = namedSeqOp.getArgumentTypes ();
238
228
if (argTypes.size () != 1 || !isa<transform::AnyOpType>(argTypes[0 ])) {
239
229
return namedSeqOp.emitError ()
240
- << " Must take one 'any_op'" << kRequiredByEntrypoint ;
230
+ << " Must take one 'any_op'" << requiredByEntrypointMessage ;
241
231
}
242
232
}
243
233
0 commit comments