Linoy Tsaban
Update app.py
23c0e05
raw
history blame
12.4 kB
import gradio as gr
import torch
import numpy as np
import requests
import random
from io import BytesIO
# from utils import *
# from constants import *
from pipeline_semantic_stable_diffusion_xl_img2img_ddpm import *
from torch import inference_mode
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, AutoencoderKL
from diffusers import DDIMScheduler
# from share_btn import community_icon_html, loading_icon_html, share_js
import torch
from huggingface_hub import hf_hub_download
from diffusers import DiffusionPipeline
from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler
import json
from safetensors.torch import load_file
# import lora
import copy
import json
import gc
import random
from time import sleep
with open("sdxl_loras-Copy1.json", "r") as file:
data = json.load(file)
sdxl_loras_raw = [
{
"image": item["image"],
"title": item["title"],
"repo": item["repo"],
"trigger_word": item["trigger_word"],
"weights": item["weights"],
"is_compatible": item["is_compatible"],
"is_pivotal": item.get("is_pivotal", False),
"text_embedding_weights": item.get("text_embedding_weights", None),
# "likes": item.get("likes", 0),
# "downloads": item.get("downloads", 0),
"is_nc": item.get("is_nc", False),
"edit_guidance_scale": item["edit_guidance_scale"],
"threshold": item["threshold"]
}
for item in data
]
state_dicts = {}
for item in sdxl_loras_raw:
saved_name = hf_hub_download(item["repo"], item["weights"])
if not saved_name.endswith('.safetensors'):
state_dict = torch.load(saved_name)
else:
state_dict = load_file(saved_name)
state_dicts[item["repo"]] = {
"saved_name": saved_name,
"state_dict": state_dict
} | item
sd_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
sd_pipe = SemanticStableDiffusionXLImg2ImgPipeline_DDPMInversion.from_pretrained(sd_model_id,
torch_dtype=torch.float16, variant="fp16", use_safetensors=True,vae=vae,
)
sd_pipe.scheduler = DDIMScheduler.from_config(sd_model_id, subfolder = "scheduler")
original_pipe = copy.deepcopy(sd_pipe)
sd_pipe.to(device)
last_lora = ""
last_merged = False
last_fused = False
def load_lora(sdxl_loras, random_lora_index, lora_scale = 1.0, progress=gr.Progress(track_tqdm=True)):
global last_lora, last_merged, last_fused, sd_pipe
randomize()
#random_lora_index = random.randrange(0, len(sdxl_loras), 1)
repo_name = sdxl_loras[random_lora_index]["repo"]
weight_name = sdxl_loras[random_lora_index]["weights"]
full_path_lora = state_dicts[repo_name]["saved_name"]
loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"])
cross_attention_kwargs = None
print(repo_name)
if last_lora != repo_name:
if last_merged:
del sd_pipe
gc.collect()
sd_pipe = copy.deepcopy(original_pipe)
sd_pipe.to(device)
elif(last_fused):
sd_pipe.unfuse_lora()
sd_pipe.unload_lora_weights()
is_compatible = sdxl_loras[random_lora_index]["is_compatible"]
if is_compatible:
sd_pipe.load_lora_weights(loaded_state_dict)
sd_pipe.fuse_lora(lora_scale)
last_fused = True
else:
is_pivotal = sdxl_loras[random_lora_index]["is_pivotal"]
if(is_pivotal):
sd_pipe.load_lora_weights(loaded_state_dict)
sd_pipe.fuse_lora(lora_scale)
last_fused = True
#Add the textual inversion embeddings from pivotal tuning models
text_embedding_name = sdxl_loras[random_lora_index]["text_embedding_weights"]
text_encoders = [sd_pipe.text_encoder, sd_pipe.text_encoder_2]
tokenizers = [sd_pipe.tokenizer, sd_pipe.tokenizer_2]
embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model")
embhandler = TokenEmbeddingsHandler(text_encoders, tokenizers)
embhandler.load_embeddings(embedding_path)
else:
merge_incompatible_lora(full_path_lora, lora_scale)
last_fused = False
last_merged = True
print("DONE MERGING")
#return random_lora_index
## SEGA ##
def shuffle_lora(sdxl_loras):
#random_lora_index = load_lora(sdxl_loras)
random_lora_index = random.randrange(0, len(sdxl_loras), 1)
lora_repo = sdxl_loras[random_lora_index]["repo"]
lora_title = sdxl_loras[random_lora_index]["title"]
lora_desc = f"""#### LoRA used to edit this image:
## {lora_title}
by `{lora_repo.split('/')[0]}`
"""
lora_image = sdxl_loras[random_lora_index]["image"]
return random_lora_index, lora_image, lora_desc, gr.update(visible=True), gr.update(height=369)
def check_if_removed(input_image):
if(input_image is None):
return gr.Row(visible=False), gr.Column(elem_classes="output_column"), gr.Image(value=None)
else:
return gr.Row(), gr.Column(), gr.Image()
def block_if_removed(input_image):
if(input_image is None):
raise gr.Warning("Photo removed. Upload a new one!")
def edit(sdxl_loras,
input_image,
wts, zs,
do_inversion,
random_lora_index,
progress=gr.Progress(track_tqdm=True)
):
show_share_button = gr.update(visible=True)
load_lora(sdxl_loras, random_lora_index)
src_prompt = ""
skip = 18
steps = 50
tar_cfg_scale = 15
src_cfg_scale = 3.5
tar_prompt = ""
if do_inversion:
image = load_image(input_image, device=device).to(torch.float16)
with inference_mode():
x0 = sd_pipe.vae.encode(image).latent_dist.sample() * sd_pipe.vae.config.scaling_factor
# invert and retrieve noise maps and latent
zs_tensor, wts_tensor = sd_pipe.invert(x0,
source_prompt= src_prompt,
# source_prompt_2 = None,
source_guidance_scale = src_cfg_scale,
negative_prompt = "blurry, ugly, bad quality",
# negative_prompt_2 = None,
num_inversion_steps = steps,
skip_steps = skip,
# eta = 1.0,
)
wts = wts_tensor
zs = zs_tensor
do_inversion = False
latnets = wts[skip].expand(1, -1, -1, -1)
editing_prompt = [sdxl_loras[random_lora_index]["trigger_word"]]
reverse_editing_direction = [False]
edit_warmup_steps = [2]
edit_guidance_scale = [sdxl_loras[random_lora_index]["edit_guidance_scale"]]
edit_threshold = [sdxl_loras[random_lora_index]["threshold"]]
editing_args = dict(
editing_prompt = editing_prompt,
reverse_editing_direction = reverse_editing_direction,
edit_warmup_steps=edit_warmup_steps,
edit_guidance_scale=edit_guidance_scale,
edit_threshold=edit_threshold,
edit_momentum_scale=0.3,
edit_mom_beta=0.6,
eta=1,)
sega_out = sd_pipe(prompt=tar_prompt, latents=latnets, guidance_scale = tar_cfg_scale,
# num_images_per_prompt=1,
# num_inference_steps=steps,
wts=wts, zs=zs[skip:], **editing_args)
#lora_repo = sdxl_loras[random_lora_index]["repo"]
#lora_desc = f"### LoRA Used To Edit this Image: {lora_repo}' }"
#lora_image = sdxl_loras[random_lora_index]["image"]
return sega_out.images[0], wts, zs, do_inversion, gr.update(height=512), gr.Column(elem_classes="output_column_reverse")
def randomize_seed_fn(seed, randomize_seed):
if randomize_seed:
seed = random.randint(0, np.iinfo(np.int32).max)
torch.manual_seed(seed)
return seed
def randomize():
seed = random.randint(0, np.iinfo(np.int32).max)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def crop_image(image):
h, w, c = image.shape
if h < w:
offset = (w - h) // 2
image = image[:, offset:offset + h]
elif w < h:
offset = (h - w) // 2
image = image[offset:offset + w]
image = np.array(Image.fromarray(image).resize((1024, 1024)))
return image
########
# demo #
########
with gr.Blocks(css="style.css") as demo:
def reset_do_inversion():
return True
gr.HTML("""<img style="margin: 0 auto; width: 180px; margin-bottom: .5em" src="https://i.imgur.com/A4BP6Lx.png" alt="LEDITS SDXL LoRA Photobooth">""")
with gr.Box(elem_id="total_box"):
gr.HTML(
"""<h1>LEDITS SDXL LoRA Photobooth</h1>
<h3>Smile, take a pic 📷✨ and <code>it'll be inverted on SDXL and a random SDXL LoRA will be applied</code></h3>
""",
)
wts = gr.State()
zs = gr.State()
reconstruction = gr.State()
do_inversion = gr.State(value=True)
gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
gr_lora_index = gr.State()
with gr.Row():
input_image = gr.Image(label="Input Image", interactive=True, source="webcam", height=512, width=512, elem_id="input_image")
with gr.Column(elem_classes="output_column") as output_column:
with gr.Row(visible=False) as loaded_lora:
lora_image = gr.Image(interactive=False, height=128, width=128, elem_id="lora_image", show_label=False, show_download_button=False)
lora_desc = gr.Markdown()
sega_edited_image = gr.Image(label=f"LEDITS Edited Image", interactive=False, elem_id="output_image", height=512, width=512)
with gr.Row():
run_button = gr.Button("Rerun with the same picture", visible=True, elem_id="run_again")
run_button.click(
fn=shuffle_lora,
inputs=[gr_sdxl_loras],
outputs=[gr_lora_index, lora_image, lora_desc, loaded_lora, sega_edited_image],
queue=False
).then(
fn=edit,
inputs=[gr_sdxl_loras,
input_image,
wts, zs,
do_inversion,
gr_lora_index
],
outputs=[sega_edited_image, wts, zs, do_inversion, sega_edited_image, output_column])
input_image.change(
fn = check_if_removed,
inputs = [input_image],
outputs = [loaded_lora, output_column, sega_edited_image],
queue=False,
show_progress=False
).then(
fn = block_if_removed,
inputs = [input_image],
queue=False,
show_progress=False
).success(
fn = reset_do_inversion,
outputs = [do_inversion],
queue = False).then(
fn=shuffle_lora,
inputs=[gr_sdxl_loras],
outputs=[gr_lora_index, lora_image, lora_desc, loaded_lora, sega_edited_image],
queue=False
).then(
fn=edit,
inputs=[gr_sdxl_loras,
input_image,
wts, zs,
do_inversion,
gr_lora_index
],
outputs=[sega_edited_image, wts, zs, do_inversion, sega_edited_image, output_column]
)
gr.HTML('''
<img src="https://iccv2023.thecvf.com/img/LogoICCV23V04.svg" width="400" style="margin: 0 auto; display: none" id='iccv_logo' />
''')
demo.queue()
demo.launch()