Skip to content

Commit a38cb3e

Browse files
authored
refactor(server): improve FPS Stats collection logic (#141)
This commit extracts the FPS statistics collection into its own class to keep the `VideoStreamTrack` implementation cleaner and more maintainable. This also makes the logic reusable across different components.
1 parent 8622a31 commit a38cb3e

File tree

2 files changed

+137
-111
lines changed

2 files changed

+137
-111
lines changed

server/app.py

+33-106
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
import asyncio
21
import argparse
3-
import os
2+
import asyncio
43
import json
54
import logging
6-
from collections import deque
5+
import os
76
import sys
87

98
import torch
@@ -13,24 +12,23 @@
1312
torch.cuda.init()
1413

1514

16-
from twilio.rest import Client
1715
from aiohttp import web
1816
from aiortc import (
19-
RTCPeerConnection,
20-
RTCSessionDescription,
17+
MediaStreamTrack,
2118
RTCConfiguration,
2219
RTCIceServer,
23-
MediaStreamTrack,
20+
RTCPeerConnection,
21+
RTCSessionDescription,
2422
)
25-
from aiortc.rtcrtpsender import RTCRtpSender
2623
from aiortc.codecs import h264
24+
from aiortc.rtcrtpsender import RTCRtpSender
2725
from pipeline import Pipeline
28-
from utils import patch_loop_datagram, StreamStats, add_prefix_to_app_routes
29-
import time
26+
from twilio.rest import Client
27+
from utils import FPSMeter, StreamStats, add_prefix_to_app_routes, patch_loop_datagram
3028

3129
logger = logging.getLogger(__name__)
32-
logging.getLogger('aiortc.rtcrtpsender').setLevel(logging.WARNING)
33-
logging.getLogger('aiortc.rtcrtpreceiver').setLevel(logging.WARNING)
30+
logging.getLogger("aiortc.rtcrtpsender").setLevel(logging.WARNING)
31+
logging.getLogger("aiortc.rtcrtpreceiver").setLevel(logging.WARNING)
3432

3533

3634
MAX_BITRATE = 2000000
@@ -45,7 +43,9 @@ class VideoStreamTrack(MediaStreamTrack):
4543
track (MediaStreamTrack): The underlying media stream track.
4644
pipeline (Pipeline): The processing pipeline to apply to each video frame.
4745
"""
46+
4847
kind = "video"
48+
4949
def __init__(self, track: MediaStreamTrack, pipeline: Pipeline):
5050
"""Initialize the VideoStreamTrack.
5151
@@ -56,21 +56,14 @@ def __init__(self, track: MediaStreamTrack, pipeline: Pipeline):
5656
super().__init__()
5757
self.track = track
5858
self.pipeline = pipeline
59-
60-
self._lock = asyncio.Lock()
61-
self._fps_interval_frame_count = 0
62-
self._last_fps_calculation_time = None
63-
self._fps_loop_start_time = time.monotonic()
64-
self._fps = 0.0
65-
self._fps_measurements = deque(maxlen=60)
66-
self._running_event = asyncio.Event()
59+
self.fps_meter = FPSMeter()
6760

6861
asyncio.create_task(self.collect_frames())
6962

70-
# Start metrics collection tasks.
71-
self._fps_stats_task = asyncio.create_task(self._calculate_fps_loop())
72-
7363
async def collect_frames(self):
64+
"""Continuously collect video frames from the underlying track and pass them to
65+
the processing pipeline.
66+
"""
7467
while True:
7568
try:
7669
frame = await self.track.recv()
@@ -79,86 +72,21 @@ async def collect_frames(self):
7972
await self.pipeline.cleanup()
8073
raise Exception(f"Error collecting video frames: {str(e)}")
8174

82-
async def _calculate_fps_loop(self):
83-
"""Loop to calculate FPS periodically."""
84-
await self._running_event.wait()
85-
self._fps_loop_start_time = time.monotonic()
86-
while self.readyState != "ended":
87-
async with self._lock:
88-
current_time = time.monotonic()
89-
if self._last_fps_calculation_time is not None:
90-
time_diff = current_time - self._last_fps_calculation_time
91-
self._fps = self._fps_interval_frame_count / time_diff
92-
self._fps_measurements.append(
93-
{
94-
"timestamp": current_time - self._fps_loop_start_time,
95-
"fps": self._fps,
96-
}
97-
) # Store the FPS measurement with timestamp
98-
99-
# Reset start_time and frame_count for the next interval.
100-
self._last_fps_calculation_time = current_time
101-
self._fps_interval_frame_count = 0
102-
await asyncio.sleep(1) # Calculate FPS every second.
103-
104-
@property
105-
async def fps(self) -> float:
106-
"""Get the current output frames per second (FPS).
107-
108-
Returns:
109-
The current output FPS.
110-
"""
111-
async with self._lock:
112-
return self._fps
113-
114-
@property
115-
async def fps_measurements(self) -> list:
116-
"""Get the array of FPS measurements for the last minute.
117-
118-
Returns:
119-
The array of FPS measurements for the last minute.
120-
"""
121-
async with self._lock:
122-
return list(self._fps_measurements)
123-
124-
@property
125-
async def average_fps(self) -> float:
126-
"""Calculate the average FPS from the measurements taken in the last minute.
127-
128-
Returns:
129-
The average FPS over the last minute.
130-
"""
131-
async with self._lock:
132-
if not self._fps_measurements:
133-
return 0.0
134-
return sum(
135-
measurement["fps"] for measurement in self._fps_measurements
136-
) / len(self._fps_measurements)
137-
138-
@property
139-
async def last_fps_calculation_time(self) -> float:
140-
"""Get the elapsed time since the last FPS calculation.
141-
142-
Returns:
143-
The elapsed time in seconds since the last FPS calculation.
144-
"""
145-
async with self._lock:
146-
return self._last_fps_calculation_time - self._fps_loop_start_time
147-
14875
async def recv(self):
76+
"""Receive a processed video frame from the pipeline, increment the frame
77+
count for FPS calculation and return the processed frame to the client.
78+
"""
14979
processed_frame = await self.pipeline.get_processed_video_frame()
15080

151-
# Increment frame count for FPS calculation.
152-
async with self._lock:
153-
self._fps_interval_frame_count += 1
154-
if not self._running_event.is_set():
155-
self._running_event.set()
81+
# Increment the frame count to calculate FPS.
82+
await self.fps_meter.increment_frame_count()
15683

15784
return processed_frame
15885

15986

16087
class AudioStreamTrack(MediaStreamTrack):
16188
kind = "audio"
89+
16290
def __init__(self, track: MediaStreamTrack, pipeline):
16391
super().__init__()
16492
self.track = track
@@ -257,30 +185,29 @@ async def offer(request):
257185
@pc.on("datachannel")
258186
def on_datachannel(channel):
259187
if channel.label == "control":
188+
260189
@channel.on("message")
261190
async def on_message(message):
262191
try:
263192
params = json.loads(message)
264193

265194
if params.get("type") == "get_nodes":
266195
nodes_info = await pipeline.get_nodes_info()
267-
response = {
268-
"type": "nodes_info",
269-
"nodes": nodes_info
270-
}
196+
response = {"type": "nodes_info", "nodes": nodes_info}
271197
channel.send(json.dumps(response))
272198
elif params.get("type") == "update_prompts":
273199
if "prompts" not in params:
274-
logger.warning("[Control] Missing prompt in update_prompt message")
200+
logger.warning(
201+
"[Control] Missing prompt in update_prompt message"
202+
)
275203
return
276204
await pipeline.update_prompts(params["prompts"])
277-
response = {
278-
"type": "prompts_updated",
279-
"success": True
280-
}
205+
response = {"type": "prompts_updated", "success": True}
281206
channel.send(json.dumps(response))
282207
else:
283-
logger.warning("[Server] Invalid message format - missing required fields")
208+
logger.warning(
209+
"[Server] Invalid message format - missing required fields"
210+
)
284211
except json.JSONDecodeError:
285212
logger.error("[Server] Invalid JSON received")
286213
except Exception as e:
@@ -389,8 +316,8 @@ async def on_shutdown(app: web.Application):
389316

390317
logging.basicConfig(
391318
level=args.log_level.upper(),
392-
format='%(asctime)s [%(levelname)s] %(message)s',
393-
datefmt='%H:%M:%S'
319+
format="%(asctime)s [%(levelname)s] %(message)s",
320+
datefmt="%H:%M:%S",
394321
)
395322

396323
app = web.Application()

server/utils.py

+104-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Utility functions for the server."""
2+
23
import asyncio
34
import random
45
import types
@@ -7,6 +8,8 @@
78
from aiohttp import web
89
from aiortc import MediaStreamTrack
910
from typing import List, Tuple, Any, Dict
11+
import time
12+
from collections import deque
1013

1114
logger = logging.getLogger(__name__)
1215

@@ -78,7 +81,9 @@ def __init__(self, app: web.Application):
7881
"""
7982
self._app = app
8083

81-
async def collect_video_metrics(self, video_track: MediaStreamTrack) -> Dict[str, Any]:
84+
async def collect_video_metrics(
85+
self, video_track: MediaStreamTrack
86+
) -> Dict[str, Any]:
8287
"""Collects real-time statistics for a video track.
8388
8489
Args:
@@ -88,10 +93,10 @@ async def collect_video_metrics(self, video_track: MediaStreamTrack) -> Dict[str
8893
A dictionary containing FPS-related statistics.
8994
"""
9095
return {
91-
"timestamp": await video_track.last_fps_calculation_time,
92-
"fps": await video_track.fps,
93-
"minute_avg_fps": await video_track.average_fps,
94-
"minute_fps_array": await video_track.fps_measurements,
96+
"timestamp": await video_track.fps_meter.last_fps_calculation_time,
97+
"fps": await video_track.fps_meter.fps,
98+
"minute_avg_fps": await video_track.fps_meter.average_fps,
99+
"minute_fps_array": await video_track.fps_meter.fps_measurements,
95100
}
96101

97102
async def collect_all_stream_metrics(self, _) -> web.Response:
@@ -132,3 +137,97 @@ async def collect_stream_metrics_by_id(self, request: web.Request) -> web.Respon
132137
content_type="application/json",
133138
text=json.dumps(stats),
134139
)
140+
141+
142+
class FPSMeter:
143+
"""Class to calculate and store the framerate of a stream by counting frames."""
144+
145+
def __init__(self):
146+
"""Initializes the FPSMeter class."""
147+
self._lock = asyncio.Lock()
148+
self._fps_interval_frame_count = 0
149+
self._last_fps_calculation_time = None
150+
self._fps_loop_start_time = None
151+
self._fps = 0.0
152+
self._fps_measurements = deque(maxlen=60)
153+
self._running_event = asyncio.Event()
154+
155+
asyncio.create_task(self._calculate_fps_loop())
156+
157+
async def _calculate_fps_loop(self):
158+
"""Loop to calculate FPS periodically."""
159+
await self._running_event.wait()
160+
self._fps_loop_start_time = time.monotonic()
161+
while True:
162+
async with self._lock:
163+
current_time = time.monotonic()
164+
if self._last_fps_calculation_time is not None:
165+
time_diff = current_time - self._last_fps_calculation_time
166+
self._fps = self._fps_interval_frame_count / time_diff
167+
self._fps_measurements.append(
168+
{
169+
"timestamp": current_time - self._fps_loop_start_time,
170+
"fps": self._fps,
171+
}
172+
) # Store the FPS measurement with timestamp
173+
174+
# Reset start_time and frame_count for the next interval.
175+
self._last_fps_calculation_time = current_time
176+
self._fps_interval_frame_count = 0
177+
await asyncio.sleep(1) # Calculate FPS every second.
178+
179+
async def increment_frame_count(self):
180+
"""Increment the frame count to calculate FPS."""
181+
async with self._lock:
182+
self._fps_interval_frame_count += 1
183+
if not self._running_event.is_set():
184+
self._running_event.set()
185+
186+
@property
187+
async def fps(self) -> float:
188+
"""Get the current output frames per second (FPS).
189+
190+
Returns:
191+
The current output FPS.
192+
"""
193+
async with self._lock:
194+
return self._fps
195+
196+
@property
197+
async def fps_measurements(self) -> list:
198+
"""Get the array of FPS measurements for the last minute.
199+
200+
Returns:
201+
The array of FPS measurements for the last minute.
202+
"""
203+
async with self._lock:
204+
return list(self._fps_measurements)
205+
206+
@property
207+
async def average_fps(self) -> float:
208+
"""Calculate the average FPS from the measurements taken in the last minute.
209+
210+
Returns:
211+
The average FPS over the last minute.
212+
"""
213+
async with self._lock:
214+
if not self._fps_measurements:
215+
return 0.0
216+
return sum(
217+
measurement["fps"] for measurement in self._fps_measurements
218+
) / len(self._fps_measurements)
219+
220+
@property
221+
async def last_fps_calculation_time(self) -> float:
222+
"""Get the elapsed time since the last FPS calculation.
223+
224+
Returns:
225+
The elapsed time in seconds since the last FPS calculation.
226+
"""
227+
async with self._lock:
228+
if (
229+
self._last_fps_calculation_time is None
230+
or self._fps_loop_start_time is None
231+
):
232+
return 0.0
233+
return self._last_fps_calculation_time - self._fps_loop_start_time

0 commit comments

Comments
 (0)