Skip to content

Commit f182502

Browse files
committed
fix: merge conflicts
2 parents 8613ea5 + 8a6b528 commit f182502

20 files changed

+926
-288
lines changed

README.md

+28-7
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ This repo also includes a WebRTC server and UI that uses comfystream to support
88
- [Install package](#install-package)
99
- [Custom Nodes](#custom-nodes)
1010
- [Usage](#usage)
11+
- [Run tests](#run-tests)
1112
- [Run server](#run-server)
1213
- [Run UI](#run-ui)
1314
- [Limitations](#limitations)
@@ -48,24 +49,38 @@ pip install git+https://github.com/yondonfu/comfystream.git
4849

4950
## Custom Nodes
5051

51-
**tensor_utils**
52+
comfystream uses a few custom nodes to support running workflows.
5253

53-
Copy the `tensor_utils` nodes into the `custom_nodes` folder of your ComfyUI workspace:
54+
Copy the custom nodes into the `custom_nodes` folder of your ComfyUI workspace:
5455

5556
```
56-
cp -r nodes/tensor_utils custom_nodes
57+
cp -r nodes/* custom_nodes/
5758
```
5859

59-
For example, if you ComfyUI workspace is under `/home/user/ComfyUI`:
60+
For example, if your ComfyUI workspace is under `/home/user/ComfyUI`:
6061

6162
```
62-
cp -r nodes/tensor_utils /home/user/ComfyUI/custom_nodes
63+
cp -r nodes/* /home/user/ComfyUI/custom_nodes
6364
```
6465

6566
## Usage
6667

6768
See `example.py`.
6869

70+
# Run tests
71+
72+
Install dev dependencies:
73+
74+
```
75+
pip install .[dev]
76+
```
77+
78+
Run tests:
79+
80+
```
81+
pytest
82+
```
83+
6984
# Run server
7085

7186
Install dependencies:
@@ -144,9 +159,15 @@ The Stream URL is the URL of the [server](#run-server) which defaults to http://
144159

145160
At the moment, a workflow must fufill the following requirements:
146161

147-
- Single input using the LoadImage node
162+
- The workflow must have a single primary input node that will receive individual video frames
163+
- The primary input node is designed by one of the following:
164+
- A single [PrimaryInputLoadImage](./nodes/video_stream_utils/primary_input_load_image.py) node (see [this workflow](./workflows/liveportait.json) for example usage)
165+
- This node can be used as a drop-in replacement for a LoadImage node
166+
- In this scenario, any number of additional LoadImage nodes can be used
167+
- A single LoadImage node
168+
- In this scenario, the workflow can only contain the single LoadImage node
148169
- At runtime, this node is replaced with a LoadTensor node
149-
- Single output using a PreviewImage or SaveImage node
170+
- The workflow must have a single output using a PreviewImage or SaveImage node
150171
- At runtime, this node is replaced with a SaveTensor node
151172

152173
# Troubleshoot

audio_example.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ async def main():
1313

1414
client.set_prompt(prompt)
1515

16-
waveform, _ = torchaudio.load("harvard.wav")
16+
waveform, _ = torchaudio.load("/home/user/harvard.wav")
1717
if waveform.ndim > 1:
1818
audio_tensor = waveform.mean(dim=0)
1919

nodes/audio_utils/apply_whisper.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ def INPUT_TYPES(s):
1111
}
1212
}
1313

14+
CATEGORY = "audio_utils"
1415
RETURN_TYPES = ("DICT",)
1516
FUNCTION = "apply_whisper"
1617

@@ -33,23 +34,23 @@ def apply_whisper(self, audio, model):
3334
concatenated_audio = torch.cat(self.audio_buffer, dim=0).cuda()
3435
self.audio_buffer = []
3536
result = self.model.transcribe(concatenated_audio.float(), fp16=True, word_timestamps=True)
36-
segments = result['segments']
37+
segments = result["segments"]
3738
segments_alignment = []
3839
words_alignment = []
3940

4041
for segment in segments:
4142
segment_dict = {
42-
'value': segment['text'].strip(),
43-
'start': segment['start'],
44-
'end': segment['end']
43+
"value": segment["text"].strip(),
44+
"start": segment["start"],
45+
"end": segment["end"]
4546
}
4647
segments_alignment.append(segment_dict)
4748

4849
for word in segment["words"]:
4950
word_dict = {
50-
'value': word["word"].strip(),
51-
'start': word["start"],
52-
'end': word['end']
51+
"value": word["word"].strip(),
52+
"start": word["start"],
53+
"end": word["end"]
5354
}
5455
words_alignment.append(word_dict)
5556

nodes/audio_utils/load_audio_tensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from comfystream import tensor_cache
22

33
class LoadAudioTensor:
4-
CATEGORY = "tensor_utils"
4+
CATEGORY = "audio_utils"
55
RETURN_TYPES = ("AUDIO",)
66
FUNCTION = "execute"
77

nodes/audio_utils/save_asr_response.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from comfystream import tensor_cache
22

33
class SaveASRResponse:
4-
CATEGORY = "tensor_utils"
4+
CATEGORY = "audio_utils"
55
RETURN_TYPES = ()
66
FUNCTION = "execute"
7-
OUTPUT_NODE = True
87

98
@classmethod
109
def INPUT_TYPES(s):

nodes/video_stream_utils/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .primary_input_load_image import PrimaryInputLoadImage
2+
3+
NODE_CLASS_MAPPINGS = {"PrimaryInputLoadImage": PrimaryInputLoadImage}
4+
5+
__all__ = ["NODE_CLASS_MAPPINGS"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import nodes
2+
3+
4+
class PrimaryInputLoadImage(nodes.LoadImage):
5+
pass

server/app.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,15 @@
1616
MediaStreamTrack,
1717
)
1818
from aiortc.rtcrtpsender import RTCRtpSender
19+
from aiortc.codecs import h264
1920
from pipeline import Pipeline
2021
from utils import patch_loop_datagram
2122

2223
logger = logging.getLogger(__name__)
2324

25+
MAX_BITRATE = 2000000
26+
MIN_BITRATE = 2000000
27+
2428

2529
class VideoStreamTrack(MediaStreamTrack):
2630
kind = "video"
@@ -165,12 +169,12 @@ def get_ice_servers():
165169

166170

167171
async def offer(request):
172+
pipeline = request.app["pipeline"]
168173
pcs = request.app["pcs"]
169-
workspace = request.app["workspace"]
170174

171175
params = await request.json()
172176

173-
pipeline = Pipeline(params["prompt"], cwd=workspace)
177+
pipeline.set_prompt(params["prompt"])
174178
await pipeline.warm()
175179

176180
offer_params = params["offer"]
@@ -194,6 +198,10 @@ async def offer(request):
194198
prefs = list(filter(lambda x: x.name == "H264", caps.codecs))
195199
transceiver.setCodecPreferences(prefs)
196200

201+
# Monkey patch max and min bitrate to ensure constant bitrate
202+
h264.MAX_BITRATE = MAX_BITRATE
203+
h264.MIN_BITRATE = MIN_BITRATE
204+
197205
@pc.on("track")
198206
def on_track(track):
199207
logger.info(f"Track received: {track.kind}")
@@ -236,6 +244,15 @@ async def on_connectionstatechange():
236244
)
237245

238246

247+
async def set_prompt(request):
248+
pipeline = request.app["pipeline"]
249+
250+
prompt = await request.json()
251+
pipeline.set_prompt(prompt)
252+
253+
return web.Response(content_type="application/json", text="OK")
254+
255+
239256
def health(_):
240257
return web.Response(content_type="application/json", text="OK")
241258

@@ -244,6 +261,9 @@ async def on_startup(app: web.Application):
244261
if app["media_ports"]:
245262
patch_loop_datagram(app["media_ports"])
246263

264+
app["pipeline"] = Pipeline(
265+
cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True
266+
)
247267
app["pcs"] = set()
248268

249269

@@ -282,6 +302,7 @@ async def on_shutdown(app: web.Application):
282302
app.on_shutdown.append(on_shutdown)
283303

284304
app.router.add_post("/offer", offer)
305+
app.router.add_post("/prompt", set_prompt)
285306
app.router.add_get("/", health)
286307

287308
web.run_app(app, host=args.host, port=int(args.port))

server/pipeline.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,21 @@
55
from typing import Any, Dict
66
from comfystream.client import ComfyStreamClient
77

8+
WARMUP_RUNS = 5
9+
810

911
class Pipeline:
10-
def __init__(self, prompt: Dict[Any, Any], **kwargs):
12+
def __init__(self, **kwargs):
1113
self.client = ComfyStreamClient(**kwargs)
12-
self.client.set_prompt(prompt)
1314

1415
async def warm(self):
1516
frame = torch.randn(1, 512, 512, 3)
16-
await self.predict(frame)
17+
18+
for _ in range(WARMUP_RUNS):
19+
await self.predict(frame)
20+
21+
def set_prompt(self, prompt: Dict[Any, Any]):
22+
self.client.set_prompt(prompt)
1723

1824
def preprocess(self, frame: av.VideoFrame) -> torch.Tensor:
1925
frame_np = frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@
1919
"opentelemetry-semantic-conventions==0.48b0",
2020
"comfyui @ git+https://github.com/hiddenswitch/ComfyUI.git@89d07f3adf32a6703181343bc732bd85104bb653",
2121
],
22+
extras_require={"dev": ["pytest"]},
2223
url="https://github.com/yondonfu/comfystream",
2324
)

src/comfystream/utils.py

+51-19
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,85 @@
11
import copy
22

3+
from typing import Dict, Any
34
from comfy.api.components.schema.prompt import Prompt, PromptDictInput
45

56

7+
def create_load_tensor_node():
8+
return {
9+
"inputs": {},
10+
"class_type": "LoadTensor",
11+
"_meta": {"title": "LoadTensor"},
12+
}
13+
14+
15+
def create_save_tensor_node(inputs: Dict[Any, Any]):
16+
return {
17+
"inputs": inputs,
18+
"class_type": "SaveTensor",
19+
"_meta": {"title": "SaveTensor"},
20+
}
21+
22+
623
def convert_prompt(prompt: PromptDictInput) -> Prompt:
724
# Validate the schema
825
Prompt.validate(prompt)
926

1027
prompt = copy.deepcopy(prompt)
1128

29+
num_primary_inputs = 0
1230
num_inputs = 0
1331
num_outputs = 0
1432

33+
keys = {
34+
"PrimaryInputLoadImage": [],
35+
"LoadImage": [],
36+
"PreviewImage": [],
37+
"SaveImage": [],
38+
}
1539
for key, node in prompt.items():
16-
if node.get("class_type") == "LoadImage":
17-
num_inputs += 1
40+
class_type = node.get("class_type")
1841

19-
prompt[key] = {
20-
"inputs": {},
21-
"class_type": "LoadTensor",
22-
"_meta": {"title": "LoadTensor"},
23-
}
24-
elif node.get("class_type") in ["PreviewImage", "SaveImage"]:
25-
num_outputs += 1
42+
# Collect keys for nodes that might need to be replaced
43+
if class_type in keys:
44+
keys[class_type].append(key)
2645

27-
prompt[key] = {
28-
"inputs": node["inputs"],
29-
"class_type": "SaveTensor",
30-
"_meta": {"title": "SaveTensor"},
31-
}
32-
elif node.get("class_type") in ["LoadTensor", "LoadAudioTensor"]:
46+
# Count inputs and outputs
47+
if class_type == "PrimaryInputLoadImage":
48+
num_primary_inputs += 1
49+
elif class_type in ["LoadImage", "LoadTensor", "LoadAudioTensor"]:
3350
num_inputs += 1
34-
elif node.get("class_type") in ["SaveTensor", "SaveASRResponse"]:
51+
elif class_type in ["PreviewImage", "SaveImage", "SaveTensor", "SaveASRResponse"]:
3552
num_outputs += 1
3653

37-
# Only handle single input for now
38-
if num_inputs > 1:
54+
# Only handle single primary input
55+
if num_primary_inputs > 1:
56+
raise Exception("too many primary inputs in prompt")
57+
58+
# If there are no primary inputs, only handle single input
59+
if num_primary_inputs == 0 and num_inputs > 1:
3960
raise Exception("too many inputs in prompt")
4061

4162
# Only handle single output for now
4263
if num_outputs > 1:
4364
raise Exception("too many outputs in prompt")
4465

45-
if num_inputs == 0:
66+
if num_primary_inputs + num_inputs == 0:
4667
raise Exception("missing input")
4768

4869
if num_outputs == 0:
4970
raise Exception("missing output")
5071

72+
# Replace nodes
73+
for key in keys["PrimaryInputLoadImage"]:
74+
prompt[key] = create_load_tensor_node()
75+
76+
if num_primary_inputs == 0 and len(keys["LoadImage"]) == 1:
77+
prompt[keys["LoadImage"][0]] = create_load_tensor_node()
78+
79+
for key in keys["PreviewImage"] + keys["SaveImage"]:
80+
node = prompt[key]
81+
prompt[key] = create_save_tensor_node(node["inputs"])
82+
5183
# Validate the processed prompt input
5284
prompt = Prompt.validate(prompt)
5385

0 commit comments

Comments
 (0)