Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add stream stats endpoint #48

Merged
merged 13 commits into from
Mar 5, 2025
142 changes: 127 additions & 15 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import json
import logging
from collections import deque
import sys

import torch
Expand All @@ -12,13 +13,6 @@
torch.cuda.init()


import torch

# Initialize CUDA before any other imports to prevent core dump.
if torch.cuda.is_available():
torch.cuda.init()


from twilio.rest import Client
from aiohttp import web
from aiortc import (
Expand All @@ -27,12 +21,12 @@
RTCConfiguration,
RTCIceServer,
MediaStreamTrack,
RTCDataChannel,
)
from aiortc.rtcrtpsender import RTCRtpSender
from aiortc.codecs import h264
from pipeline import Pipeline
from utils import patch_loop_datagram
from utils import patch_loop_datagram, StreamStats, add_prefix_to_app_routes
import time

logger = logging.getLogger(__name__)
logging.getLogger('aiortc.rtcrtpsender').setLevel(logging.WARNING)
Expand All @@ -44,13 +38,38 @@


class VideoStreamTrack(MediaStreamTrack):
"""video stream track that processes video frames using a pipeline.

Attributes:
kind (str): The kind of media, which is "video" for this class.
track (MediaStreamTrack): The underlying media stream track.
pipeline (Pipeline): The processing pipeline to apply to each video frame.
"""
kind = "video"
def __init__(self, track: MediaStreamTrack, pipeline):
def __init__(self, track: MediaStreamTrack, pipeline: Pipeline):
"""Initialize the VideoStreamTrack.

Args:
track: The underlying media stream track.
pipeline: The processing pipeline to apply to each video frame.
"""
super().__init__()
self.track = track
self.pipeline = pipeline

self._lock = asyncio.Lock()
self._fps_interval_frame_count = 0
self._last_fps_calculation_time = None
self._fps_loop_start_time = time.monotonic()
self._fps = 0.0
self._fps_measurements = deque(maxlen=60)
self._running_event = asyncio.Event()

asyncio.create_task(self.collect_frames())

# Start metrics collection tasks.
self._fps_stats_task = asyncio.create_task(self._calculate_fps_loop())

async def collect_frames(self):
while True:
try:
Expand All @@ -60,9 +79,83 @@ async def collect_frames(self):
await self.pipeline.cleanup()
raise Exception(f"Error collecting video frames: {str(e)}")

async def _calculate_fps_loop(self):
"""Loop to calculate FPS periodically."""
await self._running_event.wait()
self._fps_loop_start_time = time.monotonic()
while self.readyState != "ended":
async with self._lock:
current_time = time.monotonic()
if self._last_fps_calculation_time is not None:
time_diff = current_time - self._last_fps_calculation_time
self._fps = self._fps_interval_frame_count / time_diff
self._fps_measurements.append(
{
"timestamp": current_time - self._fps_loop_start_time,
"fps": self._fps,
}
) # Store the FPS measurement with timestamp

# Reset start_time and frame_count for the next interval.
self._last_fps_calculation_time = current_time
self._fps_interval_frame_count = 0
await asyncio.sleep(1) # Calculate FPS every second.

@property
async def fps(self) -> float:
"""Get the current output frames per second (FPS).

Returns:
The current output FPS.
"""
async with self._lock:
return self._fps

@property
async def fps_measurements(self) -> list:
"""Get the array of FPS measurements for the last minute.

Returns:
The array of FPS measurements for the last minute.
"""
async with self._lock:
return list(self._fps_measurements)

@property
async def average_fps(self) -> float:
"""Calculate the average FPS from the measurements taken in the last minute.

Returns:
The average FPS over the last minute.
"""
async with self._lock:
if not self._fps_measurements:
return 0.0
return sum(
measurement["fps"] for measurement in self._fps_measurements
) / len(self._fps_measurements)

@property
async def last_fps_calculation_time(self) -> float:
"""Get the elapsed time since the last FPS calculation.

Returns:
The elapsed time in seconds since the last FPS calculation.
"""
async with self._lock:
return self._last_fps_calculation_time - self._fps_loop_start_time

async def recv(self):
return await self.pipeline.get_processed_video_frame()

processed_frame = await self.pipeline.get_processed_video_frame()

# Increment frame count for FPS calculation.
async with self._lock:
self._fps_interval_frame_count += 1
if not self._running_event.is_set():
self._running_event.set()

return processed_frame


class AudioStreamTrack(MediaStreamTrack):
kind = "audio"
Expand Down Expand Up @@ -168,7 +261,7 @@ def on_datachannel(channel):
async def on_message(message):
try:
params = json.loads(message)

