import os import torch import torch._dynamo import gc from PIL.Image import Image from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only from huggingface_hub.constants import HF_HUB_CACHE from transformers import ( T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel ) from diffusers import ( FluxPipeline, AutoencoderKL, AutoencoderTiny, FluxTransformer2DModel, DiffusionPipeline ) from pipelines.models import TextToImageRequest from torch import Generator # Set environment variables os.environ["TOKENIZERS_PARALLELISM"] = "True" os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True" torch._dynamo.config.suppress_errors = True Pipeline = None # Define constants CHECKPOINT = "black-forest-labs/FLUX.1-schnell" REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9" class QuantativeAnalysis: def __init__(self, model, num_bins=256, scale_ratio=1.0): self.model = model self.num_bins = num_bins self.scale_ratio = scale_ratio def apply(self): for name, param in self.model.named_parameters(): if param.requires_grad: with torch.no_grad(): param_min = param.min() param_max = param.max() param_range = param_max - param_min if param_range > 0: params = 0.8 * param_min + 0.2 * param_max return self.model class AttentionQuant: def __init__(self, model, att_config): self.model = model self.att_config = att_config def apply(self): for name, param in self.model.named_parameters(): if param.requires_grad: layer_name = name.split(".")[0] if layer_name in self.att_config: num_bins, scale_factor = self.att_config[layer_name] with torch.no_grad(): param_min = param.min() param_max = param.max() param_range = param_max - param_min if param_range > 0: normalized = (param - param_min) / param_range binned = torch.round(normalized * (num_bins - 1)) / (num_bins - 1) rescaled = binned * param_range + param_min param.data.copy_(rescaled * scale_factor) else: param.data.zero_() return self.model def load_pipeline() -> Pipeline: # Load T5 model __t5_model = T5EncoderModel.from_pretrained( "TrendForge/extra1manQ1", revision="d302b6e39214ed4532be34ec337f93c7eef3eaa6", torch_dtype=torch.bfloat16 ).to(memory_format=torch.channels_last) __text_encoder_2 = __t5_model # Load VAE base_vae = AutoencoderTiny.from_pretrained( "TrendForge/extra2manQ2", revision="cef012d2db2f5a006567e797a0b9130aea5449c1", torch_dtype=torch.bfloat16 ) # Load Transformer Model path = os.path.join(HF_HUB_CACHE, "models--TrendForge--extra0manQ0/snapshots/dc2cda167b8f53792a98020a3ef2f21808b09bb4") base_trans = FluxTransformer2DModel.from_pretrained( path, torch_dtype=torch.bfloat16, use_safetensors=False ).to(memory_format=torch.channels_last) try: att_config = { "transformer_blocks.15.attn.norm_added_k.weight": (64, 0.1), "transformer_blocks.15.attn.norm_added_q.weight": (64, 0.1), "transformer_blocks.15.attn.norm_added_v.weight": (64, 0.1) } transformer = AttentionQuant(base_trans, att_config).apply() except Exception: transformer = base_trans # Load pipeline pipeline = DiffusionPipeline.from_pretrained( CHECKPOINT, revision=REVISION, vae=base_vae, transformer=transformer, text_encoder_2=__text_encoder_2, torch_dtype=torch.bfloat16 ) pipeline.to("cuda") # Warmup for _ in range(3): pipeline( prompt="forswearer, skullcap, Juglandales, bluelegs, cunila, carbro, Ammonites", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256 ) return pipeline @torch.no_grad() def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image: generator = Generator(pipeline.device).manual_seed(request.seed) return pipeline( request.prompt, generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width ).images[0]