File size: 16,733 Bytes
c2ecfb5 28dec30 c2ecfb5 28dec30 c2ecfb5 28dec30 c2ecfb5 28dec30 c2ecfb5 28dec30 c2ecfb5 28dec30 2f2c44c 28dec30 2f2c44c 28dec30 c2ecfb5 28dec30 340f0a0 c2ecfb5 28dec30 c2ecfb5 28dec30 c2ecfb5 28dec30 c2ecfb5 28dec30 2f2c44c c2ecfb5 28dec30 c2ecfb5 28dec30 c2ecfb5 28dec30 c2ecfb5 28dec30 c2ecfb5 28dec30 c2ecfb5 28dec30 c2ecfb5 28dec30 c2ecfb5 28dec30 c2ecfb5 28dec30 c2ecfb5 28dec30 c2ecfb5 c4a514f 2f2c44c c2ecfb5 2f2c44c 28dec30 2f2c44c c2ecfb5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 |
import io
import math
from typing import TYPE_CHECKING, Callable, List
from PIL import Image
import numpy as np
import torch
from einops import rearrange
from flux_emphasis import get_weighted_text_embeddings_flux
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.benchmark_limit = 20
torch.set_float32_matmul_precision("high")
from torch._dynamo import config
from torch._inductor import config as ind_config
from pybase64 import standard_b64decode
config.cache_size_limit = 10000000000
ind_config.shape_padding = True
from loguru import logger
from image_encoder import ImageEncoder
from torchvision.transforms import functional as TF
from tqdm import tqdm
from util import (
ModelSpec,
into_device,
into_dtype,
load_config_from_path,
load_models_from_config,
)
if TYPE_CHECKING:
from modules.conditioner import HFEmbedder
from modules.flux_model import Flux
from modules.autoencoder import AutoEncoder
class FluxPipeline:
def __init__(
self,
name: str,
offload: bool = False,
clip: "HFEmbedder" = None,
t5: "HFEmbedder" = None,
model: "Flux" = None,
ae: "AutoEncoder" = None,
dtype: torch.dtype = torch.float16,
verbose: bool = False,
flux_device: torch.device | str = "cuda:0",
ae_device: torch.device | str = "cuda:1",
clip_device: torch.device | str = "cuda:1",
t5_device: torch.device | str = "cuda:1",
config: ModelSpec = None,
):
self.name = name
self.device_flux = (
flux_device
if isinstance(flux_device, torch.device)
else torch.device(flux_device)
)
self.device_ae = (
ae_device
if isinstance(ae_device, torch.device)
else torch.device(ae_device)
)
self.device_clip = (
clip_device
if isinstance(clip_device, torch.device)
else torch.device(clip_device)
)
self.device_t5 = (
t5_device
if isinstance(t5_device, torch.device)
else torch.device(t5_device)
)
self.dtype = dtype
self.offload = offload
self.clip: "HFEmbedder" = clip
self.t5: "HFEmbedder" = t5
self.model: "Flux" = model
self.ae: "AutoEncoder" = ae
self.rng = torch.Generator(device="cpu")
self.img_encoder = ImageEncoder()
self.verbose = verbose
self.ae_dtype = torch.bfloat16
self.config = config
self.offload_text_encoder = config.offload_text_encoder
self.offload_vae = config.offload_vae
self.offload_flow = config.offload_flow
if not self.offload_flow:
self.model.to(self.device_flux)
if not self.offload_vae:
self.ae.to(self.device_ae)
if not self.offload_text_encoder:
self.clip.to(self.device_clip)
self.t5.to(self.device_t5)
if self.config.compile_blocks or self.config.compile_extras:
if not self.config.prequantized_flow:
print("Warmups for compile...")
warmup_dict = dict(
prompt="Street photography portrait of a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
height=1024,
width=1024,
num_steps=30,
guidance=3.5,
seed=10,
)
self.generate(**warmup_dict)
to_gpu_extras = [
"vector_in",
"img_in",
"txt_in",
"time_in",
"guidance_in",
"final_layer",
"pe_embedder",
]
if self.config.compile_blocks:
for block in self.model.double_blocks:
block.compile()
for block in self.model.single_blocks:
block.compile()
if self.config.compile_extras:
for extra in to_gpu_extras:
getattr(self.model, extra).compile()
@torch.inference_mode()
def prepare(
self,
img: torch.Tensor,
prompt: str | list[str],
target_device: torch.device = torch.device("cuda:0"),
target_dtype: torch.dtype = torch.float16,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
bs, c, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img = img.unfold(2, 2, 2).unfold(3, 2, 2).permute(0, 2, 3, 1, 4, 5)
img = img.reshape(img.shape[0], -1, img.shape[3] * img.shape[4] * img.shape[5])
assert img.shape == (
bs,
(h // 2) * (w // 2),
c * 2 * 2,
), f"{img.shape} != {(bs, (h//2)*(w//2), c*2*2)}"
if img.shape[0] == 1 and bs > 1:
img = img[None].repeat_interleave(bs, dim=0)
img_ids = torch.zeros(
h // 2, w // 2, 3, device=target_device, dtype=target_dtype
)
img_ids[..., 1] = (
img_ids[..., 1]
+ torch.arange(h // 2, device=target_device, dtype=target_dtype)[:, None]
)
img_ids[..., 2] = (
img_ids[..., 2]
+ torch.arange(w // 2, device=target_device, dtype=target_dtype)[None, :]
)
img_ids = img_ids[None].repeat(bs, 1, 1, 1).flatten(1, 2)
if self.offload_text_encoder:
self.clip.cuda(self.device_clip)
self.t5.cuda(self.device_t5)
vec, txt, txt_ids = get_weighted_text_embeddings_flux(
self,
prompt,
num_images_per_prompt=bs,
device=self.device_clip,
target_device=target_device,
target_dtype=target_dtype,
)
if self.offload_text_encoder:
self.clip.to("cpu")
self.t5.to("cpu")
torch.cuda.empty_cache()
return img, img_ids, vec, txt, txt_ids
@torch.inference_mode()
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(
self, x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
@torch.inference_mode()
def get_schedule(
self,
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
mu = self.get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = self.time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
@torch.inference_mode()
def get_noise(
self,
num_samples: int,
height: int,
width: int,
generator: torch.Generator,
dtype=None,
device=None,
):
if device is None:
device = self.device_flux
if dtype is None:
dtype = self.dtype
return torch.randn(
num_samples,
16,
# allow for packing
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
device=device,
dtype=dtype,
generator=generator,
requires_grad=False,
)
@torch.inference_mode()
def into_bytes(self, x: torch.Tensor) -> io.BytesIO:
# bring into PIL format and save
torch.cuda.synchronize()
x = x.contiguous()
x = x.clamp(-1, 1)
num_images = x.shape[0]
images: List[torch.Tensor] = []
for i in range(num_images):
x = x[i].add(1.0).mul(127.5).clamp(0, 255).contiguous().type(torch.uint8)
images.append(x)
if len(images) == 1:
im = images[0]
else:
im = torch.vstack(images)
torch.cuda.synchronize()
im = self.img_encoder.encode_torch(im, quality=99)
images.clear()
return io.BytesIO(im)
@torch.inference_mode()
def vae_decode(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor:
if self.offload_vae:
self.ae.to(self.device_ae)
x = x.to(self.device_ae)
else:
x = x.to(self.device_ae)
x = self.unpack(x.float(), height, width)
with torch.autocast(
device_type=self.device_ae.type, dtype=torch.bfloat16, cache_enabled=False
):
x = self.ae.decode(x)
if self.offload_vae:
self.ae.to("cpu")
torch.cuda.empty_cache()
return x
def unpack(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor:
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)
@torch.inference_mode()
def resize_center_crop(
self, img: torch.Tensor, height: int, width: int
) -> torch.Tensor:
img = TF.resize(img, min(width, height))
img = TF.center_crop(img, (height, width))
return img
@torch.inference_mode()
def preprocess_latent(
self,
init_image: torch.Tensor | np.ndarray = None,
height: int = 720,
width: int = 1024,
num_steps: int = 20,
strength: float = 1.0,
generator: torch.Generator = None,
num_images: int = 1,
) -> tuple[torch.Tensor, List[float]]:
# prepare input
if init_image is not None:
if isinstance(init_image, np.ndarray):
init_image = torch.from_numpy(init_image)
init_image = (
init_image.permute(2, 0, 1)
.contiguous()
.to(self.device_ae, dtype=self.ae_dtype)
.div(127.5)
.sub(1)[None, ...]
)
init_image = self.resize_center_crop(init_image, height, width)
with torch.autocast(
device_type=self.device_ae.type,
dtype=torch.bfloat16,
cache_enabled=False,
):
if self.offload_vae:
self.ae.to(self.device_ae)
init_image = (
self.ae.encode(init_image)
.to(dtype=self.dtype, device=self.device_flux)
.repeat(num_images, 1, 1, 1)
)
if self.offload_vae:
self.ae.to("cpu")
torch.cuda.empty_cache()
x = self.get_noise(
num_images,
height,
width,
device=self.device_flux,
dtype=self.dtype,
generator=generator,
)
timesteps = self.get_schedule(
num_steps=num_steps,
image_seq_len=x.shape[-1] * x.shape[-2] // 4,
shift=(self.name != "flux-schnell"),
)
if init_image is not None:
t_idx = int((1 - strength) * num_steps)
t = timesteps[t_idx]
timesteps = timesteps[t_idx:]
x = t * x + (1.0 - t) * init_image
return x, timesteps
@torch.inference_mode()
def generate(
self,
prompt: str,
width: int = 720,
height: int = 1024,
num_steps: int = 24,
guidance: float = 3.5,
seed: int | None = None,
init_image: torch.Tensor | str | None = None,
strength: float = 1.0,
silent: bool = False,
num_images: int = 1,
return_seed: bool = False,
) -> io.BytesIO:
num_steps = 4 if self.name == "flux-schnell" else num_steps
if isinstance(init_image, str):
try:
init_image = Image.open(init_image)
except Exception as e:
init_image = Image.open(io.BytesIO(standard_b64decode(init_image)))
init_image = torch.from_numpy(np.array(init_image)).type(torch.uint8)
# allow for packing and conversion to latent space
height = 16 * (height // 16)
width = 16 * (width // 16)
if isinstance(seed, str):
seed = int(seed)
if seed is None:
seed = self.rng.seed()
logger.info(f"Generating with:\nSeed: {seed}\nPrompt: {prompt}")
generator = torch.Generator(device=self.device_flux).manual_seed(seed)
img, timesteps = self.preprocess_latent(
init_image=init_image,
height=height,
width=width,
num_steps=num_steps,
strength=strength,
generator=generator,
num_images=num_images,
)
img, img_ids, vec, txt, txt_ids = map(
lambda x: x.contiguous(),
self.prepare(
img=img,
prompt=prompt,
target_device=self.device_flux,
target_dtype=self.dtype,
),
)
# this is ignored for schnell
guidance_vec = torch.full(
(img.shape[0],), guidance, device=self.device_flux, dtype=self.dtype
)
t_vec = None
if self.offload_flow:
self.model.to(self.device_flux)
for t_curr, t_prev in tqdm(
zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1, disable=silent
):
if t_vec is None:
t_vec = torch.full(
(img.shape[0],),
t_curr,
dtype=self.dtype,
device=self.device_flux,
)
else:
t_vec = t_vec.reshape((img.shape[0],)).fill_(t_curr)
pred = self.model.forward(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
img = img + (t_prev - t_curr) * pred
if self.offload_flow:
self.model.to("cpu")
torch.cuda.empty_cache()
# decode latents to pixel space
img = self.vae_decode(img, height, width)
if return_seed:
return self.into_bytes(img), seed
return self.into_bytes(img)
@classmethod
def load_pipeline_from_config_path(
cls, path: str, flow_model_path: str = None
) -> "FluxPipeline":
with torch.inference_mode():
config = load_config_from_path(path)
if flow_model_path:
config.ckpt_path = flow_model_path
return cls.load_pipeline_from_config(config)
@classmethod
def load_pipeline_from_config(cls, config: ModelSpec) -> "FluxPipeline":
from float8_quantize import quantize_flow_transformer_and_dispatch_float8
with torch.inference_mode():
print("flow_quantization_dtype", config.flow_quantization_dtype)
print("prequantized_flow?", config.prequantized_flow)
models = load_models_from_config(config)
config = models.config
flux_device = into_device(config.flux_device)
ae_device = into_device(config.ae_device)
clip_device = into_device(config.text_enc_device)
t5_device = into_device(config.text_enc_device)
flux_dtype = into_dtype(config.flow_dtype)
flow_model = models.flow
if not config.prequantized_flow:
flow_model = quantize_flow_transformer_and_dispatch_float8(
flow_model, flux_device, offload_flow=config.offload_flow
)
else:
flow_model.eval().requires_grad_(False)
return cls(
name=config.version,
clip=models.clip,
t5=models.t5,
model=flow_model,
ae=models.ae,
dtype=flux_dtype,
verbose=False,
flux_device=flux_device,
ae_device=ae_device,
clip_device=clip_device,
t5_device=t5_device,
config=config,
)
|