@@ -18,6 +18,7 @@ def __init__(self, max_workers: int = 1, **kwargs):
18
18
self .comfy_client = EmbeddedComfyClient (config , max_workers = max_workers )
19
19
self .running_prompts = {} # To be used for cancelling tasks
20
20
self .current_prompts = []
21
+ self .cleanup_lock = asyncio .Lock ()
21
22
22
23
async def set_prompts (self , prompts : List [PromptDictInput ]):
23
24
self .current_prompts = [convert_prompt (prompt ) for prompt in prompts ]
@@ -38,16 +39,39 @@ async def run_prompt(self, prompt_index: int):
38
39
try :
39
40
await self .comfy_client .queue_prompt (self .current_prompts [prompt_index ])
40
41
except Exception as e :
42
+ await self .cleanup ()
41
43
logger .error (f"Error running prompt: { str (e )} " )
42
- logger .error (f"Error type: { type (e )} " )
43
44
raise
44
45
45
46
async def cleanup (self ):
46
- for task in self .running_prompts .values ():
47
- await task .cancel ()
48
-
49
- if self .comfy_client .is_running :
50
- await self .comfy_client .__aexit__ ()
47
+ async with self .cleanup_lock :
48
+ for task in self .running_prompts .values ():
49
+ task .cancel ()
50
+ try :
51
+ await task
52
+ except asyncio .CancelledError :
53
+ pass
54
+ self .running_prompts .clear ()
55
+
56
+ if self .comfy_client .is_running :
57
+ await self .comfy_client .__aexit__ ()
58
+
59
+ await self .cleanup_queues ()
60
+ logger .info ("Client cleanup complete" )
61
+
62
+
63
+ async def cleanup_queues (self ):
64
+ while not tensor_cache .image_inputs .empty ():
65
+ tensor_cache .image_inputs .get ()
66
+
67
+ while not tensor_cache .audio_inputs .empty ():
68
+ tensor_cache .audio_inputs .get ()
69
+
70
+ while not tensor_cache .image_outputs .empty ():
71
+ await tensor_cache .image_outputs .get ()
72
+
73
+ while not tensor_cache .audio_outputs .empty ():
74
+ await tensor_cache .audio_outputs .get ()
51
75
52
76
def put_video_input (self , frame ):
53
77
if tensor_cache .image_inputs .full ():
0 commit comments