Skip to content

Commit 171ec81

Browse files
committed
Merge branch 'main' into add_comfyui_log_level_arg
2 parents dc48a86 + 7513393 commit 171ec81

16 files changed

+381
-295
lines changed

.devcontainer/post-create.sh

+7-1
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,16 @@ cd /workspace/comfystream
77
echo -e "\e[32mInstalling Comfystream in editable mode...\e[0m"
88
/workspace/miniconda3/envs/comfystream/bin/python3 -m pip install -e . --root-user-action=ignore > /dev/null
99

10+
# Install npm packages if needed
11+
if [ ! -d "/workspace/comfystream/ui/node_modules" ]; then
12+
echo -e "\e[32mInstalling npm packages for Comfystream UI...\e[0m"
13+
cd /workspace/comfystream/ui
14+
npm install
15+
fi
16+
1017
if [ ! -d "/workspace/comfystream/nodes/web/static" ]; then
1118
echo -e "\e[32mBuilding web assets...\e[0m"
1219
cd /workspace/comfystream/ui
13-
npm install
1420
npm run build
1521
fi
1622

.github/workflows/comfyui-base.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
ref: ${{ github.event.pull_request.head.sha }}
4141

4242
- name: Login to DockerHub
43-
uses: docker/login-action@v2
43+
uses: docker/login-action@v3
4444
with:
4545
username: ${{ secrets.CI_DOCKERHUB_USERNAME }}
4646
password: ${{ secrets.CI_DOCKERHUB_TOKEN }}

.github/workflows/docker.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ jobs:
3030
ref: ${{ github.event.pull_request.head.sha }}
3131

3232
- name: Login to DockerHub
33-
uses: docker/login-action@v2
33+
uses: docker/login-action@v3
3434
with:
3535
username: ${{ secrets.CI_DOCKERHUB_USERNAME }}
3636
password: ${{ secrets.CI_DOCKERHUB_TOKEN }}

docker/entrypoint.sh

-5
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,6 @@ if [ "$1" = "--build-engines" ]; then
6262
shift
6363
fi
6464

65-
# Install npm packages if needed
66-
cd /workspace/comfystream/ui
67-
if [ ! -d "node_modules" ]; then
68-
npm install --legacy-peer-deps
69-
fi
7065

7166
if [ "$1" = "--server" ]; then
7267
/usr/bin/supervisord -c /etc/supervisor/supervisord.conf

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@ license = { file = "LICENSE" }
1010
dependencies = [
1111
"asyncio",
1212
"comfyui @ git+https://github.com/hiddenswitch/ComfyUI.git@ce3583ad42c024b8f060d0002cbe20c265da6dc8",
13-
"toml",
1413
"aiortc",
1514
"aiohttp",
15+
"toml",
1616
"twilio",
17+
"prometheus_client",
1718
]
1819

1920
[project.optional-dependencies]

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ comfyui @ git+https://github.com/hiddenswitch/ComfyUI.git@ce3583ad42c024b8f060d0
33
aiortc
44
aiohttp
55
toml
6-
twilio
6+
twilio
7+
prometheus_client

server/app.py

+33-5
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
from aiortc.rtcrtpsender import RTCRtpSender
2525
from pipeline import Pipeline
2626
from twilio.rest import Client
27-
from utils import FPSMeter, StreamStats, add_prefix_to_app_routes, patch_loop_datagram
27+
from utils import patch_loop_datagram, add_prefix_to_app_routes, FPSMeter
28+
from metrics import MetricsManager, StreamStatsManager
29+
import time
2830

2931
logger = logging.getLogger(__name__)
3032
logging.getLogger("aiortc.rtcrtpsender").setLevel(logging.WARNING)
@@ -56,7 +58,9 @@ def __init__(self, track: MediaStreamTrack, pipeline: Pipeline):
5658
super().__init__()
5759
self.track = track
5860
self.pipeline = pipeline
59-
self.fps_meter = FPSMeter()
61+
self.fps_meter = FPSMeter(
62+
metrics_manager=app["metrics_manager"], track_id=track.id
63+
)
6064

6165
asyncio.create_task(self.collect_frames())
6266

@@ -327,6 +331,18 @@ async def on_shutdown(app: web.Application):
327331
choices=logging._nameToLevel.keys(),
328332
help="Set the logging level for ComfyUI inference",
329333
)
334+
parser.add_argument(
335+
"--monitor",
336+
default=False,
337+
action="store_true",
338+
help="Start a Prometheus metrics endpoint for monitoring.",
339+
)
340+
parser.add_argument(
341+
"--stream-id-label",
342+
default=False,
343+
action="store_true",
344+
help="Include stream ID as a label in Prometheus metrics.",
345+
)
330346
args = parser.parse_args()
331347

