Spaces:
Runtime error
Runtime error
Commit
·
c537c3e
1
Parent(s):
fb3dc43
Update app.py
Browse files
app.py
CHANGED
@@ -10,21 +10,21 @@ from pipeline_semantic_stable_diffusion_xl_img2img_ddpm import *
|
|
10 |
from torch import inference_mode
|
11 |
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, AutoencoderKL
|
12 |
from diffusers import DDIMScheduler
|
13 |
-
|
14 |
import torch
|
15 |
from huggingface_hub import hf_hub_download
|
16 |
from diffusers import DiffusionPipeline
|
17 |
-
from
|
18 |
import json
|
19 |
from safetensors.torch import load_file
|
20 |
-
|
21 |
import copy
|
22 |
import json
|
23 |
import gc
|
24 |
import random
|
25 |
from time import sleep
|
26 |
|
27 |
-
with open("sdxl_loras.json", "r") as file:
|
28 |
data = json.load(file)
|
29 |
sdxl_loras_raw = [
|
30 |
{
|
@@ -79,10 +79,9 @@ last_fused = False
|
|
79 |
|
80 |
def load_lora(sdxl_loras, random_lora_index, lora_scale = 1.0, progress=gr.Progress(track_tqdm=True)):
|
81 |
global last_lora, last_merged, last_fused, sd_pipe
|
82 |
-
|
83 |
-
randomize()
|
84 |
#random_lora_index = random.randrange(0, len(sdxl_loras), 1)
|
85 |
-
|
|
|
86 |
repo_name = sdxl_loras[random_lora_index]["repo"]
|
87 |
weight_name = sdxl_loras[random_lora_index]["weights"]
|
88 |
|
@@ -130,9 +129,13 @@ def load_lora(sdxl_loras, random_lora_index, lora_scale = 1.0, progress=gr.Progr
|
|
130 |
|
131 |
|
132 |
## SEGA ##
|
133 |
-
def shuffle_lora(sdxl_loras):
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
|
|
136 |
lora_repo = sdxl_loras[random_lora_index]["repo"]
|
137 |
lora_title = sdxl_loras[random_lora_index]["title"]
|
138 |
lora_desc = f"""#### LoRA used to edit this image:
|
@@ -140,18 +143,22 @@ def shuffle_lora(sdxl_loras):
|
|
140 |
by `{lora_repo.split('/')[0]}`
|
141 |
"""
|
142 |
lora_image = sdxl_loras[random_lora_index]["image"]
|
143 |
-
return random_lora_index, lora_image, lora_desc, gr.update(visible=True), gr.update(height=369)
|
144 |
|
145 |
def check_if_removed(input_image):
|
146 |
if(input_image is None):
|
147 |
-
return gr.Row(visible=False), gr.Column(elem_classes="output_column"), gr.Image(value=None)
|
148 |
else:
|
149 |
-
return gr.Row(), gr.Column(), gr.Image()
|
150 |
|
151 |
def block_if_removed(input_image):
|
152 |
if(input_image is None):
|
153 |
raise gr.Warning("Photo removed. Upload a new one!")
|
154 |
|
|
|
|
|
|
|
|
|
155 |
def edit(sdxl_loras,
|
156 |
input_image,
|
157 |
wts, zs,
|
@@ -160,7 +167,7 @@ def edit(sdxl_loras,
|
|
160 |
progress=gr.Progress(track_tqdm=True)
|
161 |
):
|
162 |
show_share_button = gr.update(visible=True)
|
163 |
-
|
164 |
load_lora(sdxl_loras, random_lora_index)
|
165 |
|
166 |
src_prompt = ""
|
@@ -169,7 +176,7 @@ def edit(sdxl_loras,
|
|
169 |
tar_cfg_scale = 15
|
170 |
src_cfg_scale = 3.5
|
171 |
tar_prompt = ""
|
172 |
-
|
173 |
if do_inversion:
|
174 |
image = load_image(input_image, device=device).to(torch.float16)
|
175 |
with inference_mode():
|
@@ -209,7 +216,7 @@ def edit(sdxl_loras,
|
|
209 |
edit_momentum_scale=0.3,
|
210 |
edit_mom_beta=0.6,
|
211 |
eta=1,)
|
212 |
-
|
213 |
sega_out = sd_pipe(prompt=tar_prompt, latents=latnets, guidance_scale = tar_cfg_scale,
|
214 |
# num_images_per_prompt=1,
|
215 |
# num_inference_steps=steps,
|
@@ -219,7 +226,7 @@ def edit(sdxl_loras,
|
|
219 |
#lora_desc = f"### LoRA Used To Edit this Image: {lora_repo}' }"
|
220 |
#lora_image = sdxl_loras[random_lora_index]["image"]
|
221 |
|
222 |
-
return sega_out.images[0], wts, zs, do_inversion, gr.update(height=512), gr.Column(elem_classes="output_column_reverse")
|
223 |
|
224 |
|
225 |
|
@@ -252,7 +259,7 @@ def crop_image(image):
|
|
252 |
|
253 |
|
254 |
|
255 |
-
########
|
256 |
# demo #
|
257 |
########
|
258 |
|
@@ -273,6 +280,7 @@ with gr.Blocks(css="style.css") as demo:
|
|
273 |
do_inversion = gr.State(value=True)
|
274 |
gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
|
275 |
gr_lora_index = gr.State()
|
|
|
276 |
with gr.Row():
|
277 |
input_image = gr.Image(label="Input Image", interactive=True, source="webcam", height=512, width=512, elem_id="input_image")
|
278 |
with gr.Column(elem_classes="output_column") as output_column:
|
@@ -285,13 +293,26 @@ with gr.Blocks(css="style.css") as demo:
|
|
285 |
|
286 |
|
287 |
with gr.Row():
|
288 |
-
run_button = gr.Button("Rerun with the same picture",
|
289 |
-
|
290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
run_button.click(
|
292 |
fn=shuffle_lora,
|
293 |
-
inputs=[gr_sdxl_loras],
|
294 |
-
outputs=[gr_lora_index, lora_image, lora_desc, loaded_lora, sega_edited_image],
|
295 |
queue=False
|
296 |
).then(
|
297 |
fn=edit,
|
@@ -300,14 +321,13 @@ with gr.Blocks(css="style.css") as demo:
|
|
300 |
wts, zs,
|
301 |
do_inversion,
|
302 |
gr_lora_index
|
303 |
-
|
304 |
],
|
305 |
-
outputs=[sega_edited_image, wts, zs, do_inversion, sega_edited_image, output_column])
|
306 |
|
307 |
input_image.change(
|
308 |
fn = check_if_removed,
|
309 |
inputs = [input_image],
|
310 |
-
outputs = [loaded_lora, output_column, sega_edited_image],
|
311 |
queue=False,
|
312 |
show_progress=False
|
313 |
).then(
|
@@ -320,8 +340,8 @@ with gr.Blocks(css="style.css") as demo:
|
|
320 |
outputs = [do_inversion],
|
321 |
queue = False).then(
|
322 |
fn=shuffle_lora,
|
323 |
-
inputs=[gr_sdxl_loras],
|
324 |
-
outputs=[gr_lora_index, lora_image, lora_desc, loaded_lora, sega_edited_image],
|
325 |
queue=False
|
326 |
).then(
|
327 |
fn=edit,
|
@@ -331,11 +351,12 @@ with gr.Blocks(css="style.css") as demo:
|
|
331 |
do_inversion,
|
332 |
gr_lora_index
|
333 |
],
|
334 |
-
outputs=[sega_edited_image, wts, zs, do_inversion, sega_edited_image, output_column]
|
335 |
)
|
336 |
gr.HTML('''
|
337 |
-
<
|
338 |
-
|
|
|
339 |
|
340 |
|
341 |
demo.queue()
|
|
|
10 |
from torch import inference_mode
|
11 |
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, AutoencoderKL
|
12 |
from diffusers import DDIMScheduler
|
13 |
+
from share_btn import community_icon_html, loading_icon_html, share_js
|
14 |
import torch
|
15 |
from huggingface_hub import hf_hub_download
|
16 |
from diffusers import DiffusionPipeline
|
17 |
+
from cog_sdxl.dataset_and_utils import TokenEmbeddingsHandler
|
18 |
import json
|
19 |
from safetensors.torch import load_file
|
20 |
+
import lora
|
21 |
import copy
|
22 |
import json
|
23 |
import gc
|
24 |
import random
|
25 |
from time import sleep
|
26 |
|
27 |
+
with open("sdxl_loras-Copy1.json", "r") as file:
|
28 |
data = json.load(file)
|
29 |
sdxl_loras_raw = [
|
30 |
{
|
|
|
79 |
|
80 |
def load_lora(sdxl_loras, random_lora_index, lora_scale = 1.0, progress=gr.Progress(track_tqdm=True)):
|
81 |
global last_lora, last_merged, last_fused, sd_pipe
|
|
|
|
|
82 |
#random_lora_index = random.randrange(0, len(sdxl_loras), 1)
|
83 |
+
print(random_lora_index)
|
84 |
+
#print(sdxl_loras)
|
85 |
repo_name = sdxl_loras[random_lora_index]["repo"]
|
86 |
weight_name = sdxl_loras[random_lora_index]["weights"]
|
87 |
|
|
|
129 |
|
130 |
|
131 |
## SEGA ##
|
132 |
+
def shuffle_lora(sdxl_loras, selected_lora=None):
|
133 |
+
print("selected_lora in shuffle_lora", selected_lora)
|
134 |
+
if(selected_lora is not None):
|
135 |
+
random_lora_index = selected_lora
|
136 |
+
else:
|
137 |
+
random_lora_index = random.randrange(0, len(sdxl_loras), 1)
|
138 |
+
print("random_lora_index in shuffle_lora: ", random_lora_index)
|
139 |
lora_repo = sdxl_loras[random_lora_index]["repo"]
|
140 |
lora_title = sdxl_loras[random_lora_index]["title"]
|
141 |
lora_desc = f"""#### LoRA used to edit this image:
|
|
|
143 |
by `{lora_repo.split('/')[0]}`
|
144 |
"""
|
145 |
lora_image = sdxl_loras[random_lora_index]["image"]
|
146 |
+
return gr.update(), random_lora_index, lora_image, lora_desc, gr.update(visible=True), gr.update(height=369)
|
147 |
|
148 |
def check_if_removed(input_image):
|
149 |
if(input_image is None):
|
150 |
+
return gr.Row(visible=False), gr.Column(elem_classes="output_column"), gr.Image(value=None), gr.State(value=None), gr.Button(visible=False)
|
151 |
else:
|
152 |
+
return gr.Row(), gr.Column(), gr.Image(), None, gr.Button()
|
153 |
|
154 |
def block_if_removed(input_image):
|
155 |
if(input_image is None):
|
156 |
raise gr.Warning("Photo removed. Upload a new one!")
|
157 |
|
158 |
+
def select_lora(selected_state: gr.SelectData):
|
159 |
+
return selected_state.index
|
160 |
+
|
161 |
+
|
162 |
def edit(sdxl_loras,
|
163 |
input_image,
|
164 |
wts, zs,
|
|
|
167 |
progress=gr.Progress(track_tqdm=True)
|
168 |
):
|
169 |
show_share_button = gr.update(visible=True)
|
170 |
+
print("random_lora_index in edit: ", random_lora_index)
|
171 |
load_lora(sdxl_loras, random_lora_index)
|
172 |
|
173 |
src_prompt = ""
|
|
|
176 |
tar_cfg_scale = 15
|
177 |
src_cfg_scale = 3.5
|
178 |
tar_prompt = ""
|
179 |
+
print("Is do_inversion?", do_inversion)
|
180 |
if do_inversion:
|
181 |
image = load_image(input_image, device=device).to(torch.float16)
|
182 |
with inference_mode():
|
|
|
216 |
edit_momentum_scale=0.3,
|
217 |
edit_mom_beta=0.6,
|
218 |
eta=1,)
|
219 |
+
torch.manual_seed(torch.seed())
|
220 |
sega_out = sd_pipe(prompt=tar_prompt, latents=latnets, guidance_scale = tar_cfg_scale,
|
221 |
# num_images_per_prompt=1,
|
222 |
# num_inference_steps=steps,
|
|
|
226 |
#lora_desc = f"### LoRA Used To Edit this Image: {lora_repo}' }"
|
227 |
#lora_image = sdxl_loras[random_lora_index]["image"]
|
228 |
|
229 |
+
return sega_out.images[0], wts, zs, do_inversion, gr.update(height=512), gr.Column(elem_classes="output_column_reverse"), gr.Button(visible=True)
|
230 |
|
231 |
|
232 |
|
|
|
259 |
|
260 |
|
261 |
|
262 |
+
########r
|
263 |
# demo #
|
264 |
########
|
265 |
|
|
|
280 |
do_inversion = gr.State(value=True)
|
281 |
gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
|
282 |
gr_lora_index = gr.State()
|
283 |
+
gr_picked_lora = gr.State()
|
284 |
with gr.Row():
|
285 |
input_image = gr.Image(label="Input Image", interactive=True, source="webcam", height=512, width=512, elem_id="input_image")
|
286 |
with gr.Column(elem_classes="output_column") as output_column:
|
|
|
293 |
|
294 |
|
295 |
with gr.Row():
|
296 |
+
run_button = gr.Button("Rerun with the same picture", elem_id="run_again", visible=False)
|
297 |
+
with gr.Accordion("Tired of randomizing? Pick your LoRA", open=False, elem_id="pick"):
|
298 |
+
choose_gallery = gr.Gallery(
|
299 |
+
value=[(item["image"], item["title"]) for item in sdxl_loras_raw],
|
300 |
+
allow_preview=False,
|
301 |
+
columns=6,
|
302 |
+
elem_id="gallery",
|
303 |
+
show_share_button=False
|
304 |
+
)
|
305 |
+
|
306 |
+
choose_gallery.select(
|
307 |
+
fn=select_lora,
|
308 |
+
outputs=[gr_picked_lora],
|
309 |
+
queue=False
|
310 |
+
)
|
311 |
+
|
312 |
run_button.click(
|
313 |
fn=shuffle_lora,
|
314 |
+
inputs=[gr_sdxl_loras, gr_picked_lora],
|
315 |
+
outputs=[sega_edited_image, gr_lora_index, lora_image, lora_desc, loaded_lora, sega_edited_image],
|
316 |
queue=False
|
317 |
).then(
|
318 |
fn=edit,
|
|
|
321 |
wts, zs,
|
322 |
do_inversion,
|
323 |
gr_lora_index
|
|
|
324 |
],
|
325 |
+
outputs=[sega_edited_image, wts, zs, do_inversion, sega_edited_image, output_column, run_button])
|
326 |
|
327 |
input_image.change(
|
328 |
fn = check_if_removed,
|
329 |
inputs = [input_image],
|
330 |
+
outputs = [loaded_lora, output_column, sega_edited_image, gr_picked_lora, run_button],
|
331 |
queue=False,
|
332 |
show_progress=False
|
333 |
).then(
|
|
|
340 |
outputs = [do_inversion],
|
341 |
queue = False).then(
|
342 |
fn=shuffle_lora,
|
343 |
+
inputs=[gr_sdxl_loras, gr_picked_lora],
|
344 |
+
outputs=[sega_edited_image, gr_lora_index, lora_image, lora_desc, loaded_lora, sega_edited_image],
|
345 |
queue=False
|
346 |
).then(
|
347 |
fn=edit,
|
|
|
351 |
do_inversion,
|
352 |
gr_lora_index
|
353 |
],
|
354 |
+
outputs=[sega_edited_image, wts, zs, do_inversion, sega_edited_image, output_column, run_button]
|
355 |
)
|
356 |
gr.HTML('''
|
357 |
+
<h1>Hugging Face at</h1>
|
358 |
+
<img src="https://iccv2023.thecvf.com/img/LogoICCV23V04.svg" width="400" style="margin: 0 auto;" />
|
359 |
+
''', elem_id="iccv_logo")
|
360 |
|
361 |
|
362 |
demo.queue()
|