Skip to content

Commit 5a7fb2e

Browse files
rickstaahjpotter92
andcommitted
feat(server): add Prometheus metrics endpoint
This commit introduces a `/metrics` endpoint for Prometheus to scrape stream metrics, including FPS and average FPS per stream. Additionally, it adds the `--stream-id-label` argument, allowing users to optionally include the `stream-id` label in Prometheus metrics. Co-authored-by: jpotter92 <git@hjpotter92.email>
1 parent c343599 commit 5a7fb2e

9 files changed

+338
-239
lines changed

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ comfyui @ git+https://github.com/hiddenswitch/ComfyUI.git@ce3583ad42c024b8f060d0
33
aiortc
44
aiohttp
55
toml
6-
twilio
6+
twilio
7+
prometheus_client

server/app.py

+33-5
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
from aiortc.rtcrtpsender import RTCRtpSender
2525
from pipeline import Pipeline
2626
from twilio.rest import Client
27-
from utils import FPSMeter, StreamStats, add_prefix_to_app_routes, patch_loop_datagram
27+
from utils import patch_loop_datagram, add_prefix_to_app_routes, FPSMeter
28+
from metrics import MetricsManager, StreamStatsManager
29+
import time
2830

2931
logger = logging.getLogger(__name__)
3032
logging.getLogger("aiortc.rtcrtpsender").setLevel(logging.WARNING)
@@ -56,7 +58,9 @@ def __init__(self, track: MediaStreamTrack, pipeline: Pipeline):
5658
super().__init__()
5759
self.track = track
5860
self.pipeline = pipeline
59-
self.fps_meter = FPSMeter()
61+
self.fps_meter = FPSMeter(
62+
metrics_manager=app["metrics_manager"], track_id=track.id
63+
)
6064

6165
asyncio.create_task(self.collect_frames())
6266

@@ -312,6 +316,18 @@ async def on_shutdown(app: web.Application):
312316
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
313317
help="Set the logging level",
314318
)
319+
parser.add_argument(
320+
"--monitor",
321+
default=False,
322+
action="store_true",
323+
help="Start a Prometheus metrics endpoint for monitoring.",
324+
)
325+
parser.add_argument(
326+
"--stream-id-label",
327+
default=False,
328+
action="store_true",
329+
help="Include stream ID as a label in Prometheus metrics.",
330+
)
315331
args = parser.parse_args()
316332

