Skip to content

Commit b29ff50

Browse files
committed
format the verifier
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
1 parent 6f36c8e commit b29ff50

File tree

5 files changed

+67
-111
lines changed

5 files changed

+67
-111
lines changed

compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ emitLinkedTuningSpec(ModuleOp module, ArrayRef<NamedSequenceOp> specsToLink) {
8383
0, hasConsumedSequences ? kArgConsumedAttrName : kArgReadOnlyAttrName,
8484
builder.getUnitAttr());
8585
newSpec->setAttr(kTuningSpecEntrypointAttrName, builder.getUnitAttr());
86-
// TODO: re-enable default attribute as below once new linking lands.
86+
// TODO: Re-enable default attribute as below once new linking lands.
8787
// module->setAttr(kTuningSpecDefaultEntrypointAttrName,
8888
// builder.getUnitAttr());
8989

compiler/src/iree/compiler/Codegen/Common/test/materialize_tuning_specs.mlir

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
// SKIPLINK-LABEL: module @user_spec
3434
// SKIPLINK-SAME: iree_codegen.tuning_spec_with_default_entrypoint
3535
// SKIPLINK-SAME: transform.with_named_sequence
36+
// SKIPLINK: transform.print {name = "Hello Tuning Spec"}
3637
// SKIPLINK-NOT: module @{{.+}}
3738
// SKIPLINK: module attributes
3839
// SKIPLINK-SAME: iree_codegen.tuning_spec_mlirbc = dense<{{.+}}> : vector<{{[0-9]+}}xi8>

