Skip to content

Commit d98e1b5

Browse files
rickstaaecmulli
andauthored
feat: add stream stats endpoint (yondonfu#48)
Adds a new stream stats endpoint which can be used to retrieve the fps metrics in a way that doesn't affect performance. --------- Co-authored-by: Evan Mullins <evancmullins@gmail.com>
1 parent c5d5e48 commit d98e1b5

File tree

2 files changed

+213
-17
lines changed

2 files changed

+213
-17
lines changed

server/app.py

+127-15
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import json
55
import logging
6+
from collections import deque
67
import sys
78

89
import torch
@@ -12,13 +13,6 @@
1213
torch.cuda.init()
1314

1415

15-
import torch
16-
17-
# Initialize CUDA before any other imports to prevent core dump.
18-
if torch.cuda.is_available():
19-
torch.cuda.init()
20-
21-
2216
from twilio.rest import Client
2317
from aiohttp import web
2418
from aiortc import (
@@ -27,12 +21,12 @@
2721
RTCConfiguration,
2822
RTCIceServer,
2923
MediaStreamTrack,
30-
RTCDataChannel,
3124
)
3225
from aiortc.rtcrtpsender import RTCRtpSender
3326
from aiortc.codecs import h264
3427
from pipeline import Pipeline
35-
from utils import patch_loop_datagram
28+
from utils import patch_loop_datagram, StreamStats, add_prefix_to_app_routes
29+
import time
3630

3731
logger = logging.getLogger(__name__)
3832
logging.getLogger('aiortc.rtcrtpsender').setLevel(logging.WARNING)
@@ -44,13 +38,38 @@
4438

4539

4640
class VideoStreamTrack(MediaStreamTrack):
41+
"""video stream track that processes video frames using a pipeline.
42+
43+
Attributes:
44+
kind (str): The kind of media, which is "video" for this class.
45+
track (MediaStreamTrack): The underlying media stream track.
46+
pipeline (Pipeline): The processing pipeline to apply to each video frame.
47+
"""
4748
kind = "video"
48-
def __init__(self, track: MediaStreamTrack, pipeline):
49+
def __init__(self, track: MediaStreamTrack, pipeline: Pipeline):
50+
"""Initialize the VideoStreamTrack.
51+
52+
Args:
53+
track: The underlying media stream track.
54+
pipeline: The processing pipeline to apply to each video frame.
55+
"""
4956
super().__init__()
5057
self.track = track
5158
self.pipeline = pipeline
59+
60+
self._lock = asyncio.Lock()
61+
self._fps_interval_frame_count = 0
62+
self._last_fps_calculation_time = None
63+
self._fps_loop_start_time = time.monotonic()
64+
self._fps = 0.0
65+
self._fps_measurements = deque(maxlen=60)
66+
self._running_event = asyncio.Event()
67+
5268
asyncio.create_task(self.collect_frames())
5369

70+
# Start metrics collection tasks.
71+
self._fps_stats_task = asyncio.create_task(self._calculate_fps_loop())
72+
5473
async def collect_frames(self):
5574
while True:
5675
try:
@@ -60,9 +79,83 @@ async def collect_frames(self):
6079
await self.pipeline.cleanup()
6180
raise Exception(f"Error collecting video frames: {str(e)}")
6281

82+
async def _calculate_fps_loop(self):
83+
"""Loop to calculate FPS periodically."""
84+
await self._running_event.wait()
85+
self._fps_loop_start_time = time.monotonic()
86+
while self.readyState != "ended":
87+
async with self._lock:
88+
current_time = time.monotonic()
89+
if self._last_fps_calculation_time is not None:
90+
time_diff = current_time - self._last_fps_calculation_time
91+
self._fps = self._fps_interval_frame_count / time_diff
92+
self._fps_measurements.append(
93+
{
94+
"timestamp": current_time - self._fps_loop_start_time,
95+
"fps": self._fps,
96+
}
97+
) # Store the FPS measurement with timestamp
98+
99+
# Reset start_time and frame_count for the next interval.
100+
self._last_fps_calculation_time = current_time
101+
self._fps_interval_frame_count = 0
102+
await asyncio.sleep(1) # Calculate FPS every second.
103+
104+
@property
105+
async def fps(self) -> float:
106+
"""Get the current output frames per second (FPS).
107+
108+
Returns:
109+
The current output FPS.
110+
"""
111+
async with self._lock:
112+
return self._fps
113+
114+
@property
115+
async def fps_measurements(self) -> list:
116+
"""Get the array of FPS measurements for the last minute.
117+
118+
Returns:
119+
The array of FPS measurements for the last minute.
120+
"""
121+
async with self._lock:
122+
return list(self._fps_measurements)
123+
124+
@property
125+
async def average_fps(self) -> float:
126+
"""Calculate the average FPS from the measurements taken in the last minute.
127+
128+
Returns:
129+
The average FPS over the last minute.
130+
"""
131+
async with self._lock:
132+
if not self._fps_measurements:
133+
return 0.0
134+
return sum(
135+
measurement["fps"] for measurement in self._fps_measurements
136+
) / len(self._fps_measurements)
137+
138+
@property
139+
async def last_fps_calculation_time(self) -> float:
140+
"""Get the elapsed time since the last FPS calculation.
141+
142+
Returns:
143+
The elapsed time in seconds since the last FPS calculation.
144+
"""
145+
async with self._lock:
146+
return self._last_fps_calculation_time - self._fps_loop_start_time
147+
63148
async def recv(self):
64-
return await self.pipeline.get_processed_video_frame()
65-
149+
processed_frame = await self.pipeline.get_processed_video_frame()
150+
151+
# Increment frame count for FPS calculation.
152+
async with self._lock:
153+
self._fps_interval_frame_count += 1
154+
if not self._running_event.is_set():
155+
self._running_event.set()
156+
157+
return processed_frame
158+
66159

67160
class AudioStreamTrack(MediaStreamTrack):
68161
kind = "audio"
@@ -168,7 +261,7 @@ def on_datachannel(channel):
168261
async def on_message(message):
169262
try:
170263
params = json.loads(message)
171-
264+
172265
if params.get("type") == "get_nodes":
173266
nodes_info = await pipeline.get_nodes_info()
174267
response = {
@@ -201,6 +294,10 @@ def on_track(track):
201294
tracks["video"] = videoTrack
202295
sender = pc.addTrack(videoTrack)
203296

297+
# Store video track in app for stats.
298+
stream_id = track.id
299+
request.app["video_tracks"][stream_id] = videoTrack
300+
204301
codec = "video/H264"
205302
force_codec(pc, sender, codec)
206303
elif track.kind == "audio":
@@ -211,6 +308,7 @@ def on_track(track):
211308
@track.on("ended")
212309
async def on_ended():
213310
logger.info(f"{track.kind} track ended")
311+
request.app["video_tracks"].pop(track.id, None)
214312

215313
@pc.on("connectionstatechange")
216314
async def on_connectionstatechange():
@@ -261,6 +359,7 @@ async def on_startup(app: web.Application):
261359
cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True
262360
)
263361
app["pcs"] = set()
362+
app["video_tracks"] = {}
264363

265364

266365
async def on_shutdown(app: web.Application):
@@ -301,11 +400,24 @@ async def on_shutdown(app: web.Application):
301400
app.on_startup.append(on_startup)
302401
app.on_shutdown.append(on_shutdown)
303402

304-
app.router.add_post("/offer", offer)
305-
app.router.add_post("/prompt", set_prompt)
306403
app.router.add_get("/", health)
307404
app.router.add_get("/health", health)
308405

406+
# WebRTC signalling and control routes.
407+
app.router.add_post("/offer", offer)
408+
app.router.add_post("/prompt", set_prompt)
409+
410+
# Add routes for getting stream statistics.
411+
stream_stats = StreamStats(app)
412+
app.router.add_get("/streams/stats", stream_stats.collect_all_stream_metrics)
413+
app.router.add_get(
414+
"/stream/{stream_id}/stats", stream_stats.collect_stream_metrics_by_id
415+
)
416+
417+
# Add hosted platform route prefix.
418+
# NOTE: This ensures that the local and hosted experiences have consistent routes.
419+
add_prefix_to_app_routes(app, "/live")
420+
309421
def force_print(*args, **kwargs):
310422
print(*args, **kwargs, flush=True)
311423
sys.stdout.flush()

server/utils.py

+86-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
"""Utility functions for the server."""
12
import asyncio
23
import random
34
import types
45
import logging
5-
6-
from typing import List, Tuple
6+
import json
7+
from aiohttp import web
8+
from aiortc import MediaStreamTrack
9+
from typing import List, Tuple, Any, Dict
710

811
logger = logging.getLogger(__name__)
912

@@ -48,3 +51,84 @@ async def create_datagram_endpoint(
4851

4952
loop.create_datagram_endpoint = types.MethodType(create_datagram_endpoint, loop)
5053
loop._patch_done = True
54+
55+
56+
def add_prefix_to_app_routes(app: web.Application, prefix: str):
57+
"""Add a prefix to all routes in the given application.
58+
59+
Args:
60+
app: The web application whose routes will be prefixed.
61+
prefix: The prefix to add to all routes.
62+
"""
63+
prefix = prefix.rstrip("/")
64+
for route in list(app.router.routes()):
65+
new_path = prefix + route.resource.canonical
66+
app.router.add_route(route.method, new_path, route.handler)
67+
68+
69+
class StreamStats:
70+
"""Handles real-time video stream statistics collection."""
71+
72+
def __init__(self, app: web.Application):
73+
"""Initializes the StreamMetrics class.
74+
75+
Args:
76+
app: The web application instance storing video streams under the
77+
"video_tracks" key.
78+
"""
79+
self._app = app
80+
81+
async def collect_video_metrics(self, video_track: MediaStreamTrack) -> Dict[str, Any]:
82+
"""Collects real-time statistics for a video track.
83+
84+
Args:
85+
video_track: The video stream track instance.
86+
87+
Returns:
88+
A dictionary containing FPS-related statistics.
89+
"""
90+
return {
91+
"timestamp": await video_track.last_fps_calculation_time,
92+
"fps": await video_track.fps,
93+
"minute_avg_fps": await video_track.average_fps,
94+
"minute_fps_array": await video_track.fps_measurements,
95+
}
96+
97+
async def collect_all_stream_metrics(self, _) -> web.Response:
98+
"""Retrieves real-time metrics for all active video streams.
99+
100+
Returns:
101+
A JSON response containing FPS statistics for all streams.
102+
"""
103+
video_tracks = self._app.get("video_tracks", {})
104+
all_stats = {
105+
stream_id: await self.collect_video_metrics(track)
106+
for stream_id, track in video_tracks.items()
107+
}
108+
109+
return web.Response(
110+
content_type="application/json",
111+
text=json.dumps(all_stats),
112+
)
113+
114+
async def collect_stream_metrics_by_id(self, request: web.Request) -> web.Response:
115+
"""Retrieves real-time metrics for a specific video stream by ID.
116+
117+
Args:
118+
request: The HTTP request containing the stream ID.
119+
120+
Returns:
121+
A JSON response with stream metrics or an error message.
122+
"""
123+
stream_id = request.match_info.get("stream_id")
124+
video_track = self._app["video_tracks"].get(stream_id)
125+
126+
if video_track:
127+
stats = await self.collect_video_metrics(video_track)
128+
else:
129+
stats = {"error": "Stream not found"}
130+
131+
return web.Response(
132+
content_type="application/json",
133+
text=json.dumps(stats),
134+
)

0 commit comments

Comments
 (0)