317333
logging.basicConfig(
@@ -335,11 +351,23 @@ async def on_shutdown(app: web.Application):
335351
app.router.add_post("/prompt", set_prompt)
336352

337353
# Add routes for getting stream statistics.
338-
stream_stats = StreamStats(app)
339-
app.router.add_get("/streams/stats", stream_stats.collect_all_stream_metrics)
354+
stream_stats_manager = StreamStatsManager(app)
340355
app.router.add_get(
341-
"/stream/{stream_id}/stats", stream_stats.collect_stream_metrics_by_id
356+
"/streams/stats", stream_stats_manager.collect_all_stream_metrics
342357
)
358+
app.router.add_get(
359+
"/stream/{stream_id}/stats", stream_stats_manager.collect_stream_metrics_by_id
360+
)
361+
362+
# Add Prometheus metrics endpoint.
363+
app["metrics_manager"] = MetricsManager(include_stream_id=args.stream_id_label)
364+
if args.monitor:
365+
app["metrics_manager"].enable()
366+
logger.info(
367+
f"Monitoring enabled - Prometheus metrics available at: "
368+
f"http://{args.host}:{args.port}/metrics"
369+
)
370+
app.router.add_get("/metrics", app["metrics_manager"].metrics_handler)
343371

344372
# Add hosted platform route prefix.
345373
# NOTE: This ensures that the local and hosted experiences have consistent routes.

server/metrics/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .prometheus_metrics import MetricsManager
2+
from .stream_stats import StreamStatsManager

server/metrics/prometheus_metrics.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""Prometheus metrics utilities."""
2+
3+
from prometheus_client import Gauge, generate_latest
4+
from aiohttp import web
5+
from typing import Optional
6+
7+
8+
class MetricsManager:
9+
"""Manages Prometheus metrics collection."""
10+
11+
def __init__(self, include_stream_id: bool = False):
12+
"""Initializes the MetricsManager class.
13+
14+
Args:
15+
include_stream_id: Whether to include the stream ID as a label in the metrics.
16+
"""
17+
self._enabled = False
18+
self._include_stream_id = include_stream_id
19+
20+
base_labels = ["stream_id"] if include_stream_id else []
21+
self._fps_gauge = Gauge(
22+
"stream_fps", "Frames per second of the stream", base_labels
23+
)
24+
25+
def enable(self):
26+
"""Enable Prometheus metrics collection."""
27+
self._enabled = True
28+
29+
def update_fps_metrics(self, fps: float, stream_id: Optional[str] = None):
30+
"""Update Prometheus metrics for a given stream.
31+
32+
Args:
33+
fps: The current frames per second.
34+
stream_id: The ID of the stream.
35+
"""
36+
if self._enabled:
37+
if self._include_stream_id:
38+
self._fps_gauge.labels(stream_id=stream_id or "").set(fps)
39+
else:
40+
self._fps_gauge.set(fps)
41+
42+
async def metrics_handler(self, _):
43+
"""Handle Prometheus metrics endpoint."""
44+
return web.Response(body=generate_latest(), content_type="text/plain")

server/metrics/stream_stats.py

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Handles real-time video stream statistics (non-Prometheus, JSON API)."""
2+
3+
from typing import Any, Dict
4+
import json
5+
from aiohttp import web
6+
from aiortc import MediaStreamTrack
7+
8+
9+
class StreamStatsManager:
10+
"""Handles real-time video stream statistics collection."""
11+
12+
def __init__(self, app: web.Application):
13+
"""Initializes the StreamMetrics class.
14+
15+
Args:
16+
app: The web application instance storing stream tracks.
17+
"""
18+
self._app = app
19+
20+
async def collect_video_metrics(
21+
self, video_track: MediaStreamTrack
22+
) -> Dict[str, Any]:
23+
"""Collects real-time statistics for a video track.
24+
25+
Args:
26+
video_track: The video stream track instance.
27+
28+
Returns:
29+
A dictionary containing FPS-related statistics.
30+
"""
31+
return {
32+
"timestamp": await video_track.fps_meter.last_fps_calculation_time,
33+
"fps": await video_track.fps_meter.fps,
34+
"minute_avg_fps": await video_track.fps_meter.average_fps,
35+
"minute_fps_array": await video_track.fps_meter.fps_measurements,
36+
}
37+
38+
async def collect_all_stream_metrics(self, _) -> web.Response:
39+
"""Retrieves real-time metrics for all active video streams.
40+
41+
Returns:
42+
A JSON response containing FPS statistics for all streams.
43+
"""
44+
video_tracks = self._app.get("video_tracks", {})
45+
all_stats = {
46+
stream_id: await self.collect_video_metrics(track)
47+
for stream_id, track in video_tracks.items()
48+
}
49+
50+
return web.Response(
51+
content_type="application/json",
52+
text=json.dumps(all_stats),
53+
)
54+
55+
async def collect_stream_metrics_by_id(self, request: web.Request) -> web.Response:
56+
"""Retrieves real-time metrics for a specific video stream by ID.
57+
58+
Args:
59+
request: The HTTP request containing the stream ID.
60+
61+
Returns:
62+
A JSON response with stream metrics or an error message.
63+
"""
64+
stream_id = request.match_info.get("stream_id")
65+
video_tracks = self._app.get("video_tracks", {})
66+
video_track = video_tracks.get(stream_id)
67+
68+
if video_track:
69+
stats = await self.collect_video_metrics(video_track)
70+
else:
71+
stats = {"error": "Stream not found"}
72+
73+
return web.Response(
74+
content_type="application/json",
75+
text=json.dumps(stats),
76+
)

0 commit comments

Comments
 (0)