JackAILab commited on
Commit
f970466
·
verified ·
1 Parent(s): 0b15b54

Update pipline_StableDiffusion_ConsistentID.py

Browse files
pipline_StableDiffusion_ConsistentID.py CHANGED
@@ -5,7 +5,8 @@ import numpy as np
5
  from PIL import Image
6
  import torch
7
  from torchvision import transforms
8
- from insightface.app import FaceAnalysis
 
9
  from safetensors import safe_open
10
  from huggingface_hub.utils import validate_hf_hub_args
11
  from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
@@ -15,16 +16,20 @@ from diffusers.utils import _get_model_file
15
  from functions import process_text_with_markers, masks_for_unique_values, fetch_mask_raw_image, tokenize_and_mask_noun_phrases_ends, prepare_image_token_idx
16
  from functions import ProjPlusModel, masks_for_unique_values
17
  from attention import Consistent_IPAttProcessor, Consistent_AttProcessor, FacialEncoder
18
- from modelscope.outputs import OutputKeys
19
- from modelscope.pipelines import pipeline
20
 
21
- #TODO 引入BiSeNet库路径
 
22
  import sys
23
  sys.path.append("./models/BiSeNet")
 
 
 
 
 
 
 
24
  from model import BiSeNet
25
 
26
-
27
-
28
  PipelineImageInput = Union[
29
  PIL.Image.Image,
30
  torch.FloatTensor,
@@ -43,13 +48,13 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
43
  subfolder: str = '',
44
  trigger_word_ID: str = '<|image|>',
45
  trigger_word_facial: str = '<|facial|>',
46
- image_encoder_path: str = '/data2/huangjiehui_m22/pretrained_model/CLIP-ViT-H-14-laion2B-s32B-b79K', # TODO CLIP路径
47
  torch_dtype = torch.float16,
48
  num_tokens = 4,
49
  lora_rank= 128,
50
  **kwargs,
51
  ):
52
- self.lora_rank = lora_rank
53
  self.torch_dtype = torch_dtype
54
  self.num_tokens = num_tokens
55
  self.set_ip_adapter()
@@ -68,7 +73,7 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
68
  ### BiSeNet
69
  self.bise_net = BiSeNet(n_classes = 19)
70
  self.bise_net.cuda()
71
- self.bise_net_cp='./models/BiSeNet_pretrained_for_ConsistentID.pth' #TODO BiSeNet的checkpoint
72
  self.bise_net.load_state_dict(torch.load(self.bise_net_cp))
73
  self.bise_net.eval()
74
  # Colors for all 20 parts
@@ -83,7 +88,7 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
83
  [0, 255, 255], [85, 255, 255], [170, 255, 255]]
84
 
85
  ### LLVA Optional
86
- self.llva_model_path = "/data6/huangjiehui_m22/pretrained_model/llava-v1.5-7b" #TODO llava模型路径
87
  self.llva_prompt = "Describe this person's facial features for me, including face, ears, eyes, nose, and mouth."
88
  self.llva_tokenizer, self.llva_model, self.llva_image_processor, self.llva_context_len = None,None,None,None #load_pretrained_model(self.llva_model_path)
89
 
@@ -95,9 +100,6 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
95
  ).to(self.device, dtype=self.torch_dtype)
96
  self.FacialEncoder = FacialEncoder(self.image_encoder).to(self.device, dtype=self.torch_dtype)
97
 
98
- # Modelscope 美肤用
99
- self.skin_retouching = pipeline('skin-retouching-torch', model='damo/cv_unet_skin_retouching_torch', model_revision='v1.0.2')
100
-
101
  # Load the main state dict first.
102
  cache_dir = kwargs.pop("cache_dir", None)
103
  force_download = kwargs.pop("force_download", False)
@@ -183,7 +185,6 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
183
  hidden_states = []
184
  uncond_hidden_states = []
185
  for facial_clip_image in facial_clip_images:
186
- # 分别把这几个裁剪出来的五官局部照用CLIP提一次
187
  hidden_state = self.image_encoder(facial_clip_image.to(self.device, dtype=self.torch_dtype), output_hidden_states=True).hidden_states[-2]
188
  uncond_hidden_state = self.image_encoder(torch.zeros_like(facial_clip_image, dtype=self.torch_dtype).to(self.device), output_hidden_states=True).hidden_states[-2]
189
  hidden_states.append(hidden_state)
@@ -191,7 +192,7 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
191
  multi_facial_embeds = torch.stack(hidden_states)
192
  uncond_multi_facial_embeds = torch.stack(uncond_hidden_states)
193
 
194
- # condition 这个关键!FacialEncoder怎么设计的
195
  facial_prompt_embeds = self.FacialEncoder(prompt_embeds, multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask)
196
 