compiler/src/iree/compiler/Codegen/Common/test/tuning_spec_default.mlir

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
module @user_spec attributes { transform.with_named_sequence, iree_codegen.tuning_spec_with_default_entrypoint } {
44
transform.named_sequence @match(%arg: !transform.any_op {transform.readonly}) -> (!transform.any_op) {
5+
transform.print {name = "Hello Tuning Spec"}
56
transform.yield %arg : !transform.any_op
67
}
78

compiler/src/iree/compiler/Codegen/Common/test/verify_tuning_specs.mlir

+19-55
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,22 @@ module @iree_default_tuning_spec attributes { iree_codegen.tuning_spec_with_defa
7979

8080
// -----
8181

82+
module @iree_default_tuning_spec attributes { iree_codegen.tuning_spec_with_default_entrypoint } {
83+
transform.named_sequence @match(%arg: !transform.any_op) -> (!transform.any_op) {
84+
transform.yield %arg : !transform.any_op
85+
}
86+
87+
transform.named_sequence @apply_op_config(%op: !transform.any_op) {
88+
transform.yield
89+
}
90+
91+
// expected-error @+1{{'__kernel_config' must contain exactly one block (required by 'iree_codegen.tuning_spec_with_default_entrypoint')}}
92+
transform.named_sequence @__kernel_config(%arg0: !transform.any_op)
93+
-> (!transform.any_op) attributes { iree_codegen.tuning_spec_entrypoint }
94+
}
95+
96+
// -----
97+
8298
module @iree_default_tuning_spec attributes { iree_codegen.tuning_spec_with_default_entrypoint } {
8399
transform.named_sequence @match(%arg: !transform.any_op {transform.readonly}) -> (!transform.any_op) {
84100
transform.yield %arg : !transform.any_op
@@ -88,7 +104,7 @@ module @iree_default_tuning_spec attributes { iree_codegen.tuning_spec_with_defa
88104
transform.yield
89105
}
90106

91-
// expected-error @+1{{'__kernel_config' must contain exactly one 'ForeachMatchOp' (required by 'iree_codegen.tuning_spec_with_default_entrypoint')}}
107+
// expected-error @+1{{'__kernel_config' must contain exactly two operations (required by 'iree_codegen.tuning_spec_with_default_entrypoint')}}
92108
transform.named_sequence @__kernel_config(%arg0: !transform.any_op {transform.consumed})
93109
-> (!transform.any_op) attributes { iree_codegen.tuning_spec_entrypoint } {
94110

@@ -128,15 +144,11 @@ module @iree_default_tuning_spec attributes { iree_codegen.tuning_spec_with_defa
128144
transform.yield
129145
}
130146

131-
// expected-error @+1{{Unexpected op 'transform.print' in '__kernel_config' (required by 'iree_codegen.tuning_spec_with_default_entrypoint')}}
147+
// expected-error @+1{{'__kernel_config' must start with 'ForeachMatchOp' (required by 'iree_codegen.tuning_spec_with_default_entrypoint')}}
132148
transform.named_sequence @__kernel_config(%arg0: !transform.any_op {transform.consumed})
133149
-> (!transform.any_op) attributes { iree_codegen.tuning_spec_entrypoint } {
134-
%res = transform.foreach_match in %arg0
135-
@match -> @apply_op_config
136-
: (!transform.any_op) -> (!transform.any_op)
137-
138-
transform.yield %res : !transform.any_op
139150
transform.print {name = "Hello"}
151+
transform.yield %arg0 : !transform.any_op
140152
}
141153
}
142154

@@ -156,54 +168,6 @@ module @iree_default_tuning_spec attributes { iree_codegen.tuning_spec_with_defa
156168

157169
// -----
158170

159-
module @iree_default_tuning_spec attributes { iree_codegen.tuning_spec_with_default_entrypoint } {
160-
transform.named_sequence private @dummy_func(!transform.any_op {transform.consumed}) -> !transform.any_op
161-
transform.named_sequence @match(%arg: !transform.any_op {transform.readonly}) -> (!transform.any_op) {
162-
transform.yield %arg : !transform.any_op
163-
}
164-
165-
transform.named_sequence @apply_op_config(%op: !transform.any_op {transform.readonly}) {
166-
transform.yield
167-
}
168-
169-
// expected-error @+1{{Unexpected op 'transform.include' in '__kernel_config' (required by 'iree_codegen.tuning_spec_with_default_entrypoint')}}
170-
transform.named_sequence @__kernel_config(%arg0: !transform.any_op {transform.consumed})
171-
-> (!transform.any_op) attributes { iree_codegen.tuning_spec_entrypoint } {
172-
%tmp = transform.include @dummy_func failures(suppress) (%arg0) : (!transform.any_op) -> (!transform.any_op)
173-
%res = transform.foreach_match in %tmp
174-
@match -> @apply_op_config
175-
: (!transform.any_op) -> (!transform.any_op)
176-
177-
transform.yield %res : !transform.any_op
178-
}
179-
}
180-
181-
// -----
182-
183-
module @iree_default_tuning_spec attributes { iree_codegen.tuning_spec_with_default_entrypoint } {
184-
transform.named_sequence private @dummy_func(!transform.any_op {transform.consumed}) -> !transform.any_op
185-
transform.named_sequence @match(%arg: !transform.any_op {transform.readonly}) -> (!transform.any_op) {
186-
transform.yield %arg : !transform.any_op
187-
}
188-
189-
transform.named_sequence @apply_op_config(%op: !transform.any_op {transform.readonly}) {
190-
transform.yield
191-
}
192-
193-
// expected-error @+1{{Unexpected op 'transform.print' in '__kernel_config' (required by 'iree_codegen.tuning_spec_with_default_entrypoint')}}
194-
transform.named_sequence @__kernel_config(%arg0: !transform.any_op {transform.consumed})
195-
-> (!transform.any_op) attributes { iree_codegen.tuning_spec_entrypoint } {
196-
transform.print {name = "Hello"}
197-
%res = transform.foreach_match in
198-
%arg0 @match -> @apply_op_config
199-
: (!transform.any_op) -> (!transform.any_op)
200-
201-
transform.yield %res : !transform.any_op
202-
}
203-
}
204-
205-
// -----
206-
207171
module @iree_default_tuning_spec attributes { transform.with_named_sequence, iree_codegen.tuning_spec_with_default_entrypoint } {
208172
transform.named_sequence @match(%arg: !transform.any_op {transform.readonly})
209173
-> (!transform.any_op, !transform.any_op) {

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.cpp

+45-55
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,11 @@ void IREECodegenDialect::initialize() {
5252
addTypes<IREE::Codegen::NullPointerType>();
5353
}
5454

55-
static LogicalResult validateTuningEntrypoint(
56-
ModuleOp moduleOp, transform::NamedSequenceOp &kernelConfigOp,
57-
int &numTuningEntryPoints, const std::string &kRequiredByDefaultAttr) {
55+
static LogicalResult
56+
validateTuningEntrypoint(ModuleOp moduleOp,
57+
transform::NamedSequenceOp &kernelConfigOp,
58+
StringRef requiredByDefaultAttrMessage) {
59+
int numTuningEntryPoints = 0;
5860
for (Operation &op : moduleOp.getBody()->getOperations()) {
5961
if (auto namedSeqOp = dyn_cast<transform::NamedSequenceOp>(&op)) {
6062
if (namedSeqOp.getName() == kKernelConfigSpecName) {
@@ -70,7 +72,7 @@ static LogicalResult validateTuningEntrypoint(
7072
if (!kernelConfigOp) {
7173
return moduleOp->emitError()
7274
<< "Missing named sequence '" << kKernelConfigSpecName << "'"
73-
<< kRequiredByDefaultAttr;
75+
<< requiredByDefaultAttrMessage;
7476
}
7577

7678
// Verify that the kernelConfigOp has the attribute
@@ -79,73 +81,63 @@ static LogicalResult validateTuningEntrypoint(
7981
return kernelConfigOp.emitError()
8082
<< "Missing attribute '" << kTuningSpecEntrypointAttrName
8183
<< "' in named sequence '" << kKernelConfigSpecName << "'"
82-
<< kRequiredByDefaultAttr;
84+
<< requiredByDefaultAttrMessage;
8385
}
8486

8587
if (numTuningEntryPoints != 1) {
8688
return moduleOp.emitError()
8789
<< "Expected one named sequence with '"
8890
<< kTuningSpecEntrypointAttrName << "', but found "
89-
<< numTuningEntryPoints << kRequiredByDefaultAttr;
91+
<< numTuningEntryPoints << requiredByDefaultAttrMessage;
9092
}
9193

9294
return success();
9395
}
9496

9597
LogicalResult
9698
validateKernelConfigContents(transform::NamedSequenceOp kernelConfigOp,
97-
const std::string &kRequiredByDefaultAttr) {
98-
transform::ForeachMatchOp foreachMatchOp;
99-
bool hasYieldOp = false;
100-
for (Block &block : kernelConfigOp.getBlocks()) {
101-
for (Operation &op : block) {
102-
if (auto foreachOp = dyn_cast<transform::ForeachMatchOp>(op)) {
103-
if (foreachMatchOp) {
104-
return kernelConfigOp.emitError()
105-
<< "'" << kKernelConfigSpecName
106-
<< "' must contain exactly one 'ForeachMatchOp'"
107-
<< kRequiredByDefaultAttr;
108-
}
109-
foreachMatchOp = foreachOp;
110-
} else if (isa<transform::YieldOp>(op)) {
111-
if (hasYieldOp) {
112-
return kernelConfigOp.emitError()
113-
<< "'" << kKernelConfigSpecName
114-
<< "' must contain exactly one 'transform::YieldOp'"
115-
<< kRequiredByDefaultAttr;
116-
}
117-
hasYieldOp = true;
118-
} else {
119-
return kernelConfigOp.emitError()
120-
<< "Unexpected op '" << op.getName() << "' in '"
121-
<< kKernelConfigSpecName << "'" << kRequiredByDefaultAttr;
122-
}
123-
}
99+
StringRef requiredByDefaultAttrMessage) {
100+
// Ensure there is exactly one block.
101+
if (kernelConfigOp.getBlocks().size() != 1) {
102+
return kernelConfigOp.emitError()
103+
<< "'" << kKernelConfigSpecName << "' must contain exactly one block"
104+
<< requiredByDefaultAttrMessage;
124105
}
125106

126-
// Ensure both required operations are present.
107+
Block &block = kernelConfigOp.getBlocks().front();
108+
// Ensure there are exactly two operations.
109+
if (block.getOperations().size() != 2) {
110+
return kernelConfigOp.emitError() << "'" << kKernelConfigSpecName
111+
<< "' must contain exactly two operations"
112+
<< requiredByDefaultAttrMessage;
113+
}
114+
115+
auto opIt = block.begin();
116+
auto foreachMatchOp = dyn_cast<transform::ForeachMatchOp>(&*opIt);
127117
if (!foreachMatchOp) {
128-
return kernelConfigOp.emitError()
129-
<< "Missing 'ForeachMatchOp' in '" << kKernelConfigSpecName << "'"
130-
<< kRequiredByDefaultAttr;
118+
return kernelConfigOp.emitError() << "'" << kKernelConfigSpecName
119+
<< "' must start with 'ForeachMatchOp'"
120+
<< requiredByDefaultAttrMessage;
131121
}
132122

133-
if (!hasYieldOp) {
134-
return kernelConfigOp.emitError()
135-
<< "Missing 'transform::YieldOp' in '" << kKernelConfigSpecName
136-
<< "'" << kRequiredByDefaultAttr;
123+
++opIt;
124+
auto yieldOp = dyn_cast<transform::YieldOp>(&*opIt);
125+
if (!yieldOp) {
126+
return kernelConfigOp.emitError() << "'" << kKernelConfigSpecName
127+
<< "' must end with 'transform::YieldOp'"
128+
<< requiredByDefaultAttrMessage;
137129
}
138130

139131
if (foreachMatchOp.getRestrictRootAttr()) {
140132
return foreachMatchOp.emitError()
141133
<< "'ForeachMatchOp' must not have 'restrict_root' attribute"
142-
<< kRequiredByDefaultAttr;
134+
<< requiredByDefaultAttrMessage;
143135
}
144136

145137
if (foreachMatchOp.getFlattenResultsAttr()) {
146138
return foreachMatchOp.emitError()
147139
<< "'ForeachMatchOp' must not have 'flatten_results' attribute"
148-
<< kRequiredByDefaultAttr;
140+
<< requiredByDefaultAttrMessage;
149141
}
150142

151143
Type anyOpType = transform::AnyOpType::get(kernelConfigOp.getContext());
@@ -154,15 +146,15 @@ validateKernelConfigContents(transform::NamedSequenceOp kernelConfigOp,
154146
if (argTypes.size() != 1 || argTypes.front() != anyOpType) {
155147
return foreachMatchOp.emitError()
156148
<< "'ForeachMatchOp' must take exactly one 'any_op' argument"
157-
<< kRequiredByDefaultAttr;
149+
<< requiredByDefaultAttrMessage;
158150
}
159151

160152
SmallVector<Type> resultTypes(foreachMatchOp.getResultTypes());
161153
// Ensure the operation has exactly one result of type any_op.
162154
if (resultTypes.size() != 1 || resultTypes.front() != anyOpType) {
163155
return foreachMatchOp.emitError()
164156
<< "'ForeachMatchOp' must return exactly one 'any_op' result"
165-
<< kRequiredByDefaultAttr;
157+
<< requiredByDefaultAttrMessage;
166158
}
167159

168160
return success();
@@ -199,20 +191,18 @@ IREECodegenDialect::verifyOperationAttribute(Operation *op,
199191
// of type `transform::AnyOpType`.
200192

201193
if (symbol == kTuningSpecDefaultEntrypointAttrName) {
202-
const std::string kRequiredByDefaultAttr =
194+
const std::string requiredByDefaultAttrMessage =
203195
" (required by '" + std::string(kTuningSpecDefaultEntrypointAttrName) +
204-
"').";
196+
"')";
205197
if (auto moduleOp = dyn_cast<ModuleOp>(op)) {
206198
transform::NamedSequenceOp kernelConfigOp;
207-
int numTuningEntryPoints = 0;
208199
if (failed(validateTuningEntrypoint(moduleOp, kernelConfigOp,
209-
numTuningEntryPoints,
210-
kRequiredByDefaultAttr))) {
200+
requiredByDefaultAttrMessage))) {
211201
return failure();
212202
}
213203

214204
if (failed(validateKernelConfigContents(kernelConfigOp,
215-
kRequiredByDefaultAttr))) {
205+
requiredByDefaultAttrMessage))) {
216206
return failure();
217207
}
218208
}
@@ -221,8 +211,8 @@ IREECodegenDialect::verifyOperationAttribute(Operation *op,
221211
if (symbol != kTuningSpecEntrypointAttrName)
222212
return success();
223213

224-
const std::string kRequiredByEntrypoint =
225-
" (required by '" + std::string(kTuningSpecEntrypointAttrName) + "').";
214+
const std::string requiredByEntrypointMessage =
215+
" (required by '" + std::string(kTuningSpecEntrypointAttrName) + "')";
226216
if (!isa<UnitAttr>(attr)) {
227217
return op->emitError("'") << symbol << "' attribute must be a UnitAttr";
228218
}
@@ -231,13 +221,13 @@ IREECodegenDialect::verifyOperationAttribute(Operation *op,
231221
ArrayRef<Type> resTypes = namedSeqOp.getFunctionType().getResults();
232222
if (resTypes.size() != 1 || !isa<transform::AnyOpType>(resTypes[0])) {
233223
return namedSeqOp.emitError()
234-
<< "Must return one 'any_op'" << kRequiredByEntrypoint;
224+
<< "Must return one 'any_op'" << requiredByEntrypointMessage;
235225
}
236226

237227
ArrayRef<Type> argTypes = namedSeqOp.getArgumentTypes();
238228
if (argTypes.size() != 1 || !isa<transform::AnyOpType>(argTypes[0])) {
239229
return namedSeqOp.emitError()
240-
<< "Must take one 'any_op'" << kRequiredByEntrypoint;
230+
<< "Must take one 'any_op'" << requiredByEntrypointMessage;
241231
}
242232
}
243233

0 commit comments

Comments
 (0)