Skip to content

Commit c48ac38

Browse files
rickstaahjpotter92
andcommitted
feat(server): add Prometheus metrics endpoint
This commit adds a Prometheus metrics endpoint (`/metrics`) that allows Prometheus to scrape stream metrics, such as FPS andaverage FPS per stream. Co-authored-by: jpotter92 <git@hjpotter92.email>
1 parent c343599 commit c48ac38

9 files changed

+320
-239
lines changed

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ comfyui @ git+https://github.com/hiddenswitch/ComfyUI.git@ce3583ad42c024b8f060d0
33
aiortc
44
aiohttp
55
toml
6-
twilio
6+
twilio
7+
prometheus_client

server/app.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
from aiortc.rtcrtpsender import RTCRtpSender
2525
from pipeline import Pipeline
2626
from twilio.rest import Client
27-
from utils import FPSMeter, StreamStats, add_prefix_to_app_routes, patch_loop_datagram
27+
from utils import patch_loop_datagram, add_prefix_to_app_routes, FPSMeter
28+
from metrics import MetricsManager, StreamStatsManager
29+
import time
2830

2931
logger = logging.getLogger(__name__)
3032
logging.getLogger("aiortc.rtcrtpsender").setLevel(logging.WARNING)
@@ -56,7 +58,9 @@ def __init__(self, track: MediaStreamTrack, pipeline: Pipeline):
5658
super().__init__()
5759
self.track = track
5860
self.pipeline = pipeline
59-
self.fps_meter = FPSMeter()
61+
self.fps_meter = FPSMeter(
62+
metrics_manager=app["metrics_manager"], track_id=track.id
63+
)
6064

6165
asyncio.create_task(self.collect_frames())
6266

@@ -312,6 +316,12 @@ async def on_shutdown(app: web.Application):
312316
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
313317
help="Set the logging level",
314318
)
319+
parser.add_argument(
320+
"--monitor",
321+
default=False,
322+
action="store_true",
323+
help="Start a Prometheus metrics endpoint for monitoring.",
324+
)
315325
args = parser.parse_args()
316326

317327
logging.basicConfig(
@@ -335,11 +345,23 @@ async def on_shutdown(app: web.Application):
335345
app.router.add_post("/prompt", set_prompt)
336346

337347
# Add routes for getting stream statistics.
338-
stream_stats = StreamStats(app)
339-
app.router.add_get("/streams/stats", stream_stats.collect_all_stream_metrics)
348+
stream_stats_manager = StreamStatsManager(video_tracks=app["video_tracks"])
340349
app.router.add_get(
341-
"/stream/{stream_id}/stats", stream_stats.collect_stream_metrics_by_id
350+
"/streams/stats", stream_stats_manager.collect_all_stream_metrics
342351
)
352+
app.router.add_get(
353+
"/stream/{stream_id}/stats", stream_stats_manager.collect_stream_metrics_by_id
354+
)
355+
356+
# Add Prometheus metrics endpoint.
357+
app["metrics_manager"] = MetricsManager()
358+
if args.monitor:
359+
app["metrics_manager"].enable()
360+
logger.info(
361+
f"Monitoring enabled - Prometheus metrics available at: "
362+
f"http://{args.host}:{args.port}/metrics"
363+
)
364+
app.router.add_get("/metrics", app["metrics_manager"].metrics_handler)
343365

344366
# Add hosted platform route prefix.
345367
# NOTE: This ensures that the local and hosted experiences have consistent routes.

server/metrics/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .prometheus_metrics import MetricsManager
2+
from .stream_stats import StreamStatsManager

server/metrics/prometheus_metrics.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""Prometheus metrics utilities."""
2+
3+
from prometheus_client import Gauge, generate_latest
4+
from aiohttp import web
5+
6+
7+
class MetricsManager:
8+
"""Manages Prometheus metrics collection."""
9+
10+
def __init__(self):
11+
self._enabled = False
12+
self._fps_gauge = Gauge(
13+
"stream_fps", "Frames per second of the stream", ["stream_id"]
14+
)
15+
16+
def enable(self):
17+
"""Enable Prometheus metrics collection."""
18+
self._enabled = True
19+
20+
def update_metrics(self, stream_id: str, fps: float):
21+
"""Update Prometheus metrics for a given stream.
22+
23+
Args:
24+
stream_id: The ID of the stream.
25+
fps: The current frames per second.
26+
avg_fps: The average frames per second per minute.
27+
"""
28+
if self._enabled:
29+
self._fps_gauge.labels(stream_id=stream_id).set(fps)
30+
31+
async def metrics_handler(self, _):
32+
"""Handle Prometheus metrics endpoint."""
33+
return web.Response(body=generate_latest(), content_type="text/plain")

server/metrics/stream_stats.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""Handles real-time video stream statistics (non-Prometheus, JSON API)."""
2+
3+
from typing import Any, Dict
4+
import json
5+
from aiohttp import web
6+
from aiortc import MediaStreamTrack
7+
8+
9+
class StreamStatsManager:
10+
"""Handles real-time video stream statistics collection."""
11+
12+
def __init__(self, video_tracks: Dict[str, MediaStreamTrack]):
13+
"""Initializes the StreamMetrics class.
14+
15+
Args:
16+
video_tracks: A dictionary that is updated with the current video streams
17+
by their IDs.
18+
"""
19+
self._video_tracks = video_tracks
20+
21+
async def collect_video_metrics(
22+
self, video_track: MediaStreamTrack
23+
) -> Dict[str, Any]:
24+
"""Collects real-time statistics for a video track.
25+
26+
Args:
27+
video_track: The video stream track instance.
28+
29+
Returns:
30+
A dictionary containing FPS-related statistics.
31+
"""
32+
return {
33+
"timestamp": await video_track.fps_meter.last_fps_calculation_time,
34+
"fps": await video_track.fps_meter.fps,
35+
"minute_avg_fps": await video_track.fps_meter.average_fps,
36+
"minute_fps_array": await video_track.fps_meter.fps_measurements,
37+
}
38+
39+
async def collect_all_stream_metrics(self, _) -> web.Response:
40+
"""Retrieves real-time metrics for all active video streams.
41+
42+
Returns:
43+
A JSON response containing FPS statistics for all streams.
44+
"""
45+
all_stats = {
46+
stream_id: await self.collect_video_metrics(track)
47+
for stream_id, track in self._video_tracks.items()
48+
}
49+
50+
return web.Response(
51+
content_type="application/json",
52+
text=json.dumps(all_stats),
53+
)
54+
55+
async def collect_stream_metrics_by_id(self, request: web.Request) -> web.Response:
56+
"""Retrieves real-time metrics for a specific video stream by ID.
57+
58+
Args:
59+
request: The HTTP request containing the stream ID.
60+
61+
Returns:
62+
A JSON response with stream metrics or an error message.
63+
"""
64+
stream_id = request.match_info.get("stream_id")
65+
video_track = self._video_tracks.get(stream_id)
66+
67+
if video_track:
68+
stats = await self.collect_video_metrics(video_track)
69+
else:
70+
stats = {"error": "Stream not found"}
71+
72+
return web.Response(
73+
content_type="application/json",
74+
text=json.dumps(stats),
75+
)

0 commit comments

Comments
 (0)