197
  # uncondition
@@ -204,15 +205,13 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
204
 
205
  clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values
206
  clip_image = clip_image.to(self.device, dtype=self.torch_dtype)
207
- # 先处理,变成1x3x224x224的clip_image,然后用图像编码器,编成1x257x1280
208
  clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
209
  uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2]
210
- # uncond_clip_image_embeds居然是用零矩阵编码出来的,用来做什么呢?用来cf guidence
211
  faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
212
  image_prompt_tokens = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale)
213
  uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale)
214
- # image_prompt_tokens感觉像是faceID与图片在CLIP提取的特征做过注意力之后的faceid_embeds
215
- # 而uncond_image_prompt_embeds像是假faceID与假图片在CLIP提取的特征做注意力
216
  return image_prompt_tokens, uncond_image_prompt_embeds
217
 
218
  def set_scale(self, scale):
@@ -223,13 +222,12 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
223
  @torch.inference_mode()
224
  def get_prepare_faceid(self, face_image):
225
  faceid_image = np.array(face_image)
226
- # 下面这句是用insightmodel获取faceid
227
  faces = self.app.get(faceid_image)
228
  if faces==[]:
229
  faceid_embeds = torch.zeros_like(torch.empty((1, 512)))
230
- else:# 这个insightmodel获得的是512的embedding,转成torch,头部加一维
231
  faceid_embeds = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
232
- # 可能获取不出来,那么他会是一个空的ID
233
  return faceid_embeds
234
 
235
  @torch.inference_mode()
@@ -247,14 +245,13 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
247
  img = to_tensor(image)
248
  img = torch.unsqueeze(img, 0)
249
  img = img.float().cuda()
250
- out = self.bise_net(img)[0] #1,19,512,512
251
- parsing_anno = out.squeeze(0).cpu().numpy().argmax(0) #512,512 每个位置上是19个通道谁最大
252
 
253
  im = np.array(image_resize_PIL)
254
  vis_im = im.copy().astype(np.uint8)
255
  stride=1
256
- vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
257
- #下句我就不明白了,一比一缩放插值一下,不是没什么用嘛
258
  vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
259
  vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
260
 
@@ -264,8 +261,8 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
264
  index = np.where(vis_parsing_anno == pi)
265
  vis_parsing_anno_color[index[0], index[1], :] = self.part_colors[pi]
266
 
267
- vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) #染了色的mask,只有颜色
268
- vis_parsing_anno_color = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)#与原图叠加一下
269
 
270
  return vis_parsing_anno_color, vis_parsing_anno
271
 
@@ -293,24 +290,20 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
293
 
294
  return face_caption
295
 
296
-
297
-
298
  @torch.inference_mode()
299
  def get_prepare_facemask(self, input_image_file):
300
- #先获取一下mask
301
  vis_parsing_anno_color, vis_parsing_anno = self.parsing_face_mask(input_image_file)
302
  parsing_mask_list = masks_for_unique_values(vis_parsing_anno)
303
 
304
  key_parsing_mask_list = {}
305
  key_list = ["Face", "Left_Ear", "Right_Ear", "Left_Eye", "Right_Eye", "Nose", "Upper_Lip", "Lower_Lip"]
306
- # TODO 背景信息还没有用上,看看有没有必要
307
-
308
  processed_keys = set()
309
  for key, mask_image in parsing_mask_list.items():
310
  if key in key_list:
311
  if "_" in key:
312
  prefix = key.split("_")[1]
313
- if prefix in processed_keys: # 左耳右耳只处理一次?都是耳朵再遇到就不处理了?
314
  continue
315
  else:
316
  key_parsing_mask_list[key] = mask_image
@@ -332,15 +325,12 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
332
  device: Optional[torch.device] = None,
333
  ):
334
  device = device or self._execution_device
335
- #这一步是比较paper比较关键的步骤,但是推理过程中我怎么觉得没什么用
336
- #就是说改这个prompt,让他的关键词顺序与key_parsing_mask_list_align中Eye Ear Nose的出现顺序一致
337
- #并且在这些关键词后加上<|facial|>的文本标记
338
  face_caption_align, key_parsing_mask_list_align = process_text_with_markers(face_caption, key_parsing_mask_list)
339
 
340
- # 与用户输入的prompt结合
341
  prompt_face = prompt + "Detail:" + face_caption_align
342
 
343
- max_text_length=330 # 如果用户输入的prompt太长了,会把包含facial关键字的face_caption_align提前到Detail:,防止之后被截断
344
  if len(self.tokenizer(prompt_face, max_length=self.tokenizer.model_max_length, padding="max_length",truncation=False,return_tensors="pt").input_ids[0])!=77:
345
  prompt_face = "Detail:" + face_caption_align + " Caption:" + prompt
