Skip to content

Commit 29f6bb7

Browse files
committed
feat: audio pipeline
1 parent 960aebe commit 29f6bb7

19 files changed

+179
-123
lines changed

nodes/audio_utils/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from .apply_whisper import ApplyWhisper
22
from .load_audio_tensor import LoadAudioTensor
33
from .save_asr_response import SaveASRResponse
4+
from .save_audio_tensor import SaveAudioTensor
45

5-
NODE_CLASS_MAPPINGS = {"LoadAudioTensor": LoadAudioTensor, "SaveASRResponse": SaveASRResponse, "ApplyWhisper": ApplyWhisper}
6+
NODE_CLASS_MAPPINGS = {"LoadAudioTensor": LoadAudioTensor, "SaveASRResponse": SaveASRResponse, "ApplyWhisper": ApplyWhisper, "SaveAudioTensor": SaveAudioTensor}
67

78
__all__ = ["NODE_CLASS_MAPPINGS"]

nodes/audio_utils/apply_whisper.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,18 @@ def __init__(self):
2121
# TO:DO to get them as params
2222
self.sample_rate = 16000
2323
self.min_duration = 1.0
24+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
2425

2526
def apply_whisper(self, audio, model):
2627
if self.model is None:
27-
self.model = whisper.load_model(model).cuda()
28+
self.model = whisper.load_model(model).to(self.device)
2829

2930
self.audio_buffer.append(audio)
3031
total_duration = sum(chunk.shape[0] / self.sample_rate for chunk in self.audio_buffer)
3132
if total_duration < self.min_duration:
3233
return {"text": "", "segments_alignment": [], "words_alignment": []}
3334

34-
concatenated_audio = torch.cat(self.audio_buffer, dim=0).cuda()
35+
concatenated_audio = torch.cat(self.audio_buffer, dim=0).to(self.device)
3536
self.audio_buffer = []
3637
result = self.model.transcribe(concatenated_audio.float(), fp16=True, word_timestamps=True)
3738
segments = result["segments"]

nodes/audio_utils/load_audio_tensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@ def IS_CHANGED():
1414
return float("nan")
1515

1616
def execute(self):
17-
audio = tensor_cache.inputs.pop()
17+
audio = tensor_cache.audio_inputs.pop()
1818
return (audio,)

nodes/audio_utils/save_asr_response.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,6 @@ def IS_CHANGED(s):
1919
return float("nan")
2020

2121
def execute(self, data: dict):
22-
fut = tensor_cache.outputs.pop()
22+
fut = tensor_cache.audio_outputs.pop()
2323
fut.set_result(data)
2424
return data
+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from comfystream import tensor_cache
2+
3+
4+
class SaveAudioTensor:
5+
CATEGORY = "audio_utils"
6+
RETURN_TYPES = ()
7+
FUNCTION = "execute"
8+
OUTPUT_NODE = True
9+
10+
@classmethod
11+
def INPUT_TYPES(s):
12+
return {
13+
"required": {
14+
"audio": ("AUDIO",),
15+
}
16+
}
17+
18+
@classmethod
19+
def IS_CHANGED(s):
20+
return float("nan")
21+
22+
def execute(self, audio):
23+
fut = tensor_cache.audio_outputs.pop()
24+
fut.set_result(audio)
25+
return audio

nodes/tensor_utils/load_tensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,5 @@ def IS_CHANGED():
1515
return float("nan")
1616

1717
def execute(self):
18-
input = tensor_cache.inputs.pop()
18+
input = tensor_cache.image_inputs.pop()
1919
return (input,)

nodes/tensor_utils/save_tensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,6 @@ def IS_CHANGED(s):
2222
return float("nan")
2323

2424
def execute(self, images: torch.Tensor):
25-
fut = tensor_cache.outputs.pop()
25+
fut = tensor_cache.image_outputs.pop()
2626
fut.set_result(images)
2727
return images

server/app.py

+19-89
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import os
44
import json
55
import logging
6-
import wave
7-
import numpy as np
86

