import gradio as gr import torch import numpy as np import requests import random from io import BytesIO # from utils import * # from constants import * from pipeline_semantic_stable_diffusion_xl_img2img_ddpm import * from torch import inference_mode from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, AutoencoderKL from diffusers import DDIMScheduler # from share_btn import community_icon_html, loading_icon_html, share_js import torch from huggingface_hub import hf_hub_download from diffusers import DiffusionPipeline from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler import json from safetensors.torch import load_file # import lora import copy import json import gc import random from time import sleep from pathlib import Path from uuid import uuid4 IMAGE_DATASET_DIR = Path("image_dataset") / f"train-{uuid4()}" IMAGE_DATASET_DIR.mkdir(parents=True, exist_ok=True) IMAGE_JSONL_PATH = IMAGE_DATASET_DIR / "metadata.jsonl" with open("sdxl_loras.json", "r") as file: data = json.load(file) sdxl_loras_raw = [ { "image": item["image"], "title": item["title"], "repo": item["repo"], "trigger_word": item["trigger_word"], "weights": item["weights"], "is_compatible": item["is_compatible"], "is_pivotal": item.get("is_pivotal", False), "text_embedding_weights": item.get("text_embedding_weights", None), # "likes": item.get("likes", 0), # "downloads": item.get("downloads", 0), "is_nc": item.get("is_nc", False), "edit_guidance_scale": item["edit_guidance_scale"], "threshold": item["threshold"] } for item in data ] state_dicts = {} for item in sdxl_loras_raw: saved_name = hf_hub_download(item["repo"], item["weights"]) if not saved_name.endswith('.safetensors'): state_dict = torch.load(saved_name) else: state_dict = load_file(saved_name) state_dicts[item["repo"]] = { "saved_name": saved_name, "state_dict": state_dict } | item sd_model_id = "stabilityai/stable-diffusion-xl-base-1.0" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) sd_pipe = SemanticStableDiffusionXLImg2ImgPipeline_DDPMInversion.from_pretrained(sd_model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True,vae=vae, ) sd_pipe.scheduler = DDIMScheduler.from_config(sd_model_id, subfolder = "scheduler") original_pipe = copy.deepcopy(sd_pipe) sd_pipe.to(device) last_lora = "" last_merged = False last_fused = False def load_lora(sdxl_loras, random_lora_index, lora_scale = 1.0, progress=gr.Progress(track_tqdm=True)): global last_lora, last_merged, last_fused, sd_pipe #random_lora_index = random.randrange(0, len(sdxl_loras), 1) print(random_lora_index) #print(sdxl_loras) repo_name = sdxl_loras[random_lora_index]["repo"] weight_name = sdxl_loras[random_lora_index]["weights"] full_path_lora = state_dicts[repo_name]["saved_name"] loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"]) cross_attention_kwargs = None print(repo_name) if last_lora != repo_name: if last_merged: del sd_pipe gc.collect() sd_pipe = copy.deepcopy(original_pipe) sd_pipe.to(device) elif(last_fused): sd_pipe.unfuse_lora() sd_pipe.unload_lora_weights() is_compatible = sdxl_loras[random_lora_index]["is_compatible"] if is_compatible: sd_pipe.load_lora_weights(loaded_state_dict) sd_pipe.fuse_lora(lora_scale) last_fused = True else: is_pivotal = sdxl_loras[random_lora_index]["is_pivotal"] if(is_pivotal): sd_pipe.load_lora_weights(loaded_state_dict) sd_pipe.fuse_lora(lora_scale) last_fused = True #Add the textual inversion embeddings from pivotal tuning models text_embedding_name = sdxl_loras[random_lora_index]["text_embedding_weights"] text_encoders = [sd_pipe.text_encoder, sd_pipe.text_encoder_2] tokenizers = [sd_pipe.tokenizer, sd_pipe.tokenizer_2] embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model") embhandler = TokenEmbeddingsHandler(text_encoders, tokenizers) embhandler.load_embeddings(embedding_path) else: merge_incompatible_lora(full_path_lora, lora_scale) last_fused = False last_merged = True print("DONE MERGING") #return random_lora_index ## SEGA ## def shuffle_lora(sdxl_loras, selected_lora=None): print("selected_lora in shuffle_lora", selected_lora) if(selected_lora is not None): random_lora_index = selected_lora else: random_lora_index = random.randrange(0, len(sdxl_loras), 1) print("random_lora_index in shuffle_lora: ", random_lora_index) spooky_concept = random.choice([' spooky witch', ' spooky vampire', ' spooky werewolf', ' spooky ghost', ' spooky wizard', ' spooky pumpkin']) lora_repo = sdxl_loras[random_lora_index]["repo"] lora_title = sdxl_loras[random_lora_index]["title"] lora_desc = f"""#### LoRA used to edit this image: prompt: {spooky_concept} ### {lora_title} by `{lora_repo.split('/')[0]}` """ lora_image = sdxl_loras[random_lora_index]["image"] return gr.update(), random_lora_index, lora_image, lora_desc, gr.update(visible=True), gr.update(height=369), spooky_concept def check_if_removed(input_image): if(input_image is None): return gr.Row(visible=False), gr.Column(elem_classes="output_column"), gr.Image(value=None), gr.State(value=None), gr.Column(visible=False) else: return gr.Row(), gr.Column(), gr.Image(), None, gr.Column() def block_if_removed(input_image): if(input_image is None): raise gr.Warning("Photo removed. Upload a new one!") def select_lora(selected_state: gr.SelectData, sdxl_loras): random_lora_index = selected_state.index spooky_concept = random.choice([' spooky witch', ' spooky vampire', ' spooky werewolf', ' spooky ghost', ' spooky wizard', ' spooky pumpkin']) lora_repo = sdxl_loras[random_lora_index]["repo"] lora_title = sdxl_loras[random_lora_index]["title"] lora_desc = f"""#### LoRA used to edit this image: ## {lora_title} by `{lora_repo.split('/')[0]}` ### halloween concept: {spooky_concept} """ lora_image = sdxl_loras[random_lora_index]["image"] return random_lora_index, lora_image, lora_desc, spooky_concept def edit(sdxl_loras, input_image, wts, zs, do_inversion, random_lora_index, spooky_concept, progress=gr.Progress(track_tqdm=True) ): show_share_button = gr.update(visible=True) print("random_lora_index in edit: ", random_lora_index) load_lora(sdxl_loras, random_lora_index) src_prompt = "" skip = 18 steps = 50 tar_cfg_scale = 15 src_cfg_scale = 3.5 tar_prompt = "" print("Is do_inversion?", do_inversion) if do_inversion: image = load_image(input_image, device=device).to(torch.float16) with inference_mode(): x0 = sd_pipe.vae.encode(image).latent_dist.sample() * sd_pipe.vae.config.scaling_factor # invert and retrieve noise maps and latent zs_tensor, wts_tensor = sd_pipe.invert(x0, source_prompt= src_prompt, # source_prompt_2 = None, source_guidance_scale = src_cfg_scale, negative_prompt = "blurry, bad quality", # negative_prompt_2 = None, num_inversion_steps = steps, skip_steps = skip, # eta = 1.0, ) wts = wts_tensor zs = zs_tensor do_inversion = False latnets = wts[skip].expand(1, -1, -1, -1) # spooky_concept = random.choice([' spooky witch', ' spooky vampire', ' spooky werewolf', ' spooky ghost', ' spooky wizard', ' spooky pumpkin']) print("spooky concept is: ", spooky_concept) editing_prompt = [sdxl_loras[random_lora_index]["trigger_word"]+ spooky_concept] reverse_editing_direction = [False] edit_warmup_steps = [2] edit_guidance_scale = [sdxl_loras[random_lora_index]["edit_guidance_scale"]] edit_threshold = [sdxl_loras[random_lora_index]["threshold"]] editing_args = dict( editing_prompt = editing_prompt, reverse_editing_direction = reverse_editing_direction, edit_warmup_steps=edit_warmup_steps, edit_guidance_scale=edit_guidance_scale, edit_threshold=edit_threshold, edit_momentum_scale=0.3, edit_mom_beta=0.6, eta=1,) torch.manual_seed(torch.seed()) sega_out = sd_pipe(prompt=tar_prompt, latents=latnets, guidance_scale = tar_cfg_scale, # num_images_per_prompt=1, # num_inference_steps=steps, wts=wts, zs=zs[skip:], **editing_args) #lora_repo = sdxl_loras[random_lora_index]["repo"] #lora_desc = f"### LoRA Used To Edit this Image: {lora_repo}' }" #lora_image = sdxl_loras[random_lora_index]["image"] return sega_out.images[0], wts, zs, do_inversion, gr.update(height=512), gr.Column(elem_classes="output_column_reverse"), gr.Row(visible=True) def randomize_seed_fn(seed, randomize_seed): if randomize_seed: seed = random.randint(0, np.iinfo(np.int32).max) torch.manual_seed(seed) return seed def randomize(): seed = random.randint(0, np.iinfo(np.int32).max) torch.manual_seed(seed) torch.cuda.manual_seed(seed) random.seed(seed) np.random.seed(seed) def crop_image(image): h, w, c = image.shape if h < w: offset = (w - h) // 2 image = image[:, offset:offset + h] elif w < h: offset = (h - w) // 2 image = image[offset:offset + w] image = np.array(Image.fromarray(image).resize((1024, 1024))) return image def save_preferences(sdxl_loras, selected_lora, input_image, result_image): lora_id = sdxl_loras[selected_lora]["repo"] uuid = uuid4() input_image_path = IMAGE_DATASET_DIR / f"{uuid}-input.png" output_image_path = IMAGE_DATASET_DIR / f"{uuid}-output.png" Image.fromarray(input_image).save(input_image_path) Image.fromarray(result_image).save(output_image_path) with IMAGE_JSONL_PATH.open("a") as f: json.dump({"selected_lora": lora_id, "input_image":input_image_path.name, "result_image":output_image_path.name}, f) f.write("\n") ########r # demo # ######## with gr.Blocks(css="style.css") as demo: def reset_do_inversion(): return True gr.HTML("""LEDITS SDXL LoRA Photobooth""") gr.HTML("""LEDITS SDXL LoRA Photobooth""") with gr.Box(elem_id="total_box"): gr.HTML( """

