Linoy Tsaban
Update app.py
ec70079
raw
history blame
9.93 kB
import gradio as gr
import torch
import numpy as np
import requests
import random
from io import BytesIO
from pipeline_semantic_stable_diffusion_xl_img2img_ddpm import *
from torch import inference_mode
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, AutoencoderKL
from diffusers import DDIMScheduler
# from share_btn import community_icon_html, loading_icon_html, share_js
import torch
from huggingface_hub import hf_hub_download
from diffusers import DiffusionPipeline
from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler
import json
from safetensors.torch import load_file
# import lora
import copy
import json
import gc
import random
with open("sdxl_loras.json", "r") as file:
data = json.load(file)
sdxl_loras_raw = [
{
"image": item["image"],
"title": item["title"],
"repo": item["repo"],
"trigger_word": item["trigger_word"],
"weights": item["weights"],
"is_compatible": item["is_compatible"],
"is_pivotal": item.get("is_pivotal", False),
"text_embedding_weights": item.get("text_embedding_weights", None),
# "likes": item.get("likes", 0),
# "downloads": item.get("downloads", 0),
"is_nc": item.get("is_nc", False),
"edit_guidance_scale": item["edit_guidance_scale"],
"threshold": item["threshold"]
}
for item in data
]
state_dicts = {}
for item in sdxl_loras_raw:
saved_name = hf_hub_download(item["repo"], item["weights"])
if not saved_name.endswith('.safetensors'):
state_dict = torch.load(saved_name)
else:
state_dict = load_file(saved_name)
state_dicts[item["repo"]] = {
"saved_name": saved_name,
"state_dict": state_dict
} | item
sd_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
sd_pipe = SemanticStableDiffusionXLImg2ImgPipeline_DDPMInversion.from_pretrained(sd_model_id,
torch_dtype=torch.float16, variant="fp16", use_safetensors=True,vae=vae,
)
sd_pipe.scheduler = DDIMScheduler.from_config(sd_model_id, subfolder = "scheduler")
original_pipe = copy.deepcopy(sd_pipe)
sd_pipe.to(device)
last_lora = ""
last_merged = False
last_fused = False
def load_lora(sdxl_loras, lora_scale = 1.0, progress=gr.Progress(track_tqdm=True)):
global last_lora, last_merged, last_fused, sd_pipe
randomize()
random_lora_index = random.randrange(0, len(sdxl_loras), 1)
repo_name = sdxl_loras[random_lora_index]["repo"]
weight_name = sdxl_loras[random_lora_index]["weights"]
full_path_lora = state_dicts[repo_name]["saved_name"]
loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"])
cross_attention_kwargs = None
print(repo_name)
if last_lora != repo_name:
if last_merged:
del sd_pipe
gc.collect()
sd_pipe = copy.deepcopy(original_pipe)
sd_pipe.to(device)
elif(last_fused):
sd_pipe.unfuse_lora()
sd_pipe.unload_lora_weights()
is_compatible = sdxl_loras[random_lora_index]["is_compatible"]
if is_compatible:
sd_pipe.load_lora_weights(loaded_state_dict)
sd_pipe.fuse_lora(lora_scale)
last_fused = True
else:
is_pivotal = sdxl_loras[random_lora_index]["is_pivotal"]
if(is_pivotal):
sd_pipe.load_lora_weights(loaded_state_dict)
sd_pipe.fuse_lora(lora_scale)
last_fused = True
#Add the textual inversion embeddings from pivotal tuning models
text_embedding_name = sdxl_loras[random_lora_index]["text_embedding_weights"]
text_encoders = [sd_pipe.text_encoder, sd_pipe.text_encoder_2]
tokenizers = [sd_pipe.tokenizer, sd_pipe.tokenizer_2]
embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model")
embhandler = TokenEmbeddingsHandler(text_encoders, tokenizers)
embhandler.load_embeddings(embedding_path)
else:
merge_incompatible_lora(full_path_lora, lora_scale)
last_fused = False
last_merged = True
print("DONE")
return random_lora_index
## SEGA ##
def edit(sdxl_loras,
input_image,
wts, zs,
do_inversion,
):
show_share_button = gr.update(visible=True)
random_lora_index = load_lora(sdxl_loras)
src_prompt = ""
skip = 18
steps = 50
tar_cfg_scale = 15
src_cfg_scale = 3.5
tar_prompt = ""
if do_inversion:
image = load_image(input_image, device=device).to(torch.float16)
with inference_mode():
x0 = sd_pipe.vae.encode(image).latent_dist.sample() * sd_pipe.vae.config.scaling_factor
# invert and retrieve noise maps and latent
zs_tensor, wts_tensor = sd_pipe.invert(x0,
source_prompt= src_prompt,
# source_prompt_2 = None,
source_guidance_scale = src_cfg_scale,
negative_prompt = "blurry, ugly, bad quality",
# negative_prompt_2 = None,
num_inversion_steps = steps,
skip_steps = skip,
# eta = 1.0,
)
wts = gr.State(value=wts_tensor)
zs = gr.State(value=zs_tensor)
do_inversion = False
latnets = wts.value[skip].expand(1, -1, -1, -1)
editing_prompt = [sdxl_loras[random_lora_index]["trigger_word"]]
reverse_editing_direction = [False]
edit_warmup_steps = [2]
edit_guidance_scale = [sdxl_loras[random_lora_index]["edit_guidance_scale"]]
edit_threshold = [sdxl_loras[random_lora_index]["threshold"]]
editing_args = dict(
editing_prompt = editing_prompt,
reverse_editing_direction = reverse_editing_direction,
edit_warmup_steps=edit_warmup_steps,
edit_guidance_scale=edit_guidance_scale,
edit_threshold=edit_threshold,
edit_momentum_scale=0.3,
edit_mom_beta=0.6,
eta=1,)
sega_out = sd_pipe(prompt=tar_prompt, latents=latnets, guidance_scale = tar_cfg_scale,
# num_images_per_prompt=1,
# num_inference_steps=steps,
wts=wts.value, zs=zs.value[skip:], **editing_args)
lora_repo = sdxl_loras[random_lora_index]["repo"]
lora_desc = f"### LoRA Used To Edit this Image: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨ {'(non-commercial LoRA, `cc-by-nc`)' if sdxl_loras[random_lora_index]['is_nc'] else '' }"
lora_image = sdxl_loras[random_lora_index]["image"]
return sega_out.images[0], wts, zs, do_inversion, lora_image, lora_desc, gr.update(visible=True), gr.update(visible=True)
def randomize_seed_fn(seed, randomize_seed):
if randomize_seed:
seed = random.randint(0, np.iinfo(np.int32).max)
torch.manual_seed(seed)
return seed
def randomize():
seed = random.randint(0, np.iinfo(np.int32).max)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def crop_image(image):
h, w, c = image.shape
if h < w:
offset = (w - h) // 2
image = image[:, offset:offset + h]
elif w < h:
offset = (h - w) // 2
image = image[offset:offset + w]
image = np.array(Image.fromarray(image).resize((1024, 1024)))
return image
########
# demo #
########
with gr.Blocks(css="style.css") as demo:
def reset_do_inversion():
return True
gr.HTML(
"""<h1><img src="https://i.imgur.com/jpMRW5y.png" alt="LEDITS LoRA Photobooth"></h1>""",
)
wts = gr.State()
zs = gr.State()
reconstruction = gr.State()
do_inversion = gr.State(value=True)
gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
with gr.Row():
input_image = gr.Image(label="Input Image", interactive=True, source="webcam", height=512, width=512)
sega_edited_image = gr.Image(label=f"LEDITS Edited Image", interactive=False, elem_id="output_image", height=512, width=512)
# input_image.style(height=365, width=365)
# sega_edited_image.style(height=365, width=365)
with gr.Row():
lora_image = gr.Image(interactive=False, height=128, width=128, visible=False)
lora_desc = gr.HTML(visible=False)
with gr.Row():
run_button = gr.Button("Run again!", visible=True)
run_button.click(
fn=edit,
inputs=[gr_sdxl_loras,
input_image,
wts, zs,
do_inversion,
],
outputs=[sega_edited_image, wts, zs, do_inversion, lora_image, lora_desc, lora_image, lora_desc])
input_image.change(
fn = reset_do_inversion,
outputs = [do_inversion],
queue = False).then(
fn=edit,
inputs=[gr_sdxl_loras,
input_image,
wts, zs,
do_inversion,
],
outputs=[sega_edited_image, wts, zs, do_inversion, lora_image, lora_desc, lora_image, lora_desc]
)
demo.queue()
demo.launch(share=True)