97
from twilio.rest import Client
108
from aiohttp import web
@@ -17,7 +15,7 @@
1715
)
1816
from aiortc.rtcrtpsender import RTCRtpSender
1917
from aiortc.codecs import h264
20-
from pipeline import Pipeline
18+
from pipeline import VideoPipeline, AudioPipeline
2119
from utils import patch_loop_datagram
2220

2321
logger = logging.getLogger(__name__)
@@ -39,93 +37,16 @@ async def recv(self):
3937
return await self.pipeline(frame)
4038

4139
class AudioStreamTrack(MediaStreamTrack):
42-
"""
43-
This custom audio track wraps an incoming audio MediaStreamTrack.
44-
It continuously records frames in 10-second chunks and saves each chunk
45-
as a separate WAV file with an incrementing index.
46-
"""
47-
4840
kind = "audio"
4941

50-
def __init__(self, track: MediaStreamTrack):
42+
def __init__(self, track: MediaStreamTrack, pipeline):
5143
super().__init__()
5244
self.track = track
53-
self.start_time = None
54-
self.frames = []
55-
self._recording_duration = 10.0 # in seconds
56-
self._chunk_index = 0
57-
self._saving = False
58-
self._lock = asyncio.Lock()
45+
self.pipeline = pipeline
5946

6047
async def recv(self):
6148
frame = await self.track.recv()
62-
return frame
63-
64-
# async def recv(self):
65-
# frame = await self.source.recv()
66-
67-
# # On the first frame, record the start time.
68-
# if self.start_time is None:
69-
# self.start_time = frame.time
70-
# logger.info(f"Audio recording started at time: {self.start_time:.3f}")
71-
72-
# elapsed = frame.time - self.start_time
73-
# self.frames.append(frame)
74-
75-
# logger.info(f"Received audio frame at time: {frame.time:.3f}, total frames: {len(self.frames)}")
76-
77-
# # Check if we've hit 10 seconds and we're not currently saving.
78-
# if elapsed >= self._recording_duration and not self._saving:
79-
# logger.info(f"10 second chunk reached (elapsed: {elapsed:.3f}s). Preparing to save chunk {self._chunk_index}.")
80-
# self._saving = True
81-
# # Handle saving in a background task so we don't block the recv loop.
82-
# asyncio.create_task(self.save_audio())
83-
84-
# return frame
85-
86-
async def save_audio(self):
87-
logger.info(f"Starting to save audio chunk {self._chunk_index}...")
88-
async with self._lock:
89-
# Extract properties from the first frame
90-
if not self.frames:
91-
logger.warning("No frames to save, skipping.")
92-
self._saving = False
93-
return
94-
95-
sample_rate = self.frames[0].sample_rate
96-
layout = self.frames[0].layout
97-
channels = len(layout.channels)
98-
99-
logger.info(f"Audio chunk {self._chunk_index}: sample_rate={sample_rate}, channels={channels}, frames_count={len(self.frames)}")
100-
101-
# Convert all frames to ndarray and concatenate
102-
data_arrays = [f.to_ndarray() for f in self.frames]
103-
data = np.concatenate(data_arrays, axis=1) # shape: (channels, total_samples)
104-
105-
# Interleave channels (if multiple) since WAV expects interleaved samples.
106-
interleaved = data.T.flatten()
107-
108-
# If needed, convert float frames to int16
109-
# interleaved = (interleaved * 32767).astype(np.int16)
110-
111-
filename = f"output_{self._chunk_index}.wav"
112-
logger.info(f"Writing audio chunk {self._chunk_index} to file: {filename}")
113-
with wave.open(filename, 'wb') as wf:
114-
wf.setnchannels(channels)
115-
wf.setsampwidth(2) # 16-bit PCM
116-
wf.setframerate(sample_rate)
117-
wf.writeframes(interleaved.tobytes())
118-
119-
logger.info(f"Audio chunk {self._chunk_index} saved successfully as {filename}")
120-
121-
# Increment the chunk index for the next segment
122-
self._chunk_index += 1
123-
124-
# Reset for next recording chunk
125-
self.frames.clear()
126-
self.start_time = None
127-
self._saving = False
128-
logger.info(f"Ready to record next 10-second chunk. Current chunk index: {self._chunk_index}")
49+
return await self.pipeline(frame)
12950

13051

