from diffusers import ( DiffusionPipeline, AutoencoderKL, FluxPipeline, FluxTransformer2DModel ) from diffusers.image_processor import VaeImageProcessor from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from huggingface_hub.constants import HF_HUB_CACHE from transformers import ( T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel ) import torch import torch._dynamo import gc from PIL import Image from pipelines.models import TextToImageRequest from torch import Generator import time import math from typing import Type, Dict, Any, Tuple, Callable, Optional, Union import numpy as np import torch.nn as nn import torch.nn.functional as F from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only # preconfigs import os os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True" os.environ["TOKENIZERS_PARALLELISM"] = "True" torch._dynamo.config.suppress_errors = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.enabled = True # torch.backends.cudnn.benchmark = True # globals Pipeline = None ckpt_id = "black-forest-labs/FLUX.1-schnell" ckpt_revision = "741f7c3ce8b383c54771c7003378a50191e9efe9" def empty_cache(): gc.collect() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats() def load_pipeline() -> Pipeline: text_encoder_2 = T5EncoderModel.from_pretrained("manbeast3b/flux.1-schnell-full1", revision = "cb1b599b0d712b9aab2c4df3ad27b050a27ec146", subfolder="text_encoder_2",torch_dtype=torch.bfloat16).to(memory_format=torch.channels_last) path = os.path.join(HF_HUB_CACHE, "models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146/transformer") transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False).to(memory_format=torch.channels_last) quantize_(AutoencoderKL.from_pretrained(ids,revision=Revision, subfolder="vae", local_files_only=True, torch_dtype=torch.bfloat16,), int8_weight_only()) pipeline = FluxPipeline.from_pretrained(ckpt_id, revision=ckpt_revision, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=torch.bfloat16,) pipeline.to("cuda") with torch.inference_mode(): pipeline(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256) return pipeline sample = 1 @torch.no_grad() def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image: global sample if not sample: sample=1 empty_cache() return pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]