ford442 commited on
Commit
17aca73
·
verified ·
1 Parent(s): b5a8096

Update ip_adapter/ip_adapter.py

Browse files
Files changed (1) hide show
  1. ip_adapter/ip_adapter.py +5 -5
ip_adapter/ip_adapter.py CHANGED
@@ -266,11 +266,11 @@ class IPAdapterXL(IPAdapter):
266
  image_prompt_embeds_list.append(image_prompt_embeds_5)
267
  uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds_5)
268
 
269
- image_prompt_embeds = torch.cat(image_prompt_embeds_list).mean(dim=0) #.unsqueeze(0)
270
- bs_embed, seq_len = image_prompt_embeds.shape
271
- #bs_embed, seq_len, _ = image_prompt_embeds.shape
272
- #image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
273
- #image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
274
 
275
  uncond_image_prompt_embeds = torch.cat(uncond_image_prompt_embeds_list).mean(dim=0).unsqueeze(0)
276
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
 
266
  image_prompt_embeds_list.append(image_prompt_embeds_5)
267
  uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds_5)
268
 
269
+ #image_prompt_embeds = torch.cat(image_prompt_embeds_list).mean(dim=0).unsqueeze(0)
270
+ image_prompt_embeds = torch.cat(image_prompt_embeds_list).unsqueeze(0)
271
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
272
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
273
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
274
 
275
  uncond_image_prompt_embeds = torch.cat(uncond_image_prompt_embeds_list).mean(dim=0).unsqueeze(0)
276
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)