|
import json |
|
from pathlib import Path |
|
from typing import Literal, Optional |
|
|
|
import torch |
|
from modules.autoencoder import AutoEncoder, AutoEncoderParams |
|
from modules.conditioner import HFEmbedder |
|
from modules.flux_model import Flux, FluxParams |
|
from modules.flux_model_f8 import Flux as FluxF8 |
|
from safetensors.torch import load_file as load_sft |
|
from enum import StrEnum |
|
from pydantic import BaseModel, ConfigDict |
|
from loguru import logger |
|
|
|
|
|
class ModelVersion(StrEnum): |
|
flux_dev = "flux-dev" |
|
flux_schnell = "flux-schnell" |
|
|
|
|
|
class QuantizationDtype(StrEnum): |
|
qfloat8 = "qfloat8" |
|
qint2 = "qint2" |
|
qint4 = "qint4" |
|
qint8 = "qint8" |
|
|
|
|
|
class ModelSpec(BaseModel): |
|
version: ModelVersion |
|
params: FluxParams |
|
ae_params: AutoEncoderParams |
|
ckpt_path: str | None |
|
ae_path: str | None |
|
repo_id: str | None |
|
repo_flow: str | None |
|
repo_ae: str | None |
|
text_enc_max_length: int = 512 |
|
text_enc_path: str | None |
|
text_enc_device: str | torch.device | None = "cuda:0" |
|
ae_device: str | torch.device | None = "cuda:0" |
|
flux_device: str | torch.device | None = "cuda:0" |
|
flow_dtype: str = "float16" |
|
ae_dtype: str = "bfloat16" |
|
text_enc_dtype: str = "bfloat16" |
|
|
|
num_to_quant: Optional[int] = 20 |
|
quantize_extras: bool = False |
|
compile_extras: bool = False |
|
compile_blocks: bool = False |
|
flow_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8 |
|
text_enc_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8 |
|
ae_quantization_dtype: Optional[QuantizationDtype] = None |
|
clip_quantization_dtype: Optional[QuantizationDtype] = None |
|
offload_text_encoder: bool = False |
|
offload_vae: bool = False |
|
offload_flow: bool = False |
|
prequantized_flow: bool = False |
|
|
|
model_config: ConfigDict = { |
|
"arbitrary_types_allowed": True, |
|
"use_enum_values": True, |
|
} |
|
|
|
|
|
def load_models(config: ModelSpec) -> tuple[Flux, AutoEncoder, HFEmbedder, HFEmbedder]: |
|
flow = load_flow_model(config) |
|
ae = load_autoencoder(config) |
|
clip, t5 = load_text_encoders(config) |
|
return flow, ae, clip, t5 |
|
|
|
|
|
def parse_device(device: str | torch.device | None) -> torch.device: |
|
if isinstance(device, str): |
|
return torch.device(device) |
|
elif isinstance(device, torch.device): |
|
return device |
|
else: |
|
return torch.device("cuda:0") |
|
|
|
|
|
def into_dtype(dtype: str) -> torch.dtype: |
|
if dtype == "float16": |
|
return torch.float16 |
|
elif dtype == "bfloat16": |
|
return torch.bfloat16 |
|
elif dtype == "float32": |
|
return torch.float32 |
|
else: |
|
raise ValueError(f"Invalid dtype: {dtype}") |
|
|
|
|
|
def into_device(device: str | torch.device | None) -> torch.device: |
|
if isinstance(device, str): |
|
return torch.device(device) |
|
elif isinstance(device, torch.device): |
|
return device |
|
elif isinstance(device, int): |
|
return torch.device(f"cuda:{device}") |
|
else: |
|
return torch.device("cuda:0") |
|
|
|
|
|
def load_config( |
|
name: ModelVersion = ModelVersion.flux_dev, |
|
flux_path: str | None = None, |
|
ae_path: str | None = None, |
|
text_enc_path: str | None = None, |
|
text_enc_device: str | torch.device | None = None, |
|
ae_device: str | torch.device | None = None, |
|
flux_device: str | torch.device | None = None, |
|
flow_dtype: str = "float16", |
|
ae_dtype: str = "bfloat16", |
|
text_enc_dtype: str = "bfloat16", |
|
num_to_quant: Optional[int] = 20, |
|
compile_extras: bool = False, |
|
compile_blocks: bool = False, |
|
offload_text_enc: bool = False, |
|
offload_ae: bool = False, |
|
offload_flow: bool = False, |
|
quant_text_enc: Optional[Literal["float8", "qint2", "qint4", "qint8"]] = None, |
|
quant_ae: bool = False, |
|
prequantized_flow: bool = False, |
|
) -> ModelSpec: |
|
""" |
|
Load a model configuration using the passed arguments. |
|
""" |
|
text_enc_device = str(parse_device(text_enc_device)) |
|
ae_device = str(parse_device(ae_device)) |
|
flux_device = str(parse_device(flux_device)) |
|
return ModelSpec( |
|
version=name, |
|
repo_id=( |
|
"black-forest-labs/FLUX.1-dev" |
|
if name == ModelVersion.flux_dev |
|
else "black-forest-labs/FLUX.1-schnell" |
|
), |
|
repo_flow=( |
|
"flux1-dev.sft" if name == ModelVersion.flux_dev else "flux1-schnell.sft" |
|
), |
|
repo_ae="ae.sft", |
|
ckpt_path=flux_path, |
|
params=FluxParams( |
|
in_channels=64, |
|
vec_in_dim=768, |
|
context_in_dim=4096, |
|
hidden_size=3072, |
|
mlp_ratio=4.0, |
|
num_heads=24, |
|
depth=19, |
|
depth_single_blocks=38, |
|
axes_dim=[16, 56, 56], |
|
theta=10_000, |
|
qkv_bias=True, |
|
guidance_embed=name == ModelVersion.flux_dev, |
|
), |
|
ae_path=ae_path, |
|
ae_params=AutoEncoderParams( |
|
resolution=256, |
|
in_channels=3, |
|
ch=128, |
|
out_ch=3, |
|
ch_mult=[1, 2, 4, 4], |
|
num_res_blocks=2, |
|
z_channels=16, |
|
scale_factor=0.3611, |
|
shift_factor=0.1159, |
|
), |
|
text_enc_path=text_enc_path, |
|
text_enc_device=text_enc_device, |
|
ae_device=ae_device, |
|
flux_device=flux_device, |
|
flow_dtype=flow_dtype, |
|
ae_dtype=ae_dtype, |
|
text_enc_dtype=text_enc_dtype, |
|
text_enc_max_length=512 if name == ModelVersion.flux_dev else 256, |
|
num_to_quant=num_to_quant, |
|
compile_extras=compile_extras, |
|
compile_blocks=compile_blocks, |
|
offload_flow=offload_flow, |
|
offload_text_encoder=offload_text_enc, |
|
offload_vae=offload_ae, |
|
text_enc_quantization_dtype={ |
|
"float8": QuantizationDtype.qfloat8, |
|
"qint2": QuantizationDtype.qint2, |
|
"qint4": QuantizationDtype.qint4, |
|
"qint8": QuantizationDtype.qint8, |
|
}.get(quant_text_enc, None), |
|
ae_quantization_dtype=QuantizationDtype.qfloat8 if quant_ae else None, |
|
prequantized_flow=prequantized_flow, |
|
) |
|
|
|
|
|
def load_config_from_path(path: str) -> ModelSpec: |
|
path_path = Path(path) |
|
if not path_path.exists(): |
|
raise ValueError(f"Path {path} does not exist") |
|
if not path_path.is_file(): |
|
raise ValueError(f"Path {path} is not a file") |
|
return ModelSpec(**json.loads(path_path.read_text())) |
|
|
|
|
|
def print_load_warning(missing: list[str], unexpected: list[str]) -> None: |
|
if len(missing) > 0 and len(unexpected) > 0: |
|
logger.warning(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) |
|
logger.warning("\n" + "-" * 79 + "\n") |
|
logger.warning( |
|
f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected) |
|
) |
|
elif len(missing) > 0: |
|
logger.warning(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) |
|
elif len(unexpected) > 0: |
|
logger.warning( |
|
f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected) |
|
) |
|
|
|
|
|
def load_flow_model(config: ModelSpec) -> Flux | FluxF8: |
|
ckpt_path = config.ckpt_path |
|
FluxClass = Flux |
|
if config.prequantized_flow: |
|
FluxClass = FluxF8 |
|
|
|
with torch.device("meta"): |
|
model = FluxClass(config.params, dtype=into_dtype(config.flow_dtype)).type( |
|
into_dtype(config.flow_dtype) |
|
) |
|
|
|
if ckpt_path is not None: |
|
|
|
sd = load_sft(ckpt_path, device="cpu") |
|
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) |
|
print_load_warning(missing, unexpected) |
|
if not config.prequantized_flow: |
|
model.type(into_dtype(config.flow_dtype)) |
|
return model |
|
|
|
|
|
def load_text_encoders(config: ModelSpec) -> tuple[HFEmbedder, HFEmbedder]: |
|
clip = HFEmbedder( |
|
"openai/clip-vit-large-patch14", |
|
max_length=77, |
|
torch_dtype=into_dtype(config.text_enc_dtype), |
|
device=into_device(config.text_enc_device).index or 0, |
|
quantization_dtype=config.clip_quantization_dtype, |
|
) |
|
t5 = HFEmbedder( |
|
config.text_enc_path, |
|
max_length=config.text_enc_max_length, |
|
torch_dtype=into_dtype(config.text_enc_dtype), |
|
device=into_device(config.text_enc_device).index or 0, |
|
quantization_dtype=config.text_enc_quantization_dtype, |
|
) |
|
return clip, t5 |
|
|
|
|
|
def load_autoencoder(config: ModelSpec) -> AutoEncoder: |
|
ckpt_path = config.ae_path |
|
with torch.device("meta" if ckpt_path is not None else config.ae_device): |
|
ae = AutoEncoder(config.ae_params).to(into_dtype(config.ae_dtype)) |
|
|
|
if ckpt_path is not None: |
|
sd = load_sft(ckpt_path, device=str(config.ae_device)) |
|
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) |
|
print_load_warning(missing, unexpected) |
|
ae.to(device=into_device(config.ae_device), dtype=into_dtype(config.ae_dtype)) |
|
if config.ae_quantization_dtype is not None: |
|
from float8_quantize import recursive_swap_linears |
|
|
|
recursive_swap_linears(ae) |
|
if config.offload_vae: |
|
ae.to("cpu") |
|
torch.cuda.empty_cache() |
|
return ae |
|
|
|
|
|
class LoadedModels(BaseModel): |
|
flow: Flux | FluxF8 |
|
ae: AutoEncoder |
|
clip: HFEmbedder |
|
t5: HFEmbedder |
|
config: ModelSpec |
|
|
|
model_config = { |
|
"arbitrary_types_allowed": True, |
|
"use_enum_values": True, |
|
} |
|
|
|
|
|
def load_models_from_config_path( |
|
path: str, |
|
) -> LoadedModels: |
|
config = load_config_from_path(path) |
|
clip, t5 = load_text_encoders(config) |
|
return LoadedModels( |
|
flow=load_flow_model(config), |
|
ae=load_autoencoder(config), |
|
clip=clip, |
|
t5=t5, |
|
config=config, |
|
) |
|
|
|
|
|
def load_models_from_config(config: ModelSpec) -> LoadedModels: |
|
clip, t5 = load_text_encoders(config) |
|
return LoadedModels( |
|
flow=load_flow_model(config), |
|
ae=load_autoencoder(config), |
|
clip=clip, |
|
t5=t5, |
|
config=config, |
|
) |
|
|