Skip to content

Commit bce4e75

Browse files
committed
feat(server): add Prometheus metrics endpoint
This commit adds a Prometheus metrics endpoint (`/metrics`) that allows Prometheus to scrape stream metrics, such as FPS andaverage FPS per stream.
1 parent f1b0fb1 commit bce4e75

8 files changed

+250
-161
lines changed

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ comfyui @ git+https://github.com/hiddenswitch/ComfyUI.git@ce3583ad42c024b8f060d0
33
aiortc
44
aiohttp
55
toml
6-
twilio
6+
twilio
7+
prometheus_client

server/app.py

+62-26
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,13 @@
2525
from aiortc.rtcrtpsender import RTCRtpSender
2626
from aiortc.codecs import h264
2727
from pipeline import Pipeline
28-
from utils import patch_loop_datagram, StreamStats, add_prefix_to_app_routes
28+
from utils import patch_loop_datagram, add_prefix_to_app_routes
29+
from metrics import MetricsManager, StreamStatsManager
2930
import time
3031

3132
logger = logging.getLogger(__name__)
32-
logging.getLogger('aiortc.rtcrtpsender').setLevel(logging.WARNING)
33-
logging.getLogger('aiortc.rtcrtpreceiver').setLevel(logging.WARNING)
33+
logging.getLogger("aiortc.rtcrtpsender").setLevel(logging.WARNING)
34+
logging.getLogger("aiortc.rtcrtpreceiver").setLevel(logging.WARNING)
3435

3536

3637
MAX_BITRATE = 2000000
@@ -45,7 +46,9 @@ class VideoStreamTrack(MediaStreamTrack):
4546
track (MediaStreamTrack): The underlying media stream track.
4647
pipeline (Pipeline): The processing pipeline to apply to each video frame.
4748
"""
49+
4850
kind = "video"
51+
4952
def __init__(self, track: MediaStreamTrack, pipeline: Pipeline):
5053
"""Initialize the VideoStreamTrack.
5154
@@ -63,6 +66,7 @@ def __init__(self, track: MediaStreamTrack, pipeline: Pipeline):
6366
self._fps_loop_start_time = time.monotonic()
6467
self._fps = 0.0
6568
self._fps_measurements = deque(maxlen=60)
69+
self._average_fps = 0.0
6670
self._running_event = asyncio.Event()
6771

6872
asyncio.create_task(self.collect_frames())
@@ -88,18 +92,36 @@ async def _calculate_fps_loop(self):
8892
current_time = time.monotonic()
8993
if self._last_fps_calculation_time is not None:
9094
time_diff = current_time - self._last_fps_calculation_time
91-
self._fps = self._fps_interval_frame_count / time_diff
95+
self._fps = (
96+
self._fps_interval_frame_count / time_diff
97+
if time_diff > 0
98+
else 0.0
99+
)
92100
self._fps_measurements.append(
93101
{
94102
"timestamp": current_time - self._fps_loop_start_time,
95103
"fps": self._fps,
96104
}
97105
) # Store the FPS measurement with timestamp
98106

99-
# Reset start_time and frame_count for the next interval.
107+
# Store the average FPS over the last minute.
108+
self._average_fps = (
109+
sum(m["fps"] for m in self._fps_measurements)
110+
/ len(self._fps_measurements)
111+
if self._fps_measurements
112+
else self._fps
113+
)
114+
115+
# Reset tracking variables for the next interval.
100116
self._last_fps_calculation_time = current_time
101117
self._fps_interval_frame_count = 0
102-
await asyncio.sleep(1) # Calculate FPS every second.
118+
119+
# Update Prometheus metrics if enabled.
120+
app["metrics_manager"].update_metrics(
121+
self.track.id, self._fps, self._average_fps
122+
)
123+
124+
await asyncio.sleep(1) # Calculate FPS every second
103125

