Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import numpy as np | |
import requests | |
import random | |
from io import BytesIO | |
# from utils import * | |
# from constants import * | |
from pipeline_semantic_stable_diffusion_xl_img2img_ddpm import * | |
from torch import inference_mode | |
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, AutoencoderKL | |
from diffusers import DDIMScheduler | |
# from share_btn import community_icon_html, loading_icon_html, share_js | |
import torch | |
from huggingface_hub import hf_hub_download | |
from diffusers import DiffusionPipeline | |
from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler | |
import json | |
from safetensors.torch import load_file | |
# import lora | |
import copy | |
import json | |
import gc | |
import random | |
from time import sleep | |
from pathlib import Path | |
from uuid import uuid4 | |
IMAGE_DATASET_DIR = Path("image_dataset") / f"train-{uuid4()}" | |
IMAGE_DATASET_DIR.mkdir(parents=True, exist_ok=True) | |
IMAGE_JSONL_PATH = IMAGE_DATASET_DIR / "metadata.jsonl" | |
with open("sdxl_loras.json", "r") as file: | |
data = json.load(file) | |
sdxl_loras_raw = [ | |
{ | |
"image": item["image"], | |
"title": item["title"], | |
"repo": item["repo"], | |
"trigger_word": item["trigger_word"], | |
"weights": item["weights"], | |
"is_compatible": item["is_compatible"], | |
"is_pivotal": item.get("is_pivotal", False), | |
"text_embedding_weights": item.get("text_embedding_weights", None), | |
# "likes": item.get("likes", 0), | |
# "downloads": item.get("downloads", 0), | |
"is_nc": item.get("is_nc", False), | |
"edit_guidance_scale": item["edit_guidance_scale"], | |
"threshold": item["threshold"] | |
} | |
for item in data | |
] | |
state_dicts = {} | |
for item in sdxl_loras_raw: | |
saved_name = hf_hub_download(item["repo"], item["weights"]) | |
if not saved_name.endswith('.safetensors'): | |
state_dict = torch.load(saved_name) | |
else: | |
state_dict = load_file(saved_name) | |
state_dicts[item["repo"]] = { | |
"saved_name": saved_name, | |
"state_dict": state_dict | |
} | item | |
sd_model_id = "stabilityai/stable-diffusion-xl-base-1.0" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) | |
sd_pipe = SemanticStableDiffusionXLImg2ImgPipeline_DDPMInversion.from_pretrained(sd_model_id, | |
torch_dtype=torch.float16, variant="fp16", use_safetensors=True,vae=vae, | |
) | |
sd_pipe.scheduler = DDIMScheduler.from_config(sd_model_id, subfolder = "scheduler") | |
original_pipe = copy.deepcopy(sd_pipe) | |
sd_pipe.to(device) | |
last_lora = "" | |
last_merged = False | |
last_fused = False | |
def load_lora(sdxl_loras, random_lora_index, lora_scale = 1.0, progress=gr.Progress(track_tqdm=True)): | |
global last_lora, last_merged, last_fused, sd_pipe | |
#random_lora_index = random.randrange(0, len(sdxl_loras), 1) | |
print(random_lora_index) | |
#print(sdxl_loras) | |
repo_name = sdxl_loras[random_lora_index]["repo"] | |
weight_name = sdxl_loras[random_lora_index]["weights"] | |
full_path_lora = state_dicts[repo_name]["saved_name"] | |
loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"]) | |
cross_attention_kwargs = None | |
print(repo_name) | |
if last_lora != repo_name: | |
if last_merged: | |
del sd_pipe | |
gc.collect() | |
sd_pipe = copy.deepcopy(original_pipe) | |
sd_pipe.to(device) | |
elif(last_fused): | |
sd_pipe.unfuse_lora() | |
sd_pipe.unload_lora_weights() | |
is_compatible = sdxl_loras[random_lora_index]["is_compatible"] | |
if is_compatible: | |
sd_pipe.load_lora_weights(loaded_state_dict) | |
sd_pipe.fuse_lora(lora_scale) | |
last_fused = True | |
else: | |
is_pivotal = sdxl_loras[random_lora_index]["is_pivotal"] | |
if(is_pivotal): | |
sd_pipe.load_lora_weights(loaded_state_dict) | |
sd_pipe.fuse_lora(lora_scale) | |
last_fused = True | |
#Add the textual inversion embeddings from pivotal tuning models | |
text_embedding_name = sdxl_loras[random_lora_index]["text_embedding_weights"] | |
text_encoders = [sd_pipe.text_encoder, sd_pipe.text_encoder_2] | |
tokenizers = [sd_pipe.tokenizer, sd_pipe.tokenizer_2] | |
embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model") | |
embhandler = TokenEmbeddingsHandler(text_encoders, tokenizers) | |
embhandler.load_embeddings(embedding_path) | |
else: | |
merge_incompatible_lora(full_path_lora, lora_scale) | |
last_fused = False | |
last_merged = True | |
print("DONE MERGING") | |
#return random_lora_index | |
## SEGA ## | |
def shuffle_lora(sdxl_loras, selected_lora=None, chosen_prompt=""): | |
print("selected_lora in shuffle_lora", selected_lora) | |
if(selected_lora is not None): | |
random_lora_index = selected_lora | |
else: | |
random_lora_index = random.randrange(0, len(sdxl_loras), 1) | |
print("random_lora_index in shuffle_lora: ", random_lora_index) | |
if(chosen_prompt): | |
spooky_concept = chosen_prompt | |
else: | |
spooky_concept = random.choice([' spooky witch', ' spooky vampire', ' spooky werewolf', ' spooky ghost', ' spooky wizard', ' spooky pumpkin', ' spooky wizard', 'spooky skeleton']) | |
lora_repo = sdxl_loras[random_lora_index]["repo"] | |
lora_title = sdxl_loras[random_lora_index]["title"] | |
lora_desc = f"""#### LoRA used to edit this image: | |
### {lora_title} | |
prompt: {spooky_concept} | |
by `{lora_repo.split('/')[0]}` | |
""" | |
lora_image = sdxl_loras[random_lora_index]["image"] | |
return gr.update(), random_lora_index, lora_image, lora_desc, gr.update(visible=True), gr.update(height=369), spooky_concept | |
def check_if_removed(input_image): | |
if(input_image is None): | |
return gr.Row(visible=False), gr.Column(elem_classes="output_column"), gr.Image(value=None), gr.State(value=None), gr.Column(visible=False) | |
else: | |
return gr.Row(), gr.Column(), gr.Image(), None, gr.Column() | |
def block_if_removed(input_image): | |
if(input_image is None): | |
raise gr.Warning("Photo removed. Upload a new one!") | |
def select_lora(selected_state: gr.SelectData, sdxl_loras, chosen_prompt): | |
random_lora_index = selected_state.index | |
if(chosen_prompt): | |
spooky_concept = chosen_prompt | |
else: | |
spooky_concept = random.choice([' spooky witch', ' spooky vampire', ' spooky werewolf', ' spooky ghost', ' spooky wizard', ' spooky pumpkin', ' spooky wizard', 'spooky skeleton']) | |
lora_repo = sdxl_loras[random_lora_index]["repo"] | |
lora_title = sdxl_loras[random_lora_index]["title"] | |
lora_desc = f"""#### LoRA used to edit this image: | |
### {lora_title} | |
prompt: {spooky_concept} | |
by `{lora_repo.split('/')[0]}` | |
""" | |
lora_image = sdxl_loras[random_lora_index]["image"] | |
return random_lora_index, lora_image, lora_desc, spooky_concept | |
def edit(sdxl_loras, | |
input_image, | |
wts, zs, | |
do_inversion, | |
random_lora_index, | |
spooky_concept, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
show_share_button = gr.update(visible=True) | |
print("random_lora_index in edit: ", random_lora_index) | |
load_lora(sdxl_loras, random_lora_index) | |
src_prompt = "" | |
skip = 18 | |
steps = 50 | |
tar_cfg_scale = 15 | |
src_cfg_scale = 3.5 | |
tar_prompt = "" | |
print("Is do_inversion?", do_inversion) | |
if do_inversion: | |
image = load_image(input_image, device=device).to(torch.float16) | |
with inference_mode(): | |
x0 = sd_pipe.vae.encode(image).latent_dist.sample() * sd_pipe.vae.config.scaling_factor | |
# invert and retrieve noise maps and latent | |
zs_tensor, wts_tensor = sd_pipe.invert(x0, | |
source_prompt= src_prompt, | |
# source_prompt_2 = None, | |
source_guidance_scale = src_cfg_scale, | |
negative_prompt = "blurry, bad quality", | |
# negative_prompt_2 = None, | |
num_inversion_steps = steps, | |
skip_steps = skip, | |
# eta = 1.0, | |
) | |
wts = wts_tensor | |
zs = zs_tensor | |
do_inversion = False | |
latnets = wts[skip].expand(1, -1, -1, -1) | |
# spooky_concept = random.choice([' spooky witch', ' spooky vampire', ' spooky werewolf', ' spooky ghost', ' spooky wizard', ' spooky pumpkin']) | |
print("spooky concept is: ", spooky_concept) | |
editing_prompt = [sdxl_loras[random_lora_index]["trigger_word"]+ spooky_concept] | |
reverse_editing_direction = [False] | |
edit_warmup_steps = [2] | |
edit_guidance_scale = [sdxl_loras[random_lora_index]["edit_guidance_scale"]] | |
edit_threshold = [sdxl_loras[random_lora_index]["threshold"]] | |
editing_args = dict( | |
editing_prompt = editing_prompt, | |
reverse_editing_direction = reverse_editing_direction, | |
edit_warmup_steps=edit_warmup_steps, | |
edit_guidance_scale=edit_guidance_scale, | |
edit_threshold=edit_threshold, | |
edit_momentum_scale=0.3, | |
edit_mom_beta=0.6, | |
eta=1,) | |
torch.manual_seed(torch.seed()) | |
sega_out = sd_pipe(prompt=tar_prompt, latents=latnets, guidance_scale = tar_cfg_scale, | |
# num_images_per_prompt=1, | |
# num_inference_steps=steps, | |
wts=wts, zs=zs[skip:], **editing_args) | |
#lora_repo = sdxl_loras[random_lora_index]["repo"] | |
#lora_desc = f"### LoRA Used To Edit this Image: {lora_repo}' }" | |
#lora_image = sdxl_loras[random_lora_index]["image"] | |
return sega_out.images[0], wts, zs, do_inversion, gr.update(height=512), gr.Column(elem_classes="output_column_reverse"), gr.Row(visible=True) | |
def randomize_seed_fn(seed, randomize_seed): | |
if randomize_seed: | |
seed = random.randint(0, np.iinfo(np.int32).max) | |
torch.manual_seed(seed) | |
return seed | |
def randomize(): | |
seed = random.randint(0, np.iinfo(np.int32).max) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
random.seed(seed) | |
np.random.seed(seed) | |
def crop_image(image): | |
h, w, c = image.shape | |
if h < w: | |
offset = (w - h) // 2 | |
image = image[:, offset:offset + h] | |
elif w < h: | |
offset = (h - w) // 2 | |
image = image[offset:offset + w] | |
image = np.array(Image.fromarray(image).resize((1024, 1024))) | |
return image | |
def save_preferences(sdxl_loras, selected_lora, input_image, result_image): | |
lora_id = sdxl_loras[selected_lora]["repo"] | |
uuid = uuid4() | |
input_image_path = IMAGE_DATASET_DIR / f"{uuid}-input.png" | |
output_image_path = IMAGE_DATASET_DIR / f"{uuid}-output.png" | |
Image.fromarray(input_image).save(input_image_path) | |
Image.fromarray(result_image).save(output_image_path) | |
with IMAGE_JSONL_PATH.open("a") as f: | |
json.dump({"selected_lora": lora_id, "input_image":input_image_path.name, "result_image":output_image_path.name}, f) | |
f.write("\n") | |
########r | |
# demo # | |
######## | |
with gr.Blocks(css="style.css") as demo: | |
def reset_do_inversion(): | |
return True | |
gr.HTML("""<img style="margin: 0 auto; width: 220px;" src="https://i.imgur.com/zDcvSbg.png" alt="LEDITS SDXL LoRA Photobooth">""") | |
gr.HTML("""<img style="margin: 0 auto; width: 300px; margin-bottom: .5em; margin-top:-35px" src="https://i.imgur.com/hvZOBzY.png" alt="LEDITS SDXL LoRA Photobooth">""") | |
with gr.Box(elem_id="total_box"): | |
gr.HTML( | |
""" | |
<h3>Smile, take a pic 📷✨ and <code>it'll be inverted on SDXL and a random SDXL LoRA will be applied</code></h3> | |
""", | |
) | |
wts = gr.State() | |
zs = gr.State() | |
reconstruction = gr.State() | |
do_inversion = gr.State(value=True) | |
gr_sdxl_loras = gr.State(value=sdxl_loras_raw) | |
gr_lora_index = gr.State() | |
gr_picked_lora = gr.State() | |
spooky_concept = gr.State() | |
with gr.Row(): | |
input_image = gr.Image(label="Input Image", interactive=True, source="webcam", height=512, width=512, elem_id="input_image") | |
with gr.Column(elem_classes="output_column") as output_column: | |
with gr.Row(visible=False) as loaded_lora: | |
lora_image = gr.Image(interactive=False, height=128, width=128, elem_id="lora_image", show_label=False, show_download_button=False) | |
lora_desc = gr.Markdown() | |
sega_edited_image = gr.Image(label=f"LEDITS Edited Image", interactive=False, elem_id="output_image", height=512, width=512) | |
with gr.Column(visible=False) as buttons_area: | |
with gr.Row(elem_id="buttons_area"): | |
print_button = gr.HTML('<button onclick="window.print()" class="lg secondary svelte-cmf5ev" id="print" style="width: 100%;margin-bottom:0">Save PDF/Print 🖨️</button>') | |
run_button = gr.Button("Regenerate with the same picture 🖼️🎲", elem_id="run_again") | |
with gr.Accordion("Tired of randomizing? Pick your prompt and LoRA", open=False, elem_id="pick", ): | |
choose_prompt = gr.Textbox(label="Spooky Prompt", value="") | |
choose_gallery = gr.Gallery( | |
value=[(item["image"], item["title"]) for item in sdxl_loras_raw], | |
allow_preview=False, | |
columns=6, | |
elem_id="gallery", | |
show_share_button=False | |
) | |
choose_gallery.select( | |
fn=select_lora, | |
inputs=[gr_sdxl_loras, choose_prompt], | |
outputs=[gr_picked_lora, lora_image, lora_desc, spooky_concept], | |
queue=False | |
) | |
run_button.click( | |
fn=shuffle_lora, | |
inputs=[gr_sdxl_loras, gr_picked_lora, choose_prompt], | |
outputs=[sega_edited_image, gr_lora_index, lora_image, lora_desc, loaded_lora, sega_edited_image, spooky_concept], | |
queue=False | |
).then( | |
fn=edit, | |
inputs=[gr_sdxl_loras, | |
input_image, | |
wts, zs, | |
do_inversion, | |
gr_lora_index, spooky_concept | |
], | |
outputs=[sega_edited_image, wts, zs, do_inversion, sega_edited_image, output_column, buttons_area] | |
).then( | |
fn=save_preferences, | |
inputs=[gr_sdxl_loras, gr_lora_index, input_image, sega_edited_image], | |
queue=False | |
) | |
input_image.change( | |
fn = check_if_removed, | |
inputs = [input_image], | |
outputs = [loaded_lora, output_column, sega_edited_image, gr_picked_lora, buttons_area], | |
queue=False, | |
show_progress=False | |
).then( | |
fn = block_if_removed, | |
inputs = [input_image], | |
queue=False, | |
show_progress=False | |
).success( | |
fn = reset_do_inversion, | |
outputs = [do_inversion], | |
queue = False).then( | |
fn=shuffle_lora, | |
inputs=[gr_sdxl_loras, gr_picked_lora, choose_prompt], | |
outputs=[sega_edited_image, gr_lora_index, lora_image, lora_desc, loaded_lora, sega_edited_image, spooky_concept], | |
queue=False | |
).then( | |
fn=edit, | |
inputs=[gr_sdxl_loras, | |
input_image, | |
wts, zs, | |
do_inversion, | |
gr_lora_index, | |
spooky_concept | |
], | |
outputs=[sega_edited_image, wts, zs, do_inversion, sega_edited_image, output_column, buttons_area] | |
).then( | |
fn=save_preferences, | |
inputs=[gr_sdxl_loras, gr_lora_index, input_image, sega_edited_image], | |
queue=False | |
) | |
demo.load(None, | |
_js="""async () => { | |
let gradioURL = new URL(window.self.location.href); | |
let params = new URLSearchParams(gradioURL.search); | |
if (!params.has('__theme') || params.get('__theme') !== 'dark') { | |
params.set('__theme', 'dark'); | |
gradioURL.search = params.toString(); | |
window.self.location.replace(gradioURL.toString()); | |
} | |
}""") | |
demo.queue() | |
demo.launch() |