346
 
@@ -350,17 +340,12 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
350
 
351
  prompt_text_only = prompt_face.replace("<|facial|>", "").replace("<|image|>", "")
352
  tokenizer = self.tokenizer
353
- # level 3 设定触发词 并获取"<|facial|>"触发词 id-49409
354
  facial_token_id = tokenizer.convert_tokens_to_ids(facial_token)
355
- image_token_id = None # TODO level2要做的事情,这个要做什么没理解
356
 
357
- # clean_input_id就是1x77长经典的SD用的tokens,里面没有触发词的编码
358
- # image_token_mask是1x77长的false,好像暂时没什么用,
359
- # facial_token_mask是1x77长的false中间有几个true,true是触发词的位置
360
- # 还有一个问题是长度就77,怎么做prompt engineering?TODO
361
  clean_input_id, image_token_mask, facial_token_mask = tokenize_and_mask_noun_phrases_ends(
362
  prompt_face, image_token_id, facial_token_id, tokenizer)
363
- # 下面这个也是没懂这做什么,image_token_idx好像没用,facial_token_idx有用,获得了facial token的位置索引,mask就感觉没什么用了
364
  image_token_idx, image_token_idx_mask, facial_token_idx, facial_token_idx_mask = prepare_image_token_idx(
365
  image_token_mask, facial_token_mask, num_id_images, max_num_facials )
366
 
@@ -375,18 +360,17 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
375
  clip_image_processor = CLIPImageProcessor()
376
 
377
  num_facial_part = len(key_parsing_mask_list)
378
- #这个循环就是取mask与原图的与,放到facial_clip_image里存着
379
  for key in key_parsing_mask_list:
380
  key_mask=key_parsing_mask_list[key]
381
  facial_mask.append(transform_mask(key_mask))
382
- # key_mask_raw_image就是按照五官的mask截取出原图的一小部分区域
383
  key_mask_raw_image = fetch_mask_raw_image(input_image_file,key_mask)
384
  parsing_clip_image = clip_image_processor(images=key_mask_raw_image, return_tensors="pt").pixel_values
385
  facial_clip_image.append(parsing_clip_image)
386
 
387
  padding_ficial_clip_image = torch.zeros_like(torch.zeros([1, 3, 224, 224]))
388
  padding_ficial_mask = torch.zeros_like(torch.zeros([1, image_size, image_size]))
389
- # facial_clip_image与facial_mask补上到max_num_facials,感觉这个没什么用
390
  if num_facial_part < max_num_facials:
391
  facial_clip_image += [torch.zeros_like(padding_ficial_clip_image) for _ in range(max_num_facials - num_facial_part) ]
392
  facial_mask += [ torch.zeros_like(padding_ficial_mask) for _ in range(max_num_facials - num_facial_part)]
@@ -394,9 +378,8 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
394
  facial_clip_image = torch.stack(facial_clip_image, dim=1).squeeze(0)
395
  facial_mask = torch.stack(facial_mask, dim=0).squeeze(dim=1)
396
 
397
- return facial_clip_image, facial_mask # facial_mask 在训练过程中是用来做 loss 的, 推理过程不需要
398
 
399
- # pipe入口是这里
400
  @torch.no_grad()
401
  def __call__(
402
  self,
@@ -420,12 +403,9 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
420
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
421
  callback_steps: int = 1,
422
  input_id_images: PipelineImageInput = None,
423
- reference_id_images: PipelineImageInput =None,
424
  start_merge_step: int = 0,
425
  class_tokens_mask: Optional[torch.LongTensor] = None,
426
  prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
427
- retouching: bool=False,
428
- need_safetycheck: bool=True,
429
  ):
430
  # 0. Default height and width to unet
431
  height = height or self.unet.config.sample_size * self.vae_scale_factor
@@ -451,7 +431,7 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
451
  if prompt is not None and isinstance(prompt, str):
452
  batch_size = 1
453
  elif prompt is not None and isinstance(prompt, list):
454
- batch_size = len(prompt) #TODO
455
  else:
456
  batch_size = prompt_embeds.shape[0]
457
 
@@ -459,42 +439,26 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
459
  do_classifier_free_guidance = guidance_scale >= 1.0
460
  input_image_file = input_id_images[0]
461
 
