@@ -301,6 +301,13 @@ def state_dict_size(sd, exclude_device=None):
301
301
return module_mem
302
302
303
303
304
+ def state_dict_parameters (sd ):
305
+ module_mem = 0
306
+ for k , v in sd .items ():
307
+ module_mem += v .nelement ()
308
+ return module_mem
309
+
310
+
304
311
def state_dict_dtype (state_dict ):
305
312
for k , v in state_dict .items ():
306
313
if hasattr (v , 'is_gguf' ):
@@ -653,44 +660,22 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
653
660
654
661
for candidate in supported_dtypes :
655
662
if candidate == torch .float16 :
656
- if should_use_fp16 (device = device , model_params = model_params , manual_cast = True ):
663
+ if should_use_fp16 (device , model_params = model_params , prioritize_performance = True , manual_cast = True ):
657
664
return candidate
658
665
if candidate == torch .bfloat16 :
659
- if should_use_bf16 (device , model_params = model_params , manual_cast = True ):
666
+ if should_use_bf16 (device , model_params = model_params , prioritize_performance = True , manual_cast = True ):
660
667
return candidate
661
668
662
669
return torch .float32
663
670
664
671
665
- # None means no manual cast
666
- def unet_manual_cast (weight_dtype , inference_device , supported_dtypes = [torch .float16 , torch .bfloat16 , torch .float32 ]):
667
- if weight_dtype == torch .float32 :
668
- return None
669
-
670
- fp16_supported = should_use_fp16 (inference_device , prioritize_performance = False )
671
- if fp16_supported and weight_dtype == torch .float16 :
672
- return None
673
-
674
- bf16_supported = should_use_bf16 (inference_device )
675
- if bf16_supported and weight_dtype == torch .bfloat16 :
676
- return None
677
-
678
- if fp16_supported and torch .float16 in supported_dtypes :
679
- return torch .float16
680
-
681
- elif bf16_supported and torch .bfloat16 in supported_dtypes :
682
- return torch .bfloat16
683
- else :
684
- return torch .float32
685
-
686
-
687
- def get_computation_dtype (inference_device , supported_dtypes = [torch .float16 , torch .bfloat16 , torch .float32 ]):
672
+ def get_computation_dtype (inference_device , parameters = 0 , supported_dtypes = [torch .float16 , torch .bfloat16 , torch .float32 ]):
688
673
for candidate in supported_dtypes :
689
674
if candidate == torch .float16 :
690
- if should_use_fp16 (inference_device , prioritize_performance = False ):
675
+ if should_use_fp16 (inference_device , model_params = parameters , prioritize_performance = True , manual_cast = False ):
691
676
return candidate
692
677
if candidate == torch .bfloat16 :
693
- if should_use_bf16 (inference_device ):
678
+ if should_use_bf16 (inference_device , model_params = parameters , prioritize_performance = True , manual_cast = False ):
694
679
return candidate
695
680
696
681
return torch .float32
@@ -1020,19 +1005,17 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
1020
1005
if props .major < 6 :
1021
1006
return False
1022
1007
1023
- fp16_works = False
1024
- # FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
1025
- # when the model doesn't actually fit on the card
1026
- # TODO: actually test if GP106 and others have the same type of behavior
1027
1008
nvidia_10_series = ["1080" , "1070" , "titan x" , "p3000" , "p3200" , "p4000" , "p4200" , "p5000" , "p5200" , "p6000" , "1060" , "1050" , "p40" , "p100" , "p6" , "p4" ]
1028
1009
for x in nvidia_10_series :
1029
1010
if x in props .name .lower ():
1030
- fp16_works = True
1031
-
1032
- if fp16_works or manual_cast :
1033
- free_model_memory = (get_free_memory () * 0.9 - minimum_inference_memory ())
1034
- if (not prioritize_performance ) or model_params * 4 > free_model_memory :
1035
- return True
1011
+ if manual_cast :
1012
+ # For storage dtype
1013
+ free_model_memory = (get_free_memory () * 0.85 - minimum_inference_memory ())
1014
+ if (not prioritize_performance ) or model_params * 4 > free_model_memory :
1015
+ return True
1016
+ else :
1017
+ # For computation dtype
1018
+ return False # Flux on 1080 can store model in fp16 to reduce swap, but computation must be fp32, otherwise super slow.
1036
1019
1037
1020
if props .major < 7 :
1038
1021
return False
@@ -1080,7 +1063,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
1080
1063
bf16_works = torch .cuda .is_bf16_supported ()
1081
1064
1082
1065
if bf16_works or manual_cast :
1083
- free_model_memory = (get_free_memory () * 0.9 - minimum_inference_memory ())
1066
+ free_model_memory = (get_free_memory () * 0.85 - minimum_inference_memory ())
1084
1067
if (not prioritize_performance ) or model_params * 4 > free_model_memory :
1085
1068
return True
1086
1069
@@ -1116,43 +1099,3 @@ def soft_empty_cache(force=False):
1116
1099
1117
1100
def unload_all_models ():
1118
1101
free_memory (1e30 , get_torch_device ())
1119
-
1120
-
1121
- def resolve_lowvram_weight (weight , model , key ): # TODO: remove
1122
- return weight
1123
-
1124
-
1125
- # TODO: might be cleaner to put this somewhere else
1126
- import threading
1127
-
1128
-
1129
- class InterruptProcessingException (Exception ):
1130
- pass
1131
-
1132
-
1133
- interrupt_processing_mutex = threading .RLock ()
1134
-
1135
- interrupt_processing = False
1136
-
1137
-
1138
- def interrupt_current_processing (value = True ):
1139
- global interrupt_processing
1140
- global interrupt_processing_mutex
1141
- with interrupt_processing_mutex :
1142
- interrupt_processing = value
1143
-
1144
-
1145
- def processing_interrupted ():
1146
- global interrupt_processing
1147
- global interrupt_processing_mutex
1148
- with interrupt_processing_mutex :
1149
- return interrupt_processing
1150
-
1151
-
1152
- def throw_exception_if_processing_interrupted ():
1153
- global interrupt_processing
1154
- global interrupt_processing_mutex
1155
- with interrupt_processing_mutex :
1156
- if interrupt_processing :
1157
- interrupt_processing = False
1158
- raise InterruptProcessingException ()
0 commit comments