import gradio as gr import torch import numpy as np import requests import random from io import BytesIO from pipeline_semantic_stable_diffusion_xl_img2img_ddpm import * from torch import inference_mode from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, AutoencoderKL from diffusers import DDIMScheduler from share_btn import community_icon_html, loading_icon_html, share_js import torch from huggingface_hub import hf_hub_download from diffusers import DiffusionPipeline from cog_sdxl.dataset_and_utils import TokenEmbeddingsHandler import json from safetensors.torch import load_file import lora import copy import json import gc import random with open("sdxl_loras.json", "r") as file: data = json.load(file) sdxl_loras_raw = [ { "image": item["image"], "title": item["title"], "repo": item["repo"], "trigger_word": item["trigger_word"], "weights": item["weights"], "is_compatible": item["is_compatible"], "is_pivotal": item.get("is_pivotal", False), "text_embedding_weights": item.get("text_embedding_weights", None), # "likes": item.get("likes", 0), # "downloads": item.get("downloads", 0), "is_nc": item.get("is_nc", False), "edit_guidance_scale": item["edit_guidance_scale"], "threshold": item["threshold"] } for item in data ] state_dicts = {} for item in sdxl_loras_raw: saved_name = hf_hub_download(item["repo"], item["weights"]) if not saved_name.endswith('.safetensors'): state_dict = torch.load(saved_name) else: state_dict = load_file(saved_name) state_dicts[item["repo"]] = { "saved_name": saved_name, "state_dict": state_dict } | item sd_model_id = "stabilityai/stable-diffusion-xl-base-1.0" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) sd_pipe = SemanticStableDiffusionXLImg2ImgPipeline_DDPMInversion.from_pretrained(sd_model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True,vae=vae, ) sd_pipe.scheduler = DDIMScheduler.from_config(sd_model_id, subfolder = "scheduler") original_pipe = copy.deepcopy(sd_pipe) sd_pipe.to(device) last_lora = "" last_merged = False last_fused = False def load_lora(sdxl_loras, lora_scale = 1.0, progress=gr.Progress(track_tqdm=True)): global last_lora, last_merged, last_fused, sd_pipe randomize() random_lora_index = random.randrange(0, len(sdxl_loras), 1) repo_name = sdxl_loras[random_lora_index]["repo"] weight_name = sdxl_loras[random_lora_index]["weights"] full_path_lora = state_dicts[repo_name]["saved_name"] loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"]) cross_attention_kwargs = None print(repo_name) if last_lora != repo_name: if last_merged: del sd_pipe gc.collect() sd_pipe = copy.deepcopy(original_pipe) sd_pipe.to(device) elif(last_fused): sd_pipe.unfuse_lora() sd_pipe.unload_lora_weights() is_compatible = sdxl_loras[random_lora_index]["is_compatible"] if is_compatible: sd_pipe.load_lora_weights(loaded_state_dict) sd_pipe.fuse_lora(lora_scale) last_fused = True else: is_pivotal = sdxl_loras[random_lora_index]["is_pivotal"] if(is_pivotal): sd_pipe.load_lora_weights(loaded_state_dict) sd_pipe.fuse_lora(lora_scale) last_fused = True #Add the textual inversion embeddings from pivotal tuning models text_embedding_name = sdxl_loras[random_lora_index]["text_embedding_weights"] text_encoders = [sd_pipe.text_encoder, sd_pipe.text_encoder_2] tokenizers = [sd_pipe.tokenizer, sd_pipe.tokenizer_2] embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model") embhandler = TokenEmbeddingsHandler(text_encoders, tokenizers) embhandler.load_embeddings(embedding_path) else: merge_incompatible_lora(full_path_lora, lora_scale) last_fused = False last_merged = True print("DONE") return random_lora_index ## SEGA ## def edit(sdxl_loras, input_image, wts, zs, do_inversion, ): show_share_button = gr.update(visible=True) random_lora_index = load_lora(sdxl_loras) src_prompt = "" skip = 18 steps = 50 tar_cfg_scale = 15 src_cfg_scale = 3.5 tar_prompt = "" if do_inversion: image = load_image(input_image, device=device).to(torch.float16) with inference_mode(): x0 = sd_pipe.vae.encode(image).latent_dist.sample() * sd_pipe.vae.config.scaling_factor # invert and retrieve noise maps and latent zs_tensor, wts_tensor = sd_pipe.invert(x0, source_prompt= src_prompt, # source_prompt_2 = None, source_guidance_scale = src_cfg_scale, negative_prompt = "blurry, ugly, bad quality", # negative_prompt_2 = None, num_inversion_steps = steps, skip_steps = skip, # eta = 1.0, ) wts = gr.State(value=wts_tensor) zs = gr.State(value=zs_tensor) do_inversion = False latnets = wts.value[skip].expand(1, -1, -1, -1) editing_prompt = [sdxl_loras[random_lora_index]["trigger_word"]] reverse_editing_direction = [False] edit_warmup_steps = [2] edit_guidance_scale = [sdxl_loras[random_lora_index]["edit_guidance_scale"]] edit_threshold = [sdxl_loras[random_lora_index]["threshold"]] editing_args = dict( editing_prompt = editing_prompt, reverse_editing_direction = reverse_editing_direction, edit_warmup_steps=edit_warmup_steps, edit_guidance_scale=edit_guidance_scale, edit_threshold=edit_threshold, edit_momentum_scale=0.3, edit_mom_beta=0.6, eta=1,) sega_out = sd_pipe(prompt=tar_prompt, latents=latnets, guidance_scale = tar_cfg_scale, # num_images_per_prompt=1, # num_inference_steps=steps, wts=wts.value, zs=zs.value[skip:], **editing_args) lora_repo = sdxl_loras[random_lora_index]["repo"] lora_desc = f"### LoRA Used To Edit this Image: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨ {'(non-commercial LoRA, `cc-by-nc`)' if sdxl_loras[random_lora_index]['is_nc'] else '' }" lora_image = sdxl_loras[random_lora_index]["image"] return sega_out.images[0], wts, zs, do_inversion, lora_image, lora_desc, gr.update(visible=True), gr.update(visible=True) def randomize_seed_fn(seed, randomize_seed): if randomize_seed: seed = random.randint(0, np.iinfo(np.int32).max) torch.manual_seed(seed) return seed def randomize(): seed = random.randint(0, np.iinfo(np.int32).max) torch.manual_seed(seed) torch.cuda.manual_seed(seed) random.seed(seed) np.random.seed(seed) def crop_image(image): h, w, c = image.shape if h < w: offset = (w - h) // 2 image = image[:, offset:offset + h] elif w < h: offset = (h - w) // 2 image = image[offset:offset + w] image = np.array(Image.fromarray(image).resize((1024, 1024))) return image ######## # demo # ######## with gr.Blocks(css="style.css") as demo: def reset_do_inversion(): return True gr.HTML( """
""", ) wts = gr.State() zs = gr.State() reconstruction = gr.State() do_inversion = gr.State(value=True) gr_sdxl_loras = gr.State(value=sdxl_loras_raw) with gr.Row(): input_image = gr.Image(label="Input Image", interactive=True, source="webcam", height=512, width=512) sega_edited_image = gr.Image(label=f"LEDITS Edited Image", interactive=False, elem_id="output_image", height=512, width=512) # input_image.style(height=365, width=365) # sega_edited_image.style(height=365, width=365) with gr.Row(): lora_image = gr.Image(interactive=False, height=128, width=128, visible=False) lora_desc = gr.HTML(visible=False) with gr.Row(): run_button = gr.Button("Run again!", visible=True) run_button.click( fn=edit, inputs=[gr_sdxl_loras, input_image, wts, zs, do_inversion, ], outputs=[sega_edited_image, wts, zs, do_inversion, lora_image, lora_desc, lora_image, lora_desc]) input_image.change( fn = reset_do_inversion, outputs = [do_inversion], queue = False).then( fn=edit, inputs=[gr_sdxl_loras, input_image, wts, zs, do_inversion, ], outputs=[sega_edited_image, wts, zs, do_inversion, lora_image, lora_desc, lora_image, lora_desc] ) demo.queue() demo.launch(share=True)