@@ -17,24 +17,27 @@ def __init__(self, max_workers: int = 1, **kwargs):
17
17
config = Configuration (** kwargs )
18
18
# TODO: Need to handle cleanup for EmbeddedComfyClient if not using async context manager?
19
19
self .comfy_client = EmbeddedComfyClient (config , max_workers = max_workers )
20
- self .running_prompts = []
20
+ self .running_prompts = {} # To be used for cancelling tasks
21
+ self .current_prompts = []
21
22
22
23
async def set_prompts (self , prompts : List [PromptDictInput ]):
23
- await self .cancel_running_tasks ()
24
- for prompt in [ convert_prompt ( prompt ) for prompt in prompts ] :
25
- task = asyncio .create_task (self .run_prompt (prompt ))
26
- self .running_prompts . append ({ "task" : task , "prompt" : prompt })
24
+ self .current_prompts = [ convert_prompt ( prompt ) for prompt in prompts ]
25
+ for idx in range ( self . current_prompts ) :
26
+ task = asyncio .create_task (self .run_prompt (idx ))
27
+ self .running_prompts [ idx ] = task
27
28
28
- async def cancel_running_tasks (self ):
29
- while self .running_prompts :
30
- task = self .running_prompts .pop ()
31
- task ["task" ].cancel ()
32
- await task ["task" ]
29
+ async def update_prompts (self , prompts : List [PromptDictInput ]):
30
+ # TODO: currently under the assumption that only already running prompts are updated
31
+ if len (prompts ) != len (self .current_prompts ):
32
+ raise ValueError (
33
+ "Number of updated prompts must match the number of currently running prompts."
34
+ )
35
+ self .current_prompts = [convert_prompt (prompt ) for prompt in prompts ]
33
36
34
- async def run_prompt (self , prompt : PromptDictInput ):
37
+ async def run_prompt (self , prompt_index : int ):
35
38
while True :
36
39
try :
37
- await self .comfy_client .queue_prompt (prompt )
40
+ await self .comfy_client .queue_prompt (self . current_prompts [ prompt_index ] )
38
41
except Exception as e :
39
42
logger .error (f"Error running prompt: { str (e )} " )
40
43
logger .error (f"Error type: { type (e )} " )
@@ -61,87 +64,92 @@ async def get_available_nodes(self):
61
64
try :
62
65
from comfy .nodes .package import import_all_nodes_in_workspace
63
66
nodes = import_all_nodes_in_workspace ()
67
+
68
+ all_prompts_nodes_info = {}
64
69
65
- # Get set of class types we need metadata for, excluding LoadTensor and SaveTensor
66
- needed_class_types = {
67
- node .get ('class_type' )
68
- for node in self .prompt .values ()
69
- if node .get ('class_type' ) not in ('LoadTensor' , 'SaveTensor' )
70
- }
71
- remaining_nodes = {
72
- node_id
73
- for node_id , node in self .prompt .items ()
74
- if node .get ('class_type' ) not in ('LoadTensor' , 'SaveTensor' )
75
- }
76
- nodes_info = {}
77
-
78
- # Only process nodes until we've found all the ones we need
79
- for class_type , node_class in nodes .NODE_CLASS_MAPPINGS .items ():
80
- if not remaining_nodes : # Exit early if we've found all needed nodes
81
- break
82
-
83
- if class_type not in needed_class_types :
84
- continue
85
-
86
- # Get metadata for this node type (same as original get_node_metadata)
87
- input_data = node_class .INPUT_TYPES () if hasattr (node_class , 'INPUT_TYPES' ) else {}
88
- input_info = {}
89
-
90
- # Process required inputs
91
- if 'required' in input_data :
92
- for name , value in input_data ['required' ].items ():
93
- if isinstance (value , tuple ) and len (value ) == 2 :
94
- input_type , config = value
95
- input_info [name ] = {
96
- 'type' : input_type ,
97
- 'required' : True ,
98
- 'min' : config .get ('min' , None ),
99
- 'max' : config .get ('max' , None ),
100
- 'widget' : config .get ('widget' , None )
101
- }
102
- else :
103
- logger .error (f"Unexpected structure for required input { name } : { value } " )
104
-
105
- # Process optional inputs
106
- if 'optional' in input_data :
107
- for name , value in input_data ['optional' ].items ():
108
- if isinstance (value , tuple ) and len (value ) == 2 :
109
- input_type , config = value
110
- input_info [name ] = {
111
- 'type' : input_type ,
112
- 'required' : False ,
113
- 'min' : config .get ('min' , None ),
114
- 'max' : config .get ('max' , None ),
115
- 'widget' : config .get ('widget' , None )
116
- }
117
- else :
118
- logger .error (f"Unexpected structure for optional input { name } : { value } " )
70
+ for prompt_index , prompt in enumerate (self .current_prompts ):
71
+ # Get set of class types we need metadata for, excluding LoadTensor and SaveTensor
72
+ needed_class_types = {
73
+ node .get ('class_type' )
74
+ for node in prompt .values ()
75
+ if node .get ('class_type' ) not in ('LoadTensor' , 'SaveTensor' )
76
+ }
77
+ remaining_nodes = {
78
+ node_id
79
+ for node_id , node in prompt .items ()
80
+ if node .get ('class_type' ) not in ('LoadTensor' , 'SaveTensor' )
81
+ }
82
+ nodes_info = {}
119
83
120
- # Now process any nodes in our prompt that use this class_type
121
- for node_id in list (remaining_nodes ):
122
- node = self .prompt [node_id ]
123
- if node .get ('class_type' ) != class_type :
84
+ # Only process nodes until we've found all the ones we need
85
+ for class_type , node_class in nodes .NODE_CLASS_MAPPINGS .items ():
86
+ if not remaining_nodes : # Exit early if we've found all needed nodes
87
+ break
88
+
89
+ if class_type not in needed_class_types :
124
90
continue
125
91
126
- node_info = {
127
- 'class_type' : class_type ,
128
- 'inputs' : {}
129
- }
92
+ # Get metadata for this node type (same as original get_node_metadata)
93
+ input_data = node_class .INPUT_TYPES () if hasattr (node_class , 'INPUT_TYPES' ) else {}
94
+ input_info = {}
95
+
96
+ # Process required inputs
97
+ if 'required' in input_data :
98
+ for name , value in input_data ['required' ].items ():
99
+ if isinstance (value , tuple ) and len (value ) == 2 :
100
+ input_type , config = value
101
+ input_info [name ] = {
102
+ 'type' : input_type ,
103
+ 'required' : True ,
104
+ 'min' : config .get ('min' , None ),
105
+ 'max' : config .get ('max' , None ),
106
+ 'widget' : config .get ('widget' , None )
107
+ }
108
+ else :
109
+ logger .error (f"Unexpected structure for required input { name } : { value } " )
130
110
131
- if 'inputs' in node :
132
- for input_name , input_value in node ['inputs' ].items ():
133
- node_info ['inputs' ][input_name ] = {
134
- 'value' : input_value ,
135
- 'type' : input_info .get (input_name , {}).get ('type' , 'unknown' ),
136
- 'min' : input_info .get (input_name , {}).get ('min' , None ),
137
- 'max' : input_info .get (input_name , {}).get ('max' , None ),
138
- 'widget' : input_info .get (input_name , {}).get ('widget' , None )
139
- }
111
+ # Process optional inputs
112
+ if 'optional' in input_data :
113
+ for name , value in input_data ['optional' ].items ():
114
+ if isinstance (value , tuple ) and len (value ) == 2 :
115
+ input_type , config = value
116
+ input_info [name ] = {
117
+ 'type' : input_type ,
118
+ 'required' : False ,
119
+ 'min' : config .get ('min' , None ),
120
+ 'max' : config .get ('max' , None ),
121
+ 'widget' : config .get ('widget' , None )
122
+ }
123
+ else :
124
+ logger .error (f"Unexpected structure for optional input { name } : { value } " )
140
125
141
- nodes_info [node_id ] = node_info
142
- remaining_nodes .remove (node_id )
126
+ # Now process any nodes in our prompt that use this class_type
127
+ for node_id in list (remaining_nodes ):
128
+ node = self .prompt [node_id ]
129
+ if node .get ('class_type' ) != class_type :
130
+ continue
131
+
132
+ node_info = {
133
+ 'class_type' : class_type ,
134
+ 'inputs' : {}
135
+ }
136
+
137
+ if 'inputs' in node :
138
+ for input_name , input_value in node ['inputs' ].items ():
139
+ node_info ['inputs' ][input_name ] = {
140
+ 'value' : input_value ,
141
+ 'type' : input_info .get (input_name , {}).get ('type' , 'unknown' ),
142
+ 'min' : input_info .get (input_name , {}).get ('min' , None ),
143
+ 'max' : input_info .get (input_name , {}).get ('max' , None ),
144
+ 'widget' : input_info .get (input_name , {}).get ('widget' , None )
145
+ }
146
+
147
+ nodes_info [node_id ] = node_info
148
+ remaining_nodes .remove (node_id )
149
+
150
+ all_prompts_nodes_info [prompt_index ] = nodes_info
143
151
144
- return nodes_info
152
+ return all_prompts_nodes_info
145
153
146
154
except Exception as e :
147
155
logger .error (f"Error getting node info: { str (e )} " )
0 commit comments