Skip to content

Commit 59754f7

Browse files
rickstaaeliteprox
andcommitted
fix: ensure patched torch graph is always synced on inference errors (#129)
* fix: ensure patched torch graph is always synced on inference errors This commit ensures that the patched torch graph remains synchronized even when an error occurs during inference, preventing potential inconsistencies and adds logging for controlnet tensor cloning errors --------- Co-authored-by: Elite <john@eliteencoder.net>
1 parent dac2a64 commit 59754f7

File tree

1 file changed

+19
-16
lines changed

1 file changed

+19
-16
lines changed

nodes/tensor_utils/prestartup_script.py

+19-16
Original file line numberDiff line numberDiff line change
@@ -28,26 +28,29 @@ def wrapped_control_merge(self, control, control_prev, output_dtype):
2828
}
2929

3030
# Get result from original merge function
31-
result = original_control_merge(self, control, control_prev, output_dtype)
32-
33-
# Clone all output tensors
34-
result = {
35-
k: [t.clone() if t is not None else None for t in v]
36-
for k, v in result.items()
37-
}
38-
39-
# Mark CUDA graph step at end
40-
if torch.cuda.is_available() and hasattr(torch.compiler, 'cudagraph_mark_step_begin'):
41-
torch.compiler.cudagraph_mark_step_begin()
42-
torch.cuda.synchronize()
43-
44-
return result
45-
31+
try:
32+
result = original_control_merge(self, control, control_prev, output_dtype)
33+
34+
# Clone all output tensors
35+
result = {
36+
k: [t.clone() if t is not None else None for t in v]
37+
for k, v in result.items()
38+
}
39+
return result
40+
except Exception as e:
41+
print(f"Error: Failed to clone ControlNet during inference: {str(e)}")
42+
raise e
43+
finally:
44+
# Mark CUDA graph step at end
45+
if torch.cuda.is_available() and hasattr(torch.compiler, 'cudagraph_mark_step_begin'):
46+
torch.compiler.cudagraph_mark_step_begin()
47+
torch.cuda.synchronize()
48+
4649
# Apply the patch
4750
ControlBase.control_merge = wrapped_control_merge
4851
print("Successfully patched ControlNet for torch.compile() compatibility")
4952
except Exception as e:
5053
print(f"Warning: Failed to patch ControlNet: {str(e)}")
5154

5255
# Apply patch when module is imported
53-
patch_controlnet_for_stream()
56+
patch_controlnet_for_stream()

0 commit comments

Comments
 (0)