ford442 commited on
Commit
2438e6b
·
verified ·
1 Parent(s): a2c6ec3

Update ip_adapter/ip_adapter.py

Browse files
Files changed (1) hide show
  1. ip_adapter/ip_adapter.py +6 -6
ip_adapter/ip_adapter.py CHANGED
@@ -39,7 +39,7 @@ class IPAdapter:
39
  self.set_ip_adapter()
40
 
41
  # load image encoder
42
- self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(self.device, dtype=torch.float32)
43
  self.clip_image_processor = CLIPImageProcessor()
44
  # image proj model
45
  self.image_proj_model = self.init_proj()
@@ -50,7 +50,7 @@ class IPAdapter:
50
  cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
51
  clip_embeddings_dim=self.image_encoder.config.projection_dim,
52
  clip_extra_context_tokens=self.num_tokens,
53
- ).to(self.device, dtype=torch.float32)
54
  return image_proj_model
55
 
56
  def set_ip_adapter(self):
@@ -70,7 +70,7 @@ class IPAdapter:
70
  attn_procs[name] = AttnProcessor()
71
  else:
72
  attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,
73
- scale=1.0,num_tokens= self.num_tokens).to(self.device, dtype=torch.float32)
74
  unet.set_attn_processor(attn_procs)
75
  if hasattr(self.pipe, "controlnet"):
76
  if isinstance(self.pipe.controlnet, MultiControlNetModel):
@@ -108,7 +108,7 @@ class IPAdapter:
108
  pil_image = [pil_image]
109
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
110
  print('clip_image_processor shape:',clip_image.shape)
111
- clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float32)).image_embeds
112
  print('image_encoder shape:',clip_image_embeds.shape)
113
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
114
  print('image_proj_model shape:',image_prompt_embeds.shape)
@@ -317,7 +317,7 @@ class IPAdapterPlus(IPAdapter):
317
  embedding_dim=self.image_encoder.config.hidden_size,
318
  output_dim=self.pipe.unet.config.cross_attention_dim,
319
  ff_mult=4
320
- ).to(self.device, dtype=torch.float32)
321
  return image_proj_model
322
 
323
  @torch.inference_mode()
@@ -325,7 +325,7 @@ class IPAdapterPlus(IPAdapter):
325
  if isinstance(pil_image, Image.Image):
326
  pil_image = [pil_image]
327
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
328
- clip_image = clip_image.to(self.device, dtype=torch.float32)
329
  clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
330
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
331
  uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2]
 
39
  self.set_ip_adapter()
40
 
41
  # load image encoder
42
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(self.device, dtype=torch.bfloat16)
43
  self.clip_image_processor = CLIPImageProcessor()
44
  # image proj model
45
  self.image_proj_model = self.init_proj()
 
50
  cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
51
  clip_embeddings_dim=self.image_encoder.config.projection_dim,
52
  clip_extra_context_tokens=self.num_tokens,
53
+ ).to(self.device, dtype=torch.bfloat16)
54
  return image_proj_model
55
 
56
  def set_ip_adapter(self):
 
70
  attn_procs[name] = AttnProcessor()
71
  else:
72
  attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,
73
+ scale=1.0,num_tokens= self.num_tokens).to(self.device, dtype=torch.bfloat16)
74
  unet.set_attn_processor(attn_procs)
75
  if hasattr(self.pipe, "controlnet"):
76
  if isinstance(self.pipe.controlnet, MultiControlNetModel):
 
108
  pil_image = [pil_image]
109
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
110
  print('clip_image_processor shape:',clip_image.shape)
111
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.bfloat16)).image_embeds
112
  print('image_encoder shape:',clip_image_embeds.shape)
113
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
114
  print('image_proj_model shape:',image_prompt_embeds.shape)
 
317
  embedding_dim=self.image_encoder.config.hidden_size,
318
  output_dim=self.pipe.unet.config.cross_attention_dim,
319
  ff_mult=4
320
+ ).to(self.device, dtype=torch.bfloat16)
321
  return image_proj_model
322
 
323
  @torch.inference_mode()
 
325
  if isinstance(pil_image, Image.Image):
326
  pil_image = [pil_image]
327
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
328
+ clip_image = clip_image.to(self.device, dtype=torch.bfloat16)
329
  clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
330
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
331
  uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2]