104126
@property
105127
async def fps(self) -> float:
@@ -129,11 +151,7 @@ async def average_fps(self) -> float:
129151
The average FPS over the last minute.
130152
"""
131153
async with self._lock:
132-
if not self._fps_measurements:
133-
return 0.0
134-
return sum(
135-
measurement["fps"] for measurement in self._fps_measurements
136-
) / len(self._fps_measurements)
154+
return self._average_fps
137155

138156
@property
139157
async def last_fps_calculation_time(self) -> float:
@@ -159,6 +177,7 @@ async def recv(self):
159177

160178
class AudioStreamTrack(MediaStreamTrack):
161179
kind = "audio"
180+
162181
def __init__(self, track: MediaStreamTrack, pipeline):
163182
super().__init__()
164183
self.track = track
@@ -257,30 +276,29 @@ async def offer(request):
257276
@pc.on("datachannel")
258277
def on_datachannel(channel):
259278
if channel.label == "control":
279+
260280
@channel.on("message")
261281
async def on_message(message):
262282
try:
263283
params = json.loads(message)
264284

265285
if params.get("type") == "get_nodes":
266286
nodes_info = await pipeline.get_nodes_info()
267-
response = {
268-
"type": "nodes_info",
269-
"nodes": nodes_info
270-
}
287+
response = {"type": "nodes_info", "nodes": nodes_info}
271288
channel.send(json.dumps(response))
272289
elif params.get("type") == "update_prompts":
273290
if "prompts" not in params:
274-
logger.warning("[Control] Missing prompt in update_prompt message")
291+
logger.warning(
292+
"[Control] Missing prompt in update_prompt message"
293+
)
275294
return
276295
await pipeline.update_prompts(params["prompts"])
277-
response = {
278-
"type": "prompts_updated",
279-
"success": True
280-
}
296+
response = {"type": "prompts_updated", "success": True}
281297
channel.send(json.dumps(response))
282298
else:
283-
logger.warning("[Server] Invalid message format - missing required fields")
299+
logger.warning(
300+
"[Server] Invalid message format - missing required fields"
301+
)
284302
except json.JSONDecodeError:
285303
logger.error("[Server] Invalid JSON received")
286304
except Exception as e:
@@ -385,12 +403,18 @@ async def on_shutdown(app: web.Application):
385403
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
386404
help="Set the logging level",
387405
)
406+
parser.add_argument(
407+
"--monitor",
408+
default=False,
409+
action="store_true",
410+
help="Start a Prometheus metrics endpoint for monitoring.",
411+
)
388412
args = parser.parse_args()
389413

390414
logging.basicConfig(
391415
level=args.log_level.upper(),
392-
format='%(asctime)s [%(levelname)s] %(message)s',
393-
datefmt='%H:%M:%S'
416+
format="%(asctime)s [%(levelname)s] %(message)s",
417+
datefmt="%H:%M:%S",
394418
)
395419

396420
app = web.Application()
@@ -408,11 +432,23 @@ async def on_shutdown(app: web.Application):
408432
app.router.add_post("/prompt", set_prompt)
409433

410434
# Add routes for getting stream statistics.
411-
stream_stats = StreamStats(app)
412-
app.router.add_get("/streams/stats", stream_stats.collect_all_stream_metrics)
435+
stream_stats_manager = StreamStatsManager(app)
413436
app.router.add_get(
414-
"/stream/{stream_id}/stats", stream_stats.collect_stream_metrics_by_id
437+
"/streams/stats", stream_stats_manager.collect_all_stream_metrics
415438
)
439+
app.router.add_get(
440+
"/stream/{stream_id}/stats", stream_stats_manager.collect_stream_metrics_by_id
441+
)
442+
443+
# Add Prometheus metrics endpoint.
444+
app["metrics_manager"] = MetricsManager()
445+
if args.monitor:
446+
app["metrics_manager"].enable()
447+
logger.info(
448+
f"Monitoring enabled - Prometheus metrics available at: "
449+
f"http://{args.host}:{args.port}/metrics"
450+
)
451+
app.router.add_get("/metrics", app["metrics_manager"].metrics_handler)
416452

417453
# Add hosted platform route prefix.
418454
# NOTE: This ensures that the local and hosted experiences have consistent routes.

server/metrics/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .prometheus_metrics import MetricsManager
2+
from .stream_stats import StreamStatsManager

server/metrics/prometheus_metrics.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""Prometheus metrics utilities."""
2+
3+
from prometheus_client import Gauge, generate_latest
4+
from aiohttp import web
5+
6+
7+
class MetricsManager:
8+
"""Manages Prometheus metrics collection.
9+
10+
Attributes:
11+
fps_gauge: A Prometheus gauge for the current frames per second.
12+
avg_fps_gauge: A Prometheus gauge for the average frames per second per minute.
13+
"""
14+
15+
def __init__(self):
16+
self._enabled = False
17+
self._fps_gauge = Gauge(
18+
"stream_fps", "Frames per second of the stream", ["stream_id"]
19+
)
20+
self._avg_fps_gauge = Gauge(
21+
"stream_avg_fps", "Average frames per second per minute", ["stream_id"]
22+
)
23+
24+
def enable(self):
25+
"""Enable Prometheus metrics collection."""
26+
self._enabled = True
27+
28+
def update_metrics(self, stream_id: str, fps: float, avg_fps: float):
29+
"""Update Prometheus metrics for a given stream.
30+
31+
Args:
32+
stream_id: The ID of the stream.
33+
fps: The current frames per second.
34+
avg_fps: The average frames per second per minute.
35+
"""
36+
if self._enabled:
37+
self._fps_gauge.labels(stream_id=stream_id).set(fps)
38+
self._avg_fps_gauge.labels(stream_id=stream_id).set(avg_fps)
39+
40+
async def metrics_handler(self, _):
41+
"""Handle Prometheus metrics endpoint."""
42+
return web.Response(body=generate_latest(), content_type="text/plain")

