Skip to content

Commit 31bed67

Browse files
committed
change some dtype behaviors based on community feedbacks
only influence old devices like 1080/70/60/50. please remove cmd flags if you are on 1080/70/60/50 and previously used many cmd flags to tune performance
1 parent 2b1e736 commit 31bed67

File tree

4 files changed

+30
-87
lines changed

4 files changed

+30
-87
lines changed

backend/loader.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,10 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
107107
model_loader = lambda c: IntegratedFluxTransformer2DModel(**c)
108108

109109
unet_config = guess.unet_config.copy()
110-
state_dict_size = memory_management.state_dict_size(state_dict)
110+
state_dict_parameters = memory_management.state_dict_parameters(state_dict)
111111
state_dict_dtype = memory_management.state_dict_dtype(state_dict)
112112

113-
storage_dtype = memory_management.unet_dtype(model_params=state_dict_size, supported_dtypes=guess.supported_inference_dtypes)
113+
storage_dtype = memory_management.unet_dtype(model_params=state_dict_parameters, supported_dtypes=guess.supported_inference_dtypes)
114114

115115
unet_storage_dtype_overwrite = backend.args.dynamic_args.get('forge_unet_storage_dtype')
116116

@@ -140,15 +140,15 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
140140
print(f'Using GGUF state dict: {type_counts}')
141141

142142
load_device = memory_management.get_torch_device()
143-
computation_dtype = memory_management.get_computation_dtype(load_device, supported_dtypes=guess.supported_inference_dtypes)
143+
computation_dtype = memory_management.get_computation_dtype(load_device, parameters=state_dict_parameters, supported_dtypes=guess.supported_inference_dtypes)
144144
offload_device = memory_management.unet_offload_device()
145145

146146
if storage_dtype in ['nf4', 'fp4', 'gguf']:
147-
initial_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=computation_dtype)
147+
initial_device = memory_management.unet_inital_load_device(parameters=state_dict_parameters, dtype=computation_dtype)
148148
with using_forge_operations(device=initial_device, dtype=computation_dtype, manual_cast_enabled=False, bnb_dtype=storage_dtype):
149149
model = model_loader(unet_config)
150150
else:
151-
initial_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=storage_dtype)
151+
initial_device = memory_management.unet_inital_load_device(parameters=state_dict_parameters, dtype=storage_dtype)
152152
need_manual_cast = storage_dtype != computation_dtype
153153
to_args = dict(device=initial_device, dtype=storage_dtype)
154154

backend/memory_management.py

+21-78
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,13 @@ def state_dict_size(sd, exclude_device=None):
301301
return module_mem
302302

303303

304+
def state_dict_parameters(sd):
305+
module_mem = 0
306+
for k, v in sd.items():
307+
module_mem += v.nelement()
308+
return module_mem
309+
310+
304311
def state_dict_dtype(state_dict):
305312
for k, v in state_dict.items():
306313
if hasattr(v, 'is_gguf'):
@@ -653,44 +660,22 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
653660

654661
for candidate in supported_dtypes:
655662
if candidate == torch.float16:
656-
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
663+
if should_use_fp16(device, model_params=model_params, prioritize_performance=True, manual_cast=True):
657664
return candidate
658665
if candidate == torch.bfloat16:
659-
if should_use_bf16(device, model_params=model_params, manual_cast=True):
666+
if should_use_bf16(device, model_params=model_params, prioritize_performance=True, manual_cast=True):
660667
return candidate
661668

662669
return torch.float32
663670

664671

