|
1 | 1 | import copy
|
2 | 2 |
|
| 3 | +from typing import Dict, Any |
3 | 4 | from comfy.api.components.schema.prompt import Prompt, PromptDictInput
|
4 | 5 |
|
5 | 6 |
|
| 7 | +def create_load_tensor_node(): |
| 8 | + return { |
| 9 | + "inputs": {}, |
| 10 | + "class_type": "LoadTensor", |
| 11 | + "_meta": {"title": "LoadTensor"}, |
| 12 | + } |
| 13 | + |
| 14 | + |
| 15 | +def create_save_tensor_node(inputs: Dict[Any, Any]): |
| 16 | + return { |
| 17 | + "inputs": inputs, |
| 18 | + "class_type": "SaveTensor", |
| 19 | + "_meta": {"title": "SaveTensor"}, |
| 20 | + } |
| 21 | + |
| 22 | + |
6 | 23 | def convert_prompt(prompt: PromptDictInput) -> Prompt:
|
7 | 24 | # Validate the schema
|
8 | 25 | Prompt.validate(prompt)
|
9 | 26 |
|
10 | 27 | prompt = copy.deepcopy(prompt)
|
11 | 28 |
|
| 29 | + num_primary_inputs = 0 |
12 | 30 | num_inputs = 0
|
13 | 31 | num_outputs = 0
|
14 | 32 |
|
| 33 | + keys = { |
| 34 | + "PrimaryInputLoadImage": [], |
| 35 | + "LoadImage": [], |
| 36 | + "PreviewImage": [], |
| 37 | + "SaveImage": [], |
| 38 | + } |
15 | 39 | for key, node in prompt.items():
|
16 |
| - if node.get("class_type") == "LoadImage": |
17 |
| - num_inputs += 1 |
| 40 | + class_type = node.get("class_type") |
18 | 41 |
|
19 |
| - prompt[key] = { |
20 |
| - "inputs": {}, |
21 |
| - "class_type": "LoadTensor", |
22 |
| - "_meta": {"title": "LoadTensor"}, |
23 |
| - } |
24 |
| - elif node.get("class_type") in ["PreviewImage", "SaveImage"]: |
25 |
| - num_outputs += 1 |
| 42 | + # Collect keys for nodes that might need to be replaced |
| 43 | + if class_type in keys: |
| 44 | + keys[class_type].append(key) |
26 | 45 |
|
27 |
| - prompt[key] = { |
28 |
| - "inputs": node["inputs"], |
29 |
| - "class_type": "SaveTensor", |
30 |
| - "_meta": {"title": "SaveTensor"}, |
31 |
| - } |
32 |
| - elif node.get("class_type") in ["LoadTensor", "LoadAudioTensor"]: |
| 46 | + # Count inputs and outputs |
| 47 | + if class_type == "PrimaryInputLoadImage": |
| 48 | + num_primary_inputs += 1 |
| 49 | + elif class_type in ["LoadImage", "LoadTensor", "LoadAudioTensor"]: |
33 | 50 | num_inputs += 1
|
34 |
| - elif node.get("class_type") in ["SaveTensor", "SaveASRResponse"]: |
| 51 | + elif class_type in ["PreviewImage", "SaveImage", "SaveTensor", "SaveASRResponse"]: |
35 | 52 | num_outputs += 1
|
36 | 53 |
|
37 |
| - # Only handle single input for now |
38 |
| - if num_inputs > 1: |
| 54 | + # Only handle single primary input |
| 55 | + if num_primary_inputs > 1: |
| 56 | + raise Exception("too many primary inputs in prompt") |
| 57 | + |
| 58 | + # If there are no primary inputs, only handle single input |
| 59 | + if num_primary_inputs == 0 and num_inputs > 1: |
39 | 60 | raise Exception("too many inputs in prompt")
|
40 | 61 |
|
41 | 62 | # Only handle single output for now
|
42 | 63 | if num_outputs > 1:
|
43 | 64 | raise Exception("too many outputs in prompt")
|
44 | 65 |
|
45 |
| - if num_inputs == 0: |
| 66 | + if num_primary_inputs + num_inputs == 0: |
46 | 67 | raise Exception("missing input")
|
47 | 68 |
|
48 | 69 | if num_outputs == 0:
|
49 | 70 | raise Exception("missing output")
|
50 | 71 |
|
| 72 | + # Replace nodes |
| 73 | + for key in keys["PrimaryInputLoadImage"]: |
| 74 | + prompt[key] = create_load_tensor_node() |
| 75 | + |
| 76 | + if num_primary_inputs == 0 and len(keys["LoadImage"]) == 1: |
| 77 | + prompt[keys["LoadImage"][0]] = create_load_tensor_node() |
| 78 | + |
| 79 | + for key in keys["PreviewImage"] + keys["SaveImage"]: |
| 80 | + node = prompt[key] |
| 81 | + prompt[key] = create_save_tensor_node(node["inputs"]) |
| 82 | + |
51 | 83 | # Validate the processed prompt input
|
52 | 84 | prompt = Prompt.validate(prompt)
|
53 | 85 |
|
|
0 commit comments