ford442 commited on
Commit
717b099
·
verified ·
1 Parent(s): 1471520

Update ip_adapter/ip_adapter.py

Browse files
Files changed (1) hide show
  1. ip_adapter/ip_adapter.py +5 -1
ip_adapter/ip_adapter.py CHANGED
@@ -109,6 +109,7 @@ class IPAdapter:
109
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
110
  clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.bfloat16)).image_embeds
111
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
 
112
  uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
113
  return image_prompt_embeds, uncond_image_prompt_embeds
114
 
@@ -238,6 +239,7 @@ class IPAdapterXL(IPAdapter):
238
  print('Using primary image.')
239
 
240
  image_prompt_embeds_1, uncond_image_prompt_embeds_1 = self.get_image_embeds(pil_image_1)
 
241
  image_prompt_embeds_list.append(image_prompt_embeds_1)
242
  uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds_1)
243
 
@@ -267,10 +269,12 @@ class IPAdapterXL(IPAdapter):
267
  uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds_5)
268
 
269
  image_prompt_embeds = torch.cat(image_prompt_embeds_list).mean(dim=0).unsqueeze(0)
 
270
  bs_embed, seq_len, _ = image_prompt_embeds.shape
271
  image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
 
272
  image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
273
-
274
  uncond_image_prompt_embeds = torch.cat(uncond_image_prompt_embeds_list).mean(dim=0).unsqueeze(0)
275
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
276
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
 
109
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
110
  clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.bfloat16)).image_embeds
111
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
112
+ print('image_proj_model shape:',image_prompt_embeds)
113
  uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
114
  return image_prompt_embeds, uncond_image_prompt_embeds
115
 
 
239
  print('Using primary image.')
240
 
241
  image_prompt_embeds_1, uncond_image_prompt_embeds_1 = self.get_image_embeds(pil_image_1)
242
+ image_prompt_embeds_1 = image_prompt_embeds_1 * scale_1
243
  image_prompt_embeds_list.append(image_prompt_embeds_1)
244
  uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds_1)
245
 
 
269
  uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds_5)
270
 
271
  image_prompt_embeds = torch.cat(image_prompt_embeds_list).mean(dim=0).unsqueeze(0)
272
+ print('catted embeds list with mean and unsqueeze shape: ',image_prompt_embeds.shape)
273
  bs_embed, seq_len, _ = image_prompt_embeds.shape
274
  image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
275
+ print('catted embeds repeat: ',image_prompt_embeds.shape)
276
  image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
277
+ print('viewed embeds: ',image_prompt_embeds.shape)
278
  uncond_image_prompt_embeds = torch.cat(uncond_image_prompt_embeds_list).mean(dim=0).unsqueeze(0)
279
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
280
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)