Spaces:
Running
on
Zero
Running
on
Zero
Update ip_adapter/ip_adapter.py
Browse files- ip_adapter/ip_adapter.py +5 -1
ip_adapter/ip_adapter.py
CHANGED
@@ -109,6 +109,7 @@ class IPAdapter:
|
|
109 |
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
110 |
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.bfloat16)).image_embeds
|
111 |
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
|
|
112 |
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
|
113 |
return image_prompt_embeds, uncond_image_prompt_embeds
|
114 |
|
@@ -238,6 +239,7 @@ class IPAdapterXL(IPAdapter):
|
|
238 |
print('Using primary image.')
|
239 |
|
240 |
image_prompt_embeds_1, uncond_image_prompt_embeds_1 = self.get_image_embeds(pil_image_1)
|
|
|
241 |
image_prompt_embeds_list.append(image_prompt_embeds_1)
|
242 |
uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds_1)
|
243 |
|
@@ -267,10 +269,12 @@ class IPAdapterXL(IPAdapter):
|
|
267 |
uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds_5)
|
268 |
|
269 |
image_prompt_embeds = torch.cat(image_prompt_embeds_list).mean(dim=0).unsqueeze(0)
|
|
|
270 |
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
271 |
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
|
|
272 |
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
273 |
-
|
274 |
uncond_image_prompt_embeds = torch.cat(uncond_image_prompt_embeds_list).mean(dim=0).unsqueeze(0)
|
275 |
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
276 |
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
|
|
109 |
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
110 |
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.bfloat16)).image_embeds
|
111 |
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
112 |
+
print('image_proj_model shape:',image_prompt_embeds)
|
113 |
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
|
114 |
return image_prompt_embeds, uncond_image_prompt_embeds
|
115 |
|
|
|
239 |
print('Using primary image.')
|
240 |
|
241 |
image_prompt_embeds_1, uncond_image_prompt_embeds_1 = self.get_image_embeds(pil_image_1)
|
242 |
+
image_prompt_embeds_1 = image_prompt_embeds_1 * scale_1
|
243 |
image_prompt_embeds_list.append(image_prompt_embeds_1)
|
244 |
uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds_1)
|
245 |
|
|
|
269 |
uncond_image_prompt_embeds_list.append(uncond_image_prompt_embeds_5)
|
270 |
|
271 |
image_prompt_embeds = torch.cat(image_prompt_embeds_list).mean(dim=0).unsqueeze(0)
|
272 |
+
print('catted embeds list with mean and unsqueeze shape: ',image_prompt_embeds.shape)
|
273 |
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
274 |
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
275 |
+
print('catted embeds repeat: ',image_prompt_embeds.shape)
|
276 |
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
277 |
+
print('viewed embeds: ',image_prompt_embeds.shape)
|
278 |
uncond_image_prompt_embeds = torch.cat(uncond_image_prompt_embeds_list).mean(dim=0).unsqueeze(0)
|
279 |
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
280 |
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|