Linoy Tsaban commited on
Commit
ec15161
·
1 Parent(s): 968ec9f

Update pipeline_semantic_stable_diffusion_img2img_solver.py

Browse files
pipeline_semantic_stable_diffusion_img2img_solver.py CHANGED
@@ -499,6 +499,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
499
  verbose=True,
500
  use_cross_attn_mask: bool = False,
501
  # Attention store (just for visualization purposes)
 
502
  attn_store_steps: Optional[List[int]] = [],
503
  store_averaged_over_steps: bool = True,
504
  use_intersect_mask: bool = False,
@@ -771,8 +772,8 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
771
  timesteps = timesteps[-zs.shape[0]:]
772
 
773
  if use_cross_attn_mask:
774
- self.attention_store = AttentionStore(average=store_averaged_over_steps, batch_size=batch_size)
775
- self.prepare_unet(self.attention_store, PnP=False)
776
  # 5. Prepare latent variables
777
  num_channels_latents = self.unet.config.in_channels
778
  latents = self.prepare_latents(
@@ -917,8 +918,8 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
917
  noise_guidance_edit_tmp = noise_guidance_edit_tmp * user_mask
918
 
919
  if use_cross_attn_mask:
920
- out = self.attention_store.aggregate_attention(
921
- attention_maps=self.attention_store.step_store,
922
  prompts=self.text_cross_attention_maps,
923
  res=16,
924
  from_where=["up", "down"],
@@ -1080,7 +1081,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
1080
  store_step = i in attn_store_steps
1081
  if store_step:
1082
  print(f"storing attention for step {i}")
1083
- self.attention_store.between_steps(store_step)
1084
 
1085
  # call the callback, if provided
1086
  if callback is not None and i % callback_steps == 0:
 
499
  verbose=True,
500
  use_cross_attn_mask: bool = False,
501
  # Attention store (just for visualization purposes)
502
+ attention_store = None,
503
  attn_store_steps: Optional[List[int]] = [],
504
  store_averaged_over_steps: bool = True,
505
  use_intersect_mask: bool = False,
 
772
  timesteps = timesteps[-zs.shape[0]:]
773
 
774
  if use_cross_attn_mask:
775
+ attention_store = AttentionStore(average=store_averaged_over_steps, batch_size=batch_size)
776
+ self.prepare_unet(attention_store, PnP=False)
777
  # 5. Prepare latent variables
778
  num_channels_latents = self.unet.config.in_channels
779
  latents = self.prepare_latents(
 
918
  noise_guidance_edit_tmp = noise_guidance_edit_tmp * user_mask
919
 
920
  if use_cross_attn_mask:
921
+ out = attention_store.aggregate_attention(
922
+ attention_maps=attention_store.step_store,
923
  prompts=self.text_cross_attention_maps,
924
  res=16,
925
  from_where=["up", "down"],
 
1081
  store_step = i in attn_store_steps
1082
  if store_step:
1083
  print(f"storing attention for step {i}")
1084
+ attention_store.between_steps(store_step)
1085
 
1086
  # call the callback, if provided
1087
  if callback is not None and i % callback_steps == 0: