Skip to content

Commit 0e4d8af

Browse files
committed
fix: audio frame skipping
1 parent 44df170 commit 0e4d8af

File tree

4 files changed

+26
-24
lines changed

4 files changed

+26
-24
lines changed

nodes/audio_utils/load_audio_tensor.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,21 @@ def IS_CHANGED():
2626

2727
def execute(self, buffer_size):
2828
if self.sample_rate is None or self.buffer_samples is None:
29-
first_audio, sr = tensor_cache.audio_inputs.get(block=True)
30-
self.sample_rate = sr
31-
self.buffer_samples = int(sr * buffer_size / 1000)
32-
self.leftover = first_audio
29+
frame = tensor_cache.audio_inputs.get(block=True)
30+
self.sample_rate = frame.sample_rate
31+
self.buffer_samples = int(self.sample_rate * buffer_size / 1000)
32+
self.leftover = frame.side_data.input
3333

3434
if self.leftover.shape[0] < self.buffer_samples:
3535
chunks = [self.leftover] if self.leftover.size > 0 else []
3636
total_samples = self.leftover.shape[0]
3737

3838
while total_samples < self.buffer_samples:
39-
audio, sr = tensor_cache.audio_inputs.get(block=True)
40-
if sr != self.sample_rate:
39+
frame = tensor_cache.audio_inputs.get(block=True)
40+
if frame.sample_rate != self.sample_rate:
4141
raise ValueError("Sample rate mismatch")
42-
chunks.append(audio)
43-
total_samples += audio.shape[0]
42+
chunks.append(frame.side_data.input)
43+
total_samples += frame.side_data.input.shape[0]
4444

4545
merged_audio = np.concatenate(chunks, dtype=np.int16)
4646
buffered_audio = merged_audio[:self.buffer_samples]

nodes/tensor_utils/load_tensor.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import time
21
from comfystream import tensor_cache
32

43

server/pipeline.py

+16-13
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@ async def warm_video(self):
2727
await self.client.get_video_output()
2828

2929
async def warm_audio(self):
30-
dummy_audio_inp = np.random.randint(-32768, 32767, int(48000 * 0.5), dtype=np.int16) # TODO: adds a lot of delay if it doesn't match the buffer size, is warmup needed?
30+
dummy_frame = av.AudioFrame()
31+
dummy_frame.side_data.input = np.random.randint(-32768, 32767, int(48000 * 0.5), dtype=np.int16) # TODO: adds a lot of delay if it doesn't match the buffer size, is warmup needed?
32+
dummy_frame.sample_rate = 48000
3133

3234
for _ in range(WARMUP_RUNS):
33-
self.client.put_audio_input((dummy_audio_inp, 48000))
35+
self.client.put_audio_input(dummy_frame)
3436
await self.client.get_audio_output()
3537

3638
async def set_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]):
@@ -52,9 +54,10 @@ async def put_video_frame(self, frame: av.VideoFrame):
5254
await self.video_incoming_frames.put(frame)
5355

5456
async def put_audio_frame(self, frame: av.AudioFrame):
55-
inp_np = self.audio_preprocess(frame)
56-
self.client.put_audio_input((inp_np, frame.sample_rate))
57-
await self.audio_incoming_frames.put((frame.pts, frame.time_base, frame.samples, frame.sample_rate))
57+
frame.side_data.input = self.audio_preprocess(frame)
58+
frame.side_data.skipped = True
59+
self.client.put_audio_input(frame)
60+
await self.audio_incoming_frames.put(frame)
5861

5962
def video_preprocess(self, frame: av.VideoFrame) -> Union[torch.Tensor, np.ndarray]:
6063
frame_np = frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0
@@ -85,18 +88,18 @@ async def get_processed_video_frame(self):
8588
return processed_frame
8689

8790
async def get_processed_audio_frame(self):
88-
# TODO: make it generic to support purely generative audio cases
89-
pts, time_base, samples, sample_rate = await self.audio_incoming_frames.get()
90-
if samples > len(self.processed_audio_buffer):
91+
# TODO: make it generic to support purely generative audio cases and also add frame skipping
92+
frame = await self.audio_incoming_frames.get()
93+
if frame.samples > len(self.processed_audio_buffer):
9194
out_tensor = await self.client.get_audio_output()
9295
self.processed_audio_buffer = np.concatenate([self.processed_audio_buffer, out_tensor])
93-
out_data = self.processed_audio_buffer[:samples]
94-
self.processed_audio_buffer = self.processed_audio_buffer[samples:]
96+
out_data = self.processed_audio_buffer[:frame.samples]
97+
self.processed_audio_buffer = self.processed_audio_buffer[frame.samples:]
9598

9699
processed_frame = self.audio_postprocess(out_data)
97-
processed_frame.pts = pts
98-
processed_frame.time_base = time_base
99-
processed_frame.sample_rate = sample_rate
100+
processed_frame.pts = frame.pts
101+
processed_frame.time_base = frame.time_base
102+
processed_frame.sample_rate = frame.sample_rate
100103

101104
return processed_frame
102105

src/comfystream/client.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def put_video_input(self, frame):
4848
tensor_cache.image_inputs.get(block=True)
4949
tensor_cache.image_inputs.put(frame)
5050

51-
def put_audio_input(self, inp_tensor):
52-
tensor_cache.audio_inputs.put(inp_tensor)
51+
def put_audio_input(self, frame):
52+
tensor_cache.audio_inputs.put(frame)
5353

5454
async def get_video_output(self):
5555
return await tensor_cache.image_outputs.get()

0 commit comments

Comments
 (0)