8
8
from fastapi import FastAPI
9
9
from gradio import Blocks
10
10
11
- from modules import errors , timer , extensions
11
+ from modules import errors , timer , extensions , shared
12
12
13
13
14
14
def report_exception (c , job ):
@@ -124,9 +124,10 @@ class ScriptCallback:
124
124
name : str = None
125
125
126
126
127
- def add_callback (callbacks , fun , * , name = None , category = 'unknown' ):
128
- stack = [x for x in inspect .stack () if x .filename != __file__ ]
129
- filename = stack [0 ].filename if stack else 'unknown file'
127
+ def add_callback (callbacks , fun , * , name = None , category = 'unknown' , filename = None ):
128
+ if filename is None :
129
+ stack = [x for x in inspect .stack () if x .filename != __file__ ]
130
+ filename = stack [0 ].filename if stack else 'unknown file'
130
131
131
132
extension = extensions .find_extension (filename )
132
133
extension_name = extension .canonical_name if extension else 'base'
@@ -146,6 +147,43 @@ def add_callback(callbacks, fun, *, name=None, category='unknown'):
146
147
callbacks .append (ScriptCallback (filename , fun , unique_callback_name ))
147
148
148
149
150
+ def sort_callbacks (category , unordered_callbacks , * , enable_user_sort = True ):
151
+ callbacks = unordered_callbacks .copy ()
152
+
153
+ if enable_user_sort :
154
+ for name in reversed (getattr (shared .opts , 'prioritized_callbacks_' + category , [])):
155
+ index = next ((i for i , callback in enumerate (callbacks ) if callback .name == name ), None )
156
+ if index is not None :
157
+ callbacks .insert (0 , callbacks .pop (index ))
158
+
159
+ return callbacks
160
+
161
+
162
+ def ordered_callbacks (category , unordered_callbacks = None , * , enable_user_sort = True ):
163
+ if unordered_callbacks is None :
164
+ unordered_callbacks = callback_map .get ('callbacks_' + category , [])
165
+
166
+ if not enable_user_sort :
167
+ return sort_callbacks (category , unordered_callbacks , enable_user_sort = False )
168
+
169
+ callbacks = ordered_callbacks_map .get (category )
170
+ if callbacks is not None and len (callbacks ) == len (unordered_callbacks ):
171
+ return callbacks
172
+
173
+ callbacks = sort_callbacks (category , unordered_callbacks )
174
+
175
+ ordered_callbacks_map [category ] = callbacks
176
+ return callbacks
177
+
178
+
179
+ def enumerate_callbacks ():
180
+ for category , callbacks in callback_map .items ():
181
+ if category .startswith ('callbacks_' ):
182
+ category = category [10 :]
183
+
184
+ yield category , callbacks
185
+
186
+
149
187
callback_map = dict (
150
188
callbacks_app_started = [],
151
189
callbacks_model_loaded = [],
@@ -170,14 +208,18 @@ def add_callback(callbacks, fun, *, name=None, category='unknown'):
170
208
callbacks_before_token_counter = [],
171
209
)
172
210
211
+ ordered_callbacks_map = {}
212
+
173
213
174
214
def clear_callbacks ():
175
215
for callback_list in callback_map .values ():
176
216
callback_list .clear ()
177
217
218
+ ordered_callbacks_map .clear ()
219
+
178
220
179
221
def app_started_callback (demo : Optional [Blocks ], app : FastAPI ):
180
- for c in callback_map [ 'callbacks_app_started' ] :
222
+ for c in ordered_callbacks ( 'app_started' ) :
181
223
try :
182
224
c .callback (demo , app )
183
225
timer .startup_timer .record (os .path .basename (c .script ))
@@ -186,15 +228,15 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
186
228
187
229
188
230
def app_reload_callback ():
189
- for c in callback_map [ 'callbacks_on_reload' ] :
231
+ for c in ordered_callbacks ( 'on_reload' ) :
190
232
try :
191
233
c .callback ()
192
234
except Exception :
193
235
report_exception (c , 'callbacks_on_reload' )
194
236
195
237
196
238
def model_loaded_callback (sd_model ):
197
- for c in callback_map [ 'callbacks_model_loaded' ] :
239
+ for c in ordered_callbacks ( 'model_loaded' ) :
198
240
try :
199
241
c .callback (sd_model )
200
242
except Exception :
@@ -204,7 +246,7 @@ def model_loaded_callback(sd_model):
204
246
def ui_tabs_callback ():
205
247
res = []
206
248
207
- for c in callback_map [ 'callbacks_ui_tabs' ] :
249
+ for c in ordered_callbacks ( 'ui_tabs' ) :
208
250
try :
209
251
res += c .callback () or []
210
252
except Exception :
@@ -214,111 +256,111 @@ def ui_tabs_callback():
214
256
215
257
216
258
def ui_train_tabs_callback (params : UiTrainTabParams ):
217
- for c in callback_map [ 'callbacks_ui_train_tabs' ] :
259
+ for c in ordered_callbacks ( 'ui_train_tabs' ) :
218
260
try :
219
261
c .callback (params )
220
262
except Exception :
221
263
report_exception (c , 'callbacks_ui_train_tabs' )
222
264
223
265
224
266
def ui_settings_callback ():
225
- for c in callback_map [ 'callbacks_ui_settings' ] :
267
+ for c in ordered_callbacks ( 'ui_settings' ) :
226
268
try :
227
269
c .callback ()
228
270
except Exception :
229
271
report_exception (c , 'ui_settings_callback' )
230
272
231
273
232
274
def before_image_saved_callback (params : ImageSaveParams ):
233
- for c in callback_map [ 'callbacks_before_image_saved' ] :
275
+ for c in ordered_callbacks ( 'before_image_saved' ) :
234
276
try :
235
277
c .callback (params )
236
278
except Exception :
237
279
report_exception (c , 'before_image_saved_callback' )
238
280
239
281
240
282
def image_saved_callback (params : ImageSaveParams ):
241
- for c in callback_map [ 'callbacks_image_saved' ] :
283
+ for c in ordered_callbacks ( 'image_saved' ) :
242
284
try :
243
285
c .callback (params )
244
286
except Exception :
245
287
report_exception (c , 'image_saved_callback' )
246
288
247
289
248
290
def extra_noise_callback (params : ExtraNoiseParams ):
249
- for c in callback_map [ 'callbacks_extra_noise' ] :
291
+ for c in ordered_callbacks ( 'extra_noise' ) :
250
292
try :
251
293
c .callback (params )
252
294
except Exception :
253
295
report_exception (c , 'callbacks_extra_noise' )
254
296
255
297
256
298
def cfg_denoiser_callback (params : CFGDenoiserParams ):
257
- for c in callback_map [ 'callbacks_cfg_denoiser' ] :
299
+ for c in ordered_callbacks ( 'cfg_denoiser' ) :
258
300
try :
259
301
c .callback (params )
260
302
except Exception :
261
303
report_exception (c , 'cfg_denoiser_callback' )
262
304
263
305
264
306
def cfg_denoised_callback (params : CFGDenoisedParams ):
265
- for c in callback_map [ 'callbacks_cfg_denoised' ] :
307
+ for c in ordered_callbacks ( 'cfg_denoised' ) :
266
308
try :
267
309
c .callback (params )
268
310
except Exception :
269
311
report_exception (c , 'cfg_denoised_callback' )
270
312
271
313
272
314
def cfg_after_cfg_callback (params : AfterCFGCallbackParams ):
273
- for c in callback_map [ 'callbacks_cfg_after_cfg' ] :
315
+ for c in ordered_callbacks ( 'cfg_after_cfg' ) :
274
316
try :
275
317
c .callback (params )
276
318
except Exception :
277
319
report_exception (c , 'cfg_after_cfg_callback' )
278
320
279
321
280
322
def before_component_callback (component , ** kwargs ):
281
- for c in callback_map [ 'callbacks_before_component' ] :
323
+ for c in ordered_callbacks ( 'before_component' ) :
282
324
try :
283
325
c .callback (component , ** kwargs )
284
326
except Exception :
285
327
report_exception (c , 'before_component_callback' )
286
328
287
329
288
330
def after_component_callback (component , ** kwargs ):
289
- for c in callback_map [ 'callbacks_after_component' ] :
331
+ for c in ordered_callbacks ( 'after_component' ) :
290
332
try :
291
333
c .callback (component , ** kwargs )
292
334
except Exception :
293
335
report_exception (c , 'after_component_callback' )
294
336
295
337
296
338
def image_grid_callback (params : ImageGridLoopParams ):
297
- for c in callback_map [ 'callbacks_image_grid' ] :
339
+ for c in ordered_callbacks ( 'image_grid' ) :
298
340
try :
299
341
c .callback (params )
300
342
except Exception :
301
343
report_exception (c , 'image_grid' )
302
344
303
345
304
346
def infotext_pasted_callback (infotext : str , params : dict [str , Any ]):
305
- for c in callback_map [ 'callbacks_infotext_pasted' ] :
347
+ for c in ordered_callbacks ( 'infotext_pasted' ) :
306
348
try :
307
349
c .callback (infotext , params )
308
350
except Exception :
309
351
report_exception (c , 'infotext_pasted' )
310
352
311
353
312
354
def script_unloaded_callback ():
313
- for c in reversed (callback_map [ 'callbacks_script_unloaded' ] ):
355
+ for c in reversed (ordered_callbacks ( 'script_unloaded' ) ):
314
356
try :
315
357
c .callback ()
316
358
except Exception :
317
359
report_exception (c , 'script_unloaded' )
318
360
319
361
320
362
def before_ui_callback ():
321
- for c in reversed (callback_map [ 'callbacks_before_ui' ] ):
363
+ for c in reversed (ordered_callbacks ( 'before_ui' ) ):
322
364
try :
323
365
c .callback ()
324
366
except Exception :
@@ -328,7 +370,7 @@ def before_ui_callback():
328
370
def list_optimizers_callback ():
329
371
res = []
330
372
331
- for c in callback_map [ 'callbacks_list_optimizers' ] :
373
+ for c in ordered_callbacks ( 'list_optimizers' ) :
332
374
try :
333
375
c .callback (res )
334
376
except Exception :
@@ -340,7 +382,7 @@ def list_optimizers_callback():
340
382
def list_unets_callback ():
341
383
res = []
342
384
343
- for c in callback_map [ 'callbacks_list_unets' ] :
385
+ for c in ordered_callbacks ( 'list_unets' ) :
344
386
try :
345
387
c .callback (res )
346
388
except Exception :
@@ -350,7 +392,7 @@ def list_unets_callback():
350
392
351
393
352
394
def before_token_counter_callback (params : BeforeTokenCounterParams ):
353
- for c in callback_map [ 'callbacks_before_token_counter' ] :
395
+ for c in ordered_callbacks ( 'before_token_counter' ) :
354
396
try :
355
397
c .callback (params )
356
398
except Exception :
0 commit comments