332348
logging.basicConfig(
@@ -350,11 +366,23 @@ async def on_shutdown(app: web.Application):
350366
app.router.add_post("/prompt", set_prompt)
351367

352368
# Add routes for getting stream statistics.
353-
stream_stats = StreamStats(app)
354-
app.router.add_get("/streams/stats", stream_stats.collect_all_stream_metrics)
369+
stream_stats_manager = StreamStatsManager(app)
355370
app.router.add_get(
356-
"/stream/{stream_id}/stats", stream_stats.collect_stream_metrics_by_id
371+
"/streams/stats", stream_stats_manager.collect_all_stream_metrics
357372
)
373+
app.router.add_get(
374+
"/stream/{stream_id}/stats", stream_stats_manager.collect_stream_metrics_by_id
375+
)
376+
377+
# Add Prometheus metrics endpoint.
378+
app["metrics_manager"] = MetricsManager(include_stream_id=args.stream_id_label)
379+
if args.monitor:
380+
app["metrics_manager"].enable()
381+
logger.info(
382+
f"Monitoring enabled - Prometheus metrics available at: "
383+
f"http://{args.host}:{args.port}/metrics"
384+
)
385+
app.router.add_get("/metrics", app["metrics_manager"].metrics_handler)
358386

359387
# Add hosted platform route prefix.
360388
# NOTE: This ensures that the local and hosted experiences have consistent routes.

server/metrics/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .prometheus_metrics import MetricsManager
2+
from .stream_stats import StreamStatsManager

server/metrics/prometheus_metrics.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""Prometheus metrics utilities."""
2+
3+
from prometheus_client import Gauge, generate_latest
4+
from aiohttp import web
5+
from typing import Optional
6+
7+
8+
class MetricsManager:
9+
"""Manages Prometheus metrics collection."""
10+
11+
def __init__(self, include_stream_id: bool = False):
12+
"""Initializes the MetricsManager class.
13+
14+
Args:
15+
include_stream_id: Whether to include the stream ID as a label in the metrics.
16+
"""
17+
self._enabled = False
18+
self._include_stream_id = include_stream_id
19+
20+
base_labels = ["stream_id"] if include_stream_id else []
21+
self._fps_gauge = Gauge(
22+
"stream_fps", "Frames per second of the stream", base_labels
23+
)
24+
25+
def enable(self):
26+
"""Enable Prometheus metrics collection."""
27+
self._enabled = True
28+
29+
def update_fps_metrics(self, fps: float, stream_id: Optional[str] = None):
30+
"""Update Prometheus metrics for a given stream.
31+
32+
Args:
33+
fps: The current frames per second.
34+
stream_id: The ID of the stream.
35+
"""
36+
if self._enabled:
37+
if self._include_stream_id:
38+
self._fps_gauge.labels(stream_id=stream_id or "").set(fps)
39+
else:
40+
self._fps_gauge.set(fps)
41+
42+
async def metrics_handler(self, _):
43+
"""Handle Prometheus metrics endpoint."""
44+
return web.Response(body=generate_latest(), content_type="text/plain")

server/metrics/stream_stats.py

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Handles real-time video stream statistics (non-Prometheus, JSON API)."""
2+
3+
from typing import Any, Dict
4+
import json
5+
from aiohttp import web
6+
from aiortc import MediaStreamTrack
7+
8+
9+
class StreamStatsManager:
10+
"""Handles real-time video stream statistics collection."""
11+
12+
def __init__(self, app: web.Application):
13+
"""Initializes the StreamMetrics class.
14+
15+
Args:
16+
app: The web application instance storing stream tracks.
17+
"""
18+
self._app = app
19+
20+
async def collect_video_metrics(
21+
self, video_track: MediaStreamTrack
22+
) -> Dict[str, Any]:
23+
"""Collects real-time statistics for a video track.
24+
25+
Args:
26+
video_track: The video stream track instance.
27+
28+
Returns:
29+
A dictionary containing FPS-related statistics.
30+
"""
31+
return {
32+
"timestamp": await video_track.fps_meter.last_fps_calculation_time,
33+
"fps": await video_track.fps_meter.fps,
34+
"minute_avg_fps": await video_track.fps_meter.average_fps,
35+
"minute_fps_array": await video_track.fps_meter.fps_measurements,
36+
}
37+
38+
async def collect_all_stream_metrics(self, _) -> web.Response:
39+
"""Retrieves real-time metrics for all active video streams.
40+
41+
Returns:
42+
A JSON response containing FPS statistics for all streams.
43+
"""
44+
video_tracks = self._app.get("video_tracks", {})
45+
all_stats = {
46+
stream_id: await self.collect_video_metrics(track)
47+
for stream_id, track in video_tracks.items()
48+
}
49+
50+
return web.Response(
51+
content_type="application/json",
52+
text=json.dumps(all_stats),
53+
)
54+
55+
async def collect_stream_metrics_by_id(self, request: web.Request) -> web.Response:
56+
"""Retrieves real-time metrics for a specific video stream by ID.
57+
58+
Args:
59+
request: The HTTP request containing the stream ID.
60+
61+
Returns:
62+
A JSON response with stream metrics or an error message.
63+
"""
64+
stream_id = request.match_info.get("stream_id")
65+
video_tracks = self._app.get("video_tracks", {})
66+
video_track = video_tracks.get(stream_id)
67+
68+
if video_track:
69+
stats = await self.collect_video_metrics(video_track)
70+
else:
71+
stats = {"error": "Stream not found"}
72+
73+
return web.Response(
74+
content_type="application/json",
75+
text=json.dumps(stats),
76+
)

server/pipeline.py

+10-25
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@ async def warm_video(self):
4040

4141
async def warm_audio(self):
4242
dummy_frame = av.AudioFrame()
43-
dummy_frame.side_data.input = np.random.randint(
44-
-32768, 32767, int(48000 * 0.5), dtype=np.int16
45-
) # TODO: adds a lot of delay if it doesn't match the buffer size, is warmup needed?
43+
dummy_frame.side_data.input = np.random.randint(-32768, 32767, int(48000 * 0.5), dtype=np.int16) # TODO: adds a lot of delay if it doesn't match the buffer size, is warmup needed?
4644
dummy_frame.sample_rate = 48000
4745

4846
for _ in range(WARMUP_RUNS):
@@ -55,9 +53,7 @@ async def set_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]
5553
else:
5654
await self.client.set_prompts([prompts])
5755

