Skip to content

Commit 7e5e673

Browse files
committed
add UI for reordering callbacks
1 parent 0411ece commit 7e5e673

File tree

5 files changed

+192
-44
lines changed

5 files changed

+192
-44
lines changed

modules/script_callbacks.py

+67-25
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from fastapi import FastAPI
99
from gradio import Blocks
1010

11-
from modules import errors, timer, extensions
11+
from modules import errors, timer, extensions, shared
1212

1313

1414
def report_exception(c, job):
@@ -124,9 +124,10 @@ class ScriptCallback:
124124
name: str = None
125125

126126

127-
def add_callback(callbacks, fun, *, name=None, category='unknown'):
128-
stack = [x for x in inspect.stack() if x.filename != __file__]
129-
filename = stack[0].filename if stack else 'unknown file'
127+
def add_callback(callbacks, fun, *, name=None, category='unknown', filename=None):
128+
if filename is None:
129+
stack = [x for x in inspect.stack() if x.filename != __file__]
130+
filename = stack[0].filename if stack else 'unknown file'
130131

131132
extension = extensions.find_extension(filename)
132133
extension_name = extension.canonical_name if extension else 'base'
@@ -146,6 +147,43 @@ def add_callback(callbacks, fun, *, name=None, category='unknown'):
146147
callbacks.append(ScriptCallback(filename, fun, unique_callback_name))
147148

148149

150+
def sort_callbacks(category, unordered_callbacks, *, enable_user_sort=True):
151+
callbacks = unordered_callbacks.copy()
152+
153+
if enable_user_sort:
154+
for name in reversed(getattr(shared.opts, 'prioritized_callbacks_' + category, [])):
155+
index = next((i for i, callback in enumerate(callbacks) if callback.name == name), None)
156+
if index is not None:
157+
callbacks.insert(0, callbacks.pop(index))
158+
159+
return callbacks
160+
161+
162+
def ordered_callbacks(category, unordered_callbacks=None, *, enable_user_sort=True):
163+
if unordered_callbacks is None:
164+
unordered_callbacks = callback_map.get('callbacks_' + category, [])
165+
166+
if not enable_user_sort:
167+
return sort_callbacks(category, unordered_callbacks, enable_user_sort=False)
168+
169+
callbacks = ordered_callbacks_map.get(category)
170+
if callbacks is not None and len(callbacks) == len(unordered_callbacks):
171+
return callbacks
172+
173+
callbacks = sort_callbacks(category, unordered_callbacks)
174+
175+
ordered_callbacks_map[category] = callbacks
176+
return callbacks
177+
178+
179+
def enumerate_callbacks():
180+
for category, callbacks in callback_map.items():
181+
if category.startswith('callbacks_'):
182+
category = category[10:]
183+
184+
yield category, callbacks
185+
186+
149187
callback_map = dict(
150188
callbacks_app_started=[],
151189
callbacks_model_loaded=[],
@@ -170,14 +208,18 @@ def add_callback(callbacks, fun, *, name=None, category='unknown'):
170208
callbacks_before_token_counter=[],
171209
)
172210

211+
ordered_callbacks_map = {}
212+
173213

174214
def clear_callbacks():
175215
for callback_list in callback_map.values():
176216
callback_list.clear()
177217

218+
ordered_callbacks_map.clear()
219+
178220

179221
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
180-
for c in callback_map['callbacks_app_started']:
222+
for c in ordered_callbacks('app_started'):
181223
try:
182224
c.callback(demo, app)
183225
timer.startup_timer.record(os.path.basename(c.script))
@@ -186,15 +228,15 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
186228

187229

188230
def app_reload_callback():
189-
for c in callback_map['callbacks_on_reload']:
231+
for c in ordered_callbacks('on_reload'):
190232
try:
191233
c.callback()
192234
except Exception:
193235
report_exception(c, 'callbacks_on_reload')
194236

195237

196238
def model_loaded_callback(sd_model):
197-
for c in callback_map['callbacks_model_loaded']:
239+
for c in ordered_callbacks('model_loaded'):
198240
try:
199241
c.callback(sd_model)
200242
except Exception:
@@ -204,7 +246,7 @@ def model_loaded_callback(sd_model):
204246
def ui_tabs_callback():
205247
res = []
206248

207-
for c in callback_map['callbacks_ui_tabs']:
249+
for c in ordered_callbacks('ui_tabs'):
208250
try:
209251
res += c.callback() or []
210252
except Exception:
@@ -214,111 +256,111 @@ def ui_tabs_callback():
214256

215257

216258
def ui_train_tabs_callback(params: UiTrainTabParams):
217-
for c in callback_map['callbacks_ui_train_tabs']:
259+
for c in ordered_callbacks('ui_train_tabs'):
218260
try:
219261
c.callback(params)
220262
except Exception:
221263
report_exception(c, 'callbacks_ui_train_tabs')
222264

223265

