Skip to content

Commit 1c1959e

Browse files
committed
fix: one-to-one mapping
1 parent 5e9e755 commit 1c1959e

File tree

7 files changed

+102
-84
lines changed

7 files changed

+102
-84
lines changed

nodes/audio_utils/load_audio_tensor.py

+33-18
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,49 @@
44

55
class LoadAudioTensor:
66
CATEGORY = "audio_utils"
7-
RETURN_TYPES = ("AUDIO",)
7+
RETURN_TYPES = ("WAVEFORM", "INT")
88
FUNCTION = "execute"
9-
9+
1010
def __init__(self):
11-
self.audio_buffer = np.array([], dtype=np.int16)
11+
self.audio_buffer = np.empty(0, dtype=np.int16)
1212
self.buffer_samples = None
13-
13+
self.sample_rate = None
14+
1415
@classmethod
1516
def INPUT_TYPES(s):
1617
return {
1718
"required": {
1819
"buffer_size": ("FLOAT", {"default": 500.0}),
19-
"sample_rate": ("INT", {"default": 48000})
2020
}
2121
}
22-
22+
2323
@classmethod
2424
def IS_CHANGED():
2525
return float("nan")
26-
27-
def execute(self, buffer_size, sample_rate):
28-
if not self.buffer_samples:
29-
self.buffer_samples = int(buffer_size * sample_rate / 1000)
30-
31-
while self.audio_buffer.size < self.buffer_samples:
32-
audio = tensor_cache.audio_inputs.get()
33-
self.audio_buffer = np.concatenate((self.audio_buffer, audio))
34-
35-
buffered_audio = self.audio_buffer
36-
self.audio_buffer = np.array([], dtype=np.int16)
37-
return (buffered_audio,)
26+
27+
def execute(self, buffer_size):
28+
if self.sample_rate is None or self.buffer_samples is None:
29+
first_audio, sr = tensor_cache.audio_inputs.get(block=True)
30+
self.sample_rate = sr
31+
self.buffer_samples = int(sr * buffer_size / 1000)
32+
self.leftover = first_audio
33+
34+
if self.leftover.shape[0] < self.buffer_samples:
35+
chunks = [self.leftover] if self.leftover.size > 0 else []
36+
total_samples = self.leftover.shape[0]
37+
38+
while total_samples < self.buffer_samples:
39+
audio, sr = tensor_cache.audio_inputs.get(block=True)
40+
if sr != self.sample_rate:
41+
raise ValueError("Sample rate mismatch")
42+
chunks.append(audio)
43+
total_samples += audio.shape[0]
44+
45+
merged_audio = np.concatenate(chunks, dtype=np.int16)
46+
buffered_audio = merged_audio[:self.buffer_samples]
47+
self.leftover = merged_audio[self.buffer_samples:]
48+
else:
49+
buffered_audio = self.leftover[:self.buffer_samples]
50+
self.leftover = self.leftover[self.buffer_samples:]
51+
52+
return buffered_audio, self.sample_rate

nodes/audio_utils/save_audio_tensor.py

+3-13
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,20 @@ class SaveAudioTensor:
66
FUNCTION = "execute"
77
OUTPUT_NODE = True
88

9-
def __init__(self):
10-
self.frame_samples = None
119

1210
@classmethod
1311
def INPUT_TYPES(s):
1412
return {
1513
"required": {
16-
"audio": ("AUDIO",),
17-
"frame_size": ("FLOAT", {"default": 20.0}),
18-
"sample_rate": ("INT", {"default": 48000})
14+
"audio": ("WAVEFORM",)
1915
}
2016
}
2117

2218
@classmethod
2319
def IS_CHANGED(s):
2420
return float("nan")
2521

26-
def execute(self, audio, frame_size, sample_rate):
27-
if self.frame_samples is None:
28-
self.frame_samples = int(frame_size * sample_rate / 1000)
29-
30-
for idx in range(0, len(audio), self.frame_samples):
31-
frame = audio[idx:idx + self.frame_samples]
32-
fut = tensor_cache.audio_outputs.get()
33-
fut.set_result(frame)
22+
def execute(self, audio):
23+
tensor_cache.audio_outputs.put_nowait(audio)
3424
return (audio,)
3525

nodes/tensor_utils/save_tensor.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,5 @@ def IS_CHANGED(s):
2222
return float("nan")
2323

2424
def execute(self, images: torch.Tensor):
25-
fut = tensor_cache.image_outputs.get()
26-
fut.set_result(images)
25+
tensor_cache.image_outputs.put_nowait(images)
2726
return images