462
- # *************4-14,使用多照片的混合faceid,发现并没有很大影响
463
- if reference_id_images:
464
- references_faceid_embeds=[]
465
- for reference_image in reference_id_images:
466
- references_faceid_embeds.append(self.get_prepare_faceid(face_image=reference_image))
467
- references_faceid_embeds = torch.stack(references_faceid_embeds, dim=0) #torch.Size([16, 1, 512])
468
- references_faceid_embeds_mean=torch.mean(references_faceid_embeds, dim=0)
469
- # references_faceid_embeds_var=torch.var(references_faceid_embeds, dim=0)
470
- # references_faceid_embeds_sample=torch.normal(references_faceid_embeds_mean, references_faceid_embeds_var)
471
-
472
- # 这里不是很理解,faceid推理时是哪里来的
473
- # insightface提的,1X512
474
- faceid_embeds = self.get_prepare_faceid(face_image=input_image_file) #TODO 用gradio的时候打开这句关掉下句
475
- # faceid_embeds = references_faceid_embeds_mean # 用参考人像集中的采样来做id
476
- # 推理的时候没有用到llava的detailed面部描述嘛?
477
- # 无
478
  face_caption = self.get_prepare_llva_caption(input_image_file)
479
- # 问题有,没识别到左眼左耳;这右耳的mask实在是太小了,聊胜于无,
480
  key_parsing_mask_list, vis_parsing_anno_color = self.get_prepare_facemask(input_image_file)
481
 
482
- # 这个是断言语句,就是guidance_scale >= 1.0时继续允许,否则抛出报错
483
  assert do_classifier_free_guidance
484
 
485
  # 3. Encode input prompt
486
  num_id_images = len(input_id_images)
487
 
488
  (
489
- prompt_text_only, # 用户输入的prompt与预制的prompt的拼接,没有facial关键词
490
- clean_input_id, # 对prompt_text_only token化得到
491
- key_parsing_mask_list_align, #似乎只有眼睛,耳朵,鼻子,嘴,是数组,里面放了四个PIL的mask
492
- facial_token_mask, # 大部分False小部分
493
- facial_token_idx, # 似乎就这个有用,得到了facial token在tokens中的索引,是一个5长度的数组
494
  facial_token_idx_mask,
495
  ) = self.encode_prompt_with_trigger_word(
496
  prompt = prompt,
497
- face_caption = face_caption,#这个是固定的
 
498
  key_parsing_mask_list=key_parsing_mask_list,
499
  device=device,
500
  max_num_facials = 5,
@@ -506,41 +470,35 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
506
 
507
  # 4. Encode input prompt without the trigger word for delayed conditioning
508
  encoder_hidden_states = self.text_encoder(clean_input_id.to(device))[0]
509
- # 上面这玩意把clean_input_id encoder了一下,变成了1x77x768张量
510
- # 这个就是CLIP的encoder,没有做修改。
511
  prompt_embeds = self._encode_prompt(
512
  prompt_text_only,
513
  device=device,
514
  num_images_per_prompt=num_images_per_prompt,
515
  do_classifier_free_guidance=True,
516
- negative_prompt=negative_prompt, #这个函数是SD的pipeline自带的,
517
- ) #这玩意是2x77x768,第一份给了negative_encoder_hidden_states_text_only,第二份给了encoder_hidden_states_text_only
518
  negative_encoder_hidden_states_text_only = prompt_embeds[0:num_images_per_prompt]
519
  encoder_hidden_states_text_only = prompt_embeds[num_images_per_prompt:]
520
 
521
  # 5. Prepare the input ID images
522
- prompt_tokens_faceid, uncond_prompt_tokens_faceid = self.get_image_embeds(faceid_embeds, face_image=input_image_file, s_scale=0.0, shortcut=True)
523
- # TODO s_scale这里被我改了,改成faceid_embeds的残差,试了一下,不行,动不了这个模块,后面的参数已经
524
- # 上面两个编码完之后是1x4x768,prompt_tokens_faceid是与整个图像做完注意力的faceid
525
- # uncond_prompt_tokens_faceid,之所以要这个uncond的,是CF guidance的公式需要,需要保留一定的多样性,
526
  facial_clip_image, facial_mask = self.get_prepare_clip_image(input_image_file, key_parsing_mask_list_align, image_size=512, max_num_facials=5)
527
- # 上面这两个处理完是5x3x224x224,5x512x512,推理只用到facial_clip_image,就是原图与mask的与,并且所放到了224x224
528
  facial_clip_images = facial_clip_image.unsqueeze(0).to(device, dtype=self.torch_dtype)
529
- # 这里,有必要把facial_clip_images的图片印出来看看,看看覆盖面积大不大
530
  facial_token_mask = facial_token_mask.to(device)
531
  facial_token_idx_mask = facial_token_idx_mask.to(device)
532
  negative_encoder_hidden_states = negative_encoder_hidden_states_text_only
533
 
534
  cross_attention_kwargs = {}
535
 
536
- # 6. Get the update text embeddingx
537
  prompt_embeds_facial, uncond_prompt_embeds_facial = self.get_facial_embeds(encoder_hidden_states, negative_encoder_hidden_states, \
538
  facial_clip_images, facial_token_mask, facial_token_idx_mask)
539
- # prompt_embeds_facial本是textprompt,在标记位用五官局部图的CLIP特征做了替换
540
- # prompt_tokens_faceid本是insightface提取的特征的,也就是faceid eb,再融合了用全图的CLIP特征,融合的时候有注意力机制
541
  prompt_embeds = torch.cat([prompt_embeds_facial, prompt_tokens_faceid], dim=1)
542
  negative_prompt_embeds = torch.cat([uncond_prompt_embeds_facial, uncond_prompt_tokens_faceid], dim=1)
543
- # TODO 这一步没懂啊,都已经获得prompt_embeds了,怎么又过了一次_encode_prompt,这个是SD自带的交叉注意力机制,具体内部是怎么样的还没看
544
  prompt_embeds = self._encode_prompt(
545
  prompt,
546
  device,
@@ -550,8 +508,6 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
550
  prompt_embeds=prompt_embeds,
551
  negative_prompt_embeds=negative_prompt_embeds,
552
  )
553
- # 从SD这个出来的出来得到的prompt_embeds torch.Size([2, 81, 768]),我猜就是在第一维把有无条件的两个prompt_embeds cat了一下
554
- # 下面这两句后prompt_embeds torch.Size([3, 81, 768]),又在后面第一维加了纯文本与faceid eb的融合
555
  prompt_embeds_text_only = torch.cat([encoder_hidden_states_text_only, prompt_tokens_faceid], dim=1)
556
  prompt_embeds = torch.cat([prompt_embeds, prompt_embeds_text_only], dim=0)
557
 
@@ -574,9 +530,9 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
574
 
575
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
576
  (
577
- null_prompt_embeds, #无条件prompt 也是negative prompt
578
- augmented_prompt_embeds, #增强的文本prompt+ id prompt
579
- text_prompt_embeds, #文本prompt+id prompt
580
  ) = prompt_embeds.chunk(3)
581
 
582
  # 9. Denoising loop
@@ -597,7 +553,7 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
597
  [null_prompt_embeds, augmented_prompt_embeds], dim=0
598
  )