server/metrics/stream_stats.py

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Handles real-time video stream statistics (non-Prometheus, JSON API)."""
2+
3+
from typing import Any, Dict
4+
import json
5+
from aiohttp import web
6+
from aiortc import MediaStreamTrack
7+
8+
9+
class StreamStatsManager:
10+
"""Manages real-time video stream statistics retrieval."""
11+
12+
def __init__(self, app: web.Application):
13+
"""Initializes the StreamMetrics class.
14+
15+
Args:
16+
app: The web application instance storing video streams under the
17+
"video_tracks" key.
18+
"""
19+
self._app = app
20+
21+
async def collect_video_metrics(
22+
self, video_track: MediaStreamTrack
23+
) -> Dict[str, Any]:
24+
"""Collects real-time statistics for a video track.
25+
26+
Args:
27+
video_track: The video stream track instance.
28+
29+
Returns:
30+
A dictionary containing FPS-related statistics.
31+
"""
32+
return {
33+
"timestamp": await video_track.last_fps_calculation_time,
34+
"fps": await video_track.fps,
35+
"minute_avg_fps": await video_track.average_fps,
36+
"minute_fps_array": await video_track.fps_measurements,
37+
}
38+
39+
async def collect_all_stream_metrics(self, _) -> web.Response:
40+
"""Retrieves real-time metrics for all active video streams.
41+
42+
Returns:
43+
A JSON response containing FPS statistics for all streams.
44+
"""
45+
video_tracks = self._app.get("video_tracks", {})
46+
all_stats = {
47+
stream_id: await self.collect_video_metrics(track)
48+
for stream_id, track in video_tracks.items()
49+
}
50+
51+
return web.Response(
52+
content_type="application/json",
53+
text=json.dumps(all_stats),
54+
)
55+
56+
async def collect_stream_metrics_by_id(self, request: web.Request) -> web.Response:
57+
"""Retrieves real-time metrics for a specific video stream by ID.
58+
59+
Args:
60+
request: The HTTP request containing the stream ID.
61+
62+
Returns:
63+
A JSON response with stream metrics or an error message.
64+
"""
65+
stream_id = request.match_info.get("stream_id")
66+
video_track = self._app["video_tracks"].get(stream_id)
67+
68+
if video_track:
69+
stats = await self.collect_video_metrics(video_track)
70+
else:
71+
stats = {"error": "Stream not found"}
72+
73+
return web.Response(
74+
content_type="application/json",
75+
text=json.dumps(stats),
76+
)

0 commit comments

Comments
 (0)