Skip to content

Commit b85d01d

Browse files
committed
fix: update prompts
1 parent aa209f0 commit b85d01d

File tree

2 files changed

+95
-87
lines changed

2 files changed

+95
-87
lines changed

server/app.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ async def on_message(message):
155155
if "prompts" not in params:
156156
logger.warning("[Control] Missing prompt in update_prompt message")
157157
return
158-
pipeline.set_prompts(params["prompts"])
158+
await pipeline.update_prompts(params["prompts"])
159159
response = {
160160
"type": "prompts_updated",
161161
"success": True

src/comfystream/client.py

+94-86
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,27 @@ def __init__(self, max_workers: int = 1, **kwargs):
1717
config = Configuration(**kwargs)
1818
# TODO: Need to handle cleanup for EmbeddedComfyClient if not using async context manager?
1919
self.comfy_client = EmbeddedComfyClient(config, max_workers=max_workers)
20-
self.running_prompts = []
20+
self.running_prompts = {} # To be used for cancelling tasks
21+
self.current_prompts = []
2122

2223
async def set_prompts(self, prompts: List[PromptDictInput]):
23-
await self.cancel_running_tasks()
24-
for prompt in [convert_prompt(prompt) for prompt in prompts]:
25-
task = asyncio.create_task(self.run_prompt(prompt))
26-
self.running_prompts.append({"task": task, "prompt": prompt})
24+
self.current_prompts = [convert_prompt(prompt) for prompt in prompts]
25+
for idx in range(self.current_prompts):
26+
task = asyncio.create_task(self.run_prompt(idx))
27+
self.running_prompts[idx] = task
2728

28-
async def cancel_running_tasks(self):
29-
while self.running_prompts:
30-
task = self.running_prompts.pop()
31-
task["task"].cancel()
32-
await task["task"]
29+
async def update_prompts(self, prompts: List[PromptDictInput]):
30+
# TODO: currently under the assumption that only already running prompts are updated
31+
if len(prompts) != len(self.current_prompts):
32+
raise ValueError(
33+
"Number of updated prompts must match the number of currently running prompts."
34+
)
35+
self.current_prompts = [convert_prompt(prompt) for prompt in prompts]
3336

34-
async def run_prompt(self, prompt: PromptDictInput):
37+
async def run_prompt(self, prompt_index: int):
3538
while True:
3639
try:
37-
await self.comfy_client.queue_prompt(prompt)
40+
await self.comfy_client.queue_prompt(self.current_prompts[prompt_index])
3841
except Exception as e:
3942
logger.error(f"Error running prompt: {str(e)}")
4043
logger.error(f"Error type: {type(e)}")
@@ -61,87 +64,92 @@ async def get_available_nodes(self):
6164
try:
6265
from comfy.nodes.package import import_all_nodes_in_workspace
6366
nodes = import_all_nodes_in_workspace()
67+
68+
all_prompts_nodes_info = {}
6469

65-
# Get set of class types we need metadata for, excluding LoadTensor and SaveTensor
66-
needed_class_types = {
67-
node.get('class_type')
68-
for node in self.prompt.values()
69-
if node.get('class_type') not in ('LoadTensor', 'SaveTensor')
70-
}
71-
remaining_nodes = {
72-
node_id
73-
for node_id, node in self.prompt.items()
74-
if node.get('class_type') not in ('LoadTensor', 'SaveTensor')
75-
}
76-
nodes_info = {}
77-
78-
# Only process nodes until we've found all the ones we need
79-
for class_type, node_class in nodes.NODE_CLASS_MAPPINGS.items():
80-
if not remaining_nodes: # Exit early if we've found all needed nodes
81-
break
82-
83-
if class_type not in needed_class_types:
84-
continue
85-
86-
# Get metadata for this node type (same as original get_node_metadata)
87-
input_data = node_class.INPUT_TYPES() if hasattr(node_class, 'INPUT_TYPES') else {}
88-
input_info = {}
89-
90-
# Process required inputs
91-
if 'required' in input_data:
92-
for name, value in input_data['required'].items():
93-
if isinstance(value, tuple) and len(value) == 2:
94-
input_type, config = value
95-
input_info[name] = {
96-
'type': input_type,
97-
'required': True,
98-
'min': config.get('min', None),
99-
'max': config.get('max', None),
100-
'widget': config.get('widget', None)
101-
}
102-
else:
103-
logger.error(f"Unexpected structure for required input {name}: {value}")
104-
105-
# Process optional inputs
106-
if 'optional' in input_data:
107-
for name, value in input_data['optional'].items():
108-
if isinstance(value, tuple) and len(value) == 2:
109-
input_type, config = value
110-
input_info[name] = {
111-
'type': input_type,
112-
'required': False,
113-
'min': config.get('min', None),
114-
'max': config.get('max', None),
115-
'widget': config.get('widget', None)
116-
}
117-
else:
118-
logger.error(f"Unexpected structure for optional input {name}: {value}")
70+
for prompt_index, prompt in enumerate(self.current_prompts):
71+
# Get set of class types we need metadata for, excluding LoadTensor and SaveTensor
72+
needed_class_types = {
73+
node.get('class_type')
74+
for node in prompt.values()
75+
if node.get('class_type') not in ('LoadTensor', 'SaveTensor')
76+
}
77+
remaining_nodes = {
78+
node_id
79+
for node_id, node in prompt.items()
80+
if node.get('class_type') not in ('LoadTensor', 'SaveTensor')
81+
}
82+
nodes_info = {}
11983

120-
# Now process any nodes in our prompt that use this class_type
121-
for node_id in list(remaining_nodes):
122-
node = self.prompt[node_id]
123-
if node.get('class_type') != class_type:
84+
# Only process nodes until we've found all the ones we need
85+
for class_type, node_class in nodes.NODE_CLASS_MAPPINGS.items():
86+
if not remaining_nodes: # Exit early if we've found all needed nodes
87+
break
88+
89+
if class_type not in needed_class_types:
12490
continue
12591

126-
node_info = {
127-
'class_type': class_type,
128-
'inputs': {}
129-
}
92+
# Get metadata for this node type (same as original get_node_metadata)
93+
input_data = node_class.INPUT_TYPES() if hasattr(node_class, 'INPUT_TYPES') else {}
94+
input_info = {}
95+
96+
# Process required inputs
97+
if 'required' in input_data:
98+
for name, value in input_data['required'].items():
99+
if isinstance(value, tuple) and len(value) == 2:
100+
input_type, config = value
101+
input_info[name] = {
102+
'type': input_type,
103+
'required': True,
104+
'min': config.get('min', None),
105+
'max': config.get('max', None),
106+
'widget': config.get('widget', None)
107+
}
108+
else:
109+
logger.error(f"Unexpected structure for required input {name}: {value}")
130110

131-
if 'inputs' in node:
132-
for input_name, input_value in node['inputs'].items():
133-
node_info['inputs'][input_name] = {
134-
'value': input_value,
135-
'type': input_info.get(input_name, {}).get('type', 'unknown'),
136-
'min': input_info.get(input_name, {}).get('min', None),
137-
'max': input_info.get(input_name, {}).get('max', None),
138-
'widget': input_info.get(input_name, {}).get('widget', None)
139-
}
111+
# Process optional inputs
112+
if 'optional' in input_data:
113+
for name, value in input_data['optional'].items():
114+
if isinstance(value, tuple) and len(value) == 2:
115+
input_type, config = value
116+
input_info[name] = {
117+
'type': input_type,
118+
'required': False,
119+
'min': config.get('min', None),
120+
'max': config.get('max', None),
121+
'widget': config.get('widget', None)
122+
}
123+
else:
124+
logger.error(f"Unexpected structure for optional input {name}: {value}")
140125

141-
nodes_info[node_id] = node_info
142-
remaining_nodes.remove(node_id)
126+
# Now process any nodes in our prompt that use this class_type
127+
for node_id in list(remaining_nodes):
128+
node = self.prompt[node_id]
129+
if node.get('class_type') != class_type:
130+
continue
131+
132+
node_info = {
133+
'class_type': class_type,
134+
'inputs': {}
135+
}
136+
137+
if 'inputs' in node:
138+
for input_name, input_value in node['inputs'].items():
139+
node_info['inputs'][input_name] = {
140+
'value': input_value,
141+
'type': input_info.get(input_name, {}).get('type', 'unknown'),
142+
'min': input_info.get(input_name, {}).get('min', None),
143+
'max': input_info.get(input_name, {}).get('max', None),
144+
'widget': input_info.get(input_name, {}).get('widget', None)
145+
}
146+
147+
nodes_info[node_id] = node_info
148+
remaining_nodes.remove(node_id)
149+
150+
all_prompts_nodes_info[prompt_index] = nodes_info
143151

144-
return nodes_info
152+
return all_prompts_nodes_info
145153

146154
except Exception as e:
147155
logger.error(f"Error getting node info: {str(e)}")

0 commit comments

Comments
 (0)