multimodalart's picture
Update app.py
d271579
raw
history blame
17.2 kB
import gradio as gr
import torch
import numpy as np
import requests
import random
from io import BytesIO
# from utils import *
# from constants import *
from pipeline_semantic_stable_diffusion_xl_img2img_ddpm import *
from torch import inference_mode
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, AutoencoderKL
from diffusers import DDIMScheduler
# from share_btn import community_icon_html, loading_icon_html, share_js
import torch
from huggingface_hub import hf_hub_download
from diffusers import DiffusionPipeline
from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler
import json
from safetensors.torch import load_file
# import lora
import copy
import json
import gc
import random
from time import sleep
from pathlib import Path
from uuid import uuid4
IMAGE_DATASET_DIR = Path("image_dataset") / f"train-{uuid4()}"
IMAGE_DATASET_DIR.mkdir(parents=True, exist_ok=True)
IMAGE_JSONL_PATH = IMAGE_DATASET_DIR / "metadata.jsonl"
with open("sdxl_loras.json", "r") as file:
data = json.load(file)
sdxl_loras_raw = [
{
"image": item["image"],
"title": item["title"],
"repo": item["repo"],
"trigger_word": item["trigger_word"],
"weights": item["weights"],
"is_compatible": item["is_compatible"],
"is_pivotal": item.get("is_pivotal", False),
"text_embedding_weights": item.get("text_embedding_weights", None),
# "likes": item.get("likes", 0),
# "downloads": item.get("downloads", 0),
"is_nc": item.get("is_nc", False),
"edit_guidance_scale": item["edit_guidance_scale"],
"threshold": item["threshold"]
}
for item in data
]
state_dicts = {}
for item in sdxl_loras_raw:
saved_name = hf_hub_download(item["repo"], item["weights"])
if not saved_name.endswith('.safetensors'):
state_dict = torch.load(saved_name)
else:
state_dict = load_file(saved_name)
state_dicts[item["repo"]] = {
"saved_name": saved_name,
"state_dict": state_dict
} | item
sd_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
sd_pipe = SemanticStableDiffusionXLImg2ImgPipeline_DDPMInversion.from_pretrained(sd_model_id,
torch_dtype=torch.float16, variant="fp16", use_safetensors=True,vae=vae,
)
sd_pipe.scheduler = DDIMScheduler.from_config(sd_model_id, subfolder = "scheduler")
original_pipe = copy.deepcopy(sd_pipe)
sd_pipe.to(device)
last_lora = ""
last_merged = False
last_fused = False
def load_lora(sdxl_loras, random_lora_index, lora_scale = 1.0, progress=gr.Progress(track_tqdm=True)):
global last_lora, last_merged, last_fused, sd_pipe
#random_lora_index = random.randrange(0, len(sdxl_loras), 1)
print(random_lora_index)
#print(sdxl_loras)
repo_name = sdxl_loras[random_lora_index]["repo"]
weight_name = sdxl_loras[random_lora_index]["weights"]
full_path_lora = state_dicts[repo_name]["saved_name"]
loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"])
cross_attention_kwargs = None
print(repo_name)
if last_lora != repo_name:
if last_merged:
del sd_pipe
gc.collect()
sd_pipe = copy.deepcopy(original_pipe)
sd_pipe.to(device)
elif(last_fused):
sd_pipe.unfuse_lora()
sd_pipe.unload_lora_weights()
is_compatible = sdxl_loras[random_lora_index]["is_compatible"]
if is_compatible:
sd_pipe.load_lora_weights(loaded_state_dict)
sd_pipe.fuse_lora(lora_scale)
last_fused = True
else:
is_pivotal = sdxl_loras[random_lora_index]["is_pivotal"]
if(is_pivotal):
sd_pipe.load_lora_weights(loaded_state_dict)
sd_pipe.fuse_lora(lora_scale)
last_fused = True
#Add the textual inversion embeddings from pivotal tuning models
text_embedding_name = sdxl_loras[random_lora_index]["text_embedding_weights"]
text_encoders = [sd_pipe.text_encoder, sd_pipe.text_encoder_2]
tokenizers = [sd_pipe.tokenizer, sd_pipe.tokenizer_2]
embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model")
embhandler = TokenEmbeddingsHandler(text_encoders, tokenizers)
embhandler.load_embeddings(embedding_path)
else:
merge_incompatible_lora(full_path_lora, lora_scale)
last_fused = False
last_merged = True
print("DONE MERGING")
#return random_lora_index
## SEGA ##
def shuffle_lora(sdxl_loras, selected_lora=None, chosen_prompt=""):
print("selected_lora in shuffle_lora", selected_lora)
if(selected_lora is not None):
random_lora_index = selected_lora
else:
random_lora_index = random.randrange(0, len(sdxl_loras), 1)
print("random_lora_index in shuffle_lora: ", random_lora_index)
if(chosen_prompt):
spooky_concept = chosen_prompt
else:
spooky_concept = random.choice([' spooky witch', ' spooky vampire', ' spooky werewolf', ' spooky ghost', ' spooky wizard', ' spooky pumpkin', ' spooky wizard', 'spooky skeleton'])
lora_repo = sdxl_loras[random_lora_index]["repo"]
lora_title = sdxl_loras[random_lora_index]["title"]
lora_desc = f"""#### LoRA used to edit this image:
### {lora_title}
prompt: {spooky_concept}
by `{lora_repo.split('/')[0]}`
"""
lora_image = sdxl_loras[random_lora_index]["image"]
return gr.update(), random_lora_index, lora_image, lora_desc, gr.update(visible=True), gr.update(height=369), spooky_concept
def check_if_removed(input_image):
if(input_image is None):
return gr.Row(visible=False), gr.Column(elem_classes="output_column"), gr.Image(value=None), gr.State(value=None), gr.Column(visible=False)
else:
return gr.Row(), gr.Column(), gr.Image(), None, gr.Column()
def block_if_removed(input_image):
if(input_image is None):
raise gr.Warning("Photo removed. Upload a new one!")
def select_lora(selected_state: gr.SelectData, sdxl_loras, chosen_prompt):
random_lora_index = selected_state.index
if(chosen_prompt):
spooky_concept = chosen_prompt
else:
spooky_concept = random.choice([' spooky witch', ' spooky vampire', ' spooky werewolf', ' spooky ghost', ' spooky wizard', ' spooky pumpkin', ' spooky wizard', 'spooky skeleton'])
lora_repo = sdxl_loras[random_lora_index]["repo"]
lora_title = sdxl_loras[random_lora_index]["title"]
lora_desc = f"""#### LoRA used to edit this image:
### {lora_title}
prompt: {spooky_concept}
by `{lora_repo.split('/')[0]}`
"""
lora_image = sdxl_loras[random_lora_index]["image"]
return random_lora_index, lora_image, lora_desc, spooky_concept
def edit(sdxl_loras,
input_image,
wts, zs,
do_inversion,
random_lora_index,
spooky_concept,
progress=gr.Progress(track_tqdm=True)
):
show_share_button = gr.update(visible=True)
print("random_lora_index in edit: ", random_lora_index)
load_lora(sdxl_loras, random_lora_index)
src_prompt = ""
skip = 18
steps = 50
tar_cfg_scale = 15
src_cfg_scale = 3.5
tar_prompt = ""
print("Is do_inversion?", do_inversion)
if do_inversion:
image = load_image(input_image, device=device).to(torch.float16)
with inference_mode():
x0 = sd_pipe.vae.encode(image).latent_dist.sample() * sd_pipe.vae.config.scaling_factor
# invert and retrieve noise maps and latent
zs_tensor, wts_tensor = sd_pipe.invert(x0,
source_prompt= src_prompt,
# source_prompt_2 = None,
source_guidance_scale = src_cfg_scale,
negative_prompt = "blurry, bad quality",
# negative_prompt_2 = None,
num_inversion_steps = steps,
skip_steps = skip,
# eta = 1.0,
)
wts = wts_tensor
zs = zs_tensor
do_inversion = False
latnets = wts[skip].expand(1, -1, -1, -1)
# spooky_concept = random.choice([' spooky witch', ' spooky vampire', ' spooky werewolf', ' spooky ghost', ' spooky wizard', ' spooky pumpkin'])
print("spooky concept is: ", spooky_concept)
editing_prompt = [sdxl_loras[random_lora_index]["trigger_word"]+ spooky_concept]
reverse_editing_direction = [False]
edit_warmup_steps = [2]
edit_guidance_scale = [sdxl_loras[random_lora_index]["edit_guidance_scale"]]
edit_threshold = [sdxl_loras[random_lora_index]["threshold"]]
editing_args = dict(
editing_prompt = editing_prompt,
reverse_editing_direction = reverse_editing_direction,
edit_warmup_steps=edit_warmup_steps,
edit_guidance_scale=edit_guidance_scale,
edit_threshold=edit_threshold,
edit_momentum_scale=0.3,
edit_mom_beta=0.6,
eta=1,)
torch.manual_seed(torch.seed())
sega_out = sd_pipe(prompt=tar_prompt, latents=latnets, guidance_scale = tar_cfg_scale,
# num_images_per_prompt=1,
# num_inference_steps=steps,
wts=wts, zs=zs[skip:], **editing_args)
#lora_repo = sdxl_loras[random_lora_index]["repo"]
#lora_desc = f"### LoRA Used To Edit this Image: {lora_repo}' }"
#lora_image = sdxl_loras[random_lora_index]["image"]
return sega_out.images[0], wts, zs, do_inversion, gr.update(height=512), gr.Column(elem_classes="output_column_reverse"), gr.Row(visible=True)
def randomize_seed_fn(seed, randomize_seed):
if randomize_seed:
seed = random.randint(0, np.iinfo(np.int32).max)
torch.manual_seed(seed)
return seed
def randomize():
seed = random.randint(0, np.iinfo(np.int32).max)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def crop_image(image):
h, w, c = image.shape
if h < w:
offset = (w - h) // 2
image = image[:, offset:offset + h]
elif w < h:
offset = (h - w) // 2
image = image[offset:offset + w]
image = np.array(Image.fromarray(image).resize((1024, 1024)))
return image
def save_preferences(sdxl_loras, selected_lora, input_image, result_image):
lora_id = sdxl_loras[selected_lora]["repo"]
uuid = uuid4()
input_image_path = IMAGE_DATASET_DIR / f"{uuid}-input.png"
output_image_path = IMAGE_DATASET_DIR / f"{uuid}-output.png"
Image.fromarray(input_image).save(input_image_path)
Image.fromarray(result_image).save(output_image_path)
with IMAGE_JSONL_PATH.open("a") as f:
json.dump({"selected_lora": lora_id, "input_image":input_image_path.name, "result_image":output_image_path.name}, f)
f.write("\n")
########r
# demo #
########
with gr.Blocks(css="style.css") as demo:
def reset_do_inversion():
return True
gr.HTML("""<img style="margin: 0 auto; width: 220px;" src="https://i.imgur.com/zDcvSbg.png" alt="LEDITS SDXL LoRA Photobooth">""")
gr.HTML("""<img style="margin: 0 auto; width: 300px; margin-bottom: .5em; margin-top:-35px" src="https://i.imgur.com/hvZOBzY.png" alt="LEDITS SDXL LoRA Photobooth">""")
with gr.Box(elem_id="total_box"):
gr.HTML(
"""
<h3>Smile, take a pic 📷✨ and <code>it'll be inverted on SDXL and a random SDXL LoRA will be applied</code></h3>
""",
)
wts = gr.State()
zs = gr.State()
reconstruction = gr.State()
do_inversion = gr.State(value=True)
gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
gr_lora_index = gr.State()
gr_picked_lora = gr.State()
spooky_concept = gr.State()
with gr.Row():
input_image = gr.Image(label="Input Image", interactive=True, source="webcam", height=512, width=512, elem_id="input_image")
with gr.Column(elem_classes="output_column") as output_column:
with gr.Row(visible=False) as loaded_lora:
lora_image = gr.Image(interactive=False, height=128, width=128, elem_id="lora_image", show_label=False, show_download_button=False)
lora_desc = gr.Markdown()
sega_edited_image = gr.Image(label=f"LEDITS Edited Image", interactive=False, elem_id="output_image", height=512, width=512)
with gr.Column(visible=False) as buttons_area:
with gr.Row(elem_id="buttons_area"):
print_button = gr.HTML('<button onclick="window.print()" class="lg secondary svelte-cmf5ev" id="print" style="width: 100%;margin-bottom:0">Save PDF/Print 🖨️</button>')
run_button = gr.Button("Regenerate with the same picture 🖼️🎲", elem_id="run_again")
with gr.Accordion("Tired of randomizing? Pick your prompt and LoRA", open=False, elem_id="pick", ):
choose_prompt = gr.Textbox(label="Spooky Prompt", value="")
choose_gallery = gr.Gallery(
value=[(item["image"], item["title"]) for item in sdxl_loras_raw],
allow_preview=False,
columns=6,
elem_id="gallery",
show_share_button=False
)
choose_gallery.select(
fn=select_lora,
inputs=[gr_sdxl_loras, choose_prompt],
outputs=[gr_picked_lora, lora_image, lora_desc, spooky_concept],
queue=False
)
run_button.click(
fn=shuffle_lora,
inputs=[gr_sdxl_loras, gr_picked_lora, choose_prompt],
outputs=[sega_edited_image, gr_lora_index, lora_image, lora_desc, loaded_lora, sega_edited_image, spooky_concept],
queue=False
).then(
fn=edit,
inputs=[gr_sdxl_loras,
input_image,
wts, zs,
do_inversion,
gr_lora_index, spooky_concept
],
outputs=[sega_edited_image, wts, zs, do_inversion, sega_edited_image, output_column, buttons_area]
).then(
fn=save_preferences,
inputs=[gr_sdxl_loras, gr_lora_index, input_image, sega_edited_image],
queue=False
)
input_image.change(
fn = check_if_removed,
inputs = [input_image],
outputs = [loaded_lora, output_column, sega_edited_image, gr_picked_lora, buttons_area],
queue=False,
show_progress=False
).then(
fn = block_if_removed,
inputs = [input_image],
queue=False,
show_progress=False
).success(
fn = reset_do_inversion,
outputs = [do_inversion],
queue = False).then(
fn=shuffle_lora,
inputs=[gr_sdxl_loras, gr_picked_lora, choose_prompt],
outputs=[sega_edited_image, gr_lora_index, lora_image, lora_desc, loaded_lora, sega_edited_image, spooky_concept],
queue=False
).then(
fn=edit,
inputs=[gr_sdxl_loras,
input_image,
wts, zs,
do_inversion,
gr_lora_index,
spooky_concept
],
outputs=[sega_edited_image, wts, zs, do_inversion, sega_edited_image, output_column, buttons_area]
).then(
fn=save_preferences,
inputs=[gr_sdxl_loras, gr_lora_index, input_image, sega_edited_image],
queue=False
)
demo.load(None,
_js="""async () => {
let gradioURL = new URL(window.self.location.href);
let params = new URLSearchParams(gradioURL.search);
if (!params.has('__theme') || params.get('__theme') !== 'dark') {
params.set('__theme', 'dark');
gradioURL.search = params.toString();
window.self.location.replace(gradioURL.toString());
}
}""")
demo.queue()
demo.launch()