13152
def force_codec(pc, sender, forced_codec):
@@ -169,13 +90,19 @@ def get_ice_servers():
16990

17091

17192
async def offer(request):
172-
pipeline = request.app["pipeline"]
93+
video_pipeline = request.app["video_pipeline"]
94+
audio_pipeline = request.app["audio_pipeline"]
17395
pcs = request.app["pcs"]
17496

17597
params = await request.json()
17698

177-
pipeline.set_prompt(params["prompt"])
178-
await pipeline.warm()
99+
print("VIDEO PROMPT", params["video_prompt"])
100+
print("AUDIO PROMPT", params["audio_prompt"])
101+
102+
video_pipeline.set_prompt(params["video_prompt"])
103+
await video_pipeline.warm()
104+
audio_pipeline.set_prompt(params["audio_prompt"])
105+
await audio_pipeline.warm()
179106

180107
offer_params = params["offer"]
181108
offer = RTCSessionDescription(sdp=offer_params["sdp"], type=offer_params["type"])
@@ -206,14 +133,14 @@ async def offer(request):
206133
def on_track(track):
207134
logger.info(f"Track received: {track.kind}")
208135
if track.kind == "video":
209-
videoTrack = VideoStreamTrack(track, pipeline)
136+
videoTrack = VideoStreamTrack(track, video_pipeline)
210137
tracks["video"] = videoTrack
211138
sender = pc.addTrack(videoTrack)
212139

213140
codec = "video/H264"
214141
force_codec(pc, sender, codec)
215142
elif track.kind == "audio":
216-
audioTrack = AudioStreamTrack(track)
143+
audioTrack = AudioStreamTrack(track, audio_pipeline)
217144
tracks["audio"] = audioTrack
218145
pc.addTrack(audioTrack)
219146

@@ -261,7 +188,10 @@ async def on_startup(app: web.Application):
261188
if app["media_ports"]:
262189
patch_loop_datagram(app["media_ports"])
263190

