JackAILab commited on
Commit
17d73b1
·
verified ·
1 Parent(s): 8c7338c

Upload 2 files

Browse files
Files changed (2) hide show
  1. attention.py +294 -0
  2. functions.py +605 -0
attention.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from diffusers.models.lora import LoRALinearLayer
5
+ from functions import AttentionMLP
6
+ from diffusers.utils.import_utils import is_xformers_available
7
+ if is_xformers_available():
8
+ import xformers
9
+
10
+ class FuseModule(nn.Module):
11
+ def __init__(self, embed_dim):
12
+ super().__init__()
13
+ self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False)
14
+ self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True)
15
+ self.layer_norm = nn.LayerNorm(embed_dim)
16
+
17
+ def fuse_fn(self, prompt_embeds, id_embeds):
18
+ stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1)
19
+ stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds
20
+ stacked_id_embeds = self.mlp2(stacked_id_embeds)
21
+ stacked_id_embeds = self.layer_norm(stacked_id_embeds)
22
+ return stacked_id_embeds
23
+
24
+ def forward(
25
+ self,
26
+ prompt_embeds,
27
+ id_embeds,
28
+ class_tokens_mask,
29
+ valid_id_mask,
30
+ ) -> torch.Tensor:
31
+
32
+ id_embeds = id_embeds.to(prompt_embeds.dtype)
33
+ batch_size, max_num_inputs = id_embeds.shape[:2]
34
+ seq_length = prompt_embeds.shape[1]
35
+ flat_id_embeds = id_embeds.view(-1, id_embeds.shape[-2], id_embeds.shape[-1])
36
+
37
+ valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()]
38
+
39
+ prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1])
40
+ class_tokens_mask = class_tokens_mask.view(-1)
41
+ valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1])
42
+ image_token_embeds = prompt_embeds[class_tokens_mask]
43
+ stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds)
44
+ assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}"
45
+ prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype))
46
+ updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1)
47
+
48
+ return updated_prompt_embeds
49
+
50
+ class MLP(nn.Module):
51
+ def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True):
52
+ super().__init__()
53
+ if use_residual:
54
+ assert in_dim == out_dim
55
+ self.layernorm = nn.LayerNorm(in_dim)
56
+ self.fc1 = nn.Linear(in_dim, hidden_dim)
57
+ self.fc2 = nn.Linear(hidden_dim, out_dim)
58
+ self.use_residual = use_residual
59
+ self.act_fn = nn.GELU()
60
+
61
+ def forward(self, x):
62
+
63
+ residual = x
64
+ x = self.layernorm(x)
65
+ x = self.fc1(x)
66
+ x = self.act_fn(x)
67
+ x = self.fc2(x)
68
+ if self.use_residual:
69
+ x = x + residual
70
+ return x
71
+
72
+ class FacialEncoder(nn.Module):
73
+ def __init__(self,image_CLIPModel_encoder=None,embedding_dim=1280, output_dim=768, embed_dim=768):
74
+ super().__init__()
75
+ self.visual_projection = AttentionMLP(embedding_dim=embedding_dim, output_dim=output_dim)
76
+ self.fuse_module = FuseModule(embed_dim=embed_dim)
77
+
78
+ def forward(self, prompt_embeds, multi_image_embeds, class_tokens_mask, valid_id_mask):
79
+
80
+ bs, num_inputs, token_length, image_dim = multi_image_embeds.shape
81
+ multi_image_embeds_view = multi_image_embeds.view(bs * num_inputs, token_length, image_dim)
82
+
83
+ id_embeds = self.visual_projection(multi_image_embeds_view)
84
+ id_embeds = id_embeds.view(bs, num_inputs, 1, -1)
85
+
86
+ updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask, valid_id_mask)
87
+
88
+ return updated_prompt_embeds
89
+
90
+ class Consistent_AttProcessor(nn.Module):
91
+
92
+ def __init__(
93
+ self,
94
+ hidden_size=None,
95
+ cross_attention_dim=None,
96
+ rank=4,
97
+ network_alpha=None,
98
+ lora_scale=1.0,
99
+ ):
100
+ super().__init__()
101
+
102
+ self.rank = rank
103
+ self.lora_scale = lora_scale
104
+
105
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
106
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
107
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
108
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
109
+
110
+ def __call__(
111
+ self,
112
+ attn,
113
+ hidden_states,
114
+ encoder_hidden_states=None,
115
+ attention_mask=None,
116
+ temb=None,
117
+ ):
118
+ residual = hidden_states
119
+
120
+ if attn.spatial_norm is not None:
121
+ hidden_states = attn.spatial_norm(hidden_states, temb)
122
+
123
+ input_ndim = hidden_states.ndim
124
+
125
+ if input_ndim == 4:
126
+ batch_size, channel, height, width = hidden_states.shape
127
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
128
+
129
+ batch_size, sequence_length, _ = (
130
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
131
+ )
132
+
133
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
134
+
135
+ if attn.group_norm is not None:
136
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
137
+
138
+ query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
139
+
140
+ if encoder_hidden_states is None:
141
+ encoder_hidden_states = hidden_states
142
+ elif attn.norm_cross:
143
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
144
+
145
+ key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
146
+ value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
147
+
148
+ query = attn.head_to_batch_dim(query)
149
+ key = attn.head_to_batch_dim(key)
150
+ value = attn.head_to_batch_dim(value)
151
+
152
+ if is_xformers_available():
153
+ ### xformers
154
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
155
+ hidden_states = hidden_states.to(query.dtype)
156
+ else:
157
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
158
+ hidden_states = torch.bmm(attention_probs, value)
159
+ hidden_states = attn.batch_to_head_dim(hidden_states)
160
+
161
+ # linear proj
162
+ hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
163
+ # dropout
164
+ hidden_states = attn.to_out[1](hidden_states)
165
+
166
+ if input_ndim == 4:
167
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
168
+
169
+ if attn.residual_connection:
170
+ hidden_states = hidden_states + residual
171
+
172
+ hidden_states = hidden_states / attn.rescale_output_factor
173
+
174
+ return hidden_states
175
+
176
+
177
+ class Consistent_IPAttProcessor(nn.Module):
178
+
179
+ def __init__(
180
+ self,
181
+ hidden_size,
182
+ cross_attention_dim=None,
183
+ rank=4,
184
+ network_alpha=None,
185
+ lora_scale=1.0,
186
+ scale=1.0,
187
+ num_tokens=4):
188
+ super().__init__()
189
+
190
+ self.rank = rank
191
+ self.lora_scale = lora_scale
192
+ self.num_tokens = num_tokens
193
+
194
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
195
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
196
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
197
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
198
+
199
+
200
+ self.hidden_size = hidden_size
201
+ self.cross_attention_dim = cross_attention_dim
202
+ self.scale = scale
203
+
204
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
205
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
206
+
207
+ def __call__(
208
+ self,
209
+ attn,
210
+ hidden_states,
211
+ encoder_hidden_states=None,
212
+ attention_mask=None,
213
+ scale=1.0,
214
+ temb=None,
215
+ ):
216
+ residual = hidden_states
217
+
218
+ if attn.spatial_norm is not None:
219
+ hidden_states = attn.spatial_norm(hidden_states, temb)
220
+
221
+ input_ndim = hidden_states.ndim
222
+
223
+ if input_ndim == 4:
224
+ batch_size, channel, height, width = hidden_states.shape
225
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
226
+
227
+ batch_size, sequence_length, _ = (
228
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
229
+ )
230
+
231
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
232
+
233
+ if attn.group_norm is not None:
234
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
235
+
236
+ query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
237
+
238
+ if encoder_hidden_states is None:
239
+ encoder_hidden_states = hidden_states
240
+ else:
241
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
242
+ encoder_hidden_states, ip_hidden_states = (
243
+ encoder_hidden_states[:, :end_pos, :],
244
+ encoder_hidden_states[:, end_pos:, :],
245
+ )
246
+ if attn.norm_cross:
247
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
248
+
249
+ key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
250
+ value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
251
+
252
+ inner_dim = key.shape[-1]
253
+ head_dim = inner_dim // attn.heads
254
+
255
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
256
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
257
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
258
+
259
+ hidden_states = F.scaled_dot_product_attention(
260
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
261
+ )
262
+
263
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
264
+ hidden_states = hidden_states.to(query.dtype)
265
+
266
+ ip_key = self.to_k_ip(ip_hidden_states)
267
+ ip_value = self.to_v_ip(ip_hidden_states)
268
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
269
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
270
+
271
+
272
+ ip_hidden_states = F.scaled_dot_product_attention(
273
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
274
+ )
275
+
276
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
277
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
278
+
279
+ hidden_states = hidden_states + self.scale * ip_hidden_states
280
+
281
+ # linear proj
282
+ hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
283
+ # dropout
284
+ hidden_states = attn.to_out[1](hidden_states)
285
+
286
+ if input_ndim == 4:
287
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
288
+
289
+ if attn.residual_connection:
290
+ hidden_states = hidden_states + residual
291
+
292
+ hidden_states = hidden_states / attn.rescale_output_factor
293
+
294
+ return hidden_states
functions.py ADDED
@@ -0,0 +1,605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+ import types
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+ import cv2
8
+ import re
9
+ import torch.nn.functional as F
10
+ from einops import rearrange
11
+ from einops.layers.torch import Rearrange
12
+ from PIL import Image
13
+
14
+ def extract_first_sentence(text):
15
+ end_index = text.find('.')
16
+ if end_index != -1:
17
+ first_sentence = text[:end_index + 1]
18
+ return first_sentence.strip()
19
+ else:
20
+ return text.strip()
21
+
22
+ import re
23
+ def remove_duplicate_keywords(text, keywords): ### This function can continue to be optimized
24
+ keyword_counts = {}
25
+
26
+ words = re.findall(r'\b\w+\b|[.,;!?]', text)
27
+
28
+ for keyword in keywords:
29
+ keyword_counts[keyword] = 0
30
+ for i, word in enumerate(words):
31
+ if word.lower() == keyword.lower():
32
+ keyword_counts[keyword] += 1
33
+ if keyword_counts[keyword] > 1:
34
+ words[i] = ""
35
+ processed_text = " ".join(words)
36
+
37
+ return processed_text
38
+
39
+ def process_text_with_markers(text, parsing_mask_list):
40
+ keywords = ["face", "ears", "eyes", "nose", "mouth"]
41
+ text = remove_duplicate_keywords(text, keywords)
42
+ key_parsing_mask_markers = ["Face", "Left_Ear", "Right_Ear", "Left_Eye", "Right_Eye", "Nose", "Upper_Lip", "Lower_Lip"]
43
+ mapping = {
44
+ "Face": "face",
45
+ "Left_Ear": "ears",
46
+ "Right_Ear": "ears",
47
+ "Left_Eye": "eyes",
48
+ "Right_Eye": "eyes",
49
+ "Nose": "nose",
50
+ "Upper_Lip": "mouth",
51
+ "Lower_Lip": "mouth",
52
+ }
53
+ facial_features_align = []
54
+ markers_align = []
55
+ for key in key_parsing_mask_markers:
56
+ if key in parsing_mask_list:
57
+ mapped_key = mapping.get(key, key.lower())
58
+ if mapped_key not in facial_features_align:
59
+ facial_features_align.append(mapped_key)
60
+ markers_align.append("<|"+mapped_key+"|>")
61
+
62
+ text_marked = text
63
+ align_parsing_mask_list = parsing_mask_list
64
+ for feature, marker in zip(facial_features_align[::-1], markers_align[::-1]):
65
+ pattern = rf'\b{feature}\b'
66
+ text_marked_new = re.sub(pattern, f'{feature} {marker}', text_marked, count=1)
67
+ if text_marked == text_marked_new:
68
+ for key, value in mapping.items():
69
+ if value == feature:
70
+ if key in align_parsing_mask_list:
71
+ del align_parsing_mask_list[key]
72
+
73
+ text_marked = text_marked_new
74
+
75
+ text_marked = text_marked.replace('\n', '')
76
+
77
+ ordered_text = []
78
+ text_none_makers = []
79
+ facial_marked_count = 0
80
+ skip_count = 0
81
+ for marker in markers_align:
82
+ start_idx = text_marked.find(marker)
83
+ end_idx = start_idx + len(marker)
84
+
85
+ while start_idx > 0 and text_marked[start_idx - 1] not in [",", ".", ";"]:
86
+ start_idx -= 1
87
+
88
+ while end_idx < len(text_marked) and text_marked[end_idx] not in [",", ".", ";"]:
89
+ end_idx += 1
90
+
91
+ context = text_marked[start_idx:end_idx].strip()
92
+ if context == "":
93
+ text_none_makers.append(text_marked[:end_idx])
94
+ else:
95
+ if skip_count!=0:
96
+ skip_count -= 1
97
+ continue
98
+ else:
99
+ ordered_text.append(context + ",")
100
+ text_delete_makers = text_marked[:start_idx] + text_marked[end_idx:]
101
+ text_marked = text_delete_makers
102
+ facial_marked_count += 1
103
+
104
+ align_marked_text = " ".join(ordered_text)
105
+ replace_list = ["<|face|>", "<|ears|>", "<|nose|>", "<|eyes|>", "<|mouth|>"]
106
+ for item in replace_list:
107
+ align_marked_text = align_marked_text.replace(item, "<|facial|>")
108
+
109
+ return align_marked_text, align_parsing_mask_list
110
+
111
+ def tokenize_and_mask_noun_phrases_ends(text, image_token_id, facial_token_id, tokenizer):
112
+ input_ids = tokenizer.encode(text)
113
+ image_noun_phrase_end_mask = [False for _ in input_ids]
114
+ facial_noun_phrase_end_mask = [False for _ in input_ids]
115
+ clean_input_ids = []
116
+ clean_index = 0
117
+ image_num = 0
118
+
119
+ for i, id in enumerate(input_ids):
120
+ if id == image_token_id:
121
+ image_noun_phrase_end_mask[clean_index + image_num - 1] = True
122
+ image_num += 1
123
+ elif id == facial_token_id:
124
+ facial_noun_phrase_end_mask[clean_index - 1] = True
125
+ else:
126
+ clean_input_ids.append(id)
127
+ clean_index += 1
128
+
129
+ max_len = tokenizer.model_max_length
130
+
131
+ if len(clean_input_ids) > max_len:
132
+ clean_input_ids = clean_input_ids[:max_len]
133
+ else:
134
+ clean_input_ids = clean_input_ids + [tokenizer.pad_token_id] * (
135
+ max_len - len(clean_input_ids)
136
+ )
137
+
138
+ if len(image_noun_phrase_end_mask) > max_len:
139
+ image_noun_phrase_end_mask = image_noun_phrase_end_mask[:max_len]
140
+ else:
141
+ image_noun_phrase_end_mask = image_noun_phrase_end_mask + [False] * (
142
+ max_len - len(image_noun_phrase_end_mask)
143
+ )
144
+
145
+ if len(facial_noun_phrase_end_mask) > max_len:
146
+ facial_noun_phrase_end_mask = facial_noun_phrase_end_mask[:max_len]
147
+ else:
148
+ facial_noun_phrase_end_mask = facial_noun_phrase_end_mask + [False] * (
149
+ max_len - len(facial_noun_phrase_end_mask)
150
+ )
151
+
152
+ clean_input_ids = torch.tensor(clean_input_ids, dtype=torch.long)
153
+ image_noun_phrase_end_mask = torch.tensor(image_noun_phrase_end_mask, dtype=torch.bool)
154
+ facial_noun_phrase_end_mask = torch.tensor(facial_noun_phrase_end_mask, dtype=torch.bool)
155
+
156
+ return clean_input_ids.unsqueeze(0), image_noun_phrase_end_mask.unsqueeze(0), facial_noun_phrase_end_mask.unsqueeze(0)
157
+
158
+ def prepare_image_token_idx(image_token_mask, facial_token_mask, max_num_objects=2, max_num_facials=5):
159
+ image_token_idx = torch.nonzero(image_token_mask, as_tuple=True)[1]
160
+ image_token_idx_mask = torch.ones_like(image_token_idx, dtype=torch.bool)
161
+ if len(image_token_idx) < max_num_objects:
162
+ image_token_idx = torch.cat(
163
+ [
164
+ image_token_idx,
165
+ torch.zeros(max_num_objects - len(image_token_idx), dtype=torch.long),
166
+ ]
167
+ )
168
+ image_token_idx_mask = torch.cat(
169
+ [
170
+ image_token_idx_mask,
171
+ torch.zeros(
172
+ max_num_objects - len(image_token_idx_mask),
173
+ dtype=torch.bool,
174
+ ),
175
+ ]
176
+ )
177
+
178
+ facial_token_idx = torch.nonzero(facial_token_mask, as_tuple=True)[1]
179
+ facial_token_idx_mask = torch.ones_like(facial_token_idx, dtype=torch.bool)
180
+ if len(facial_token_idx) < max_num_facials:
181
+ facial_token_idx = torch.cat(
182
+ [
183
+ facial_token_idx,
184
+ torch.zeros(max_num_facials - len(facial_token_idx), dtype=torch.long),
185
+ ]
186
+ )
187
+ facial_token_idx_mask = torch.cat(
188
+ [
189
+ facial_token_idx_mask,
190
+ torch.zeros(
191
+ max_num_facials - len(facial_token_idx_mask),
192
+ dtype=torch.bool,
193
+ ),
194
+ ]
195
+ )
196
+
197
+ image_token_idx = image_token_idx.unsqueeze(0)
198
+ image_token_idx_mask = image_token_idx_mask.unsqueeze(0)
199
+
200
+ facial_token_idx = facial_token_idx.unsqueeze(0)
201
+ facial_token_idx_mask = facial_token_idx_mask.unsqueeze(0)
202
+
203
+ return image_token_idx, image_token_idx_mask, facial_token_idx, facial_token_idx_mask
204
+
205
+ def get_object_localization_loss_for_one_layer(
206
+ cross_attention_scores,
207
+ object_segmaps,
208
+ object_token_idx,
209
+ object_token_idx_mask,
210
+ loss_fn,
211
+ ):
212
+ bxh, num_noise_latents, num_text_tokens = cross_attention_scores.shape
213
+ b, max_num_objects, _, _ = object_segmaps.shape
214
+ size = int(num_noise_latents**0.5)
215
+
216
+ object_segmaps = F.interpolate(object_segmaps, size=(size, size), mode="bilinear", antialias=True)
217
+
218
+ object_segmaps = object_segmaps.view(
219
+ b, max_num_objects, -1
220
+ )
221
+
222
+ num_heads = bxh // b
223
+ cross_attention_scores = cross_attention_scores.view(b, num_heads, num_noise_latents, num_text_tokens)
224
+
225
+
226
+ object_token_attn_prob = torch.gather(
227
+ cross_attention_scores,
228
+ dim=3,
229
+ index=object_token_idx.view(b, 1, 1, max_num_objects).expand(
230
+ b, num_heads, num_noise_latents, max_num_objects
231
+ ),
232
+ )
233
+ object_segmaps = (
234
+ object_segmaps.permute(0, 2, 1)
235
+ .unsqueeze(1)
236
+ .expand(b, num_heads, num_noise_latents, max_num_objects)
237
+ )
238
+ loss = loss_fn(object_token_attn_prob, object_segmaps)
239
+
240
+ loss = loss * object_token_idx_mask.view(b, 1, max_num_objects)
241
+ object_token_cnt = object_token_idx_mask.sum(dim=1).view(b, 1) + 1e-5
242
+ loss = (loss.sum(dim=2) / object_token_cnt).mean()
243
+
244
+ return loss
245
+
246
+
247
+ def get_object_localization_loss(
248
+ cross_attention_scores,
249
+ object_segmaps,
250
+ image_token_idx,
251
+ image_token_idx_mask,
252
+ loss_fn,
253
+ ):
254
+ num_layers = len(cross_attention_scores)
255
+ loss = 0
256
+ for k, v in cross_attention_scores.items():
257
+ layer_loss = get_object_localization_loss_for_one_layer(
258
+ v, object_segmaps, image_token_idx, image_token_idx_mask, loss_fn
259
+ )
260
+ loss += layer_loss
261
+ return loss / num_layers
262
+
263
+ def unet_store_cross_attention_scores(unet, attention_scores, layers=5):
264
+ from diffusers.models.attention_processor import Attention
265
+
266
+ UNET_LAYER_NAMES = [
267
+ "down_blocks.0",
268
+ "down_blocks.1",
269
+ "down_blocks.2",
270
+ "mid_block",
271
+ "up_blocks.1",
272
+ "up_blocks.2",
273
+ "up_blocks.3",
274
+ ]
275
+
276
+ start_layer = (len(UNET_LAYER_NAMES) - layers) // 2
277
+ end_layer = start_layer + layers
278
+ applicable_layers = UNET_LAYER_NAMES[start_layer:end_layer]
279
+
280
+ def make_new_get_attention_scores_fn(name):
281
+ def new_get_attention_scores(module, query, key, attention_mask=None):
282
+ attention_probs = module.old_get_attention_scores(
283
+ query, key, attention_mask
284
+ )
285
+ attention_scores[name] = attention_probs
286
+ return attention_probs
287
+
288
+ return new_get_attention_scores
289
+
290
+ for name, module in unet.named_modules():
291
+ if isinstance(module, Attention) and "attn1" in name:
292
+ if not any(layer in name for layer in applicable_layers):
293
+ continue
294
+
295
+ module.old_get_attention_scores = module.get_attention_scores
296
+ module.get_attention_scores = types.MethodType(
297
+ make_new_get_attention_scores_fn(name), module
298
+ )
299
+ return unet
300
+
301
+ class BalancedL1Loss(nn.Module):
302
+ def __init__(self, threshold=1.0, normalize=False):
303
+ super().__init__()
304
+ self.threshold = threshold
305
+ self.normalize = normalize
306
+
307
+ def forward(self, object_token_attn_prob, object_segmaps):
308
+ if self.normalize:
309
+ object_token_attn_prob = object_token_attn_prob / (
310
+ object_token_attn_prob.max(dim=2, keepdim=True)[0] + 1e-5
311
+ )
312
+ background_segmaps = 1 - object_segmaps
313
+ background_segmaps_sum = background_segmaps.sum(dim=2) + 1e-5
314
+ object_segmaps_sum = object_segmaps.sum(dim=2) + 1e-5
315
+
316
+ background_loss = (object_token_attn_prob * background_segmaps).sum(
317
+ dim=2
318
+ ) / background_segmaps_sum
319
+
320
+ object_loss = (object_token_attn_prob * object_segmaps).sum(
321
+ dim=2
322
+ ) / object_segmaps_sum
323
+
324
+ return background_loss - object_loss
325
+
326
+ def fetch_mask_raw_image(raw_image, mask_image):
327
+
328
+ mask_image = mask_image.resize(raw_image.size)
329
+ mask_raw_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (0, 0, 0)), mask_image)
330
+
331
+ return mask_raw_image
332
+
333
+ mapping_table = [
334
+ {"Mask Value": 0, "Body Part": "Background", "RGB Color": [0, 0, 0]},
335
+ {"Mask Value": 1, "Body Part": "Face", "RGB Color": [255, 0, 0]},
336
+ {"Mask Value": 2, "Body Part": "Left_Eyebrow", "RGB Color": [255, 85, 0]},
337
+ {"Mask Value": 3, "Body Part": "Right_Eyebrow", "RGB Color": [255, 170, 0]},
338
+ {"Mask Value": 4, "Body Part": "Left_Eye", "RGB Color": [255, 0, 85]},
339
+ {"Mask Value": 5, "Body Part": "Right_Eye", "RGB Color": [255, 0, 170]},
340
+ {"Mask Value": 6, "Body Part": "Hair", "RGB Color": [0, 0, 255]},
341
+ {"Mask Value": 7, "Body Part": "Left_Ear", "RGB Color": [85, 0, 255]},
342
+ {"Mask Value": 8, "Body Part": "Right_Ear", "RGB Color": [170, 0, 255]},
343
+ {"Mask Value": 9, "Body Part": "Mouth_External Contour", "RGB Color": [0, 255, 85]},
344
+ {"Mask Value": 10, "Body Part": "Nose", "RGB Color": [0, 255, 0]},
345
+ {"Mask Value": 11, "Body Part": "Mouth_Inner_Contour", "RGB Color": [0, 255, 170]},
346
+ {"Mask Value": 12, "Body Part": "Upper_Lip", "RGB Color": [85, 255, 0]},
347
+ {"Mask Value": 13, "Body Part": "Lower_Lip", "RGB Color": [170, 255, 0]},
348
+ {"Mask Value": 14, "Body Part": "Neck", "RGB Color": [0, 85, 255]},
349
+ {"Mask Value": 15, "Body Part": "Neck_Inner Contour", "RGB Color": [0, 170, 255]},
350
+ {"Mask Value": 16, "Body Part": "Cloth", "RGB Color": [255, 255, 0]},
351
+ {"Mask Value": 17, "Body Part": "Hat", "RGB Color": [255, 0, 255]},
352
+ {"Mask Value": 18, "Body Part": "Earring", "RGB Color": [255, 85, 255]},
353
+ {"Mask Value": 19, "Body Part": "Necklace", "RGB Color": [255, 255, 85]},
354
+ {"Mask Value": 20, "Body Part": "Glasses", "RGB Color": [255, 170, 255]},
355
+ {"Mask Value": 21, "Body Part": "Hand", "RGB Color": [255, 0, 255]},
356
+ {"Mask Value": 22, "Body Part": "Wristband", "RGB Color": [0, 255, 255]},
357
+ {"Mask Value": 23, "Body Part": "Clothes_Upper", "RGB Color": [85, 255, 255]},
358
+ {"Mask Value": 24, "Body Part": "Clothes_Lower", "RGB Color": [170, 255, 255]}
359
+ ]
360
+
361
+ def masks_for_unique_values(image_raw_mask):
362
+
363
+ image_array = np.array(image_raw_mask)
364
+ unique_values, counts = np.unique(image_array, return_counts=True)
365
+ masks_dict = {}
366
+ for value in unique_values:
367
+ binary_image = np.uint8(image_array == value) * 255
368
+
369
+ contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
370
+
371
+ mask = np.zeros_like(image_array)
372
+
373
+ for contour in contours:
374
+ cv2.drawContours(mask, [contour], -1, (255), thickness=cv2.FILLED)
375
+
376
+ if value == 0:
377
+ body_part="WithoutBackground"
378
+ mask2 = np.where(mask == 255, 0, 255).astype(mask.dtype)
379
+ masks_dict[body_part] = Image.fromarray(mask2)
380
+
381
+ body_part = next((entry["Body Part"] for entry in mapping_table if entry["Mask Value"] == value), f"Unknown_{value}")
382
+ if body_part.startswith("Unknown_"):
383
+ continue
384
+
385
+ masks_dict[body_part] = Image.fromarray(mask)
386
+
387
+ return masks_dict
388
+
389
+ # FFN
390
+ def FeedForward(dim, mult=4):
391
+ inner_dim = int(dim * mult)
392
+ return nn.Sequential(
393
+ nn.LayerNorm(dim),
394
+ nn.Linear(dim, inner_dim, bias=False),
395
+ nn.GELU(),
396
+ nn.Linear(inner_dim, dim, bias=False),
397
+ )
398
+
399
+
400
+ def reshape_tensor(x, heads):
401
+ bs, length, width = x.shape
402
+ x = x.view(bs, length, heads, -1)
403
+ x = x.transpose(1, 2)
404
+ x = x.reshape(bs, heads, length, -1)
405
+ return x
406
+
407
+ class PerceiverAttention(nn.Module):
408
+ def __init__(self, *, dim, dim_head=64, heads=8):
409
+ super().__init__()
410
+ self.scale = dim_head**-0.5
411
+ self.dim_head = dim_head
412
+ self.heads = heads
413
+ inner_dim = dim_head * heads
414
+
415
+ self.norm1 = nn.LayerNorm(dim)
416
+ self.norm2 = nn.LayerNorm(dim)
417
+
418
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
419
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
420
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
421
+
422
+ def forward(self, x, latents):
423
+ """
424
+ Args:
425
+ x (torch.Tensor): image features
426
+ shape (b, n1, D)
427
+ latent (torch.Tensor): latent features
428
+ shape (b, n2, D)
429
+ """
430
+
431
+ x = self.norm1(x)
432
+ latents = self.norm2(latents)
433
+
434
+ b, l, _ = latents.shape
435
+
436
+ q = self.to_q(latents)
437
+ kv_input = torch.cat((x, latents), dim=-2)
438
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
439
+
440
+ q = reshape_tensor(q, self.heads)
441
+ k = reshape_tensor(k, self.heads)
442
+ v = reshape_tensor(v, self.heads)
443
+
444
+ # attention
445
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
446
+ weight = (q * scale) @ (k * scale).transpose(-2, -1)
447
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
448
+ out = weight @ v
449
+
450
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
451
+
452
+ return self.to_out(out)
453
+
454
+ class FacePerceiverResampler(torch.nn.Module):
455
+ def __init__(
456
+ self,
457
+ *,
458
+ dim=768,
459
+ depth=4,
460
+ dim_head=64,
461
+ heads=16,
462
+ embedding_dim=1280,
463
+ output_dim=768,
464
+ ff_mult=4,
465
+ ):
466
+ super().__init__()
467
+
468
+ self.proj_in = torch.nn.Linear(embedding_dim, dim)
469
+ self.proj_out = torch.nn.Linear(dim, output_dim)
470
+ self.norm_out = torch.nn.LayerNorm(output_dim)
471
+ self.layers = torch.nn.ModuleList([])
472
+ for _ in range(depth):
473
+ self.layers.append(
474
+ torch.nn.ModuleList(
475
+ [
476
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
477
+ FeedForward(dim=dim, mult=ff_mult),
478
+ ]
479
+ )
480
+ )
481
+
482
+ def forward(self, latents, x):
483
+ x = self.proj_in(x)
484
+ for attn, ff in self.layers:
485
+ latents = attn(x, latents) + latents
486
+ latents = ff(latents) + latents
487
+ latents = self.proj_out(latents)
488
+ return self.norm_out(latents)
489
+
490
+ class ProjPlusModel(torch.nn.Module):
491
+ def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4):
492
+ super().__init__()
493
+
494
+ self.cross_attention_dim = cross_attention_dim
495
+ self.num_tokens = num_tokens
496
+
497
+ self.proj = torch.nn.Sequential(
498
+ torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
499
+ torch.nn.GELU(),
500
+ torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
501
+ )
502
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
503
+
504
+ self.perceiver_resampler = FacePerceiverResampler(
505
+ dim=cross_attention_dim,
506
+ depth=4,
507
+ dim_head=64,
508
+ heads=cross_attention_dim // 64,
509
+ embedding_dim=clip_embeddings_dim,
510
+ output_dim=cross_attention_dim,
511
+ ff_mult=4,
512
+ )
513
+
514
+ def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0):
515
+
516
+ x = self.proj(id_embeds)
517
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
518
+ x = self.norm(x)
519
+ out = self.perceiver_resampler(x, clip_embeds)
520
+ if shortcut:
521
+ out = x + scale * out
522
+ return out
523
+
524
+ class AttentionMLP(nn.Module):
525
+ def __init__(
526
+ self,
527
+ dtype=torch.float16,
528
+ dim=1024,
529
+ depth=8,
530
+ dim_head=64,
531
+ heads=16,
532
+ single_num_tokens=1,
533
+ embedding_dim=1280,
534
+ output_dim=768,
535
+ ff_mult=4,
536
+ max_seq_len: int = 257*2,
537
+ apply_pos_emb: bool = False,
538
+ num_latents_mean_pooled: int = 0,
539
+ ):
540
+ super().__init__()
541
+ self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
542
+
543
+ self.single_num_tokens = single_num_tokens
544
+ self.latents = nn.Parameter(torch.randn(1, self.single_num_tokens, dim) / dim**0.5)
545
+
546
+ self.proj_in = nn.Linear(embedding_dim, dim)
547
+
548
+ self.proj_out = nn.Linear(dim, output_dim)
549
+ self.norm_out = nn.LayerNorm(output_dim)
550
+
551
+ self.to_latents_from_mean_pooled_seq = (
552
+ nn.Sequential(
553
+ nn.LayerNorm(dim),
554
+ nn.Linear(dim, dim * num_latents_mean_pooled),
555
+ Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
556
+ )
557
+ if num_latents_mean_pooled > 0
558
+ else None
559
+ )
560
+
561
+ self.layers = nn.ModuleList([])
562
+ for _ in range(depth):
563
+ self.layers.append(
564
+ nn.ModuleList(
565
+ [
566
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
567
+ FeedForward(dim=dim, mult=ff_mult),
568
+ ]
569
+ )
570
+ )
571
+
572
+ def forward(self, x):
573
+ if self.pos_emb is not None:
574
+ n, device = x.shape[1], x.device
575
+ pos_emb = self.pos_emb(torch.arange(n, device=device))
576
+ x = x + pos_emb
577
+
578
+ latents = self.latents.repeat(x.size(0), 1, 1)
579
+
580
+ x = self.proj_in(x)
581
+
582
+ if self.to_latents_from_mean_pooled_seq:
583
+ meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
584
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
585
+ latents = torch.cat((meanpooled_latents, latents), dim=-2)
586
+
587
+ for attn, ff in self.layers:
588
+ latents = attn(x, latents) + latents
589
+ latents = ff(latents) + latents
590
+
591
+ latents = self.proj_out(latents)
592
+ return self.norm_out(latents)
593
+
594
+
595
+ def masked_mean(t, *, dim, mask=None):
596
+ if mask is None:
597
+ return t.mean(dim=dim)
598
+
599
+ denom = mask.sum(dim=dim, keepdim=True)
600
+ mask = rearrange(mask, "b n -> b n 1")
601
+ masked_t = t.masked_fill(~mask, 0.0)
602
+
603
+ return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
604
+
605
+