665-
# None means no manual cast
666-
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
667-
if weight_dtype == torch.float32:
668-
return None
669-
670-
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
671-
if fp16_supported and weight_dtype == torch.float16:
672-
return None
673-
674-
bf16_supported = should_use_bf16(inference_device)
675-
if bf16_supported and weight_dtype == torch.bfloat16:
676-
return None
677-
678-
if fp16_supported and torch.float16 in supported_dtypes:
679-
return torch.float16
680-
681-
elif bf16_supported and torch.bfloat16 in supported_dtypes:
682-
return torch.bfloat16
683-
else:
684-
return torch.float32
685-
686-
687-
def get_computation_dtype(inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
672+
def get_computation_dtype(inference_device, parameters=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
688673
for candidate in supported_dtypes:
689674
if candidate == torch.float16:
690-
if should_use_fp16(inference_device, prioritize_performance=False):
675+
if should_use_fp16(inference_device, model_params=parameters, prioritize_performance=True, manual_cast=False):
691676
return candidate
692677
if candidate == torch.bfloat16:
693-
if should_use_bf16(inference_device):
678+
if should_use_bf16(inference_device, model_params=parameters, prioritize_performance=True, manual_cast=False):
694679
return candidate
695680

696681
return torch.float32
@@ -1020,19 +1005,17 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
10201005
if props.major < 6:
10211006
return False
10221007

1023-
fp16_works = False
1024-
# FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
1025-
# when the model doesn't actually fit on the card
1026-
# TODO: actually test if GP106 and others have the same type of behavior
10271008
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
10281009
for x in nvidia_10_series:
10291010
if x in props.name.lower():
1030-
fp16_works = True
1031-
1032-
if fp16_works or manual_cast:
1033-
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
1034-
if (not prioritize_performance) or model_params * 4 > free_model_memory:
1035-
return True
1011+
if manual_cast:
1012+
# For storage dtype
1013+
free_model_memory = (get_free_memory() * 0.85 - minimum_inference_memory())
1014+
if (not prioritize_performance) or model_params * 4 > free_model_memory:
1015+
return True
1016+
else:
1017+
# For computation dtype
1018+
return False # Flux on 1080 can store model in fp16 to reduce swap, but computation must be fp32, otherwise super slow.
10361019

10371020
if props.major < 7:
10381021
return False
@@ -1080,7 +1063,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
10801063
bf16_works = torch.cuda.is_bf16_supported()
10811064

10821065
if bf16_works or manual_cast:
1083-
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
1066+
free_model_memory = (get_free_memory() * 0.85 - minimum_inference_memory())
10841067
if (not prioritize_performance) or model_params * 4 > free_model_memory:
10851068
return True
10861069

@@ -1116,43 +1099,3 @@ def soft_empty_cache(force=False):
11161099

11171100
def unload_all_models():
11181101
free_memory(1e30, get_torch_device())
1119-
1120-
1121-
def resolve_lowvram_weight(weight, model, key): # TODO: remove
1122-
return weight
1123-
1124-
1125-
# TODO: might be cleaner to put this somewhere else
1126-
import threading
1127-
1128-
1129-
class InterruptProcessingException(Exception):
1130-
pass
1131-
1132-
1133-
interrupt_processing_mutex = threading.RLock()
1134-
1135-
interrupt_processing = False
1136-
1137-
1138-
def interrupt_current_processing(value=True):
1139-
global interrupt_processing
1140-
global interrupt_processing_mutex
1141-
with interrupt_processing_mutex:
1142-
interrupt_processing = value
1143-
1144-
1145-
def processing_interrupted():
1146-
global interrupt_processing
1147-
global interrupt_processing_mutex
1148-
with interrupt_processing_mutex:
1149-
return interrupt_processing
1150-
1151-
1152-
def throw_exception_if_processing_interrupted():
1153-
global interrupt_processing
1154-
global interrupt_processing_mutex
1155-
with interrupt_processing_mutex:
1156-
if interrupt_processing:
1157-
interrupt_processing = False
1158-
raise InterruptProcessingException()

backend/patcher/controlnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def pre_run(self, model, percent_to_timestep_function):
438438

439439
self.manual_cast_dtype = model.computation_dtype
440440

441-
with using_forge_operations(operations=ControlLoraOps, dtype=dtype):
441+
with using_forge_operations(operations=ControlLoraOps, dtype=dtype, manual_cast_enabled=self.manual_cast_dtype != dtype):
442442
self.control_model = cldm.ControlNet(**controlnet_config)
443443

444444
self.control_model.to(device=memory_management.get_torch_device(), dtype=dtype)

modules_forge/supported_controlnet.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,12 @@ def try_build_from_state_dict(controlnet_data, ckpt_path):
110110
controlnet_config['dtype'] = unet_dtype
111111

112112
load_device = memory_management.get_torch_device()
113-
manual_cast_dtype = memory_management.unet_manual_cast(unet_dtype, load_device)
113+
computation_dtype = memory_management.get_computation_dtype(load_device)
114114

115115
controlnet_config.pop("out_channels")
116116
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
117117

118-
with using_forge_operations(dtype=unet_dtype):
118+
with using_forge_operations(dtype=unet_dtype, manual_cast_enabled=computation_dtype != unet_dtype):
119119
control_model = cldm.ControlNet(**controlnet_config).to(dtype=unet_dtype)
120120

121121
if pth:
@@ -139,7 +139,7 @@ class WeightsLoader(torch.nn.Module):
139139
# TODO: smarter way of enabling global_average_pooling
140140
global_average_pooling = True
141141

142-
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
142+
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=computation_dtype)
143143
return ControlNetPatcher(control)
144144

145145
def __init__(self, model_patcher):

0 commit comments

Comments
 (0)