Smile, take a pic 📷✨ and it'll be inverted on SDXL and a random SDXL LoRA will be applied

""", ) wts = gr.State() zs = gr.State() reconstruction = gr.State() do_inversion = gr.State(value=True) gr_sdxl_loras = gr.State(value=sdxl_loras_raw) gr_lora_index = gr.State() gr_picked_lora = gr.State() spooky_concept = gr.State() with gr.Row(): input_image = gr.Image(label="Input Image", interactive=True, source="webcam", height=512, width=512, elem_id="input_image") with gr.Column(elem_classes="output_column") as output_column: with gr.Row(visible=False) as loaded_lora: lora_image = gr.Image(interactive=False, height=128, width=128, elem_id="lora_image", show_label=False, show_download_button=False) lora_desc = gr.Markdown() sega_edited_image = gr.Image(label=f"LEDITS Edited Image", interactive=False, elem_id="output_image", height=512, width=512) with gr.Column(visible=False) as buttons_area: with gr.Row(elem_id="buttons_area"): print_button = gr.HTML('') run_button = gr.Button("Regenerate with the same picture 🖼️🎲", elem_id="run_again") with gr.Accordion("Tired of randomizing? Pick your LoRA", open=False, elem_id="pick", ): choose_gallery = gr.Gallery( value=[(item["image"], item["title"]) for item in sdxl_loras_raw], allow_preview=False, columns=6, elem_id="gallery", show_share_button=False ) choose_gallery.select( fn=select_lora, inputs=gr_sdxl_loras, outputs=[gr_picked_lora, lora_image, lora_desc, spooky_concept], queue=False ) run_button.click( fn=shuffle_lora, inputs=[gr_sdxl_loras, gr_picked_lora], outputs=[sega_edited_image, gr_lora_index, lora_image, lora_desc, loaded_lora, sega_edited_image, spooky_concept], queue=False ).then( fn=edit, inputs=[gr_sdxl_loras, input_image, wts, zs, do_inversion, gr_lora_index, spooky_concept ], outputs=[sega_edited_image, wts, zs, do_inversion, sega_edited_image, output_column, buttons_area] ).then( fn=save_preferences, inputs=[gr_sdxl_loras, gr_lora_index, input_image, sega_edited_image], queue=False ) input_image.change( fn = check_if_removed, inputs = [input_image], outputs = [loaded_lora, output_column, sega_edited_image, gr_picked_lora, buttons_area], queue=False, show_progress=False ).then( fn = block_if_removed, inputs = [input_image], queue=False, show_progress=False ).success( fn = reset_do_inversion, outputs = [do_inversion], queue = False).then( fn=shuffle_lora, inputs=[gr_sdxl_loras, gr_picked_lora], outputs=[sega_edited_image, gr_lora_index, lora_image, lora_desc, loaded_lora, sega_edited_image, spooky_concept], queue=False ).then( fn=edit, inputs=[gr_sdxl_loras, input_image, wts, zs, do_inversion, gr_lora_index, spooky_concept ], outputs=[sega_edited_image, wts, zs, do_inversion, sega_edited_image, output_column, buttons_area] ).then( fn=save_preferences, inputs=[gr_sdxl_loras, gr_lora_index, input_image, sega_edited_image], queue=False ) demo.load(None, _js="""async () => { let gradioURL = new URL(window.self.location.href); let params = new URLSearchParams(gradioURL.search); if (!params.has('__theme') || params.get('__theme') !== 'dark') { params.set('__theme', 'dark'); gradioURL.search = params.toString(); window.self.location.replace(gradioURL.toString()); } }""") demo.queue() demo.launch()