server/app.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ async def offer(request):
107107
params = await request.json()
108108

109109
await pipeline.set_prompts(params["prompts"])
110-
# await pipeline.warm()
111110

112111
offer_params = params["offer"]
113112
offer = RTCSessionDescription(sdp=offer_params["sdp"], type=offer_params["type"])
@@ -152,13 +151,13 @@ async def on_message(message):
152151
"nodes": nodes_info
153152
}
154153
channel.send(json.dumps(response))
155-
elif params.get("type") == "update_prompt":
156-
if "prompt" not in params:
154+
elif params.get("type") == "update_prompts":
155+
if "prompts" not in params:
157156
logger.warning("[Control] Missing prompt in update_prompt message")
158157
return
159-
pipeline.set_prompt(params["prompt"])
158+
pipeline.set_prompts(params["prompts"])
160159
response = {
161-
"type": "prompt_updated",
160+
"type": "prompts_updated",
162161
"success": True
163162
}
164163
channel.send(json.dumps(response))

server/pipeline.py

+46-35
Original file line numberDiff line numberDiff line change
@@ -6,78 +6,89 @@
66
from typing import Any, Dict, Union, List
77
from comfystream.client import ComfyStreamClient
88

9-
WARMUP_RUNS = 10
9+
WARMUP_RUNS = 5
1010

1111

1212
class Pipeline:
1313
def __init__(self, **kwargs):
1414
self.client = ComfyStreamClient(**kwargs, max_workers=5) # TODO: hardcoded max workers, should it be configurable?
1515

16-
self.video_futures = asyncio.Queue()
17-
self.audio_futures = asyncio.Queue()
16+
self.video_incoming_frames = asyncio.Queue()
17+
self.audio_incoming_frames = asyncio.Queue()
18+
19+
self.processed_audio_buffer = np.array([], dtype=np.int16)
1820

1921
async def warm_video(self):
2022
dummy_video_inp = torch.randn(1, 512, 512, 3)
2123

2224
for _ in range(WARMUP_RUNS):
23-
image_out_fut = self.client.put_video_input(dummy_video_inp)
24-
await image_out_fut
25+
self.client.put_video_input(dummy_video_inp)
26+
await self.client.get_video_output()
2527

2628
async def warm_audio(self):
27-
dummy_audio_inp = np.random.randint(-32768, 32767, 48 * 20, dtype=np.int16) # TODO: might affect the workflow, due to buffering
29+
dummy_audio_inp = np.random.randint(-32768, 32767, int(48000 * 0.5), dtype=np.int16) # TODO: adds a lot of delay if it doesn't match the buffer size, is warmup needed?
2830

29-
futs = []
3031
for _ in range(WARMUP_RUNS):
31-
audio_out_fut = self.client.put_audio_input(dummy_audio_inp)
32-
futs.append(audio_out_fut)
33-
34-
await asyncio.gather(*futs)
32+
self.client.put_audio_input((dummy_audio_inp, 48000))
33+
await self.client.get_audio_output()
3534

3635
async def set_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]):
37-
if isinstance(prompts, dict):
38-
await self.client.set_prompts([prompts])
39-
else:
36+
if isinstance(prompts, list):
4037
await self.client.set_prompts(prompts)
38+
else:
39+
await self.client.set_prompts([prompts])
4140

4241
async def put_video_frame(self, frame: av.VideoFrame):
4342
inp_tensor = self.video_preprocess(frame)
44-
out_future = self.client.put_video_input(inp_tensor)
45-
await self.video_futures.put((out_future, frame.pts, frame.time_base))
43+
self.client.put_video_input(inp_tensor)
44+
await self.video_incoming_frames.put((frame.pts, frame.time_base))
4645

4746
async def put_audio_frame(self, frame: av.AudioFrame):
48-
inp_tensor = self.audio_preprocess(frame)
49-
out_future = self.client.put_audio_input(inp_tensor)
50-
await self.audio_futures.put((out_future, frame.pts, frame.time_base, frame.sample_rate))
47+
inp_np = self.audio_preprocess(frame)
48+
self.client.put_audio_input((inp_np, frame.sample_rate))
49+
await self.audio_incoming_frames.put((frame.pts, frame.time_base, frame.samples, frame.sample_rate))
5150

52-
def video_preprocess(self, frame: av.VideoFrame) -> torch.Tensor:
51+
def video_preprocess(self, frame: av.VideoFrame) -> Union[torch.Tensor, np.ndarray]:
5352
frame_np = frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0
5453
return torch.from_numpy(frame_np).unsqueeze(0)
5554

