|
import io |
|
import math |
|
from typing import TYPE_CHECKING, Callable, List |
|
from PIL import Image |
|
import numpy as np |
|
|
|
import torch |
|
|
|
from einops import rearrange |
|
from flux_emphasis import get_weighted_text_embeddings_flux |
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
torch.backends.cudnn.benchmark = True |
|
torch.backends.cudnn.benchmark_limit = 20 |
|
torch.set_float32_matmul_precision("high") |
|
from torch._dynamo import config |
|
from torch._inductor import config as ind_config |
|
from pybase64 import standard_b64decode |
|
|
|
config.cache_size_limit = 10000000000 |
|
ind_config.shape_padding = True |
|
from loguru import logger |
|
from image_encoder import ImageEncoder |
|
from torchvision.transforms import functional as TF |
|
from tqdm import tqdm |
|
from util import ( |
|
ModelSpec, |
|
into_device, |
|
into_dtype, |
|
load_config_from_path, |
|
load_models_from_config, |
|
) |
|
|
|
|
|
if TYPE_CHECKING: |
|
from modules.conditioner import HFEmbedder |
|
from modules.flux_model import Flux |
|
from modules.autoencoder import AutoEncoder |
|
|
|
|
|
class FluxPipeline: |
|
def __init__( |
|
self, |
|
name: str, |
|
offload: bool = False, |
|
clip: "HFEmbedder" = None, |
|
t5: "HFEmbedder" = None, |
|
model: "Flux" = None, |
|
ae: "AutoEncoder" = None, |
|
dtype: torch.dtype = torch.float16, |
|
verbose: bool = False, |
|
flux_device: torch.device | str = "cuda:0", |
|
ae_device: torch.device | str = "cuda:1", |
|
clip_device: torch.device | str = "cuda:1", |
|
t5_device: torch.device | str = "cuda:1", |
|
config: ModelSpec = None, |
|
): |
|
|
|
self.name = name |
|
self.device_flux = ( |
|
flux_device |
|
if isinstance(flux_device, torch.device) |
|
else torch.device(flux_device) |
|
) |
|
self.device_ae = ( |
|
ae_device |
|
if isinstance(ae_device, torch.device) |
|
else torch.device(ae_device) |
|
) |
|
self.device_clip = ( |
|
clip_device |
|
if isinstance(clip_device, torch.device) |
|
else torch.device(clip_device) |
|
) |
|
self.device_t5 = ( |
|
t5_device |
|
if isinstance(t5_device, torch.device) |
|
else torch.device(t5_device) |
|
) |
|
self.dtype = dtype |
|
self.offload = offload |
|
self.clip: "HFEmbedder" = clip |
|
self.t5: "HFEmbedder" = t5 |
|
self.model: "Flux" = model |
|
self.ae: "AutoEncoder" = ae |
|
self.rng = torch.Generator(device="cpu") |
|
self.img_encoder = ImageEncoder() |
|
self.verbose = verbose |
|
self.ae_dtype = torch.bfloat16 |
|
self.config = config |
|
self.offload_text_encoder = config.offload_text_encoder |
|
self.offload_vae = config.offload_vae |
|
self.offload_flow = config.offload_flow |
|
if not self.offload_flow: |
|
self.model.to(self.device_flux) |
|
if not self.offload_vae: |
|
self.ae.to(self.device_ae) |
|
if not self.offload_text_encoder: |
|
self.clip.to(self.device_clip) |
|
self.t5.to(self.device_t5) |
|
|
|
if self.config.compile_blocks or self.config.compile_extras: |
|
if not self.config.prequantized_flow: |
|
print("Warmups for compile...") |
|
warmup_dict = dict( |
|
prompt="Street photography portrait of a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns", |
|
height=1024, |
|
width=1024, |
|
num_steps=30, |
|
guidance=3.5, |
|
seed=10, |
|
) |
|
self.generate(**warmup_dict) |
|
to_gpu_extras = [ |
|
"vector_in", |
|
"img_in", |
|
"txt_in", |
|
"time_in", |
|
"guidance_in", |
|
"final_layer", |
|
"pe_embedder", |
|
] |
|
if self.config.compile_blocks: |
|
for block in self.model.double_blocks: |
|
block.compile() |
|
for block in self.model.single_blocks: |
|
block.compile() |
|
if self.config.compile_extras: |
|
for extra in to_gpu_extras: |
|
getattr(self.model, extra).compile() |
|
|
|
@torch.inference_mode() |
|
def prepare( |
|
self, |
|
img: torch.Tensor, |
|
prompt: str | list[str], |
|
target_device: torch.device = torch.device("cuda:0"), |
|
target_dtype: torch.dtype = torch.float16, |
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
bs, c, h, w = img.shape |
|
if bs == 1 and not isinstance(prompt, str): |
|
bs = len(prompt) |
|
img = img.unfold(2, 2, 2).unfold(3, 2, 2).permute(0, 2, 3, 1, 4, 5) |
|
img = img.reshape(img.shape[0], -1, img.shape[3] * img.shape[4] * img.shape[5]) |
|
assert img.shape == ( |
|
bs, |
|
(h // 2) * (w // 2), |
|
c * 2 * 2, |
|
), f"{img.shape} != {(bs, (h//2)*(w//2), c*2*2)}" |
|
if img.shape[0] == 1 and bs > 1: |
|
img = img[None].repeat_interleave(bs, dim=0) |
|
|
|
img_ids = torch.zeros( |
|
h // 2, w // 2, 3, device=target_device, dtype=target_dtype |
|
) |
|
img_ids[..., 1] = ( |
|
img_ids[..., 1] |
|
+ torch.arange(h // 2, device=target_device, dtype=target_dtype)[:, None] |
|
) |
|
img_ids[..., 2] = ( |
|
img_ids[..., 2] |
|
+ torch.arange(w // 2, device=target_device, dtype=target_dtype)[None, :] |
|
) |
|
|
|
img_ids = img_ids[None].repeat(bs, 1, 1, 1).flatten(1, 2) |
|
if self.offload_text_encoder: |
|
self.clip.cuda(self.device_clip) |
|
self.t5.cuda(self.device_t5) |
|
vec, txt, txt_ids = get_weighted_text_embeddings_flux( |
|
self, |
|
prompt, |
|
num_images_per_prompt=bs, |
|
device=self.device_clip, |
|
target_device=target_device, |
|
target_dtype=target_dtype, |
|
) |
|
if self.offload_text_encoder: |
|
self.clip.to("cpu") |
|
self.t5.to("cpu") |
|
torch.cuda.empty_cache() |
|
return img, img_ids, vec, txt, txt_ids |
|
|
|
@torch.inference_mode() |
|
def time_shift(self, mu: float, sigma: float, t: torch.Tensor): |
|
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) |
|
|
|
def get_lin_function( |
|
self, x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 |
|
) -> Callable[[float], float]: |
|
m = (y2 - y1) / (x2 - x1) |
|
b = y1 - m * x1 |
|
return lambda x: m * x + b |
|
|
|
@torch.inference_mode() |
|
def get_schedule( |
|
self, |
|
num_steps: int, |
|
image_seq_len: int, |
|
base_shift: float = 0.5, |
|
max_shift: float = 1.15, |
|
shift: bool = True, |
|
) -> list[float]: |
|
|
|
timesteps = torch.linspace(1, 0, num_steps + 1) |
|
|
|
|
|
if shift: |
|
|
|
mu = self.get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) |
|
timesteps = self.time_shift(mu, 1.0, timesteps) |
|
|
|
return timesteps.tolist() |
|
|
|
@torch.inference_mode() |
|
def get_noise( |
|
self, |
|
num_samples: int, |
|
height: int, |
|
width: int, |
|
generator: torch.Generator, |
|
dtype=None, |
|
device=None, |
|
): |
|
if device is None: |
|
device = self.device_flux |
|
if dtype is None: |
|
dtype = self.dtype |
|
return torch.randn( |
|
num_samples, |
|
16, |
|
|
|
2 * math.ceil(height / 16), |
|
2 * math.ceil(width / 16), |
|
device=device, |
|
dtype=dtype, |
|
generator=generator, |
|
requires_grad=False, |
|
) |
|
|
|
@torch.inference_mode() |
|
def into_bytes(self, x: torch.Tensor) -> io.BytesIO: |
|
|
|
torch.cuda.synchronize() |
|
x = x.contiguous() |
|
x = x.clamp(-1, 1) |
|
num_images = x.shape[0] |
|
images: List[torch.Tensor] = [] |
|
for i in range(num_images): |
|
x = x[i].add(1.0).mul(127.5).clamp(0, 255).contiguous().type(torch.uint8) |
|
images.append(x) |
|
if len(images) == 1: |
|
im = images[0] |
|
else: |
|
im = torch.vstack(images) |
|
|
|
torch.cuda.synchronize() |
|
im = self.img_encoder.encode_torch(im, quality=99) |
|
images.clear() |
|
return io.BytesIO(im) |
|
|
|
@torch.inference_mode() |
|
def vae_decode(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor: |
|
if self.offload_vae: |
|
self.ae.to(self.device_ae) |
|
x = x.to(self.device_ae) |
|
else: |
|
x = x.to(self.device_ae) |
|
x = self.unpack(x.float(), height, width) |
|
with torch.autocast( |
|
device_type=self.device_ae.type, dtype=torch.bfloat16, cache_enabled=False |
|
): |
|
x = self.ae.decode(x) |
|
if self.offload_vae: |
|
self.ae.to("cpu") |
|
torch.cuda.empty_cache() |
|
return x |
|
|
|
def unpack(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor: |
|
return rearrange( |
|
x, |
|
"b (h w) (c ph pw) -> b c (h ph) (w pw)", |
|
h=math.ceil(height / 16), |
|
w=math.ceil(width / 16), |
|
ph=2, |
|
pw=2, |
|
) |
|
|
|
@torch.inference_mode() |
|
def resize_center_crop( |
|
self, img: torch.Tensor, height: int, width: int |
|
) -> torch.Tensor: |
|
img = TF.resize(img, min(width, height)) |
|
img = TF.center_crop(img, (height, width)) |
|
return img |
|
|
|
@torch.inference_mode() |
|
def preprocess_latent( |
|
self, |
|
init_image: torch.Tensor | np.ndarray = None, |
|
height: int = 720, |
|
width: int = 1024, |
|
num_steps: int = 20, |
|
strength: float = 1.0, |
|
generator: torch.Generator = None, |
|
num_images: int = 1, |
|
) -> tuple[torch.Tensor, List[float]]: |
|
|
|
|
|
if init_image is not None: |
|
if isinstance(init_image, np.ndarray): |
|
init_image = torch.from_numpy(init_image) |
|
|
|
init_image = ( |
|
init_image.permute(2, 0, 1) |
|
.contiguous() |
|
.to(self.device_ae, dtype=self.ae_dtype) |
|
.div(127.5) |
|
.sub(1)[None, ...] |
|
) |
|
init_image = self.resize_center_crop(init_image, height, width) |
|
with torch.autocast( |
|
device_type=self.device_ae.type, |
|
dtype=torch.bfloat16, |
|
cache_enabled=False, |
|
): |
|
if self.offload_vae: |
|
self.ae.to(self.device_ae) |
|
init_image = ( |
|
self.ae.encode(init_image) |
|
.to(dtype=self.dtype, device=self.device_flux) |
|
.repeat(num_images, 1, 1, 1) |
|
) |
|
if self.offload_vae: |
|
self.ae.to("cpu") |
|
torch.cuda.empty_cache() |
|
|
|
x = self.get_noise( |
|
num_images, |
|
height, |
|
width, |
|
device=self.device_flux, |
|
dtype=self.dtype, |
|
generator=generator, |
|
) |
|
timesteps = self.get_schedule( |
|
num_steps=num_steps, |
|
image_seq_len=x.shape[-1] * x.shape[-2] // 4, |
|
shift=(self.name != "flux-schnell"), |
|
) |
|
if init_image is not None: |
|
t_idx = int((1 - strength) * num_steps) |
|
t = timesteps[t_idx] |
|
timesteps = timesteps[t_idx:] |
|
x = t * x + (1.0 - t) * init_image |
|
return x, timesteps |
|
|
|
@torch.inference_mode() |
|
def generate( |
|
self, |
|
prompt: str, |
|
width: int = 720, |
|
height: int = 1024, |
|
num_steps: int = 24, |
|
guidance: float = 3.5, |
|
seed: int | None = None, |
|
init_image: torch.Tensor | str | None = None, |
|
strength: float = 1.0, |
|
silent: bool = False, |
|
num_images: int = 1, |
|
return_seed: bool = False, |
|
) -> io.BytesIO: |
|
num_steps = 4 if self.name == "flux-schnell" else num_steps |
|
|
|
if isinstance(init_image, str): |
|
try: |
|
init_image = Image.open(init_image) |
|
except Exception as e: |
|
init_image = Image.open(io.BytesIO(standard_b64decode(init_image))) |
|
init_image = torch.from_numpy(np.array(init_image)).type(torch.uint8) |
|
|
|
|
|
height = 16 * (height // 16) |
|
width = 16 * (width // 16) |
|
if isinstance(seed, str): |
|
seed = int(seed) |
|
if seed is None: |
|
seed = self.rng.seed() |
|
logger.info(f"Generating with:\nSeed: {seed}\nPrompt: {prompt}") |
|
|
|
generator = torch.Generator(device=self.device_flux).manual_seed(seed) |
|
img, timesteps = self.preprocess_latent( |
|
init_image=init_image, |
|
height=height, |
|
width=width, |
|
num_steps=num_steps, |
|
strength=strength, |
|
generator=generator, |
|
num_images=num_images, |
|
) |
|
img, img_ids, vec, txt, txt_ids = map( |
|
lambda x: x.contiguous(), |
|
self.prepare( |
|
img=img, |
|
prompt=prompt, |
|
target_device=self.device_flux, |
|
target_dtype=self.dtype, |
|
), |
|
) |
|
|
|
|
|
guidance_vec = torch.full( |
|
(img.shape[0],), guidance, device=self.device_flux, dtype=self.dtype |
|
) |
|
t_vec = None |
|
if self.offload_flow: |
|
self.model.to(self.device_flux) |
|
for t_curr, t_prev in tqdm( |
|
zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1, disable=silent |
|
): |
|
if t_vec is None: |
|
t_vec = torch.full( |
|
(img.shape[0],), |
|
t_curr, |
|
dtype=self.dtype, |
|
device=self.device_flux, |
|
) |
|
else: |
|
t_vec = t_vec.reshape((img.shape[0],)).fill_(t_curr) |
|
pred = self.model.forward( |
|
img=img, |
|
img_ids=img_ids, |
|
txt=txt, |
|
txt_ids=txt_ids, |
|
y=vec, |
|
timesteps=t_vec, |
|
guidance=guidance_vec, |
|
) |
|
|
|
img = img + (t_prev - t_curr) * pred |
|
|
|
if self.offload_flow: |
|
self.model.to("cpu") |
|
torch.cuda.empty_cache() |
|
|
|
|
|
img = self.vae_decode(img, height, width) |
|
|
|
if return_seed: |
|
return self.into_bytes(img), seed |
|
return self.into_bytes(img) |
|
|
|
@classmethod |
|
def load_pipeline_from_config_path( |
|
cls, path: str, flow_model_path: str = None |
|
) -> "FluxPipeline": |
|
with torch.inference_mode(): |
|
config = load_config_from_path(path) |
|
if flow_model_path: |
|
config.ckpt_path = flow_model_path |
|
return cls.load_pipeline_from_config(config) |
|
|
|
@classmethod |
|
def load_pipeline_from_config(cls, config: ModelSpec) -> "FluxPipeline": |
|
from float8_quantize import quantize_flow_transformer_and_dispatch_float8 |
|
|
|
with torch.inference_mode(): |
|
print("flow_quantization_dtype", config.flow_quantization_dtype) |
|
print("prequantized_flow?", config.prequantized_flow) |
|
|
|
models = load_models_from_config(config) |
|
config = models.config |
|
flux_device = into_device(config.flux_device) |
|
ae_device = into_device(config.ae_device) |
|
clip_device = into_device(config.text_enc_device) |
|
t5_device = into_device(config.text_enc_device) |
|
flux_dtype = into_dtype(config.flow_dtype) |
|
flow_model = models.flow |
|
|
|
if not config.prequantized_flow: |
|
flow_model = quantize_flow_transformer_and_dispatch_float8( |
|
flow_model, flux_device, offload_flow=config.offload_flow |
|
) |
|
else: |
|
flow_model.eval().requires_grad_(False) |
|
|
|
return cls( |
|
name=config.version, |
|
clip=models.clip, |
|
t5=models.t5, |
|
model=flow_model, |
|
ae=models.ae, |
|
dtype=flux_dtype, |
|
verbose=False, |
|
flux_device=flux_device, |
|
ae_device=ae_device, |
|
clip_device=clip_device, |
|
t5_device=t5_device, |
|
config=config, |
|
) |
|
|