-
Notifications
You must be signed in to change notification settings - Fork 352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add flux support #356
add flux support #356
Conversation
EDIT 2: I was just being stupid, nevermind |
Which model did you use and how much VRAM does your GPU have? Could be a memory issue since these models are pretty large. |
I was trying to run it on CPU. With 32 GB of RAM + lots of swap |
Anyways, I figured out I was just on the wrong commit for the ggml submodule, checking out the correct one fixed the compilation and now it works! |
It's almost twice as fast as ComfyUI's implementation of GGUF support for Flux. On my Ryzen9 5900x (no GPU) with q4_1 model (512² resolution):
Great job @leejet ! |
Looks like conversion on cpu with 32gigs or ram + swap is not enough.
update: added a 32gig swapfile, conversion works now |
Worked with CPU-only build, crashed with core dump when built with cuda, probably due to low 4GB VRAM on my GTX GPU. Thanks for the nice work! |
@leejet it would be nice if sd.cpp supported llama.cpp tensor naming conventions. Since text encoders exploded in size and now consume substantial amounts of resources, making use of ggml quantizations would be very usable. So I went and tried loading the q8_0 t5xxl from here https://huggingface.co/city96/t5-v1_1-xxl-encoder-gguf/tree/main but it does not load. looking at the log it becomes obvious quick:
I think city96's conversion is using llama.cpp's tensor name convention. |
I agree being able to use llama.cpp quants could be great, though you can always quantize the t5 encoder with stable-diffusion.cpp yourself and get a working gguf. |
Something seems to be really wrong with flux rendering on stable-diffusion.cpp backend. With q4_0 quantization i run out of vram and the program crashes. I have 8gb of vram and on ComfyUI i can render resolutions of 1152x896 without problems at 6s/it, with out crashes. I can even use q5_0 quantized flux model without crashing. The weird thing is that with sd models including sdxl, everything renders fine and fast with memory efficiency on stable-diffusion.cpp, so i really wonder why flux acts this way on this backend. And another thing i have noticed is that when the nvidia driver tries to send a part of the model to shared vram, it looks like the clip models get unloaded from ram, causing the program to just hang at he sampling stage. |
You can perform the quantization yourself. |
You are right I tried it the wrong way first. Here is I can not spot a difference to f16 t5xxl, so I recommend this over f16 in any case. However, it does look like it is not using less memory.
edit: using bit still acceptable. |
I wanted to test a fine tune and merge(?) of flux.1 schnell and dev (??), but it contained will make a pr in a bit. edit: pr here #359 |
in the same way it is done for bf16 like how bf16 converts losslessly to fp32, f8_e4m3 converts losslessly to fp16
I tried uploading some quantized models to Hugging Face, but no matter which network I use, the upload speed is limited to 3 Mbps. |
LoRA support has been added! |
Yea I also had issues with uploads being canceled all the time... no idea why. |
I also uploaded a f16 conversion of the vae, it looks almost lossless to me. |
Even a q2_k vae looks good enough. |
Does this Also work with AMD rocm ? |
If you look at the file sizes, it blocks anything lower than f16, so you are looking at f16. |
Not sure if anyone tried yet, but you can grab a build from here https://github.com/leejet/stable-diffusion.cpp/releases/tag/master-64d231f (if you run windows) |
Thank you very much for uploading flux schnell gguf. Could you upload clip_l.safetensors or clip_l.gguf for this model, please? |
Sure, I uploaded gguf If you want the safetensors, check the op for a link. edit: I am not seeing much of a difference between |
Thanks. Tested it with this command-line parameters: sd.exe --diffusion-model ./models/flux1-schnell-q2_k.gguf --vae ./models/ae-f16.gguf --clip_l ./models/clip_l-f16.gguf --t5xxl ./models/t5xxl_q2_k.gguf -p "a lovely cat holding a sign says 'flux.cpp'" -t 8 --steps 4 --cfg-scale 1.0 --sampling-method euler -v My system configuration: Ryzen 7 4700u, igpu Vega 7, 16gb ram, ssd, Windows 11. Image generation took 520 seconds or so. Each step took 110 seconds. Hope that kobold.cpp will upgrade it's stable-duffusion plugin to support flux, because kobold.cpp uses Vulkan acceleration which makes generation much faster. Are there any plans to add Vulkan build to next releases of stable-diffusion.cpp? By the way, i recently downloaded Amuse windows app (https://www.amuse-ai.com/) and it generates images very fast because it uses DirectML acceleration, onnx and SD-turbo technologies. 512x512, 4 steps image generation takes only 7 seconds! I'm very sad that there is no DirectML acceleration in stable-diffusion.cpp and llama.cpp. Another thing which makes me cry is the fact that flux onnx model can't be quantatized to fit into my 16gb ram. Or i don't know something and it can be done? |
If you want Vulkan support, take a look at the discussion here: #291 |
Interestingly, the file sizes are very close, but still slightly different.
There's also some very slight artifacting (a bit like jpeg) with the q2_k auto encoder that isn't noticable with the other quants I tested (q8 and f16): I'm not sure if saving only a few kilobytes is worth a barely noticable difference in output. That's a strange dilemma. Quantized is definitely worth it compared to full size though. |
Only a very small number of tensors of ae will be quantized. |
Would it be possible to package and use the flux unet, clip, ae, etc into a single file like with the SD models? |
Figured I'd chime in. I've been doing some work over at ComfyUI-GGUF to support flux quantization for image gen. I've noticed some differences between my version and this version by @Green-Sky higher up in the thread. The most obvious thing is that the bias weights are quantized. These can be kept in FP32 without adding more than at most 40MBs to the final model file, and doing this should increase both quality and speed (since less tensors have to be dequantized overall, though this should be relatively fast on small tensors like that). The second issue I noticed is that there's no logic for keeping more vital tensors in higher precision the same way llama.cpp does with LLMs. From my short tests, these benefit the most from doing so while only adding ~100MB: For the text encoder, I've used the default llama.cpp binary to create them as both the full encoder/decoder as well as the encoder only model is supported natively now. Assuming your code can handle mixed quantization, I recommend using this method since keeping the token_embed and the norm/biases in higher precisions makes the effects of quantization a lot less severe. Mapping the keys back to the original names is fairly straight forward. This is the mapping I ended up with for the replacement: clip_sd_map = {
"enc.": "encoder.",
".blk.": ".block.",
"token_embd": "shared",
"output_norm": "final_layer_norm",
"attn_q": "layer.0.SelfAttention.q",
"attn_k": "layer.0.SelfAttention.k",
"attn_v": "layer.0.SelfAttention.v",
"attn_o": "layer.0.SelfAttention.o",
"attn_norm": "layer.0.layer_norm",
"attn_rel_b": "layer.0.SelfAttention.relative_attention_bias",
"ffn_up": "layer.1.DenseReluDense.wi_1",
"ffn_down": "layer.1.DenseReluDense.wo",
"ffn_gate": "layer.1.DenseReluDense.wi_0",
"ffn_norm": "layer.1.layer_norm",
}
for k,v in state_dict.items():
for s,d in clip_sd_map.items():
k = k.replace(s,d)
... Hope this helps! |
In my tests, not converting the bias didn't make something better. Moreover, if I convert the |
like for sd3, there exists an tiny auto encoder that does not work with sd.cpp yet edit: this is not a priority, since vae speed has improved since taesd was first implemented in sd.cpp and uses less compute compared to flux diffusion anyway. |
flux.1-schnell 1024x1024 4step using the new also using quants for:
(also i hate comic sans 🙈 ) |
I am seeing "unknown tensor " using 58d5473
I am using t5xxl from https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/t5xxl_fp16.safetensors |
.\bin\Release\sd.exe |
Hello, I want the quantized versions of T5 and CLIP to also use video memory. Is there any way to achieve this? |
Yeah, sure. Just remove those lines and compile it again. |
|
Ah i thought it would work, I guess I was wrong. Glancing at the code of the Cuda backend it looks like the GET_ROWS operation isn't supported for k quants? Do you have enough VRAM to test with a q4_0 quant instead? |
Thanks. I've tried it all, and many types of quant have the same error. |
Although the architecture is similar to sd3, flux actually has a lot of additional things to implement, so adding flux support took me a bit longer. After merging this pr, I will take some time to merge the PRs of other contributors.
How to Use
Download weights
Convert flux weights
Using fp16 will lead to overflow, but ggml's support for bf16 is not yet fully developed. Therefore, we need to convert flux to gguf format here, which also saves VRAM. For example:
Run
--cfg-scale
is recommended to be set to 1.Flux-dev q8_0
Flux-dev q4_0
Flux-dev q3_k
Flux-dev q2_k
Flux-schnell q8_0
Run with LoRA
Since many flux LoRA training libraries have used various LoRA naming formats, it is possible that not all flux LoRA naming formats are supported. It is recommended to use LoRA with naming formats compatible with ComfyUI.
Flux dev q8_0 with LoRA