3
3
import os
4
4
import json
5
5
import logging
6
+ from collections import deque
6
7
import sys
7
8
8
9
import torch
12
13
torch .cuda .init ()
13
14
14
15
15
- import torch
16
-
17
- # Initialize CUDA before any other imports to prevent core dump.
18
- if torch .cuda .is_available ():
19
- torch .cuda .init ()
20
-
21
-
22
16
from twilio .rest import Client
23
17
from aiohttp import web
24
18
from aiortc import (
27
21
RTCConfiguration ,
28
22
RTCIceServer ,
29
23
MediaStreamTrack ,
30
- RTCDataChannel ,
31
24
)
32
25
from aiortc .rtcrtpsender import RTCRtpSender
33
26
from aiortc .codecs import h264
34
27
from pipeline import Pipeline
35
- from utils import patch_loop_datagram
28
+ from utils import patch_loop_datagram , StreamStats , add_prefix_to_app_routes
29
+ import time
36
30
37
31
logger = logging .getLogger (__name__ )
38
32
logging .getLogger ('aiortc.rtcrtpsender' ).setLevel (logging .WARNING )
44
38
45
39
46
40
class VideoStreamTrack (MediaStreamTrack ):
41
+ """video stream track that processes video frames using a pipeline.
42
+
43
+ Attributes:
44
+ kind (str): The kind of media, which is "video" for this class.
45
+ track (MediaStreamTrack): The underlying media stream track.
46
+ pipeline (Pipeline): The processing pipeline to apply to each video frame.
47
+ """
47
48
kind = "video"
48
- def __init__ (self , track : MediaStreamTrack , pipeline ):
49
+ def __init__ (self , track : MediaStreamTrack , pipeline : Pipeline ):
50
+ """Initialize the VideoStreamTrack.
51
+
52
+ Args:
53
+ track: The underlying media stream track.
54
+ pipeline: The processing pipeline to apply to each video frame.
55
+ """
49
56
super ().__init__ ()
50
57
self .track = track
51
58
self .pipeline = pipeline
59
+
60
+ self ._lock = asyncio .Lock ()
61
+ self ._fps_interval_frame_count = 0
62
+ self ._last_fps_calculation_time = None
63
+ self ._fps_loop_start_time = time .monotonic ()
64
+ self ._fps = 0.0
65
+ self ._fps_measurements = deque (maxlen = 60 )
66
+ self ._running_event = asyncio .Event ()
67
+
52
68
asyncio .create_task (self .collect_frames ())
53
69
70
+ # Start metrics collection tasks.
71
+ self ._fps_stats_task = asyncio .create_task (self ._calculate_fps_loop ())
72
+
54
73
async def collect_frames (self ):
55
74
while True :
56
75
try :
@@ -60,9 +79,83 @@ async def collect_frames(self):
60
79
await self .pipeline .cleanup ()
61
80
raise Exception (f"Error collecting video frames: { str (e )} " )
62
81
82
+ async def _calculate_fps_loop (self ):
83
+ """Loop to calculate FPS periodically."""
84
+ await self ._running_event .wait ()
85
+ self ._fps_loop_start_time = time .monotonic ()
86
+ while self .readyState != "ended" :
87
+ async with self ._lock :
88
+ current_time = time .monotonic ()
89
+ if self ._last_fps_calculation_time is not None :
90
+ time_diff = current_time - self ._last_fps_calculation_time
91
+ self ._fps = self ._fps_interval_frame_count / time_diff
92
+ self ._fps_measurements .append (
93
+ {
94
+ "timestamp" : current_time - self ._fps_loop_start_time ,
95
+ "fps" : self ._fps ,
96
+ }
97
+ ) # Store the FPS measurement with timestamp
98
+
99
+ # Reset start_time and frame_count for the next interval.
100
+ self ._last_fps_calculation_time = current_time
101
+ self ._fps_interval_frame_count = 0
102
+ await asyncio .sleep (1 ) # Calculate FPS every second.
103
+
104
+ @property
105
+ async def fps (self ) -> float :
106
+ """Get the current output frames per second (FPS).
107
+
108
+ Returns:
109
+ The current output FPS.
110
+ """
111
+ async with self ._lock :
112
+ return self ._fps
113
+
114
+ @property
115
+ async def fps_measurements (self ) -> list :
116
+ """Get the array of FPS measurements for the last minute.
117
+
118
+ Returns:
119
+ The array of FPS measurements for the last minute.
120
+ """
121
+ async with self ._lock :
122
+ return list (self ._fps_measurements )
123
+
124
+ @property
125
+ async def average_fps (self ) -> float :
126
+ """Calculate the average FPS from the measurements taken in the last minute.
127
+
128
+ Returns:
129
+ The average FPS over the last minute.
130
+ """
131
+ async with self ._lock :
132
+ if not self ._fps_measurements :
133
+ return 0.0
134
+ return sum (
135
+ measurement ["fps" ] for measurement in self ._fps_measurements
136
+ ) / len (self ._fps_measurements )
137
+
138
+ @property
139
+ async def last_fps_calculation_time (self ) -> float :
140
+ """Get the elapsed time since the last FPS calculation.
141
+
142
+ Returns:
143
+ The elapsed time in seconds since the last FPS calculation.
144
+ """
145
+ async with self ._lock :
146
+ return self ._last_fps_calculation_time - self ._fps_loop_start_time
147
+
63
148
async def recv (self ):
64
- return await self .pipeline .get_processed_video_frame ()
65
-
149
+ processed_frame = await self .pipeline .get_processed_video_frame ()
150
+
151
+ # Increment frame count for FPS calculation.
152
+ async with self ._lock :
153
+ self ._fps_interval_frame_count += 1
154
+ if not self ._running_event .is_set ():
155
+ self ._running_event .set ()
156
+
157
+ return processed_frame
158
+
66
159
67
160
class AudioStreamTrack (MediaStreamTrack ):
68
161
kind = "audio"
@@ -168,7 +261,7 @@ def on_datachannel(channel):
168
261
async def on_message (message ):
169
262
try :
170
263
params = json .loads (message )
171
-
264
+
172
265
if params .get ("type" ) == "get_nodes" :
173
266
nodes_info = await pipeline .get_nodes_info ()
174
267
response = {
@@ -201,6 +294,10 @@ def on_track(track):
201
294
tracks ["video" ] = videoTrack
202
295
sender = pc .addTrack (videoTrack )
203
296
297
+ # Store video track in app for stats.
298
+ stream_id = track .id
299
+ request .app ["video_tracks" ][stream_id ] = videoTrack
300
+
204
301
codec = "video/H264"
205
302
force_codec (pc , sender , codec )
206
303
elif track .kind == "audio" :
@@ -211,6 +308,7 @@ def on_track(track):
211
308
@track .on ("ended" )
212
309
async def on_ended ():
213
310
logger .info (f"{ track .kind } track ended" )
311
+ request .app ["video_tracks" ].pop (track .id , None )
214
312
215
313
@pc .on ("connectionstatechange" )
216
314
async def on_connectionstatechange ():
@@ -261,6 +359,7 @@ async def on_startup(app: web.Application):
261
359
cwd = app ["workspace" ], disable_cuda_malloc = True , gpu_only = True
262
360
)
263
361
app ["pcs" ] = set ()
362
+ app ["video_tracks" ] = {}
264
363
265
364
266
365
async def on_shutdown (app : web .Application ):
@@ -301,11 +400,24 @@ async def on_shutdown(app: web.Application):
301
400
app .on_startup .append (on_startup )
302
401
app .on_shutdown .append (on_shutdown )
303
402
304
- app .router .add_post ("/offer" , offer )
305
- app .router .add_post ("/prompt" , set_prompt )
306
403
app .router .add_get ("/" , health )
307
404
app .router .add_get ("/health" , health )
308
405
406
+ # WebRTC signalling and control routes.
407
+ app .router .add_post ("/offer" , offer )
408
+ app .router .add_post ("/prompt" , set_prompt )
409
+
410
+ # Add routes for getting stream statistics.
411
+ stream_stats = StreamStats (app )
412
+ app .router .add_get ("/streams/stats" , stream_stats .collect_all_stream_metrics )
413
+ app .router .add_get (
414
+ "/stream/{stream_id}/stats" , stream_stats .collect_stream_metrics_by_id
415
+ )
416
+
417
+ # Add hosted platform route prefix.
418
+ # NOTE: This ensures that the local and hosted experiences have consistent routes.
419
+ add_prefix_to_app_routes (app , "/live" )
420
+
309
421
def force_print (* args , ** kwargs ):
310
422
print (* args , ** kwargs , flush = True )
311
423
sys .stdout .flush ()
0 commit comments