56-
def audio_preprocess(self, frame: av.AudioFrame) -> torch.Tensor:
55+
def audio_preprocess(self, frame: av.AudioFrame) -> Union[torch.Tensor, np.ndarray]:
5756
return frame.to_ndarray().ravel().reshape(-1, 2).mean(axis=1).astype(np.int16)
5857

59-
def video_postprocess(self, output: torch.Tensor) -> av.VideoFrame:
58+
def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame:
6059
return av.VideoFrame.from_ndarray(
6160
(output * 255.0).clamp(0, 255).to(dtype=torch.uint8).squeeze(0).cpu().numpy()
6261
)
6362

64-
def audio_postprocess(self, output: torch.Tensor) -> av.AudioFrame:
65-
return av.AudioFrame.from_ndarray(output.reshape(1, -1), layout="mono")
63+
def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.AudioFrame:
64+
return av.AudioFrame.from_ndarray(np.repeat(output, 2).reshape(1, -1))
6665

6766
async def get_processed_video_frame(self):
68-
out_fut, pts, time_base = await self.video_futures.get()
69-
frame = self.video_postprocess(await out_fut)
70-
frame.pts = pts
71-
frame.time_base = time_base
72-
return frame
67+
# TODO: make it generic to support purely generative video cases
68+
pts, time_base = await self.video_incoming_frames.get()
69+
out_tensor = await self.client.get_video_output()
70+
71+
processed_frame = self.video_postprocess(out_tensor)
72+
processed_frame.pts = pts
73+
processed_frame.time_base = time_base
74+
75+
return processed_frame
7376

7477
async def get_processed_audio_frame(self):
75-
out_fut, pts, time_base, sample_rate = await self.audio_futures.get()
76-
frame = self.audio_postprocess(await out_fut)
77-
frame.pts = pts
78-
frame.time_base = time_base
79-
frame.sample_rate = sample_rate
80-
return frame
78+
# TODO: make it generic to support purely generative audio cases
79+
pts, time_base, samples, sample_rate = await self.audio_incoming_frames.get()
80+
if samples > len(self.processed_audio_buffer):
81+
out_tensor = await self.client.get_audio_output()
82+
self.processed_audio_buffer = np.concatenate([self.processed_audio_buffer, out_tensor])
83+
out_data = self.processed_audio_buffer[:samples]
84+
self.processed_audio_buffer = self.processed_audio_buffer[samples:]
85+
86+
processed_frame = self.audio_postprocess(out_data)
87+
processed_frame.pts = pts
88+
processed_frame.time_base = time_base
89+
processed_frame.sample_rate = sample_rate
90+
91+
return processed_frame
8192

8293
async def get_nodes_info(self) -> Dict[str, Any]:
8394
"""Get information about all nodes in the current prompt including metadata."""

src/comfystream/client.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,16 @@ async def run_prompt(self, prompt: PromptDictInput):
4141
raise
4242

4343
def put_video_input(self, inp_tensor):
44-
out_future = asyncio.Future()
45-
tensor_cache.image_outputs.put(out_future)
4644
tensor_cache.image_inputs.put(inp_tensor)
47-
return out_future
4845

4946
def put_audio_input(self, inp_tensor):
50-
out_future = asyncio.Future()
51-
tensor_cache.audio_outputs.put(out_future)
5247
tensor_cache.audio_inputs.put(inp_tensor)
53-
return out_future
48+
49+
async def get_video_output(self):
50+
return await tensor_cache.image_outputs.get()
51+
52+
async def get_audio_output(self):
53+
return await tensor_cache.audio_outputs.get()
5454

5555
async def get_available_nodes(self):
5656
"""Get metadata and available nodes info in a single pass"""

src/comfystream/tensor_cache.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1-
import asyncio
21
import torch
2+
import numpy as np
3+
34
from queue import Queue
5+
from asyncio import Queue as AsyncQueue
6+
7+
from typing import Union
48

5-
image_inputs: Queue[torch.Tensor] = Queue()
6-
image_outputs: Queue[asyncio.Future] = Queue()
9+
image_inputs: Queue[Union[torch.Tensor, np.ndarray]] = Queue()
10+
image_outputs: AsyncQueue[Union[torch.Tensor, np.ndarray]] = AsyncQueue()
711

8-
audio_inputs: Queue[torch.Tensor] = Queue()
9-
audio_outputs: Queue[asyncio.Future] = Queue()
12+
audio_inputs: Queue[Union[torch.Tensor, np.ndarray]] = Queue()
13+
audio_outputs: AsyncQueue[Union[torch.Tensor, np.ndarray]] = AsyncQueue()

0 commit comments

Comments
 (0)