|
6 | 6 | from typing import Any, Dict, Union, List
|
7 | 7 | from comfystream.client import ComfyStreamClient
|
8 | 8 |
|
9 |
| -WARMUP_RUNS = 10 |
| 9 | +WARMUP_RUNS = 5 |
10 | 10 |
|
11 | 11 |
|
12 | 12 | class Pipeline:
|
13 | 13 | def __init__(self, **kwargs):
|
14 | 14 | self.client = ComfyStreamClient(**kwargs, max_workers=5) # TODO: hardcoded max workers, should it be configurable?
|
15 | 15 |
|
16 |
| - self.video_futures = asyncio.Queue() |
17 |
| - self.audio_futures = asyncio.Queue() |
| 16 | + self.video_incoming_frames = asyncio.Queue() |
| 17 | + self.audio_incoming_frames = asyncio.Queue() |
| 18 | + |
| 19 | + self.processed_audio_buffer = np.array([], dtype=np.int16) |
18 | 20 |
|
19 | 21 | async def warm_video(self):
|
20 | 22 | dummy_video_inp = torch.randn(1, 512, 512, 3)
|
21 | 23 |
|
22 | 24 | for _ in range(WARMUP_RUNS):
|
23 |
| - image_out_fut = self.client.put_video_input(dummy_video_inp) |
24 |
| - await image_out_fut |
| 25 | + self.client.put_video_input(dummy_video_inp) |
| 26 | + await self.client.get_video_output() |
25 | 27 |
|
26 | 28 | async def warm_audio(self):
|
27 |
| - dummy_audio_inp = np.random.randint(-32768, 32767, 48 * 20, dtype=np.int16) # TODO: might affect the workflow, due to buffering |
| 29 | + dummy_audio_inp = np.random.randint(-32768, 32767, int(48000 * 0.5), dtype=np.int16) # TODO: adds a lot of delay if it doesn't match the buffer size, is warmup needed? |
28 | 30 |
|
29 |
| - futs = [] |
30 | 31 | for _ in range(WARMUP_RUNS):
|
31 |
| - audio_out_fut = self.client.put_audio_input(dummy_audio_inp) |
32 |
| - futs.append(audio_out_fut) |
33 |
| - |
34 |
| - await asyncio.gather(*futs) |
| 32 | + self.client.put_audio_input((dummy_audio_inp, 48000)) |
| 33 | + await self.client.get_audio_output() |
35 | 34 |
|
36 | 35 | async def set_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]):
|
37 |
| - if isinstance(prompts, dict): |
38 |
| - await self.client.set_prompts([prompts]) |
39 |
| - else: |
| 36 | + if isinstance(prompts, list): |
40 | 37 | await self.client.set_prompts(prompts)
|
| 38 | + else: |
| 39 | + await self.client.set_prompts([prompts]) |
41 | 40 |
|
42 | 41 | async def put_video_frame(self, frame: av.VideoFrame):
|
43 | 42 | inp_tensor = self.video_preprocess(frame)
|
44 |
| - out_future = self.client.put_video_input(inp_tensor) |
45 |
| - await self.video_futures.put((out_future, frame.pts, frame.time_base)) |
| 43 | + self.client.put_video_input(inp_tensor) |
| 44 | + await self.video_incoming_frames.put((frame.pts, frame.time_base)) |
46 | 45 |
|
47 | 46 | async def put_audio_frame(self, frame: av.AudioFrame):
|
48 |
| - inp_tensor = self.audio_preprocess(frame) |
49 |
| - out_future = self.client.put_audio_input(inp_tensor) |
50 |
| - await self.audio_futures.put((out_future, frame.pts, frame.time_base, frame.sample_rate)) |
| 47 | + inp_np = self.audio_preprocess(frame) |
| 48 | + self.client.put_audio_input((inp_np, frame.sample_rate)) |
| 49 | + await self.audio_incoming_frames.put((frame.pts, frame.time_base, frame.samples, frame.sample_rate)) |
51 | 50 |
|
52 |
| - def video_preprocess(self, frame: av.VideoFrame) -> torch.Tensor: |
| 51 | + def video_preprocess(self, frame: av.VideoFrame) -> Union[torch.Tensor, np.ndarray]: |
53 | 52 | frame_np = frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0
|
54 | 53 | return torch.from_numpy(frame_np).unsqueeze(0)
|
55 | 54 |
|
56 |
| - def audio_preprocess(self, frame: av.AudioFrame) -> torch.Tensor: |
| 55 | + def audio_preprocess(self, frame: av.AudioFrame) -> Union[torch.Tensor, np.ndarray]: |
57 | 56 | return frame.to_ndarray().ravel().reshape(-1, 2).mean(axis=1).astype(np.int16)
|
58 | 57 |
|
59 |
| - def video_postprocess(self, output: torch.Tensor) -> av.VideoFrame: |
| 58 | + def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame: |
60 | 59 | return av.VideoFrame.from_ndarray(
|
61 | 60 | (output * 255.0).clamp(0, 255).to(dtype=torch.uint8).squeeze(0).cpu().numpy()
|
62 | 61 | )
|
63 | 62 |
|
64 |
| - def audio_postprocess(self, output: torch.Tensor) -> av.AudioFrame: |
65 |
| - return av.AudioFrame.from_ndarray(output.reshape(1, -1), layout="mono") |
| 63 | + def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.AudioFrame: |
| 64 | + return av.AudioFrame.from_ndarray(np.repeat(output, 2).reshape(1, -1)) |
66 | 65 |
|
67 | 66 | async def get_processed_video_frame(self):
|
68 |
| - out_fut, pts, time_base = await self.video_futures.get() |
69 |
| - frame = self.video_postprocess(await out_fut) |
70 |
| - frame.pts = pts |
71 |
| - frame.time_base = time_base |
72 |
| - return frame |
| 67 | + # TODO: make it generic to support purely generative video cases |
| 68 | + pts, time_base = await self.video_incoming_frames.get() |
| 69 | + out_tensor = await self.client.get_video_output() |
| 70 | + |
| 71 | + processed_frame = self.video_postprocess(out_tensor) |
| 72 | + processed_frame.pts = pts |
| 73 | + processed_frame.time_base = time_base |
| 74 | + |
| 75 | + return processed_frame |
73 | 76 |
|
74 | 77 | async def get_processed_audio_frame(self):
|
75 |
| - out_fut, pts, time_base, sample_rate = await self.audio_futures.get() |
76 |
| - frame = self.audio_postprocess(await out_fut) |
77 |
| - frame.pts = pts |
78 |
| - frame.time_base = time_base |
79 |
| - frame.sample_rate = sample_rate |
80 |
| - return frame |
| 78 | + # TODO: make it generic to support purely generative audio cases |
| 79 | + pts, time_base, samples, sample_rate = await self.audio_incoming_frames.get() |
| 80 | + if samples > len(self.processed_audio_buffer): |
| 81 | + out_tensor = await self.client.get_audio_output() |
| 82 | + self.processed_audio_buffer = np.concatenate([self.processed_audio_buffer, out_tensor]) |
| 83 | + out_data = self.processed_audio_buffer[:samples] |
| 84 | + self.processed_audio_buffer = self.processed_audio_buffer[samples:] |
| 85 | + |
| 86 | + processed_frame = self.audio_postprocess(out_data) |
| 87 | + processed_frame.pts = pts |
| 88 | + processed_frame.time_base = time_base |
| 89 | + processed_frame.sample_rate = sample_rate |
| 90 | + |
| 91 | + return processed_frame |
81 | 92 |
|
82 | 93 | async def get_nodes_info(self) -> Dict[str, Any]:
|
83 | 94 | """Get information about all nodes in the current prompt including metadata."""
|
|
0 commit comments