264-
app["pipeline"] = Pipeline(
191+
app["video_pipeline"] = VideoPipeline(
192+
cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True
193+
)
194+
app["audio_pipeline"] = AudioPipeline(
265195
cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True
266196
)
267197
app["pcs"] = set()

server/pipeline.py

+40-3
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
import av
33
import numpy as np
44

5-
from typing import Any, Dict
5+
from typing import Any, Dict, Optional, Union
66
from comfystream.client import ComfyStreamClient
77

88
WARMUP_RUNS = 5
99

1010

11-
class Pipeline:
11+
class VideoPipeline:
1212
def __init__(self, **kwargs):
13-
self.client = ComfyStreamClient(**kwargs)
13+
self.client = ComfyStreamClient(**kwargs, type="image")
1414

1515
async def warm(self):
1616
frame = torch.randn(1, 512, 512, 3)
@@ -42,3 +42,40 @@ async def __call__(self, frame: av.VideoFrame) -> av.VideoFrame:
4242
post_output.time_base = frame.time_base
4343

4444
return post_output
45+
46+
47+
class AudioPipeline:
48+
def __init__(self, **kwargs):
49+
self.client = ComfyStreamClient(**kwargs, type="audio")
50+
51+
async def warm(self):
52+
dummy_audio = torch.randn(16000)
53+
for _ in range(WARMUP_RUNS):
54+
await self.predict(dummy_audio)
55+
56+
def set_prompt(self, prompt: Dict[Any, Any]):
57+
self.client.set_prompt(prompt)
58+
59+
def preprocess(self, frame: av.AudioFrame) -> torch.Tensor:
60+
self.sample_rate = frame.sample_rate
61+
samples = frame.to_ndarray(format="s16", layout="stereo")
62+
samples = samples.astype(np.float32) / 32768.0
63+
return torch.from_numpy(samples)
64+
65+
def postprocess(self, output: torch.Tensor) -> Optional[Union[av.AudioFrame, str]]:
66+
out_np = output.cpu().numpy()
67+
out_np = np.clip(out_np * 32768.0, -32768, 32767).astype(np.int16)
68+
audio_frame = av.AudioFrame.from_ndarray(out_np, format="s16", layout="stereo")
69+
return audio_frame
70+
71+
async def predict(self, frame: torch.Tensor) -> torch.Tensor:
72+
return await self.client.queue_prompt(frame)
73+
74+
async def __call__(self, frame: av.AudioFrame):
75+
pre_output = self.preprocess(frame)
76+
pred_output = await self.predict(pre_output)
77+
post_output = self.postprocess(pred_output)
78+
post_output.sample_rate = self.sample_rate
79+
post_output.pts = frame.pts
80+
post_output.time_base = frame.time_base
81+
return post_output

src/comfystream/client.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,29 @@
99

1010

1111
class ComfyStreamClient:
12-
def __init__(self, **kwargs):
12+
def __init__(self, type: str = "image", **kwargs):
1313
config = Configuration(**kwargs)
1414
# TODO: Need to handle cleanup for EmbeddedComfyClient if not using async context manager?
1515
self.comfy_client = EmbeddedComfyClient(config)
1616
self.prompt = None
17+
self.type = type.lower()
18+
if self.type not in {"image", "audio"}:
19+
raise ValueError(f"Unsupported type: {self.type}. Supported types are 'image' and 'audio'.")
20+
21+
self.input_cache = getattr(tensor_cache, f"{self.type}_inputs", None)
22+
self.output_cache = getattr(tensor_cache, f"{self.type}_outputs", None)
23+
24+
if self.input_cache is None or self.output_cache is None:
25+
raise AttributeError(f"tensor_cache does not have attributes for type '{self.type}'.")
1726

1827
def set_prompt(self, prompt: PromptDictInput):
1928
self.prompt = convert_prompt(prompt)
2029

2130
async def queue_prompt(self, input: torch.Tensor) -> torch.Tensor:
22-
tensor_cache.inputs.append(input)
31+
self.input_cache.append(input)
2332

2433
output_fut = asyncio.Future()
25-
tensor_cache.outputs.append(output_fut)
34+
self.output_cache.append(output_fut)
2635

2736
await self.comfy_client.queue_prompt(self.prompt)
2837

src/comfystream/tensor_cache.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,8 @@
22
import torch
33
from typing import List
44

5-
inputs: List[torch.Tensor] = []
6-
outputs: List[asyncio.Future] = []
5+
image_inputs: List[torch.Tensor] = []
6+
image_outputs: List[asyncio.Future] = []
7+
8+
audio_inputs: List[torch.Tensor] = []
9+
audio_outputs: List[asyncio.Future] = []

src/comfystream/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def convert_prompt(prompt: PromptDictInput) -> Prompt:
4848
num_primary_inputs += 1
4949
elif class_type in ["LoadImage", "LoadTensor", "LoadAudioTensor"]:
5050
num_inputs += 1
51-
elif class_type in ["PreviewImage", "SaveImage", "SaveTensor", "SaveASRResponse"]:
51+
elif class_type in ["PreviewImage", "SaveImage", "SaveTensor", "SaveASRResponse", "SaveAudioTensor"]:
5252
num_outputs += 1
5353

5454
# Only handle single primary input

ui/src/app/api/offer/route.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import { NextRequest, NextResponse } from "next/server";
22

33
export const POST = async function POST(req: NextRequest) {
4-
const { endpoint, prompt, offer } = await req.json();
4+
const { endpoint, video_prompt, audio_prompt, offer } = await req.json();
55

66
const res = await fetch(endpoint + "/offer", {
77
method: "POST",
88
headers: {
99
"Content-Type": "application/json",
1010
},
11-
body: JSON.stringify({ prompt, offer }),
11+
body: JSON.stringify({ video_prompt, audio_prompt, offer }),
1212
});
1313

1414
return NextResponse.json(await res.json(), { status: res.status });

ui/src/components/peer.tsx

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ import { PeerContext } from "@/context/peer-context";
44

55
export interface PeerProps extends React.HTMLAttributes<HTMLDivElement> {
66
url: string;
7-
prompt: any;
7+
videoPrompt: any;
8+
audioPrompt: any;
89
connect: boolean;
910
onConnected: () => void;
1011
onDisconnected: () => void;

0 commit comments

Comments
 (0)