Linoy Tsaban commited on
Commit
23c0e05
·
1 Parent(s): ec70079

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -71
app.py CHANGED
@@ -4,6 +4,8 @@ import numpy as np
4
  import requests
5
  import random
6
  from io import BytesIO
 
 
7
  from pipeline_semantic_stable_diffusion_xl_img2img_ddpm import *
8
  from torch import inference_mode
9
  from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, AutoencoderKL
@@ -20,8 +22,9 @@ import copy
20
  import json
21
  import gc
22
  import random
 
23
 
24
- with open("sdxl_loras.json", "r") as file:
25
  data = json.load(file)
26
  sdxl_loras_raw = [
27
  {
@@ -74,12 +77,12 @@ last_merged = False
74
  last_fused = False
75
 
76
 
77
- def load_lora(sdxl_loras, lora_scale = 1.0, progress=gr.Progress(track_tqdm=True)):
78
  global last_lora, last_merged, last_fused, sd_pipe
79
 
80
  randomize()
81
- random_lora_index = random.randrange(0, len(sdxl_loras), 1)
82
-
83
  repo_name = sdxl_loras[random_lora_index]["repo"]
84
  weight_name = sdxl_loras[random_lora_index]["weights"]
85
 
@@ -121,22 +124,44 @@ def load_lora(sdxl_loras, lora_scale = 1.0, progress=gr.Progress(track_tqdm=True
121
  merge_incompatible_lora(full_path_lora, lora_scale)
122
  last_fused = False
123
  last_merged = True
124
- print("DONE")
125
- return random_lora_index
126
 
127
 
128
 
129
  ## SEGA ##
 
 
 
 
 
 
 
 
 
 
 
130
 
 
 
 
 
 
 
 
 
 
 
131
  def edit(sdxl_loras,
132
  input_image,
133
  wts, zs,
134
  do_inversion,
135
-
 
136
  ):
137
  show_share_button = gr.update(visible=True)
138
 
139
- random_lora_index = load_lora(sdxl_loras)
140
 
141
  src_prompt = ""
142
  skip = 18
@@ -161,12 +186,12 @@ def edit(sdxl_loras,
161
  # eta = 1.0,
162
  )
163
 
164
- wts = gr.State(value=wts_tensor)
165
- zs = gr.State(value=zs_tensor)
166
  do_inversion = False
167
 
168
 
169
- latnets = wts.value[skip].expand(1, -1, -1, -1)
170
 
171
  editing_prompt = [sdxl_loras[random_lora_index]["trigger_word"]]
172
  reverse_editing_direction = [False]
@@ -188,13 +213,13 @@ def edit(sdxl_loras,
188
  sega_out = sd_pipe(prompt=tar_prompt, latents=latnets, guidance_scale = tar_cfg_scale,
189
  # num_images_per_prompt=1,
190
  # num_inference_steps=steps,
191
- wts=wts.value, zs=zs.value[skip:], **editing_args)
192
 
193
- lora_repo = sdxl_loras[random_lora_index]["repo"]
194
- lora_desc = f"### LoRA Used To Edit this Image: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨ {'(non-commercial LoRA, `cc-by-nc`)' if sdxl_loras[random_lora_index]['is_nc'] else '' }"
195
- lora_image = sdxl_loras[random_lora_index]["image"]
196
 
197
- return sega_out.images[0], wts, zs, do_inversion, lora_image, lora_desc, gr.update(visible=True), gr.update(visible=True)
198
 
199
 
200
 
@@ -232,64 +257,86 @@ def crop_image(image):
232
  ########
233
 
234
  with gr.Blocks(css="style.css") as demo:
235
-
236
-
237
  def reset_do_inversion():
238
  return True
239
 
240
-
241
- gr.HTML(
242
- """<h1><img src="https://i.imgur.com/jpMRW5y.png" alt="LEDITS LoRA Photobooth"></h1>""",
243
- )
244
- wts = gr.State()
245
- zs = gr.State()
246
- reconstruction = gr.State()
247
- do_inversion = gr.State(value=True)
248
- gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
249
-
250
- with gr.Row():
251
- input_image = gr.Image(label="Input Image", interactive=True, source="webcam", height=512, width=512)
252
- sega_edited_image = gr.Image(label=f"LEDITS Edited Image", interactive=False, elem_id="output_image", height=512, width=512)
253
- # input_image.style(height=365, width=365)
254
- # sega_edited_image.style(height=365, width=365)
255
-
256
- with gr.Row():
257
- lora_image = gr.Image(interactive=False, height=128, width=128, visible=False)
258
- lora_desc = gr.HTML(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
-
261
- with gr.Row():
262
- run_button = gr.Button("Run again!", visible=True)
263
-
264
-
265
- run_button.click(
266
- fn=edit,
267
- inputs=[gr_sdxl_loras,
268
- input_image,
269
- wts, zs,
270
- do_inversion,
271
-
272
- ],
273
- outputs=[sega_edited_image, wts, zs, do_inversion, lora_image, lora_desc, lora_image, lora_desc])
274
 
275
- input_image.change(
276
- fn = reset_do_inversion,
277
- outputs = [do_inversion],
278
- queue = False).then(
279
- fn=edit,
280
- inputs=[gr_sdxl_loras,
281
- input_image,
282
- wts, zs,
283
- do_inversion,
284
-
285
-
286
- ],
287
- outputs=[sega_edited_image, wts, zs, do_inversion, lora_image, lora_desc, lora_image, lora_desc]
288
- )
289
-
290
-
291
-
292
-
293
-
294
  demo.queue()
295
- demo.launch(share=True)
 
4
  import requests
5
  import random
6
  from io import BytesIO
7
+ # from utils import *
8
+ # from constants import *
9
  from pipeline_semantic_stable_diffusion_xl_img2img_ddpm import *
10
  from torch import inference_mode
11
  from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, AutoencoderKL
 
22
  import json
23
  import gc
24
  import random
25
+ from time import sleep
26
 
27
+ with open("sdxl_loras-Copy1.json", "r") as file:
28
  data = json.load(file)
29
  sdxl_loras_raw = [
30
  {
 
77
  last_fused = False
78
 
79
 
80
+ def load_lora(sdxl_loras, random_lora_index, lora_scale = 1.0, progress=gr.Progress(track_tqdm=True)):
81
  global last_lora, last_merged, last_fused, sd_pipe
82
 
83
  randomize()
84
+ #random_lora_index = random.randrange(0, len(sdxl_loras), 1)
85
+
86
  repo_name = sdxl_loras[random_lora_index]["repo"]
87
  weight_name = sdxl_loras[random_lora_index]["weights"]
88
 
 
124
  merge_incompatible_lora(full_path_lora, lora_scale)
125
  last_fused = False
126
  last_merged = True
127
+ print("DONE MERGING")
128
+ #return random_lora_index
129
 
130
 
131
 
132
  ## SEGA ##
133
+ def shuffle_lora(sdxl_loras):
134
+ #random_lora_index = load_lora(sdxl_loras)
135
+ random_lora_index = random.randrange(0, len(sdxl_loras), 1)
136
+ lora_repo = sdxl_loras[random_lora_index]["repo"]
137
+ lora_title = sdxl_loras[random_lora_index]["title"]
138
+ lora_desc = f"""#### LoRA used to edit this image:
139
+ ## {lora_title}
140
+ by `{lora_repo.split('/')[0]}`
141
+ """
142
+ lora_image = sdxl_loras[random_lora_index]["image"]
143
+ return random_lora_index, lora_image, lora_desc, gr.update(visible=True), gr.update(height=369)
144
 
145
+ def check_if_removed(input_image):
146
+ if(input_image is None):
147
+ return gr.Row(visible=False), gr.Column(elem_classes="output_column"), gr.Image(value=None)
148
+ else:
149
+ return gr.Row(), gr.Column(), gr.Image()
150
+
151
+ def block_if_removed(input_image):
152
+ if(input_image is None):
153
+ raise gr.Warning("Photo removed. Upload a new one!")
154
+
155
  def edit(sdxl_loras,
156
  input_image,
157
  wts, zs,
158
  do_inversion,
159
+ random_lora_index,
160
+ progress=gr.Progress(track_tqdm=True)
161
  ):
162
  show_share_button = gr.update(visible=True)
163
 
164
+ load_lora(sdxl_loras, random_lora_index)
165
 
166
  src_prompt = ""
167
  skip = 18
 
186
  # eta = 1.0,
187
  )
188
 
189
+ wts = wts_tensor
190
+ zs = zs_tensor
191
  do_inversion = False
192
 
193
 
194
+ latnets = wts[skip].expand(1, -1, -1, -1)
195
 
196
  editing_prompt = [sdxl_loras[random_lora_index]["trigger_word"]]
197
  reverse_editing_direction = [False]
 
213
  sega_out = sd_pipe(prompt=tar_prompt, latents=latnets, guidance_scale = tar_cfg_scale,
214
  # num_images_per_prompt=1,
215
  # num_inference_steps=steps,
216
+ wts=wts, zs=zs[skip:], **editing_args)
217
 
218
+ #lora_repo = sdxl_loras[random_lora_index]["repo"]
219
+ #lora_desc = f"### LoRA Used To Edit this Image: {lora_repo}' }"
220
+ #lora_image = sdxl_loras[random_lora_index]["image"]
221
 
222
+ return sega_out.images[0], wts, zs, do_inversion, gr.update(height=512), gr.Column(elem_classes="output_column_reverse")
223
 
224
 
225
 
 
257
  ########
258
 
259
  with gr.Blocks(css="style.css") as demo:
 
 
260
  def reset_do_inversion():
261
  return True
262
 
263
+ gr.HTML("""<img style="margin: 0 auto; width: 180px; margin-bottom: .5em" src="https://i.imgur.com/A4BP6Lx.png" alt="LEDITS SDXL LoRA Photobooth">""")
264
+ with gr.Box(elem_id="total_box"):
265
+ gr.HTML(
266
+ """<h1>LEDITS SDXL LoRA Photobooth</h1>
267
+ <h3>Smile, take a pic 📷✨ and <code>it'll be inverted on SDXL and a random SDXL LoRA will be applied</code></h3>
268
+ """,
269
+ )
270
+ wts = gr.State()
271
+ zs = gr.State()
272
+ reconstruction = gr.State()
273
+ do_inversion = gr.State(value=True)
274
+ gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
275
+ gr_lora_index = gr.State()
276
+ with gr.Row():
277
+ input_image = gr.Image(label="Input Image", interactive=True, source="webcam", height=512, width=512, elem_id="input_image")
278
+ with gr.Column(elem_classes="output_column") as output_column:
279
+ with gr.Row(visible=False) as loaded_lora:
280
+ lora_image = gr.Image(interactive=False, height=128, width=128, elem_id="lora_image", show_label=False, show_download_button=False)
281
+ lora_desc = gr.Markdown()
282
+ sega_edited_image = gr.Image(label=f"LEDITS Edited Image", interactive=False, elem_id="output_image", height=512, width=512)
283
+
284
+
285
+
286
+
287
+ with gr.Row():
288
+ run_button = gr.Button("Rerun with the same picture", visible=True, elem_id="run_again")
289
+
290
+
291
+ run_button.click(
292
+ fn=shuffle_lora,
293
+ inputs=[gr_sdxl_loras],
294
+ outputs=[gr_lora_index, lora_image, lora_desc, loaded_lora, sega_edited_image],
295
+ queue=False
296
+ ).then(
297
+ fn=edit,
298
+ inputs=[gr_sdxl_loras,
299
+ input_image,
300
+ wts, zs,
301
+ do_inversion,
302
+ gr_lora_index
303
+
304
+ ],
305
+ outputs=[sega_edited_image, wts, zs, do_inversion, sega_edited_image, output_column])
306
+
307
+ input_image.change(
308
+ fn = check_if_removed,
309
+ inputs = [input_image],
310
+ outputs = [loaded_lora, output_column, sega_edited_image],
311
+ queue=False,
312
+ show_progress=False
313
+ ).then(
314
+ fn = block_if_removed,
315
+ inputs = [input_image],
316
+ queue=False,
317
+ show_progress=False
318
+ ).success(
319
+ fn = reset_do_inversion,
320
+ outputs = [do_inversion],
321
+ queue = False).then(
322
+ fn=shuffle_lora,
323
+ inputs=[gr_sdxl_loras],
324
+ outputs=[gr_lora_index, lora_image, lora_desc, loaded_lora, sega_edited_image],
325
+ queue=False
326
+ ).then(
327
+ fn=edit,
328
+ inputs=[gr_sdxl_loras,
329
+ input_image,
330
+ wts, zs,
331
+ do_inversion,
332
+ gr_lora_index
333
+ ],
334
+ outputs=[sega_edited_image, wts, zs, do_inversion, sega_edited_image, output_column]
335
+ )
336
+ gr.HTML('''
337
+ <img src="https://iccv2023.thecvf.com/img/LogoICCV23V04.svg" width="400" style="margin: 0 auto; display: none" id='iccv_logo' />
338
+ ''')
339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  demo.queue()
342
+ demo.launch()