Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import gc | |
import json | |
import pathlib | |
import sys | |
import gradio as gr | |
import PIL.Image | |
import torch | |
from diffusers import StableDiffusionPipeline | |
from peft import LoraModel, LoraConfig, set_peft_model_state_dict | |
class InferencePipeline: | |
def __init__(self): | |
self.pipe = None | |
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
self.weight_path = None | |
def clear(self) -> None: | |
self.weight_path = None | |
del self.pipe | |
self.pipe = None | |
torch.cuda.empty_cache() | |
gc.collect() | |
def get_lora_weight_path(name: str) -> pathlib.Path: | |
curr_dir = pathlib.Path(__file__).parent | |
return curr_dir / name, curr_dir / f'{name.replace(".pt", "_config.json")}' | |
def load_and_set_lora_ckpt(self, pipe, weight_path, config_path, dtype): | |
with open(config_path, "r") as f: | |
lora_config = json.load(f) | |
lora_checkpoint_sd = torch.load(weight_path, map_location=self.device) | |
unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k} | |
text_encoder_lora_ds = { | |
k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k | |
} | |
unet_config = LoraConfig(**lora_config["peft_config"]) | |
pipe.unet = LoraModel(unet_config, pipe.unet) | |
set_peft_model_state_dict(pipe.unet, unet_lora_ds) | |
if "text_encoder_peft_config" in lora_config: | |
text_encoder_config = LoraConfig(**lora_config["text_encoder_peft_config"]) | |
pipe.text_encoder = LoraModel(text_encoder_config, pipe.text_encoder) | |
set_peft_model_state_dict(pipe.text_encoder, text_encoder_lora_ds) | |
if dtype in (torch.float16, torch.bfloat16): | |
pipe.unet.half() | |
pipe.text_encoder.half() | |
pipe.to(self.device) | |
return pipe | |
def load_pipe(self, model_id: str, lora_filename: str) -> None: | |
weight_path, config_path = self.get_lora_weight_path(lora_filename) | |
if weight_path == self.weight_path: | |
return | |
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(self.device) | |
pipe = pipe.to(self.device) | |
pipe = self.load_and_set_lora_ckpt(pipe, weight_path, config_path, torch.float16) | |
self.pipe = pipe | |
def run( | |
self, | |
base_model: str, | |
lora_weight_name: str, | |
prompt: str, | |
negative_prompt: str, | |
seed: int, | |
n_steps: int, | |
guidance_scale: float, | |
) -> PIL.Image.Image: | |
if not torch.cuda.is_available(): | |
raise gr.Error("CUDA is not available.") | |
self.load_pipe(base_model, lora_weight_name) | |
generator = torch.Generator(device=self.device).manual_seed(seed) | |
out = self.pipe( | |
prompt, | |
num_inference_steps=n_steps, | |
guidance_scale=guidance_scale, | |
generator=generator, | |
negative_prompt=negative_prompt if negative_prompt else None, | |
) # type: ignore | |
return out.images[0] | |