599
 
600
- # predict the noise residual 这一步魔改了一点东西
601
  noise_pred = self.unet(
602
  latent_model_input,
603
  t,
@@ -630,27 +586,17 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
630
  if output_type == "latent":
631
  image = latents
632
  has_nsfw_concept = None
633
- elif output_type == "pil": #默认的
634
  # 9.1 Post-processing
635
  image = self.decode_latents(latents)
636
 
637
  # 9.2 Run safety checker
638
- if need_safetycheck:
639
- image, has_nsfw_concept = self.run_safety_checker(
640
- image, device, prompt_embeds.dtype
641
- )
642
- else:
643
- has_nsfw_concept = None
644
 
645
- # 9.3 Convert to PIL list
646
- image = self.numpy_to_pil(image)
647
-
648
- # 临时添加的,美肤效果,modelscope接收PIL对象,给一个BGR矩阵
649
- # 用了一下还是不要了,这个美肤模型失败概率有点大
650
- if retouching:
651
- after_retouching = self.skin_retouching(image[0])
652
- if OutputKeys.OUTPUT_IMG in after_retouching:
653
- image = [Image.fromarray(cv2.cvtColor(after_retouching[OutputKeys.OUTPUT_IMG], cv2.COLOR_BGR2RGB))]
654
  else:
655
  # 9.1 Post-processing
656
  image = self.decode_latents(latents)
@@ -660,7 +606,6 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
660
  image, device, prompt_embeds.dtype
661
  )
662
 
663
-
664
  # Offload last model to CPU
665
  if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
666
  self.final_offload_hook.offload()
@@ -672,3 +617,10 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
672
  images=image, nsfw_content_detected=has_nsfw_concept
673
  )
674
 
 
 
 
 
 
 
 
 
5
  from PIL import Image
6
  import torch
7
  from torchvision import transforms
8
+ from insightface.app import FaceAnalysis
9
+ ### insight-face installation can be found at https://github.com/deepinsight/insightface
10
  from safetensors import safe_open
11
  from huggingface_hub.utils import validate_hf_hub_args
12
  from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
 
16
  from functions import process_text_with_markers, masks_for_unique_values, fetch_mask_raw_image, tokenize_and_mask_noun_phrases_ends, prepare_image_token_idx
17
  from functions import ProjPlusModel, masks_for_unique_values
18
  from attention import Consistent_IPAttProcessor, Consistent_AttProcessor, FacialEncoder
 
 
19
 
20
+ <<<<<<< HEAD
21
+ #Import BiSeNet's model file
22
  import sys
23
  sys.path.append("./models/BiSeNet")
