@@ -197,14 +197,30 @@ struct MaterializeTuningSpecsPass final
197
197
return ;
198
198
}
199
199
200
- // If only the default tuning spec is available, use it directly and skip
201
- // the linking stage.
202
- if (!hasUserTuningSpec) {
203
- if (failed (dumpFinalTuningSpecToDir (*defaultTuningSpec))) {
200
+ // When the user tuning spec and default spec are available, link all
201
+ // available libraries into a single module. We insert the default tuning
202
+ // spec last, so that any user-specified tuning configurations take
203
+ // precedence.
204
+ SmallVector<ModuleOp, 2 > allSpecs;
205
+ if (hasUserTuningSpec) {
206
+ allSpecs.push_back (*userTuningSpec);
207
+ }
208
+ if (hasDefaultTuningSpec) {
209
+ allSpecs.push_back (*defaultTuningSpec);
210
+ }
211
+
212
+ // Determine if the linking pass should be skipped.
213
+ // Skip if there is only one tuning spec (either user-provided or default)
214
+ // with the default attribute.
215
+ if (allSpecs.size () == 1 &&
216
+ allSpecs[0 ]->hasAttr (kTuningSpecDefaultEntrypointAttrName )) {
217
+ // Use the appropriate tuning spec (user or default).
218
+ ModuleOp tuningSpecWithDefaultAttr = allSpecs[0 ];
219
+ if (failed (dumpFinalTuningSpecToDir (tuningSpecWithDefaultAttr))) {
204
220
return signalPassFailure ();
205
221
}
206
222
FailureOr<DenseElementsAttr> serializedSpec =
207
- serializeTuningSpecToAttr (*defaultTuningSpec );
223
+ serializeTuningSpecToAttr (tuningSpecWithDefaultAttr );
208
224
if (failed (serializedSpec)) {
209
225
module->emitError (" Failed to serialize default tuning specs" );
210
226
return signalPassFailure ();
@@ -213,14 +229,6 @@ struct MaterializeTuningSpecsPass final
213
229
return ;
214
230
}
215
231
216
- // When the user tuning spec is available, link all available libraries into
217
- // a single module. We insert the default tuning spec last, so that any
218
- // user-specified tuning configurations take precedence.
219
- SmallVector<ModuleOp, 2 > allSpecs = {*userTuningSpec};
220
- if (hasDefaultTuningSpec) {
221
- allSpecs.push_back (*defaultTuningSpec);
222
- }
223
-
224
232
Location loc =
225
233
FusedLoc::get (ctx, llvm::map_to_vector<2 >(allSpecs, [](ModuleOp m) {
226
234
return m.getLoc ();
0 commit comments