if params.get("type") == "get_nodes":
nodes_info = await pipeline.get_nodes_info()
response = {
Expand Down Expand Up @@ -201,6 +294,10 @@ def on_track(track):
tracks["video"] = videoTrack
sender = pc.addTrack(videoTrack)

# Store video track in app for stats.
stream_id = track.id
request.app["video_tracks"][stream_id] = videoTrack

codec = "video/H264"
force_codec(pc, sender, codec)
elif track.kind == "audio":
Expand All @@ -211,6 +308,7 @@ def on_track(track):
@track.on("ended")
async def on_ended():
logger.info(f"{track.kind} track ended")
request.app["video_tracks"].pop(track.id, None)

@pc.on("connectionstatechange")
async def on_connectionstatechange():
Expand Down Expand Up @@ -261,6 +359,7 @@ async def on_startup(app: web.Application):
cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True
)
app["pcs"] = set()
app["video_tracks"] = {}


async def on_shutdown(app: web.Application):
Expand Down Expand Up @@ -301,11 +400,24 @@ async def on_shutdown(app: web.Application):
app.on_startup.append(on_startup)
app.on_shutdown.append(on_shutdown)

app.router.add_post("/offer", offer)
app.router.add_post("/prompt", set_prompt)
app.router.add_get("/", health)
app.router.add_get("/health", health)

# WebRTC signalling and control routes.
app.router.add_post("/offer", offer)
app.router.add_post("/prompt", set_prompt)

# Add routes for getting stream statistics.
stream_stats = StreamStats(app)
app.router.add_get("/streams/stats", stream_stats.collect_all_stream_metrics)
app.router.add_get(
"/stream/{stream_id}/stats", stream_stats.collect_stream_metrics_by_id
)

# Add hosted platform route prefix.
# NOTE: This ensures that the local and hosted experiences have consistent routes.
add_prefix_to_app_routes(app, "/live")

def force_print(*args, **kwargs):
print(*args, **kwargs, flush=True)
sys.stdout.flush()
Expand Down
88 changes: 86 additions & 2 deletions server/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Utility functions for the server."""
import asyncio
import random
import types
import logging

from typing import List, Tuple
import json
from aiohttp import web
from aiortc import MediaStreamTrack
from typing import List, Tuple, Any, Dict

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -48,3 +51,84 @@ async def create_datagram_endpoint(

loop.create_datagram_endpoint = types.MethodType(create_datagram_endpoint, loop)
loop._patch_done = True


def add_prefix_to_app_routes(app: web.Application, prefix: str):
"""Add a prefix to all routes in the given application.

Args:
app: The web application whose routes will be prefixed.
prefix: The prefix to add to all routes.
"""
prefix = prefix.rstrip("/")
for route in list(app.router.routes()):
new_path = prefix + route.resource.canonical
app.router.add_route(route.method, new_path, route.handler)


class StreamStats:
"""Handles real-time video stream statistics collection."""

def __init__(self, app: web.Application):
"""Initializes the StreamMetrics class.

Args:
app: The web application instance storing video streams under the
"video_tracks" key.
"""
self._app = app

async def collect_video_metrics(self, video_track: MediaStreamTrack) -> Dict[str, Any]:
"""Collects real-time statistics for a video track.

Args:
video_track: The video stream track instance.

Returns:
A dictionary containing FPS-related statistics.
"""
return {
"timestamp": await video_track.last_fps_calculation_time,
"fps": await video_track.fps,
"minute_avg_fps": await video_track.average_fps,
"minute_fps_array": await video_track.fps_measurements,
}

async def collect_all_stream_metrics(self, _) -> web.Response:
"""Retrieves real-time metrics for all active video streams.

Returns:
A JSON response containing FPS statistics for all streams.
"""
video_tracks = self._app.get("video_tracks", {})
all_stats = {
stream_id: await self.collect_video_metrics(track)
for stream_id, track in video_tracks.items()
}

return web.Response(
content_type="application/json",
text=json.dumps(all_stats),
)

async def collect_stream_metrics_by_id(self, request: web.Request) -> web.Response:
"""Retrieves real-time metrics for a specific video stream by ID.

Args:
request: The HTTP request containing the stream ID.

Returns:
A JSON response with stream metrics or an error message.
"""
stream_id = request.match_info.get("stream_id")
video_track = self._app["video_tracks"].get(stream_id)

if video_track:
stats = await self.collect_video_metrics(video_track)
else:
stats = {"error": "Stream not found"}

return web.Response(
content_type="application/json",
text=json.dumps(stats),
)