quite slow to load the fp8 model

#21
by gpt3eth - opened

On Nvidia A6000, using the code below to load the fp8

import torch
from diffusers import FluxTransformer2DModel, FluxPipeline
from transformers import T5EncoderModel
from optimum.quanto import freeze, qfloat8, quantize
import time
import json

# Initialize a dictionary to store stats
stats = {}

# Measure the time taken to load and prepare the model into VRAM
start_time = time.time()
bfl_repo = "black-forest-labs/FLUX.1-schnell"
dtype = torch.bfloat16

transformer = FluxTransformer2DModel.from_single_file(
    "https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-schnell-fp8.safetensors", 
    torch_dtype=dtype
)
quantize(transformer, weights=qfloat8)
freeze(transformer)

text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)

pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2
pipe.to("cuda")
stats['model_loading_time'] = time.time() - start_time

it took around 293s to load the model, why is this so slow to load?

I'm getting the error when running this code:

transformer = FluxTransformer2DModel.from_single_file(...

AttributeError: type object 'FluxTransformer2DModel' has no attribute 'from_single_file'

Which version of diffuser do you have installed?

Thanks!

Thanks, I think they just realeased the .from_single_file

Could make it work installing the latest version from the main (I basically just clicked through all steps listed https://huggingface.co/docs/diffusers/installation#install-from-source

885s with a simple laptop 16GB ram (no GPU support since Windows+AMD-GPU)
The quantizing + freezing takes a lot of time.

why have u quantized it again when it is already in fp8?

You could probably just import it like so:

transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-schnell-fp8.safetensors", torch_dtype=torch.bfloat16
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder_2=None, torch_dtype=torch.bfloat16)

You could probably just import it like so:

transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-schnell-fp8.safetensors", torch_dtype=dtype
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder_2=None, torch_dtype=dtype)

What dtype though?

You could probably just import it like so:

transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-schnell-fp8.safetensors", torch_dtype=dtype
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder_2=None, torch_dtype=dtype)

What dtype though?

torch.bfloat16

You could probably just import it like so:

transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-schnell-fp8.safetensors", torch_dtype=dtype
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder_2=None, torch_dtype=dtype)

What dtype though?

torch.bfloat16

But its not 16 its 8. I think what you need is load_in_8bit=True

pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder_2=None, load_in_8bit=True)

You could probably just import it like so:

transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-schnell-fp8.safetensors", torch_dtype=dtype
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder_2=None, torch_dtype=dtype)

What dtype though?

torch.bfloat16

But its not 16 its 8. I think what you need is load_in_8bit=True

pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder_2=None, load_in_8bit=True)

The Diffusers pipeline suggests to import it and then use optimum quanto to quantize the weights and freeze the transformer in qfloat8
https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#single-file-loading-for-the-fluxtransformer2dmodel

I am experiencing the same issue, it would be really nice if anyone can help. Using the example provided in official docs, https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#single-file-loading-for-the-fluxtransformer2dmodel

Sign up or log in to comment