24
+ =======
25
+ ###TODO Import BiSeNet's model file
26
+ ### Model can be import from https://github.com/zllrunning/face-parsing.PyTorch?tab=readme-ov-file
27
+ ### We use the ckpt of 79999_iter.pth: https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812
28
+ ### Thanks for the open source of face-parsing model.
29
+ sys.path.append("")
30
+ >>>>>>> 6f06fd81331aaed15193b840b17e221773a1abe2
31
  from model import BiSeNet
32
 
 
 
33
  PipelineImageInput = Union[
34
  PIL.Image.Image,
35
  torch.FloatTensor,
 
48
  subfolder: str = '',
49
  trigger_word_ID: str = '<|image|>',
50
  trigger_word_facial: str = '<|facial|>',
51
+ image_encoder_path: str = 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K', # TODO Import CLIP pretrained model
52
  torch_dtype = torch.float16,
53
  num_tokens = 4,
54
  lora_rank= 128,
55
  **kwargs,
56
  ):
57
+ self.lora_rank = lora_rank
58
  self.torch_dtype = torch_dtype
59
  self.num_tokens = num_tokens
60
  self.set_ip_adapter()
 
73
  ### BiSeNet
74
  self.bise_net = BiSeNet(n_classes = 19)
75
  self.bise_net.cuda()
76
+ self.bise_net_cp='JackAILab/ConsistentID/face_parsing.pth' # Import BiSeNet model
77
  self.bise_net.load_state_dict(torch.load(self.bise_net_cp))
78
  self.bise_net.eval()
79
  # Colors for all 20 parts
 
88
  [0, 255, 255], [85, 255, 255], [170, 255, 255]]
89
 
90
  ### LLVA Optional
91
+ self.llva_model_path = "" #TODO import llava weights
92
  self.llva_prompt = "Describe this person's facial features for me, including face, ears, eyes, nose, and mouth."
93
  self.llva_tokenizer, self.llva_model, self.llva_image_processor, self.llva_context_len = None,None,None,None #load_pretrained_model(self.llva_model_path)
94
 
 
100
  ).to(self.device, dtype=self.torch_dtype)
101
  self.FacialEncoder = FacialEncoder(self.image_encoder).to(self.device, dtype=self.torch_dtype)
102
 
 
 
 
103
  # Load the main state dict first.
104
  cache_dir = kwargs.pop("cache_dir", None)
105
  force_download = kwargs.pop("force_download", False)
 
185
  hidden_states = []
186
  uncond_hidden_states = []
187
  for facial_clip_image in facial_clip_images:
 
188
  hidden_state = self.image_encoder(facial_clip_image.to(self.device, dtype=self.torch_dtype), output_hidden_states=True).hidden_states[-2]
189
  uncond_hidden_state = self.image_encoder(torch.zeros_like(facial_clip_image, dtype=self.torch_dtype).to(self.device), output_hidden_states=True).hidden_states[-2]
190
  hidden_states.append(hidden_state)
 
192
  multi_facial_embeds = torch.stack(hidden_states)
193
  uncond_multi_facial_embeds = torch.stack(uncond_hidden_states)
194
 
195
+ # condition
196
  facial_prompt_embeds = self.FacialEncoder(prompt_embeds, multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask)
197
 
198
  # uncondition
 
205
 
206
  clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values
207
  clip_image = clip_image.to(self.device, dtype=self.torch_dtype)
 
208
  clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
209
  uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2]
210
+
211
  faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
212
  image_prompt_tokens = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale)
213
  uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale)
214
+
 
215
  return image_prompt_tokens, uncond_image_prompt_embeds
216
 
217
  def set_scale(self, scale):
 
222
  @torch.inference_mode()
223
  def get_prepare_faceid(self, face_image):
224
  faceid_image = np.array(face_image)
 
225
  faces = self.app.get(faceid_image)
226
  if faces==[]:
227
  faceid_embeds = torch.zeros_like(torch.empty((1, 512)))
228
+ else:
229
  faceid_embeds = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
230
+
231
  return faceid_embeds
232
 
233
  @torch.inference_mode()
 
245
  img = to_tensor(image)
246
  img = torch.unsqueeze(img, 0)
247
  img = img.float().cuda()
248
+ out = self.bise_net(img)[0]
249
+ parsing_anno = out.squeeze(0).cpu().numpy().argmax(0)
250
 
251
  im = np.array(image_resize_PIL)
252
  vis_im = im.copy().astype(np.uint8)
253
  stride=1
254
+ vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
 
255
  vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
256
  vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
257
 
 
261
  index = np.where(vis_parsing_anno == pi)
262
  vis_parsing_anno_color[index[0], index[1], :] = self.part_colors[pi]
263
 
264
+ vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
265
+ vis_parsing_anno_color = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)
266
 
267
  return vis_parsing_anno_color, vis_parsing_anno