224266
def ui_settings_callback():
225-
for c in callback_map['callbacks_ui_settings']:
267+
for c in ordered_callbacks('ui_settings'):
226268
try:
227269
c.callback()
228270
except Exception:
229271
report_exception(c, 'ui_settings_callback')
230272

231273

232274
def before_image_saved_callback(params: ImageSaveParams):
233-
for c in callback_map['callbacks_before_image_saved']:
275+
for c in ordered_callbacks('before_image_saved'):
234276
try:
235277
c.callback(params)
236278
except Exception:
237279
report_exception(c, 'before_image_saved_callback')
238280

239281

240282
def image_saved_callback(params: ImageSaveParams):
241-
for c in callback_map['callbacks_image_saved']:
283+
for c in ordered_callbacks('image_saved'):
242284
try:
243285
c.callback(params)
244286
except Exception:
245287
report_exception(c, 'image_saved_callback')
246288

247289

248290
def extra_noise_callback(params: ExtraNoiseParams):
249-
for c in callback_map['callbacks_extra_noise']:
291+
for c in ordered_callbacks('extra_noise'):
250292
try:
251293
c.callback(params)
252294
except Exception:
253295
report_exception(c, 'callbacks_extra_noise')
254296

255297

256298
def cfg_denoiser_callback(params: CFGDenoiserParams):
257-
for c in callback_map['callbacks_cfg_denoiser']:
299+
for c in ordered_callbacks('cfg_denoiser'):
258300
try:
259301
c.callback(params)
260302
except Exception:
261303
report_exception(c, 'cfg_denoiser_callback')
262304

263305

264306
def cfg_denoised_callback(params: CFGDenoisedParams):
265-
for c in callback_map['callbacks_cfg_denoised']:
307+
for c in ordered_callbacks('cfg_denoised'):
266308
try:
267309
c.callback(params)
268310
except Exception:
269311
report_exception(c, 'cfg_denoised_callback')
270312

271313

272314
def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
273-
for c in callback_map['callbacks_cfg_after_cfg']:
315+
for c in ordered_callbacks('cfg_after_cfg'):
274316
try:
275317
c.callback(params)
276318
except Exception:
277319
report_exception(c, 'cfg_after_cfg_callback')
278320

279321

280322
def before_component_callback(component, **kwargs):
281-
for c in callback_map['callbacks_before_component']:
323+
for c in ordered_callbacks('before_component'):
282324
try:
283325
c.callback(component, **kwargs)
284326
except Exception:
285327
report_exception(c, 'before_component_callback')
286328

287329

288330
def after_component_callback(component, **kwargs):
289-
for c in callback_map['callbacks_after_component']:
331+
for c in ordered_callbacks('after_component'):
290332
try:
291333
c.callback(component, **kwargs)
292334
except Exception:
293335
report_exception(c, 'after_component_callback')
294336

295337

296338
def image_grid_callback(params: ImageGridLoopParams):
297-
for c in callback_map['callbacks_image_grid']:
339+
for c in ordered_callbacks('image_grid'):
298340
try:
299341
c.callback(params)
300342
except Exception:
301343
report_exception(c, 'image_grid')
302344

303345

304346
def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
305-
for c in callback_map['callbacks_infotext_pasted']:
347+
for c in ordered_callbacks('infotext_pasted'):
306348
try:
307349
c.callback(infotext, params)
308350
except Exception:
309351
report_exception(c, 'infotext_pasted')
310352

311353

312354
def script_unloaded_callback():
313-
for c in reversed(callback_map['callbacks_script_unloaded']):
355+
for c in reversed(ordered_callbacks('script_unloaded')):
314356
try:
315357
c.callback()
316358
except Exception:
317359
report_exception(c, 'script_unloaded')
318360

319361

320362
def before_ui_callback():
321-
for c in reversed(callback_map['callbacks_before_ui']):
363+
for c in reversed(ordered_callbacks('before_ui')):
322364
try:
323365
c.callback()
324366
except Exception:
@@ -328,7 +370,7 @@ def before_ui_callback():
328370
def list_optimizers_callback():
329371
res = []
330372

331-
for c in callback_map['callbacks_list_optimizers']:
373+
for c in ordered_callbacks('list_optimizers'):
332374
try:
333375
c.callback(res)
334376
except Exception:
@@ -340,7 +382,7 @@ def list_optimizers_callback():
340382
def list_unets_callback():
341383
res = []
342384

343-
for c in callback_map['callbacks_list_unets']:
385+
for c in ordered_callbacks('list_unets'):
344386
try:
345387
c.callback(res)
346388
except Exception:
@@ -350,7 +392,7 @@ def list_unets_callback():
350392

351393

352394
def before_token_counter_callback(params: BeforeTokenCounterParams):
353-
for c in callback_map['callbacks_before_token_counter']:
395+
for c in ordered_callbacks('before_token_counter'):
354396
try:
355397
c.callback(params)
356398
except Exception:

0 commit comments

Comments
 (0)