multimodalart HF staff commited on
Commit
c537c3e
·
1 Parent(s): fb3dc43

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -31
app.py CHANGED
@@ -10,21 +10,21 @@ from pipeline_semantic_stable_diffusion_xl_img2img_ddpm import *
10
  from torch import inference_mode
11
  from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, AutoencoderKL
12
  from diffusers import DDIMScheduler
13
- # from share_btn import community_icon_html, loading_icon_html, share_js
14
  import torch
15
  from huggingface_hub import hf_hub_download
16
  from diffusers import DiffusionPipeline
17
- from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler
18
  import json
19
  from safetensors.torch import load_file
20
- # import lora
21
  import copy
22
  import json
23
  import gc
24
  import random
25
  from time import sleep
26
 
27
- with open("sdxl_loras.json", "r") as file:
28
  data = json.load(file)
29
  sdxl_loras_raw = [
30
  {
@@ -79,10 +79,9 @@ last_fused = False
79
 
80
  def load_lora(sdxl_loras, random_lora_index, lora_scale = 1.0, progress=gr.Progress(track_tqdm=True)):
81
  global last_lora, last_merged, last_fused, sd_pipe
82
-
83
- randomize()
84
  #random_lora_index = random.randrange(0, len(sdxl_loras), 1)
85
-
 
86
  repo_name = sdxl_loras[random_lora_index]["repo"]
87
  weight_name = sdxl_loras[random_lora_index]["weights"]
88
 
@@ -130,9 +129,13 @@ def load_lora(sdxl_loras, random_lora_index, lora_scale = 1.0, progress=gr.Progr
130
 
131
 
132
  ## SEGA ##
133
- def shuffle_lora(sdxl_loras):
134
- #random_lora_index = load_lora(sdxl_loras)
135
- random_lora_index = random.randrange(0, len(sdxl_loras), 1)
 
 
 
 
136
  lora_repo = sdxl_loras[random_lora_index]["repo"]
137
  lora_title = sdxl_loras[random_lora_index]["title"]
138
  lora_desc = f"""#### LoRA used to edit this image:
@@ -140,18 +143,22 @@ def shuffle_lora(sdxl_loras):
140
  by `{lora_repo.split('/')[0]}`
141
  """
142
  lora_image = sdxl_loras[random_lora_index]["image"]
143
- return random_lora_index, lora_image, lora_desc, gr.update(visible=True), gr.update(height=369)
144
 
145
  def check_if_removed(input_image):
146
  if(input_image is None):
147
- return gr.Row(visible=False), gr.Column(elem_classes="output_column"), gr.Image(value=None)
148
  else:
149
- return gr.Row(), gr.Column(), gr.Image()
150
 
151
  def block_if_removed(input_image):
152
  if(input_image is None):
153
  raise gr.Warning("Photo removed. Upload a new one!")
154
 
 
 
 
 
155
  def edit(sdxl_loras,
156
  input_image,
157
  wts, zs,
@@ -160,7 +167,7 @@ def edit(sdxl_loras,
160
  progress=gr.Progress(track_tqdm=True)
161
  ):
162
  show_share_button = gr.update(visible=True)
163
-
164
  load_lora(sdxl_loras, random_lora_index)
165
 
166
  src_prompt = ""
@@ -169,7 +176,7 @@ def edit(sdxl_loras,
169
  tar_cfg_scale = 15
170
  src_cfg_scale = 3.5
171
  tar_prompt = ""
172
-
173
  if do_inversion:
174
  image = load_image(input_image, device=device).to(torch.float16)
175
  with inference_mode():
@@ -209,7 +216,7 @@ def edit(sdxl_loras,
209
  edit_momentum_scale=0.3,
210
  edit_mom_beta=0.6,
211
  eta=1,)
212
-
213
  sega_out = sd_pipe(prompt=tar_prompt, latents=latnets, guidance_scale = tar_cfg_scale,
214
  # num_images_per_prompt=1,
215
  # num_inference_steps=steps,
@@ -219,7 +226,7 @@ def edit(sdxl_loras,
219
  #lora_desc = f"### LoRA Used To Edit this Image: {lora_repo}' }"
220
  #lora_image = sdxl_loras[random_lora_index]["image"]
221
 
222
- return sega_out.images[0], wts, zs, do_inversion, gr.update(height=512), gr.Column(elem_classes="output_column_reverse")
223
 
224
 
225
 
@@ -252,7 +259,7 @@ def crop_image(image):
252
 
253
 
254
 
255
- ########
256
  # demo #
257
  ########
258
 
@@ -273,6 +280,7 @@ with gr.Blocks(css="style.css") as demo:
273
  do_inversion = gr.State(value=True)
274
  gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
275
  gr_lora_index = gr.State()
 
276
  with gr.Row():
277
  input_image = gr.Image(label="Input Image", interactive=True, source="webcam", height=512, width=512, elem_id="input_image")
278
  with gr.Column(elem_classes="output_column") as output_column:
@@ -285,13 +293,26 @@ with gr.Blocks(css="style.css") as demo:
285
 
286
 
287
  with gr.Row():
288
- run_button = gr.Button("Rerun with the same picture", visible=True, elem_id="run_again")
289
-
290
-
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  run_button.click(
292
  fn=shuffle_lora,
293
- inputs=[gr_sdxl_loras],
294
- outputs=[gr_lora_index, lora_image, lora_desc, loaded_lora, sega_edited_image],
295
  queue=False
296
  ).then(
297
  fn=edit,
@@ -300,14 +321,13 @@ with gr.Blocks(css="style.css") as demo:
300
  wts, zs,
301
  do_inversion,
302
  gr_lora_index
303
-
304
  ],
305
- outputs=[sega_edited_image, wts, zs, do_inversion, sega_edited_image, output_column])
306
 
307
  input_image.change(
308
  fn = check_if_removed,
309
  inputs = [input_image],
310
- outputs = [loaded_lora, output_column, sega_edited_image],
311
  queue=False,
312
  show_progress=False
313
  ).then(
@@ -320,8 +340,8 @@ with gr.Blocks(css="style.css") as demo:
320
  outputs = [do_inversion],
321
  queue = False).then(
322
  fn=shuffle_lora,
323
- inputs=[gr_sdxl_loras],
324
- outputs=[gr_lora_index, lora_image, lora_desc, loaded_lora, sega_edited_image],
325
  queue=False
326
  ).then(
327
  fn=edit,
@@ -331,11 +351,12 @@ with gr.Blocks(css="style.css") as demo:
331
  do_inversion,
332
  gr_lora_index
333
  ],
334
- outputs=[sega_edited_image, wts, zs, do_inversion, sega_edited_image, output_column]
335
  )
336
  gr.HTML('''
337
- <img src="https://iccv2023.thecvf.com/img/LogoICCV23V04.svg" width="400" style="margin: 0 auto; display: none" id='iccv_logo' />
338
- ''')
 
339
 
340
 
341
  demo.queue()
 
10
  from torch import inference_mode
11
  from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, AutoencoderKL
12
  from diffusers import DDIMScheduler
13
+ from share_btn import community_icon_html, loading_icon_html, share_js
14
  import torch
15
  from huggingface_hub import hf_hub_download
16
  from diffusers import DiffusionPipeline
17
+ from cog_sdxl.dataset_and_utils import TokenEmbeddingsHandler
18
  import json
19
  from safetensors.torch import load_file
20
+ import lora
21
  import copy
22
  import json
23
  import gc
24
  import random
25
  from time import sleep
26
 
27
+ with open("sdxl_loras-Copy1.json", "r") as file:
28
  data = json.load(file)
29
  sdxl_loras_raw = [
30
  {
 
79
 
80
  def load_lora(sdxl_loras, random_lora_index, lora_scale = 1.0, progress=gr.Progress(track_tqdm=True)):
81
  global last_lora, last_merged, last_fused, sd_pipe
 
 
82
  #random_lora_index = random.randrange(0, len(sdxl_loras), 1)
83
+ print(random_lora_index)
84
+ #print(sdxl_loras)
85
  repo_name = sdxl_loras[random_lora_index]["repo"]
86
  weight_name = sdxl_loras[random_lora_index]["weights"]
87
 
 
129
 
130
 
131
  ## SEGA ##
132
+ def shuffle_lora(sdxl_loras, selected_lora=None):
133
+ print("selected_lora in shuffle_lora", selected_lora)
134
+ if(selected_lora is not None):
135
+ random_lora_index = selected_lora
136
+ else:
137
+ random_lora_index = random.randrange(0, len(sdxl_loras), 1)
138
+ print("random_lora_index in shuffle_lora: ", random_lora_index)
139
  lora_repo = sdxl_loras[random_lora_index]["repo"]
140
  lora_title = sdxl_loras[random_lora_index]["title"]
141
  lora_desc = f"""#### LoRA used to edit this image:
 
143
  by `{lora_repo.split('/')[0]}`
144
  """
145
  lora_image = sdxl_loras[random_lora_index]["image"]
146
+ return gr.update(), random_lora_index, lora_image, lora_desc, gr.update(visible=True), gr.update(height=369)
147
 
148
  def check_if_removed(input_image):
149
  if(input_image is None):
150
+ return gr.Row(visible=False), gr.Column(elem_classes="output_column"), gr.Image(value=None), gr.State(value=None), gr.Button(visible=False)
151
  else:
152
+ return gr.Row(), gr.Column(), gr.Image(), None, gr.Button()
153
 
154
  def block_if_removed(input_image):
155
  if(input_image is None):
156
  raise gr.Warning("Photo removed. Upload a new one!")
157
 
158
+ def select_lora(selected_state: gr.SelectData):
159
+ return selected_state.index
160
+
161
+
162
  def edit(sdxl_loras,
163
  input_image,
164
  wts, zs,
 
167
  progress=gr.Progress(track_tqdm=True)
168
  ):
169
  show_share_button = gr.update(visible=True)
170
+ print("random_lora_index in edit: ", random_lora_index)
171
  load_lora(sdxl_loras, random_lora_index)
172
 
173
  src_prompt = ""
 
176
  tar_cfg_scale = 15
177
  src_cfg_scale = 3.5
178
  tar_prompt = ""
179
+ print("Is do_inversion?", do_inversion)
180
  if do_inversion:
181
  image = load_image(input_image, device=device).to(torch.float16)
182
  with inference_mode():
 
216
  edit_momentum_scale=0.3,
217
  edit_mom_beta=0.6,
218
  eta=1,)
219
+ torch.manual_seed(torch.seed())
220
  sega_out = sd_pipe(prompt=tar_prompt, latents=latnets, guidance_scale = tar_cfg_scale,
221
  # num_images_per_prompt=1,
222
  # num_inference_steps=steps,
 
226
  #lora_desc = f"### LoRA Used To Edit this Image: {lora_repo}' }"
227
  #lora_image = sdxl_loras[random_lora_index]["image"]
228
 
229
+ return sega_out.images[0], wts, zs, do_inversion, gr.update(height=512), gr.Column(elem_classes="output_column_reverse"), gr.Button(visible=True)
230
 
231
 
232
 
 
259
 
260
 
261
 
262
+ ########r
263
  # demo #
264
  ########
265
 
 
280
  do_inversion = gr.State(value=True)
281
  gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
282
  gr_lora_index = gr.State()
283
+ gr_picked_lora = gr.State()
284
  with gr.Row():
285
  input_image = gr.Image(label="Input Image", interactive=True, source="webcam", height=512, width=512, elem_id="input_image")
286
  with gr.Column(elem_classes="output_column") as output_column:
 
293
 
294
 
295
  with gr.Row():
296
+ run_button = gr.Button("Rerun with the same picture", elem_id="run_again", visible=False)
297
+ with gr.Accordion("Tired of randomizing? Pick your LoRA", open=False, elem_id="pick"):
298
+ choose_gallery = gr.Gallery(
299
+ value=[(item["image"], item["title"]) for item in sdxl_loras_raw],
300
+ allow_preview=False,
301
+ columns=6,
302
+ elem_id="gallery",
303
+ show_share_button=False
304
+ )
305
+
306
+ choose_gallery.select(
307
+ fn=select_lora,
308
+ outputs=[gr_picked_lora],
309
+ queue=False
310
+ )
311
+
312
  run_button.click(
313
  fn=shuffle_lora,
314
+ inputs=[gr_sdxl_loras, gr_picked_lora],
315
+ outputs=[sega_edited_image, gr_lora_index, lora_image, lora_desc, loaded_lora, sega_edited_image],
316
  queue=False
317
  ).then(
318
  fn=edit,
 
321
  wts, zs,
322
  do_inversion,
323
  gr_lora_index
 
324
  ],
325
+ outputs=[sega_edited_image, wts, zs, do_inversion, sega_edited_image, output_column, run_button])
326
 
327
  input_image.change(
328
  fn = check_if_removed,
329
  inputs = [input_image],
330
+ outputs = [loaded_lora, output_column, sega_edited_image, gr_picked_lora, run_button],
331
  queue=False,
332
  show_progress=False
333
  ).then(
 
340
  outputs = [do_inversion],
341
  queue = False).then(
342
  fn=shuffle_lora,
343
+ inputs=[gr_sdxl_loras, gr_picked_lora],
344
+ outputs=[sega_edited_image, gr_lora_index, lora_image, lora_desc, loaded_lora, sega_edited_image],
345
  queue=False
346
  ).then(
347
  fn=edit,
 
351
  do_inversion,
352
  gr_lora_index
353
  ],
354
+ outputs=[sega_edited_image, wts, zs, do_inversion, sega_edited_image, output_column, run_button]
355
  )
356
  gr.HTML('''
357
+ <h1>Hugging Face at</h1>
358
+ <img src="https://iccv2023.thecvf.com/img/LogoICCV23V04.svg" width="400" style="margin: 0 auto;" />
359
+ ''', elem_id="iccv_logo")
360
 
361
 
362
  demo.queue()