Skip to content

Commit 9f56ec5

Browse files
varshith15eliteprox
authored andcommitted
fix: cleanup
1 parent c87f404 commit 9f56ec5

File tree

3 files changed

+46
-11
lines changed

3 files changed

+46
-11
lines changed

server/app.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,12 @@ def __init__(self, track: MediaStreamTrack, pipeline):
3838

3939
async def collect_frames(self):
4040
while True:
41-
frame = await self.track.recv()
42-
await self.pipeline.put_video_frame(frame)
41+
try:
42+
frame = await self.track.recv()
43+
await self.pipeline.put_video_frame(frame)
44+
except Exception as e:
45+
await self.pipeline.cleanup()
46+
raise Exception(f"Error collecting video frames: {str(e)}")
4347

4448
async def recv(self):
4549
return await self.pipeline.get_processed_video_frame()
@@ -55,8 +59,12 @@ def __init__(self, track: MediaStreamTrack, pipeline):
5559

5660
async def collect_frames(self):
5761
while True:
58-
frame = await self.track.recv()
59-
await self.pipeline.put_audio_frame(frame)
62+
try:
63+
frame = await self.track.recv()
64+
await self.pipeline.put_audio_frame(frame)
65+
except Exception as e:
66+
await self.pipeline.cleanup()
67+
raise Exception(f"Error collecting audio frames: {str(e)}")
6068

6169
async def recv(self):
6270
return await self.pipeline.get_processed_audio_frame()

server/pipeline.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -106,4 +106,7 @@ async def get_processed_audio_frame(self):
106106
async def get_nodes_info(self) -> Dict[str, Any]:
107107
"""Get information about all nodes in the current prompt including metadata."""
108108
nodes_info = await self.client.get_available_nodes()
109-
return nodes_info
109+
return nodes_info
110+
111+
async def cleanup(self):
112+
await self.client.cleanup()

src/comfystream/client.py

+30-6
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def __init__(self, max_workers: int = 1, **kwargs):
1818
self.comfy_client = EmbeddedComfyClient(config, max_workers=max_workers)
1919
self.running_prompts = {} # To be used for cancelling tasks
2020
self.current_prompts = []
21+
self.cleanup_lock = asyncio.Lock()
2122

2223
async def set_prompts(self, prompts: List[PromptDictInput]):
2324
self.current_prompts = [convert_prompt(prompt) for prompt in prompts]
@@ -38,16 +39,39 @@ async def run_prompt(self, prompt_index: int):
3839
try:
3940
await self.comfy_client.queue_prompt(self.current_prompts[prompt_index])
4041
except Exception as e:
42+
await self.cleanup()
4143
logger.error(f"Error running prompt: {str(e)}")
42-
logger.error(f"Error type: {type(e)}")
4344
raise
4445

4546
async def cleanup(self):
46-
for task in self.running_prompts.values():
47-
await task.cancel()
48-
49-
if self.comfy_client.is_running:
50-
await self.comfy_client.__aexit__()
47+
async with self.cleanup_lock:
48+
for task in self.running_prompts.values():
49+
task.cancel()
50+
try:
51+
await task
52+
except asyncio.CancelledError:
53+
pass
54+
self.running_prompts.clear()
55+
56+
if self.comfy_client.is_running:
57+
await self.comfy_client.__aexit__()
58+
59+
await self.cleanup_queues()
60+
logger.info("Client cleanup complete")
61+
62+
63+
async def cleanup_queues(self):
64+
while not tensor_cache.image_inputs.empty():
65+
tensor_cache.image_inputs.get()
66+
67+
while not tensor_cache.audio_inputs.empty():
68+
tensor_cache.audio_inputs.get()
69+
70+
while not tensor_cache.image_outputs.empty():
71+
await tensor_cache.image_outputs.get()
72+
73+
while not tensor_cache.audio_outputs.empty():
74+
await tensor_cache.audio_outputs.get()
5175

5276
def put_video_input(self, frame):
5377
if tensor_cache.image_inputs.full():

0 commit comments

Comments
 (0)