58-
async def update_prompts(
59-
self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]
60-
):
56+
async def update_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]):
6157
if isinstance(prompts, list):
6258
await self.client.update_prompts(prompts)
6359
else:
@@ -82,21 +78,12 @@ def video_preprocess(self, frame: av.VideoFrame) -> Union[torch.Tensor, np.ndarr
8278
def audio_preprocess(self, frame: av.AudioFrame) -> Union[torch.Tensor, np.ndarray]:
8379
return frame.to_ndarray().ravel().reshape(-1, 2).mean(axis=1).astype(np.int16)
8480

85-
def video_postprocess(
86-
self, output: Union[torch.Tensor, np.ndarray]
87-
) -> av.VideoFrame:
81+
def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame:
8882
return av.VideoFrame.from_ndarray(
89-
(output * 255.0)
90-
.clamp(0, 255)
91-
.to(dtype=torch.uint8)
92-
.squeeze(0)
93-
.cpu()
94-
.numpy()
83+
(output * 255.0).clamp(0, 255).to(dtype=torch.uint8).squeeze(0).cpu().numpy()
9584
)
9685

97-
def audio_postprocess(
98-
self, output: Union[torch.Tensor, np.ndarray]
99-
) -> av.AudioFrame:
86+
def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.AudioFrame:
10087
return av.AudioFrame.from_ndarray(np.repeat(output, 2).reshape(1, -1))
10188

10289
async def get_processed_video_frame(self):
@@ -107,7 +94,7 @@ async def get_processed_video_frame(self):
10794
while frame.side_data.skipped:
10895
frame = await self.video_incoming_frames.get()
10996

110-
processed_frame = self.video_postprocess(out_tensor)
97+
processed_frame = self.video_postprocess(out_tensor)
11198
processed_frame.pts = frame.pts
11299
processed_frame.time_base = frame.time_base
113100

@@ -119,17 +106,15 @@ async def get_processed_audio_frame(self):
119106
if frame.samples > len(self.processed_audio_buffer):
120107
async with temporary_log_level("comfy", self._comfyui_inference_log_level):
121108
out_tensor = await self.client.get_audio_output()
122-
self.processed_audio_buffer = np.concatenate(
123-
[self.processed_audio_buffer, out_tensor]
124-
)
125-
out_data = self.processed_audio_buffer[: frame.samples]
126-
self.processed_audio_buffer = self.processed_audio_buffer[frame.samples :]
109+
self.processed_audio_buffer = np.concatenate([self.processed_audio_buffer, out_tensor])
110+
out_data = self.processed_audio_buffer[:frame.samples]
111+
self.processed_audio_buffer = self.processed_audio_buffer[frame.samples:]
127112

128113
processed_frame = self.audio_postprocess(out_data)
129114
processed_frame.pts = frame.pts
130115
processed_frame.time_base = frame.time_base
131116
processed_frame.sample_rate = frame.sample_rate
132-
117+
133118
return processed_frame
134119

135120
async def get_nodes_info(self) -> Dict[str, Any]:

0 commit comments

Comments
 (0)