Skip to content

Commit 49deb2f

Browse files
committed
fix: server
1 parent 2a3d086 commit 49deb2f

File tree

5 files changed

+82
-58
lines changed

5 files changed

+82
-58
lines changed

nodes/audio_utils/save_audio_tensor.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@ def INPUT_TYPES(s):
1212
return {
1313
"required": {
1414
"audio": ("AUDIO",),
15+
"text": ("TEXT",)
1516
}
1617
}
1718

1819
@classmethod
1920
def IS_CHANGED(s):
2021
return float("nan")
2122

22-
def execute(self, audio):
23+
def execute(self, audio, text):
2324
fut = tensor_cache.audio_outputs.pop()
24-
fut.set_result(audio)
25-
return audio
25+
fut.set_result((audio, text))
26+
return (audio, text)

nodes/whisper_utils/apply_whisper.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
# TODO: move it to a seperate repo
2-
1+
import librosa
2+
import numpy as np
33
from .whisper_online import FasterWhisperASR, VACOnlineASRProcessor
44

55
class ApplyWhisper:
@@ -12,7 +12,7 @@ def INPUT_TYPES(s):
1212
}
1313

1414
CATEGORY = "whisper_utils"
15-
RETURN_TYPES = ("RESULT",)
15+
RETURN_TYPES = ("TEXT",)
1616
FUNCTION = "apply_whisper"
1717

1818
def __init__(self):
@@ -34,6 +34,7 @@ def __init__(self):
3434
)
3535

3636
def apply_whisper(self, audio):
37+
audio = librosa.resample(audio.astype(np.float32), 48000, 16000)
3738
self.online.insert_audio_chunk(audio)
38-
result = self.online.process_iter()
39-
return (result,)
39+
text = self.online.process_iter()
40+
return (text,)

server/app.py

+30-9
Original file line numberDiff line numberDiff line change
@@ -26,27 +26,51 @@
2626

2727
class VideoStreamTrack(MediaStreamTrack):
2828
kind = "video"
29-
3029
def __init__(self, track: MediaStreamTrack, pipeline):
3130
super().__init__()
3231
self.track = track
3332
self.pipeline = pipeline
33+
self.processed_frames = asyncio.Queue()
34+
asyncio.create_task(self.collect_frames())
35+
36+
async def collect_frames(self):
37+
while True:
38+
frame = await self.track.recv()
39+
processed = await self.pipeline(frame)
40+
await self.processed_frames.put(processed)
3441

3542
async def recv(self):
36-
frame = await self.track.recv()
37-
return await self.pipeline(frame)
43+
return await self.processed_frames.get()
3844

45+
3946
class AudioStreamTrack(MediaStreamTrack):
4047
kind = "audio"
41-
4248
def __init__(self, track: MediaStreamTrack, pipeline):
4349
super().__init__()
4450
self.track = track
4551
self.pipeline = pipeline
52+
self.incoming_frames = asyncio.Queue()
53+
self.processed_frames = asyncio.Queue()
54+
asyncio.create_task(self.collect_frames())
55+
asyncio.create_task(self.process_frames())
56+
self.started = False
57+
58+
async def collect_frames(self):
59+
while True:
60+
frame = await self.track.recv()
61+
await self.incoming_frames.put(frame)
62+
63+
async def process_frames(self):
64+
while True:
65+
frames = []
66+
while len(frames) < 25:
67+
frames.append(await self.incoming_frames.get())
68+
processed_frames = await self.pipeline(frames)
69+
for processed_frame in processed_frames:
70+
await self.processed_frames.put(processed_frame)
4671

4772
async def recv(self):
48-
frame = await self.track.recv()
49-
return await self.pipeline(frame)
73+
return await self.processed_frames.get()
5074

5175

5276
def force_codec(pc, sender, forced_codec):
@@ -96,9 +120,6 @@ async def offer(request):
96120

97121
params = await request.json()
98122

99-
print("VIDEO PROMPT", params["video_prompt"])
100-
print("AUDIO PROMPT", params["audio_prompt"])
101-
102123
video_pipeline.set_prompt(params["video_prompt"])
103124
await video_pipeline.warm()
104125
audio_pipeline.set_prompt(params["audio_prompt"])

server/pipeline.py

+36-36
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import torch
22
import av
33
import numpy as np
4+
import fractions
45

5-
from typing import Any, Dict, Optional, Union
6+
from av import AudioFrame
7+
from typing import Any, Dict, Optional, Union, List
68
from comfystream.client import ComfyStreamClient
79

