3
3
import os
4
4
import json
5
5
import logging
6
- import wave
7
- import numpy as np
8
6
9
7
from twilio .rest import Client
10
8
from aiohttp import web
17
15
)
18
16
from aiortc .rtcrtpsender import RTCRtpSender
19
17
from aiortc .codecs import h264
20
- from pipeline import Pipeline
18
+ from pipeline import VideoPipeline , AudioPipeline
21
19
from utils import patch_loop_datagram
22
20
23
21
logger = logging .getLogger (__name__ )
@@ -39,93 +37,16 @@ async def recv(self):
39
37
return await self .pipeline (frame )
40
38
41
39
class AudioStreamTrack (MediaStreamTrack ):
42
- """
43
- This custom audio track wraps an incoming audio MediaStreamTrack.
44
- It continuously records frames in 10-second chunks and saves each chunk
45
- as a separate WAV file with an incrementing index.
46
- """
47
-
48
40
kind = "audio"
49
41
50
- def __init__ (self , track : MediaStreamTrack ):
42
+ def __init__ (self , track : MediaStreamTrack , pipeline ):
51
43
super ().__init__ ()
52
44
self .track = track
53
- self .start_time = None
54
- self .frames = []
55
- self ._recording_duration = 10.0 # in seconds
56
- self ._chunk_index = 0
57
- self ._saving = False
58
- self ._lock = asyncio .Lock ()
45
+ self .pipeline = pipeline
59
46
60
47
async def recv (self ):
61
48
frame = await self .track .recv ()
62
- return frame
63
-
64
- # async def recv(self):
65
- # frame = await self.source.recv()
66
-
67
- # # On the first frame, record the start time.
68
- # if self.start_time is None:
69
- # self.start_time = frame.time
70
- # logger.info(f"Audio recording started at time: {self.start_time:.3f}")
71
-
72
- # elapsed = frame.time - self.start_time
73
- # self.frames.append(frame)
74
-
75
- # logger.info(f"Received audio frame at time: {frame.time:.3f}, total frames: {len(self.frames)}")
76
-
77
- # # Check if we've hit 10 seconds and we're not currently saving.
78
- # if elapsed >= self._recording_duration and not self._saving:
79
- # logger.info(f"10 second chunk reached (elapsed: {elapsed:.3f}s). Preparing to save chunk {self._chunk_index}.")
80
- # self._saving = True
81
- # # Handle saving in a background task so we don't block the recv loop.
82
- # asyncio.create_task(self.save_audio())
83
-
84
- # return frame
85
-
86
- async def save_audio (self ):
87
- logger .info (f"Starting to save audio chunk { self ._chunk_index } ..." )
88
- async with self ._lock :
89
- # Extract properties from the first frame
90
- if not self .frames :
91
- logger .warning ("No frames to save, skipping." )
92
- self ._saving = False
93
- return
94
-
95
- sample_rate = self .frames [0 ].sample_rate
96
- layout = self .frames [0 ].layout
97
- channels = len (layout .channels )
98
-
99
- logger .info (f"Audio chunk { self ._chunk_index } : sample_rate={ sample_rate } , channels={ channels } , frames_count={ len (self .frames )} " )
100
-
101
- # Convert all frames to ndarray and concatenate
102
- data_arrays = [f .to_ndarray () for f in self .frames ]
103
- data = np .concatenate (data_arrays , axis = 1 ) # shape: (channels, total_samples)
104
-
105
- # Interleave channels (if multiple) since WAV expects interleaved samples.
106
- interleaved = data .T .flatten ()
107
-
108
- # If needed, convert float frames to int16
109
- # interleaved = (interleaved * 32767).astype(np.int16)
110
-
111
- filename = f"output_{ self ._chunk_index } .wav"
112
- logger .info (f"Writing audio chunk { self ._chunk_index } to file: { filename } " )
113
- with wave .open (filename , 'wb' ) as wf :
114
- wf .setnchannels (channels )
115
- wf .setsampwidth (2 ) # 16-bit PCM
116
- wf .setframerate (sample_rate )
117
- wf .writeframes (interleaved .tobytes ())
118
-
119
- logger .info (f"Audio chunk { self ._chunk_index } saved successfully as { filename } " )
120
-
121
- # Increment the chunk index for the next segment
122
- self ._chunk_index += 1
123
-
124
- # Reset for next recording chunk
125
- self .frames .clear ()
126
- self .start_time = None
127
- self ._saving = False
128
- logger .info (f"Ready to record next 10-second chunk. Current chunk index: { self ._chunk_index } " )
49
+ return await self .pipeline (frame )
129
50
130
51
131
52
def force_codec (pc , sender , forced_codec ):
@@ -169,13 +90,19 @@ def get_ice_servers():
169
90
170
91
171
92
async def offer (request ):
172
- pipeline = request .app ["pipeline" ]
93
+ video_pipeline = request .app ["video_pipeline" ]
94
+ audio_pipeline = request .app ["audio_pipeline" ]
173
95
pcs = request .app ["pcs" ]
174
96
175
97
params = await request .json ()
176
98
177
- pipeline .set_prompt (params ["prompt" ])
178
- await pipeline .warm ()
99
+ print ("VIDEO PROMPT" , params ["video_prompt" ])
100
+ print ("AUDIO PROMPT" , params ["audio_prompt" ])
101
+
102
+ video_pipeline .set_prompt (params ["video_prompt" ])
103
+ await video_pipeline .warm ()
104
+ audio_pipeline .set_prompt (params ["audio_prompt" ])
105
+ await audio_pipeline .warm ()
179
106
180
107
offer_params = params ["offer" ]
181
108
offer = RTCSessionDescription (sdp = offer_params ["sdp" ], type = offer_params ["type" ])
@@ -206,14 +133,14 @@ async def offer(request):
206
133
def on_track (track ):
207
134
logger .info (f"Track received: { track .kind } " )
208
135
if track .kind == "video" :
209
- videoTrack = VideoStreamTrack (track , pipeline )
136
+ videoTrack = VideoStreamTrack (track , video_pipeline )
210
137
tracks ["video" ] = videoTrack
211
138
sender = pc .addTrack (videoTrack )
212
139
213
140
codec = "video/H264"
214
141
force_codec (pc , sender , codec )
215
142
elif track .kind == "audio" :
216
- audioTrack = AudioStreamTrack (track )
143
+ audioTrack = AudioStreamTrack (track , audio_pipeline )
217
144
tracks ["audio" ] = audioTrack
218
145
pc .addTrack (audioTrack )
219
146
@@ -261,7 +188,10 @@ async def on_startup(app: web.Application):
261
188
if app ["media_ports" ]:
262
189
patch_loop_datagram (app ["media_ports" ])
263
190
264
- app ["pipeline" ] = Pipeline (
191
+ app ["video_pipeline" ] = VideoPipeline (
192
+ cwd = app ["workspace" ], disable_cuda_malloc = True , gpu_only = True
193
+ )
194
+ app ["audio_pipeline" ] = AudioPipeline (
265
195
cwd = app ["workspace" ], disable_cuda_malloc = True , gpu_only = True
266
196
)
267
197
app ["pcs" ] = set ()
0 commit comments