Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import numpy as np | |
import requests | |
import random | |
from io import BytesIO | |
# from utils import * | |
# from constants import * | |
from pipeline_semantic_stable_diffusion_xl_img2img_ddpm import * | |
from torch import inference_mode | |
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, AutoencoderKL | |
from diffusers import DDIMScheduler | |
# from share_btn import community_icon_html, loading_icon_html, share_js | |
import torch | |
from huggingface_hub import hf_hub_download | |
from diffusers import DiffusionPipeline | |
from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler | |
import json | |
from safetensors.torch import load_file | |
# import lora | |
import copy | |
import json | |
import gc | |
import random | |
from time import sleep | |
with open("sdxl_loras-Copy1.json", "r") as file: | |
data = json.load(file) | |
sdxl_loras_raw = [ | |
{ | |
"image": item["image"], | |
"title": item["title"], | |
"repo": item["repo"], | |
"trigger_word": item["trigger_word"], | |
"weights": item["weights"], | |
"is_compatible": item["is_compatible"], | |
"is_pivotal": item.get("is_pivotal", False), | |
"text_embedding_weights": item.get("text_embedding_weights", None), | |
# "likes": item.get("likes", 0), | |
# "downloads": item.get("downloads", 0), | |
"is_nc": item.get("is_nc", False), | |
"edit_guidance_scale": item["edit_guidance_scale"], | |
"threshold": item["threshold"] | |
} | |
for item in data | |
] | |
state_dicts = {} | |
for item in sdxl_loras_raw: | |
saved_name = hf_hub_download(item["repo"], item["weights"]) | |
if not saved_name.endswith('.safetensors'): | |
state_dict = torch.load(saved_name) | |
else: | |
state_dict = load_file(saved_name) | |
state_dicts[item["repo"]] = { | |
"saved_name": saved_name, | |
"state_dict": state_dict | |
} | item | |
sd_model_id = "stabilityai/stable-diffusion-xl-base-1.0" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) | |
sd_pipe = SemanticStableDiffusionXLImg2ImgPipeline_DDPMInversion.from_pretrained(sd_model_id, | |
torch_dtype=torch.float16, variant="fp16", use_safetensors=True,vae=vae, | |
) | |
sd_pipe.scheduler = DDIMScheduler.from_config(sd_model_id, subfolder = "scheduler") | |
original_pipe = copy.deepcopy(sd_pipe) | |
sd_pipe.to(device) | |
last_lora = "" | |
last_merged = False | |
last_fused = False | |
def load_lora(sdxl_loras, random_lora_index, lora_scale = 1.0, progress=gr.Progress(track_tqdm=True)): | |
global last_lora, last_merged, last_fused, sd_pipe | |
randomize() | |
#random_lora_index = random.randrange(0, len(sdxl_loras), 1) | |
repo_name = sdxl_loras[random_lora_index]["repo"] | |
weight_name = sdxl_loras[random_lora_index]["weights"] | |
full_path_lora = state_dicts[repo_name]["saved_name"] | |
loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"]) | |
cross_attention_kwargs = None | |
print(repo_name) | |
if last_lora != repo_name: | |
if last_merged: | |
del sd_pipe | |
gc.collect() | |
sd_pipe = copy.deepcopy(original_pipe) | |
sd_pipe.to(device) | |
elif(last_fused): | |
sd_pipe.unfuse_lora() | |
sd_pipe.unload_lora_weights() | |
is_compatible = sdxl_loras[random_lora_index]["is_compatible"] | |
if is_compatible: | |
sd_pipe.load_lora_weights(loaded_state_dict) | |
sd_pipe.fuse_lora(lora_scale) | |
last_fused = True | |
else: | |
is_pivotal = sdxl_loras[random_lora_index]["is_pivotal"] | |
if(is_pivotal): | |
sd_pipe.load_lora_weights(loaded_state_dict) | |
sd_pipe.fuse_lora(lora_scale) | |
last_fused = True | |
#Add the textual inversion embeddings from pivotal tuning models | |
text_embedding_name = sdxl_loras[random_lora_index]["text_embedding_weights"] | |
text_encoders = [sd_pipe.text_encoder, sd_pipe.text_encoder_2] | |
tokenizers = [sd_pipe.tokenizer, sd_pipe.tokenizer_2] | |
embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model") | |
embhandler = TokenEmbeddingsHandler(text_encoders, tokenizers) | |
embhandler.load_embeddings(embedding_path) | |
else: | |
merge_incompatible_lora(full_path_lora, lora_scale) | |
last_fused = False | |
last_merged = True | |
print("DONE MERGING") | |
#return random_lora_index | |
## SEGA ## | |
def shuffle_lora(sdxl_loras): | |
#random_lora_index = load_lora(sdxl_loras) | |
random_lora_index = random.randrange(0, len(sdxl_loras), 1) | |
lora_repo = sdxl_loras[random_lora_index]["repo"] | |
lora_title = sdxl_loras[random_lora_index]["title"] | |
lora_desc = f"""#### LoRA used to edit this image: | |
## {lora_title} | |
by `{lora_repo.split('/')[0]}` | |
""" | |
lora_image = sdxl_loras[random_lora_index]["image"] | |
return random_lora_index, lora_image, lora_desc, gr.update(visible=True), gr.update(height=369) | |
def check_if_removed(input_image): | |
if(input_image is None): | |
return gr.Row(visible=False), gr.Column(elem_classes="output_column"), gr.Image(value=None) | |
else: | |
return gr.Row(), gr.Column(), gr.Image() | |
def block_if_removed(input_image): | |
if(input_image is None): | |
raise gr.Warning("Photo removed. Upload a new one!") | |
def edit(sdxl_loras, | |
input_image, | |
wts, zs, | |
do_inversion, | |
random_lora_index, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
show_share_button = gr.update(visible=True) | |
load_lora(sdxl_loras, random_lora_index) | |
src_prompt = "" | |
skip = 18 | |
steps = 50 | |
tar_cfg_scale = 15 | |
src_cfg_scale = 3.5 | |
tar_prompt = "" | |
if do_inversion: | |
image = load_image(input_image, device=device).to(torch.float16) | |
with inference_mode(): | |
x0 = sd_pipe.vae.encode(image).latent_dist.sample() * sd_pipe.vae.config.scaling_factor | |
# invert and retrieve noise maps and latent | |
zs_tensor, wts_tensor = sd_pipe.invert(x0, | |
source_prompt= src_prompt, | |
# source_prompt_2 = None, | |
source_guidance_scale = src_cfg_scale, | |
negative_prompt = "blurry, ugly, bad quality", | |
# negative_prompt_2 = None, | |
num_inversion_steps = steps, | |
skip_steps = skip, | |
# eta = 1.0, | |
) | |
wts = wts_tensor | |
zs = zs_tensor | |
do_inversion = False | |
latnets = wts[skip].expand(1, -1, -1, -1) | |
editing_prompt = [sdxl_loras[random_lora_index]["trigger_word"]] | |
reverse_editing_direction = [False] | |
edit_warmup_steps = [2] | |
edit_guidance_scale = [sdxl_loras[random_lora_index]["edit_guidance_scale"]] | |
edit_threshold = [sdxl_loras[random_lora_index]["threshold"]] | |
editing_args = dict( | |
editing_prompt = editing_prompt, | |
reverse_editing_direction = reverse_editing_direction, | |
edit_warmup_steps=edit_warmup_steps, | |
edit_guidance_scale=edit_guidance_scale, | |
edit_threshold=edit_threshold, | |
edit_momentum_scale=0.3, | |
edit_mom_beta=0.6, | |
eta=1,) | |
sega_out = sd_pipe(prompt=tar_prompt, latents=latnets, guidance_scale = tar_cfg_scale, | |
# num_images_per_prompt=1, | |
# num_inference_steps=steps, | |
wts=wts, zs=zs[skip:], **editing_args) | |
#lora_repo = sdxl_loras[random_lora_index]["repo"] | |
#lora_desc = f"### LoRA Used To Edit this Image: {lora_repo}' }" | |
#lora_image = sdxl_loras[random_lora_index]["image"] | |
return sega_out.images[0], wts, zs, do_inversion, gr.update(height=512), gr.Column(elem_classes="output_column_reverse") | |
def randomize_seed_fn(seed, randomize_seed): | |
if randomize_seed: | |
seed = random.randint(0, np.iinfo(np.int32).max) | |
torch.manual_seed(seed) | |
return seed | |
def randomize(): | |
seed = random.randint(0, np.iinfo(np.int32).max) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
random.seed(seed) | |
np.random.seed(seed) | |
def crop_image(image): | |
h, w, c = image.shape | |
if h < w: | |
offset = (w - h) // 2 | |
image = image[:, offset:offset + h] | |
elif w < h: | |
offset = (h - w) // 2 | |
image = image[offset:offset + w] | |
image = np.array(Image.fromarray(image).resize((1024, 1024))) | |
return image | |
######## | |
# demo # | |
######## | |
with gr.Blocks(css="style.css") as demo: | |
def reset_do_inversion(): | |
return True | |
gr.HTML("""<img style="margin: 0 auto; width: 180px; margin-bottom: .5em" src="https://i.imgur.com/A4BP6Lx.png" alt="LEDITS SDXL LoRA Photobooth">""") | |
with gr.Box(elem_id="total_box"): | |
gr.HTML( | |
"""<h1>LEDITS SDXL LoRA Photobooth</h1> | |
<h3>Smile, take a pic 📷✨ and <code>it'll be inverted on SDXL and a random SDXL LoRA will be applied</code></h3> | |
""", | |
) | |
wts = gr.State() | |
zs = gr.State() | |
reconstruction = gr.State() | |
do_inversion = gr.State(value=True) | |
gr_sdxl_loras = gr.State(value=sdxl_loras_raw) | |
gr_lora_index = gr.State() | |
with gr.Row(): | |
input_image = gr.Image(label="Input Image", interactive=True, source="webcam", height=512, width=512, elem_id="input_image") | |
with gr.Column(elem_classes="output_column") as output_column: | |
with gr.Row(visible=False) as loaded_lora: | |
lora_image = gr.Image(interactive=False, height=128, width=128, elem_id="lora_image", show_label=False, show_download_button=False) | |
lora_desc = gr.Markdown() | |
sega_edited_image = gr.Image(label=f"LEDITS Edited Image", interactive=False, elem_id="output_image", height=512, width=512) | |
with gr.Row(): | |
run_button = gr.Button("Rerun with the same picture", visible=True, elem_id="run_again") | |
run_button.click( | |
fn=shuffle_lora, | |
inputs=[gr_sdxl_loras], | |
outputs=[gr_lora_index, lora_image, lora_desc, loaded_lora, sega_edited_image], | |
queue=False | |
).then( | |
fn=edit, | |
inputs=[gr_sdxl_loras, | |
input_image, | |
wts, zs, | |
do_inversion, | |
gr_lora_index | |
], | |
outputs=[sega_edited_image, wts, zs, do_inversion, sega_edited_image, output_column]) | |
input_image.change( | |
fn = check_if_removed, | |
inputs = [input_image], | |
outputs = [loaded_lora, output_column, sega_edited_image], | |
queue=False, | |
show_progress=False | |
).then( | |
fn = block_if_removed, | |
inputs = [input_image], | |
queue=False, | |
show_progress=False | |
).success( | |
fn = reset_do_inversion, | |
outputs = [do_inversion], | |
queue = False).then( | |
fn=shuffle_lora, | |
inputs=[gr_sdxl_loras], | |
outputs=[gr_lora_index, lora_image, lora_desc, loaded_lora, sega_edited_image], | |
queue=False | |
).then( | |
fn=edit, | |
inputs=[gr_sdxl_loras, | |
input_image, | |
wts, zs, | |
do_inversion, | |
gr_lora_index | |
], | |
outputs=[sega_edited_image, wts, zs, do_inversion, sega_edited_image, output_column] | |
) | |
gr.HTML(''' | |
<img src="https://iccv2023.thecvf.com/img/LogoICCV23V04.svg" width="400" style="margin: 0 auto; display: none" id='iccv_logo' /> | |
''') | |
demo.queue() | |
demo.launch() |