1
- import asyncio
2
1
import argparse
3
- import os
2
+ import asyncio
4
3
import json
5
4
import logging
6
- from collections import deque
5
+ import os
7
6
import sys
8
7
9
8
import torch
13
12
torch .cuda .init ()
14
13
15
14
16
- from twilio .rest import Client
17
15
from aiohttp import web
18
16
from aiortc import (
19
- RTCPeerConnection ,
20
- RTCSessionDescription ,
17
+ MediaStreamTrack ,
21
18
RTCConfiguration ,
22
19
RTCIceServer ,
23
- MediaStreamTrack ,
20
+ RTCPeerConnection ,
21
+ RTCSessionDescription ,
24
22
)
25
- from aiortc .rtcrtpsender import RTCRtpSender
26
23
from aiortc .codecs import h264
24
+ from aiortc .rtcrtpsender import RTCRtpSender
27
25
from pipeline import Pipeline
28
- from utils import patch_loop_datagram , StreamStats , add_prefix_to_app_routes
29
- import time
26
+ from twilio . rest import Client
27
+ from utils import FPSMeter , StreamStats , add_prefix_to_app_routes , patch_loop_datagram
30
28
31
29
logger = logging .getLogger (__name__ )
32
- logging .getLogger (' aiortc.rtcrtpsender' ).setLevel (logging .WARNING )
33
- logging .getLogger (' aiortc.rtcrtpreceiver' ).setLevel (logging .WARNING )
30
+ logging .getLogger (" aiortc.rtcrtpsender" ).setLevel (logging .WARNING )
31
+ logging .getLogger (" aiortc.rtcrtpreceiver" ).setLevel (logging .WARNING )
34
32
35
33
36
34
MAX_BITRATE = 2000000
@@ -45,7 +43,9 @@ class VideoStreamTrack(MediaStreamTrack):
45
43
track (MediaStreamTrack): The underlying media stream track.
46
44
pipeline (Pipeline): The processing pipeline to apply to each video frame.
47
45
"""
46
+
48
47
kind = "video"
48
+
49
49
def __init__ (self , track : MediaStreamTrack , pipeline : Pipeline ):
50
50
"""Initialize the VideoStreamTrack.
51
51
@@ -56,21 +56,14 @@ def __init__(self, track: MediaStreamTrack, pipeline: Pipeline):
56
56
super ().__init__ ()
57
57
self .track = track
58
58
self .pipeline = pipeline
59
-
60
- self ._lock = asyncio .Lock ()
61
- self ._fps_interval_frame_count = 0
62
- self ._last_fps_calculation_time = None
63
- self ._fps_loop_start_time = time .monotonic ()
64
- self ._fps = 0.0
65
- self ._fps_measurements = deque (maxlen = 60 )
66
- self ._running_event = asyncio .Event ()
59
+ self .fps_meter = FPSMeter ()
67
60
68
61
asyncio .create_task (self .collect_frames ())
69
62
70
- # Start metrics collection tasks.
71
- self ._fps_stats_task = asyncio .create_task (self ._calculate_fps_loop ())
72
-
73
63
async def collect_frames (self ):
64
+ """Continuously collect video frames from the underlying track and pass them to
65
+ the processing pipeline.
66
+ """
74
67
while True :
75
68
try :
76
69
frame = await self .track .recv ()
@@ -79,86 +72,21 @@ async def collect_frames(self):
79
72
await self .pipeline .cleanup ()
80
73
raise Exception (f"Error collecting video frames: { str (e )} " )
81
74
82
- async def _calculate_fps_loop (self ):
83
- """Loop to calculate FPS periodically."""
84
- await self ._running_event .wait ()
85
- self ._fps_loop_start_time = time .monotonic ()
86
- while self .readyState != "ended" :
87
- async with self ._lock :
88
- current_time = time .monotonic ()
89
- if self ._last_fps_calculation_time is not None :
90
- time_diff = current_time - self ._last_fps_calculation_time
91
- self ._fps = self ._fps_interval_frame_count / time_diff
92
- self ._fps_measurements .append (
93
- {
94
- "timestamp" : current_time - self ._fps_loop_start_time ,
95
- "fps" : self ._fps ,
96
- }
97
- ) # Store the FPS measurement with timestamp
98
-
99
- # Reset start_time and frame_count for the next interval.
100
- self ._last_fps_calculation_time = current_time
101
- self ._fps_interval_frame_count = 0
102
- await asyncio .sleep (1 ) # Calculate FPS every second.
103
-
104
- @property
105
- async def fps (self ) -> float :
106
- """Get the current output frames per second (FPS).
107
-
108
- Returns:
109
- The current output FPS.
110
- """
111
- async with self ._lock :
112
- return self ._fps
113
-
114
- @property
115
- async def fps_measurements (self ) -> list :
116
- """Get the array of FPS measurements for the last minute.
117
-
118
- Returns:
119
- The array of FPS measurements for the last minute.
120
- """
121
- async with self ._lock :
122
- return list (self ._fps_measurements )
123
-
124
- @property
125
- async def average_fps (self ) -> float :
126
- """Calculate the average FPS from the measurements taken in the last minute.
127
-
128
- Returns:
129
- The average FPS over the last minute.
130
- """
131
- async with self ._lock :
132
- if not self ._fps_measurements :
133
- return 0.0
134
- return sum (
135
- measurement ["fps" ] for measurement in self ._fps_measurements
136
- ) / len (self ._fps_measurements )
137
-
138
- @property
139
- async def last_fps_calculation_time (self ) -> float :
140
- """Get the elapsed time since the last FPS calculation.
141
-
142
- Returns:
143
- The elapsed time in seconds since the last FPS calculation.
144
- """
145
- async with self ._lock :
146
- return self ._last_fps_calculation_time - self ._fps_loop_start_time
147
-
148
75
async def recv (self ):
76
+ """Receive a processed video frame from the pipeline, increment the frame
77
+ count for FPS calculation and return the processed frame to the client.
78
+ """
149
79
processed_frame = await self .pipeline .get_processed_video_frame ()
150
80
151
- # Increment frame count for FPS calculation.
152
- async with self ._lock :
153
- self ._fps_interval_frame_count += 1
154
- if not self ._running_event .is_set ():
155
- self ._running_event .set ()
81
+ # Increment the frame count to calculate FPS.
82
+ await self .fps_meter .increment_frame_count ()
156
83
157
84
return processed_frame
158
85
159
86
160
87
class AudioStreamTrack (MediaStreamTrack ):
161
88
kind = "audio"
89
+
162
90
def __init__ (self , track : MediaStreamTrack , pipeline ):
163
91
super ().__init__ ()
164
92
self .track = track
@@ -257,30 +185,29 @@ async def offer(request):
257
185
@pc .on ("datachannel" )
258
186
def on_datachannel (channel ):
259
187
if channel .label == "control" :
188
+
260
189
@channel .on ("message" )
261
190
async def on_message (message ):
262
191
try :
263
192
params = json .loads (message )
264
193
265
194
if params .get ("type" ) == "get_nodes" :
266
195
nodes_info = await pipeline .get_nodes_info ()
267
- response = {
268
- "type" : "nodes_info" ,
269
- "nodes" : nodes_info
270
- }
196
+ response = {"type" : "nodes_info" , "nodes" : nodes_info }
271
197
channel .send (json .dumps (response ))
272
198
elif params .get ("type" ) == "update_prompts" :
273
199
if "prompts" not in params :
274
- logger .warning ("[Control] Missing prompt in update_prompt message" )
200
+ logger .warning (
201
+ "[Control] Missing prompt in update_prompt message"
202
+ )
275
203
return
276
204
await pipeline .update_prompts (params ["prompts" ])
277
- response = {
278
- "type" : "prompts_updated" ,
279
- "success" : True
280
- }
205
+ response = {"type" : "prompts_updated" , "success" : True }
281
206
channel .send (json .dumps (response ))
282
207
else :
283
- logger .warning ("[Server] Invalid message format - missing required fields" )
208
+ logger .warning (
209
+ "[Server] Invalid message format - missing required fields"
210
+ )
284
211
except json .JSONDecodeError :
285
212
logger .error ("[Server] Invalid JSON received" )
286
213
except Exception as e :
@@ -389,8 +316,8 @@ async def on_shutdown(app: web.Application):
389
316
390
317
logging .basicConfig (
391
318
level = args .log_level .upper (),
392
- format = ' %(asctime)s [%(levelname)s] %(message)s' ,
393
- datefmt = ' %H:%M:%S'
319
+ format = " %(asctime)s [%(levelname)s] %(message)s" ,
320
+ datefmt = " %H:%M:%S" ,
394
321
)
395
322
396
323
app = web .Application ()
0 commit comments