Linoy Tsaban commited on
Commit
24b22ad
·
1 Parent(s): 4065064

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -14
app.py CHANGED
@@ -35,9 +35,9 @@ def caption_image(input_image):
35
  generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
36
  return generated_caption, generated_caption
37
 
38
- def sample(zs, wts, attention_store, prompt_tar="", cfg_scale_tar=15, skip=36, eta=1):
39
  latents = wts[-1].expand(1, -1, -1, -1)
40
- img, attention_store = pipe(
41
  prompt=prompt_tar,
42
  init_latents=latents,
43
  guidance_scale=cfg_scale_tar,
@@ -45,10 +45,10 @@ def sample(zs, wts, attention_store, prompt_tar="", cfg_scale_tar=15, skip=36, e
45
  # num_inference_steps=steps,
46
  # use_ddpm=True,
47
  # wts=wts.value,
48
- attention_store = attention_store,
49
  zs=zs,
50
  )
51
- return img.images[0], attention_store
52
 
53
 
54
  def reconstruct(
@@ -59,6 +59,7 @@ def reconstruct(
59
  wts,
60
  zs,
61
  attention_store,
 
62
  do_reconstruction,
63
  reconstruction,
64
  reconstruct_button,
@@ -79,8 +80,8 @@ def reconstruct(
79
  ): # if image caption was not changed, run actual reconstruction
80
  tar_prompt = ""
81
  latents = wts[-1].expand(1, -1, -1, -1)
82
- reconstruction, attention_store = sample(
83
- zs, wts, attention_store=attention_store, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale
84
  )
85
  do_reconstruction = False
86
  return (
@@ -130,7 +131,7 @@ def load_and_invert(
130
  ## SEGA ##
131
 
132
  def edit(input_image,
133
- wts, zs, attention_store,
134
  tar_prompt,
135
  image_caption,
136
  steps,
@@ -197,27 +198,27 @@ def edit(input_image,
197
  )
198
 
199
  latnets = wts[-1].expand(1, -1, -1, -1)
200
- sega_out, attention_store = pipe(prompt=tar_prompt,
201
  init_latents=latnets,
202
  guidance_scale = tar_cfg_scale,
203
  # num_images_per_prompt=1,
204
  # num_inference_steps=steps,
205
  # use_ddpm=True,
206
  # wts=wts.value,
207
- zs=zs, attention_store=attention_store, **editing_args)
208
 
209
- return sega_out.images[0], gr.update(visible=True), do_reconstruction, reconstruction, wts, zs, attention_store, do_inversion, show_share_button
210
 
211
 
212
  else: # if sega concepts were not added, performs regular ddpm sampling
213
 
214
  if do_reconstruction: # if ddpm sampling wasn't computed
215
- pure_ddpm_img, attention_store = sample(zs, wts, attention_store=attention_store, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
216
  reconstruction = pure_ddpm_img
217
  do_reconstruction = False
218
- return pure_ddpm_img, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, attention_store, do_inversion, show_share_button
219
 
220
- return reconstruction, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, attention_store, do_inversion, show_share_button
221
 
222
 
223
  def randomize_seed_fn(seed, is_random):
@@ -461,6 +462,7 @@ with gr.Blocks(css="style.css") as demo:
461
  wts = gr.State()
462
  zs = gr.State()
463
  attention_store=gr.State()
 
464
  reconstruction = gr.State()
465
  do_inversion = gr.State(value=True)
466
  do_reconstruction = gr.State(value=True)
@@ -697,6 +699,7 @@ with gr.Blocks(css="style.css") as demo:
697
  fn=edit,
698
  inputs=[input_image,
699
  wts, zs, attention_store,
 
700
  tar_prompt,
701
  image_caption,
702
  steps,
@@ -716,7 +719,7 @@ with gr.Blocks(css="style.css") as demo:
716
 
717
 
718
  ],
719
- outputs=[sega_edited_image, reconstruct_button, do_reconstruction, reconstruction, wts, zs,attention_store, do_inversion, share_btn_container])
720
  # .success(fn=update_gallery_display, inputs= [prev_output_image, sega_edited_image], outputs = [gallery, gallery, prev_output_image])
721
 
722
 
 
35
  generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
36
  return generated_caption, generated_caption
37
 
38
+ def sample(zs, wts, attention_store, text_cross_attention_maps, prompt_tar="", cfg_scale_tar=15, skip=36, eta=1):
39
  latents = wts[-1].expand(1, -1, -1, -1)
40
+ img, attention_store, text_cross_attention_maps = pipe(
41
  prompt=prompt_tar,
42
  init_latents=latents,
43
  guidance_scale=cfg_scale_tar,
 
45
  # num_inference_steps=steps,
46
  # use_ddpm=True,
47
  # wts=wts.value,
48
+ attention_store = attention_store, text_cross_attention_maps=text_cross_attention_maps,
49
  zs=zs,
50
  )
51
+ return img.images[0], attention_store, text_cross_attention_maps
52
 
53
 
54
  def reconstruct(
 
59
  wts,
60
  zs,
61
  attention_store,
62
+ text_cross_attention_maps,
63
  do_reconstruction,
64
  reconstruction,
65
  reconstruct_button,
 
80
  ): # if image caption was not changed, run actual reconstruction
81
  tar_prompt = ""
82
  latents = wts[-1].expand(1, -1, -1, -1)
83
+ reconstruction, attention_store, text_cross_attention_maps = sample(
84
+ zs, wts, attention_store=attention_store, text_cross_attention_maps=text_cross_attention_maps,prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale
85
  )
86
  do_reconstruction = False
87
  return (
 
131
  ## SEGA ##
132
 
133
  def edit(input_image,
134
+ wts, zs, attention_store, text_cross_attention_maps,
135
  tar_prompt,
136
  image_caption,
137
  steps,
 
198
  )
199
 
200
  latnets = wts[-1].expand(1, -1, -1, -1)
201
+ sega_out, attention_store, text_cross_attention_maps = pipe(prompt=tar_prompt,
202
  init_latents=latnets,
203
  guidance_scale = tar_cfg_scale,
204
  # num_images_per_prompt=1,
205
  # num_inference_steps=steps,
206
  # use_ddpm=True,
207
  # wts=wts.value,
208
+ zs=zs, attention_store=attention_store, text_cross_attention_maps=text_cross_attention_maps, **editing_args)
209
 
210
+ return sega_out.images[0], gr.update(visible=True), do_reconstruction, reconstruction, wts, zs, attention_store, text_cross_attention_maps, do_inversion, show_share_button
211
 
212
 
213
  else: # if sega concepts were not added, performs regular ddpm sampling
214
 
215
  if do_reconstruction: # if ddpm sampling wasn't computed
216
+ pure_ddpm_img, attention_store, text_cross_attention_maps = sample(zs, wts, attention_store=attention_store, text_cross_attention_maps=text_cross_attention_maps, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
217
  reconstruction = pure_ddpm_img
218
  do_reconstruction = False
219
+ return pure_ddpm_img, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, attention_store, text_cross_attention_maps, do_inversion, show_share_button
220
 
221
+ return reconstruction, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, attention_store, text_cross_attention_maps, do_inversion, show_share_button
222
 
223
 
224
  def randomize_seed_fn(seed, is_random):
 
462
  wts = gr.State()
463
  zs = gr.State()
464
  attention_store=gr.State()
465
+ text_cross_attention_maps = gr.State()
466
  reconstruction = gr.State()
467
  do_inversion = gr.State(value=True)
468
  do_reconstruction = gr.State(value=True)
 
699
  fn=edit,
700
  inputs=[input_image,
701
  wts, zs, attention_store,
702
+ text_cross_attention_maps,
703
  tar_prompt,
704
  image_caption,
705
  steps,
 
719
 
720
 
721
  ],
722
+ outputs=[sega_edited_image, reconstruct_button, do_reconstruction, reconstruction, wts, zs,attention_store, text_cross_attention_maps, do_inversion, share_btn_container])
723
  # .success(fn=update_gallery_display, inputs= [prev_output_image, sega_edited_image], outputs = [gallery, gallery, prev_output_image])
724
 
725