derbim6 / src /pipeline.py
TrendForge's picture
Initial commit with folder contents
524d6b8 verified
import os
import torch
import torch._dynamo
import gc
from PIL.Image import Image
from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import (
T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
)
from diffusers import (
FluxPipeline, AutoencoderKL, AutoencoderTiny, FluxTransformer2DModel, DiffusionPipeline
)
from pipelines.models import TextToImageRequest
from torch import Generator
# Set environment variables
os.environ["TOKENIZERS_PARALLELISM"] = "True"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
torch._dynamo.config.suppress_errors = True
Pipeline = None
# Define constants
CHECKPOINT = "black-forest-labs/FLUX.1-schnell"
REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"
class QuantativeAnalysis:
def __init__(self, model, num_bins=256, scale_ratio=1.0):
self.model = model
self.num_bins = num_bins
self.scale_ratio = scale_ratio
def apply(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
with torch.no_grad():
param_min = param.min()
param_max = param.max()
param_range = param_max - param_min
if param_range > 0:
params = 0.8 * param_min + 0.2 * param_max
return self.model
class AttentionQuant:
def __init__(self, model, att_config):
self.model = model
self.att_config = att_config
def apply(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
layer_name = name.split(".")[0]
if layer_name in self.att_config:
num_bins, scale_factor = self.att_config[layer_name]
with torch.no_grad():
param_min = param.min()
param_max = param.max()
param_range = param_max - param_min
if param_range > 0:
normalized = (param - param_min) / param_range
binned = torch.round(normalized * (num_bins - 1)) / (num_bins - 1)
rescaled = binned * param_range + param_min
param.data.copy_(rescaled * scale_factor)
else:
param.data.zero_()
return self.model
def load_pipeline() -> Pipeline:
# Load T5 model
__t5_model = T5EncoderModel.from_pretrained(
"TrendForge/extra1manQ1",
revision="d302b6e39214ed4532be34ec337f93c7eef3eaa6",
torch_dtype=torch.bfloat16
).to(memory_format=torch.channels_last)
__text_encoder_2 = __t5_model
# Load VAE
base_vae = AutoencoderTiny.from_pretrained(
"TrendForge/extra2manQ2",
revision="cef012d2db2f5a006567e797a0b9130aea5449c1",
torch_dtype=torch.bfloat16
)
# Load Transformer Model
path = os.path.join(HF_HUB_CACHE, "models--TrendForge--extra0manQ0/snapshots/dc2cda167b8f53792a98020a3ef2f21808b09bb4")
base_trans = FluxTransformer2DModel.from_pretrained(
path, torch_dtype=torch.bfloat16, use_safetensors=False
).to(memory_format=torch.channels_last)
try:
att_config = {
"transformer_blocks.15.attn.norm_added_k.weight": (64, 0.1),
"transformer_blocks.15.attn.norm_added_q.weight": (64, 0.1),
"transformer_blocks.15.attn.norm_added_v.weight": (64, 0.1)
}
transformer = AttentionQuant(base_trans, att_config).apply()
except Exception:
transformer = base_trans
# Load pipeline
pipeline = DiffusionPipeline.from_pretrained(
CHECKPOINT,
revision=REVISION,
vae=base_vae,
transformer=transformer,
text_encoder_2=__text_encoder_2,
torch_dtype=torch.bfloat16
)
pipeline.to("cuda")
# Warmup
for _ in range(3):
pipeline(
prompt="forswearer, skullcap, Juglandales, bluelegs, cunila, carbro, Ammonites",
width=1024,
height=1024,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256
)
return pipeline
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
generator = Generator(pipeline.device).manual_seed(request.seed)
return pipeline(
request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width
).images[0]