JackAILab commited on
Commit
8c7338c
·
verified ·
1 Parent(s): d078b5d

Upload pipline_StableDiffusionXL_ConsistentID.py

Browse files
pipline_StableDiffusionXL_ConsistentID.py ADDED
@@ -0,0 +1,701 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
2
+ import cv2
3
+ import PIL
4
+ import numpy as np
5
+ from PIL import Image
6
+ import torch
7
+ from torchvision import transforms
8
+ from insightface.app import FaceAnalysis
9
+ ### insight-face installation can be found at https://github.com/deepinsight/insightface
10
+ from safetensors import safe_open
11
+ from huggingface_hub.utils import validate_hf_hub_args
12
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
13
+ from diffusers.utils import _get_model_file
14
+ from functions import process_text_with_markers, masks_for_unique_values, fetch_mask_raw_image, tokenize_and_mask_noun_phrases_ends, prepare_image_token_idx
15
+ from functions import ProjPlusModel, masks_for_unique_values
16
+ from attention import Consistent_IPAttProcessor, Consistent_AttProcessor, FacialEncoder
17
+ ### Model can be imported from https://github.com/zllrunning/face-parsing.PyTorch?tab=readme-ov-file
18
+ ### We use the ckpt of 79999_iter.pth: https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812
19
+ ### Thanks for the open source of face-parsing model.
20
+ from models.BiSeNet.model import BiSeNet # resnet tensorflow
21
+ import pdb
22
+ ######################################
23
+ ########## add for sdxl
24
+ ######################################
25
+ from diffusers import StableDiffusionXLPipeline
26
+ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
27
+ ######################################
28
+ ########## add for llava
29
+ ######################################
30
+ # import sys
31
+ # sys.path.append("./Llava1.5/LLaVA")
32
+ # from llava.model.builder import load_pretrained_model
33
+ # from llava.mm_utils import get_model_name_from_path
34
+ # from llava.eval.run_llava import eval_model
35
+
36
+ PipelineImageInput = Union[
37
+ PIL.Image.Image,
38
+ torch.FloatTensor,
39
+ List[PIL.Image.Image],
40
+ List[torch.FloatTensor],
41
+ ]
42
+
43
+
44
+ class ConsistentIDStableDiffusionXLPipeline(StableDiffusionXLPipeline):
45
+
46
+ @validate_hf_hub_args
47
+ def load_ConsistentID_model(
48
+ self,
49
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
50
+ weight_name: str,
51
+ subfolder: str = '',
52
+ trigger_word_ID: str = '<|image|>',
53
+ trigger_word_facial: str = '<|facial|>',
54
+ image_encoder_path: str = 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K', # Import CLIP pretrained model
55
+ bise_net_cp: str = 'JackAILab/ConsistentID/face_parsing.pth',
56
+ torch_dtype = torch.float16,
57
+ num_tokens = 4,
58
+ lora_rank= 128,
59
+ **kwargs,
60
+ ):
61
+ self.lora_rank = lora_rank
62
+ self.torch_dtype = torch_dtype
63
+ self.num_tokens = num_tokens
64
+ self.set_ip_adapter()
65
+ self.image_encoder_path = image_encoder_path
66
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
67
+ self.device, dtype=self.torch_dtype
68
+ )
69
+ self.clip_image_processor = CLIPImageProcessor()
70
+ self.id_image_processor = CLIPImageProcessor()
71
+ self.crop_size = 512
72
+
73
+ # FaceID
74
+ self.app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) ### root="/root/.insightface/models/buffalo_l"
75
+ self.app.prepare(ctx_id=0, det_size=(512, 512)) ### (640, 640)
76
+
77
+ ### BiSeNet
78
+ self.bise_net = BiSeNet(n_classes = 19)
79
+ self.bise_net.cuda()
80
+ self.bise_net_cp= bise_net_cp # Import BiSeNet model
81
+ self.bise_net.load_state_dict(torch.load(self.bise_net_cp)) # , map_location="cpu"
82
+ self.bise_net.eval()
83
+ # Colors for all 20 parts
84
+ self.part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
85
+ [255, 0, 85], [255, 0, 170],
86
+ [0, 255, 0], [85, 255, 0], [170, 255, 0],
87
+ [0, 255, 85], [0, 255, 170],
88
+ [0, 0, 255], [85, 0, 255], [170, 0, 255],
89
+ [0, 85, 255], [0, 170, 255],
90
+ [255, 255, 0], [255, 255, 85], [255, 255, 170],
91
+ [255, 0, 255], [255, 85, 255], [255, 170, 255],
92
+ [0, 255, 255], [85, 255, 255], [170, 255, 255]]
93
+
94
+ ### LLVA Optional
95
+ self.llva_model_path = "" #TODO import llava weights
96
+ self.llva_prompt = "Describe this person's facial features for me, including face, ears, eyes, nose, and mouth."
97
+ self.llva_tokenizer, self.llva_model, self.llva_image_processor, self.llva_context_len = None,None,None,None #load_pretrained_model(self.llva_model_path)
98
+
99
+ self.FacialEncoder = FacialEncoder(self.image_encoder, embedding_dim=1280, output_dim=2048, embed_dim=2048).to(self.device, dtype=self.torch_dtype)
100
+
101
+ # Load the main state dict first.
102
+ cache_dir = kwargs.pop("cache_dir", None)
103
+ force_download = kwargs.pop("force_download", False)
104
+ resume_download = kwargs.pop("resume_download", False)
105
+ proxies = kwargs.pop("proxies", None)
106
+ local_files_only = kwargs.pop("local_files_only", None)
107
+ token = kwargs.pop("token", None)
108
+ revision = kwargs.pop("revision", None)
109
+
110
+ user_agent = {
111
+ "file_type": "attn_procs_weights",
112
+ "framework": "pytorch",
113
+ }
114
+
115
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
116
+ model_file = _get_model_file(
117
+ pretrained_model_name_or_path_or_dict,
118
+ weights_name=weight_name,
119
+ cache_dir=cache_dir,
120
+ force_download=force_download,
121
+ resume_download=resume_download,
122
+ proxies=proxies,
123
+ local_files_only=local_files_only,
124
+ use_auth_token=token,
125
+ revision=revision,
126
+ subfolder=subfolder,
127
+ user_agent=user_agent,
128
+ )
129
+ if weight_name.endswith(".safetensors"):
130
+ state_dict = {"image_proj_model": {}, "adapter_modules": {}, "FacialEncoder": {}}
131
+ with safe_open(model_file, framework="pt", device="cpu") as f:
132
+ for key in f.keys():
133
+ if key.startswith("unet"):
134
+ pass
135
+ elif key.startswith("image_proj_model"):
136
+ state_dict["image_proj_model"][key.replace("image_proj_model.", "")] = f.get_tensor(key)
137
+ elif key.startswith("adapter_modules"):
138
+ state_dict["adapter_modules"][key.replace("adapter_modules.", "")] = f.get_tensor(key)
139
+ elif key.startswith("FacialEncoder"):
140
+ state_dict["FacialEncoder"][key.replace("FacialEncoder.", "")] = f.get_tensor(key)
141
+ else:
142
+ state_dict = torch.load(model_file, map_location="cuda")
143
+ else:
144
+ state_dict = pretrained_model_name_or_path_or_dict
145
+
146
+
147
+ self.trigger_word_ID = trigger_word_ID
148
+ self.trigger_word_facial = trigger_word_facial
149
+
150
+ self.image_proj_model = ProjPlusModel(
151
+ cross_attention_dim=self.unet.config.cross_attention_dim,
152
+ id_embeddings_dim=512,
153
+ clip_embeddings_dim=self.image_encoder.config.hidden_size,
154
+ num_tokens=self.num_tokens, # 4
155
+ ).to(self.device, dtype=self.torch_dtype)
156
+ self.image_proj_model.load_state_dict(state_dict["image_proj_model"], strict=True)
157
+
158
+ ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values())
159
+ ip_layers.load_state_dict(state_dict["adapter_modules"], strict=True)
160
+ self.FacialEncoder.load_state_dict(state_dict["FacialEncoder"], strict=True)
161
+ print(f"Successfully loaded weights from checkpoint")
162
+
163
+ # Add trigger word token
164
+ if self.tokenizer is not None:
165
+ self.tokenizer.add_tokens([self.trigger_word_ID], special_tokens=True)
166
+ self.tokenizer.add_tokens([self.trigger_word_facial], special_tokens=True)
167
+
168
+ ######################################
169
+ ########## add for sdxl
170
+ ######################################
171
+ ### (1) load lora into models
172
+ # print(f"Loading ConsistentID components lora_weights from [{pretrained_model_name_or_path_or_dict}]")
173
+ # self.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker")
174
+
175
+ ### (2) Add trigger word token for tokenizer_2
176
+ self.tokenizer_2.add_tokens([self.trigger_word_ID], special_tokens=True)
177
+
178
+ def set_ip_adapter(self):
179
+ unet = self.unet
180
+ attn_procs = {}
181
+ for name in unet.attn_processors.keys():
182
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
183
+ if name.startswith("mid_block"):
184
+ hidden_size = unet.config.block_out_channels[-1]
185
+ elif name.startswith("up_blocks"):
186
+ block_id = int(name[len("up_blocks.")])
187
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
188
+ elif name.startswith("down_blocks"):
189
+ block_id = int(name[len("down_blocks.")])
190
+ hidden_size = unet.config.block_out_channels[block_id]
191
+ if cross_attention_dim is None:
192
+ attn_procs[name] = Consistent_AttProcessor(
193
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank,
194
+ ).to(self.device, dtype=self.torch_dtype)
195
+ else:
196
+ attn_procs[name] = Consistent_IPAttProcessor(
197
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens,
198
+ ).to(self.device, dtype=self.torch_dtype)
199
+
200
+ unet.set_attn_processor(attn_procs)
201
+
202
+ @torch.inference_mode()
203
+ def get_facial_embeds(self, prompt_embeds, negative_prompt_embeds, facial_clip_images, facial_token_masks, valid_facial_token_idx_mask):
204
+
205
+ hidden_states = []
206
+ uncond_hidden_states = []
207
+ for facial_clip_image in facial_clip_images:
208
+ hidden_state = self.image_encoder(facial_clip_image.to(self.device, dtype=self.torch_dtype), output_hidden_states=True).hidden_states[-2]
209
+ uncond_hidden_state = self.image_encoder(torch.zeros_like(facial_clip_image, dtype=self.torch_dtype).to(self.device), output_hidden_states=True).hidden_states[-2]
210
+ hidden_states.append(hidden_state)
211
+ uncond_hidden_states.append(uncond_hidden_state)
212
+ multi_facial_embeds = torch.stack(hidden_states)
213
+ uncond_multi_facial_embeds = torch.stack(uncond_hidden_states)
214
+
215
+ # condition
216
+ facial_prompt_embeds = self.FacialEncoder(prompt_embeds, multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask)
217
+
218
+ # uncondition
219
+ uncond_facial_prompt_embeds = self.FacialEncoder(negative_prompt_embeds, uncond_multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask)
220
+
221
+ return facial_prompt_embeds, uncond_facial_prompt_embeds
222
+
223
+ @torch.inference_mode()
224
+ def get_image_embeds(self, faceid_embeds, face_image, s_scale=1.0, shortcut=False):
225
+
226
+ clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values
227
+ clip_image = clip_image.to(self.device, dtype=self.torch_dtype)
228
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
229
+ uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2]
230
+
231
+ faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
232
+ image_prompt_tokens = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale)
233
+ uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale)
234
+
235
+ return image_prompt_tokens, uncond_image_prompt_embeds
236
+
237
+ def set_scale(self, scale):
238
+ for attn_processor in self.pipe.unet.attn_processors.values():
239
+ if isinstance(attn_processor, Consistent_IPAttProcessor):
240
+ attn_processor.scale = scale
241
+
242
+ @torch.inference_mode()
243
+ def get_prepare_faceid(self, input_image_path=None):
244
+ faceid_image = cv2.imread(input_image_path)
245
+ face_info = self.app.get(faceid_image)
246
+ if face_info==[]:
247
+ faceid_embeds = torch.zeros_like(torch.empty((1, 512)))
248
+ else:
249
+ faceid_embeds = torch.from_numpy(face_info[0].normed_embedding).unsqueeze(0)
250
+
251
+ # print(f"faceid_embeds is : {faceid_embeds}")
252
+ return faceid_embeds
253
+
254
+ @torch.inference_mode()
255
+ def parsing_face_mask(self, raw_image_refer):
256
+
257
+ to_tensor = transforms.Compose([
258
+ transforms.ToTensor(),
259
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
260
+ ])
261
+ to_pil = transforms.ToPILImage()
262
+
263
+ with torch.no_grad():
264
+ ### change sdxl
265
+ image = raw_image_refer.resize((1280, 1280), Image.BILINEAR)
266
+ image_resize_PIL = image
267
+ img = to_tensor(image)
268
+ img = torch.unsqueeze(img, 0)
269
+ img = img.float().cuda()
270
+ out = self.bise_net(img)[0]
271
+ parsing_anno = out.squeeze(0).cpu().numpy().argmax(0)
272
+
273
+ im = np.array(image_resize_PIL)
274
+ vis_im = im.copy().astype(np.uint8)
275
+ stride=1
276
+ vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
277
+ vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
278
+ vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
279
+
280
+ num_of_class = np.max(vis_parsing_anno)
281
+
282
+ for pi in range(1, num_of_class + 1): # num_of_class=17 pi=1~16
283
+ index = np.where(vis_parsing_anno == pi)
284
+ vis_parsing_anno_color[index[0], index[1], :] = self.part_colors[pi]
285
+
286
+ vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
287
+ vis_parsing_anno_color = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)
288
+
289
+ return vis_parsing_anno_color, vis_parsing_anno
290
+
291
+ @torch.inference_mode()
292
+ def get_prepare_llva_caption(self, input_image_file, model_path=None, prompt=None):
293
+
294
+ ### Optional: Use the LLaVA
295
+ # args = type('Args', (), {
296
+ # "model_path": self.llva_model_path,
297
+ # "model_base": None,
298
+ # "model_name": get_model_name_from_path(self.llva_model_path),
299
+ # "query": self.llva_prompt,
300
+ # "conv_mode": None,
301
+ # "image_file": input_image_file,
302
+ # "sep": ",",
303
+ # "temperature": 0,
304
+ # "top_p": None,
305
+ # "num_beams": 1,
306
+ # "max_new_tokens": 512
307
+ # })()
308
+ # face_caption = eval_model(args, self.llva_tokenizer, self.llva_model, self.llva_image_processor)
309
+
310
+ ### Use built-in template
311
+ face_caption = "The person has one face, one nose, two eyes, two ears, and a mouth."
312
+
313
+ return face_caption
314
+
315
+ @torch.inference_mode()
316
+ def get_prepare_facemask(self, input_image_file):
317
+
318
+ vis_parsing_anno_color, vis_parsing_anno = self.parsing_face_mask(input_image_file)
319
+ parsing_mask_list = masks_for_unique_values(vis_parsing_anno)
320
+
321
+ key_parsing_mask_list = {}
322
+ key_list = ["Face", "Left_Ear", "Right_Ear", "Left_Eye", "Right_Eye", "Nose", "Upper_Lip", "Lower_Lip"]
323
+ processed_keys = set()
324
+ for key, mask_image in parsing_mask_list.items():
325
+ if key in key_list:
326
+ if "_" in key:
327
+ prefix = key.split("_")[1]
328
+ if prefix in processed_keys:
329
+ continue
330
+ else:
331
+ key_parsing_mask_list[key] = mask_image
332
+ processed_keys.add(prefix)
333
+
334
+ key_parsing_mask_list[key] = mask_image
335
+
336
+ return key_parsing_mask_list, vis_parsing_anno_color
337
+
338
+ def encode_prompt_with_trigger_word(
339
+ self,
340
+ prompt: str,
341
+ face_caption: str,
342
+ key_parsing_mask_list = None,
343
+ image_token = "<|image|>",
344
+ facial_token = "<|facial|>",
345
+ max_num_facials = 5,
346
+ num_id_images: int = 1,
347
+ device: Optional[torch.device] = None,
348
+ ):
349
+ device = device or self._execution_device
350
+
351
+ # pdb.set_trace()
352
+ face_caption_align, key_parsing_mask_list_align = process_text_with_markers(face_caption, key_parsing_mask_list)
353
+
354
+ prompt_face = prompt + "; Detail:" + face_caption_align
355
+
356
+ max_text_length=330
357
+ if len(self.tokenizer(prompt_face, max_length=self.tokenizer.model_max_length, padding="max_length",truncation=False,return_tensors="pt").input_ids[0])!=77:
358
+ prompt_face = "; Detail:" + face_caption_align + " Caption:" + prompt
359
+
360
+ if len(face_caption)>max_text_length:
361
+ prompt_face = prompt
362
+ face_caption_align = ""
363
+
364
+ prompt_text_only = prompt_face.replace("<|facial|>", "").replace("<|image|>", "")
365
+ tokenizer = self.tokenizer
366
+ facial_token_id = tokenizer.convert_tokens_to_ids(facial_token)
367
+ image_token_id = None
368
+
369
+ clean_input_id, image_token_mask, facial_token_mask = tokenize_and_mask_noun_phrases_ends(
370
+ prompt_face, image_token_id, facial_token_id, tokenizer)
371
+
372
+ image_token_idx, image_token_idx_mask, facial_token_idx, facial_token_idx_mask = prepare_image_token_idx(
373
+ image_token_mask, facial_token_mask, num_id_images, max_num_facials )
374
+
375
+ ######################################
376
+ ########## add for sdxl
377
+ ######################################
378
+ tokenizer_2 = self.tokenizer_2
379
+ facial_token_id2 = tokenizer.convert_tokens_to_ids(facial_token)
380
+ image_token_id2 = None
381
+ clean_input_id2, image_token_mask2, facial_token_mask2 = tokenize_and_mask_noun_phrases_ends(
382
+ prompt_face, image_token_id2, facial_token_id2, tokenizer_2)
383
+
384
+ image_token_idx2, image_token_idx_mask2, facial_token_idx2, facial_token_idx_mask2 = prepare_image_token_idx(
385
+ image_token_mask2, facial_token_mask2, num_id_images, max_num_facials )
386
+
387
+ return prompt_text_only, clean_input_id, clean_input_id2, key_parsing_mask_list_align, facial_token_mask, facial_token_idx, facial_token_idx_mask
388
+
389
+ @torch.inference_mode()
390
+ def get_prepare_clip_image(self, input_image_file, key_parsing_mask_list, image_size=512, max_num_facials=5, change_facial=True):
391
+
392
+ facial_mask = []
393
+ facial_clip_image = []
394
+ transform_mask = transforms.Compose([transforms.CenterCrop(size=image_size), transforms.ToTensor(),])
395
+ clip_image_processor = CLIPImageProcessor()
396
+
397
+ num_facial_part = len(key_parsing_mask_list)
398
+
399
+ for key in key_parsing_mask_list:
400
+ key_mask=key_parsing_mask_list[key]
401
+ facial_mask.append(transform_mask(key_mask))
402
+ key_mask_raw_image = fetch_mask_raw_image(input_image_file,key_mask)
403
+ parsing_clip_image = clip_image_processor(images=key_mask_raw_image, return_tensors="pt").pixel_values
404
+ facial_clip_image.append(parsing_clip_image)
405
+
406
+ padding_ficial_clip_image = torch.zeros_like(torch.zeros([1, 3, 224, 224]))
407
+ padding_ficial_mask = torch.zeros_like(torch.zeros([1, image_size, image_size]))
408
+
409
+ if num_facial_part < max_num_facials:
410
+ facial_clip_image += [torch.zeros_like(padding_ficial_clip_image) for _ in range(max_num_facials - num_facial_part) ]
411
+ facial_mask += [ torch.zeros_like(padding_ficial_mask) for _ in range(max_num_facials - num_facial_part)]
412
+
413
+ facial_clip_image = torch.stack(facial_clip_image, dim=1).squeeze(0)
414
+ facial_mask = torch.stack(facial_mask, dim=0).squeeze(dim=1)
415
+
416
+ return facial_clip_image, facial_mask
417
+
418
+ @torch.no_grad()
419
+ def __call__(
420
+ self,
421
+ prompt: Union[str, List[str]] = None,
422
+ face_caption: Union[str, List[str]] = None,
423
+ height: Optional[int] = None,
424
+ width: Optional[int] = None,
425
+ num_inference_steps: int = 50,
426
+ guidance_scale: float = 7.5,
427
+ negative_prompt: Optional[Union[str, List[str]]] = None,
428
+ num_images_per_prompt: Optional[int] = 1,
429
+ eta: float = 0.0,
430
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
431
+ latents: Optional[torch.FloatTensor] = None,
432
+ prompt_embeds: Optional[torch.FloatTensor] = None,
433
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
434
+ output_type: Optional[str] = "pil",
435
+ return_dict: bool = True,
436
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
437
+ original_size: Optional[Tuple[int, int]] = None,
438
+ target_size: Optional[Tuple[int, int]] = None,
439
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
440
+ callback_steps: int = 1,
441
+ input_id_images: PipelineImageInput = None,
442
+ input_image_path: PipelineImageInput = None,
443
+ start_merge_step: int = 0,
444
+ class_tokens_mask: Optional[torch.LongTensor] = None,
445
+ prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
446
+ ### add for sdxl
447
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
448
+ prompt_2: Optional[Union[str, List[str]]] = None,
449
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
450
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
451
+ pooled_prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
452
+ guidance_rescale: float = 7.5
453
+ ):
454
+ # 0. Default height and width to unet
455
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
456
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
457
+
458
+ original_size = original_size or (height, width)
459
+ target_size = target_size or (height, width)
460
+
461
+ # 1. Check inputs. Raise error if not correct
462
+ # self.check_inputs(
463
+ # prompt,
464
+ # height,
465
+ # width,
466
+ # callback_steps,
467
+ # negative_prompt,
468
+ # prompt_embeds,
469
+ # negative_prompt_embeds,
470
+ # )
471
+
472
+ if not isinstance(input_id_images, list):
473
+ input_id_images = [input_id_images]
474
+
475
+ # 2. Define call parameters
476
+ if prompt is not None and isinstance(prompt, str):
477
+ batch_size = 1
478
+ elif prompt is not None and isinstance(prompt, list):
479
+ batch_size = len(prompt)
480
+ else:
481
+ batch_size = prompt_embeds.shape[0]
482
+
483
+ device = self._execution_device
484
+ do_classifier_free_guidance = guidance_scale >= 1.0
485
+ input_image_file = input_id_images[0]
486
+
487
+ faceid_embeds = self.get_prepare_faceid(input_image_path=input_image_path)
488
+ face_caption = self.get_prepare_llva_caption(input_image_file=input_image_file)
489
+ key_parsing_mask_list, vis_parsing_anno_color = self.get_prepare_facemask(input_image_file)
490
+
491
+ assert do_classifier_free_guidance
492
+
493
+ # 3. Encode input prompt
494
+ num_id_images = len(input_id_images)
495
+
496
+ (
497
+ prompt_text_only,
498
+ clean_input_id,
499
+ clean_input_id2, ### add for sdxl
500
+ key_parsing_mask_list_align,
501
+ facial_token_mask,
502
+ facial_token_idx,
503
+ facial_token_idx_mask,
504
+ ) = self.encode_prompt_with_trigger_word(
505
+ prompt = prompt,
506
+ face_caption = face_caption,
507
+ key_parsing_mask_list=key_parsing_mask_list,
508
+ device=device,
509
+ max_num_facials = 5,
510
+ num_id_images= num_id_images,
511
+ )
512
+
513
+ # 4. Encode input prompt without the trigger word for delayed conditioning
514
+ text_embeds = self.text_encoder(clean_input_id.to(device), output_hidden_states=True).hidden_states[-2]
515
+ ######################################
516
+ ########## add for sdxl : add pooled_text_embeds
517
+ ######################################
518
+ ### (4-1)
519
+ encoder_output_2 = self.text_encoder_2(clean_input_id2.to(device), output_hidden_states=True)
520
+ pooled_text_embeds = encoder_output_2[0]
521
+ text_embeds_2 = encoder_output_2.hidden_states[-2]
522
+
523
+ ### (4-2)
524
+ encoder_hidden_states = torch.concat([text_embeds, text_embeds_2], dim=-1) # concat
525
+
526
+ ### (4-3)
527
+ if self.text_encoder_2 is None:
528
+ text_encoder_projection_dim = int(pooled_text_embeds.shape[-1])
529
+ else:
530
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
531
+ add_time_ids = self._get_add_time_ids(
532
+ original_size,
533
+ crops_coords_top_left,
534
+ target_size,
535
+ dtype=self.torch_dtype,
536
+ text_encoder_projection_dim=text_encoder_projection_dim,
537
+ )
538
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) ### add_time_ids.Size([2, 6])
539
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
540
+
541
+ ######################################
542
+ ########## add for sdxl : add pooled_prompt_embeds
543
+ ######################################
544
+ text_encoder_lora_scale = (
545
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
546
+ )
547
+ (
548
+ prompt_embeds,
549
+ negative_prompt_embeds,
550
+ pooled_prompt_embeds_text_only,
551
+ negative_pooled_prompt_embeds,
552
+ )= self.encode_prompt(
553
+ prompt=prompt,
554
+ prompt_2=prompt_2,
555
+ device=device,
556
+ num_images_per_prompt=num_images_per_prompt,
557
+ do_classifier_free_guidance=do_classifier_free_guidance,
558
+ negative_prompt=negative_prompt,
559
+ negative_prompt_2=negative_prompt_2,
560
+ prompt_embeds=prompt_embeds_text_only,
561
+ negative_prompt_embeds=negative_prompt_embeds,
562
+ pooled_prompt_embeds=pooled_prompt_embeds_text_only,
563
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
564
+ lora_scale=text_encoder_lora_scale,
565
+ )
566
+
567
+ # 5. Prepare the input ID images
568
+ prompt_tokens_faceid, uncond_prompt_tokens_faceid = self.get_image_embeds(faceid_embeds, face_image=input_image_file, s_scale=1.0, shortcut=True)
569
+
570
+ facial_clip_image, facial_mask = self.get_prepare_clip_image(input_image_file, key_parsing_mask_list_align, image_size=1280, max_num_facials=5)
571
+ facial_clip_images = facial_clip_image.unsqueeze(0).to(device, dtype=self.torch_dtype)
572
+ facial_token_mask = facial_token_mask.to(device)
573
+ facial_token_idx_mask = facial_token_idx_mask.to(device)
574
+
575
+ cross_attention_kwargs = {}
576
+
577
+ # 6. Get the update text embedding
578
+ prompt_embeds_facial, uncond_prompt_embeds_facial = self.get_facial_embeds(encoder_hidden_states, negative_prompt_embeds, \
579
+ facial_clip_images, facial_token_mask, facial_token_idx_mask)
580
+
581
+ ########## text_facial embeds
582
+ prompt_embeds_facial = torch.cat([prompt_embeds_facial, prompt_tokens_faceid], dim=1)
583
+ negative_prompt_embeds_facial = torch.cat([uncond_prompt_embeds_facial, uncond_prompt_tokens_faceid], dim=1)
584
+
585
+ ########## text_only embeds
586
+ prompt_embeds_text_only = torch.cat([prompt_embeds, prompt_tokens_faceid], dim=1)
587
+ negative_prompt_embeds_text_only = torch.cat([negative_prompt_embeds, uncond_prompt_tokens_faceid], dim=1)
588
+
589
+ # 7. Prepare timesteps
590
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
591
+ timesteps = self.scheduler.timesteps
592
+
593
+ # 8. Prepare latent variables
594
+ num_channels_latents = self.unet.in_channels
595
+ latents = self.prepare_latents(
596
+ batch_size * num_images_per_prompt,
597
+ num_channels_latents,
598
+ height,
599
+ width,
600
+ prompt_embeds.dtype,
601
+ device,
602
+ generator,
603
+ latents,
604
+ )
605
+
606
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
607
+
608
+ # 9. Denoising loop
609
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
610
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
611
+ for i, t in enumerate(timesteps):
612
+ latent_model_input = (
613
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
614
+ )
615
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
616
+
617
+ ######################################
618
+ ########## add for sdxl : add unet_added_cond_kwargs
619
+ ######################################
620
+ if i <= start_merge_step:
621
+ current_prompt_embeds = torch.cat(
622
+ [negative_prompt_embeds_text_only, prompt_embeds_text_only], dim=0
623
+ )
624
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds_text_only], dim=0)
625
+ else:
626
+ current_prompt_embeds = torch.cat(
627
+ [negative_prompt_embeds_facial, prompt_embeds_facial], dim=0
628
+ )
629
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_text_embeds], dim=0)
630
+
631
+ unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
632
+
633
+ # predict the noise residual
634
+ noise_pred = self.unet(
635
+ latent_model_input,
636
+ t,
637
+ encoder_hidden_states=current_prompt_embeds,
638
+ cross_attention_kwargs=cross_attention_kwargs,
639
+ added_cond_kwargs=unet_added_cond_kwargs,
640
+ # return_dict=False, ### [0]
641
+ ).sample
642
+
643
+ # perform guidance
644
+ if do_classifier_free_guidance:
645
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
646
+ noise_pred = noise_pred_uncond + guidance_scale * (
647
+ noise_pred_text - noise_pred_uncond
648
+ )
649
+ else:
650
+ assert 0, "Not Implemented"
651
+
652
+ # if do_classifier_free_guidance and guidance_rescale > 0.0:
653
+ # # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
654
+ # noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) ### TODO optimal noise and LCM
655
+
656
+ # compute the previous noisy sample x_t -> x_t-1
657
+ latents = self.scheduler.step(
658
+ noise_pred, t, latents, **extra_step_kwargs
659
+ ).prev_sample
660
+
661
+ # call the callback, if provided
662
+ if i == len(timesteps) - 1 or (
663
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
664
+ ):
665
+ progress_bar.update()
666
+ if callback is not None and i % callback_steps == 0:
667
+ callback(i, t, latents)
668
+
669
+ # make sure the VAE is in float32 mode, as it overflows in float16
670
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
671
+ self.upcast_vae()
672
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
673
+
674
+ if not output_type == "latent":
675
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
676
+ else:
677
+ image = latents
678
+ return StableDiffusionXLPipelineOutput(images=image)
679
+
680
+ # apply watermark if available
681
+ # if self.watermark is not None:
682
+ # image = self.watermark.apply_watermark(image)
683
+
684
+ image = self.image_processor.postprocess(image, output_type=output_type)
685
+
686
+ # Offload all models
687
+ self.maybe_free_model_hooks()
688
+
689
+ if not return_dict:
690
+ return (image,)
691
+
692
+ return StableDiffusionXLPipelineOutput(images=image)
693
+
694
+
695
+
696
+
697
+
698
+
699
+
700
+
701
+