268
 
 
290
 
291
  return face_caption
292
 
 
 
293
  @torch.inference_mode()
294
  def get_prepare_facemask(self, input_image_file):
295
+
296
  vis_parsing_anno_color, vis_parsing_anno = self.parsing_face_mask(input_image_file)
297
  parsing_mask_list = masks_for_unique_values(vis_parsing_anno)
298
 
299
  key_parsing_mask_list = {}
300
  key_list = ["Face", "Left_Ear", "Right_Ear", "Left_Eye", "Right_Eye", "Nose", "Upper_Lip", "Lower_Lip"]
 
 
301
  processed_keys = set()
302
  for key, mask_image in parsing_mask_list.items():
303
  if key in key_list:
304
  if "_" in key:
305
  prefix = key.split("_")[1]
306
+ if prefix in processed_keys:
307
  continue
308
  else:
309
  key_parsing_mask_list[key] = mask_image
 
325
  device: Optional[torch.device] = None,
326
  ):
327
  device = device or self._execution_device
328
+
 
 
329
  face_caption_align, key_parsing_mask_list_align = process_text_with_markers(face_caption, key_parsing_mask_list)
330
 
 
331
  prompt_face = prompt + "Detail:" + face_caption_align
332
 
333
+ max_text_length=330
334
  if len(self.tokenizer(prompt_face, max_length=self.tokenizer.model_max_length, padding="max_length",truncation=False,return_tensors="pt").input_ids[0])!=77:
335
  prompt_face = "Detail:" + face_caption_align + " Caption:" + prompt
336
 
 
340
 
341
  prompt_text_only = prompt_face.replace("<|facial|>", "").replace("<|image|>", "")
342
  tokenizer = self.tokenizer
 
343
  facial_token_id = tokenizer.convert_tokens_to_ids(facial_token)
344
+ image_token_id = None
345
 
 
 
 
 
346
  clean_input_id, image_token_mask, facial_token_mask = tokenize_and_mask_noun_phrases_ends(
347
  prompt_face, image_token_id, facial_token_id, tokenizer)
348
+
349
  image_token_idx, image_token_idx_mask, facial_token_idx, facial_token_idx_mask = prepare_image_token_idx(
350
  image_token_mask, facial_token_mask, num_id_images, max_num_facials )
351
 
 
360
  clip_image_processor = CLIPImageProcessor()
361
 
362
  num_facial_part = len(key_parsing_mask_list)
363
+
364
  for key in key_parsing_mask_list:
365
  key_mask=key_parsing_mask_list[key]
366
  facial_mask.append(transform_mask(key_mask))
 
367
  key_mask_raw_image = fetch_mask_raw_image(input_image_file,key_mask)
368
  parsing_clip_image = clip_image_processor(images=key_mask_raw_image, return_tensors="pt").pixel_values
369
  facial_clip_image.append(parsing_clip_image)
370
 
371
  padding_ficial_clip_image = torch.zeros_like(torch.zeros([1, 3, 224, 224]))
372
  padding_ficial_mask = torch.zeros_like(torch.zeros([1, image_size, image_size]))
373
+
374
  if num_facial_part < max_num_facials:
375
  facial_clip_image += [torch.zeros_like(padding_ficial_clip_image) for _ in range(max_num_facials - num_facial_part) ]
376
  facial_mask += [ torch.zeros_like(padding_ficial_mask) for _ in range(max_num_facials - num_facial_part)]
 
378
  facial_clip_image = torch.stack(facial_clip_image, dim=1).squeeze(0)
379
  facial_mask = torch.stack(facial_mask, dim=0).squeeze(dim=1)
380
 
381
+ return facial_clip_image, facial_mask
382
 
 
383
  @torch.no_grad()
384
  def __call__(
385
  self,
 
403
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
404
  callback_steps: int = 1,
405
  input_id_images: PipelineImageInput = None,
 
406
  start_merge_step: int = 0,
407
  class_tokens_mask: Optional[torch.LongTensor] = None,
408
  prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
 
 
409
  ):
410
  # 0. Default height and width to unet
411
  height = height or self.unet.config.sample_size * self.vae_scale_factor
 
431
  if prompt is not None and isinstance(prompt, str):
432
  batch_size = 1
433
  elif prompt is not None and isinstance(prompt, list):
434
+ batch_size = len(prompt)
435
  else:
436
  batch_size = prompt_embeds.shape[0]
437
 
 
439
  do_classifier_free_guidance = guidance_scale >= 1.0
440
  input_image_file = input_id_images[0]
441
 