810
WARMUP_RUNS = 5
9-
10-
# TODO: remove, was just for temp UI
1111
import logging
12-
1312
display_logger = logging.getLogger('display_logger')
1413
display_logger.setLevel(logging.INFO)
1514
handler = logging.FileHandler('display_logs.txt')
@@ -57,46 +56,47 @@ async def __call__(self, frame: av.VideoFrame) -> av.VideoFrame:
5756
class AudioPipeline:
5857
def __init__(self, **kwargs):
5958
self.client = ComfyStreamClient(**kwargs, type="audio")
60-
self.resampler = av.audio.resampler.AudioResampler(format='s16', layout='mono', rate=16000)
59+
self.resampler = av.audio.resampler.AudioResampler(format='s16', layout='mono', rate=48000)
60+
self.sample_rate = 48000
61+
self.frame_size = int(self.sample_rate * 0.02)
62+
self.time_base = fractions.Fraction(1, self.sample_rate)
63+
self.curr_pts = 0
6164

6265
async def warm(self):
63-
dummy_audio = torch.randn(16000)
66+
dummy_audio = np.random.randint(-32768, 32767, 48000 * 1, dtype=np.int16)
6467
for _ in range(WARMUP_RUNS):
6568
await self.predict(dummy_audio)
6669

6770
def set_prompt(self, prompt: Dict[Any, Any]):
6871
self.client.set_prompt(prompt)
6972

70-
def preprocess(self, frame: av.AudioFrame) -> torch.Tensor:
71-
resampled_frame = self.resampler.resample(frame)[0]
72-
samples = resampled_frame.to_ndarray()
73-
samples = samples.astype(np.float32) / 32768.0
74-
return samples
75-
76-
def postprocess(self, output: torch.Tensor) -> Optional[Union[av.AudioFrame, str]]:
77-
out_np = output.cpu().numpy()
78-
out_np = np.clip(out_np * 32768.0, -32768, 32767).astype(np.int16)
79-
audio_frame = av.AudioFrame.from_ndarray(out_np, format="s16", layout="stereo")
80-
return audio_frame
73+
def preprocess(self, frames: List[av.AudioFrame]) -> torch.Tensor:
74+
audio_arrays = []
75+
for frame in frames:
76+
audio_arrays.append(self.resampler.resample(frame)[0].to_ndarray())
77+
return np.concatenate(audio_arrays, axis=1).flatten()
78+
79+
def postprocess(self, out_np) -> Optional[Union[av.AudioFrame, str]]:
80+
frames = []
81+
for idx in range(0, len(out_np), self.frame_size):
82+
frame_samples = out_np[idx:idx + self.frame_size]
83+
frame_samples = frame_samples.reshape(1, -1)
84+
frame = AudioFrame.from_ndarray(frame_samples, layout="mono")
85+
frame.sample_rate = self.sample_rate
86+
frame.pts = self.curr_pts
87+
frame.time_base = self.time_base
88+
self.curr_pts += 960
89+
90+
frames.append(frame)
91+
return frames
8192

82-
async def predict(self, frame: torch.Tensor) -> torch.Tensor:
93+
async def predict(self, frame) -> torch.Tensor:
8394
return await self.client.queue_prompt(frame)
8495

85-
async def __call__(self, frame: av.AudioFrame):
86-
# TODO: clean this up later for audio-to-text and audio-to-audio
87-
pre_output = self.preprocess(frame)
88-
pred_output = await self.predict(pre_output)
89-
if type(pred_output) == tuple:
90-
if pred_output[0] is not None:
91-
await self.log_text(f"{pred_output[0]} {pred_output[1]} {pred_output[2]}")
92-
return frame
93-
else:
94-
post_output = self.postprocess(pred_output)
95-
post_output.sample_rate = frame.sample_rate
96-
post_output.pts = frame.pts
97-
post_output.time_base = frame.time_base
98-
return post_output
99-
100-
async def log_text(self, text: str):
101-
# TODO: remove, was just for temp UI
102-
display_logger.info(text)
96+
async def __call__(self, frames: List[av.AudioFrame]):
97+
pre_audio = self.preprocess(frames)
98+
pred_audio, text = await self.predict(pre_audio)
99+
if text[-1] != "":
100+
display_logger.info(f"{text[0]} {text[1]} {text[2]}")
101+
pred_audios = self.postprocess(pred_audio)
102+
return pred_audios

ui/src/components/webcam.tsx

+6-5
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,12 @@ export function Webcam({ onStreamReady, deviceId, frameRate, selectedAudioDevice
149149
},
150150
audio: {
151151
...(selectedAudioDeviceId ? { deviceId: { exact: selectedAudioDeviceId } } : {}),
152-
sampleRate: { ideal: 16000 },
153-
sampleSize: { ideal: 16 },
154-
channelCount: { ideal: 1 },
155-
echoCancellation: true,
156-
noiseSuppression: true,
152+
sampleRate: 48000,
153+
channelCount: 2,
154+
sampleSize: 16,
155+
echoCancellation: false,
156+
noiseSuppression: false,
157+
autoGainControl: false,
157158
},
158159
});
159160
return newStream;

0 commit comments

Comments
 (0)