442
+ faceid_embeds = self.get_prepare_faceid(face_image=input_image_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  face_caption = self.get_prepare_llva_caption(input_image_file)
 
444
  key_parsing_mask_list, vis_parsing_anno_color = self.get_prepare_facemask(input_image_file)
445
 
 
446
  assert do_classifier_free_guidance
447
 
448
  # 3. Encode input prompt
449
  num_id_images = len(input_id_images)
450
 
451
  (
452
+ prompt_text_only,
453
+ clean_input_id,
454
+ key_parsing_mask_list_align,
455
+ facial_token_mask,
456
+ facial_token_idx,
457
  facial_token_idx_mask,
458
  ) = self.encode_prompt_with_trigger_word(
459
  prompt = prompt,
460
+ face_caption = face_caption,
461
+ # prompt_2=None,
462
  key_parsing_mask_list=key_parsing_mask_list,
463
  device=device,
464
  max_num_facials = 5,
 
470
 
471
  # 4. Encode input prompt without the trigger word for delayed conditioning
472
  encoder_hidden_states = self.text_encoder(clean_input_id.to(device))[0]
473
+
 
474
  prompt_embeds = self._encode_prompt(
475
  prompt_text_only,
476
  device=device,
477
  num_images_per_prompt=num_images_per_prompt,
478
  do_classifier_free_guidance=True,
479
+ negative_prompt=negative_prompt,
480
+ )
481
  negative_encoder_hidden_states_text_only = prompt_embeds[0:num_images_per_prompt]
482
  encoder_hidden_states_text_only = prompt_embeds[num_images_per_prompt:]
483
 
484
  # 5. Prepare the input ID images
485
+ prompt_tokens_faceid, uncond_prompt_tokens_faceid = self.get_image_embeds(faceid_embeds, face_image=input_image_file, s_scale=1.0, shortcut=False)
486
+
 
 
487
  facial_clip_image, facial_mask = self.get_prepare_clip_image(input_image_file, key_parsing_mask_list_align, image_size=512, max_num_facials=5)
 
488
  facial_clip_images = facial_clip_image.unsqueeze(0).to(device, dtype=self.torch_dtype)
 
489
  facial_token_mask = facial_token_mask.to(device)
490
  facial_token_idx_mask = facial_token_idx_mask.to(device)
491
  negative_encoder_hidden_states = negative_encoder_hidden_states_text_only
492
 
493
  cross_attention_kwargs = {}
494
 
495
+ # 6. Get the update text embedding
496
  prompt_embeds_facial, uncond_prompt_embeds_facial = self.get_facial_embeds(encoder_hidden_states, negative_encoder_hidden_states, \
497
  facial_clip_images, facial_token_mask, facial_token_idx_mask)
498
+
 
499
  prompt_embeds = torch.cat([prompt_embeds_facial, prompt_tokens_faceid], dim=1)
500
  negative_prompt_embeds = torch.cat([uncond_prompt_embeds_facial, uncond_prompt_tokens_faceid], dim=1)
501
+
502
  prompt_embeds = self._encode_prompt(
503
  prompt,
504
  device,
 
508
  prompt_embeds=prompt_embeds,
509
  negative_prompt_embeds=negative_prompt_embeds,
510
  )
 
 
511
  prompt_embeds_text_only = torch.cat([encoder_hidden_states_text_only, prompt_tokens_faceid], dim=1)
512
  prompt_embeds = torch.cat([prompt_embeds, prompt_embeds_text_only], dim=0)
513
 
 
530
 
531
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
532
  (
533
+ null_prompt_embeds,
534
+ augmented_prompt_embeds,
535
+ text_prompt_embeds,
536
  ) = prompt_embeds.chunk(3)
537
 
538
  # 9. Denoising loop
 
553
  [null_prompt_embeds, augmented_prompt_embeds], dim=0
554
  )
555
 
556
+ # predict the noise residual
557
  noise_pred = self.unet(
558
  latent_model_input,
559
  t,
 
586
  if output_type == "latent":
587
  image = latents
588
  has_nsfw_concept = None
589
+ elif output_type == "pil":
590
  # 9.1 Post-processing
591
  image = self.decode_latents(latents)
592
 
593
  # 9.2 Run safety checker
594
+ image, has_nsfw_concept = self.run_safety_checker(
595
+ image, device, prompt_embeds.dtype
596
+ )
 
 
 
597
 
598
+ # 9.3 Convert to PIL
599
+ image = self.numpy_to_pil(image)
 
 
 
 
 
 
 
600
  else:
601
  # 9.1 Post-processing
602
  image = self.decode_latents(latents)
 
606
  image, device, prompt_embeds.dtype
607
  )
608
 
 
609
  # Offload last model to CPU
610
  if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
611
  self.final_offload_hook.offload()
 
617
  images=image, nsfw_content_detected=has_nsfw_concept
618
  )
619
 
620
+
621
+
622
+
623
+
624
+
625
+
626
+