zejunyang commited on
Commit
2e4e201
·
1 Parent(s): 558ddd8
src/audio2vid.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import ffmpeg
3
+ from datetime import datetime
4
+ from pathlib import Path
5
+ import numpy as np
6
+ import cv2
7
+ import torch
8
+ # import spaces
9
+ from scipy.spatial.transform import Rotation as R
10
+ from scipy.interpolate import interp1d
11
+
12
+ from diffusers import AutoencoderKL, DDIMScheduler
13
+ from einops import repeat
14
+ from omegaconf import OmegaConf
15
+ from PIL import Image
16
+ from torchvision import transforms
17
+ from transformers import CLIPVisionModelWithProjection
18
+
19
+
20
+ from src.models.pose_guider import PoseGuider
21
+ from src.models.unet_2d_condition import UNet2DConditionModel
22
+ from src.models.unet_3d import UNet3DConditionModel
23
+ from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
24
+ from src.utils.util import save_videos_grid
25
+
26
+ from src.audio_models.model import Audio2MeshModel
27
+ from src.utils.audio_util import prepare_audio_feature
28
+ from src.utils.mp_utils import LMKExtractor
29
+ from src.utils.draw_util import FaceMeshVisualizer
30
+ from src.utils.pose_util import project_points
31
+
32
+
33
+ def matrix_to_euler_and_translation(matrix):
34
+ rotation_matrix = matrix[:3, :3]
35
+ translation_vector = matrix[:3, 3]
36
+ rotation = R.from_matrix(rotation_matrix)
37
+ euler_angles = rotation.as_euler('xyz', degrees=True)
38
+ return euler_angles, translation_vector
39
+
40
+
41
+ def smooth_pose_seq(pose_seq, window_size=5):
42
+ smoothed_pose_seq = np.zeros_like(pose_seq)
43
+
44
+ for i in range(len(pose_seq)):
45
+ start = max(0, i - window_size // 2)
46
+ end = min(len(pose_seq), i + window_size // 2 + 1)
47
+ smoothed_pose_seq[i] = np.mean(pose_seq[start:end], axis=0)
48
+
49
+ return smoothed_pose_seq
50
+
51
+ def get_headpose_temp(input_video):
52
+ lmk_extractor = LMKExtractor()
53
+ cap = cv2.VideoCapture(input_video)
54
+
55
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
56
+ fps = cap.get(cv2.CAP_PROP_FPS)
57
+
58
+ trans_mat_list = []
59
+ while cap.isOpened():
60
+ ret, frame = cap.read()
61
+ if not ret:
62
+ break
63
+
64
+ result = lmk_extractor(frame)
65
+ trans_mat_list.append(result['trans_mat'].astype(np.float32))
66
+ cap.release()
67
+
68
+ trans_mat_arr = np.array(trans_mat_list)
69
+
70
+ # compute delta pose
71
+ trans_mat_inv_frame_0 = np.linalg.inv(trans_mat_arr[0])
72
+ pose_arr = np.zeros([trans_mat_arr.shape[0], 6])
73
+
74
+ for i in range(pose_arr.shape[0]):
75
+ pose_mat = trans_mat_inv_frame_0 @ trans_mat_arr[i]
76
+ euler_angles, translation_vector = matrix_to_euler_and_translation(pose_mat)
77
+ pose_arr[i, :3] = euler_angles
78
+ pose_arr[i, 3:6] = translation_vector
79
+
80
+ # interpolate to 30 fps
81
+ new_fps = 30
82
+ old_time = np.linspace(0, total_frames / fps, total_frames)
83
+ new_time = np.linspace(0, total_frames / fps, int(total_frames * new_fps / fps))
84
+
85
+ pose_arr_interp = np.zeros((len(new_time), 6))
86
+ for i in range(6):
87
+ interp_func = interp1d(old_time, pose_arr[:, i])
88
+ pose_arr_interp[:, i] = interp_func(new_time)
89
+
90
+ pose_arr_smooth = smooth_pose_seq(pose_arr_interp)
91
+
92
+ return pose_arr_smooth
93
+
94
+ # @spaces.GPU
95
+ def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, length=150, seed=42):
96
+ fps = 30
97
+ cfg = 3.5
98
+
99
+ config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
100
+
101
+ if config.weight_dtype == "fp16":
102
+ weight_dtype = torch.float16
103
+ else:
104
+ weight_dtype = torch.float32
105
+
106
+ audio_infer_config = OmegaConf.load(config.audio_inference_config)
107
+ # prepare model
108
+ a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
109
+ a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt']), strict=False)
110
+ a2m_model.cuda().eval()
111
+
112
+ vae = AutoencoderKL.from_pretrained(
113
+ config.pretrained_vae_path,
114
+ ).to("cuda", dtype=weight_dtype)
115
+
116
+ reference_unet = UNet2DConditionModel.from_pretrained(
117
+ config.pretrained_base_model_path,
118
+ subfolder="unet",
119
+ ).to(dtype=weight_dtype, device="cuda")
120
+
121
+ inference_config_path = config.inference_config
122
+ infer_config = OmegaConf.load(inference_config_path)
123
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
124
+ config.pretrained_base_model_path,
125
+ config.motion_module_path,
126
+ subfolder="unet",
127
+ unet_additional_kwargs=infer_config.unet_additional_kwargs,
128
+ ).to(dtype=weight_dtype, device="cuda")
129
+
130
+
131
+ pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
132
+
133
+ image_enc = CLIPVisionModelWithProjection.from_pretrained(
134
+ config.image_encoder_path
135
+ ).to(dtype=weight_dtype, device="cuda")
136
+
137
+ sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
138
+ scheduler = DDIMScheduler(**sched_kwargs)
139
+
140
+ generator = torch.manual_seed(seed)
141
+
142
+ width, height = size, size
143
+
144
+ # load pretrained weights
145
+ denoising_unet.load_state_dict(
146
+ torch.load(config.denoising_unet_path, map_location="cpu"),
147
+ strict=False,
148
+ )
149
+ reference_unet.load_state_dict(
150
+ torch.load(config.reference_unet_path, map_location="cpu"),
151
+ )
152
+ pose_guider.load_state_dict(
153
+ torch.load(config.pose_guider_path, map_location="cpu"),
154
+ )
155
+
156
+ pipe = Pose2VideoPipeline(
157
+ vae=vae,
158
+ image_encoder=image_enc,
159
+ reference_unet=reference_unet,
160
+ denoising_unet=denoising_unet,
161
+ pose_guider=pose_guider,
162
+ scheduler=scheduler,
163
+ )
164
+ pipe = pipe.to("cuda", dtype=weight_dtype)
165
+
166
+ date_str = datetime.now().strftime("%Y%m%d")
167
+ time_str = datetime.now().strftime("%H%M")
168
+ save_dir_name = f"{time_str}--seed_{seed}-{size}x{size}"
169
+
170
+ save_dir = Path(f"output/{date_str}/{save_dir_name}")
171
+ save_dir.mkdir(exist_ok=True, parents=True)
172
+
173
+ lmk_extractor = LMKExtractor()
174
+ vis = FaceMeshVisualizer(forehead_edge=False)
175
+
176
+ ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
177
+ # TODO: 人脸检测+裁剪
178
+ ref_image_np = cv2.resize(ref_image_np, (size, size))
179
+ ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
180
+
181
+ face_result = lmk_extractor(ref_image_np)
182
+ if face_result is None:
183
+ return None
184
+
185
+ lmks = face_result['lmks'].astype(np.float32)
186
+ ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
187
+
188
+ sample = prepare_audio_feature(input_audio, wav2vec_model_path=audio_infer_config['a2m_model']['model_path'])
189
+ sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda()
190
+ sample['audio_feature'] = sample['audio_feature'].unsqueeze(0)
191
+
192
+ # inference
193
+ pred = a2m_model.infer(sample['audio_feature'], sample['seq_len'])
194
+ pred = pred.squeeze().detach().cpu().numpy()
195
+ pred = pred.reshape(pred.shape[0], -1, 3)
196
+ pred = pred + face_result['lmks3d']
197
+
198
+ if headpose_video is not None:
199
+ pose_seq = get_headpose_temp(headpose_video)
200
+ else:
201
+ pose_seq = np.load(config['pose_temp'])
202
+ mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0)
203
+ cycled_pose_seq = np.tile(mirrored_pose_seq, (sample['seq_len'] // len(mirrored_pose_seq) + 1, 1))[:sample['seq_len']]
204
+
205
+ # project 3D mesh to 2D landmark
206
+ projected_vertices = project_points(pred, face_result['trans_mat'], cycled_pose_seq, [height, width])
207
+
208
+ pose_images = []
209
+ for i, verts in enumerate(projected_vertices):
210
+ lmk_img = vis.draw_landmarks((width, height), verts, normed=False)
211
+ pose_images.append(lmk_img)
212
+
213
+ pose_list = []
214
+ pose_tensor_list = []
215
+
216
+ pose_transform = transforms.Compose(
217
+ [transforms.Resize((height, width)), transforms.ToTensor()]
218
+ )
219
+ args_L = len(pose_images) if length==0 or length > len(pose_images) else length
220
+ for pose_image_np in pose_images[: args_L]:
221
+ pose_image_pil = Image.fromarray(cv2.cvtColor(pose_image_np, cv2.COLOR_BGR2RGB))
222
+ pose_tensor_list.append(pose_transform(pose_image_pil))
223
+ pose_image_np = cv2.resize(pose_image_np, (width, height))
224
+ pose_list.append(pose_image_np)
225
+
226
+ pose_list = np.array(pose_list)
227
+
228
+ video_length = len(pose_tensor_list)
229
+
230
+ video = pipe(
231
+ ref_image_pil,
232
+ pose_list,
233
+ ref_pose,
234
+ width,
235
+ height,
236
+ video_length,
237
+ steps,
238
+ cfg,
239
+ generator=generator,
240
+ ).videos
241
+
242
+ save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4"
243
+ save_videos_grid(
244
+ video,
245
+ save_path,
246
+ n_rows=1,
247
+ fps=fps,
248
+ )
249
+
250
+ stream = ffmpeg.input(save_path)
251
+ audio = ffmpeg.input(input_audio)
252
+ ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac').run()
253
+ os.remove(save_path)
254
+
255
+ return save_path.replace('_noaudio.mp4', '.mp4')
src/audio_models/mish.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Applies the mish function element-wise:
3
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
4
+ """
5
+
6
+ # import pytorch
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+ @torch.jit.script
12
+ def mish(input):
13
+ """
14
+ Applies the mish function element-wise:
15
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
16
+ See additional documentation for mish class.
17
+ """
18
+ return input * torch.tanh(F.softplus(input))
19
+
20
+ class Mish(nn.Module):
21
+ """
22
+ Applies the mish function element-wise:
23
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
24
+
25
+ Shape:
26
+ - Input: (N, *) where * means, any number of additional
27
+ dimensions
28
+ - Output: (N, *), same shape as the input
29
+
30
+ Examples:
31
+ >>> m = Mish()
32
+ >>> input = torch.randn(2)
33
+ >>> output = m(input)
34
+
35
+ Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html
36
+ """
37
+
38
+ def __init__(self):
39
+ """
40
+ Init method.
41
+ """
42
+ super().__init__()
43
+
44
+ def forward(self, input):
45
+ """
46
+ Forward pass of the function.
47
+ """
48
+ if torch.__version__ >= "1.9":
49
+ return F.mish(input)
50
+ else:
51
+ return mish(input)
src/audio_models/model.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from transformers import Wav2Vec2Config
6
+
7
+ from .torch_utils import get_mask_from_lengths
8
+ from .wav2vec2 import Wav2Vec2Model
9
+
10
+
11
+ class Audio2MeshModel(nn.Module):
12
+ def __init__(
13
+ self,
14
+ config
15
+ ):
16
+ super().__init__()
17
+ out_dim = config['out_dim']
18
+ latent_dim = config['latent_dim']
19
+ model_path = config['model_path']
20
+ only_last_fetures = config['only_last_fetures']
21
+ from_pretrained = config['from_pretrained']
22
+
23
+ self._only_last_features = only_last_fetures
24
+
25
+ self.audio_encoder_config = Wav2Vec2Config.from_pretrained(model_path, local_files_only=True)
26
+ if from_pretrained:
27
+ self.audio_encoder = Wav2Vec2Model.from_pretrained(model_path, local_files_only=True)
28
+ else:
29
+ self.audio_encoder = Wav2Vec2Model(self.audio_encoder_config)
30
+ self.audio_encoder.feature_extractor._freeze_parameters()
31
+
32
+ hidden_size = self.audio_encoder_config.hidden_size
33
+
34
+ self.in_fn = nn.Linear(hidden_size, latent_dim)
35
+
36
+ self.out_fn = nn.Linear(latent_dim, out_dim)
37
+ nn.init.constant_(self.out_fn.weight, 0)
38
+ nn.init.constant_(self.out_fn.bias, 0)
39
+
40
+ def forward(self, audio, label, audio_len=None):
41
+ attention_mask = ~get_mask_from_lengths(audio_len) if audio_len else None
42
+
43
+ seq_len = label.shape[1]
44
+
45
+ embeddings = self.audio_encoder(audio, seq_len=seq_len, output_hidden_states=True,
46
+ attention_mask=attention_mask)
47
+
48
+ if self._only_last_features:
49
+ hidden_states = embeddings.last_hidden_state
50
+ else:
51
+ hidden_states = sum(embeddings.hidden_states) / len(embeddings.hidden_states)
52
+
53
+ layer_in = self.in_fn(hidden_states)
54
+ out = self.out_fn(layer_in)
55
+
56
+ return out, None
57
+
58
+ def infer(self, input_value, seq_len):
59
+ embeddings = self.audio_encoder(input_value, seq_len=seq_len, output_hidden_states=True)
60
+
61
+ if self._only_last_features:
62
+ hidden_states = embeddings.last_hidden_state
63
+ else:
64
+ hidden_states = sum(embeddings.hidden_states) / len(embeddings.hidden_states)
65
+
66
+ layer_in = self.in_fn(hidden_states)
67
+ out = self.out_fn(layer_in)
68
+
69
+ return out
70
+
71
+
src/audio_models/torch_utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def get_mask_from_lengths(lengths, max_len=None):
6
+ lengths = lengths.to(torch.long)
7
+ if max_len is None:
8
+ max_len = torch.max(lengths).item()
9
+
10
+ ids = torch.arange(0, max_len).unsqueeze(0).expand(lengths.shape[0], -1).to(lengths.device)
11
+ mask = ids < lengths.unsqueeze(1).expand(-1, max_len)
12
+
13
+ return mask
14
+
15
+
16
+ def linear_interpolation(features, seq_len):
17
+ features = features.transpose(1, 2)
18
+ output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
19
+ return output_features.transpose(1, 2)
20
+
21
+
22
+ if __name__ == "__main__":
23
+ import numpy as np
24
+ mask = ~get_mask_from_lengths(torch.from_numpy(np.array([4,6])))
25
+ import pdb; pdb.set_trace()
src/audio_models/wav2vec2.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Wav2Vec2Config, Wav2Vec2Model
2
+ from transformers.modeling_outputs import BaseModelOutput
3
+
4
+ from .torch_utils import linear_interpolation
5
+
6
+ # the implementation of Wav2Vec2Model is borrowed from
7
+ # https://github.com/huggingface/transformers/blob/HEAD/src/transformers/models/wav2vec2/modeling_wav2vec2.py
8
+ # initialize our encoder with the pre-trained wav2vec 2.0 weights.
9
+ class Wav2Vec2Model(Wav2Vec2Model):
10
+ def __init__(self, config: Wav2Vec2Config):
11
+ super().__init__(config)
12
+
13
+ def forward(
14
+ self,
15
+ input_values,
16
+ seq_len,
17
+ attention_mask=None,
18
+ mask_time_indices=None,
19
+ output_attentions=None,
20
+ output_hidden_states=None,
21
+ return_dict=None,
22
+ ):
23
+ self.config.output_attentions = True
24
+
25
+ output_hidden_states = (
26
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
27
+ )
28
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
29
+
30
+ extract_features = self.feature_extractor(input_values)
31
+ extract_features = extract_features.transpose(1, 2)
32
+ extract_features = linear_interpolation(extract_features, seq_len=seq_len)
33
+
34
+ if attention_mask is not None:
35
+ # compute reduced attention_mask corresponding to feature vectors
36
+ attention_mask = self._get_feature_vector_attention_mask(
37
+ extract_features.shape[1], attention_mask, add_adapter=False
38
+ )
39
+
40
+ hidden_states, extract_features = self.feature_projection(extract_features)
41
+ hidden_states = self._mask_hidden_states(
42
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
43
+ )
44
+
45
+ encoder_outputs = self.encoder(
46
+ hidden_states,
47
+ attention_mask=attention_mask,
48
+ output_attentions=output_attentions,
49
+ output_hidden_states=output_hidden_states,
50
+ return_dict=return_dict,
51
+ )
52
+
53
+ hidden_states = encoder_outputs[0]
54
+
55
+ if self.adapter is not None:
56
+ hidden_states = self.adapter(hidden_states)
57
+
58
+ if not return_dict:
59
+ return (hidden_states, ) + encoder_outputs[1:]
60
+ return BaseModelOutput(
61
+ last_hidden_state=hidden_states,
62
+ hidden_states=encoder_outputs.hidden_states,
63
+ attentions=encoder_outputs.attentions,
64
+ )
65
+
66
+
67
+ def feature_extract(
68
+ self,
69
+ input_values,
70
+ seq_len,
71
+ ):
72
+ extract_features = self.feature_extractor(input_values)
73
+ extract_features = extract_features.transpose(1, 2)
74
+ extract_features = linear_interpolation(extract_features, seq_len=seq_len)
75
+
76
+ return extract_features
77
+
78
+ def encode(
79
+ self,
80
+ extract_features,
81
+ attention_mask=None,
82
+ mask_time_indices=None,
83
+ output_attentions=None,
84
+ output_hidden_states=None,
85
+ return_dict=None,
86
+ ):
87
+ self.config.output_attentions = True
88
+
89
+ output_hidden_states = (
90
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
91
+ )
92
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
93
+
94
+ if attention_mask is not None:
95
+ # compute reduced attention_mask corresponding to feature vectors
96
+ attention_mask = self._get_feature_vector_attention_mask(
97
+ extract_features.shape[1], attention_mask, add_adapter=False
98
+ )
99
+
100
+
101
+ hidden_states, extract_features = self.feature_projection(extract_features)
102
+ hidden_states = self._mask_hidden_states(
103
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
104
+ )
105
+
106
+ encoder_outputs = self.encoder(
107
+ hidden_states,
108
+ attention_mask=attention_mask,
109
+ output_attentions=output_attentions,
110
+ output_hidden_states=output_hidden_states,
111
+ return_dict=return_dict,
112
+ )
113
+
114
+ hidden_states = encoder_outputs[0]
115
+
116
+ if self.adapter is not None:
117
+ hidden_states = self.adapter(hidden_states)
118
+
119
+ if not return_dict:
120
+ return (hidden_states, ) + encoder_outputs[1:]
121
+ return BaseModelOutput(
122
+ last_hidden_state=hidden_states,
123
+ hidden_states=encoder_outputs.hidden_states,
124
+ attentions=encoder_outputs.attentions,
125
+ )
src/models/attention.py ADDED
@@ -0,0 +1,840 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ from typing import Any, Dict, Optional
4
+
5
+ import torch
6
+ from diffusers.models.attention import AdaLayerNorm, Attention, FeedForward
7
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
8
+ from einops import rearrange
9
+ from torch import nn
10
+
11
+ from diffusers.models.attention import *
12
+ from diffusers.models.attention_processor import *
13
+
14
+ class BasicTransformerBlock(nn.Module):
15
+ r"""
16
+ A basic Transformer block.
17
+
18
+ Parameters:
19
+ dim (`int`): The number of channels in the input and output.
20
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
21
+ attention_head_dim (`int`): The number of channels in each head.
22
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
23
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
24
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
25
+ num_embeds_ada_norm (:
26
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
27
+ attention_bias (:
28
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
29
+ only_cross_attention (`bool`, *optional*):
30
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
31
+ double_self_attention (`bool`, *optional*):
32
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
33
+ upcast_attention (`bool`, *optional*):
34
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
35
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
36
+ Whether to use learnable elementwise affine parameters for normalization.
37
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
38
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
39
+ final_dropout (`bool` *optional*, defaults to False):
40
+ Whether to apply a final dropout after the last feed-forward layer.
41
+ attention_type (`str`, *optional*, defaults to `"default"`):
42
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
43
+ positional_embeddings (`str`, *optional*, defaults to `None`):
44
+ The type of positional embeddings to apply to.
45
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
46
+ The maximum number of positional embeddings to apply.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ dim: int,
52
+ num_attention_heads: int,
53
+ attention_head_dim: int,
54
+ dropout=0.0,
55
+ cross_attention_dim: Optional[int] = None,
56
+ activation_fn: str = "geglu",
57
+ num_embeds_ada_norm: Optional[int] = None,
58
+ attention_bias: bool = False,
59
+ only_cross_attention: bool = False,
60
+ double_self_attention: bool = False,
61
+ upcast_attention: bool = False,
62
+ norm_elementwise_affine: bool = True,
63
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
64
+ norm_eps: float = 1e-5,
65
+ final_dropout: bool = False,
66
+ attention_type: str = "default",
67
+ positional_embeddings: Optional[str] = None,
68
+ num_positional_embeddings: Optional[int] = None,
69
+ ):
70
+ super().__init__()
71
+ self.only_cross_attention = only_cross_attention
72
+
73
+ self.use_ada_layer_norm_zero = (
74
+ num_embeds_ada_norm is not None
75
+ ) and norm_type == "ada_norm_zero"
76
+ self.use_ada_layer_norm = (
77
+ num_embeds_ada_norm is not None
78
+ ) and norm_type == "ada_norm"
79
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
80
+ self.use_layer_norm = norm_type == "layer_norm"
81
+
82
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
83
+ raise ValueError(
84
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
85
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
86
+ )
87
+
88
+ if positional_embeddings and (num_positional_embeddings is None):
89
+ raise ValueError(
90
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
91
+ )
92
+
93
+ if positional_embeddings == "sinusoidal":
94
+ self.pos_embed = SinusoidalPositionalEmbedding(
95
+ dim, max_seq_length=num_positional_embeddings
96
+ )
97
+ else:
98
+ self.pos_embed = None
99
+
100
+ # Define 3 blocks. Each block has its own normalization layer.
101
+ # 1. Self-Attn
102
+ if self.use_ada_layer_norm:
103
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
104
+ elif self.use_ada_layer_norm_zero:
105
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
106
+ else:
107
+ self.norm1 = nn.LayerNorm(
108
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
109
+ )
110
+
111
+ self.attn1 = Attention(
112
+ query_dim=dim,
113
+ heads=num_attention_heads,
114
+ dim_head=attention_head_dim,
115
+ dropout=dropout,
116
+ bias=attention_bias,
117
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
118
+ upcast_attention=upcast_attention,
119
+ )
120
+
121
+ # 2. Cross-Attn
122
+ if cross_attention_dim is not None or double_self_attention:
123
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
124
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
125
+ # the second cross attention block.
126
+ self.norm2 = (
127
+ AdaLayerNorm(dim, num_embeds_ada_norm)
128
+ if self.use_ada_layer_norm
129
+ else nn.LayerNorm(
130
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
131
+ )
132
+ )
133
+ self.attn2 = Attention(
134
+ query_dim=dim,
135
+ cross_attention_dim=cross_attention_dim
136
+ if not double_self_attention
137
+ else None,
138
+ heads=num_attention_heads,
139
+ dim_head=attention_head_dim,
140
+ dropout=dropout,
141
+ bias=attention_bias,
142
+ upcast_attention=upcast_attention,
143
+ ) # is self-attn if encoder_hidden_states is none
144
+ else:
145
+ self.norm2 = None
146
+ self.attn2 = None
147
+
148
+ # 3. Feed-forward
149
+ if not self.use_ada_layer_norm_single:
150
+ self.norm3 = nn.LayerNorm(
151
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
152
+ )
153
+
154
+ self.ff = FeedForward(
155
+ dim,
156
+ dropout=dropout,
157
+ activation_fn=activation_fn,
158
+ final_dropout=final_dropout,
159
+ )
160
+
161
+ # 4. Fuser
162
+ if attention_type == "gated" or attention_type == "gated-text-image":
163
+ self.fuser = GatedSelfAttentionDense(
164
+ dim, cross_attention_dim, num_attention_heads, attention_head_dim
165
+ )
166
+
167
+ # 5. Scale-shift for PixArt-Alpha.
168
+ if self.use_ada_layer_norm_single:
169
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
170
+
171
+ # let chunk size default to None
172
+ self._chunk_size = None
173
+ self._chunk_dim = 0
174
+
175
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
176
+ # Sets chunk feed-forward
177
+ self._chunk_size = chunk_size
178
+ self._chunk_dim = dim
179
+
180
+ def forward(
181
+ self,
182
+ hidden_states: torch.FloatTensor,
183
+ attention_mask: Optional[torch.FloatTensor] = None,
184
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
185
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
186
+ timestep: Optional[torch.LongTensor] = None,
187
+ cross_attention_kwargs: Dict[str, Any] = None,
188
+ class_labels: Optional[torch.LongTensor] = None,
189
+ ) -> torch.FloatTensor:
190
+ # Notice that normalization is always applied before the real computation in the following blocks.
191
+ # 0. Self-Attention
192
+ batch_size = hidden_states.shape[0]
193
+
194
+ if self.use_ada_layer_norm:
195
+ norm_hidden_states = self.norm1(hidden_states, timestep)
196
+ elif self.use_ada_layer_norm_zero:
197
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
198
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
199
+ )
200
+ elif self.use_layer_norm:
201
+ norm_hidden_states = self.norm1(hidden_states)
202
+ elif self.use_ada_layer_norm_single:
203
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
204
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
205
+ ).chunk(6, dim=1)
206
+ norm_hidden_states = self.norm1(hidden_states)
207
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
208
+ norm_hidden_states = norm_hidden_states.squeeze(1)
209
+ else:
210
+ raise ValueError("Incorrect norm used")
211
+
212
+ if self.pos_embed is not None:
213
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
214
+
215
+ # 1. Retrieve lora scale.
216
+ lora_scale = (
217
+ cross_attention_kwargs.get("scale", 1.0)
218
+ if cross_attention_kwargs is not None
219
+ else 1.0
220
+ )
221
+
222
+ # 2. Prepare GLIGEN inputs
223
+ cross_attention_kwargs = (
224
+ cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
225
+ )
226
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
227
+
228
+ attn_output = self.attn1(
229
+ norm_hidden_states,
230
+ encoder_hidden_states=encoder_hidden_states
231
+ if self.only_cross_attention
232
+ else None,
233
+ attention_mask=attention_mask,
234
+ **cross_attention_kwargs,
235
+ )
236
+ if self.use_ada_layer_norm_zero:
237
+ attn_output = gate_msa.unsqueeze(1) * attn_output
238
+ elif self.use_ada_layer_norm_single:
239
+ attn_output = gate_msa * attn_output
240
+
241
+ hidden_states = attn_output + hidden_states
242
+ if hidden_states.ndim == 4:
243
+ hidden_states = hidden_states.squeeze(1)
244
+
245
+ # 2.5 GLIGEN Control
246
+ if gligen_kwargs is not None:
247
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
248
+
249
+ # 3. Cross-Attention
250
+ if self.attn2 is not None:
251
+ if self.use_ada_layer_norm:
252
+ norm_hidden_states = self.norm2(hidden_states, timestep)
253
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
254
+ norm_hidden_states = self.norm2(hidden_states)
255
+ elif self.use_ada_layer_norm_single:
256
+ # For PixArt norm2 isn't applied here:
257
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
258
+ norm_hidden_states = hidden_states
259
+ else:
260
+ raise ValueError("Incorrect norm")
261
+
262
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
263
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
264
+
265
+ attn_output = self.attn2(
266
+ norm_hidden_states,
267
+ encoder_hidden_states=encoder_hidden_states,
268
+ attention_mask=encoder_attention_mask,
269
+ **cross_attention_kwargs,
270
+ )
271
+ hidden_states = attn_output + hidden_states
272
+
273
+ # 4. Feed-forward
274
+ if not self.use_ada_layer_norm_single:
275
+ norm_hidden_states = self.norm3(hidden_states)
276
+
277
+ if self.use_ada_layer_norm_zero:
278
+ norm_hidden_states = (
279
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
280
+ )
281
+
282
+ if self.use_ada_layer_norm_single:
283
+ norm_hidden_states = self.norm2(hidden_states)
284
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
285
+
286
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
287
+
288
+ if self.use_ada_layer_norm_zero:
289
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
290
+ elif self.use_ada_layer_norm_single:
291
+ ff_output = gate_mlp * ff_output
292
+
293
+ hidden_states = ff_output + hidden_states
294
+ if hidden_states.ndim == 4:
295
+ hidden_states = hidden_states.squeeze(1)
296
+
297
+ return hidden_states
298
+
299
+
300
+ class TemporalBasicTransformerBlock(nn.Module):
301
+ def __init__(
302
+ self,
303
+ dim: int,
304
+ num_attention_heads: int,
305
+ attention_head_dim: int,
306
+ dropout=0.0,
307
+ cross_attention_dim: Optional[int] = None,
308
+ activation_fn: str = "geglu",
309
+ num_embeds_ada_norm: Optional[int] = None,
310
+ attention_bias: bool = False,
311
+ only_cross_attention: bool = False,
312
+ upcast_attention: bool = False,
313
+ unet_use_cross_frame_attention=None,
314
+ unet_use_temporal_attention=None,
315
+ ):
316
+ super().__init__()
317
+ self.only_cross_attention = only_cross_attention
318
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
319
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
320
+ self.unet_use_temporal_attention = unet_use_temporal_attention
321
+
322
+ # SC-Attn
323
+ self.attn1 = Attention(
324
+ query_dim=dim,
325
+ heads=num_attention_heads,
326
+ dim_head=attention_head_dim,
327
+ dropout=dropout,
328
+ bias=attention_bias,
329
+ upcast_attention=upcast_attention,
330
+ )
331
+ self.norm1 = (
332
+ AdaLayerNorm(dim, num_embeds_ada_norm)
333
+ if self.use_ada_layer_norm
334
+ else nn.LayerNorm(dim)
335
+ )
336
+
337
+ # Cross-Attn
338
+ if cross_attention_dim is not None:
339
+ self.attn2 = Attention(
340
+ query_dim=dim,
341
+ cross_attention_dim=cross_attention_dim,
342
+ heads=num_attention_heads,
343
+ dim_head=attention_head_dim,
344
+ dropout=dropout,
345
+ bias=attention_bias,
346
+ upcast_attention=upcast_attention,
347
+ )
348
+ else:
349
+ self.attn2 = None
350
+
351
+ if cross_attention_dim is not None:
352
+ self.norm2 = (
353
+ AdaLayerNorm(dim, num_embeds_ada_norm)
354
+ if self.use_ada_layer_norm
355
+ else nn.LayerNorm(dim)
356
+ )
357
+ else:
358
+ self.norm2 = None
359
+
360
+ # Feed-forward
361
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
362
+ self.norm3 = nn.LayerNorm(dim)
363
+ self.use_ada_layer_norm_zero = False
364
+
365
+ # Temp-Attn
366
+ assert unet_use_temporal_attention is not None
367
+ if unet_use_temporal_attention:
368
+ self.attn_temp = Attention(
369
+ query_dim=dim,
370
+ heads=num_attention_heads,
371
+ dim_head=attention_head_dim,
372
+ dropout=dropout,
373
+ bias=attention_bias,
374
+ upcast_attention=upcast_attention,
375
+ )
376
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
377
+ self.norm_temp = (
378
+ AdaLayerNorm(dim, num_embeds_ada_norm)
379
+ if self.use_ada_layer_norm
380
+ else nn.LayerNorm(dim)
381
+ )
382
+
383
+ def forward(
384
+ self,
385
+ hidden_states,
386
+ encoder_hidden_states=None,
387
+ timestep=None,
388
+ attention_mask=None,
389
+ video_length=None,
390
+ ):
391
+ norm_hidden_states = (
392
+ self.norm1(hidden_states, timestep)
393
+ if self.use_ada_layer_norm
394
+ else self.norm1(hidden_states)
395
+ )
396
+
397
+ if self.unet_use_cross_frame_attention:
398
+ hidden_states = (
399
+ self.attn1(
400
+ norm_hidden_states,
401
+ attention_mask=attention_mask,
402
+ video_length=video_length,
403
+ )
404
+ + hidden_states
405
+ )
406
+ else:
407
+ hidden_states = (
408
+ self.attn1(norm_hidden_states, attention_mask=attention_mask)
409
+ + hidden_states
410
+ )
411
+
412
+ if self.attn2 is not None:
413
+ # Cross-Attention
414
+ norm_hidden_states = (
415
+ self.norm2(hidden_states, timestep)
416
+ if self.use_ada_layer_norm
417
+ else self.norm2(hidden_states)
418
+ )
419
+ hidden_states = (
420
+ self.attn2(
421
+ norm_hidden_states,
422
+ encoder_hidden_states=encoder_hidden_states,
423
+ attention_mask=attention_mask,
424
+ )
425
+ + hidden_states
426
+ )
427
+
428
+ # Feed-forward
429
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
430
+
431
+ # Temporal-Attention
432
+ if self.unet_use_temporal_attention:
433
+ d = hidden_states.shape[1]
434
+ hidden_states = rearrange(
435
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
436
+ )
437
+ norm_hidden_states = (
438
+ self.norm_temp(hidden_states, timestep)
439
+ if self.use_ada_layer_norm
440
+ else self.norm_temp(hidden_states)
441
+ )
442
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
443
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
444
+
445
+ return hidden_states
446
+
447
+
448
+ class ResidualTemporalBasicTransformerBlock(TemporalBasicTransformerBlock):
449
+ def __init__(
450
+ self,
451
+ dim: int,
452
+ num_attention_heads: int,
453
+ attention_head_dim: int,
454
+ dropout=0.0,
455
+ cross_attention_dim: Optional[int] = None,
456
+ activation_fn: str = "geglu",
457
+ num_embeds_ada_norm: Optional[int] = None,
458
+ attention_bias: bool = False,
459
+ only_cross_attention: bool = False,
460
+ upcast_attention: bool = False,
461
+ unet_use_cross_frame_attention=None,
462
+ unet_use_temporal_attention=None,
463
+ ):
464
+ super(TemporalBasicTransformerBlock, self).__init__()
465
+ self.only_cross_attention = only_cross_attention
466
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
467
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
468
+ self.unet_use_temporal_attention = unet_use_temporal_attention
469
+
470
+ # SC-Attn
471
+ self.attn1 = ResidualAttention(
472
+ query_dim=dim,
473
+ heads=num_attention_heads,
474
+ dim_head=attention_head_dim,
475
+ dropout=dropout,
476
+ bias=attention_bias,
477
+ upcast_attention=upcast_attention,
478
+ )
479
+ self.norm1 = (
480
+ AdaLayerNorm(dim, num_embeds_ada_norm)
481
+ if self.use_ada_layer_norm
482
+ else nn.LayerNorm(dim)
483
+ )
484
+
485
+ # Cross-Attn
486
+ if cross_attention_dim is not None:
487
+ self.attn2 = ResidualAttention(
488
+ query_dim=dim,
489
+ cross_attention_dim=cross_attention_dim,
490
+ heads=num_attention_heads,
491
+ dim_head=attention_head_dim,
492
+ dropout=dropout,
493
+ bias=attention_bias,
494
+ upcast_attention=upcast_attention,
495
+ )
496
+ else:
497
+ self.attn2 = None
498
+
499
+ if cross_attention_dim is not None:
500
+ self.norm2 = (
501
+ AdaLayerNorm(dim, num_embeds_ada_norm)
502
+ if self.use_ada_layer_norm
503
+ else nn.LayerNorm(dim)
504
+ )
505
+ else:
506
+ self.norm2 = None
507
+
508
+ # Feed-forward
509
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
510
+ self.norm3 = nn.LayerNorm(dim)
511
+ self.use_ada_layer_norm_zero = False
512
+
513
+ # Temp-Attn
514
+ assert unet_use_temporal_attention is not None
515
+ if unet_use_temporal_attention:
516
+ self.attn_temp = Attention(
517
+ query_dim=dim,
518
+ heads=num_attention_heads,
519
+ dim_head=attention_head_dim,
520
+ dropout=dropout,
521
+ bias=attention_bias,
522
+ upcast_attention=upcast_attention,
523
+ )
524
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
525
+ self.norm_temp = (
526
+ AdaLayerNorm(dim, num_embeds_ada_norm)
527
+ if self.use_ada_layer_norm
528
+ else nn.LayerNorm(dim)
529
+ )
530
+
531
+ def forward(
532
+ self,
533
+ hidden_states,
534
+ encoder_hidden_states=None,
535
+ timestep=None,
536
+ attention_mask=None,
537
+ video_length=None,
538
+ block_idx: Optional[int] = None,
539
+ additional_residuals: Optional[Dict[str, torch.FloatTensor]] = None
540
+ ):
541
+ norm_hidden_states = (
542
+ self.norm1(hidden_states, timestep)
543
+ if self.use_ada_layer_norm
544
+ else self.norm1(hidden_states)
545
+ )
546
+
547
+ if self.unet_use_cross_frame_attention:
548
+ hidden_states = (
549
+ self.attn1(
550
+ norm_hidden_states,
551
+ attention_mask=attention_mask,
552
+ video_length=video_length,
553
+ block_idx=block_idx,
554
+ additional_residuals=additional_residuals,
555
+ )
556
+ + hidden_states
557
+ )
558
+ else:
559
+ hidden_states = (
560
+ self.attn1(norm_hidden_states, attention_mask=attention_mask,
561
+ block_idx=block_idx,
562
+ additional_residuals=additional_residuals
563
+ )
564
+ + hidden_states
565
+ )
566
+
567
+ if self.attn2 is not None:
568
+ # Cross-Attention
569
+ norm_hidden_states = (
570
+ self.norm2(hidden_states, timestep)
571
+ if self.use_ada_layer_norm
572
+ else self.norm2(hidden_states)
573
+ )
574
+ hidden_states = (
575
+ self.attn2(
576
+ norm_hidden_states,
577
+ encoder_hidden_states=encoder_hidden_states,
578
+ attention_mask=attention_mask,
579
+ block_idx=block_idx,
580
+ additional_residuals=additional_residuals,
581
+ )
582
+ + hidden_states
583
+ )
584
+
585
+ # Feed-forward
586
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
587
+
588
+ # Temporal-Attention
589
+ if self.unet_use_temporal_attention:
590
+ d = hidden_states.shape[1]
591
+ hidden_states = rearrange(
592
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
593
+ )
594
+ norm_hidden_states = (
595
+ self.norm_temp(hidden_states, timestep)
596
+ if self.use_ada_layer_norm
597
+ else self.norm_temp(hidden_states)
598
+ )
599
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
600
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
601
+
602
+ return hidden_states
603
+
604
+
605
+ class ResidualAttention(Attention):
606
+ def set_use_memory_efficient_attention_xformers(
607
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
608
+ ):
609
+ is_lora = hasattr(self, "processor") and isinstance(
610
+ self.processor,
611
+ LORA_ATTENTION_PROCESSORS,
612
+ )
613
+ is_custom_diffusion = hasattr(self, "processor") and isinstance(
614
+ self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
615
+ )
616
+ is_added_kv_processor = hasattr(self, "processor") and isinstance(
617
+ self.processor,
618
+ (
619
+ AttnAddedKVProcessor,
620
+ AttnAddedKVProcessor2_0,
621
+ SlicedAttnAddedKVProcessor,
622
+ XFormersAttnAddedKVProcessor,
623
+ LoRAAttnAddedKVProcessor,
624
+ ),
625
+ )
626
+
627
+ if use_memory_efficient_attention_xformers:
628
+ if is_added_kv_processor and (is_lora or is_custom_diffusion):
629
+ raise NotImplementedError(
630
+ f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
631
+ )
632
+ if not is_xformers_available():
633
+ raise ModuleNotFoundError(
634
+ (
635
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
636
+ " xformers"
637
+ ),
638
+ name="xformers",
639
+ )
640
+ elif not torch.cuda.is_available():
641
+ raise ValueError(
642
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
643
+ " only available for GPU "
644
+ )
645
+ else:
646
+ try:
647
+ # Make sure we can run the memory efficient attention
648
+ _ = xformers.ops.memory_efficient_attention(
649
+ torch.randn((1, 2, 40), device="cuda"),
650
+ torch.randn((1, 2, 40), device="cuda"),
651
+ torch.randn((1, 2, 40), device="cuda"),
652
+ )
653
+ except Exception as e:
654
+ raise e
655
+
656
+ if is_lora:
657
+ # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
658
+ # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
659
+ processor = LoRAXFormersAttnProcessor(
660
+ hidden_size=self.processor.hidden_size,
661
+ cross_attention_dim=self.processor.cross_attention_dim,
662
+ rank=self.processor.rank,
663
+ attention_op=attention_op,
664
+ )
665
+ processor.load_state_dict(self.processor.state_dict())
666
+ processor.to(self.processor.to_q_lora.up.weight.device)
667
+ elif is_custom_diffusion:
668
+ processor = CustomDiffusionXFormersAttnProcessor(
669
+ train_kv=self.processor.train_kv,
670
+ train_q_out=self.processor.train_q_out,
671
+ hidden_size=self.processor.hidden_size,
672
+ cross_attention_dim=self.processor.cross_attention_dim,
673
+ attention_op=attention_op,
674
+ )
675
+ processor.load_state_dict(self.processor.state_dict())
676
+ if hasattr(self.processor, "to_k_custom_diffusion"):
677
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
678
+ elif is_added_kv_processor:
679
+ # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
680
+ # which uses this type of cross attention ONLY because the attention mask of format
681
+ # [0, ..., -10.000, ..., 0, ...,] is not supported
682
+ # throw warning
683
+ logger.info(
684
+ "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
685
+ )
686
+ processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
687
+ else:
688
+ processor = ResidualXFormersAttnProcessor(attention_op=attention_op)
689
+ else:
690
+ if is_lora:
691
+ attn_processor_class = (
692
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
693
+ )
694
+ processor = attn_processor_class(
695
+ hidden_size=self.processor.hidden_size,
696
+ cross_attention_dim=self.processor.cross_attention_dim,
697
+ rank=self.processor.rank,
698
+ )
699
+ processor.load_state_dict(self.processor.state_dict())
700
+ processor.to(self.processor.to_q_lora.up.weight.device)
701
+ elif is_custom_diffusion:
702
+ processor = CustomDiffusionAttnProcessor(
703
+ train_kv=self.processor.train_kv,
704
+ train_q_out=self.processor.train_q_out,
705
+ hidden_size=self.processor.hidden_size,
706
+ cross_attention_dim=self.processor.cross_attention_dim,
707
+ )
708
+ processor.load_state_dict(self.processor.state_dict())
709
+ if hasattr(self.processor, "to_k_custom_diffusion"):
710
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
711
+ else:
712
+ # set attention processor
713
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
714
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
715
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
716
+ processor = (
717
+ AttnProcessor2_0()
718
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
719
+ else AttnProcessor()
720
+ )
721
+
722
+ self.set_processor(processor)
723
+
724
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None,
725
+ block_idx: Optional[int] = None, additional_residuals: Optional[Dict[str, torch.FloatTensor]] = None,
726
+ is_self_attn: Optional[bool] = None, **cross_attention_kwargs):
727
+ # The `Attention` class can call different attention processors / attention functions
728
+ # here we simply pass along all tensors to the selected processor class
729
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
730
+ return self.processor(
731
+ self,
732
+ hidden_states,
733
+ encoder_hidden_states=encoder_hidden_states,
734
+ attention_mask=attention_mask,
735
+ block_idx=block_idx,
736
+ additional_residuals=additional_residuals,
737
+ is_self_attn=is_self_attn,
738
+ **cross_attention_kwargs,
739
+ )
740
+
741
+ class ResidualXFormersAttnProcessor(XFormersAttnProcessor):
742
+ def __call__(
743
+ self,
744
+ attn: Attention,
745
+ hidden_states: torch.FloatTensor,
746
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
747
+ attention_mask: Optional[torch.FloatTensor] = None,
748
+ temb: Optional[torch.FloatTensor] = None,
749
+ block_idx: Optional[int] = None,
750
+ additional_residuals: Optional[Dict[str, torch.FloatTensor]] = None,
751
+ is_self_attn: Optional[bool] = None
752
+ ):
753
+ residual = hidden_states
754
+
755
+ if attn.spatial_norm is not None:
756
+ hidden_states = attn.spatial_norm(hidden_states, temb)
757
+
758
+ input_ndim = hidden_states.ndim
759
+
760
+ if input_ndim == 4:
761
+ batch_size, channel, height, width = hidden_states.shape
762
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
763
+
764
+ batch_size, key_tokens, _ = (
765
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
766
+ )
767
+
768
+ attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
769
+ if attention_mask is not None:
770
+ # expand our mask's singleton query_tokens dimension:
771
+ # [batch*heads, 1, key_tokens] ->
772
+ # [batch*heads, query_tokens, key_tokens]
773
+ # so that it can be added as a bias onto the attention scores that xformers computes:
774
+ # [batch*heads, query_tokens, key_tokens]
775
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
776
+ _, query_tokens, _ = hidden_states.shape
777
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
778
+
779
+ if attn.group_norm is not None:
780
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
781
+
782
+ query = attn.to_q(hidden_states)
783
+
784
+ # newly added
785
+ if is_self_attn and additional_residuals and f"block_{block_idx}_self_attn_q" in additional_residuals:
786
+ query = query + additional_residuals[f"block_{block_idx}_self_attn_q"]
787
+ elif not is_self_attn and additional_residuals and f"block_{block_idx}_cross_attn_q" in additional_residuals:
788
+ query = query + additional_residuals[f"block_{block_idx}_cross_attn_q"]
789
+
790
+ if encoder_hidden_states is None:
791
+ encoder_hidden_states = hidden_states
792
+ elif attn.norm_cross:
793
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
794
+
795
+ if not is_self_attn and additional_residuals and f"block_{block_idx}_cross_attn_c" in additional_residuals:
796
+ not_uc = torch.abs(encoder_hidden_states - torch.zeros_like(encoder_hidden_states)).mean(dim=[1, 2], keepdim=True) < 1e-4
797
+ encoder_hidden_states = encoder_hidden_states + additional_residuals[f"block_{block_idx}_cross_attn_c"] * not_uc
798
+ # encoder_hidden_states[not_uc] = encoder_hidden_states[not_uc] + \
799
+ # additional_residuals[f"block_{block_idx}_cross_attn_c"][not_uc]
800
+ # encoder_hidden_states[~not_uc] = encoder_hidden_states[~not_uc] + \
801
+ # additional_residuals[f"block_{block_idx}_cross_attn_c"][~not_uc] * 0.
802
+
803
+ key = attn.to_k(encoder_hidden_states)
804
+ value = attn.to_v(encoder_hidden_states)
805
+
806
+ # newly added
807
+ if is_self_attn and additional_residuals and f"block_{block_idx}_self_attn_k" in additional_residuals:
808
+ key = key + additional_residuals[f"block_{block_idx}_self_attn_k"]
809
+ elif not is_self_attn and additional_residuals and f"block_{block_idx}_cross_attn_k" in additional_residuals:
810
+ key = key + additional_residuals[f"block_{block_idx}_cross_attn_k"]
811
+
812
+ if is_self_attn and additional_residuals and f"block_{block_idx}_self_attn_v" in additional_residuals:
813
+ value = value + additional_residuals[f"block_{block_idx}_self_attn_v"]
814
+ elif not is_self_attn and additional_residuals and f"block_{block_idx}_cross_attn_v" in additional_residuals:
815
+ value = value + additional_residuals[f"block_{block_idx}_cross_attn_v"]
816
+
817
+ query = attn.head_to_batch_dim(query).contiguous()
818
+ key = attn.head_to_batch_dim(key).contiguous()
819
+ value = attn.head_to_batch_dim(value).contiguous()
820
+
821
+ hidden_states = xformers.ops.memory_efficient_attention(
822
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
823
+ )
824
+ hidden_states = hidden_states.to(query.dtype)
825
+ hidden_states = attn.batch_to_head_dim(hidden_states)
826
+
827
+ # linear proj
828
+ hidden_states = attn.to_out[0](hidden_states)
829
+ # dropout
830
+ hidden_states = attn.to_out[1](hidden_states)
831
+
832
+ if input_ndim == 4:
833
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
834
+
835
+ if attn.residual_connection:
836
+ hidden_states = hidden_states + residual
837
+
838
+ hidden_states = hidden_states / attn.rescale_output_factor
839
+
840
+ return hidden_states
src/models/motion_module.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapt from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Callable, Optional
5
+
6
+ import torch
7
+ from diffusers.models.attention import FeedForward
8
+ from diffusers.models.attention_processor import Attention, AttnProcessor
9
+ from diffusers.utils import BaseOutput
10
+ from diffusers.utils.import_utils import is_xformers_available
11
+ from einops import rearrange, repeat
12
+ from torch import nn
13
+
14
+
15
+ def zero_module(module):
16
+ # Zero out the parameters of a module and return it.
17
+ for p in module.parameters():
18
+ p.detach().zero_()
19
+ return module
20
+
21
+
22
+ @dataclass
23
+ class TemporalTransformer3DModelOutput(BaseOutput):
24
+ sample: torch.FloatTensor
25
+
26
+
27
+ if is_xformers_available():
28
+ import xformers
29
+ import xformers.ops
30
+ else:
31
+ xformers = None
32
+
33
+
34
+ def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
35
+ if motion_module_type == "Vanilla":
36
+ return VanillaTemporalModule(
37
+ in_channels=in_channels,
38
+ **motion_module_kwargs,
39
+ )
40
+ else:
41
+ raise ValueError
42
+
43
+
44
+ class VanillaTemporalModule(nn.Module):
45
+ def __init__(
46
+ self,
47
+ in_channels,
48
+ num_attention_heads=8,
49
+ num_transformer_block=2,
50
+ attention_block_types=("Temporal_Self", "Temporal_Self"),
51
+ cross_frame_attention_mode=None,
52
+ temporal_position_encoding=False,
53
+ temporal_position_encoding_max_len=24,
54
+ temporal_attention_dim_div=1,
55
+ zero_initialize=True,
56
+ ):
57
+ super().__init__()
58
+
59
+ self.temporal_transformer = TemporalTransformer3DModel(
60
+ in_channels=in_channels,
61
+ num_attention_heads=num_attention_heads,
62
+ attention_head_dim=in_channels
63
+ // num_attention_heads
64
+ // temporal_attention_dim_div,
65
+ num_layers=num_transformer_block,
66
+ attention_block_types=attention_block_types,
67
+ cross_frame_attention_mode=cross_frame_attention_mode,
68
+ temporal_position_encoding=temporal_position_encoding,
69
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
70
+ )
71
+
72
+ if zero_initialize:
73
+ self.temporal_transformer.proj_out = zero_module(
74
+ self.temporal_transformer.proj_out
75
+ )
76
+
77
+ def forward(
78
+ self,
79
+ input_tensor,
80
+ temb,
81
+ encoder_hidden_states,
82
+ attention_mask=None,
83
+ anchor_frame_idx=None,
84
+ ):
85
+ hidden_states = input_tensor
86
+ hidden_states = self.temporal_transformer(
87
+ hidden_states, encoder_hidden_states, attention_mask
88
+ )
89
+
90
+ output = hidden_states
91
+ return output
92
+
93
+
94
+ class TemporalTransformer3DModel(nn.Module):
95
+ def __init__(
96
+ self,
97
+ in_channels,
98
+ num_attention_heads,
99
+ attention_head_dim,
100
+ num_layers,
101
+ attention_block_types=(
102
+ "Temporal_Self",
103
+ "Temporal_Self",
104
+ ),
105
+ dropout=0.0,
106
+ norm_num_groups=32,
107
+ cross_attention_dim=768,
108
+ activation_fn="geglu",
109
+ attention_bias=False,
110
+ upcast_attention=False,
111
+ cross_frame_attention_mode=None,
112
+ temporal_position_encoding=False,
113
+ temporal_position_encoding_max_len=24,
114
+ ):
115
+ super().__init__()
116
+
117
+ inner_dim = num_attention_heads * attention_head_dim
118
+
119
+ self.norm = torch.nn.GroupNorm(
120
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
121
+ )
122
+ self.proj_in = nn.Linear(in_channels, inner_dim)
123
+
124
+ self.transformer_blocks = nn.ModuleList(
125
+ [
126
+ TemporalTransformerBlock(
127
+ dim=inner_dim,
128
+ num_attention_heads=num_attention_heads,
129
+ attention_head_dim=attention_head_dim,
130
+ attention_block_types=attention_block_types,
131
+ dropout=dropout,
132
+ norm_num_groups=norm_num_groups,
133
+ cross_attention_dim=cross_attention_dim,
134
+ activation_fn=activation_fn,
135
+ attention_bias=attention_bias,
136
+ upcast_attention=upcast_attention,
137
+ cross_frame_attention_mode=cross_frame_attention_mode,
138
+ temporal_position_encoding=temporal_position_encoding,
139
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
140
+ )
141
+ for d in range(num_layers)
142
+ ]
143
+ )
144
+ self.proj_out = nn.Linear(inner_dim, in_channels)
145
+
146
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
147
+ assert (
148
+ hidden_states.dim() == 5
149
+ ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
150
+ video_length = hidden_states.shape[2]
151
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
152
+
153
+ batch, channel, height, weight = hidden_states.shape
154
+ residual = hidden_states
155
+
156
+ hidden_states = self.norm(hidden_states)
157
+ inner_dim = hidden_states.shape[1]
158
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
159
+ batch, height * weight, inner_dim
160
+ )
161
+ hidden_states = self.proj_in(hidden_states)
162
+
163
+ # Transformer Blocks
164
+ for block in self.transformer_blocks:
165
+ hidden_states = block(
166
+ hidden_states,
167
+ encoder_hidden_states=encoder_hidden_states,
168
+ video_length=video_length,
169
+ )
170
+
171
+ # output
172
+ hidden_states = self.proj_out(hidden_states)
173
+ hidden_states = (
174
+ hidden_states.reshape(batch, height, weight, inner_dim)
175
+ .permute(0, 3, 1, 2)
176
+ .contiguous()
177
+ )
178
+
179
+ output = hidden_states + residual
180
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
181
+
182
+ return output
183
+
184
+
185
+ class TemporalTransformerBlock(nn.Module):
186
+ def __init__(
187
+ self,
188
+ dim,
189
+ num_attention_heads,
190
+ attention_head_dim,
191
+ attention_block_types=(
192
+ "Temporal_Self",
193
+ "Temporal_Self",
194
+ ),
195
+ dropout=0.0,
196
+ norm_num_groups=32,
197
+ cross_attention_dim=768,
198
+ activation_fn="geglu",
199
+ attention_bias=False,
200
+ upcast_attention=False,
201
+ cross_frame_attention_mode=None,
202
+ temporal_position_encoding=False,
203
+ temporal_position_encoding_max_len=24,
204
+ ):
205
+ super().__init__()
206
+
207
+ attention_blocks = []
208
+ norms = []
209
+
210
+ for block_name in attention_block_types:
211
+ attention_blocks.append(
212
+ VersatileAttention(
213
+ attention_mode=block_name.split("_")[0],
214
+ cross_attention_dim=cross_attention_dim
215
+ if block_name.endswith("_Cross")
216
+ else None,
217
+ query_dim=dim,
218
+ heads=num_attention_heads,
219
+ dim_head=attention_head_dim,
220
+ dropout=dropout,
221
+ bias=attention_bias,
222
+ upcast_attention=upcast_attention,
223
+ cross_frame_attention_mode=cross_frame_attention_mode,
224
+ temporal_position_encoding=temporal_position_encoding,
225
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
226
+ )
227
+ )
228
+ norms.append(nn.LayerNorm(dim))
229
+
230
+ self.attention_blocks = nn.ModuleList(attention_blocks)
231
+ self.norms = nn.ModuleList(norms)
232
+
233
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
234
+ self.ff_norm = nn.LayerNorm(dim)
235
+
236
+ def forward(
237
+ self,
238
+ hidden_states,
239
+ encoder_hidden_states=None,
240
+ attention_mask=None,
241
+ video_length=None,
242
+ ):
243
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
244
+ norm_hidden_states = norm(hidden_states)
245
+ hidden_states = (
246
+ attention_block(
247
+ norm_hidden_states,
248
+ encoder_hidden_states=encoder_hidden_states
249
+ if attention_block.is_cross_attention
250
+ else None,
251
+ video_length=video_length,
252
+ )
253
+ + hidden_states
254
+ )
255
+
256
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
257
+
258
+ output = hidden_states
259
+ return output
260
+
261
+
262
+ class PositionalEncoding(nn.Module):
263
+ def __init__(self, d_model, dropout=0.0, max_len=24):
264
+ super().__init__()
265
+ self.dropout = nn.Dropout(p=dropout)
266
+ position = torch.arange(max_len).unsqueeze(1)
267
+ div_term = torch.exp(
268
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
269
+ )
270
+ pe = torch.zeros(1, max_len, d_model)
271
+ pe[0, :, 0::2] = torch.sin(position * div_term)
272
+ pe[0, :, 1::2] = torch.cos(position * div_term)
273
+ self.register_buffer("pe", pe)
274
+
275
+ def forward(self, x):
276
+ x = x + self.pe[:, : x.size(1)]
277
+ return self.dropout(x)
278
+
279
+
280
+ class VersatileAttention(Attention):
281
+ def __init__(
282
+ self,
283
+ attention_mode=None,
284
+ cross_frame_attention_mode=None,
285
+ temporal_position_encoding=False,
286
+ temporal_position_encoding_max_len=24,
287
+ *args,
288
+ **kwargs,
289
+ ):
290
+ super().__init__(*args, **kwargs)
291
+ assert attention_mode == "Temporal"
292
+
293
+ self.attention_mode = attention_mode
294
+ self.is_cross_attention = kwargs["cross_attention_dim"] is not None
295
+
296
+ self.pos_encoder = (
297
+ PositionalEncoding(
298
+ kwargs["query_dim"],
299
+ dropout=0.0,
300
+ max_len=temporal_position_encoding_max_len,
301
+ )
302
+ if (temporal_position_encoding and attention_mode == "Temporal")
303
+ else None
304
+ )
305
+
306
+ def extra_repr(self):
307
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
308
+
309
+ def set_use_memory_efficient_attention_xformers(
310
+ self,
311
+ use_memory_efficient_attention_xformers: bool,
312
+ attention_op: Optional[Callable] = None,
313
+ ):
314
+ if use_memory_efficient_attention_xformers:
315
+ if not is_xformers_available():
316
+ raise ModuleNotFoundError(
317
+ (
318
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
319
+ " xformers"
320
+ ),
321
+ name="xformers",
322
+ )
323
+ elif not torch.cuda.is_available():
324
+ raise ValueError(
325
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
326
+ " only available for GPU "
327
+ )
328
+ else:
329
+ try:
330
+ # Make sure we can run the memory efficient attention
331
+ _ = xformers.ops.memory_efficient_attention(
332
+ torch.randn((1, 2, 40), device="cuda"),
333
+ torch.randn((1, 2, 40), device="cuda"),
334
+ torch.randn((1, 2, 40), device="cuda"),
335
+ )
336
+ except Exception as e:
337
+ raise e
338
+
339
+ # XFormersAttnProcessor corrupts video generation and work with Pytorch 1.13.
340
+ # Pytorch 2.0.1 AttnProcessor works the same as XFormersAttnProcessor in Pytorch 1.13.
341
+ # You don't need XFormersAttnProcessor here.
342
+ # processor = XFormersAttnProcessor(
343
+ # attention_op=attention_op,
344
+ # )
345
+ processor = AttnProcessor()
346
+ else:
347
+ processor = AttnProcessor()
348
+
349
+ self.set_processor(processor)
350
+
351
+ def forward(
352
+ self,
353
+ hidden_states,
354
+ encoder_hidden_states=None,
355
+ attention_mask=None,
356
+ video_length=None,
357
+ **cross_attention_kwargs,
358
+ ):
359
+ if self.attention_mode == "Temporal":
360
+ d = hidden_states.shape[1] # d means HxW
361
+ hidden_states = rearrange(
362
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
363
+ )
364
+
365
+ if self.pos_encoder is not None:
366
+ hidden_states = self.pos_encoder(hidden_states)
367
+
368
+ encoder_hidden_states = (
369
+ repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
370
+ if encoder_hidden_states is not None
371
+ else encoder_hidden_states
372
+ )
373
+
374
+ else:
375
+ raise NotImplementedError
376
+
377
+ hidden_states = self.processor(
378
+ self,
379
+ hidden_states,
380
+ encoder_hidden_states=encoder_hidden_states,
381
+ attention_mask=attention_mask,
382
+ **cross_attention_kwargs,
383
+ )
384
+
385
+ if self.attention_mode == "Temporal":
386
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
387
+
388
+ return hidden_states
src/models/mutual_self_attention.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py
2
+ from typing import Any, Dict, Optional
3
+
4
+ import torch
5
+ from einops import rearrange
6
+
7
+ from src.models.attention import TemporalBasicTransformerBlock
8
+
9
+ from .attention import BasicTransformerBlock
10
+
11
+
12
+ def torch_dfs(model: torch.nn.Module):
13
+ result = [model]
14
+ for child in model.children():
15
+ result += torch_dfs(child)
16
+ return result
17
+
18
+
19
+ class ReferenceAttentionControl:
20
+ def __init__(
21
+ self,
22
+ unet,
23
+ mode="write",
24
+ do_classifier_free_guidance=False,
25
+ attention_auto_machine_weight=float("inf"),
26
+ gn_auto_machine_weight=1.0,
27
+ style_fidelity=1.0,
28
+ reference_attn=True,
29
+ reference_adain=False,
30
+ fusion_blocks="midup",
31
+ batch_size=1,
32
+ ) -> None:
33
+ # 10. Modify self attention and group norm
34
+ self.unet = unet
35
+ assert mode in ["read", "write"]
36
+ assert fusion_blocks in ["midup", "full"]
37
+ self.reference_attn = reference_attn
38
+ self.reference_adain = reference_adain
39
+ self.fusion_blocks = fusion_blocks
40
+ self.register_reference_hooks(
41
+ mode,
42
+ do_classifier_free_guidance,
43
+ attention_auto_machine_weight,
44
+ gn_auto_machine_weight,
45
+ style_fidelity,
46
+ reference_attn,
47
+ reference_adain,
48
+ fusion_blocks,
49
+ batch_size=batch_size,
50
+ )
51
+
52
+ def register_reference_hooks(
53
+ self,
54
+ mode,
55
+ do_classifier_free_guidance,
56
+ attention_auto_machine_weight,
57
+ gn_auto_machine_weight,
58
+ style_fidelity,
59
+ reference_attn,
60
+ reference_adain,
61
+ dtype=torch.float16,
62
+ batch_size=1,
63
+ num_images_per_prompt=1,
64
+ device=torch.device("cpu"),
65
+ fusion_blocks="midup",
66
+ ):
67
+ MODE = mode
68
+ do_classifier_free_guidance = do_classifier_free_guidance
69
+ attention_auto_machine_weight = attention_auto_machine_weight
70
+ gn_auto_machine_weight = gn_auto_machine_weight
71
+ style_fidelity = style_fidelity
72
+ reference_attn = reference_attn
73
+ reference_adain = reference_adain
74
+ fusion_blocks = fusion_blocks
75
+ num_images_per_prompt = num_images_per_prompt
76
+ dtype = dtype
77
+ if do_classifier_free_guidance:
78
+ uc_mask = (
79
+ torch.Tensor(
80
+ [1] * batch_size * num_images_per_prompt * 16
81
+ + [0] * batch_size * num_images_per_prompt * 16
82
+ )
83
+ .to(device)
84
+ .bool()
85
+ )
86
+ else:
87
+ uc_mask = (
88
+ torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
89
+ .to(device)
90
+ .bool()
91
+ )
92
+
93
+ def hacked_basic_transformer_inner_forward(
94
+ self,
95
+ hidden_states: torch.FloatTensor,
96
+ attention_mask: Optional[torch.FloatTensor] = None,
97
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
98
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
99
+ timestep: Optional[torch.LongTensor] = None,
100
+ cross_attention_kwargs: Dict[str, Any] = None,
101
+ class_labels: Optional[torch.LongTensor] = None,
102
+ video_length=None,
103
+ ):
104
+ if self.use_ada_layer_norm: # False
105
+ norm_hidden_states = self.norm1(hidden_states, timestep)
106
+ elif self.use_ada_layer_norm_zero:
107
+ (
108
+ norm_hidden_states,
109
+ gate_msa,
110
+ shift_mlp,
111
+ scale_mlp,
112
+ gate_mlp,
113
+ ) = self.norm1(
114
+ hidden_states,
115
+ timestep,
116
+ class_labels,
117
+ hidden_dtype=hidden_states.dtype,
118
+ )
119
+ else:
120
+ norm_hidden_states = self.norm1(hidden_states)
121
+
122
+ # 1. Self-Attention
123
+ # self.only_cross_attention = False
124
+ cross_attention_kwargs = (
125
+ cross_attention_kwargs if cross_attention_kwargs is not None else {}
126
+ )
127
+ if self.only_cross_attention:
128
+ attn_output = self.attn1(
129
+ norm_hidden_states,
130
+ encoder_hidden_states=encoder_hidden_states
131
+ if self.only_cross_attention
132
+ else None,
133
+ attention_mask=attention_mask,
134
+ **cross_attention_kwargs,
135
+ )
136
+ else:
137
+ if MODE == "write":
138
+ self.bank.append(norm_hidden_states.clone())
139
+ attn_output = self.attn1(
140
+ norm_hidden_states,
141
+ encoder_hidden_states=encoder_hidden_states
142
+ if self.only_cross_attention
143
+ else None,
144
+ attention_mask=attention_mask,
145
+ **cross_attention_kwargs,
146
+ )
147
+ if MODE == "read":
148
+ bank_fea = [
149
+ rearrange(
150
+ d.unsqueeze(1).repeat(1, video_length, 1, 1),
151
+ "b t l c -> (b t) l c",
152
+ )
153
+ for d in self.bank
154
+ ]
155
+ modify_norm_hidden_states = torch.cat(
156
+ [norm_hidden_states] + bank_fea, dim=1
157
+ )
158
+ hidden_states_uc = (
159
+ self.attn1(
160
+ norm_hidden_states,
161
+ encoder_hidden_states=modify_norm_hidden_states,
162
+ attention_mask=attention_mask,
163
+ )
164
+ + hidden_states
165
+ )
166
+ if do_classifier_free_guidance:
167
+ hidden_states_c = hidden_states_uc.clone()
168
+ _uc_mask = uc_mask.clone()
169
+ if hidden_states.shape[0] != _uc_mask.shape[0]:
170
+ _uc_mask = (
171
+ torch.Tensor(
172
+ [1] * (hidden_states.shape[0] // 2)
173
+ + [0] * (hidden_states.shape[0] // 2)
174
+ )
175
+ .to(device)
176
+ .bool()
177
+ )
178
+ hidden_states_c[_uc_mask] = (
179
+ self.attn1(
180
+ norm_hidden_states[_uc_mask],
181
+ encoder_hidden_states=norm_hidden_states[_uc_mask],
182
+ attention_mask=attention_mask,
183
+ )
184
+ + hidden_states[_uc_mask]
185
+ )
186
+ hidden_states = hidden_states_c.clone()
187
+ else:
188
+ hidden_states = hidden_states_uc
189
+
190
+ # self.bank.clear()
191
+ if self.attn2 is not None:
192
+ # Cross-Attention
193
+ norm_hidden_states = (
194
+ self.norm2(hidden_states, timestep)
195
+ if self.use_ada_layer_norm
196
+ else self.norm2(hidden_states)
197
+ )
198
+ hidden_states = (
199
+ self.attn2(
200
+ norm_hidden_states,
201
+ encoder_hidden_states=encoder_hidden_states,
202
+ attention_mask=attention_mask,
203
+ )
204
+ + hidden_states
205
+ )
206
+
207
+ # Feed-forward
208
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
209
+
210
+ # Temporal-Attention
211
+ if self.unet_use_temporal_attention:
212
+ d = hidden_states.shape[1]
213
+ hidden_states = rearrange(
214
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
215
+ )
216
+ norm_hidden_states = (
217
+ self.norm_temp(hidden_states, timestep)
218
+ if self.use_ada_layer_norm
219
+ else self.norm_temp(hidden_states)
220
+ )
221
+ hidden_states = (
222
+ self.attn_temp(norm_hidden_states) + hidden_states
223
+ )
224
+ hidden_states = rearrange(
225
+ hidden_states, "(b d) f c -> (b f) d c", d=d
226
+ )
227
+
228
+ return hidden_states
229
+
230
+ if self.use_ada_layer_norm_zero:
231
+ attn_output = gate_msa.unsqueeze(1) * attn_output
232
+ hidden_states = attn_output + hidden_states
233
+
234
+ if self.attn2 is not None:
235
+ norm_hidden_states = (
236
+ self.norm2(hidden_states, timestep)
237
+ if self.use_ada_layer_norm
238
+ else self.norm2(hidden_states)
239
+ )
240
+
241
+ # 2. Cross-Attention
242
+ attn_output = self.attn2(
243
+ norm_hidden_states,
244
+ encoder_hidden_states=encoder_hidden_states,
245
+ attention_mask=encoder_attention_mask,
246
+ **cross_attention_kwargs,
247
+ )
248
+ hidden_states = attn_output + hidden_states
249
+
250
+ # 3. Feed-forward
251
+ norm_hidden_states = self.norm3(hidden_states)
252
+
253
+ if self.use_ada_layer_norm_zero:
254
+ norm_hidden_states = (
255
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
256
+ )
257
+
258
+ ff_output = self.ff(norm_hidden_states)
259
+
260
+ if self.use_ada_layer_norm_zero:
261
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
262
+
263
+ hidden_states = ff_output + hidden_states
264
+
265
+ return hidden_states
266
+
267
+ if self.reference_attn:
268
+ if self.fusion_blocks == "midup":
269
+ attn_modules = [
270
+ module
271
+ for module in (
272
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
273
+ )
274
+ if isinstance(module, BasicTransformerBlock)
275
+ or isinstance(module, TemporalBasicTransformerBlock)
276
+ ]
277
+ elif self.fusion_blocks == "full":
278
+ attn_modules = [
279
+ module
280
+ for module in torch_dfs(self.unet)
281
+ if isinstance(module, BasicTransformerBlock)
282
+ or isinstance(module, TemporalBasicTransformerBlock)
283
+ ]
284
+ attn_modules = sorted(
285
+ attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
286
+ )
287
+
288
+ for i, module in enumerate(attn_modules):
289
+ module._original_inner_forward = module.forward
290
+ if isinstance(module, BasicTransformerBlock):
291
+ module.forward = hacked_basic_transformer_inner_forward.__get__(
292
+ module, BasicTransformerBlock
293
+ )
294
+ if isinstance(module, TemporalBasicTransformerBlock):
295
+ module.forward = hacked_basic_transformer_inner_forward.__get__(
296
+ module, TemporalBasicTransformerBlock
297
+ )
298
+
299
+ module.bank = []
300
+ module.attn_weight = float(i) / float(len(attn_modules))
301
+
302
+ def update(self, writer, dtype=torch.float16):
303
+ if self.reference_attn:
304
+ if self.fusion_blocks == "midup":
305
+ reader_attn_modules = [
306
+ module
307
+ for module in (
308
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
309
+ )
310
+ if isinstance(module, TemporalBasicTransformerBlock)
311
+ ]
312
+ writer_attn_modules = [
313
+ module
314
+ for module in (
315
+ torch_dfs(writer.unet.mid_block)
316
+ + torch_dfs(writer.unet.up_blocks)
317
+ )
318
+ if isinstance(module, BasicTransformerBlock)
319
+ ]
320
+ elif self.fusion_blocks == "full":
321
+ reader_attn_modules = [
322
+ module
323
+ for module in torch_dfs(self.unet)
324
+ if isinstance(module, TemporalBasicTransformerBlock)
325
+ ]
326
+ writer_attn_modules = [
327
+ module
328
+ for module in torch_dfs(writer.unet)
329
+ if isinstance(module, BasicTransformerBlock)
330
+ ]
331
+ reader_attn_modules = sorted(
332
+ reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
333
+ )
334
+ writer_attn_modules = sorted(
335
+ writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
336
+ )
337
+ for r, w in zip(reader_attn_modules, writer_attn_modules):
338
+ r.bank = [v.clone().to(dtype) for v in w.bank]
339
+ # w.bank.clear()
340
+
341
+ def clear(self):
342
+ if self.reference_attn:
343
+ if self.fusion_blocks == "midup":
344
+ reader_attn_modules = [
345
+ module
346
+ for module in (
347
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
348
+ )
349
+ if isinstance(module, BasicTransformerBlock)
350
+ or isinstance(module, TemporalBasicTransformerBlock)
351
+ ]
352
+ elif self.fusion_blocks == "full":
353
+ reader_attn_modules = [
354
+ module
355
+ for module in torch_dfs(self.unet)
356
+ if isinstance(module, BasicTransformerBlock)
357
+ or isinstance(module, TemporalBasicTransformerBlock)
358
+ ]
359
+ reader_attn_modules = sorted(
360
+ reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
361
+ )
362
+ for r in reader_attn_modules:
363
+ r.bank.clear()
src/models/pose_guider.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.init as init
5
+ from einops import rearrange
6
+ import numpy as np
7
+ from diffusers.models.modeling_utils import ModelMixin
8
+
9
+ from typing import Any, Dict, Optional
10
+ from src.models.attention import BasicTransformerBlock
11
+
12
+
13
+ class PoseGuider(ModelMixin):
14
+ def __init__(self, noise_latent_channels=320, use_ca=True):
15
+ super(PoseGuider, self).__init__()
16
+
17
+ self.use_ca = use_ca
18
+
19
+ self.conv_layers = nn.Sequential(
20
+ nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1),
21
+ nn.BatchNorm2d(3),
22
+ nn.ReLU(),
23
+ nn.Conv2d(in_channels=3, out_channels=16, kernel_size=4, stride=2, padding=1),
24
+ nn.BatchNorm2d(16),
25
+ nn.ReLU(),
26
+
27
+ nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1),
28
+ nn.BatchNorm2d(16),
29
+ nn.ReLU(),
30
+ nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2, padding=1),
31
+ nn.BatchNorm2d(32),
32
+ nn.ReLU(),
33
+
34
+ nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1),
35
+ nn.BatchNorm2d(32),
36
+ nn.ReLU(),
37
+ nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1),
38
+ nn.BatchNorm2d(64),
39
+ nn.ReLU(),
40
+
41
+ nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
42
+ nn.BatchNorm2d(64),
43
+ nn.ReLU(),
44
+ nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
45
+ nn.BatchNorm2d(128),
46
+ nn.ReLU()
47
+ )
48
+
49
+ # Final projection layer
50
+ self.final_proj = nn.Conv2d(in_channels=128, out_channels=noise_latent_channels, kernel_size=1)
51
+
52
+ self.conv_layers_1 = nn.Sequential(
53
+ nn.Conv2d(in_channels=noise_latent_channels, out_channels=noise_latent_channels, kernel_size=3, padding=1),
54
+ nn.BatchNorm2d(noise_latent_channels),
55
+ nn.ReLU(),
56
+ nn.Conv2d(in_channels=noise_latent_channels, out_channels=noise_latent_channels, kernel_size=3, stride=2, padding=1),
57
+ nn.BatchNorm2d(noise_latent_channels),
58
+ nn.ReLU(),
59
+ )
60
+
61
+ self.conv_layers_2 = nn.Sequential(
62
+ nn.Conv2d(in_channels=noise_latent_channels, out_channels=noise_latent_channels, kernel_size=3, padding=1),
63
+ nn.BatchNorm2d(noise_latent_channels),
64
+ nn.ReLU(),
65
+ nn.Conv2d(in_channels=noise_latent_channels, out_channels=noise_latent_channels*2, kernel_size=3, stride=2, padding=1),
66
+ nn.BatchNorm2d(noise_latent_channels*2),
67
+ nn.ReLU(),
68
+ )
69
+
70
+ self.conv_layers_3 = nn.Sequential(
71
+ nn.Conv2d(in_channels=noise_latent_channels*2, out_channels=noise_latent_channels*2, kernel_size=3, padding=1),
72
+ nn.BatchNorm2d(noise_latent_channels*2),
73
+ nn.ReLU(),
74
+ nn.Conv2d(in_channels=noise_latent_channels*2, out_channels=noise_latent_channels*4, kernel_size=3, stride=2, padding=1),
75
+ nn.BatchNorm2d(noise_latent_channels*4),
76
+ nn.ReLU(),
77
+ )
78
+
79
+ self.conv_layers_4 = nn.Sequential(
80
+ nn.Conv2d(in_channels=noise_latent_channels*4, out_channels=noise_latent_channels*4, kernel_size=3, padding=1),
81
+ nn.BatchNorm2d(noise_latent_channels*4),
82
+ nn.ReLU(),
83
+ )
84
+
85
+ if self.use_ca:
86
+ self.cross_attn1 = Transformer2DModel(in_channels=noise_latent_channels)
87
+ self.cross_attn2 = Transformer2DModel(in_channels=noise_latent_channels*2)
88
+ self.cross_attn3 = Transformer2DModel(in_channels=noise_latent_channels*4)
89
+ self.cross_attn4 = Transformer2DModel(in_channels=noise_latent_channels*4)
90
+
91
+ # Initialize layers
92
+ self._initialize_weights()
93
+
94
+ self.scale = nn.Parameter(torch.ones(1) * 2)
95
+
96
+ # def _initialize_weights(self):
97
+ # # Initialize weights with Gaussian distribution and zero out the final layer
98
+ # for m in self.conv_layers:
99
+ # if isinstance(m, nn.Conv2d):
100
+ # init.normal_(m.weight, mean=0.0, std=0.02)
101
+ # if m.bias is not None:
102
+ # init.zeros_(m.bias)
103
+
104
+ # init.zeros_(self.final_proj.weight)
105
+ # if self.final_proj.bias is not None:
106
+ # init.zeros_(self.final_proj.bias)
107
+
108
+ def _initialize_weights(self):
109
+ # Initialize weights with He initialization and zero out the biases
110
+ conv_blocks = [self.conv_layers, self.conv_layers_1, self.conv_layers_2, self.conv_layers_3, self.conv_layers_4]
111
+ for block_item in conv_blocks:
112
+ for m in block_item:
113
+ if isinstance(m, nn.Conv2d):
114
+ n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
115
+ init.normal_(m.weight, mean=0.0, std=np.sqrt(2. / n))
116
+ if m.bias is not None:
117
+ init.zeros_(m.bias)
118
+
119
+ # For the final projection layer, initialize weights to zero (or you may choose to use He initialization here as well)
120
+ init.zeros_(self.final_proj.weight)
121
+ if self.final_proj.bias is not None:
122
+ init.zeros_(self.final_proj.bias)
123
+
124
+ def forward(self, x, ref_x):
125
+ fea = []
126
+ b = x.shape[0]
127
+
128
+ x = rearrange(x, "b c f h w -> (b f) c h w")
129
+ x = self.conv_layers(x)
130
+ x = self.final_proj(x)
131
+ x = x * self.scale
132
+ # x = rearrange(x, "(b f) c h w -> b c f h w", b=b)
133
+ fea.append(rearrange(x, "(b f) c h w -> b c f h w", b=b))
134
+
135
+ x = self.conv_layers_1(x)
136
+ if self.use_ca:
137
+ ref_x = self.conv_layers(ref_x)
138
+ ref_x = self.final_proj(ref_x)
139
+ ref_x = ref_x * self.scale
140
+ ref_x = self.conv_layers_1(ref_x)
141
+ x = self.cross_attn1(x, ref_x)
142
+ fea.append(rearrange(x, "(b f) c h w -> b c f h w", b=b))
143
+
144
+ x = self.conv_layers_2(x)
145
+ if self.use_ca:
146
+ ref_x = self.conv_layers_2(ref_x)
147
+ x = self.cross_attn2(x, ref_x)
148
+ fea.append(rearrange(x, "(b f) c h w -> b c f h w", b=b))
149
+
150
+ x = self.conv_layers_3(x)
151
+ if self.use_ca:
152
+ ref_x = self.conv_layers_3(ref_x)
153
+ x = self.cross_attn3(x, ref_x)
154
+ fea.append(rearrange(x, "(b f) c h w -> b c f h w", b=b))
155
+
156
+ x = self.conv_layers_4(x)
157
+ if self.use_ca:
158
+ ref_x = self.conv_layers_4(ref_x)
159
+ x = self.cross_attn4(x, ref_x)
160
+ fea.append(rearrange(x, "(b f) c h w -> b c f h w", b=b))
161
+
162
+ return fea
163
+
164
+ # @classmethod
165
+ # def from_pretrained(cls,pretrained_model_path):
166
+ # if not os.path.exists(pretrained_model_path):
167
+ # print(f"There is no model file in {pretrained_model_path}")
168
+ # print(f"loaded PoseGuider's pretrained weights from {pretrained_model_path} ...")
169
+
170
+ # state_dict = torch.load(pretrained_model_path, map_location="cpu")
171
+ # model = Hack_PoseGuider(noise_latent_channels=320)
172
+
173
+ # m, u = model.load_state_dict(state_dict, strict=True)
174
+ # # print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
175
+ # params = [p.numel() for n, p in model.named_parameters()]
176
+ # print(f"### PoseGuider's Parameters: {sum(params) / 1e6} M")
177
+
178
+ # return model
179
+
180
+
181
+ class Transformer2DModel(ModelMixin):
182
+ _supports_gradient_checkpointing = True
183
+ def __init__(
184
+ self,
185
+ num_attention_heads: int = 16,
186
+ attention_head_dim: int = 88,
187
+ in_channels: Optional[int] = None,
188
+ num_layers: int = 1,
189
+ dropout: float = 0.0,
190
+ norm_num_groups: int = 32,
191
+ cross_attention_dim: Optional[int] = None,
192
+ attention_bias: bool = False,
193
+ activation_fn: str = "geglu",
194
+ num_embeds_ada_norm: Optional[int] = None,
195
+ use_linear_projection: bool = False,
196
+ only_cross_attention: bool = False,
197
+ double_self_attention: bool = False,
198
+ upcast_attention: bool = False,
199
+ norm_type: str = "layer_norm",
200
+ norm_elementwise_affine: bool = True,
201
+ norm_eps: float = 1e-5,
202
+ attention_type: str = "default",
203
+ ):
204
+ super().__init__()
205
+ self.use_linear_projection = use_linear_projection
206
+ self.num_attention_heads = num_attention_heads
207
+ self.attention_head_dim = attention_head_dim
208
+ inner_dim = num_attention_heads * attention_head_dim
209
+
210
+ self.in_channels = in_channels
211
+
212
+ self.norm = torch.nn.GroupNorm(
213
+ num_groups=norm_num_groups,
214
+ num_channels=in_channels,
215
+ eps=1e-6,
216
+ affine=True,
217
+ )
218
+ if use_linear_projection:
219
+ self.proj_in = nn.Linear(in_channels, inner_dim)
220
+ else:
221
+ self.proj_in = nn.Conv2d(
222
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
223
+ )
224
+
225
+ # 3. Define transformers blocks
226
+ self.transformer_blocks = nn.ModuleList(
227
+ [
228
+ BasicTransformerBlock(
229
+ inner_dim,
230
+ num_attention_heads,
231
+ attention_head_dim,
232
+ dropout=dropout,
233
+ cross_attention_dim=cross_attention_dim,
234
+ activation_fn=activation_fn,
235
+ num_embeds_ada_norm=num_embeds_ada_norm,
236
+ attention_bias=attention_bias,
237
+ only_cross_attention=only_cross_attention,
238
+ double_self_attention=double_self_attention,
239
+ upcast_attention=upcast_attention,
240
+ norm_type=norm_type,
241
+ norm_elementwise_affine=norm_elementwise_affine,
242
+ norm_eps=norm_eps,
243
+ attention_type=attention_type,
244
+ )
245
+ for d in range(num_layers)
246
+ ]
247
+ )
248
+
249
+ if use_linear_projection:
250
+ self.proj_out = nn.Linear(inner_dim, in_channels)
251
+ else:
252
+ self.proj_out = nn.Conv2d(
253
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0
254
+ )
255
+
256
+ self.gradient_checkpointing = False
257
+
258
+ def _set_gradient_checkpointing(self, module, value=False):
259
+ if hasattr(module, "gradient_checkpointing"):
260
+ module.gradient_checkpointing = value
261
+
262
+ def forward(
263
+ self,
264
+ hidden_states: torch.Tensor,
265
+ encoder_hidden_states: Optional[torch.Tensor] = None,
266
+ timestep: Optional[torch.LongTensor] = None,
267
+ ):
268
+ batch, _, height, width = hidden_states.shape
269
+ residual = hidden_states
270
+
271
+ hidden_states = self.norm(hidden_states)
272
+ if not self.use_linear_projection:
273
+ hidden_states = self.proj_in(hidden_states)
274
+ inner_dim = hidden_states.shape[1]
275
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
276
+ batch, height * width, inner_dim
277
+ )
278
+ else:
279
+ inner_dim = hidden_states.shape[1]
280
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
281
+ batch, height * width, inner_dim
282
+ )
283
+ hidden_states = self.proj_in(hidden_states)
284
+
285
+ for block in self.transformer_blocks:
286
+ hidden_states = block(
287
+ hidden_states,
288
+ encoder_hidden_states=encoder_hidden_states,
289
+ timestep=timestep,
290
+ )
291
+
292
+ if not self.use_linear_projection:
293
+ hidden_states = (
294
+ hidden_states.reshape(batch, height, width, inner_dim)
295
+ .permute(0, 3, 1, 2)
296
+ .contiguous()
297
+ )
298
+ hidden_states = self.proj_out(hidden_states)
299
+ else:
300
+ hidden_states = self.proj_out(hidden_states)
301
+ hidden_states = (
302
+ hidden_states.reshape(batch, height, width, inner_dim)
303
+ .permute(0, 3, 1, 2)
304
+ .contiguous()
305
+ )
306
+
307
+ output = hidden_states + residual
308
+ return output
309
+
310
+
311
+ if __name__ == '__main__':
312
+ model = PoseGuider(noise_latent_channels=320).to(device="cuda")
313
+
314
+ input_data = torch.randn(1,3,1,512,512).to(device="cuda")
315
+ input_data1 = torch.randn(1,3,512,512).to(device="cuda")
316
+
317
+ output = model(input_data, input_data1)
318
+ for item in output:
319
+ print(item.shape)
320
+
321
+ # tf_model = Transformer2DModel(
322
+ # in_channels=320
323
+ # ).to('cuda')
324
+
325
+ # input_data = torch.randn(4,320,32,32).to(device="cuda")
326
+ # # input_emb = torch.randn(4,1,768).to(device="cuda")
327
+ # input_emb = torch.randn(4,320,32,32).to(device="cuda")
328
+ # o1 = tf_model(input_data, input_emb)
329
+ # print(o1.shape)
src/models/resnet.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+ from typing import Dict, Optional
8
+
9
+
10
+ class InflatedConv3d(nn.Conv2d):
11
+ def forward(self, x):
12
+ video_length = x.shape[2]
13
+
14
+ x = rearrange(x, "b c f h w -> (b f) c h w")
15
+ x = super().forward(x)
16
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
17
+
18
+ return x
19
+
20
+
21
+ class InflatedGroupNorm(nn.GroupNorm):
22
+ def forward(self, x):
23
+ video_length = x.shape[2]
24
+
25
+ x = rearrange(x, "b c f h w -> (b f) c h w")
26
+ x = super().forward(x)
27
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
28
+
29
+ return x
30
+
31
+
32
+ class Upsample3D(nn.Module):
33
+ def __init__(
34
+ self,
35
+ channels,
36
+ use_conv=False,
37
+ use_conv_transpose=False,
38
+ out_channels=None,
39
+ name="conv",
40
+ ):
41
+ super().__init__()
42
+ self.channels = channels
43
+ self.out_channels = out_channels or channels
44
+ self.use_conv = use_conv
45
+ self.use_conv_transpose = use_conv_transpose
46
+ self.name = name
47
+
48
+ conv = None
49
+ if use_conv_transpose:
50
+ raise NotImplementedError
51
+ elif use_conv:
52
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
53
+
54
+ def forward(self, hidden_states, output_size=None):
55
+ assert hidden_states.shape[1] == self.channels
56
+
57
+ if self.use_conv_transpose:
58
+ raise NotImplementedError
59
+
60
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
61
+ dtype = hidden_states.dtype
62
+ if dtype == torch.bfloat16:
63
+ hidden_states = hidden_states.to(torch.float32)
64
+
65
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
66
+ if hidden_states.shape[0] >= 64:
67
+ hidden_states = hidden_states.contiguous()
68
+
69
+ # if `output_size` is passed we force the interpolation output
70
+ # size and do not make use of `scale_factor=2`
71
+ if output_size is None:
72
+ hidden_states = F.interpolate(
73
+ hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest"
74
+ )
75
+ else:
76
+ hidden_states = F.interpolate(
77
+ hidden_states, size=output_size, mode="nearest"
78
+ )
79
+
80
+ # If the input is bfloat16, we cast back to bfloat16
81
+ if dtype == torch.bfloat16:
82
+ hidden_states = hidden_states.to(dtype)
83
+
84
+ # if self.use_conv:
85
+ # if self.name == "conv":
86
+ # hidden_states = self.conv(hidden_states)
87
+ # else:
88
+ # hidden_states = self.Conv2d_0(hidden_states)
89
+ hidden_states = self.conv(hidden_states)
90
+
91
+ return hidden_states
92
+
93
+
94
+ class Downsample3D(nn.Module):
95
+ def __init__(
96
+ self, channels, use_conv=False, out_channels=None, padding=1, name="conv"
97
+ ):
98
+ super().__init__()
99
+ self.channels = channels
100
+ self.out_channels = out_channels or channels
101
+ self.use_conv = use_conv
102
+ self.padding = padding
103
+ stride = 2
104
+ self.name = name
105
+
106
+ if use_conv:
107
+ self.conv = InflatedConv3d(
108
+ self.channels, self.out_channels, 3, stride=stride, padding=padding
109
+ )
110
+ else:
111
+ raise NotImplementedError
112
+
113
+ def forward(self, hidden_states):
114
+ assert hidden_states.shape[1] == self.channels
115
+ if self.use_conv and self.padding == 0:
116
+ raise NotImplementedError
117
+
118
+ assert hidden_states.shape[1] == self.channels
119
+ hidden_states = self.conv(hidden_states)
120
+
121
+ return hidden_states
122
+
123
+
124
+ class ResnetBlock3D(nn.Module):
125
+ def __init__(
126
+ self,
127
+ *,
128
+ in_channels,
129
+ out_channels=None,
130
+ conv_shortcut=False,
131
+ dropout=0.0,
132
+ temb_channels=512,
133
+ groups=32,
134
+ groups_out=None,
135
+ pre_norm=True,
136
+ eps=1e-6,
137
+ non_linearity="swish",
138
+ time_embedding_norm="default",
139
+ output_scale_factor=1.0,
140
+ use_in_shortcut=None,
141
+ use_inflated_groupnorm=None,
142
+ ):
143
+ super().__init__()
144
+ self.pre_norm = pre_norm
145
+ self.pre_norm = True
146
+ self.in_channels = in_channels
147
+ out_channels = in_channels if out_channels is None else out_channels
148
+ self.out_channels = out_channels
149
+ self.use_conv_shortcut = conv_shortcut
150
+ self.time_embedding_norm = time_embedding_norm
151
+ self.output_scale_factor = output_scale_factor
152
+
153
+ if groups_out is None:
154
+ groups_out = groups
155
+
156
+ assert use_inflated_groupnorm != None
157
+ if use_inflated_groupnorm:
158
+ self.norm1 = InflatedGroupNorm(
159
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
160
+ )
161
+ else:
162
+ self.norm1 = torch.nn.GroupNorm(
163
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
164
+ )
165
+
166
+ self.conv1 = InflatedConv3d(
167
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
168
+ )
169
+
170
+ if temb_channels is not None:
171
+ if self.time_embedding_norm == "default":
172
+ time_emb_proj_out_channels = out_channels
173
+ elif self.time_embedding_norm == "scale_shift":
174
+ time_emb_proj_out_channels = out_channels * 2
175
+ else:
176
+ raise ValueError(
177
+ f"unknown time_embedding_norm : {self.time_embedding_norm} "
178
+ )
179
+
180
+ self.time_emb_proj = torch.nn.Linear(
181
+ temb_channels, time_emb_proj_out_channels
182
+ )
183
+ else:
184
+ self.time_emb_proj = None
185
+
186
+ if use_inflated_groupnorm:
187
+ self.norm2 = InflatedGroupNorm(
188
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
189
+ )
190
+ else:
191
+ self.norm2 = torch.nn.GroupNorm(
192
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
193
+ )
194
+ self.dropout = torch.nn.Dropout(dropout)
195
+ self.conv2 = InflatedConv3d(
196
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
197
+ )
198
+
199
+ if non_linearity == "swish":
200
+ self.nonlinearity = lambda x: F.silu(x)
201
+ elif non_linearity == "mish":
202
+ self.nonlinearity = Mish()
203
+ elif non_linearity == "silu":
204
+ self.nonlinearity = nn.SiLU()
205
+
206
+ self.use_in_shortcut = (
207
+ self.in_channels != self.out_channels
208
+ if use_in_shortcut is None
209
+ else use_in_shortcut
210
+ )
211
+
212
+ self.conv_shortcut = None
213
+ if self.use_in_shortcut:
214
+ self.conv_shortcut = InflatedConv3d(
215
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
216
+ )
217
+
218
+ def forward(self, input_tensor, temb):
219
+ hidden_states = input_tensor
220
+
221
+ hidden_states = self.norm1(hidden_states)
222
+ hidden_states = self.nonlinearity(hidden_states)
223
+
224
+ hidden_states = self.conv1(hidden_states)
225
+
226
+ if temb is not None:
227
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
228
+
229
+ if temb is not None and self.time_embedding_norm == "default":
230
+ hidden_states = hidden_states + temb
231
+
232
+ hidden_states = self.norm2(hidden_states)
233
+
234
+ if temb is not None and self.time_embedding_norm == "scale_shift":
235
+ scale, shift = torch.chunk(temb, 2, dim=1)
236
+ hidden_states = hidden_states * (1 + scale) + shift
237
+
238
+ hidden_states = self.nonlinearity(hidden_states)
239
+
240
+ hidden_states = self.dropout(hidden_states)
241
+ hidden_states = self.conv2(hidden_states)
242
+
243
+ if self.conv_shortcut is not None:
244
+ input_tensor = self.conv_shortcut(input_tensor)
245
+
246
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
247
+
248
+ return output_tensor
249
+
250
+ class Mish(torch.nn.Module):
251
+ def forward(self, hidden_states):
252
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
src/models/transformer_2d.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, Optional
4
+
5
+ import torch
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.models.embeddings import CaptionProjection
8
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
9
+ from diffusers.models.modeling_utils import ModelMixin
10
+ from diffusers.models.normalization import AdaLayerNormSingle
11
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
12
+ from torch import nn
13
+
14
+ from .attention import BasicTransformerBlock
15
+
16
+
17
+ @dataclass
18
+ class Transformer2DModelOutput(BaseOutput):
19
+ """
20
+ The output of [`Transformer2DModel`].
21
+
22
+ Args:
23
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
24
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
25
+ distributions for the unnoised latent pixels.
26
+ """
27
+
28
+ sample: torch.FloatTensor
29
+ ref_feature: torch.FloatTensor
30
+
31
+
32
+ class Transformer2DModel(ModelMixin, ConfigMixin):
33
+ """
34
+ A 2D Transformer model for image-like data.
35
+
36
+ Parameters:
37
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
38
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
39
+ in_channels (`int`, *optional*):
40
+ The number of channels in the input and output (specify if the input is **continuous**).
41
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
42
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
43
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
44
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
45
+ This is fixed during training since it is used to learn a number of position embeddings.
46
+ num_vector_embeds (`int`, *optional*):
47
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
48
+ Includes the class for the masked latent pixel.
49
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
50
+ num_embeds_ada_norm ( `int`, *optional*):
51
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
52
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
53
+ added to the hidden states.
54
+
55
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
56
+ attention_bias (`bool`, *optional*):
57
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
58
+ """
59
+
60
+ _supports_gradient_checkpointing = True
61
+
62
+ @register_to_config
63
+ def __init__(
64
+ self,
65
+ num_attention_heads: int = 16,
66
+ attention_head_dim: int = 88,
67
+ in_channels: Optional[int] = None,
68
+ out_channels: Optional[int] = None,
69
+ num_layers: int = 1,
70
+ dropout: float = 0.0,
71
+ norm_num_groups: int = 32,
72
+ cross_attention_dim: Optional[int] = None,
73
+ attention_bias: bool = False,
74
+ sample_size: Optional[int] = None,
75
+ num_vector_embeds: Optional[int] = None,
76
+ patch_size: Optional[int] = None,
77
+ activation_fn: str = "geglu",
78
+ num_embeds_ada_norm: Optional[int] = None,
79
+ use_linear_projection: bool = False,
80
+ only_cross_attention: bool = False,
81
+ double_self_attention: bool = False,
82
+ upcast_attention: bool = False,
83
+ norm_type: str = "layer_norm",
84
+ norm_elementwise_affine: bool = True,
85
+ norm_eps: float = 1e-5,
86
+ attention_type: str = "default",
87
+ caption_channels: int = None,
88
+ ):
89
+ super().__init__()
90
+ self.use_linear_projection = use_linear_projection
91
+ self.num_attention_heads = num_attention_heads
92
+ self.attention_head_dim = attention_head_dim
93
+ inner_dim = num_attention_heads * attention_head_dim
94
+
95
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
96
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
97
+
98
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
99
+ # Define whether input is continuous or discrete depending on configuration
100
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
101
+ self.is_input_vectorized = num_vector_embeds is not None
102
+ self.is_input_patches = in_channels is not None and patch_size is not None
103
+
104
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
105
+ deprecation_message = (
106
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
107
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
108
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
109
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
110
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
111
+ )
112
+ deprecate(
113
+ "norm_type!=num_embeds_ada_norm",
114
+ "1.0.0",
115
+ deprecation_message,
116
+ standard_warn=False,
117
+ )
118
+ norm_type = "ada_norm"
119
+
120
+ if self.is_input_continuous and self.is_input_vectorized:
121
+ raise ValueError(
122
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
123
+ " sure that either `in_channels` or `num_vector_embeds` is None."
124
+ )
125
+ elif self.is_input_vectorized and self.is_input_patches:
126
+ raise ValueError(
127
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
128
+ " sure that either `num_vector_embeds` or `num_patches` is None."
129
+ )
130
+ elif (
131
+ not self.is_input_continuous
132
+ and not self.is_input_vectorized
133
+ and not self.is_input_patches
134
+ ):
135
+ raise ValueError(
136
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
137
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
138
+ )
139
+
140
+ # 2. Define input layers
141
+ self.in_channels = in_channels
142
+
143
+ self.norm = torch.nn.GroupNorm(
144
+ num_groups=norm_num_groups,
145
+ num_channels=in_channels,
146
+ eps=1e-6,
147
+ affine=True,
148
+ )
149
+ if use_linear_projection:
150
+ self.proj_in = linear_cls(in_channels, inner_dim)
151
+ else:
152
+ self.proj_in = conv_cls(
153
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
154
+ )
155
+
156
+ # 3. Define transformers blocks
157
+ self.transformer_blocks = nn.ModuleList(
158
+ [
159
+ BasicTransformerBlock(
160
+ inner_dim,
161
+ num_attention_heads,
162
+ attention_head_dim,
163
+ dropout=dropout,
164
+ cross_attention_dim=cross_attention_dim,
165
+ activation_fn=activation_fn,
166
+ num_embeds_ada_norm=num_embeds_ada_norm,
167
+ attention_bias=attention_bias,
168
+ only_cross_attention=only_cross_attention,
169
+ double_self_attention=double_self_attention,
170
+ upcast_attention=upcast_attention,
171
+ norm_type=norm_type,
172
+ norm_elementwise_affine=norm_elementwise_affine,
173
+ norm_eps=norm_eps,
174
+ attention_type=attention_type,
175
+ )
176
+ for d in range(num_layers)
177
+ ]
178
+ )
179
+
180
+ # 4. Define output layers
181
+ self.out_channels = in_channels if out_channels is None else out_channels
182
+ # TODO: should use out_channels for continuous projections
183
+ if use_linear_projection:
184
+ self.proj_out = linear_cls(inner_dim, in_channels)
185
+ else:
186
+ self.proj_out = conv_cls(
187
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0
188
+ )
189
+
190
+ # 5. PixArt-Alpha blocks.
191
+ self.adaln_single = None
192
+ self.use_additional_conditions = False
193
+ if norm_type == "ada_norm_single":
194
+ self.use_additional_conditions = self.config.sample_size == 128
195
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
196
+ # additional conditions until we find better name
197
+ self.adaln_single = AdaLayerNormSingle(
198
+ inner_dim, use_additional_conditions=self.use_additional_conditions
199
+ )
200
+
201
+ self.caption_projection = None
202
+ if caption_channels is not None:
203
+ self.caption_projection = CaptionProjection(
204
+ in_features=caption_channels, hidden_size=inner_dim
205
+ )
206
+
207
+ self.gradient_checkpointing = False
208
+
209
+ def _set_gradient_checkpointing(self, module, value=False):
210
+ if hasattr(module, "gradient_checkpointing"):
211
+ module.gradient_checkpointing = value
212
+
213
+ def forward(
214
+ self,
215
+ hidden_states: torch.Tensor,
216
+ encoder_hidden_states: Optional[torch.Tensor] = None,
217
+ timestep: Optional[torch.LongTensor] = None,
218
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
219
+ class_labels: Optional[torch.LongTensor] = None,
220
+ cross_attention_kwargs: Dict[str, Any] = None,
221
+ attention_mask: Optional[torch.Tensor] = None,
222
+ encoder_attention_mask: Optional[torch.Tensor] = None,
223
+ return_dict: bool = True,
224
+ ):
225
+ """
226
+ The [`Transformer2DModel`] forward method.
227
+
228
+ Args:
229
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
230
+ Input `hidden_states`.
231
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
232
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
233
+ self-attention.
234
+ timestep ( `torch.LongTensor`, *optional*):
235
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
236
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
237
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
238
+ `AdaLayerZeroNorm`.
239
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
240
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
241
+ `self.processor` in
242
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
243
+ attention_mask ( `torch.Tensor`, *optional*):
244
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
245
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
246
+ negative values to the attention scores corresponding to "discard" tokens.
247
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
248
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
249
+
250
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
251
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
252
+
253
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
254
+ above. This bias will be added to the cross-attention scores.
255
+ return_dict (`bool`, *optional*, defaults to `True`):
256
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
257
+ tuple.
258
+
259
+ Returns:
260
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
261
+ `tuple` where the first element is the sample tensor.
262
+ """
263
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
264
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
265
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
266
+ # expects mask of shape:
267
+ # [batch, key_tokens]
268
+ # adds singleton query_tokens dimension:
269
+ # [batch, 1, key_tokens]
270
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
271
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
272
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
273
+ if attention_mask is not None and attention_mask.ndim == 2:
274
+ # assume that mask is expressed as:
275
+ # (1 = keep, 0 = discard)
276
+ # convert mask into a bias that can be added to attention scores:
277
+ # (keep = +0, discard = -10000.0)
278
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
279
+ attention_mask = attention_mask.unsqueeze(1)
280
+
281
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
282
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
283
+ encoder_attention_mask = (
284
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
285
+ ) * -10000.0
286
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
287
+
288
+ # Retrieve lora scale.
289
+ lora_scale = (
290
+ cross_attention_kwargs.get("scale", 1.0)
291
+ if cross_attention_kwargs is not None
292
+ else 1.0
293
+ )
294
+
295
+ # 1. Input
296
+ batch, _, height, width = hidden_states.shape
297
+ residual = hidden_states
298
+
299
+ hidden_states = self.norm(hidden_states)
300
+ if not self.use_linear_projection:
301
+ hidden_states = (
302
+ self.proj_in(hidden_states, scale=lora_scale)
303
+ if not USE_PEFT_BACKEND
304
+ else self.proj_in(hidden_states)
305
+ )
306
+ inner_dim = hidden_states.shape[1]
307
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
308
+ batch, height * width, inner_dim
309
+ )
310
+ else:
311
+ inner_dim = hidden_states.shape[1]
312
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
313
+ batch, height * width, inner_dim
314
+ )
315
+ hidden_states = (
316
+ self.proj_in(hidden_states, scale=lora_scale)
317
+ if not USE_PEFT_BACKEND
318
+ else self.proj_in(hidden_states)
319
+ )
320
+
321
+ # 2. Blocks
322
+ if self.caption_projection is not None:
323
+ batch_size = hidden_states.shape[0]
324
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
325
+ encoder_hidden_states = encoder_hidden_states.view(
326
+ batch_size, -1, hidden_states.shape[-1]
327
+ )
328
+
329
+ ref_feature = hidden_states.reshape(batch, height, width, inner_dim)
330
+ for block in self.transformer_blocks:
331
+ if self.training and self.gradient_checkpointing:
332
+
333
+ def create_custom_forward(module, return_dict=None):
334
+ def custom_forward(*inputs):
335
+ if return_dict is not None:
336
+ return module(*inputs, return_dict=return_dict)
337
+ else:
338
+ return module(*inputs)
339
+
340
+ return custom_forward
341
+
342
+ ckpt_kwargs: Dict[str, Any] = (
343
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
344
+ )
345
+ hidden_states = torch.utils.checkpoint.checkpoint(
346
+ create_custom_forward(block),
347
+ hidden_states,
348
+ attention_mask,
349
+ encoder_hidden_states,
350
+ encoder_attention_mask,
351
+ timestep,
352
+ cross_attention_kwargs,
353
+ class_labels,
354
+ **ckpt_kwargs,
355
+ )
356
+ else:
357
+ hidden_states = block(
358
+ hidden_states,
359
+ attention_mask=attention_mask,
360
+ encoder_hidden_states=encoder_hidden_states,
361
+ encoder_attention_mask=encoder_attention_mask,
362
+ timestep=timestep,
363
+ cross_attention_kwargs=cross_attention_kwargs,
364
+ class_labels=class_labels,
365
+ )
366
+
367
+ # 3. Output
368
+ if self.is_input_continuous:
369
+ if not self.use_linear_projection:
370
+ hidden_states = (
371
+ hidden_states.reshape(batch, height, width, inner_dim)
372
+ .permute(0, 3, 1, 2)
373
+ .contiguous()
374
+ )
375
+ hidden_states = (
376
+ self.proj_out(hidden_states, scale=lora_scale)
377
+ if not USE_PEFT_BACKEND
378
+ else self.proj_out(hidden_states)
379
+ )
380
+ else:
381
+ hidden_states = (
382
+ self.proj_out(hidden_states, scale=lora_scale)
383
+ if not USE_PEFT_BACKEND
384
+ else self.proj_out(hidden_states)
385
+ )
386
+ hidden_states = (
387
+ hidden_states.reshape(batch, height, width, inner_dim)
388
+ .permute(0, 3, 1, 2)
389
+ .contiguous()
390
+ )
391
+
392
+ output = hidden_states + residual
393
+ if not return_dict:
394
+ return (output, ref_feature)
395
+
396
+ return Transformer2DModelOutput(sample=output, ref_feature=ref_feature)
src/models/transformer_3d.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Dict
3
+
4
+ import torch
5
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
6
+ from diffusers.models import ModelMixin
7
+ from diffusers.utils import BaseOutput
8
+ from diffusers.utils.import_utils import is_xformers_available
9
+ from einops import rearrange, repeat
10
+ from torch import nn
11
+
12
+ from .attention import TemporalBasicTransformerBlock, ResidualTemporalBasicTransformerBlock
13
+
14
+
15
+ @dataclass
16
+ class Transformer3DModelOutput(BaseOutput):
17
+ sample: torch.FloatTensor
18
+
19
+
20
+ if is_xformers_available():
21
+ import xformers
22
+ import xformers.ops
23
+ else:
24
+ xformers = None
25
+
26
+
27
+ class Transformer3DModel(ModelMixin, ConfigMixin):
28
+ _supports_gradient_checkpointing = True
29
+
30
+ @register_to_config
31
+ def __init__(
32
+ self,
33
+ num_attention_heads: int = 16,
34
+ attention_head_dim: int = 88,
35
+ in_channels: Optional[int] = None,
36
+ num_layers: int = 1,
37
+ dropout: float = 0.0,
38
+ norm_num_groups: int = 32,
39
+ cross_attention_dim: Optional[int] = None,
40
+ attention_bias: bool = False,
41
+ activation_fn: str = "geglu",
42
+ num_embeds_ada_norm: Optional[int] = None,
43
+ use_linear_projection: bool = False,
44
+ only_cross_attention: bool = False,
45
+ upcast_attention: bool = False,
46
+ unet_use_cross_frame_attention=None,
47
+ unet_use_temporal_attention=None,
48
+ ):
49
+ super().__init__()
50
+ self.use_linear_projection = use_linear_projection
51
+ self.num_attention_heads = num_attention_heads
52
+ self.attention_head_dim = attention_head_dim
53
+ inner_dim = num_attention_heads * attention_head_dim
54
+
55
+ # Define input layers
56
+ self.in_channels = in_channels
57
+
58
+ self.norm = torch.nn.GroupNorm(
59
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
60
+ )
61
+ if use_linear_projection:
62
+ self.proj_in = nn.Linear(in_channels, inner_dim)
63
+ else:
64
+ self.proj_in = nn.Conv2d(
65
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
66
+ )
67
+
68
+ # Define transformers blocks
69
+ self.transformer_blocks = nn.ModuleList(
70
+ [
71
+ TemporalBasicTransformerBlock(
72
+ inner_dim,
73
+ num_attention_heads,
74
+ attention_head_dim,
75
+ dropout=dropout,
76
+ cross_attention_dim=cross_attention_dim,
77
+ activation_fn=activation_fn,
78
+ num_embeds_ada_norm=num_embeds_ada_norm,
79
+ attention_bias=attention_bias,
80
+ only_cross_attention=only_cross_attention,
81
+ upcast_attention=upcast_attention,
82
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
83
+ unet_use_temporal_attention=unet_use_temporal_attention,
84
+ )
85
+ for d in range(num_layers)
86
+ ]
87
+ )
88
+
89
+ # 4. Define output layers
90
+ if use_linear_projection:
91
+ self.proj_out = nn.Linear(in_channels, inner_dim)
92
+ else:
93
+ self.proj_out = nn.Conv2d(
94
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0
95
+ )
96
+
97
+ self.gradient_checkpointing = False
98
+
99
+ def _set_gradient_checkpointing(self, module, value=False):
100
+ if hasattr(module, "gradient_checkpointing"):
101
+ module.gradient_checkpointing = value
102
+
103
+ def forward(
104
+ self,
105
+ hidden_states,
106
+ encoder_hidden_states=None,
107
+ timestep=None,
108
+ return_dict: bool = True,
109
+ ):
110
+ # Input
111
+ assert (
112
+ hidden_states.dim() == 5
113
+ ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
114
+ video_length = hidden_states.shape[2]
115
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
116
+ if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
117
+ encoder_hidden_states = repeat(
118
+ encoder_hidden_states, "b n c -> (b f) n c", f=video_length
119
+ )
120
+
121
+ batch, channel, height, weight = hidden_states.shape
122
+ residual = hidden_states
123
+
124
+ hidden_states = self.norm(hidden_states)
125
+ if not self.use_linear_projection:
126
+ hidden_states = self.proj_in(hidden_states)
127
+ inner_dim = hidden_states.shape[1]
128
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
129
+ batch, height * weight, inner_dim
130
+ )
131
+ else:
132
+ inner_dim = hidden_states.shape[1]
133
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
134
+ batch, height * weight, inner_dim
135
+ )
136
+ hidden_states = self.proj_in(hidden_states)
137
+
138
+ # Blocks
139
+ for i, block in enumerate(self.transformer_blocks):
140
+ hidden_states = block(
141
+ hidden_states,
142
+ encoder_hidden_states=encoder_hidden_states,
143
+ timestep=timestep,
144
+ video_length=video_length,
145
+ )
146
+
147
+ # Output
148
+ if not self.use_linear_projection:
149
+ hidden_states = (
150
+ hidden_states.reshape(batch, height, weight, inner_dim)
151
+ .permute(0, 3, 1, 2)
152
+ .contiguous()
153
+ )
154
+ hidden_states = self.proj_out(hidden_states)
155
+ else:
156
+ hidden_states = self.proj_out(hidden_states)
157
+ hidden_states = (
158
+ hidden_states.reshape(batch, height, weight, inner_dim)
159
+ .permute(0, 3, 1, 2)
160
+ .contiguous()
161
+ )
162
+
163
+ output = hidden_states + residual
164
+
165
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
166
+ if not return_dict:
167
+ return (output,)
168
+
169
+ return Transformer3DModelOutput(sample=output)
src/models/unet_2d_blocks.py ADDED
@@ -0,0 +1,1074 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+ from typing import Any, Dict, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from diffusers.models.activations import get_activation
8
+ from diffusers.models.attention_processor import Attention
9
+ from diffusers.models.dual_transformer_2d import DualTransformer2DModel
10
+ from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
11
+ from diffusers.utils import is_torch_version, logging
12
+ from diffusers.utils.torch_utils import apply_freeu
13
+ from torch import nn
14
+
15
+ from .transformer_2d import Transformer2DModel
16
+
17
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18
+
19
+
20
+ def get_down_block(
21
+ down_block_type: str,
22
+ num_layers: int,
23
+ in_channels: int,
24
+ out_channels: int,
25
+ temb_channels: int,
26
+ add_downsample: bool,
27
+ resnet_eps: float,
28
+ resnet_act_fn: str,
29
+ transformer_layers_per_block: int = 1,
30
+ num_attention_heads: Optional[int] = None,
31
+ resnet_groups: Optional[int] = None,
32
+ cross_attention_dim: Optional[int] = None,
33
+ downsample_padding: Optional[int] = None,
34
+ dual_cross_attention: bool = False,
35
+ use_linear_projection: bool = False,
36
+ only_cross_attention: bool = False,
37
+ upcast_attention: bool = False,
38
+ resnet_time_scale_shift: str = "default",
39
+ attention_type: str = "default",
40
+ resnet_skip_time_act: bool = False,
41
+ resnet_out_scale_factor: float = 1.0,
42
+ cross_attention_norm: Optional[str] = None,
43
+ attention_head_dim: Optional[int] = None,
44
+ downsample_type: Optional[str] = None,
45
+ dropout: float = 0.0,
46
+ ):
47
+ # If attn head dim is not defined, we default it to the number of heads
48
+ if attention_head_dim is None:
49
+ logger.warn(
50
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
51
+ )
52
+ attention_head_dim = num_attention_heads
53
+
54
+ down_block_type = (
55
+ down_block_type[7:]
56
+ if down_block_type.startswith("UNetRes")
57
+ else down_block_type
58
+ )
59
+ if down_block_type == "DownBlock2D":
60
+ return DownBlock2D(
61
+ num_layers=num_layers,
62
+ in_channels=in_channels,
63
+ out_channels=out_channels,
64
+ temb_channels=temb_channels,
65
+ dropout=dropout,
66
+ add_downsample=add_downsample,
67
+ resnet_eps=resnet_eps,
68
+ resnet_act_fn=resnet_act_fn,
69
+ resnet_groups=resnet_groups,
70
+ downsample_padding=downsample_padding,
71
+ resnet_time_scale_shift=resnet_time_scale_shift,
72
+ )
73
+ elif down_block_type == "CrossAttnDownBlock2D":
74
+ if cross_attention_dim is None:
75
+ raise ValueError(
76
+ "cross_attention_dim must be specified for CrossAttnDownBlock2D"
77
+ )
78
+ return CrossAttnDownBlock2D(
79
+ num_layers=num_layers,
80
+ transformer_layers_per_block=transformer_layers_per_block,
81
+ in_channels=in_channels,
82
+ out_channels=out_channels,
83
+ temb_channels=temb_channels,
84
+ dropout=dropout,
85
+ add_downsample=add_downsample,
86
+ resnet_eps=resnet_eps,
87
+ resnet_act_fn=resnet_act_fn,
88
+ resnet_groups=resnet_groups,
89
+ downsample_padding=downsample_padding,
90
+ cross_attention_dim=cross_attention_dim,
91
+ num_attention_heads=num_attention_heads,
92
+ dual_cross_attention=dual_cross_attention,
93
+ use_linear_projection=use_linear_projection,
94
+ only_cross_attention=only_cross_attention,
95
+ upcast_attention=upcast_attention,
96
+ resnet_time_scale_shift=resnet_time_scale_shift,
97
+ attention_type=attention_type,
98
+ )
99
+ raise ValueError(f"{down_block_type} does not exist.")
100
+
101
+
102
+ def get_up_block(
103
+ up_block_type: str,
104
+ num_layers: int,
105
+ in_channels: int,
106
+ out_channels: int,
107
+ prev_output_channel: int,
108
+ temb_channels: int,
109
+ add_upsample: bool,
110
+ resnet_eps: float,
111
+ resnet_act_fn: str,
112
+ resolution_idx: Optional[int] = None,
113
+ transformer_layers_per_block: int = 1,
114
+ num_attention_heads: Optional[int] = None,
115
+ resnet_groups: Optional[int] = None,
116
+ cross_attention_dim: Optional[int] = None,
117
+ dual_cross_attention: bool = False,
118
+ use_linear_projection: bool = False,
119
+ only_cross_attention: bool = False,
120
+ upcast_attention: bool = False,
121
+ resnet_time_scale_shift: str = "default",
122
+ attention_type: str = "default",
123
+ resnet_skip_time_act: bool = False,
124
+ resnet_out_scale_factor: float = 1.0,
125
+ cross_attention_norm: Optional[str] = None,
126
+ attention_head_dim: Optional[int] = None,
127
+ upsample_type: Optional[str] = None,
128
+ dropout: float = 0.0,
129
+ ) -> nn.Module:
130
+ # If attn head dim is not defined, we default it to the number of heads
131
+ if attention_head_dim is None:
132
+ logger.warn(
133
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
134
+ )
135
+ attention_head_dim = num_attention_heads
136
+
137
+ up_block_type = (
138
+ up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
139
+ )
140
+ if up_block_type == "UpBlock2D":
141
+ return UpBlock2D(
142
+ num_layers=num_layers,
143
+ in_channels=in_channels,
144
+ out_channels=out_channels,
145
+ prev_output_channel=prev_output_channel,
146
+ temb_channels=temb_channels,
147
+ resolution_idx=resolution_idx,
148
+ dropout=dropout,
149
+ add_upsample=add_upsample,
150
+ resnet_eps=resnet_eps,
151
+ resnet_act_fn=resnet_act_fn,
152
+ resnet_groups=resnet_groups,
153
+ resnet_time_scale_shift=resnet_time_scale_shift,
154
+ )
155
+ elif up_block_type == "CrossAttnUpBlock2D":
156
+ if cross_attention_dim is None:
157
+ raise ValueError(
158
+ "cross_attention_dim must be specified for CrossAttnUpBlock2D"
159
+ )
160
+ return CrossAttnUpBlock2D(
161
+ num_layers=num_layers,
162
+ transformer_layers_per_block=transformer_layers_per_block,
163
+ in_channels=in_channels,
164
+ out_channels=out_channels,
165
+ prev_output_channel=prev_output_channel,
166
+ temb_channels=temb_channels,
167
+ resolution_idx=resolution_idx,
168
+ dropout=dropout,
169
+ add_upsample=add_upsample,
170
+ resnet_eps=resnet_eps,
171
+ resnet_act_fn=resnet_act_fn,
172
+ resnet_groups=resnet_groups,
173
+ cross_attention_dim=cross_attention_dim,
174
+ num_attention_heads=num_attention_heads,
175
+ dual_cross_attention=dual_cross_attention,
176
+ use_linear_projection=use_linear_projection,
177
+ only_cross_attention=only_cross_attention,
178
+ upcast_attention=upcast_attention,
179
+ resnet_time_scale_shift=resnet_time_scale_shift,
180
+ attention_type=attention_type,
181
+ )
182
+
183
+ raise ValueError(f"{up_block_type} does not exist.")
184
+
185
+
186
+ class AutoencoderTinyBlock(nn.Module):
187
+ """
188
+ Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU
189
+ blocks.
190
+
191
+ Args:
192
+ in_channels (`int`): The number of input channels.
193
+ out_channels (`int`): The number of output channels.
194
+ act_fn (`str`):
195
+ ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.
196
+
197
+ Returns:
198
+ `torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
199
+ `out_channels`.
200
+ """
201
+
202
+ def __init__(self, in_channels: int, out_channels: int, act_fn: str):
203
+ super().__init__()
204
+ act_fn = get_activation(act_fn)
205
+ self.conv = nn.Sequential(
206
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
207
+ act_fn,
208
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
209
+ act_fn,
210
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
211
+ )
212
+ self.skip = (
213
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
214
+ if in_channels != out_channels
215
+ else nn.Identity()
216
+ )
217
+ self.fuse = nn.ReLU()
218
+
219
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
220
+ return self.fuse(self.conv(x) + self.skip(x))
221
+
222
+
223
+ class UNetMidBlock2D(nn.Module):
224
+ """
225
+ A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
226
+
227
+ Args:
228
+ in_channels (`int`): The number of input channels.
229
+ temb_channels (`int`): The number of temporal embedding channels.
230
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
231
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
232
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
233
+ resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
234
+ The type of normalization to apply to the time embeddings. This can help to improve the performance of the
235
+ model on tasks with long-range temporal dependencies.
236
+ resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
237
+ resnet_groups (`int`, *optional*, defaults to 32):
238
+ The number of groups to use in the group normalization layers of the resnet blocks.
239
+ attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
240
+ resnet_pre_norm (`bool`, *optional*, defaults to `True`):
241
+ Whether to use pre-normalization for the resnet blocks.
242
+ add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
243
+ attention_head_dim (`int`, *optional*, defaults to 1):
244
+ Dimension of a single attention head. The number of attention heads is determined based on this value and
245
+ the number of input channels.
246
+ output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
247
+
248
+ Returns:
249
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
250
+ in_channels, height, width)`.
251
+
252
+ """
253
+
254
+ def __init__(
255
+ self,
256
+ in_channels: int,
257
+ temb_channels: int,
258
+ dropout: float = 0.0,
259
+ num_layers: int = 1,
260
+ resnet_eps: float = 1e-6,
261
+ resnet_time_scale_shift: str = "default", # default, spatial
262
+ resnet_act_fn: str = "swish",
263
+ resnet_groups: int = 32,
264
+ attn_groups: Optional[int] = None,
265
+ resnet_pre_norm: bool = True,
266
+ add_attention: bool = True,
267
+ attention_head_dim: int = 1,
268
+ output_scale_factor: float = 1.0,
269
+ ):
270
+ super().__init__()
271
+ resnet_groups = (
272
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
273
+ )
274
+ self.add_attention = add_attention
275
+
276
+ if attn_groups is None:
277
+ attn_groups = (
278
+ resnet_groups if resnet_time_scale_shift == "default" else None
279
+ )
280
+
281
+ # there is always at least one resnet
282
+ resnets = [
283
+ ResnetBlock2D(
284
+ in_channels=in_channels,
285
+ out_channels=in_channels,
286
+ temb_channels=temb_channels,
287
+ eps=resnet_eps,
288
+ groups=resnet_groups,
289
+ dropout=dropout,
290
+ time_embedding_norm=resnet_time_scale_shift,
291
+ non_linearity=resnet_act_fn,
292
+ output_scale_factor=output_scale_factor,
293
+ pre_norm=resnet_pre_norm,
294
+ )
295
+ ]
296
+ attentions = []
297
+
298
+ if attention_head_dim is None:
299
+ logger.warn(
300
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
301
+ )
302
+ attention_head_dim = in_channels
303
+
304
+ for _ in range(num_layers):
305
+ if self.add_attention:
306
+ attentions.append(
307
+ Attention(
308
+ in_channels,
309
+ heads=in_channels // attention_head_dim,
310
+ dim_head=attention_head_dim,
311
+ rescale_output_factor=output_scale_factor,
312
+ eps=resnet_eps,
313
+ norm_num_groups=attn_groups,
314
+ spatial_norm_dim=temb_channels
315
+ if resnet_time_scale_shift == "spatial"
316
+ else None,
317
+ residual_connection=True,
318
+ bias=True,
319
+ upcast_softmax=True,
320
+ _from_deprecated_attn_block=True,
321
+ )
322
+ )
323
+ else:
324
+ attentions.append(None)
325
+
326
+ resnets.append(
327
+ ResnetBlock2D(
328
+ in_channels=in_channels,
329
+ out_channels=in_channels,
330
+ temb_channels=temb_channels,
331
+ eps=resnet_eps,
332
+ groups=resnet_groups,
333
+ dropout=dropout,
334
+ time_embedding_norm=resnet_time_scale_shift,
335
+ non_linearity=resnet_act_fn,
336
+ output_scale_factor=output_scale_factor,
337
+ pre_norm=resnet_pre_norm,
338
+ )
339
+ )
340
+
341
+ self.attentions = nn.ModuleList(attentions)
342
+ self.resnets = nn.ModuleList(resnets)
343
+
344
+ def forward(
345
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None
346
+ ) -> torch.FloatTensor:
347
+ hidden_states = self.resnets[0](hidden_states, temb)
348
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
349
+ if attn is not None:
350
+ hidden_states = attn(hidden_states, temb=temb)
351
+ hidden_states = resnet(hidden_states, temb)
352
+
353
+ return hidden_states
354
+
355
+
356
+ class UNetMidBlock2DCrossAttn(nn.Module):
357
+ def __init__(
358
+ self,
359
+ in_channels: int,
360
+ temb_channels: int,
361
+ dropout: float = 0.0,
362
+ num_layers: int = 1,
363
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
364
+ resnet_eps: float = 1e-6,
365
+ resnet_time_scale_shift: str = "default",
366
+ resnet_act_fn: str = "swish",
367
+ resnet_groups: int = 32,
368
+ resnet_pre_norm: bool = True,
369
+ num_attention_heads: int = 1,
370
+ output_scale_factor: float = 1.0,
371
+ cross_attention_dim: int = 1280,
372
+ dual_cross_attention: bool = False,
373
+ use_linear_projection: bool = False,
374
+ upcast_attention: bool = False,
375
+ attention_type: str = "default",
376
+ ):
377
+ super().__init__()
378
+
379
+ self.has_cross_attention = True
380
+ self.num_attention_heads = num_attention_heads
381
+ resnet_groups = (
382
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
383
+ )
384
+
385
+ # support for variable transformer layers per block
386
+ if isinstance(transformer_layers_per_block, int):
387
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
388
+
389
+ # there is always at least one resnet
390
+ resnets = [
391
+ ResnetBlock2D(
392
+ in_channels=in_channels,
393
+ out_channels=in_channels,
394
+ temb_channels=temb_channels,
395
+ eps=resnet_eps,
396
+ groups=resnet_groups,
397
+ dropout=dropout,
398
+ time_embedding_norm=resnet_time_scale_shift,
399
+ non_linearity=resnet_act_fn,
400
+ output_scale_factor=output_scale_factor,
401
+ pre_norm=resnet_pre_norm,
402
+ )
403
+ ]
404
+ attentions = []
405
+
406
+ for i in range(num_layers):
407
+ if not dual_cross_attention:
408
+ attentions.append(
409
+ Transformer2DModel(
410
+ num_attention_heads,
411
+ in_channels // num_attention_heads,
412
+ in_channels=in_channels,
413
+ num_layers=transformer_layers_per_block[i],
414
+ cross_attention_dim=cross_attention_dim,
415
+ norm_num_groups=resnet_groups,
416
+ use_linear_projection=use_linear_projection,
417
+ upcast_attention=upcast_attention,
418
+ attention_type=attention_type,
419
+ )
420
+ )
421
+ else:
422
+ attentions.append(
423
+ DualTransformer2DModel(
424
+ num_attention_heads,
425
+ in_channels // num_attention_heads,
426
+ in_channels=in_channels,
427
+ num_layers=1,
428
+ cross_attention_dim=cross_attention_dim,
429
+ norm_num_groups=resnet_groups,
430
+ )
431
+ )
432
+ resnets.append(
433
+ ResnetBlock2D(
434
+ in_channels=in_channels,
435
+ out_channels=in_channels,
436
+ temb_channels=temb_channels,
437
+ eps=resnet_eps,
438
+ groups=resnet_groups,
439
+ dropout=dropout,
440
+ time_embedding_norm=resnet_time_scale_shift,
441
+ non_linearity=resnet_act_fn,
442
+ output_scale_factor=output_scale_factor,
443
+ pre_norm=resnet_pre_norm,
444
+ )
445
+ )
446
+
447
+ self.attentions = nn.ModuleList(attentions)
448
+ self.resnets = nn.ModuleList(resnets)
449
+
450
+ self.gradient_checkpointing = False
451
+
452
+ def forward(
453
+ self,
454
+ hidden_states: torch.FloatTensor,
455
+ temb: Optional[torch.FloatTensor] = None,
456
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
457
+ attention_mask: Optional[torch.FloatTensor] = None,
458
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
459
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
460
+ ) -> torch.FloatTensor:
461
+ lora_scale = (
462
+ cross_attention_kwargs.get("scale", 1.0)
463
+ if cross_attention_kwargs is not None
464
+ else 1.0
465
+ )
466
+ hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
467
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
468
+ if self.training and self.gradient_checkpointing:
469
+
470
+ def create_custom_forward(module, return_dict=None):
471
+ def custom_forward(*inputs):
472
+ if return_dict is not None:
473
+ return module(*inputs, return_dict=return_dict)
474
+ else:
475
+ return module(*inputs)
476
+
477
+ return custom_forward
478
+
479
+ ckpt_kwargs: Dict[str, Any] = (
480
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
481
+ )
482
+ hidden_states, ref_feature = attn(
483
+ hidden_states,
484
+ encoder_hidden_states=encoder_hidden_states,
485
+ cross_attention_kwargs=cross_attention_kwargs,
486
+ attention_mask=attention_mask,
487
+ encoder_attention_mask=encoder_attention_mask,
488
+ return_dict=False,
489
+ )
490
+ hidden_states = torch.utils.checkpoint.checkpoint(
491
+ create_custom_forward(resnet),
492
+ hidden_states,
493
+ temb,
494
+ **ckpt_kwargs,
495
+ )
496
+ else:
497
+ hidden_states, ref_feature = attn(
498
+ hidden_states,
499
+ encoder_hidden_states=encoder_hidden_states,
500
+ cross_attention_kwargs=cross_attention_kwargs,
501
+ attention_mask=attention_mask,
502
+ encoder_attention_mask=encoder_attention_mask,
503
+ return_dict=False,
504
+ )
505
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
506
+
507
+ return hidden_states
508
+
509
+
510
+ class CrossAttnDownBlock2D(nn.Module):
511
+ def __init__(
512
+ self,
513
+ in_channels: int,
514
+ out_channels: int,
515
+ temb_channels: int,
516
+ dropout: float = 0.0,
517
+ num_layers: int = 1,
518
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
519
+ resnet_eps: float = 1e-6,
520
+ resnet_time_scale_shift: str = "default",
521
+ resnet_act_fn: str = "swish",
522
+ resnet_groups: int = 32,
523
+ resnet_pre_norm: bool = True,
524
+ num_attention_heads: int = 1,
525
+ cross_attention_dim: int = 1280,
526
+ output_scale_factor: float = 1.0,
527
+ downsample_padding: int = 1,
528
+ add_downsample: bool = True,
529
+ dual_cross_attention: bool = False,
530
+ use_linear_projection: bool = False,
531
+ only_cross_attention: bool = False,
532
+ upcast_attention: bool = False,
533
+ attention_type: str = "default",
534
+ ):
535
+ super().__init__()
536
+ resnets = []
537
+ attentions = []
538
+
539
+ self.has_cross_attention = True
540
+ self.num_attention_heads = num_attention_heads
541
+ if isinstance(transformer_layers_per_block, int):
542
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
543
+
544
+ for i in range(num_layers):
545
+ in_channels = in_channels if i == 0 else out_channels
546
+ resnets.append(
547
+ ResnetBlock2D(
548
+ in_channels=in_channels,
549
+ out_channels=out_channels,
550
+ temb_channels=temb_channels,
551
+ eps=resnet_eps,
552
+ groups=resnet_groups,
553
+ dropout=dropout,
554
+ time_embedding_norm=resnet_time_scale_shift,
555
+ non_linearity=resnet_act_fn,
556
+ output_scale_factor=output_scale_factor,
557
+ pre_norm=resnet_pre_norm,
558
+ )
559
+ )
560
+ if not dual_cross_attention:
561
+ attentions.append(
562
+ Transformer2DModel(
563
+ num_attention_heads,
564
+ out_channels // num_attention_heads,
565
+ in_channels=out_channels,
566
+ num_layers=transformer_layers_per_block[i],
567
+ cross_attention_dim=cross_attention_dim,
568
+ norm_num_groups=resnet_groups,
569
+ use_linear_projection=use_linear_projection,
570
+ only_cross_attention=only_cross_attention,
571
+ upcast_attention=upcast_attention,
572
+ attention_type=attention_type,
573
+ )
574
+ )
575
+ else:
576
+ attentions.append(
577
+ DualTransformer2DModel(
578
+ num_attention_heads,
579
+ out_channels // num_attention_heads,
580
+ in_channels=out_channels,
581
+ num_layers=1,
582
+ cross_attention_dim=cross_attention_dim,
583
+ norm_num_groups=resnet_groups,
584
+ )
585
+ )
586
+ self.attentions = nn.ModuleList(attentions)
587
+ self.resnets = nn.ModuleList(resnets)
588
+
589
+ if add_downsample:
590
+ self.downsamplers = nn.ModuleList(
591
+ [
592
+ Downsample2D(
593
+ out_channels,
594
+ use_conv=True,
595
+ out_channels=out_channels,
596
+ padding=downsample_padding,
597
+ name="op",
598
+ )
599
+ ]
600
+ )
601
+ else:
602
+ self.downsamplers = None
603
+
604
+ self.gradient_checkpointing = False
605
+
606
+ def forward(
607
+ self,
608
+ hidden_states: torch.FloatTensor,
609
+ temb: Optional[torch.FloatTensor] = None,
610
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
611
+ attention_mask: Optional[torch.FloatTensor] = None,
612
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
613
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
614
+ additional_residuals: Optional[torch.FloatTensor] = None,
615
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
616
+ output_states = ()
617
+
618
+ lora_scale = (
619
+ cross_attention_kwargs.get("scale", 1.0)
620
+ if cross_attention_kwargs is not None
621
+ else 1.0
622
+ )
623
+
624
+ blocks = list(zip(self.resnets, self.attentions))
625
+
626
+ for i, (resnet, attn) in enumerate(blocks):
627
+ if self.training and self.gradient_checkpointing:
628
+
629
+ def create_custom_forward(module, return_dict=None):
630
+ def custom_forward(*inputs):
631
+ if return_dict is not None:
632
+ return module(*inputs, return_dict=return_dict)
633
+ else:
634
+ return module(*inputs)
635
+
636
+ return custom_forward
637
+
638
+ ckpt_kwargs: Dict[str, Any] = (
639
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
640
+ )
641
+ hidden_states = torch.utils.checkpoint.checkpoint(
642
+ create_custom_forward(resnet),
643
+ hidden_states,
644
+ temb,
645
+ **ckpt_kwargs,
646
+ )
647
+ hidden_states, ref_feature = attn(
648
+ hidden_states,
649
+ encoder_hidden_states=encoder_hidden_states,
650
+ cross_attention_kwargs=cross_attention_kwargs,
651
+ attention_mask=attention_mask,
652
+ encoder_attention_mask=encoder_attention_mask,
653
+ return_dict=False,
654
+ )
655
+ else:
656
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
657
+ hidden_states, ref_feature = attn(
658
+ hidden_states,
659
+ encoder_hidden_states=encoder_hidden_states,
660
+ cross_attention_kwargs=cross_attention_kwargs,
661
+ attention_mask=attention_mask,
662
+ encoder_attention_mask=encoder_attention_mask,
663
+ return_dict=False,
664
+ )
665
+
666
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
667
+ if i == len(blocks) - 1 and additional_residuals is not None:
668
+ hidden_states = hidden_states + additional_residuals
669
+
670
+ output_states = output_states + (hidden_states,)
671
+
672
+ if self.downsamplers is not None:
673
+ for downsampler in self.downsamplers:
674
+ hidden_states = downsampler(hidden_states, scale=lora_scale)
675
+
676
+ output_states = output_states + (hidden_states,)
677
+
678
+ return hidden_states, output_states
679
+
680
+
681
+ class DownBlock2D(nn.Module):
682
+ def __init__(
683
+ self,
684
+ in_channels: int,
685
+ out_channels: int,
686
+ temb_channels: int,
687
+ dropout: float = 0.0,
688
+ num_layers: int = 1,
689
+ resnet_eps: float = 1e-6,
690
+ resnet_time_scale_shift: str = "default",
691
+ resnet_act_fn: str = "swish",
692
+ resnet_groups: int = 32,
693
+ resnet_pre_norm: bool = True,
694
+ output_scale_factor: float = 1.0,
695
+ add_downsample: bool = True,
696
+ downsample_padding: int = 1,
697
+ ):
698
+ super().__init__()
699
+ resnets = []
700
+
701
+ for i in range(num_layers):
702
+ in_channels = in_channels if i == 0 else out_channels
703
+ resnets.append(
704
+ ResnetBlock2D(
705
+ in_channels=in_channels,
706
+ out_channels=out_channels,
707
+ temb_channels=temb_channels,
708
+ eps=resnet_eps,
709
+ groups=resnet_groups,
710
+ dropout=dropout,
711
+ time_embedding_norm=resnet_time_scale_shift,
712
+ non_linearity=resnet_act_fn,
713
+ output_scale_factor=output_scale_factor,
714
+ pre_norm=resnet_pre_norm,
715
+ )
716
+ )
717
+
718
+ self.resnets = nn.ModuleList(resnets)
719
+
720
+ if add_downsample:
721
+ self.downsamplers = nn.ModuleList(
722
+ [
723
+ Downsample2D(
724
+ out_channels,
725
+ use_conv=True,
726
+ out_channels=out_channels,
727
+ padding=downsample_padding,
728
+ name="op",
729
+ )
730
+ ]
731
+ )
732
+ else:
733
+ self.downsamplers = None
734
+
735
+ self.gradient_checkpointing = False
736
+
737
+ def forward(
738
+ self,
739
+ hidden_states: torch.FloatTensor,
740
+ temb: Optional[torch.FloatTensor] = None,
741
+ scale: float = 1.0,
742
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
743
+ output_states = ()
744
+
745
+ for resnet in self.resnets:
746
+ if self.training and self.gradient_checkpointing:
747
+
748
+ def create_custom_forward(module):
749
+ def custom_forward(*inputs):
750
+ return module(*inputs)
751
+
752
+ return custom_forward
753
+
754
+ if is_torch_version(">=", "1.11.0"):
755
+ hidden_states = torch.utils.checkpoint.checkpoint(
756
+ create_custom_forward(resnet),
757
+ hidden_states,
758
+ temb,
759
+ use_reentrant=False,
760
+ )
761
+ else:
762
+ hidden_states = torch.utils.checkpoint.checkpoint(
763
+ create_custom_forward(resnet), hidden_states, temb
764
+ )
765
+ else:
766
+ hidden_states = resnet(hidden_states, temb, scale=scale)
767
+
768
+ output_states = output_states + (hidden_states,)
769
+
770
+ if self.downsamplers is not None:
771
+ for downsampler in self.downsamplers:
772
+ hidden_states = downsampler(hidden_states, scale=scale)
773
+
774
+ output_states = output_states + (hidden_states,)
775
+
776
+ return hidden_states, output_states
777
+
778
+
779
+ class CrossAttnUpBlock2D(nn.Module):
780
+ def __init__(
781
+ self,
782
+ in_channels: int,
783
+ out_channels: int,
784
+ prev_output_channel: int,
785
+ temb_channels: int,
786
+ resolution_idx: Optional[int] = None,
787
+ dropout: float = 0.0,
788
+ num_layers: int = 1,
789
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
790
+ resnet_eps: float = 1e-6,
791
+ resnet_time_scale_shift: str = "default",
792
+ resnet_act_fn: str = "swish",
793
+ resnet_groups: int = 32,
794
+ resnet_pre_norm: bool = True,
795
+ num_attention_heads: int = 1,
796
+ cross_attention_dim: int = 1280,
797
+ output_scale_factor: float = 1.0,
798
+ add_upsample: bool = True,
799
+ dual_cross_attention: bool = False,
800
+ use_linear_projection: bool = False,
801
+ only_cross_attention: bool = False,
802
+ upcast_attention: bool = False,
803
+ attention_type: str = "default",
804
+ ):
805
+ super().__init__()
806
+ resnets = []
807
+ attentions = []
808
+
809
+ self.has_cross_attention = True
810
+ self.num_attention_heads = num_attention_heads
811
+
812
+ if isinstance(transformer_layers_per_block, int):
813
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
814
+
815
+ for i in range(num_layers):
816
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
817
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
818
+
819
+ resnets.append(
820
+ ResnetBlock2D(
821
+ in_channels=resnet_in_channels + res_skip_channels,
822
+ out_channels=out_channels,
823
+ temb_channels=temb_channels,
824
+ eps=resnet_eps,
825
+ groups=resnet_groups,
826
+ dropout=dropout,
827
+ time_embedding_norm=resnet_time_scale_shift,
828
+ non_linearity=resnet_act_fn,
829
+ output_scale_factor=output_scale_factor,
830
+ pre_norm=resnet_pre_norm,
831
+ )
832
+ )
833
+ if not dual_cross_attention:
834
+ attentions.append(
835
+ Transformer2DModel(
836
+ num_attention_heads,
837
+ out_channels // num_attention_heads,
838
+ in_channels=out_channels,
839
+ num_layers=transformer_layers_per_block[i],
840
+ cross_attention_dim=cross_attention_dim,
841
+ norm_num_groups=resnet_groups,
842
+ use_linear_projection=use_linear_projection,
843
+ only_cross_attention=only_cross_attention,
844
+ upcast_attention=upcast_attention,
845
+ attention_type=attention_type,
846
+ )
847
+ )
848
+ else:
849
+ attentions.append(
850
+ DualTransformer2DModel(
851
+ num_attention_heads,
852
+ out_channels // num_attention_heads,
853
+ in_channels=out_channels,
854
+ num_layers=1,
855
+ cross_attention_dim=cross_attention_dim,
856
+ norm_num_groups=resnet_groups,
857
+ )
858
+ )
859
+ self.attentions = nn.ModuleList(attentions)
860
+ self.resnets = nn.ModuleList(resnets)
861
+
862
+ if add_upsample:
863
+ self.upsamplers = nn.ModuleList(
864
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
865
+ )
866
+ else:
867
+ self.upsamplers = None
868
+
869
+ self.gradient_checkpointing = False
870
+ self.resolution_idx = resolution_idx
871
+
872
+ def forward(
873
+ self,
874
+ hidden_states: torch.FloatTensor,
875
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
876
+ temb: Optional[torch.FloatTensor] = None,
877
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
878
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
879
+ upsample_size: Optional[int] = None,
880
+ attention_mask: Optional[torch.FloatTensor] = None,
881
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
882
+ ) -> torch.FloatTensor:
883
+ lora_scale = (
884
+ cross_attention_kwargs.get("scale", 1.0)
885
+ if cross_attention_kwargs is not None
886
+ else 1.0
887
+ )
888
+ is_freeu_enabled = (
889
+ getattr(self, "s1", None)
890
+ and getattr(self, "s2", None)
891
+ and getattr(self, "b1", None)
892
+ and getattr(self, "b2", None)
893
+ )
894
+
895
+ for resnet, attn in zip(self.resnets, self.attentions):
896
+ # pop res hidden states
897
+ res_hidden_states = res_hidden_states_tuple[-1]
898
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
899
+
900
+ # FreeU: Only operate on the first two stages
901
+ if is_freeu_enabled:
902
+ hidden_states, res_hidden_states = apply_freeu(
903
+ self.resolution_idx,
904
+ hidden_states,
905
+ res_hidden_states,
906
+ s1=self.s1,
907
+ s2=self.s2,
908
+ b1=self.b1,
909
+ b2=self.b2,
910
+ )
911
+
912
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
913
+
914
+ if self.training and self.gradient_checkpointing:
915
+
916
+ def create_custom_forward(module, return_dict=None):
917
+ def custom_forward(*inputs):
918
+ if return_dict is not None:
919
+ return module(*inputs, return_dict=return_dict)
920
+ else:
921
+ return module(*inputs)
922
+
923
+ return custom_forward
924
+
925
+ ckpt_kwargs: Dict[str, Any] = (
926
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
927
+ )
928
+ hidden_states = torch.utils.checkpoint.checkpoint(
929
+ create_custom_forward(resnet),
930
+ hidden_states,
931
+ temb,
932
+ **ckpt_kwargs,
933
+ )
934
+ hidden_states, ref_feature = attn(
935
+ hidden_states,
936
+ encoder_hidden_states=encoder_hidden_states,
937
+ cross_attention_kwargs=cross_attention_kwargs,
938
+ attention_mask=attention_mask,
939
+ encoder_attention_mask=encoder_attention_mask,
940
+ return_dict=False,
941
+ )
942
+ else:
943
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
944
+ hidden_states, ref_feature = attn(
945
+ hidden_states,
946
+ encoder_hidden_states=encoder_hidden_states,
947
+ cross_attention_kwargs=cross_attention_kwargs,
948
+ attention_mask=attention_mask,
949
+ encoder_attention_mask=encoder_attention_mask,
950
+ return_dict=False,
951
+ )
952
+
953
+ if self.upsamplers is not None:
954
+ for upsampler in self.upsamplers:
955
+ hidden_states = upsampler(
956
+ hidden_states, upsample_size, scale=lora_scale
957
+ )
958
+
959
+ return hidden_states
960
+
961
+
962
+ class UpBlock2D(nn.Module):
963
+ def __init__(
964
+ self,
965
+ in_channels: int,
966
+ prev_output_channel: int,
967
+ out_channels: int,
968
+ temb_channels: int,
969
+ resolution_idx: Optional[int] = None,
970
+ dropout: float = 0.0,
971
+ num_layers: int = 1,
972
+ resnet_eps: float = 1e-6,
973
+ resnet_time_scale_shift: str = "default",
974
+ resnet_act_fn: str = "swish",
975
+ resnet_groups: int = 32,
976
+ resnet_pre_norm: bool = True,
977
+ output_scale_factor: float = 1.0,
978
+ add_upsample: bool = True,
979
+ ):
980
+ super().__init__()
981
+ resnets = []
982
+
983
+ for i in range(num_layers):
984
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
985
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
986
+
987
+ resnets.append(
988
+ ResnetBlock2D(
989
+ in_channels=resnet_in_channels + res_skip_channels,
990
+ out_channels=out_channels,
991
+ temb_channels=temb_channels,
992
+ eps=resnet_eps,
993
+ groups=resnet_groups,
994
+ dropout=dropout,
995
+ time_embedding_norm=resnet_time_scale_shift,
996
+ non_linearity=resnet_act_fn,
997
+ output_scale_factor=output_scale_factor,
998
+ pre_norm=resnet_pre_norm,
999
+ )
1000
+ )
1001
+
1002
+ self.resnets = nn.ModuleList(resnets)
1003
+
1004
+ if add_upsample:
1005
+ self.upsamplers = nn.ModuleList(
1006
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
1007
+ )
1008
+ else:
1009
+ self.upsamplers = None
1010
+
1011
+ self.gradient_checkpointing = False
1012
+ self.resolution_idx = resolution_idx
1013
+
1014
+ def forward(
1015
+ self,
1016
+ hidden_states: torch.FloatTensor,
1017
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1018
+ temb: Optional[torch.FloatTensor] = None,
1019
+ upsample_size: Optional[int] = None,
1020
+ scale: float = 1.0,
1021
+ ) -> torch.FloatTensor:
1022
+ is_freeu_enabled = (
1023
+ getattr(self, "s1", None)
1024
+ and getattr(self, "s2", None)
1025
+ and getattr(self, "b1", None)
1026
+ and getattr(self, "b2", None)
1027
+ )
1028
+
1029
+ for resnet in self.resnets:
1030
+ # pop res hidden states
1031
+ res_hidden_states = res_hidden_states_tuple[-1]
1032
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1033
+
1034
+ # FreeU: Only operate on the first two stages
1035
+ if is_freeu_enabled:
1036
+ hidden_states, res_hidden_states = apply_freeu(
1037
+ self.resolution_idx,
1038
+ hidden_states,
1039
+ res_hidden_states,
1040
+ s1=self.s1,
1041
+ s2=self.s2,
1042
+ b1=self.b1,
1043
+ b2=self.b2,
1044
+ )
1045
+
1046
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1047
+
1048
+ if self.training and self.gradient_checkpointing:
1049
+
1050
+ def create_custom_forward(module):
1051
+ def custom_forward(*inputs):
1052
+ return module(*inputs)
1053
+
1054
+ return custom_forward
1055
+
1056
+ if is_torch_version(">=", "1.11.0"):
1057
+ hidden_states = torch.utils.checkpoint.checkpoint(
1058
+ create_custom_forward(resnet),
1059
+ hidden_states,
1060
+ temb,
1061
+ use_reentrant=False,
1062
+ )
1063
+ else:
1064
+ hidden_states = torch.utils.checkpoint.checkpoint(
1065
+ create_custom_forward(resnet), hidden_states, temb
1066
+ )
1067
+ else:
1068
+ hidden_states = resnet(hidden_states, temb, scale=scale)
1069
+
1070
+ if self.upsamplers is not None:
1071
+ for upsampler in self.upsamplers:
1072
+ hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
1073
+
1074
+ return hidden_states
src/models/unet_2d_condition.py ADDED
@@ -0,0 +1,1308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.utils.checkpoint
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders import UNet2DConditionLoadersMixin
10
+ from diffusers.models.activations import get_activation
11
+ from diffusers.models.attention_processor import (
12
+ ADDED_KV_ATTENTION_PROCESSORS,
13
+ CROSS_ATTENTION_PROCESSORS,
14
+ AttentionProcessor,
15
+ AttnAddedKVProcessor,
16
+ AttnProcessor,
17
+ )
18
+ from diffusers.models.embeddings import (
19
+ GaussianFourierProjection,
20
+ ImageHintTimeEmbedding,
21
+ ImageProjection,
22
+ ImageTimeEmbedding,
23
+ PositionNet,
24
+ TextImageProjection,
25
+ TextImageTimeEmbedding,
26
+ TextTimeEmbedding,
27
+ TimestepEmbedding,
28
+ Timesteps,
29
+ )
30
+ from diffusers.models.modeling_utils import ModelMixin
31
+ from diffusers.utils import (
32
+ USE_PEFT_BACKEND,
33
+ BaseOutput,
34
+ deprecate,
35
+ logging,
36
+ scale_lora_layers,
37
+ unscale_lora_layers,
38
+ )
39
+
40
+ from .unet_2d_blocks import (
41
+ UNetMidBlock2D,
42
+ UNetMidBlock2DCrossAttn,
43
+ get_down_block,
44
+ get_up_block,
45
+ )
46
+
47
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
48
+
49
+
50
+ @dataclass
51
+ class UNet2DConditionOutput(BaseOutput):
52
+ """
53
+ The output of [`UNet2DConditionModel`].
54
+
55
+ Args:
56
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
57
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
58
+ """
59
+
60
+ sample: torch.FloatTensor = None
61
+ ref_features: Tuple[torch.FloatTensor] = None
62
+
63
+
64
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
65
+ r"""
66
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
67
+ shaped output.
68
+
69
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
70
+ for all models (such as downloading or saving).
71
+
72
+ Parameters:
73
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
74
+ Height and width of input/output sample.
75
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
76
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
77
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
78
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
79
+ Whether to flip the sin to cos in the time embedding.
80
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
81
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
82
+ The tuple of downsample blocks to use.
83
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
84
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
85
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
86
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
87
+ The tuple of upsample blocks to use.
88
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
89
+ Whether to include self-attention in the basic transformer blocks, see
90
+ [`~models.attention.BasicTransformerBlock`].
91
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
92
+ The tuple of output channels for each block.
93
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
94
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
95
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
96
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
97
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
98
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
99
+ If `None`, normalization and activation layers is skipped in post-processing.
100
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
101
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
102
+ The dimension of the cross attention features.
103
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
104
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
105
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
106
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
107
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
108
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
109
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
110
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
111
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
112
+ encoder_hid_dim (`int`, *optional*, defaults to None):
113
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
114
+ dimension to `cross_attention_dim`.
115
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
116
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
117
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
118
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
119
+ num_attention_heads (`int`, *optional*):
120
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
121
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
122
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
123
+ class_embed_type (`str`, *optional*, defaults to `None`):
124
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
125
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
126
+ addition_embed_type (`str`, *optional*, defaults to `None`):
127
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
128
+ "text". "text" will use the `TextTimeEmbedding` layer.
129
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
130
+ Dimension for the timestep embeddings.
131
+ num_class_embeds (`int`, *optional*, defaults to `None`):
132
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
133
+ class conditioning with `class_embed_type` equal to `None`.
134
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
135
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
136
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
137
+ An optional override for the dimension of the projected time embedding.
138
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
139
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
140
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
141
+ timestep_post_act (`str`, *optional*, defaults to `None`):
142
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
143
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
144
+ The dimension of `cond_proj` layer in the timestep embedding.
145
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
146
+ *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
147
+ *optional*): The dimension of the `class_labels` input when
148
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
149
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
150
+ embeddings with the class embeddings.
151
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
152
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
153
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
154
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
155
+ otherwise.
156
+ """
157
+
158
+ _supports_gradient_checkpointing = True
159
+
160
+ @register_to_config
161
+ def __init__(
162
+ self,
163
+ sample_size: Optional[int] = None,
164
+ in_channels: int = 4,
165
+ out_channels: int = 4,
166
+ center_input_sample: bool = False,
167
+ flip_sin_to_cos: bool = True,
168
+ freq_shift: int = 0,
169
+ down_block_types: Tuple[str] = (
170
+ "CrossAttnDownBlock2D",
171
+ "CrossAttnDownBlock2D",
172
+ "CrossAttnDownBlock2D",
173
+ "DownBlock2D",
174
+ ),
175
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
176
+ up_block_types: Tuple[str] = (
177
+ "UpBlock2D",
178
+ "CrossAttnUpBlock2D",
179
+ "CrossAttnUpBlock2D",
180
+ "CrossAttnUpBlock2D",
181
+ ),
182
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
183
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
184
+ layers_per_block: Union[int, Tuple[int]] = 2,
185
+ downsample_padding: int = 1,
186
+ mid_block_scale_factor: float = 1,
187
+ dropout: float = 0.0,
188
+ act_fn: str = "silu",
189
+ norm_num_groups: Optional[int] = 32,
190
+ norm_eps: float = 1e-5,
191
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
192
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
193
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
194
+ encoder_hid_dim: Optional[int] = None,
195
+ encoder_hid_dim_type: Optional[str] = None,
196
+ attention_head_dim: Union[int, Tuple[int]] = 8,
197
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
198
+ dual_cross_attention: bool = False,
199
+ use_linear_projection: bool = False,
200
+ class_embed_type: Optional[str] = None,
201
+ addition_embed_type: Optional[str] = None,
202
+ addition_time_embed_dim: Optional[int] = None,
203
+ num_class_embeds: Optional[int] = None,
204
+ upcast_attention: bool = False,
205
+ resnet_time_scale_shift: str = "default",
206
+ resnet_skip_time_act: bool = False,
207
+ resnet_out_scale_factor: int = 1.0,
208
+ time_embedding_type: str = "positional",
209
+ time_embedding_dim: Optional[int] = None,
210
+ time_embedding_act_fn: Optional[str] = None,
211
+ timestep_post_act: Optional[str] = None,
212
+ time_cond_proj_dim: Optional[int] = None,
213
+ conv_in_kernel: int = 3,
214
+ conv_out_kernel: int = 3,
215
+ projection_class_embeddings_input_dim: Optional[int] = None,
216
+ attention_type: str = "default",
217
+ class_embeddings_concat: bool = False,
218
+ mid_block_only_cross_attention: Optional[bool] = None,
219
+ cross_attention_norm: Optional[str] = None,
220
+ addition_embed_type_num_heads=64,
221
+ ):
222
+ super().__init__()
223
+
224
+ self.sample_size = sample_size
225
+
226
+ if num_attention_heads is not None:
227
+ raise ValueError(
228
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
229
+ )
230
+
231
+ # If `num_attention_heads` is not defined (which is the case for most models)
232
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
233
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
234
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
235
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
236
+ # which is why we correct for the naming here.
237
+ num_attention_heads = num_attention_heads or attention_head_dim
238
+
239
+ # Check inputs
240
+ if len(down_block_types) != len(up_block_types):
241
+ raise ValueError(
242
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
243
+ )
244
+
245
+ if len(block_out_channels) != len(down_block_types):
246
+ raise ValueError(
247
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
248
+ )
249
+
250
+ if not isinstance(only_cross_attention, bool) and len(
251
+ only_cross_attention
252
+ ) != len(down_block_types):
253
+ raise ValueError(
254
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
255
+ )
256
+
257
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
258
+ down_block_types
259
+ ):
260
+ raise ValueError(
261
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
262
+ )
263
+
264
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
265
+ down_block_types
266
+ ):
267
+ raise ValueError(
268
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
269
+ )
270
+
271
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(
272
+ down_block_types
273
+ ):
274
+ raise ValueError(
275
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
276
+ )
277
+
278
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(
279
+ down_block_types
280
+ ):
281
+ raise ValueError(
282
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
283
+ )
284
+ if (
285
+ isinstance(transformer_layers_per_block, list)
286
+ and reverse_transformer_layers_per_block is None
287
+ ):
288
+ for layer_number_per_block in transformer_layers_per_block:
289
+ if isinstance(layer_number_per_block, list):
290
+ raise ValueError(
291
+ "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet."
292
+ )
293
+
294
+ # input
295
+ conv_in_padding = (conv_in_kernel - 1) // 2
296
+ self.conv_in = nn.Conv2d(
297
+ in_channels,
298
+ block_out_channels[0],
299
+ kernel_size=conv_in_kernel,
300
+ padding=conv_in_padding,
301
+ )
302
+
303
+ # time
304
+ if time_embedding_type == "fourier":
305
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
306
+ if time_embed_dim % 2 != 0:
307
+ raise ValueError(
308
+ f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
309
+ )
310
+ self.time_proj = GaussianFourierProjection(
311
+ time_embed_dim // 2,
312
+ set_W_to_weight=False,
313
+ log=False,
314
+ flip_sin_to_cos=flip_sin_to_cos,
315
+ )
316
+ timestep_input_dim = time_embed_dim
317
+ elif time_embedding_type == "positional":
318
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
319
+
320
+ self.time_proj = Timesteps(
321
+ block_out_channels[0], flip_sin_to_cos, freq_shift
322
+ )
323
+ timestep_input_dim = block_out_channels[0]
324
+ else:
325
+ raise ValueError(
326
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
327
+ )
328
+
329
+ self.time_embedding = TimestepEmbedding(
330
+ timestep_input_dim,
331
+ time_embed_dim,
332
+ act_fn=act_fn,
333
+ post_act_fn=timestep_post_act,
334
+ cond_proj_dim=time_cond_proj_dim,
335
+ )
336
+
337
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
338
+ encoder_hid_dim_type = "text_proj"
339
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
340
+ logger.info(
341
+ "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
342
+ )
343
+
344
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
345
+ raise ValueError(
346
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
347
+ )
348
+
349
+ if encoder_hid_dim_type == "text_proj":
350
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
351
+ elif encoder_hid_dim_type == "text_image_proj":
352
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
353
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
354
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
355
+ self.encoder_hid_proj = TextImageProjection(
356
+ text_embed_dim=encoder_hid_dim,
357
+ image_embed_dim=cross_attention_dim,
358
+ cross_attention_dim=cross_attention_dim,
359
+ )
360
+ elif encoder_hid_dim_type == "image_proj":
361
+ # Kandinsky 2.2
362
+ self.encoder_hid_proj = ImageProjection(
363
+ image_embed_dim=encoder_hid_dim,
364
+ cross_attention_dim=cross_attention_dim,
365
+ )
366
+ elif encoder_hid_dim_type is not None:
367
+ raise ValueError(
368
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
369
+ )
370
+ else:
371
+ self.encoder_hid_proj = None
372
+
373
+ # class embedding
374
+ if class_embed_type is None and num_class_embeds is not None:
375
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
376
+ elif class_embed_type == "timestep":
377
+ self.class_embedding = TimestepEmbedding(
378
+ timestep_input_dim, time_embed_dim, act_fn=act_fn
379
+ )
380
+ elif class_embed_type == "identity":
381
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
382
+ elif class_embed_type == "projection":
383
+ if projection_class_embeddings_input_dim is None:
384
+ raise ValueError(
385
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
386
+ )
387
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
388
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
389
+ # 2. it projects from an arbitrary input dimension.
390
+ #
391
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
392
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
393
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
394
+ self.class_embedding = TimestepEmbedding(
395
+ projection_class_embeddings_input_dim, time_embed_dim
396
+ )
397
+ elif class_embed_type == "simple_projection":
398
+ if projection_class_embeddings_input_dim is None:
399
+ raise ValueError(
400
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
401
+ )
402
+ self.class_embedding = nn.Linear(
403
+ projection_class_embeddings_input_dim, time_embed_dim
404
+ )
405
+ else:
406
+ self.class_embedding = None
407
+
408
+ if addition_embed_type == "text":
409
+ if encoder_hid_dim is not None:
410
+ text_time_embedding_from_dim = encoder_hid_dim
411
+ else:
412
+ text_time_embedding_from_dim = cross_attention_dim
413
+
414
+ self.add_embedding = TextTimeEmbedding(
415
+ text_time_embedding_from_dim,
416
+ time_embed_dim,
417
+ num_heads=addition_embed_type_num_heads,
418
+ )
419
+ elif addition_embed_type == "text_image":
420
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
421
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
422
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
423
+ self.add_embedding = TextImageTimeEmbedding(
424
+ text_embed_dim=cross_attention_dim,
425
+ image_embed_dim=cross_attention_dim,
426
+ time_embed_dim=time_embed_dim,
427
+ )
428
+ elif addition_embed_type == "text_time":
429
+ self.add_time_proj = Timesteps(
430
+ addition_time_embed_dim, flip_sin_to_cos, freq_shift
431
+ )
432
+ self.add_embedding = TimestepEmbedding(
433
+ projection_class_embeddings_input_dim, time_embed_dim
434
+ )
435
+ elif addition_embed_type == "image":
436
+ # Kandinsky 2.2
437
+ self.add_embedding = ImageTimeEmbedding(
438
+ image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
439
+ )
440
+ elif addition_embed_type == "image_hint":
441
+ # Kandinsky 2.2 ControlNet
442
+ self.add_embedding = ImageHintTimeEmbedding(
443
+ image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
444
+ )
445
+ elif addition_embed_type is not None:
446
+ raise ValueError(
447
+ f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
448
+ )
449
+
450
+ if time_embedding_act_fn is None:
451
+ self.time_embed_act = None
452
+ else:
453
+ self.time_embed_act = get_activation(time_embedding_act_fn)
454
+
455
+ self.down_blocks = nn.ModuleList([])
456
+ self.up_blocks = nn.ModuleList([])
457
+
458
+ if isinstance(only_cross_attention, bool):
459
+ if mid_block_only_cross_attention is None:
460
+ mid_block_only_cross_attention = only_cross_attention
461
+
462
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
463
+
464
+ if mid_block_only_cross_attention is None:
465
+ mid_block_only_cross_attention = False
466
+
467
+ if isinstance(num_attention_heads, int):
468
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
469
+
470
+ if isinstance(attention_head_dim, int):
471
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
472
+
473
+ if isinstance(cross_attention_dim, int):
474
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
475
+
476
+ if isinstance(layers_per_block, int):
477
+ layers_per_block = [layers_per_block] * len(down_block_types)
478
+
479
+ if isinstance(transformer_layers_per_block, int):
480
+ transformer_layers_per_block = [transformer_layers_per_block] * len(
481
+ down_block_types
482
+ )
483
+
484
+ if class_embeddings_concat:
485
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
486
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
487
+ # regular time embeddings
488
+ blocks_time_embed_dim = time_embed_dim * 2
489
+ else:
490
+ blocks_time_embed_dim = time_embed_dim
491
+
492
+ # down
493
+ output_channel = block_out_channels[0]
494
+ for i, down_block_type in enumerate(down_block_types):
495
+ input_channel = output_channel
496
+ output_channel = block_out_channels[i]
497
+ is_final_block = i == len(block_out_channels) - 1
498
+
499
+ down_block = get_down_block(
500
+ down_block_type,
501
+ num_layers=layers_per_block[i],
502
+ transformer_layers_per_block=transformer_layers_per_block[i],
503
+ in_channels=input_channel,
504
+ out_channels=output_channel,
505
+ temb_channels=blocks_time_embed_dim,
506
+ add_downsample=not is_final_block,
507
+ resnet_eps=norm_eps,
508
+ resnet_act_fn=act_fn,
509
+ resnet_groups=norm_num_groups,
510
+ cross_attention_dim=cross_attention_dim[i],
511
+ num_attention_heads=num_attention_heads[i],
512
+ downsample_padding=downsample_padding,
513
+ dual_cross_attention=dual_cross_attention,
514
+ use_linear_projection=use_linear_projection,
515
+ only_cross_attention=only_cross_attention[i],
516
+ upcast_attention=upcast_attention,
517
+ resnet_time_scale_shift=resnet_time_scale_shift,
518
+ attention_type=attention_type,
519
+ resnet_skip_time_act=resnet_skip_time_act,
520
+ resnet_out_scale_factor=resnet_out_scale_factor,
521
+ cross_attention_norm=cross_attention_norm,
522
+ attention_head_dim=attention_head_dim[i]
523
+ if attention_head_dim[i] is not None
524
+ else output_channel,
525
+ dropout=dropout,
526
+ )
527
+ self.down_blocks.append(down_block)
528
+
529
+ # mid
530
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
531
+ self.mid_block = UNetMidBlock2DCrossAttn(
532
+ transformer_layers_per_block=transformer_layers_per_block[-1],
533
+ in_channels=block_out_channels[-1],
534
+ temb_channels=blocks_time_embed_dim,
535
+ dropout=dropout,
536
+ resnet_eps=norm_eps,
537
+ resnet_act_fn=act_fn,
538
+ output_scale_factor=mid_block_scale_factor,
539
+ resnet_time_scale_shift=resnet_time_scale_shift,
540
+ cross_attention_dim=cross_attention_dim[-1],
541
+ num_attention_heads=num_attention_heads[-1],
542
+ resnet_groups=norm_num_groups,
543
+ dual_cross_attention=dual_cross_attention,
544
+ use_linear_projection=use_linear_projection,
545
+ upcast_attention=upcast_attention,
546
+ attention_type=attention_type,
547
+ )
548
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
549
+ raise NotImplementedError(f"Unsupport mid_block_type: {mid_block_type}")
550
+ elif mid_block_type == "UNetMidBlock2D":
551
+ self.mid_block = UNetMidBlock2D(
552
+ in_channels=block_out_channels[-1],
553
+ temb_channels=blocks_time_embed_dim,
554
+ dropout=dropout,
555
+ num_layers=0,
556
+ resnet_eps=norm_eps,
557
+ resnet_act_fn=act_fn,
558
+ output_scale_factor=mid_block_scale_factor,
559
+ resnet_groups=norm_num_groups,
560
+ resnet_time_scale_shift=resnet_time_scale_shift,
561
+ add_attention=False,
562
+ )
563
+ elif mid_block_type is None:
564
+ self.mid_block = None
565
+ else:
566
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
567
+
568
+ # count how many layers upsample the images
569
+ self.num_upsamplers = 0
570
+
571
+ # up
572
+ reversed_block_out_channels = list(reversed(block_out_channels))
573
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
574
+ reversed_layers_per_block = list(reversed(layers_per_block))
575
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
576
+ reversed_transformer_layers_per_block = (
577
+ list(reversed(transformer_layers_per_block))
578
+ if reverse_transformer_layers_per_block is None
579
+ else reverse_transformer_layers_per_block
580
+ )
581
+ only_cross_attention = list(reversed(only_cross_attention))
582
+
583
+ output_channel = reversed_block_out_channels[0]
584
+ for i, up_block_type in enumerate(up_block_types):
585
+ is_final_block = i == len(block_out_channels) - 1
586
+
587
+ prev_output_channel = output_channel
588
+ output_channel = reversed_block_out_channels[i]
589
+ input_channel = reversed_block_out_channels[
590
+ min(i + 1, len(block_out_channels) - 1)
591
+ ]
592
+
593
+ # add upsample block for all BUT final layer
594
+ if not is_final_block:
595
+ add_upsample = True
596
+ self.num_upsamplers += 1
597
+ else:
598
+ add_upsample = False
599
+
600
+ up_block = get_up_block(
601
+ up_block_type,
602
+ num_layers=reversed_layers_per_block[i] + 1,
603
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
604
+ in_channels=input_channel,
605
+ out_channels=output_channel,
606
+ prev_output_channel=prev_output_channel,
607
+ temb_channels=blocks_time_embed_dim,
608
+ add_upsample=add_upsample,
609
+ resnet_eps=norm_eps,
610
+ resnet_act_fn=act_fn,
611
+ resolution_idx=i,
612
+ resnet_groups=norm_num_groups,
613
+ cross_attention_dim=reversed_cross_attention_dim[i],
614
+ num_attention_heads=reversed_num_attention_heads[i],
615
+ dual_cross_attention=dual_cross_attention,
616
+ use_linear_projection=use_linear_projection,
617
+ only_cross_attention=only_cross_attention[i],
618
+ upcast_attention=upcast_attention,
619
+ resnet_time_scale_shift=resnet_time_scale_shift,
620
+ attention_type=attention_type,
621
+ resnet_skip_time_act=resnet_skip_time_act,
622
+ resnet_out_scale_factor=resnet_out_scale_factor,
623
+ cross_attention_norm=cross_attention_norm,
624
+ attention_head_dim=attention_head_dim[i]
625
+ if attention_head_dim[i] is not None
626
+ else output_channel,
627
+ dropout=dropout,
628
+ )
629
+ self.up_blocks.append(up_block)
630
+ prev_output_channel = output_channel
631
+
632
+ # out
633
+ if norm_num_groups is not None:
634
+ self.conv_norm_out = nn.GroupNorm(
635
+ num_channels=block_out_channels[0],
636
+ num_groups=norm_num_groups,
637
+ eps=norm_eps,
638
+ )
639
+
640
+ self.conv_act = get_activation(act_fn)
641
+
642
+ else:
643
+ self.conv_norm_out = None
644
+ self.conv_act = None
645
+ self.conv_norm_out = None
646
+
647
+ conv_out_padding = (conv_out_kernel - 1) // 2
648
+ # self.conv_out = nn.Conv2d(
649
+ # block_out_channels[0],
650
+ # out_channels,
651
+ # kernel_size=conv_out_kernel,
652
+ # padding=conv_out_padding,
653
+ # )
654
+
655
+ if attention_type in ["gated", "gated-text-image"]:
656
+ positive_len = 768
657
+ if isinstance(cross_attention_dim, int):
658
+ positive_len = cross_attention_dim
659
+ elif isinstance(cross_attention_dim, tuple) or isinstance(
660
+ cross_attention_dim, list
661
+ ):
662
+ positive_len = cross_attention_dim[0]
663
+
664
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
665
+ self.position_net = PositionNet(
666
+ positive_len=positive_len,
667
+ out_dim=cross_attention_dim,
668
+ feature_type=feature_type,
669
+ )
670
+
671
+ @property
672
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
673
+ r"""
674
+ Returns:
675
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
676
+ indexed by its weight name.
677
+ """
678
+ # set recursively
679
+ processors = {}
680
+
681
+ def fn_recursive_add_processors(
682
+ name: str,
683
+ module: torch.nn.Module,
684
+ processors: Dict[str, AttentionProcessor],
685
+ ):
686
+ if hasattr(module, "get_processor"):
687
+ processors[f"{name}.processor"] = module.get_processor(
688
+ return_deprecated_lora=True
689
+ )
690
+
691
+ for sub_name, child in module.named_children():
692
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
693
+
694
+ return processors
695
+
696
+ for name, module in self.named_children():
697
+ fn_recursive_add_processors(name, module, processors)
698
+
699
+ return processors
700
+
701
+ def set_attn_processor(
702
+ self,
703
+ processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
704
+ _remove_lora=False,
705
+ ):
706
+ r"""
707
+ Sets the attention processor to use to compute attention.
708
+
709
+ Parameters:
710
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
711
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
712
+ for **all** `Attention` layers.
713
+
714
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
715
+ processor. This is strongly recommended when setting trainable attention processors.
716
+
717
+ """
718
+ count = len(self.attn_processors.keys())
719
+
720
+ if isinstance(processor, dict) and len(processor) != count:
721
+ raise ValueError(
722
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
723
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
724
+ )
725
+
726
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
727
+ if hasattr(module, "set_processor"):
728
+ if not isinstance(processor, dict):
729
+ module.set_processor(processor, _remove_lora=_remove_lora)
730
+ else:
731
+ module.set_processor(
732
+ processor.pop(f"{name}.processor"), _remove_lora=_remove_lora
733
+ )
734
+
735
+ for sub_name, child in module.named_children():
736
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
737
+
738
+ for name, module in self.named_children():
739
+ fn_recursive_attn_processor(name, module, processor)
740
+
741
+ def set_default_attn_processor(self):
742
+ """
743
+ Disables custom attention processors and sets the default attention implementation.
744
+ """
745
+ if all(
746
+ proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
747
+ for proc in self.attn_processors.values()
748
+ ):
749
+ processor = AttnAddedKVProcessor()
750
+ elif all(
751
+ proc.__class__ in CROSS_ATTENTION_PROCESSORS
752
+ for proc in self.attn_processors.values()
753
+ ):
754
+ processor = AttnProcessor()
755
+ else:
756
+ raise ValueError(
757
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
758
+ )
759
+
760
+ self.set_attn_processor(processor, _remove_lora=True)
761
+
762
+ def set_attention_slice(self, slice_size):
763
+ r"""
764
+ Enable sliced attention computation.
765
+
766
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
767
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
768
+
769
+ Args:
770
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
771
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
772
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
773
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
774
+ must be a multiple of `slice_size`.
775
+ """
776
+ sliceable_head_dims = []
777
+
778
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
779
+ if hasattr(module, "set_attention_slice"):
780
+ sliceable_head_dims.append(module.sliceable_head_dim)
781
+
782
+ for child in module.children():
783
+ fn_recursive_retrieve_sliceable_dims(child)
784
+
785
+ # retrieve number of attention layers
786
+ for module in self.children():
787
+ fn_recursive_retrieve_sliceable_dims(module)
788
+
789
+ num_sliceable_layers = len(sliceable_head_dims)
790
+
791
+ if slice_size == "auto":
792
+ # half the attention head size is usually a good trade-off between
793
+ # speed and memory
794
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
795
+ elif slice_size == "max":
796
+ # make smallest slice possible
797
+ slice_size = num_sliceable_layers * [1]
798
+
799
+ slice_size = (
800
+ num_sliceable_layers * [slice_size]
801
+ if not isinstance(slice_size, list)
802
+ else slice_size
803
+ )
804
+
805
+ if len(slice_size) != len(sliceable_head_dims):
806
+ raise ValueError(
807
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
808
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
809
+ )
810
+
811
+ for i in range(len(slice_size)):
812
+ size = slice_size[i]
813
+ dim = sliceable_head_dims[i]
814
+ if size is not None and size > dim:
815
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
816
+
817
+ # Recursively walk through all the children.
818
+ # Any children which exposes the set_attention_slice method
819
+ # gets the message
820
+ def fn_recursive_set_attention_slice(
821
+ module: torch.nn.Module, slice_size: List[int]
822
+ ):
823
+ if hasattr(module, "set_attention_slice"):
824
+ module.set_attention_slice(slice_size.pop())
825
+
826
+ for child in module.children():
827
+ fn_recursive_set_attention_slice(child, slice_size)
828
+
829
+ reversed_slice_size = list(reversed(slice_size))
830
+ for module in self.children():
831
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
832
+
833
+ def _set_gradient_checkpointing(self, module, value=False):
834
+ if hasattr(module, "gradient_checkpointing"):
835
+ module.gradient_checkpointing = value
836
+
837
+ def enable_freeu(self, s1, s2, b1, b2):
838
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
839
+
840
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
841
+
842
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
843
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
844
+
845
+ Args:
846
+ s1 (`float`):
847
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
848
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
849
+ s2 (`float`):
850
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
851
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
852
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
853
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
854
+ """
855
+ for i, upsample_block in enumerate(self.up_blocks):
856
+ setattr(upsample_block, "s1", s1)
857
+ setattr(upsample_block, "s2", s2)
858
+ setattr(upsample_block, "b1", b1)
859
+ setattr(upsample_block, "b2", b2)
860
+
861
+ def disable_freeu(self):
862
+ """Disables the FreeU mechanism."""
863
+ freeu_keys = {"s1", "s2", "b1", "b2"}
864
+ for i, upsample_block in enumerate(self.up_blocks):
865
+ for k in freeu_keys:
866
+ if (
867
+ hasattr(upsample_block, k)
868
+ or getattr(upsample_block, k, None) is not None
869
+ ):
870
+ setattr(upsample_block, k, None)
871
+
872
+ def forward(
873
+ self,
874
+ sample: torch.FloatTensor,
875
+ timestep: Union[torch.Tensor, float, int],
876
+ encoder_hidden_states: torch.Tensor,
877
+ class_labels: Optional[torch.Tensor] = None,
878
+ timestep_cond: Optional[torch.Tensor] = None,
879
+ attention_mask: Optional[torch.Tensor] = None,
880
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
881
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
882
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
883
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
884
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
885
+ encoder_attention_mask: Optional[torch.Tensor] = None,
886
+ return_dict: bool = True,
887
+ ) -> Union[UNet2DConditionOutput, Tuple]:
888
+ r"""
889
+ The [`UNet2DConditionModel`] forward method.
890
+
891
+ Args:
892
+ sample (`torch.FloatTensor`):
893
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
894
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
895
+ encoder_hidden_states (`torch.FloatTensor`):
896
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
897
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
898
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
899
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
900
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
901
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
902
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
903
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
904
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
905
+ negative values to the attention scores corresponding to "discard" tokens.
906
+ cross_attention_kwargs (`dict`, *optional*):
907
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
908
+ `self.processor` in
909
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
910
+ added_cond_kwargs: (`dict`, *optional*):
911
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
912
+ are passed along to the UNet blocks.
913
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
914
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
915
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
916
+ A tensor that if specified is added to the residual of the middle unet block.
917
+ encoder_attention_mask (`torch.Tensor`):
918
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
919
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
920
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
921
+ return_dict (`bool`, *optional*, defaults to `True`):
922
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
923
+ tuple.
924
+ cross_attention_kwargs (`dict`, *optional*):
925
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
926
+ added_cond_kwargs: (`dict`, *optional*):
927
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
928
+ are passed along to the UNet blocks.
929
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
930
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
931
+ example from ControlNet side model(s)
932
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
933
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
934
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
935
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
936
+
937
+ Returns:
938
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
939
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
940
+ a `tuple` is returned where the first element is the sample tensor.
941
+ """
942
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
943
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
944
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
945
+ # on the fly if necessary.
946
+ default_overall_up_factor = 2**self.num_upsamplers
947
+
948
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
949
+ forward_upsample_size = False
950
+ upsample_size = None
951
+
952
+ for dim in sample.shape[-2:]:
953
+ if dim % default_overall_up_factor != 0:
954
+ # Forward upsample size to force interpolation output size.
955
+ forward_upsample_size = True
956
+ break
957
+
958
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
959
+ # expects mask of shape:
960
+ # [batch, key_tokens]
961
+ # adds singleton query_tokens dimension:
962
+ # [batch, 1, key_tokens]
963
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
964
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
965
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
966
+ if attention_mask is not None:
967
+ # assume that mask is expressed as:
968
+ # (1 = keep, 0 = discard)
969
+ # convert mask into a bias that can be added to attention scores:
970
+ # (keep = +0, discard = -10000.0)
971
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
972
+ attention_mask = attention_mask.unsqueeze(1)
973
+
974
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
975
+ if encoder_attention_mask is not None:
976
+ encoder_attention_mask = (
977
+ 1 - encoder_attention_mask.to(sample.dtype)
978
+ ) * -10000.0
979
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
980
+
981
+ # 0. center input if necessary
982
+ if self.config.center_input_sample:
983
+ sample = 2 * sample - 1.0
984
+
985
+ # 1. time
986
+ timesteps = timestep
987
+ if not torch.is_tensor(timesteps):
988
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
989
+ # This would be a good case for the `match` statement (Python 3.10+)
990
+ is_mps = sample.device.type == "mps"
991
+ if isinstance(timestep, float):
992
+ dtype = torch.float32 if is_mps else torch.float64
993
+ else:
994
+ dtype = torch.int32 if is_mps else torch.int64
995
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
996
+ elif len(timesteps.shape) == 0:
997
+ timesteps = timesteps[None].to(sample.device)
998
+
999
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1000
+ timesteps = timesteps.expand(sample.shape[0])
1001
+
1002
+ t_emb = self.time_proj(timesteps)
1003
+
1004
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1005
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1006
+ # there might be better ways to encapsulate this.
1007
+ t_emb = t_emb.to(dtype=sample.dtype)
1008
+
1009
+ emb = self.time_embedding(t_emb, timestep_cond)
1010
+ aug_emb = None
1011
+
1012
+ if self.class_embedding is not None:
1013
+ if class_labels is None:
1014
+ raise ValueError(
1015
+ "class_labels should be provided when num_class_embeds > 0"
1016
+ )
1017
+
1018
+ if self.config.class_embed_type == "timestep":
1019
+ class_labels = self.time_proj(class_labels)
1020
+
1021
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1022
+ # there might be better ways to encapsulate this.
1023
+ class_labels = class_labels.to(dtype=sample.dtype)
1024
+
1025
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
1026
+
1027
+ if self.config.class_embeddings_concat:
1028
+ emb = torch.cat([emb, class_emb], dim=-1)
1029
+ else:
1030
+ emb = emb + class_emb
1031
+
1032
+ if self.config.addition_embed_type == "text":
1033
+ aug_emb = self.add_embedding(encoder_hidden_states)
1034
+ elif self.config.addition_embed_type == "text_image":
1035
+ # Kandinsky 2.1 - style
1036
+ if "image_embeds" not in added_cond_kwargs:
1037
+ raise ValueError(
1038
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1039
+ )
1040
+
1041
+ image_embs = added_cond_kwargs.get("image_embeds")
1042
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
1043
+ aug_emb = self.add_embedding(text_embs, image_embs)
1044
+ elif self.config.addition_embed_type == "text_time":
1045
+ # SDXL - style
1046
+ if "text_embeds" not in added_cond_kwargs:
1047
+ raise ValueError(
1048
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
1049
+ )
1050
+ text_embeds = added_cond_kwargs.get("text_embeds")
1051
+ if "time_ids" not in added_cond_kwargs:
1052
+ raise ValueError(
1053
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
1054
+ )
1055
+ time_ids = added_cond_kwargs.get("time_ids")
1056
+ time_embeds = self.add_time_proj(time_ids.flatten())
1057
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
1058
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
1059
+ add_embeds = add_embeds.to(emb.dtype)
1060
+ aug_emb = self.add_embedding(add_embeds)
1061
+ elif self.config.addition_embed_type == "image":
1062
+ # Kandinsky 2.2 - style
1063
+ if "image_embeds" not in added_cond_kwargs:
1064
+ raise ValueError(
1065
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1066
+ )
1067
+ image_embs = added_cond_kwargs.get("image_embeds")
1068
+ aug_emb = self.add_embedding(image_embs)
1069
+ elif self.config.addition_embed_type == "image_hint":
1070
+ # Kandinsky 2.2 - style
1071
+ if (
1072
+ "image_embeds" not in added_cond_kwargs
1073
+ or "hint" not in added_cond_kwargs
1074
+ ):
1075
+ raise ValueError(
1076
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1077
+ )
1078
+ image_embs = added_cond_kwargs.get("image_embeds")
1079
+ hint = added_cond_kwargs.get("hint")
1080
+ aug_emb, hint = self.add_embedding(image_embs, hint)
1081
+ sample = torch.cat([sample, hint], dim=1)
1082
+
1083
+ emb = emb + aug_emb if aug_emb is not None else emb
1084
+
1085
+ if self.time_embed_act is not None:
1086
+ emb = self.time_embed_act(emb)
1087
+
1088
+ if (
1089
+ self.encoder_hid_proj is not None
1090
+ and self.config.encoder_hid_dim_type == "text_proj"
1091
+ ):
1092
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1093
+ elif (
1094
+ self.encoder_hid_proj is not None
1095
+ and self.config.encoder_hid_dim_type == "text_image_proj"
1096
+ ):
1097
+ # Kadinsky 2.1 - style
1098
+ if "image_embeds" not in added_cond_kwargs:
1099
+ raise ValueError(
1100
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1101
+ )
1102
+
1103
+ image_embeds = added_cond_kwargs.get("image_embeds")
1104
+ encoder_hidden_states = self.encoder_hid_proj(
1105
+ encoder_hidden_states, image_embeds
1106
+ )
1107
+ elif (
1108
+ self.encoder_hid_proj is not None
1109
+ and self.config.encoder_hid_dim_type == "image_proj"
1110
+ ):
1111
+ # Kandinsky 2.2 - style
1112
+ if "image_embeds" not in added_cond_kwargs:
1113
+ raise ValueError(
1114
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1115
+ )
1116
+ image_embeds = added_cond_kwargs.get("image_embeds")
1117
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1118
+ elif (
1119
+ self.encoder_hid_proj is not None
1120
+ and self.config.encoder_hid_dim_type == "ip_image_proj"
1121
+ ):
1122
+ if "image_embeds" not in added_cond_kwargs:
1123
+ raise ValueError(
1124
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1125
+ )
1126
+ image_embeds = added_cond_kwargs.get("image_embeds")
1127
+ image_embeds = self.encoder_hid_proj(image_embeds).to(
1128
+ encoder_hidden_states.dtype
1129
+ )
1130
+ encoder_hidden_states = torch.cat(
1131
+ [encoder_hidden_states, image_embeds], dim=1
1132
+ )
1133
+
1134
+ # 2. pre-process
1135
+ sample = self.conv_in(sample)
1136
+
1137
+ # 2.5 GLIGEN position net
1138
+ if (
1139
+ cross_attention_kwargs is not None
1140
+ and cross_attention_kwargs.get("gligen", None) is not None
1141
+ ):
1142
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1143
+ gligen_args = cross_attention_kwargs.pop("gligen")
1144
+ cross_attention_kwargs["gligen"] = {
1145
+ "objs": self.position_net(**gligen_args)
1146
+ }
1147
+
1148
+ # 3. down
1149
+ lora_scale = (
1150
+ cross_attention_kwargs.get("scale", 1.0)
1151
+ if cross_attention_kwargs is not None
1152
+ else 1.0
1153
+ )
1154
+ if USE_PEFT_BACKEND:
1155
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1156
+ scale_lora_layers(self, lora_scale)
1157
+
1158
+ is_controlnet = (
1159
+ mid_block_additional_residual is not None
1160
+ and down_block_additional_residuals is not None
1161
+ )
1162
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1163
+ is_adapter = down_intrablock_additional_residuals is not None
1164
+ # maintain backward compatibility for legacy usage, where
1165
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1166
+ # but can only use one or the other
1167
+ if (
1168
+ not is_adapter
1169
+ and mid_block_additional_residual is None
1170
+ and down_block_additional_residuals is not None
1171
+ ):
1172
+ deprecate(
1173
+ "T2I should not use down_block_additional_residuals",
1174
+ "1.3.0",
1175
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1176
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1177
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1178
+ standard_warn=False,
1179
+ )
1180
+ down_intrablock_additional_residuals = down_block_additional_residuals
1181
+ is_adapter = True
1182
+
1183
+ down_block_res_samples = (sample,)
1184
+ tot_referece_features = ()
1185
+ for downsample_block in self.down_blocks:
1186
+ if (
1187
+ hasattr(downsample_block, "has_cross_attention")
1188
+ and downsample_block.has_cross_attention
1189
+ ):
1190
+ # For t2i-adapter CrossAttnDownBlock2D
1191
+ additional_residuals = {}
1192
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1193
+ additional_residuals[
1194
+ "additional_residuals"
1195
+ ] = down_intrablock_additional_residuals.pop(0)
1196
+
1197
+ sample, res_samples = downsample_block(
1198
+ hidden_states=sample,
1199
+ temb=emb,
1200
+ encoder_hidden_states=encoder_hidden_states,
1201
+ attention_mask=attention_mask,
1202
+ cross_attention_kwargs=cross_attention_kwargs,
1203
+ encoder_attention_mask=encoder_attention_mask,
1204
+ **additional_residuals,
1205
+ )
1206
+ else:
1207
+ sample, res_samples = downsample_block(
1208
+ hidden_states=sample, temb=emb, scale=lora_scale
1209
+ )
1210
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1211
+ sample += down_intrablock_additional_residuals.pop(0)
1212
+
1213
+ down_block_res_samples += res_samples
1214
+
1215
+ if is_controlnet:
1216
+ new_down_block_res_samples = ()
1217
+
1218
+ for down_block_res_sample, down_block_additional_residual in zip(
1219
+ down_block_res_samples, down_block_additional_residuals
1220
+ ):
1221
+ down_block_res_sample = (
1222
+ down_block_res_sample + down_block_additional_residual
1223
+ )
1224
+ new_down_block_res_samples = new_down_block_res_samples + (
1225
+ down_block_res_sample,
1226
+ )
1227
+
1228
+ down_block_res_samples = new_down_block_res_samples
1229
+
1230
+ # 4. mid
1231
+ if self.mid_block is not None:
1232
+ if (
1233
+ hasattr(self.mid_block, "has_cross_attention")
1234
+ and self.mid_block.has_cross_attention
1235
+ ):
1236
+ sample = self.mid_block(
1237
+ sample,
1238
+ emb,
1239
+ encoder_hidden_states=encoder_hidden_states,
1240
+ attention_mask=attention_mask,
1241
+ cross_attention_kwargs=cross_attention_kwargs,
1242
+ encoder_attention_mask=encoder_attention_mask,
1243
+ )
1244
+ else:
1245
+ sample = self.mid_block(sample, emb)
1246
+
1247
+ # To support T2I-Adapter-XL
1248
+ if (
1249
+ is_adapter
1250
+ and len(down_intrablock_additional_residuals) > 0
1251
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1252
+ ):
1253
+ sample += down_intrablock_additional_residuals.pop(0)
1254
+
1255
+ if is_controlnet:
1256
+ sample = sample + mid_block_additional_residual
1257
+
1258
+ # 5. up
1259
+ for i, upsample_block in enumerate(self.up_blocks):
1260
+ is_final_block = i == len(self.up_blocks) - 1
1261
+
1262
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1263
+ down_block_res_samples = down_block_res_samples[
1264
+ : -len(upsample_block.resnets)
1265
+ ]
1266
+
1267
+ # if we have not reached the final block and need to forward the
1268
+ # upsample size, we do it here
1269
+ if not is_final_block and forward_upsample_size:
1270
+ upsample_size = down_block_res_samples[-1].shape[2:]
1271
+
1272
+ if (
1273
+ hasattr(upsample_block, "has_cross_attention")
1274
+ and upsample_block.has_cross_attention
1275
+ ):
1276
+ sample = upsample_block(
1277
+ hidden_states=sample,
1278
+ temb=emb,
1279
+ res_hidden_states_tuple=res_samples,
1280
+ encoder_hidden_states=encoder_hidden_states,
1281
+ cross_attention_kwargs=cross_attention_kwargs,
1282
+ upsample_size=upsample_size,
1283
+ attention_mask=attention_mask,
1284
+ encoder_attention_mask=encoder_attention_mask,
1285
+ )
1286
+ else:
1287
+ sample = upsample_block(
1288
+ hidden_states=sample,
1289
+ temb=emb,
1290
+ res_hidden_states_tuple=res_samples,
1291
+ upsample_size=upsample_size,
1292
+ scale=lora_scale,
1293
+ )
1294
+
1295
+ # 6. post-process
1296
+ # if self.conv_norm_out:
1297
+ # sample = self.conv_norm_out(sample)
1298
+ # sample = self.conv_act(sample)
1299
+ # sample = self.conv_out(sample)
1300
+
1301
+ if USE_PEFT_BACKEND:
1302
+ # remove `lora_scale` from each PEFT layer
1303
+ unscale_lora_layers(self, lora_scale)
1304
+
1305
+ if not return_dict:
1306
+ return (sample,)
1307
+
1308
+ return UNet2DConditionOutput(sample=sample)
src/models/unet_3d.py ADDED
@@ -0,0 +1,673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet_blocks.py
2
+
3
+ from collections import OrderedDict
4
+ from dataclasses import dataclass
5
+ import pdb
6
+ from os import PathLike
7
+ from pathlib import Path
8
+ from typing import Dict, List, Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.utils.checkpoint
13
+ import torch.nn.functional as F
14
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
15
+ from diffusers.models.attention_processor import AttentionProcessor
16
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
17
+ from diffusers.models.modeling_utils import ModelMixin
18
+ from diffusers.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
19
+ from safetensors.torch import load_file
20
+
21
+ from .resnet import InflatedConv3d, InflatedGroupNorm
22
+ from .unet_3d_blocks import UNetMidBlock3DCrossAttn, get_down_block, get_up_block
23
+
24
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
+
26
+
27
+ @dataclass
28
+ class UNet3DConditionOutput(BaseOutput):
29
+ sample: torch.FloatTensor
30
+
31
+
32
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
33
+ _supports_gradient_checkpointing = True
34
+
35
+ @register_to_config
36
+ def __init__(
37
+ self,
38
+ sample_size: Optional[int] = None,
39
+ in_channels: int = 4,
40
+ out_channels: int = 4,
41
+ center_input_sample: bool = False,
42
+ flip_sin_to_cos: bool = True,
43
+ freq_shift: int = 0,
44
+ down_block_types: Tuple[str] = (
45
+ "CrossAttnDownBlock3D",
46
+ "CrossAttnDownBlock3D",
47
+ "CrossAttnDownBlock3D",
48
+ "DownBlock3D",
49
+ ),
50
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
51
+ up_block_types: Tuple[str] = (
52
+ "UpBlock3D",
53
+ "CrossAttnUpBlock3D",
54
+ "CrossAttnUpBlock3D",
55
+ "CrossAttnUpBlock3D",
56
+ ),
57
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
58
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
59
+ layers_per_block: int = 2,
60
+ downsample_padding: int = 1,
61
+ mid_block_scale_factor: float = 1,
62
+ act_fn: str = "silu",
63
+ norm_num_groups: int = 32,
64
+ norm_eps: float = 1e-5,
65
+ cross_attention_dim: int = 1280,
66
+ attention_head_dim: Union[int, Tuple[int]] = 8,
67
+ dual_cross_attention: bool = False,
68
+ use_linear_projection: bool = False,
69
+ class_embed_type: Optional[str] = None,
70
+ num_class_embeds: Optional[int] = None,
71
+ upcast_attention: bool = False,
72
+ resnet_time_scale_shift: str = "default",
73
+ use_inflated_groupnorm=False,
74
+ # Additional
75
+ use_motion_module=False,
76
+ motion_module_resolutions=(1, 2, 4, 8),
77
+ motion_module_mid_block=False,
78
+ motion_module_decoder_only=False,
79
+ motion_module_type=None,
80
+ motion_module_kwargs={},
81
+ unet_use_cross_frame_attention=None,
82
+ unet_use_temporal_attention=None,
83
+ ):
84
+ super().__init__()
85
+
86
+ self.sample_size = sample_size
87
+ time_embed_dim = block_out_channels[0] * 4
88
+
89
+ # input
90
+ self.conv_in = InflatedConv3d(
91
+ in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)
92
+ )
93
+
94
+ # time
95
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
96
+ timestep_input_dim = block_out_channels[0]
97
+
98
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
99
+
100
+ # class embedding
101
+ if class_embed_type is None and num_class_embeds is not None:
102
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
103
+ elif class_embed_type == "timestep":
104
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
105
+ elif class_embed_type == "identity":
106
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
107
+ else:
108
+ self.class_embedding = None
109
+
110
+ self.down_blocks = nn.ModuleList([])
111
+ self.mid_block = None
112
+ self.up_blocks = nn.ModuleList([])
113
+
114
+ if isinstance(only_cross_attention, bool):
115
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
116
+
117
+ if isinstance(attention_head_dim, int):
118
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
119
+
120
+ # down
121
+ output_channel = block_out_channels[0]
122
+ for i, down_block_type in enumerate(down_block_types):
123
+ res = 2**i
124
+ input_channel = output_channel
125
+ output_channel = block_out_channels[i]
126
+ is_final_block = i == len(block_out_channels) - 1
127
+
128
+ down_block = get_down_block(
129
+ down_block_type,
130
+ num_layers=layers_per_block,
131
+ in_channels=input_channel,
132
+ out_channels=output_channel,
133
+ temb_channels=time_embed_dim,
134
+ add_downsample=not is_final_block,
135
+ resnet_eps=norm_eps,
136
+ resnet_act_fn=act_fn,
137
+ resnet_groups=norm_num_groups,
138
+ cross_attention_dim=cross_attention_dim,
139
+ attn_num_head_channels=attention_head_dim[i],
140
+ downsample_padding=downsample_padding,
141
+ dual_cross_attention=dual_cross_attention,
142
+ use_linear_projection=use_linear_projection,
143
+ only_cross_attention=only_cross_attention[i],
144
+ upcast_attention=upcast_attention,
145
+ resnet_time_scale_shift=resnet_time_scale_shift,
146
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
147
+ unet_use_temporal_attention=unet_use_temporal_attention,
148
+ use_inflated_groupnorm=use_inflated_groupnorm,
149
+ use_motion_module=use_motion_module
150
+ and (res in motion_module_resolutions)
151
+ and (not motion_module_decoder_only),
152
+ motion_module_type=motion_module_type,
153
+ motion_module_kwargs=motion_module_kwargs,
154
+ )
155
+ self.down_blocks.append(down_block)
156
+
157
+ # mid
158
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
159
+ self.mid_block = UNetMidBlock3DCrossAttn(
160
+ in_channels=block_out_channels[-1],
161
+ temb_channels=time_embed_dim,
162
+ resnet_eps=norm_eps,
163
+ resnet_act_fn=act_fn,
164
+ output_scale_factor=mid_block_scale_factor,
165
+ resnet_time_scale_shift=resnet_time_scale_shift,
166
+ cross_attention_dim=cross_attention_dim,
167
+ attn_num_head_channels=attention_head_dim[-1],
168
+ resnet_groups=norm_num_groups,
169
+ dual_cross_attention=dual_cross_attention,
170
+ use_linear_projection=use_linear_projection,
171
+ upcast_attention=upcast_attention,
172
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
173
+ unet_use_temporal_attention=unet_use_temporal_attention,
174
+ use_inflated_groupnorm=use_inflated_groupnorm,
175
+ use_motion_module=use_motion_module and motion_module_mid_block,
176
+ motion_module_type=motion_module_type,
177
+ motion_module_kwargs=motion_module_kwargs,
178
+ )
179
+ else:
180
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
181
+
182
+ # count how many layers upsample the videos
183
+ self.num_upsamplers = 0
184
+
185
+ # up
186
+ reversed_block_out_channels = list(reversed(block_out_channels))
187
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
188
+ only_cross_attention = list(reversed(only_cross_attention))
189
+ output_channel = reversed_block_out_channels[0]
190
+ for i, up_block_type in enumerate(up_block_types):
191
+ res = 2 ** (3 - i)
192
+ is_final_block = i == len(block_out_channels) - 1
193
+
194
+ prev_output_channel = output_channel
195
+ output_channel = reversed_block_out_channels[i]
196
+ input_channel = reversed_block_out_channels[
197
+ min(i + 1, len(block_out_channels) - 1)
198
+ ]
199
+
200
+ # add upsample block for all BUT final layer
201
+ if not is_final_block:
202
+ add_upsample = True
203
+ self.num_upsamplers += 1
204
+ else:
205
+ add_upsample = False
206
+
207
+ up_block = get_up_block(
208
+ up_block_type,
209
+ num_layers=layers_per_block + 1,
210
+ in_channels=input_channel,
211
+ out_channels=output_channel,
212
+ prev_output_channel=prev_output_channel,
213
+ temb_channels=time_embed_dim,
214
+ add_upsample=add_upsample,
215
+ resnet_eps=norm_eps,
216
+ resnet_act_fn=act_fn,
217
+ resnet_groups=norm_num_groups,
218
+ cross_attention_dim=cross_attention_dim,
219
+ attn_num_head_channels=reversed_attention_head_dim[i],
220
+ dual_cross_attention=dual_cross_attention,
221
+ use_linear_projection=use_linear_projection,
222
+ only_cross_attention=only_cross_attention[i],
223
+ upcast_attention=upcast_attention,
224
+ resnet_time_scale_shift=resnet_time_scale_shift,
225
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
226
+ unet_use_temporal_attention=unet_use_temporal_attention,
227
+ use_inflated_groupnorm=use_inflated_groupnorm,
228
+ use_motion_module=use_motion_module
229
+ and (res in motion_module_resolutions),
230
+ motion_module_type=motion_module_type,
231
+ motion_module_kwargs=motion_module_kwargs,
232
+ )
233
+ self.up_blocks.append(up_block)
234
+ prev_output_channel = output_channel
235
+
236
+ # out
237
+ if use_inflated_groupnorm:
238
+ self.conv_norm_out = InflatedGroupNorm(
239
+ num_channels=block_out_channels[0],
240
+ num_groups=norm_num_groups,
241
+ eps=norm_eps,
242
+ )
243
+ else:
244
+ self.conv_norm_out = nn.GroupNorm(
245
+ num_channels=block_out_channels[0],
246
+ num_groups=norm_num_groups,
247
+ eps=norm_eps,
248
+ )
249
+ self.conv_act = nn.SiLU()
250
+ self.conv_out = InflatedConv3d(
251
+ block_out_channels[0], out_channels, kernel_size=3, padding=1
252
+ )
253
+
254
+ @property
255
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
256
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
257
+ r"""
258
+ Returns:
259
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
260
+ indexed by its weight name.
261
+ """
262
+ # set recursively
263
+ processors = {}
264
+
265
+ def fn_recursive_add_processors(
266
+ name: str,
267
+ module: torch.nn.Module,
268
+ processors: Dict[str, AttentionProcessor],
269
+ ):
270
+ if hasattr(module, "set_processor"):
271
+ processors[f"{name}.processor"] = module.processor
272
+
273
+ for sub_name, child in module.named_children():
274
+ if "temporal_transformer" not in sub_name:
275
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
276
+
277
+ return processors
278
+
279
+ for name, module in self.named_children():
280
+ if "temporal_transformer" not in name:
281
+ fn_recursive_add_processors(name, module, processors)
282
+
283
+ return processors
284
+
285
+ def set_attention_slice(self, slice_size):
286
+ r"""
287
+ Enable sliced attention computation.
288
+
289
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
290
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
291
+
292
+ Args:
293
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
294
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
295
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
296
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
297
+ must be a multiple of `slice_size`.
298
+ """
299
+ sliceable_head_dims = []
300
+
301
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
302
+ if hasattr(module, "set_attention_slice"):
303
+ sliceable_head_dims.append(module.sliceable_head_dim)
304
+
305
+ for child in module.children():
306
+ fn_recursive_retrieve_slicable_dims(child)
307
+
308
+ # retrieve number of attention layers
309
+ for module in self.children():
310
+ fn_recursive_retrieve_slicable_dims(module)
311
+
312
+ num_slicable_layers = len(sliceable_head_dims)
313
+
314
+ if slice_size == "auto":
315
+ # half the attention head size is usually a good trade-off between
316
+ # speed and memory
317
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
318
+ elif slice_size == "max":
319
+ # make smallest slice possible
320
+ slice_size = num_slicable_layers * [1]
321
+
322
+ slice_size = (
323
+ num_slicable_layers * [slice_size]
324
+ if not isinstance(slice_size, list)
325
+ else slice_size
326
+ )
327
+
328
+ if len(slice_size) != len(sliceable_head_dims):
329
+ raise ValueError(
330
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
331
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
332
+ )
333
+
334
+ for i in range(len(slice_size)):
335
+ size = slice_size[i]
336
+ dim = sliceable_head_dims[i]
337
+ if size is not None and size > dim:
338
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
339
+
340
+ # Recursively walk through all the children.
341
+ # Any children which exposes the set_attention_slice method
342
+ # gets the message
343
+ def fn_recursive_set_attention_slice(
344
+ module: torch.nn.Module, slice_size: List[int]
345
+ ):
346
+ if hasattr(module, "set_attention_slice"):
347
+ module.set_attention_slice(slice_size.pop())
348
+
349
+ for child in module.children():
350
+ fn_recursive_set_attention_slice(child, slice_size)
351
+
352
+ reversed_slice_size = list(reversed(slice_size))
353
+ for module in self.children():
354
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
355
+
356
+ def _set_gradient_checkpointing(self, module, value=False):
357
+ if hasattr(module, "gradient_checkpointing"):
358
+ module.gradient_checkpointing = value
359
+
360
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
361
+ def set_attn_processor(
362
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
363
+ ):
364
+ r"""
365
+ Sets the attention processor to use to compute attention.
366
+
367
+ Parameters:
368
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
369
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
370
+ for **all** `Attention` layers.
371
+
372
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
373
+ processor. This is strongly recommended when setting trainable attention processors.
374
+
375
+ """
376
+ count = len(self.attn_processors.keys())
377
+
378
+ if isinstance(processor, dict) and len(processor) != count:
379
+ raise ValueError(
380
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
381
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
382
+ )
383
+
384
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
385
+ if hasattr(module, "set_processor"):
386
+ if not isinstance(processor, dict):
387
+ module.set_processor(processor)
388
+ else:
389
+ module.set_processor(processor.pop(f"{name}.processor"))
390
+
391
+ for sub_name, child in module.named_children():
392
+ if "temporal_transformer" not in sub_name:
393
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
394
+
395
+ for name, module in self.named_children():
396
+ if "temporal_transformer" not in name:
397
+ fn_recursive_attn_processor(name, module, processor)
398
+
399
+ def forward(
400
+ self,
401
+ sample: torch.FloatTensor,
402
+ timestep: Union[torch.Tensor, float, int],
403
+ encoder_hidden_states: torch.Tensor,
404
+ class_labels: Optional[torch.Tensor] = None,
405
+ pose_cond_fea = None,
406
+ attention_mask: Optional[torch.Tensor] = None,
407
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
408
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
409
+ return_dict: bool = True,
410
+ ) -> Union[UNet3DConditionOutput, Tuple]:
411
+ r"""
412
+ Args:
413
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
414
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
415
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
416
+ return_dict (`bool`, *optional*, defaults to `True`):
417
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
418
+
419
+ Returns:
420
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
421
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
422
+ returning a tuple, the first element is the sample tensor.
423
+ """
424
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
425
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
426
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
427
+ # on the fly if necessary.
428
+ default_overall_up_factor = 2**self.num_upsamplers
429
+
430
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
431
+ forward_upsample_size = False
432
+ upsample_size = None
433
+
434
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
435
+ logger.info("Forward upsample size to force interpolation output size.")
436
+ forward_upsample_size = True
437
+
438
+ # prepare attention_mask
439
+ if attention_mask is not None:
440
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
441
+ attention_mask = attention_mask.unsqueeze(1)
442
+
443
+ # center input if necessary
444
+ if self.config.center_input_sample:
445
+ sample = 2 * sample - 1.0
446
+
447
+ # time
448
+ timesteps = timestep
449
+ if not torch.is_tensor(timesteps):
450
+ # This would be a good case for the `match` statement (Python 3.10+)
451
+ is_mps = sample.device.type == "mps"
452
+ if isinstance(timestep, float):
453
+ dtype = torch.float32 if is_mps else torch.float64
454
+ else:
455
+ dtype = torch.int32 if is_mps else torch.int64
456
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
457
+ elif len(timesteps.shape) == 0:
458
+ timesteps = timesteps[None].to(sample.device)
459
+
460
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
461
+ timesteps = timesteps.expand(sample.shape[0])
462
+
463
+ t_emb = self.time_proj(timesteps)
464
+
465
+ # timesteps does not contain any weights and will always return f32 tensors
466
+ # but time_embedding might actually be running in fp16. so we need to cast here.
467
+ # there might be better ways to encapsulate this.
468
+ t_emb = t_emb.to(dtype=self.dtype)
469
+ emb = self.time_embedding(t_emb)
470
+
471
+ if self.class_embedding is not None:
472
+ if class_labels is None:
473
+ raise ValueError(
474
+ "class_labels should be provided when num_class_embeds > 0"
475
+ )
476
+
477
+ if self.config.class_embed_type == "timestep":
478
+ class_labels = self.time_proj(class_labels)
479
+
480
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
481
+ emb = emb + class_emb
482
+
483
+ # pre-process
484
+ sample = self.conv_in(sample)
485
+ if pose_cond_fea is not None:
486
+ sample = sample + pose_cond_fea[0]
487
+
488
+ # down
489
+ down_block_res_samples = (sample,)
490
+ block_count = 1
491
+ for downsample_block in self.down_blocks:
492
+ if (
493
+ hasattr(downsample_block, "has_cross_attention")
494
+ and downsample_block.has_cross_attention
495
+ ):
496
+ sample, res_samples = downsample_block(
497
+ hidden_states=sample,
498
+ temb=emb,
499
+ encoder_hidden_states=encoder_hidden_states,
500
+ attention_mask=attention_mask,
501
+ )
502
+ else:
503
+ sample, res_samples = downsample_block(
504
+ hidden_states=sample,
505
+ temb=emb,
506
+ encoder_hidden_states=encoder_hidden_states,
507
+ )
508
+ if pose_cond_fea is not None:
509
+ sample = sample + pose_cond_fea[block_count]
510
+ block_count += 1
511
+ down_block_res_samples += res_samples
512
+
513
+ if down_block_additional_residuals is not None:
514
+ new_down_block_res_samples = ()
515
+
516
+ for down_block_res_sample, down_block_additional_residual in zip(
517
+ down_block_res_samples, down_block_additional_residuals
518
+ ):
519
+ down_block_res_sample = (
520
+ down_block_res_sample + down_block_additional_residual
521
+ )
522
+ new_down_block_res_samples += (down_block_res_sample,)
523
+
524
+ down_block_res_samples = new_down_block_res_samples
525
+
526
+ # mid
527
+ sample = self.mid_block(
528
+ sample,
529
+ emb,
530
+ encoder_hidden_states=encoder_hidden_states,
531
+ attention_mask=attention_mask,
532
+ )
533
+
534
+ if mid_block_additional_residual is not None:
535
+ sample = sample + mid_block_additional_residual
536
+
537
+ # up
538
+ for i, upsample_block in enumerate(self.up_blocks):
539
+ is_final_block = i == len(self.up_blocks) - 1
540
+
541
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
542
+ down_block_res_samples = down_block_res_samples[
543
+ : -len(upsample_block.resnets)
544
+ ]
545
+
546
+ # if we have not reached the final block and need to forward the
547
+ # upsample size, we do it here
548
+ if not is_final_block and forward_upsample_size:
549
+ upsample_size = down_block_res_samples[-1].shape[2:]
550
+
551
+ if (
552
+ hasattr(upsample_block, "has_cross_attention")
553
+ and upsample_block.has_cross_attention
554
+ ):
555
+ sample = upsample_block(
556
+ hidden_states=sample,
557
+ temb=emb,
558
+ res_hidden_states_tuple=res_samples,
559
+ encoder_hidden_states=encoder_hidden_states,
560
+ upsample_size=upsample_size,
561
+ attention_mask=attention_mask,
562
+ )
563
+ else:
564
+ sample = upsample_block(
565
+ hidden_states=sample,
566
+ temb=emb,
567
+ res_hidden_states_tuple=res_samples,
568
+ upsample_size=upsample_size,
569
+ encoder_hidden_states=encoder_hidden_states,
570
+ )
571
+
572
+ # post-process
573
+ sample = self.conv_norm_out(sample)
574
+ sample = self.conv_act(sample)
575
+ sample = self.conv_out(sample)
576
+
577
+ if not return_dict:
578
+ return (sample,)
579
+
580
+ return UNet3DConditionOutput(sample=sample)
581
+
582
+ @classmethod
583
+ def from_pretrained_2d(
584
+ cls,
585
+ pretrained_model_path: PathLike,
586
+ motion_module_path: PathLike,
587
+ subfolder=None,
588
+ unet_additional_kwargs=None,
589
+ mm_zero_proj_out=False,
590
+ ):
591
+ pretrained_model_path = Path(pretrained_model_path)
592
+ motion_module_path = Path(motion_module_path)
593
+ if subfolder is not None:
594
+ pretrained_model_path = pretrained_model_path.joinpath(subfolder)
595
+ logger.info(
596
+ f"loaded temporal unet's pretrained weights from {pretrained_model_path} ..."
597
+ )
598
+
599
+ config_file = pretrained_model_path / "config.json"
600
+ if not (config_file.exists() and config_file.is_file()):
601
+ raise RuntimeError(f"{config_file} does not exist or is not a file")
602
+
603
+ unet_config = cls.load_config(config_file)
604
+ unet_config["_class_name"] = cls.__name__
605
+ unet_config["down_block_types"] = [
606
+ "CrossAttnDownBlock3D",
607
+ "CrossAttnDownBlock3D",
608
+ "CrossAttnDownBlock3D",
609
+ "DownBlock3D",
610
+ ]
611
+ unet_config["up_block_types"] = [
612
+ "UpBlock3D",
613
+ "CrossAttnUpBlock3D",
614
+ "CrossAttnUpBlock3D",
615
+ "CrossAttnUpBlock3D",
616
+ ]
617
+ unet_config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
618
+
619
+ model = cls.from_config(unet_config, **unet_additional_kwargs)
620
+ # load the vanilla weights
621
+ if pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME).exists():
622
+ logger.debug(
623
+ f"loading safeTensors weights from {pretrained_model_path} ..."
624
+ )
625
+ state_dict = load_file(
626
+ pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME), device="cpu"
627
+ )
628
+
629
+ elif pretrained_model_path.joinpath(WEIGHTS_NAME).exists():
630
+ logger.debug(f"loading weights from {pretrained_model_path} ...")
631
+ state_dict = torch.load(
632
+ pretrained_model_path.joinpath(WEIGHTS_NAME),
633
+ map_location="cpu",
634
+ weights_only=True,
635
+ )
636
+ else:
637
+ raise FileNotFoundError(f"no weights file found in {pretrained_model_path}")
638
+
639
+ # load the motion module weights
640
+ if motion_module_path.exists() and motion_module_path.is_file():
641
+ if motion_module_path.suffix.lower() in [".pth", ".pt", ".ckpt"]:
642
+ logger.info(f"Load motion module params from {motion_module_path}")
643
+ motion_state_dict = torch.load(
644
+ motion_module_path, map_location="cpu", weights_only=True
645
+ )
646
+ elif motion_module_path.suffix.lower() == ".safetensors":
647
+ motion_state_dict = load_file(motion_module_path, device="cpu")
648
+ else:
649
+ raise RuntimeError(
650
+ f"unknown file format for motion module weights: {motion_module_path.suffix}"
651
+ )
652
+ if mm_zero_proj_out:
653
+ logger.info(f"Zero initialize proj_out layers in motion module...")
654
+ new_motion_state_dict = OrderedDict()
655
+ for k in motion_state_dict:
656
+ if "proj_out" in k:
657
+ continue
658
+ new_motion_state_dict[k] = motion_state_dict[k]
659
+ motion_state_dict = new_motion_state_dict
660
+
661
+ # merge the state dicts
662
+ state_dict.update(motion_state_dict)
663
+
664
+ # load the weights into the model
665
+ m, u = model.load_state_dict(state_dict, strict=False)
666
+ logger.debug(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
667
+
668
+ params = [
669
+ p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()
670
+ ]
671
+ logger.info(f"Loaded {sum(params) / 1e6}M-parameter motion module")
672
+
673
+ return model
src/models/unet_3d_blocks.py ADDED
@@ -0,0 +1,861 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+
3
+ import pdb
4
+ from typing import Dict, Optional
5
+ import torch
6
+ from torch import nn
7
+
8
+ from .motion_module import get_motion_module
9
+
10
+ # from .motion_module import get_motion_module
11
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
12
+ from .transformer_3d import Transformer3DModel
13
+
14
+
15
+ def get_down_block(
16
+ down_block_type,
17
+ num_layers,
18
+ in_channels,
19
+ out_channels,
20
+ temb_channels,
21
+ add_downsample,
22
+ resnet_eps,
23
+ resnet_act_fn,
24
+ attn_num_head_channels,
25
+ resnet_groups=None,
26
+ cross_attention_dim=None,
27
+ downsample_padding=None,
28
+ dual_cross_attention=False,
29
+ use_linear_projection=False,
30
+ only_cross_attention=False,
31
+ upcast_attention=False,
32
+ resnet_time_scale_shift="default",
33
+ unet_use_cross_frame_attention=None,
34
+ unet_use_temporal_attention=None,
35
+ use_inflated_groupnorm=None,
36
+ use_motion_module=None,
37
+ motion_module_type=None,
38
+ motion_module_kwargs=None,
39
+ ):
40
+ down_block_type = (
41
+ down_block_type[7:]
42
+ if down_block_type.startswith("UNetRes")
43
+ else down_block_type
44
+ )
45
+ if down_block_type == "DownBlock3D":
46
+ return DownBlock3D(
47
+ num_layers=num_layers,
48
+ in_channels=in_channels,
49
+ out_channels=out_channels,
50
+ temb_channels=temb_channels,
51
+ add_downsample=add_downsample,
52
+ resnet_eps=resnet_eps,
53
+ resnet_act_fn=resnet_act_fn,
54
+ resnet_groups=resnet_groups,
55
+ downsample_padding=downsample_padding,
56
+ resnet_time_scale_shift=resnet_time_scale_shift,
57
+ use_inflated_groupnorm=use_inflated_groupnorm,
58
+ use_motion_module=use_motion_module,
59
+ motion_module_type=motion_module_type,
60
+ motion_module_kwargs=motion_module_kwargs,
61
+ )
62
+ elif down_block_type == "CrossAttnDownBlock3D":
63
+ if cross_attention_dim is None:
64
+ raise ValueError(
65
+ "cross_attention_dim must be specified for CrossAttnDownBlock3D"
66
+ )
67
+ return CrossAttnDownBlock3D(
68
+ num_layers=num_layers,
69
+ in_channels=in_channels,
70
+ out_channels=out_channels,
71
+ temb_channels=temb_channels,
72
+ add_downsample=add_downsample,
73
+ resnet_eps=resnet_eps,
74
+ resnet_act_fn=resnet_act_fn,
75
+ resnet_groups=resnet_groups,
76
+ downsample_padding=downsample_padding,
77
+ cross_attention_dim=cross_attention_dim,
78
+ attn_num_head_channels=attn_num_head_channels,
79
+ dual_cross_attention=dual_cross_attention,
80
+ use_linear_projection=use_linear_projection,
81
+ only_cross_attention=only_cross_attention,
82
+ upcast_attention=upcast_attention,
83
+ resnet_time_scale_shift=resnet_time_scale_shift,
84
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
85
+ unet_use_temporal_attention=unet_use_temporal_attention,
86
+ use_inflated_groupnorm=use_inflated_groupnorm,
87
+ use_motion_module=use_motion_module,
88
+ motion_module_type=motion_module_type,
89
+ motion_module_kwargs=motion_module_kwargs,
90
+ )
91
+ raise ValueError(f"{down_block_type} does not exist.")
92
+
93
+
94
+ def get_up_block(
95
+ up_block_type,
96
+ num_layers,
97
+ in_channels,
98
+ out_channels,
99
+ prev_output_channel,
100
+ temb_channels,
101
+ add_upsample,
102
+ resnet_eps,
103
+ resnet_act_fn,
104
+ attn_num_head_channels,
105
+ resnet_groups=None,
106
+ cross_attention_dim=None,
107
+ dual_cross_attention=False,
108
+ use_linear_projection=False,
109
+ only_cross_attention=False,
110
+ upcast_attention=False,
111
+ resnet_time_scale_shift="default",
112
+ unet_use_cross_frame_attention=None,
113
+ unet_use_temporal_attention=None,
114
+ use_inflated_groupnorm=None,
115
+ use_motion_module=None,
116
+ motion_module_type=None,
117
+ motion_module_kwargs=None,
118
+ ):
119
+ up_block_type = (
120
+ up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
121
+ )
122
+ if up_block_type == "UpBlock3D":
123
+ return UpBlock3D(
124
+ num_layers=num_layers,
125
+ in_channels=in_channels,
126
+ out_channels=out_channels,
127
+ prev_output_channel=prev_output_channel,
128
+ temb_channels=temb_channels,
129
+ add_upsample=add_upsample,
130
+ resnet_eps=resnet_eps,
131
+ resnet_act_fn=resnet_act_fn,
132
+ resnet_groups=resnet_groups,
133
+ resnet_time_scale_shift=resnet_time_scale_shift,
134
+ use_inflated_groupnorm=use_inflated_groupnorm,
135
+ use_motion_module=use_motion_module,
136
+ motion_module_type=motion_module_type,
137
+ motion_module_kwargs=motion_module_kwargs,
138
+ )
139
+ elif up_block_type == "CrossAttnUpBlock3D":
140
+ if cross_attention_dim is None:
141
+ raise ValueError(
142
+ "cross_attention_dim must be specified for CrossAttnUpBlock3D"
143
+ )
144
+ return CrossAttnUpBlock3D(
145
+ num_layers=num_layers,
146
+ in_channels=in_channels,
147
+ out_channels=out_channels,
148
+ prev_output_channel=prev_output_channel,
149
+ temb_channels=temb_channels,
150
+ add_upsample=add_upsample,
151
+ resnet_eps=resnet_eps,
152
+ resnet_act_fn=resnet_act_fn,
153
+ resnet_groups=resnet_groups,
154
+ cross_attention_dim=cross_attention_dim,
155
+ attn_num_head_channels=attn_num_head_channels,
156
+ dual_cross_attention=dual_cross_attention,
157
+ use_linear_projection=use_linear_projection,
158
+ only_cross_attention=only_cross_attention,
159
+ upcast_attention=upcast_attention,
160
+ resnet_time_scale_shift=resnet_time_scale_shift,
161
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
162
+ unet_use_temporal_attention=unet_use_temporal_attention,
163
+ use_inflated_groupnorm=use_inflated_groupnorm,
164
+ use_motion_module=use_motion_module,
165
+ motion_module_type=motion_module_type,
166
+ motion_module_kwargs=motion_module_kwargs,
167
+ )
168
+ raise ValueError(f"{up_block_type} does not exist.")
169
+
170
+
171
+ class UNetMidBlock3DCrossAttn(nn.Module):
172
+ def __init__(
173
+ self,
174
+ in_channels: int,
175
+ temb_channels: int,
176
+ dropout: float = 0.0,
177
+ num_layers: int = 1,
178
+ resnet_eps: float = 1e-6,
179
+ resnet_time_scale_shift: str = "default",
180
+ resnet_act_fn: str = "swish",
181
+ resnet_groups: int = 32,
182
+ resnet_pre_norm: bool = True,
183
+ attn_num_head_channels=1,
184
+ output_scale_factor=1.0,
185
+ cross_attention_dim=1280,
186
+ dual_cross_attention=False,
187
+ use_linear_projection=False,
188
+ upcast_attention=False,
189
+ unet_use_cross_frame_attention=None,
190
+ unet_use_temporal_attention=None,
191
+ use_inflated_groupnorm=None,
192
+ use_motion_module=None,
193
+ motion_module_type=None,
194
+ motion_module_kwargs=None,
195
+ ):
196
+ super().__init__()
197
+
198
+ self.has_cross_attention = True
199
+ self.attn_num_head_channels = attn_num_head_channels
200
+ resnet_groups = (
201
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
202
+ )
203
+
204
+ # there is always at least one resnet
205
+ resnets = [
206
+ ResnetBlock3D(
207
+ in_channels=in_channels,
208
+ out_channels=in_channels,
209
+ temb_channels=temb_channels,
210
+ eps=resnet_eps,
211
+ groups=resnet_groups,
212
+ dropout=dropout,
213
+ time_embedding_norm=resnet_time_scale_shift,
214
+ non_linearity=resnet_act_fn,
215
+ output_scale_factor=output_scale_factor,
216
+ pre_norm=resnet_pre_norm,
217
+ use_inflated_groupnorm=use_inflated_groupnorm,
218
+ )
219
+ ]
220
+ attentions = []
221
+ motion_modules = []
222
+
223
+ for _ in range(num_layers):
224
+ if dual_cross_attention:
225
+ raise NotImplementedError
226
+ attentions.append(
227
+ Transformer3DModel(
228
+ attn_num_head_channels,
229
+ in_channels // attn_num_head_channels,
230
+ in_channels=in_channels,
231
+ num_layers=1,
232
+ cross_attention_dim=cross_attention_dim,
233
+ norm_num_groups=resnet_groups,
234
+ use_linear_projection=use_linear_projection,
235
+ upcast_attention=upcast_attention,
236
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
237
+ unet_use_temporal_attention=unet_use_temporal_attention,
238
+ )
239
+ )
240
+ motion_modules.append(
241
+ get_motion_module(
242
+ in_channels=in_channels,
243
+ motion_module_type=motion_module_type,
244
+ motion_module_kwargs=motion_module_kwargs,
245
+ )
246
+ if use_motion_module
247
+ else None
248
+ )
249
+ resnets.append(
250
+ ResnetBlock3D(
251
+ in_channels=in_channels,
252
+ out_channels=in_channels,
253
+ temb_channels=temb_channels,
254
+ eps=resnet_eps,
255
+ groups=resnet_groups,
256
+ dropout=dropout,
257
+ time_embedding_norm=resnet_time_scale_shift,
258
+ non_linearity=resnet_act_fn,
259
+ output_scale_factor=output_scale_factor,
260
+ pre_norm=resnet_pre_norm,
261
+ use_inflated_groupnorm=use_inflated_groupnorm,
262
+ )
263
+ )
264
+
265
+ self.attentions = nn.ModuleList(attentions)
266
+ self.resnets = nn.ModuleList(resnets)
267
+ self.motion_modules = nn.ModuleList(motion_modules)
268
+
269
+ def forward(
270
+ self,
271
+ hidden_states,
272
+ temb=None,
273
+ encoder_hidden_states=None,
274
+ attention_mask=None,
275
+ ):
276
+ hidden_states = self.resnets[0](hidden_states, temb)
277
+ for attn, resnet, motion_module in zip(
278
+ self.attentions, self.resnets[1:], self.motion_modules
279
+ ):
280
+ hidden_states = attn(
281
+ hidden_states,
282
+ encoder_hidden_states=encoder_hidden_states,
283
+ ).sample
284
+ hidden_states = (
285
+ motion_module(
286
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
287
+ )
288
+ if motion_module is not None
289
+ else hidden_states
290
+ )
291
+ hidden_states = resnet(hidden_states, temb)
292
+
293
+ return hidden_states
294
+
295
+
296
+ class CrossAttnDownBlock3D(nn.Module):
297
+ def __init__(
298
+ self,
299
+ in_channels: int,
300
+ out_channels: int,
301
+ temb_channels: int,
302
+ dropout: float = 0.0,
303
+ num_layers: int = 1,
304
+ resnet_eps: float = 1e-6,
305
+ resnet_time_scale_shift: str = "default",
306
+ resnet_act_fn: str = "swish",
307
+ resnet_groups: int = 32,
308
+ resnet_pre_norm: bool = True,
309
+ attn_num_head_channels=1,
310
+ cross_attention_dim=1280,
311
+ output_scale_factor=1.0,
312
+ downsample_padding=1,
313
+ add_downsample=True,
314
+ dual_cross_attention=False,
315
+ use_linear_projection=False,
316
+ only_cross_attention=False,
317
+ upcast_attention=False,
318
+ unet_use_cross_frame_attention=None,
319
+ unet_use_temporal_attention=None,
320
+ use_inflated_groupnorm=None,
321
+ use_motion_module=None,
322
+ motion_module_type=None,
323
+ motion_module_kwargs=None,
324
+ ):
325
+ super().__init__()
326
+ resnets = []
327
+ attentions = []
328
+ motion_modules = []
329
+
330
+ self.has_cross_attention = True
331
+ self.attn_num_head_channels = attn_num_head_channels
332
+
333
+ for i in range(num_layers):
334
+ in_channels = in_channels if i == 0 else out_channels
335
+ resnets.append(
336
+ ResnetBlock3D(
337
+ in_channels=in_channels,
338
+ out_channels=out_channels,
339
+ temb_channels=temb_channels,
340
+ eps=resnet_eps,
341
+ groups=resnet_groups,
342
+ dropout=dropout,
343
+ time_embedding_norm=resnet_time_scale_shift,
344
+ non_linearity=resnet_act_fn,
345
+ output_scale_factor=output_scale_factor,
346
+ pre_norm=resnet_pre_norm,
347
+ use_inflated_groupnorm=use_inflated_groupnorm,
348
+ )
349
+ )
350
+ if dual_cross_attention:
351
+ raise NotImplementedError
352
+ attentions.append(
353
+ Transformer3DModel(
354
+ attn_num_head_channels,
355
+ out_channels // attn_num_head_channels,
356
+ in_channels=out_channels,
357
+ num_layers=1,
358
+ cross_attention_dim=cross_attention_dim,
359
+ norm_num_groups=resnet_groups,
360
+ use_linear_projection=use_linear_projection,
361
+ only_cross_attention=only_cross_attention,
362
+ upcast_attention=upcast_attention,
363
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
364
+ unet_use_temporal_attention=unet_use_temporal_attention,
365
+ )
366
+ )
367
+ motion_modules.append(
368
+ get_motion_module(
369
+ in_channels=out_channels,
370
+ motion_module_type=motion_module_type,
371
+ motion_module_kwargs=motion_module_kwargs,
372
+ )
373
+ if use_motion_module
374
+ else None
375
+ )
376
+
377
+ self.attentions = nn.ModuleList(attentions)
378
+ self.resnets = nn.ModuleList(resnets)
379
+ self.motion_modules = nn.ModuleList(motion_modules)
380
+
381
+ if add_downsample:
382
+ self.downsamplers = nn.ModuleList(
383
+ [
384
+ Downsample3D(
385
+ out_channels,
386
+ use_conv=True,
387
+ out_channels=out_channels,
388
+ padding=downsample_padding,
389
+ name="op",
390
+ )
391
+ ]
392
+ )
393
+ else:
394
+ self.downsamplers = None
395
+
396
+ self.gradient_checkpointing = False
397
+
398
+ def forward(
399
+ self,
400
+ hidden_states,
401
+ temb=None,
402
+ encoder_hidden_states=None,
403
+ attention_mask=None,
404
+ ):
405
+ output_states = ()
406
+
407
+ for i, (resnet, attn, motion_module) in enumerate(
408
+ zip(self.resnets, self.attentions, self.motion_modules)
409
+ ):
410
+ # self.gradient_checkpointing = False
411
+ if self.training and self.gradient_checkpointing:
412
+
413
+ def create_custom_forward(module, return_dict=None):
414
+ def custom_forward(*inputs):
415
+ if return_dict is not None:
416
+ return module(*inputs, return_dict=return_dict)
417
+ else:
418
+ return module(*inputs)
419
+
420
+ return custom_forward
421
+
422
+ hidden_states = torch.utils.checkpoint.checkpoint(
423
+ create_custom_forward(resnet), hidden_states, temb
424
+ )
425
+ hidden_states = torch.utils.checkpoint.checkpoint(
426
+ create_custom_forward(attn, return_dict=False),
427
+ hidden_states,
428
+ encoder_hidden_states,
429
+ )[0]
430
+
431
+ # add motion module
432
+ hidden_states = (
433
+ motion_module(
434
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
435
+ )
436
+ if motion_module is not None
437
+ else hidden_states
438
+ )
439
+
440
+ else:
441
+ hidden_states = resnet(hidden_states, temb)
442
+ hidden_states = attn(
443
+ hidden_states,
444
+ encoder_hidden_states=encoder_hidden_states,
445
+ ).sample
446
+
447
+ # add motion module
448
+ hidden_states = (
449
+ motion_module(
450
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
451
+ )
452
+ if motion_module is not None
453
+ else hidden_states
454
+ )
455
+
456
+ output_states += (hidden_states,)
457
+
458
+ if self.downsamplers is not None:
459
+ for downsampler in self.downsamplers:
460
+ hidden_states = downsampler(hidden_states)
461
+
462
+ output_states += (hidden_states,)
463
+
464
+ return hidden_states, output_states
465
+
466
+
467
+ class DownBlock3D(nn.Module):
468
+ def __init__(
469
+ self,
470
+ in_channels: int,
471
+ out_channels: int,
472
+ temb_channels: int,
473
+ dropout: float = 0.0,
474
+ num_layers: int = 1,
475
+ resnet_eps: float = 1e-6,
476
+ resnet_time_scale_shift: str = "default",
477
+ resnet_act_fn: str = "swish",
478
+ resnet_groups: int = 32,
479
+ resnet_pre_norm: bool = True,
480
+ output_scale_factor=1.0,
481
+ add_downsample=True,
482
+ downsample_padding=1,
483
+ use_inflated_groupnorm=None,
484
+ use_motion_module=None,
485
+ motion_module_type=None,
486
+ motion_module_kwargs=None,
487
+ ):
488
+ super().__init__()
489
+ resnets = []
490
+ motion_modules = []
491
+
492
+ # use_motion_module = False
493
+ for i in range(num_layers):
494
+ in_channels = in_channels if i == 0 else out_channels
495
+ resnets.append(
496
+ ResnetBlock3D(
497
+ in_channels=in_channels,
498
+ out_channels=out_channels,
499
+ temb_channels=temb_channels,
500
+ eps=resnet_eps,
501
+ groups=resnet_groups,
502
+ dropout=dropout,
503
+ time_embedding_norm=resnet_time_scale_shift,
504
+ non_linearity=resnet_act_fn,
505
+ output_scale_factor=output_scale_factor,
506
+ pre_norm=resnet_pre_norm,
507
+ use_inflated_groupnorm=use_inflated_groupnorm,
508
+ )
509
+ )
510
+ motion_modules.append(
511
+ get_motion_module(
512
+ in_channels=out_channels,
513
+ motion_module_type=motion_module_type,
514
+ motion_module_kwargs=motion_module_kwargs,
515
+ )
516
+ if use_motion_module
517
+ else None
518
+ )
519
+
520
+ self.resnets = nn.ModuleList(resnets)
521
+ self.motion_modules = nn.ModuleList(motion_modules)
522
+
523
+ if add_downsample:
524
+ self.downsamplers = nn.ModuleList(
525
+ [
526
+ Downsample3D(
527
+ out_channels,
528
+ use_conv=True,
529
+ out_channels=out_channels,
530
+ padding=downsample_padding,
531
+ name="op",
532
+ )
533
+ ]
534
+ )
535
+ else:
536
+ self.downsamplers = None
537
+
538
+ self.gradient_checkpointing = False
539
+
540
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
541
+ output_states = ()
542
+
543
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
544
+ # print(f"DownBlock3D {self.gradient_checkpointing = }")
545
+ if self.training and self.gradient_checkpointing:
546
+
547
+ def create_custom_forward(module):
548
+ def custom_forward(*inputs):
549
+ return module(*inputs)
550
+
551
+ return custom_forward
552
+
553
+ hidden_states = torch.utils.checkpoint.checkpoint(
554
+ create_custom_forward(resnet), hidden_states, temb
555
+ )
556
+ if motion_module is not None:
557
+ hidden_states = torch.utils.checkpoint.checkpoint(
558
+ create_custom_forward(motion_module),
559
+ hidden_states.requires_grad_(),
560
+ temb,
561
+ encoder_hidden_states,
562
+ )
563
+ else:
564
+ hidden_states = resnet(hidden_states, temb)
565
+
566
+ # add motion module
567
+ hidden_states = (
568
+ motion_module(
569
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
570
+ )
571
+ if motion_module is not None
572
+ else hidden_states
573
+ )
574
+
575
+ output_states += (hidden_states,)
576
+
577
+ if self.downsamplers is not None:
578
+ for downsampler in self.downsamplers:
579
+ hidden_states = downsampler(hidden_states)
580
+
581
+ output_states += (hidden_states,)
582
+
583
+ return hidden_states, output_states
584
+
585
+
586
+ class CrossAttnUpBlock3D(nn.Module):
587
+ def __init__(
588
+ self,
589
+ in_channels: int,
590
+ out_channels: int,
591
+ prev_output_channel: int,
592
+ temb_channels: int,
593
+ dropout: float = 0.0,
594
+ num_layers: int = 1,
595
+ resnet_eps: float = 1e-6,
596
+ resnet_time_scale_shift: str = "default",
597
+ resnet_act_fn: str = "swish",
598
+ resnet_groups: int = 32,
599
+ resnet_pre_norm: bool = True,
600
+ attn_num_head_channels=1,
601
+ cross_attention_dim=1280,
602
+ output_scale_factor=1.0,
603
+ add_upsample=True,
604
+ dual_cross_attention=False,
605
+ use_linear_projection=False,
606
+ only_cross_attention=False,
607
+ upcast_attention=False,
608
+ unet_use_cross_frame_attention=None,
609
+ unet_use_temporal_attention=None,
610
+ use_motion_module=None,
611
+ use_inflated_groupnorm=None,
612
+ motion_module_type=None,
613
+ motion_module_kwargs=None,
614
+ ):
615
+ super().__init__()
616
+ resnets = []
617
+ attentions = []
618
+ motion_modules = []
619
+
620
+ self.has_cross_attention = True
621
+ self.attn_num_head_channels = attn_num_head_channels
622
+
623
+ for i in range(num_layers):
624
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
625
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
626
+
627
+ resnets.append(
628
+ ResnetBlock3D(
629
+ in_channels=resnet_in_channels + res_skip_channels,
630
+ out_channels=out_channels,
631
+ temb_channels=temb_channels,
632
+ eps=resnet_eps,
633
+ groups=resnet_groups,
634
+ dropout=dropout,
635
+ time_embedding_norm=resnet_time_scale_shift,
636
+ non_linearity=resnet_act_fn,
637
+ output_scale_factor=output_scale_factor,
638
+ pre_norm=resnet_pre_norm,
639
+ use_inflated_groupnorm=use_inflated_groupnorm,
640
+ )
641
+ )
642
+ if dual_cross_attention:
643
+ raise NotImplementedError
644
+ attentions.append(
645
+ Transformer3DModel(
646
+ attn_num_head_channels,
647
+ out_channels // attn_num_head_channels,
648
+ in_channels=out_channels,
649
+ num_layers=1,
650
+ cross_attention_dim=cross_attention_dim,
651
+ norm_num_groups=resnet_groups,
652
+ use_linear_projection=use_linear_projection,
653
+ only_cross_attention=only_cross_attention,
654
+ upcast_attention=upcast_attention,
655
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
656
+ unet_use_temporal_attention=unet_use_temporal_attention,
657
+ )
658
+ )
659
+ motion_modules.append(
660
+ get_motion_module(
661
+ in_channels=out_channels,
662
+ motion_module_type=motion_module_type,
663
+ motion_module_kwargs=motion_module_kwargs,
664
+ )
665
+ if use_motion_module
666
+ else None
667
+ )
668
+
669
+ self.attentions = nn.ModuleList(attentions)
670
+ self.resnets = nn.ModuleList(resnets)
671
+ self.motion_modules = nn.ModuleList(motion_modules)
672
+
673
+ if add_upsample:
674
+ self.upsamplers = nn.ModuleList(
675
+ [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
676
+ )
677
+ else:
678
+ self.upsamplers = None
679
+
680
+ self.gradient_checkpointing = False
681
+
682
+ def forward(
683
+ self,
684
+ hidden_states,
685
+ res_hidden_states_tuple,
686
+ temb=None,
687
+ encoder_hidden_states=None,
688
+ upsample_size=None,
689
+ attention_mask=None,
690
+ ):
691
+ for i, (resnet, attn, motion_module) in enumerate(
692
+ zip(self.resnets, self.attentions, self.motion_modules)
693
+ ):
694
+ # pop res hidden states
695
+ res_hidden_states = res_hidden_states_tuple[-1]
696
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
697
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
698
+
699
+ if self.training and self.gradient_checkpointing:
700
+
701
+ def create_custom_forward(module, return_dict=None):
702
+ def custom_forward(*inputs):
703
+ if return_dict is not None:
704
+ return module(*inputs, return_dict=return_dict)
705
+ else:
706
+ return module(*inputs)
707
+
708
+ return custom_forward
709
+
710
+ hidden_states = torch.utils.checkpoint.checkpoint(
711
+ create_custom_forward(resnet), hidden_states, temb
712
+ )
713
+ hidden_states = attn(
714
+ hidden_states,
715
+ encoder_hidden_states=encoder_hidden_states,
716
+ ).sample
717
+ if motion_module is not None:
718
+ hidden_states = torch.utils.checkpoint.checkpoint(
719
+ create_custom_forward(motion_module),
720
+ hidden_states.requires_grad_(),
721
+ temb,
722
+ encoder_hidden_states,
723
+ )
724
+
725
+ else:
726
+ hidden_states = resnet(hidden_states, temb)
727
+ hidden_states = attn(
728
+ hidden_states,
729
+ encoder_hidden_states=encoder_hidden_states,
730
+ ).sample
731
+
732
+ # add motion module
733
+ hidden_states = (
734
+ motion_module(
735
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
736
+ )
737
+ if motion_module is not None
738
+ else hidden_states
739
+ )
740
+
741
+ if self.upsamplers is not None:
742
+ for upsampler in self.upsamplers:
743
+ hidden_states = upsampler(hidden_states, upsample_size)
744
+
745
+ return hidden_states
746
+
747
+ class UpBlock3D(nn.Module):
748
+ def __init__(
749
+ self,
750
+ in_channels: int,
751
+ prev_output_channel: int,
752
+ out_channels: int,
753
+ temb_channels: int,
754
+ dropout: float = 0.0,
755
+ num_layers: int = 1,
756
+ resnet_eps: float = 1e-6,
757
+ resnet_time_scale_shift: str = "default",
758
+ resnet_act_fn: str = "swish",
759
+ resnet_groups: int = 32,
760
+ resnet_pre_norm: bool = True,
761
+ output_scale_factor=1.0,
762
+ add_upsample=True,
763
+ use_inflated_groupnorm=None,
764
+ use_motion_module=None,
765
+ motion_module_type=None,
766
+ motion_module_kwargs=None,
767
+ ):
768
+ super().__init__()
769
+ resnets = []
770
+ motion_modules = []
771
+
772
+ # use_motion_module = False
773
+ for i in range(num_layers):
774
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
775
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
776
+
777
+ resnets.append(
778
+ ResnetBlock3D(
779
+ in_channels=resnet_in_channels + res_skip_channels,
780
+ out_channels=out_channels,
781
+ temb_channels=temb_channels,
782
+ eps=resnet_eps,
783
+ groups=resnet_groups,
784
+ dropout=dropout,
785
+ time_embedding_norm=resnet_time_scale_shift,
786
+ non_linearity=resnet_act_fn,
787
+ output_scale_factor=output_scale_factor,
788
+ pre_norm=resnet_pre_norm,
789
+ use_inflated_groupnorm=use_inflated_groupnorm,
790
+ )
791
+ )
792
+ motion_modules.append(
793
+ get_motion_module(
794
+ in_channels=out_channels,
795
+ motion_module_type=motion_module_type,
796
+ motion_module_kwargs=motion_module_kwargs,
797
+ )
798
+ if use_motion_module
799
+ else None
800
+ )
801
+
802
+ self.resnets = nn.ModuleList(resnets)
803
+ self.motion_modules = nn.ModuleList(motion_modules)
804
+
805
+ if add_upsample:
806
+ self.upsamplers = nn.ModuleList(
807
+ [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
808
+ )
809
+ else:
810
+ self.upsamplers = None
811
+
812
+ self.gradient_checkpointing = False
813
+
814
+ def forward(
815
+ self,
816
+ hidden_states,
817
+ res_hidden_states_tuple,
818
+ temb=None,
819
+ upsample_size=None,
820
+ encoder_hidden_states=None,
821
+ ):
822
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
823
+ # pop res hidden states
824
+ res_hidden_states = res_hidden_states_tuple[-1]
825
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
826
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
827
+
828
+ # print(f"UpBlock3D {self.gradient_checkpointing = }")
829
+ if self.training and self.gradient_checkpointing:
830
+
831
+ def create_custom_forward(module):
832
+ def custom_forward(*inputs):
833
+ return module(*inputs)
834
+
835
+ return custom_forward
836
+
837
+ hidden_states = torch.utils.checkpoint.checkpoint(
838
+ create_custom_forward(resnet), hidden_states, temb
839
+ )
840
+ if motion_module is not None:
841
+ hidden_states = torch.utils.checkpoint.checkpoint(
842
+ create_custom_forward(motion_module),
843
+ hidden_states.requires_grad_(),
844
+ temb,
845
+ encoder_hidden_states,
846
+ )
847
+ else:
848
+ hidden_states = resnet(hidden_states, temb)
849
+ hidden_states = (
850
+ motion_module(
851
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
852
+ )
853
+ if motion_module is not None
854
+ else hidden_states
855
+ )
856
+
857
+ if self.upsamplers is not None:
858
+ for upsampler in self.upsamplers:
859
+ hidden_states = upsampler(hidden_states, upsample_size)
860
+
861
+ return hidden_states
src/pipelines/context.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: Adapted from cli
2
+ from typing import Callable, List, Optional
3
+
4
+ import numpy as np
5
+
6
+
7
+ def ordered_halving(val):
8
+ bin_str = f"{val:064b}"
9
+ bin_flip = bin_str[::-1]
10
+ as_int = int(bin_flip, 2)
11
+
12
+ return as_int / (1 << 64)
13
+
14
+
15
+ def uniform(
16
+ step: int = ...,
17
+ num_steps: Optional[int] = None,
18
+ num_frames: int = ...,
19
+ context_size: Optional[int] = None,
20
+ context_stride: int = 3,
21
+ context_overlap: int = 4,
22
+ closed_loop: bool = True,
23
+ ):
24
+ if num_frames <= context_size:
25
+ yield list(range(num_frames))
26
+ return
27
+
28
+ context_stride = min(
29
+ context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1
30
+ )
31
+
32
+ for context_step in 1 << np.arange(context_stride):
33
+ pad = int(round(num_frames * ordered_halving(step)))
34
+ for j in range(
35
+ int(ordered_halving(step) * context_step) + pad,
36
+ num_frames + pad + (0 if closed_loop else -context_overlap),
37
+ (context_size * context_step - context_overlap),
38
+ ):
39
+ yield [
40
+ e % num_frames
41
+ for e in range(j, j + context_size * context_step, context_step)
42
+ ]
43
+
44
+
45
+ def get_context_scheduler(name: str) -> Callable:
46
+ if name == "uniform":
47
+ return uniform
48
+ else:
49
+ raise ValueError(f"Unknown context_overlap policy {name}")
50
+
51
+
52
+ def get_total_steps(
53
+ scheduler,
54
+ timesteps: List[int],
55
+ num_steps: Optional[int] = None,
56
+ num_frames: int = ...,
57
+ context_size: Optional[int] = None,
58
+ context_stride: int = 3,
59
+ context_overlap: int = 4,
60
+ closed_loop: bool = True,
61
+ ):
62
+ return sum(
63
+ len(
64
+ list(
65
+ scheduler(
66
+ i,
67
+ num_steps,
68
+ num_frames,
69
+ context_size,
70
+ context_stride,
71
+ context_overlap,
72
+ )
73
+ )
74
+ )
75
+ for i in range(len(timesteps))
76
+ )
src/pipelines/pipeline_pose2vid_long.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/pipelines/pipeline_animation.py
2
+ import inspect
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import Callable, List, Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torchvision.transforms as transforms
10
+ from diffusers import DiffusionPipeline
11
+ from diffusers.image_processor import VaeImageProcessor
12
+ from diffusers.schedulers import (
13
+ DDIMScheduler,
14
+ DPMSolverMultistepScheduler,
15
+ EulerAncestralDiscreteScheduler,
16
+ EulerDiscreteScheduler,
17
+ LMSDiscreteScheduler,
18
+ PNDMScheduler,
19
+ )
20
+ from diffusers.utils import BaseOutput, deprecate, is_accelerate_available, logging
21
+ from diffusers.utils.torch_utils import randn_tensor
22
+ from einops import rearrange
23
+ from tqdm import tqdm
24
+ from transformers import CLIPImageProcessor
25
+
26
+ from src.models.mutual_self_attention import ReferenceAttentionControl
27
+ from src.pipelines.context import get_context_scheduler
28
+ from src.pipelines.utils import get_tensor_interpolation_method
29
+
30
+
31
+ @dataclass
32
+ class Pose2VideoPipelineOutput(BaseOutput):
33
+ videos: Union[torch.Tensor, np.ndarray]
34
+
35
+
36
+ class Pose2VideoPipeline(DiffusionPipeline):
37
+ _optional_components = []
38
+
39
+ def __init__(
40
+ self,
41
+ vae,
42
+ image_encoder,
43
+ reference_unet,
44
+ denoising_unet,
45
+ pose_guider,
46
+ scheduler: Union[
47
+ DDIMScheduler,
48
+ PNDMScheduler,
49
+ LMSDiscreteScheduler,
50
+ EulerDiscreteScheduler,
51
+ EulerAncestralDiscreteScheduler,
52
+ DPMSolverMultistepScheduler,
53
+ ],
54
+ image_proj_model=None,
55
+ tokenizer=None,
56
+ text_encoder=None,
57
+ ):
58
+ super().__init__()
59
+
60
+ self.register_modules(
61
+ vae=vae,
62
+ image_encoder=image_encoder,
63
+ reference_unet=reference_unet,
64
+ denoising_unet=denoising_unet,
65
+ pose_guider=pose_guider,
66
+ scheduler=scheduler,
67
+ image_proj_model=image_proj_model,
68
+ tokenizer=tokenizer,
69
+ text_encoder=text_encoder,
70
+ )
71
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
72
+ self.clip_image_processor = CLIPImageProcessor()
73
+ self.ref_image_processor = VaeImageProcessor(
74
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
75
+ )
76
+ self.cond_image_processor = VaeImageProcessor(
77
+ vae_scale_factor=self.vae_scale_factor,
78
+ do_convert_rgb=True,
79
+ do_normalize=True,
80
+ )
81
+
82
+ def enable_vae_slicing(self):
83
+ self.vae.enable_slicing()
84
+
85
+ def disable_vae_slicing(self):
86
+ self.vae.disable_slicing()
87
+
88
+ def enable_sequential_cpu_offload(self, gpu_id=0):
89
+ if is_accelerate_available():
90
+ from accelerate import cpu_offload
91
+ else:
92
+ raise ImportError("Please install accelerate via `pip install accelerate`")
93
+
94
+ device = torch.device(f"cuda:{gpu_id}")
95
+
96
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
97
+ if cpu_offloaded_model is not None:
98
+ cpu_offload(cpu_offloaded_model, device)
99
+
100
+ @property
101
+ def _execution_device(self):
102
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
103
+ return self.device
104
+ for module in self.unet.modules():
105
+ if (
106
+ hasattr(module, "_hf_hook")
107
+ and hasattr(module._hf_hook, "execution_device")
108
+ and module._hf_hook.execution_device is not None
109
+ ):
110
+ return torch.device(module._hf_hook.execution_device)
111
+ return self.device
112
+
113
+ def decode_latents(self, latents):
114
+ video_length = latents.shape[2]
115
+ latents = 1 / 0.18215 * latents
116
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
117
+ # video = self.vae.decode(latents).sample
118
+ video = []
119
+ for frame_idx in tqdm(range(latents.shape[0])):
120
+ video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
121
+ video = torch.cat(video)
122
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
123
+ video = (video / 2 + 0.5).clamp(0, 1)
124
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
125
+ video = video.cpu().float().numpy()
126
+ return video
127
+
128
+ def prepare_extra_step_kwargs(self, generator, eta):
129
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
130
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
131
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
132
+ # and should be between [0, 1]
133
+
134
+ accepts_eta = "eta" in set(
135
+ inspect.signature(self.scheduler.step).parameters.keys()
136
+ )
137
+ extra_step_kwargs = {}
138
+ if accepts_eta:
139
+ extra_step_kwargs["eta"] = eta
140
+
141
+ # check if the scheduler accepts generator
142
+ accepts_generator = "generator" in set(
143
+ inspect.signature(self.scheduler.step).parameters.keys()
144
+ )
145
+ if accepts_generator:
146
+ extra_step_kwargs["generator"] = generator
147
+ return extra_step_kwargs
148
+
149
+ def prepare_latents(
150
+ self,
151
+ batch_size,
152
+ num_channels_latents,
153
+ width,
154
+ height,
155
+ video_length,
156
+ dtype,
157
+ device,
158
+ generator,
159
+ latents=None,
160
+ ):
161
+ shape = (
162
+ batch_size,
163
+ num_channels_latents,
164
+ video_length,
165
+ height // self.vae_scale_factor,
166
+ width // self.vae_scale_factor,
167
+ )
168
+ if isinstance(generator, list) and len(generator) != batch_size:
169
+ raise ValueError(
170
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
171
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
172
+ )
173
+
174
+ if latents is None:
175
+ latents = randn_tensor(
176
+ shape, generator=generator, device=device, dtype=dtype
177
+ )
178
+ else:
179
+ latents = latents.to(device)
180
+
181
+ # scale the initial noise by the standard deviation required by the scheduler
182
+ latents = latents * self.scheduler.init_noise_sigma
183
+ return latents
184
+
185
+ def _encode_prompt(
186
+ self,
187
+ prompt,
188
+ device,
189
+ num_videos_per_prompt,
190
+ do_classifier_free_guidance,
191
+ negative_prompt,
192
+ ):
193
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
194
+
195
+ text_inputs = self.tokenizer(
196
+ prompt,
197
+ padding="max_length",
198
+ max_length=self.tokenizer.model_max_length,
199
+ truncation=True,
200
+ return_tensors="pt",
201
+ )
202
+ text_input_ids = text_inputs.input_ids
203
+ untruncated_ids = self.tokenizer(
204
+ prompt, padding="longest", return_tensors="pt"
205
+ ).input_ids
206
+
207
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
208
+ text_input_ids, untruncated_ids
209
+ ):
210
+ removed_text = self.tokenizer.batch_decode(
211
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
212
+ )
213
+
214
+ if (
215
+ hasattr(self.text_encoder.config, "use_attention_mask")
216
+ and self.text_encoder.config.use_attention_mask
217
+ ):
218
+ attention_mask = text_inputs.attention_mask.to(device)
219
+ else:
220
+ attention_mask = None
221
+
222
+ text_embeddings = self.text_encoder(
223
+ text_input_ids.to(device),
224
+ attention_mask=attention_mask,
225
+ )
226
+ text_embeddings = text_embeddings[0]
227
+
228
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
229
+ bs_embed, seq_len, _ = text_embeddings.shape
230
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
231
+ text_embeddings = text_embeddings.view(
232
+ bs_embed * num_videos_per_prompt, seq_len, -1
233
+ )
234
+
235
+ # get unconditional embeddings for classifier free guidance
236
+ if do_classifier_free_guidance:
237
+ uncond_tokens: List[str]
238
+ if negative_prompt is None:
239
+ uncond_tokens = [""] * batch_size
240
+ elif type(prompt) is not type(negative_prompt):
241
+ raise TypeError(
242
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
243
+ f" {type(prompt)}."
244
+ )
245
+ elif isinstance(negative_prompt, str):
246
+ uncond_tokens = [negative_prompt]
247
+ elif batch_size != len(negative_prompt):
248
+ raise ValueError(
249
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
250
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
251
+ " the batch size of `prompt`."
252
+ )
253
+ else:
254
+ uncond_tokens = negative_prompt
255
+
256
+ max_length = text_input_ids.shape[-1]
257
+ uncond_input = self.tokenizer(
258
+ uncond_tokens,
259
+ padding="max_length",
260
+ max_length=max_length,
261
+ truncation=True,
262
+ return_tensors="pt",
263
+ )
264
+
265
+ if (
266
+ hasattr(self.text_encoder.config, "use_attention_mask")
267
+ and self.text_encoder.config.use_attention_mask
268
+ ):
269
+ attention_mask = uncond_input.attention_mask.to(device)
270
+ else:
271
+ attention_mask = None
272
+
273
+ uncond_embeddings = self.text_encoder(
274
+ uncond_input.input_ids.to(device),
275
+ attention_mask=attention_mask,
276
+ )
277
+ uncond_embeddings = uncond_embeddings[0]
278
+
279
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
280
+ seq_len = uncond_embeddings.shape[1]
281
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
282
+ uncond_embeddings = uncond_embeddings.view(
283
+ batch_size * num_videos_per_prompt, seq_len, -1
284
+ )
285
+
286
+ # For classifier free guidance, we need to do two forward passes.
287
+ # Here we concatenate the unconditional and text embeddings into a single batch
288
+ # to avoid doing two forward passes
289
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
290
+
291
+ return text_embeddings
292
+
293
+ def interpolate_latents(
294
+ self, latents: torch.Tensor, interpolation_factor: int, device
295
+ ):
296
+ if interpolation_factor < 2:
297
+ return latents
298
+
299
+ new_latents = torch.zeros(
300
+ (
301
+ latents.shape[0],
302
+ latents.shape[1],
303
+ ((latents.shape[2] - 1) * interpolation_factor) + 1,
304
+ latents.shape[3],
305
+ latents.shape[4],
306
+ ),
307
+ device=latents.device,
308
+ dtype=latents.dtype,
309
+ )
310
+
311
+ org_video_length = latents.shape[2]
312
+ rate = [i / interpolation_factor for i in range(interpolation_factor)][1:]
313
+
314
+ new_index = 0
315
+
316
+ v0 = None
317
+ v1 = None
318
+
319
+ for i0, i1 in zip(range(org_video_length), range(org_video_length)[1:]):
320
+ v0 = latents[:, :, i0, :, :]
321
+ v1 = latents[:, :, i1, :, :]
322
+
323
+ new_latents[:, :, new_index, :, :] = v0
324
+ new_index += 1
325
+
326
+ for f in rate:
327
+ v = get_tensor_interpolation_method()(
328
+ v0.to(device=device), v1.to(device=device), f
329
+ )
330
+ new_latents[:, :, new_index, :, :] = v.to(latents.device)
331
+ new_index += 1
332
+
333
+ new_latents[:, :, new_index, :, :] = v1
334
+ new_index += 1
335
+
336
+ return new_latents
337
+
338
+ @torch.no_grad()
339
+ def __call__(
340
+ self,
341
+ ref_image,
342
+ pose_images,
343
+ ref_pose_image,
344
+ width,
345
+ height,
346
+ video_length,
347
+ num_inference_steps,
348
+ guidance_scale,
349
+ num_images_per_prompt=1,
350
+ eta: float = 0.0,
351
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
352
+ output_type: Optional[str] = "tensor",
353
+ return_dict: bool = True,
354
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
355
+ callback_steps: Optional[int] = 1,
356
+ context_schedule="uniform",
357
+ context_frames=16,
358
+ context_stride=1,
359
+ context_overlap=4,
360
+ context_batch_size=1,
361
+ interpolation_factor=1,
362
+ **kwargs,
363
+ ):
364
+ # Default height and width to unet
365
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
366
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
367
+
368
+ device = self._execution_device
369
+
370
+ do_classifier_free_guidance = guidance_scale > 1.0
371
+
372
+ # Prepare timesteps
373
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
374
+ timesteps = self.scheduler.timesteps
375
+
376
+ batch_size = 1
377
+
378
+ # Prepare clip image embeds
379
+ clip_image = self.clip_image_processor.preprocess(
380
+ ref_image.resize((224, 224)), return_tensors="pt"
381
+ ).pixel_values
382
+ clip_image_embeds = self.image_encoder(
383
+ clip_image.to(device, dtype=self.image_encoder.dtype)
384
+ ).image_embeds
385
+ encoder_hidden_states = clip_image_embeds.unsqueeze(1)
386
+ uncond_encoder_hidden_states = torch.zeros_like(encoder_hidden_states)
387
+
388
+ if do_classifier_free_guidance:
389
+ encoder_hidden_states = torch.cat(
390
+ [uncond_encoder_hidden_states, encoder_hidden_states], dim=0
391
+ )
392
+
393
+ reference_control_writer = ReferenceAttentionControl(
394
+ self.reference_unet,
395
+ do_classifier_free_guidance=do_classifier_free_guidance,
396
+ mode="write",
397
+ batch_size=batch_size,
398
+ fusion_blocks="full",
399
+ )
400
+ reference_control_reader = ReferenceAttentionControl(
401
+ self.denoising_unet,
402
+ do_classifier_free_guidance=do_classifier_free_guidance,
403
+ mode="read",
404
+ batch_size=batch_size,
405
+ fusion_blocks="full",
406
+ )
407
+
408
+ num_channels_latents = self.denoising_unet.in_channels
409
+ latents = self.prepare_latents(
410
+ batch_size * num_images_per_prompt,
411
+ num_channels_latents,
412
+ width,
413
+ height,
414
+ video_length,
415
+ clip_image_embeds.dtype,
416
+ device,
417
+ generator,
418
+ )
419
+
420
+ # Prepare extra step kwargs.
421
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
422
+
423
+ # Prepare ref image latents
424
+ ref_image_tensor = self.ref_image_processor.preprocess(
425
+ ref_image, height=height, width=width
426
+ ) # (bs, c, width, height)
427
+ ref_image_tensor = ref_image_tensor.to(
428
+ dtype=self.vae.dtype, device=self.vae.device
429
+ )
430
+ ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
431
+ ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
432
+
433
+ # Prepare a list of pose condition images
434
+ pose_cond_tensor_list = []
435
+ for pose_image in pose_images:
436
+ pose_cond_tensor = self.cond_image_processor.preprocess(
437
+ pose_image, height=height, width=width
438
+ )
439
+ pose_cond_tensor = pose_cond_tensor.unsqueeze(2) # (bs, c, 1, h, w)
440
+ pose_cond_tensor_list.append(pose_cond_tensor)
441
+ pose_cond_tensor = torch.cat(pose_cond_tensor_list, dim=2) # (bs, c, t, h, w)
442
+
443
+ pose_cond_tensor = pose_cond_tensor.to(
444
+ device=device, dtype=self.pose_guider.dtype
445
+ )
446
+
447
+ ref_pose_tensor = self.cond_image_processor.preprocess(
448
+ ref_pose_image, height=height, width=width
449
+ )
450
+ ref_pose_tensor = ref_pose_tensor.to(
451
+ device=device, dtype=self.pose_guider.dtype
452
+ )
453
+
454
+ context_scheduler = get_context_scheduler(context_schedule)
455
+
456
+ # denoising loop
457
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
458
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
459
+ for i, t in enumerate(timesteps):
460
+ noise_pred = torch.zeros(
461
+ (
462
+ latents.shape[0] * (2 if do_classifier_free_guidance else 1),
463
+ *latents.shape[1:],
464
+ ),
465
+ device=latents.device,
466
+ dtype=latents.dtype,
467
+ )
468
+ counter = torch.zeros(
469
+ (1, 1, latents.shape[2], 1, 1),
470
+ device=latents.device,
471
+ dtype=latents.dtype,
472
+ )
473
+
474
+ # 1. Forward reference image
475
+ if i == 0:
476
+ self.reference_unet(
477
+ ref_image_latents.repeat(
478
+ (2 if do_classifier_free_guidance else 1), 1, 1, 1
479
+ ),
480
+ torch.zeros_like(t),
481
+ # t,
482
+ encoder_hidden_states=encoder_hidden_states,
483
+ return_dict=False,
484
+ )
485
+ reference_control_reader.update(reference_control_writer)
486
+
487
+ context_queue = list(
488
+ context_scheduler(
489
+ 0,
490
+ num_inference_steps,
491
+ latents.shape[2],
492
+ context_frames,
493
+ context_stride,
494
+ 0,
495
+ )
496
+ )
497
+ num_context_batches = math.ceil(len(context_queue) / context_batch_size)
498
+
499
+ context_queue = list(
500
+ context_scheduler(
501
+ 0,
502
+ num_inference_steps,
503
+ latents.shape[2],
504
+ context_frames,
505
+ context_stride,
506
+ context_overlap,
507
+ )
508
+ )
509
+
510
+ num_context_batches = math.ceil(len(context_queue) / context_batch_size)
511
+ global_context = []
512
+ for i in range(num_context_batches):
513
+ global_context.append(
514
+ context_queue[
515
+ i * context_batch_size : (i + 1) * context_batch_size
516
+ ]
517
+ )
518
+
519
+ for context in global_context:
520
+ # 3.1 expand the latents if we are doing classifier free guidance
521
+ latent_model_input = (
522
+ torch.cat([latents[:, :, c] for c in context])
523
+ .to(device)
524
+ .repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)
525
+ )
526
+ latent_model_input = self.scheduler.scale_model_input(
527
+ latent_model_input, t
528
+ )
529
+ b, c, f, h, w = latent_model_input.shape
530
+
531
+ pose_cond_input = (
532
+ torch.cat([pose_cond_tensor[:, :, c] for c in context])
533
+ .to(device)
534
+ .repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)
535
+ )
536
+ pose_fea = self.pose_guider(pose_cond_input, ref_pose_tensor)
537
+
538
+ pred = self.denoising_unet(
539
+ latent_model_input,
540
+ t,
541
+ encoder_hidden_states=encoder_hidden_states[:b],
542
+ pose_cond_fea=pose_fea,
543
+ return_dict=False,
544
+ )[0]
545
+
546
+ for j, c in enumerate(context):
547
+ noise_pred[:, :, c] = noise_pred[:, :, c] + pred
548
+ counter[:, :, c] = counter[:, :, c] + 1
549
+
550
+ # perform guidance
551
+ if do_classifier_free_guidance:
552
+ noise_pred_uncond, noise_pred_text = (noise_pred / counter).chunk(2)
553
+ noise_pred = noise_pred_uncond + guidance_scale * (
554
+ noise_pred_text - noise_pred_uncond
555
+ )
556
+
557
+ latents = self.scheduler.step(
558
+ noise_pred, t, latents, **extra_step_kwargs
559
+ ).prev_sample
560
+
561
+ if i == len(timesteps) - 1 or (
562
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
563
+ ):
564
+ progress_bar.update()
565
+ if callback is not None and i % callback_steps == 0:
566
+ step_idx = i // getattr(self.scheduler, "order", 1)
567
+ callback(step_idx, t, latents)
568
+
569
+ reference_control_reader.clear()
570
+ reference_control_writer.clear()
571
+
572
+ if interpolation_factor > 0:
573
+ latents = self.interpolate_latents(latents, interpolation_factor, device)
574
+ # Post-processing
575
+ images = self.decode_latents(latents) # (b, c, f, h, w)
576
+
577
+ # Convert to tensor
578
+ if output_type == "tensor":
579
+ images = torch.from_numpy(images)
580
+
581
+ if not return_dict:
582
+ return images
583
+
584
+ return Pose2VideoPipelineOutput(videos=images)
src/pipelines/utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ tensor_interpolation = None
4
+
5
+
6
+ def get_tensor_interpolation_method():
7
+ return tensor_interpolation
8
+
9
+
10
+ def set_tensor_interpolation_method(is_slerp):
11
+ global tensor_interpolation
12
+ tensor_interpolation = slerp if is_slerp else linear
13
+
14
+
15
+ def linear(v1, v2, t):
16
+ return (1.0 - t) * v1 + t * v2
17
+
18
+
19
+ def slerp(
20
+ v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995
21
+ ) -> torch.Tensor:
22
+ u0 = v0 / v0.norm()
23
+ u1 = v1 / v1.norm()
24
+ dot = (u0 * u1).sum()
25
+ if dot.abs() > DOT_THRESHOLD:
26
+ # logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.')
27
+ return (1.0 - t) * v0 + t * v1
28
+ omega = dot.acos()
29
+ return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()
src/utils/audio_util.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+
4
+ import librosa
5
+ import numpy as np
6
+ from transformers import Wav2Vec2FeatureExtractor
7
+
8
+
9
+ class DataProcessor:
10
+ def __init__(self, sampling_rate, wav2vec_model_path):
11
+ self._processor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True)
12
+ self._sampling_rate = sampling_rate
13
+
14
+ def extract_feature(self, audio_path):
15
+ speech_array, sampling_rate = librosa.load(audio_path, sr=self._sampling_rate)
16
+ input_value = np.squeeze(self._processor(speech_array, sampling_rate=sampling_rate).input_values)
17
+ return input_value
18
+
19
+
20
+ def prepare_audio_feature(wav_file, fps=30, sampling_rate=16000, wav2vec_model_path=None):
21
+ data_preprocessor = DataProcessor(sampling_rate, wav2vec_model_path)
22
+
23
+ input_value = data_preprocessor.extract_feature(wav_file)
24
+ seq_len = math.ceil(len(input_value)/sampling_rate*fps)
25
+ return {
26
+ "audio_feature": input_value,
27
+ "seq_len": seq_len
28
+ }
29
+
30
+
src/utils/draw_util.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import mediapipe as mp
3
+ import numpy as np
4
+ from mediapipe.framework.formats import landmark_pb2
5
+
6
+ class FaceMeshVisualizer:
7
+ def __init__(self, forehead_edge=False):
8
+ self.mp_drawing = mp.solutions.drawing_utils
9
+ mp_face_mesh = mp.solutions.face_mesh
10
+ self.mp_face_mesh = mp_face_mesh
11
+ self.forehead_edge = forehead_edge
12
+
13
+ DrawingSpec = mp.solutions.drawing_styles.DrawingSpec
14
+ f_thick = 2
15
+ f_rad = 1
16
+ right_iris_draw = DrawingSpec(color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad)
17
+ right_eye_draw = DrawingSpec(color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad)
18
+ right_eyebrow_draw = DrawingSpec(color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad)
19
+ left_iris_draw = DrawingSpec(color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad)
20
+ left_eye_draw = DrawingSpec(color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad)
21
+ left_eyebrow_draw = DrawingSpec(color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad)
22
+ head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad)
23
+
24
+ mouth_draw_obl = DrawingSpec(color=(10, 180, 20), thickness=f_thick, circle_radius=f_rad)
25
+ mouth_draw_obr = DrawingSpec(color=(20, 10, 180), thickness=f_thick, circle_radius=f_rad)
26
+
27
+ mouth_draw_ibl = DrawingSpec(color=(100, 100, 30), thickness=f_thick, circle_radius=f_rad)
28
+ mouth_draw_ibr = DrawingSpec(color=(100, 150, 50), thickness=f_thick, circle_radius=f_rad)
29
+
30
+ mouth_draw_otl = DrawingSpec(color=(20, 80, 100), thickness=f_thick, circle_radius=f_rad)
31
+ mouth_draw_otr = DrawingSpec(color=(80, 100, 20), thickness=f_thick, circle_radius=f_rad)
32
+
33
+ mouth_draw_itl = DrawingSpec(color=(120, 100, 200), thickness=f_thick, circle_radius=f_rad)
34
+ mouth_draw_itr = DrawingSpec(color=(150 ,120, 100), thickness=f_thick, circle_radius=f_rad)
35
+
36
+ FACEMESH_LIPS_OUTER_BOTTOM_LEFT = [(61,146),(146,91),(91,181),(181,84),(84,17)]
37
+ FACEMESH_LIPS_OUTER_BOTTOM_RIGHT = [(17,314),(314,405),(405,321),(321,375),(375,291)]
38
+
39
+ FACEMESH_LIPS_INNER_BOTTOM_LEFT = [(78,95),(95,88),(88,178),(178,87),(87,14)]
40
+ FACEMESH_LIPS_INNER_BOTTOM_RIGHT = [(14,317),(317,402),(402,318),(318,324),(324,308)]
41
+
42
+ FACEMESH_LIPS_OUTER_TOP_LEFT = [(61,185),(185,40),(40,39),(39,37),(37,0)]
43
+ FACEMESH_LIPS_OUTER_TOP_RIGHT = [(0,267),(267,269),(269,270),(270,409),(409,291)]
44
+
45
+ FACEMESH_LIPS_INNER_TOP_LEFT = [(78,191),(191,80),(80,81),(81,82),(82,13)]
46
+ FACEMESH_LIPS_INNER_TOP_RIGHT = [(13,312),(312,311),(311,310),(310,415),(415,308)]
47
+
48
+ FACEMESH_CUSTOM_FACE_OVAL = [(176, 149), (150, 136), (356, 454), (58, 132), (152, 148), (361, 288), (251, 389), (132, 93), (389, 356), (400, 377), (136, 172), (377, 152), (323, 361), (172, 58), (454, 323), (365, 379), (379, 378), (148, 176), (93, 234), (397, 365), (149, 150), (288, 397), (234, 127), (378, 400), (127, 162), (162, 21)]
49
+
50
+ # mp_face_mesh.FACEMESH_CONTOURS has all the items we care about.
51
+ face_connection_spec = {}
52
+ if self.forehead_edge:
53
+ for edge in mp_face_mesh.FACEMESH_FACE_OVAL:
54
+ face_connection_spec[edge] = head_draw
55
+ else:
56
+ for edge in FACEMESH_CUSTOM_FACE_OVAL:
57
+ face_connection_spec[edge] = head_draw
58
+ for edge in mp_face_mesh.FACEMESH_LEFT_EYE:
59
+ face_connection_spec[edge] = left_eye_draw
60
+ for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW:
61
+ face_connection_spec[edge] = left_eyebrow_draw
62
+ # for edge in mp_face_mesh.FACEMESH_LEFT_IRIS:
63
+ # face_connection_spec[edge] = left_iris_draw
64
+ for edge in mp_face_mesh.FACEMESH_RIGHT_EYE:
65
+ face_connection_spec[edge] = right_eye_draw
66
+ for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW:
67
+ face_connection_spec[edge] = right_eyebrow_draw
68
+ # for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS:
69
+ # face_connection_spec[edge] = right_iris_draw
70
+ # for edge in mp_face_mesh.FACEMESH_LIPS:
71
+ # face_connection_spec[edge] = mouth_draw
72
+
73
+ for edge in FACEMESH_LIPS_OUTER_BOTTOM_LEFT:
74
+ face_connection_spec[edge] = mouth_draw_obl
75
+ for edge in FACEMESH_LIPS_OUTER_BOTTOM_RIGHT:
76
+ face_connection_spec[edge] = mouth_draw_obr
77
+ for edge in FACEMESH_LIPS_INNER_BOTTOM_LEFT:
78
+ face_connection_spec[edge] = mouth_draw_ibl
79
+ for edge in FACEMESH_LIPS_INNER_BOTTOM_RIGHT:
80
+ face_connection_spec[edge] = mouth_draw_ibr
81
+ for edge in FACEMESH_LIPS_OUTER_TOP_LEFT:
82
+ face_connection_spec[edge] = mouth_draw_otl
83
+ for edge in FACEMESH_LIPS_OUTER_TOP_RIGHT:
84
+ face_connection_spec[edge] = mouth_draw_otr
85
+ for edge in FACEMESH_LIPS_INNER_TOP_LEFT:
86
+ face_connection_spec[edge] = mouth_draw_itl
87
+ for edge in FACEMESH_LIPS_INNER_TOP_RIGHT:
88
+ face_connection_spec[edge] = mouth_draw_itr
89
+
90
+
91
+ iris_landmark_spec = {468: right_iris_draw, 473: left_iris_draw}
92
+
93
+ self.face_connection_spec = face_connection_spec
94
+ def draw_pupils(self, image, landmark_list, drawing_spec, halfwidth: int = 2):
95
+ """We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all
96
+ landmarks. Until our PR is merged into mediapipe, we need this separate method."""
97
+ if len(image.shape) != 3:
98
+ raise ValueError("Input image must be H,W,C.")
99
+ image_rows, image_cols, image_channels = image.shape
100
+ if image_channels != 3: # BGR channels
101
+ raise ValueError('Input image must contain three channel bgr data.')
102
+ for idx, landmark in enumerate(landmark_list.landmark):
103
+ if (
104
+ (landmark.HasField('visibility') and landmark.visibility < 0.9) or
105
+ (landmark.HasField('presence') and landmark.presence < 0.5)
106
+ ):
107
+ continue
108
+ if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0:
109
+ continue
110
+ image_x = int(image_cols*landmark.x)
111
+ image_y = int(image_rows*landmark.y)
112
+ draw_color = None
113
+ if isinstance(drawing_spec, Mapping):
114
+ if drawing_spec.get(idx) is None:
115
+ continue
116
+ else:
117
+ draw_color = drawing_spec[idx].color
118
+ elif isinstance(drawing_spec, DrawingSpec):
119
+ draw_color = drawing_spec.color
120
+ image[image_y-halfwidth:image_y+halfwidth, image_x-halfwidth:image_x+halfwidth, :] = draw_color
121
+
122
+
123
+
124
+ def draw_landmarks(self, image_size, keypoints, normed=False):
125
+ ini_size = [512, 512]
126
+ image = np.zeros([ini_size[1], ini_size[0], 3], dtype=np.uint8)
127
+ new_landmarks = landmark_pb2.NormalizedLandmarkList()
128
+ for i in range(keypoints.shape[0]):
129
+ landmark = new_landmarks.landmark.add()
130
+ if normed:
131
+ landmark.x = keypoints[i, 0]
132
+ landmark.y = keypoints[i, 1]
133
+ else:
134
+ landmark.x = keypoints[i, 0] / image_size[0]
135
+ landmark.y = keypoints[i, 1] / image_size[1]
136
+ landmark.z = 1.0
137
+
138
+ self.mp_drawing.draw_landmarks(
139
+ image=image,
140
+ landmark_list=new_landmarks,
141
+ connections=self.face_connection_spec.keys(),
142
+ landmark_drawing_spec=None,
143
+ connection_drawing_spec=self.face_connection_spec
144
+ )
145
+ # draw_pupils(image, face_landmarks, iris_landmark_spec, 2)
146
+ image = cv2.resize(image, (image_size[0], image_size[1]))
147
+
148
+ return image
149
+
src/utils/face_landmark.py ADDED
@@ -0,0 +1,3305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The MediaPipe Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """MediaPipe face landmarker task."""
15
+
16
+ import dataclasses
17
+ import enum
18
+ from typing import Callable, Mapping, Optional, List
19
+
20
+ import numpy as np
21
+
22
+ from mediapipe.framework.formats import classification_pb2
23
+ from mediapipe.framework.formats import landmark_pb2
24
+ from mediapipe.framework.formats import matrix_data_pb2
25
+ from mediapipe.python import packet_creator
26
+ from mediapipe.python import packet_getter
27
+ from mediapipe.python._framework_bindings import image as image_module
28
+ from mediapipe.python._framework_bindings import packet as packet_module
29
+ # pylint: disable=unused-import
30
+ from mediapipe.tasks.cc.vision.face_geometry.proto import face_geometry_pb2
31
+ # pylint: enable=unused-import
32
+ from mediapipe.tasks.cc.vision.face_landmarker.proto import face_landmarker_graph_options_pb2
33
+ from mediapipe.tasks.python.components.containers import category as category_module
34
+ from mediapipe.tasks.python.components.containers import landmark as landmark_module
35
+ from mediapipe.tasks.python.core import base_options as base_options_module
36
+ from mediapipe.tasks.python.core import task_info as task_info_module
37
+ from mediapipe.tasks.python.core.optional_dependencies import doc_controls
38
+ from mediapipe.tasks.python.vision.core import base_vision_task_api
39
+ from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
40
+ from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
41
+
42
+ _BaseOptions = base_options_module.BaseOptions
43
+ _FaceLandmarkerGraphOptionsProto = (
44
+ face_landmarker_graph_options_pb2.FaceLandmarkerGraphOptions
45
+ )
46
+ _LayoutEnum = matrix_data_pb2.MatrixData.Layout
47
+ _RunningMode = running_mode_module.VisionTaskRunningMode
48
+ _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
49
+ _TaskInfo = task_info_module.TaskInfo
50
+
51
+ _IMAGE_IN_STREAM_NAME = 'image_in'
52
+ _IMAGE_OUT_STREAM_NAME = 'image_out'
53
+ _IMAGE_TAG = 'IMAGE'
54
+ _NORM_RECT_STREAM_NAME = 'norm_rect_in'
55
+ _NORM_RECT_TAG = 'NORM_RECT'
56
+ _NORM_LANDMARKS_STREAM_NAME = 'norm_landmarks'
57
+ _NORM_LANDMARKS_TAG = 'NORM_LANDMARKS'
58
+ _BLENDSHAPES_STREAM_NAME = 'blendshapes'
59
+ _BLENDSHAPES_TAG = 'BLENDSHAPES'
60
+ _FACE_GEOMETRY_STREAM_NAME = 'face_geometry'
61
+ _FACE_GEOMETRY_TAG = 'FACE_GEOMETRY'
62
+ _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.face_landmarker.FaceLandmarkerGraph'
63
+ _MICRO_SECONDS_PER_MILLISECOND = 1000
64
+
65
+
66
+ class Blendshapes(enum.IntEnum):
67
+ """The 52 blendshape coefficients."""
68
+
69
+ NEUTRAL = 0
70
+ BROW_DOWN_LEFT = 1
71
+ BROW_DOWN_RIGHT = 2
72
+ BROW_INNER_UP = 3
73
+ BROW_OUTER_UP_LEFT = 4
74
+ BROW_OUTER_UP_RIGHT = 5
75
+ CHEEK_PUFF = 6
76
+ CHEEK_SQUINT_LEFT = 7
77
+ CHEEK_SQUINT_RIGHT = 8
78
+ EYE_BLINK_LEFT = 9
79
+ EYE_BLINK_RIGHT = 10
80
+ EYE_LOOK_DOWN_LEFT = 11
81
+ EYE_LOOK_DOWN_RIGHT = 12
82
+ EYE_LOOK_IN_LEFT = 13
83
+ EYE_LOOK_IN_RIGHT = 14
84
+ EYE_LOOK_OUT_LEFT = 15
85
+ EYE_LOOK_OUT_RIGHT = 16
86
+ EYE_LOOK_UP_LEFT = 17
87
+ EYE_LOOK_UP_RIGHT = 18
88
+ EYE_SQUINT_LEFT = 19
89
+ EYE_SQUINT_RIGHT = 20
90
+ EYE_WIDE_LEFT = 21
91
+ EYE_WIDE_RIGHT = 22
92
+ JAW_FORWARD = 23
93
+ JAW_LEFT = 24
94
+ JAW_OPEN = 25
95
+ JAW_RIGHT = 26
96
+ MOUTH_CLOSE = 27
97
+ MOUTH_DIMPLE_LEFT = 28
98
+ MOUTH_DIMPLE_RIGHT = 29
99
+ MOUTH_FROWN_LEFT = 30
100
+ MOUTH_FROWN_RIGHT = 31
101
+ MOUTH_FUNNEL = 32
102
+ MOUTH_LEFT = 33
103
+ MOUTH_LOWER_DOWN_LEFT = 34
104
+ MOUTH_LOWER_DOWN_RIGHT = 35
105
+ MOUTH_PRESS_LEFT = 36
106
+ MOUTH_PRESS_RIGHT = 37
107
+ MOUTH_PUCKER = 38
108
+ MOUTH_RIGHT = 39
109
+ MOUTH_ROLL_LOWER = 40
110
+ MOUTH_ROLL_UPPER = 41
111
+ MOUTH_SHRUG_LOWER = 42
112
+ MOUTH_SHRUG_UPPER = 43
113
+ MOUTH_SMILE_LEFT = 44
114
+ MOUTH_SMILE_RIGHT = 45
115
+ MOUTH_STRETCH_LEFT = 46
116
+ MOUTH_STRETCH_RIGHT = 47
117
+ MOUTH_UPPER_UP_LEFT = 48
118
+ MOUTH_UPPER_UP_RIGHT = 49
119
+ NOSE_SNEER_LEFT = 50
120
+ NOSE_SNEER_RIGHT = 51
121
+
122
+
123
+ class FaceLandmarksConnections:
124
+ """The connections between face landmarks."""
125
+
126
+ @dataclasses.dataclass
127
+ class Connection:
128
+ """The connection class for face landmarks."""
129
+
130
+ start: int
131
+ end: int
132
+
133
+ FACE_LANDMARKS_LIPS: List[Connection] = [
134
+ Connection(61, 146),
135
+ Connection(146, 91),
136
+ Connection(91, 181),
137
+ Connection(181, 84),
138
+ Connection(84, 17),
139
+ Connection(17, 314),
140
+ Connection(314, 405),
141
+ Connection(405, 321),
142
+ Connection(321, 375),
143
+ Connection(375, 291),
144
+ Connection(61, 185),
145
+ Connection(185, 40),
146
+ Connection(40, 39),
147
+ Connection(39, 37),
148
+ Connection(37, 0),
149
+ Connection(0, 267),
150
+ Connection(267, 269),
151
+ Connection(269, 270),
152
+ Connection(270, 409),
153
+ Connection(409, 291),
154
+ Connection(78, 95),
155
+ Connection(95, 88),
156
+ Connection(88, 178),
157
+ Connection(178, 87),
158
+ Connection(87, 14),
159
+ Connection(14, 317),
160
+ Connection(317, 402),
161
+ Connection(402, 318),
162
+ Connection(318, 324),
163
+ Connection(324, 308),
164
+ Connection(78, 191),
165
+ Connection(191, 80),
166
+ Connection(80, 81),
167
+ Connection(81, 82),
168
+ Connection(82, 13),
169
+ Connection(13, 312),
170
+ Connection(312, 311),
171
+ Connection(311, 310),
172
+ Connection(310, 415),
173
+ Connection(415, 308),
174
+ ]
175
+
176
+ FACE_LANDMARKS_LEFT_EYE: List[Connection] = [
177
+ Connection(263, 249),
178
+ Connection(249, 390),
179
+ Connection(390, 373),
180
+ Connection(373, 374),
181
+ Connection(374, 380),
182
+ Connection(380, 381),
183
+ Connection(381, 382),
184
+ Connection(382, 362),
185
+ Connection(263, 466),
186
+ Connection(466, 388),
187
+ Connection(388, 387),
188
+ Connection(387, 386),
189
+ Connection(386, 385),
190
+ Connection(385, 384),
191
+ Connection(384, 398),
192
+ Connection(398, 362),
193
+ ]
194
+
195
+ FACE_LANDMARKS_LEFT_EYEBROW: List[Connection] = [
196
+ Connection(276, 283),
197
+ Connection(283, 282),
198
+ Connection(282, 295),
199
+ Connection(295, 285),
200
+ Connection(300, 293),
201
+ Connection(293, 334),
202
+ Connection(334, 296),
203
+ Connection(296, 336),
204
+ ]
205
+
206
+ FACE_LANDMARKS_LEFT_IRIS: List[Connection] = [
207
+ Connection(474, 475),
208
+ Connection(475, 476),
209
+ Connection(476, 477),
210
+ Connection(477, 474),
211
+ ]
212
+
213
+ FACE_LANDMARKS_RIGHT_EYE: List[Connection] = [
214
+ Connection(33, 7),
215
+ Connection(7, 163),
216
+ Connection(163, 144),
217
+ Connection(144, 145),
218
+ Connection(145, 153),
219
+ Connection(153, 154),
220
+ Connection(154, 155),
221
+ Connection(155, 133),
222
+ Connection(33, 246),
223
+ Connection(246, 161),
224
+ Connection(161, 160),
225
+ Connection(160, 159),
226
+ Connection(159, 158),
227
+ Connection(158, 157),
228
+ Connection(157, 173),
229
+ Connection(173, 133),
230
+ ]
231
+
232
+ FACE_LANDMARKS_RIGHT_EYEBROW: List[Connection] = [
233
+ Connection(46, 53),
234
+ Connection(53, 52),
235
+ Connection(52, 65),
236
+ Connection(65, 55),
237
+ Connection(70, 63),
238
+ Connection(63, 105),
239
+ Connection(105, 66),
240
+ Connection(66, 107),
241
+ ]
242
+
243
+ FACE_LANDMARKS_RIGHT_IRIS: List[Connection] = [
244
+ Connection(469, 470),
245
+ Connection(470, 471),
246
+ Connection(471, 472),
247
+ Connection(472, 469),
248
+ ]
249
+
250
+ FACE_LANDMARKS_FACE_OVAL: List[Connection] = [
251
+ Connection(10, 338),
252
+ Connection(338, 297),
253
+ Connection(297, 332),
254
+ Connection(332, 284),
255
+ Connection(284, 251),
256
+ Connection(251, 389),
257
+ Connection(389, 356),
258
+ Connection(356, 454),
259
+ Connection(454, 323),
260
+ Connection(323, 361),
261
+ Connection(361, 288),
262
+ Connection(288, 397),
263
+ Connection(397, 365),
264
+ Connection(365, 379),
265
+ Connection(379, 378),
266
+ Connection(378, 400),
267
+ Connection(400, 377),
268
+ Connection(377, 152),
269
+ Connection(152, 148),
270
+ Connection(148, 176),
271
+ Connection(176, 149),
272
+ Connection(149, 150),
273
+ Connection(150, 136),
274
+ Connection(136, 172),
275
+ Connection(172, 58),
276
+ Connection(58, 132),
277
+ Connection(132, 93),
278
+ Connection(93, 234),
279
+ Connection(234, 127),
280
+ Connection(127, 162),
281
+ Connection(162, 21),
282
+ Connection(21, 54),
283
+ Connection(54, 103),
284
+ Connection(103, 67),
285
+ Connection(67, 109),
286
+ Connection(109, 10),
287
+ ]
288
+
289
+ FACE_LANDMARKS_CONTOURS: List[Connection] = (
290
+ FACE_LANDMARKS_LIPS
291
+ + FACE_LANDMARKS_LEFT_EYE
292
+ + FACE_LANDMARKS_LEFT_EYEBROW
293
+ + FACE_LANDMARKS_RIGHT_EYE
294
+ + FACE_LANDMARKS_RIGHT_EYEBROW
295
+ + FACE_LANDMARKS_FACE_OVAL
296
+ )
297
+
298
+ FACE_LANDMARKS_TESSELATION: List[Connection] = [
299
+ Connection(127, 34),
300
+ Connection(34, 139),
301
+ Connection(139, 127),
302
+ Connection(11, 0),
303
+ Connection(0, 37),
304
+ Connection(37, 11),
305
+ Connection(232, 231),
306
+ Connection(231, 120),
307
+ Connection(120, 232),
308
+ Connection(72, 37),
309
+ Connection(37, 39),
310
+ Connection(39, 72),
311
+ Connection(128, 121),
312
+ Connection(121, 47),
313
+ Connection(47, 128),
314
+ Connection(232, 121),
315
+ Connection(121, 128),
316
+ Connection(128, 232),
317
+ Connection(104, 69),
318
+ Connection(69, 67),
319
+ Connection(67, 104),
320
+ Connection(175, 171),
321
+ Connection(171, 148),
322
+ Connection(148, 175),
323
+ Connection(118, 50),
324
+ Connection(50, 101),
325
+ Connection(101, 118),
326
+ Connection(73, 39),
327
+ Connection(39, 40),
328
+ Connection(40, 73),
329
+ Connection(9, 151),
330
+ Connection(151, 108),
331
+ Connection(108, 9),
332
+ Connection(48, 115),
333
+ Connection(115, 131),
334
+ Connection(131, 48),
335
+ Connection(194, 204),
336
+ Connection(204, 211),
337
+ Connection(211, 194),
338
+ Connection(74, 40),
339
+ Connection(40, 185),
340
+ Connection(185, 74),
341
+ Connection(80, 42),
342
+ Connection(42, 183),
343
+ Connection(183, 80),
344
+ Connection(40, 92),
345
+ Connection(92, 186),
346
+ Connection(186, 40),
347
+ Connection(230, 229),
348
+ Connection(229, 118),
349
+ Connection(118, 230),
350
+ Connection(202, 212),
351
+ Connection(212, 214),
352
+ Connection(214, 202),
353
+ Connection(83, 18),
354
+ Connection(18, 17),
355
+ Connection(17, 83),
356
+ Connection(76, 61),
357
+ Connection(61, 146),
358
+ Connection(146, 76),
359
+ Connection(160, 29),
360
+ Connection(29, 30),
361
+ Connection(30, 160),
362
+ Connection(56, 157),
363
+ Connection(157, 173),
364
+ Connection(173, 56),
365
+ Connection(106, 204),
366
+ Connection(204, 194),
367
+ Connection(194, 106),
368
+ Connection(135, 214),
369
+ Connection(214, 192),
370
+ Connection(192, 135),
371
+ Connection(203, 165),
372
+ Connection(165, 98),
373
+ Connection(98, 203),
374
+ Connection(21, 71),
375
+ Connection(71, 68),
376
+ Connection(68, 21),
377
+ Connection(51, 45),
378
+ Connection(45, 4),
379
+ Connection(4, 51),
380
+ Connection(144, 24),
381
+ Connection(24, 23),
382
+ Connection(23, 144),
383
+ Connection(77, 146),
384
+ Connection(146, 91),
385
+ Connection(91, 77),
386
+ Connection(205, 50),
387
+ Connection(50, 187),
388
+ Connection(187, 205),
389
+ Connection(201, 200),
390
+ Connection(200, 18),
391
+ Connection(18, 201),
392
+ Connection(91, 106),
393
+ Connection(106, 182),
394
+ Connection(182, 91),
395
+ Connection(90, 91),
396
+ Connection(91, 181),
397
+ Connection(181, 90),
398
+ Connection(85, 84),
399
+ Connection(84, 17),
400
+ Connection(17, 85),
401
+ Connection(206, 203),
402
+ Connection(203, 36),
403
+ Connection(36, 206),
404
+ Connection(148, 171),
405
+ Connection(171, 140),
406
+ Connection(140, 148),
407
+ Connection(92, 40),
408
+ Connection(40, 39),
409
+ Connection(39, 92),
410
+ Connection(193, 189),
411
+ Connection(189, 244),
412
+ Connection(244, 193),
413
+ Connection(159, 158),
414
+ Connection(158, 28),
415
+ Connection(28, 159),
416
+ Connection(247, 246),
417
+ Connection(246, 161),
418
+ Connection(161, 247),
419
+ Connection(236, 3),
420
+ Connection(3, 196),
421
+ Connection(196, 236),
422
+ Connection(54, 68),
423
+ Connection(68, 104),
424
+ Connection(104, 54),
425
+ Connection(193, 168),
426
+ Connection(168, 8),
427
+ Connection(8, 193),
428
+ Connection(117, 228),
429
+ Connection(228, 31),
430
+ Connection(31, 117),
431
+ Connection(189, 193),
432
+ Connection(193, 55),
433
+ Connection(55, 189),
434
+ Connection(98, 97),
435
+ Connection(97, 99),
436
+ Connection(99, 98),
437
+ Connection(126, 47),
438
+ Connection(47, 100),
439
+ Connection(100, 126),
440
+ Connection(166, 79),
441
+ Connection(79, 218),
442
+ Connection(218, 166),
443
+ Connection(155, 154),
444
+ Connection(154, 26),
445
+ Connection(26, 155),
446
+ Connection(209, 49),
447
+ Connection(49, 131),
448
+ Connection(131, 209),
449
+ Connection(135, 136),
450
+ Connection(136, 150),
451
+ Connection(150, 135),
452
+ Connection(47, 126),
453
+ Connection(126, 217),
454
+ Connection(217, 47),
455
+ Connection(223, 52),
456
+ Connection(52, 53),
457
+ Connection(53, 223),
458
+ Connection(45, 51),
459
+ Connection(51, 134),
460
+ Connection(134, 45),
461
+ Connection(211, 170),
462
+ Connection(170, 140),
463
+ Connection(140, 211),
464
+ Connection(67, 69),
465
+ Connection(69, 108),
466
+ Connection(108, 67),
467
+ Connection(43, 106),
468
+ Connection(106, 91),
469
+ Connection(91, 43),
470
+ Connection(230, 119),
471
+ Connection(119, 120),
472
+ Connection(120, 230),
473
+ Connection(226, 130),
474
+ Connection(130, 247),
475
+ Connection(247, 226),
476
+ Connection(63, 53),
477
+ Connection(53, 52),
478
+ Connection(52, 63),
479
+ Connection(238, 20),
480
+ Connection(20, 242),
481
+ Connection(242, 238),
482
+ Connection(46, 70),
483
+ Connection(70, 156),
484
+ Connection(156, 46),
485
+ Connection(78, 62),
486
+ Connection(62, 96),
487
+ Connection(96, 78),
488
+ Connection(46, 53),
489
+ Connection(53, 63),
490
+ Connection(63, 46),
491
+ Connection(143, 34),
492
+ Connection(34, 227),
493
+ Connection(227, 143),
494
+ Connection(123, 117),
495
+ Connection(117, 111),
496
+ Connection(111, 123),
497
+ Connection(44, 125),
498
+ Connection(125, 19),
499
+ Connection(19, 44),
500
+ Connection(236, 134),
501
+ Connection(134, 51),
502
+ Connection(51, 236),
503
+ Connection(216, 206),
504
+ Connection(206, 205),
505
+ Connection(205, 216),
506
+ Connection(154, 153),
507
+ Connection(153, 22),
508
+ Connection(22, 154),
509
+ Connection(39, 37),
510
+ Connection(37, 167),
511
+ Connection(167, 39),
512
+ Connection(200, 201),
513
+ Connection(201, 208),
514
+ Connection(208, 200),
515
+ Connection(36, 142),
516
+ Connection(142, 100),
517
+ Connection(100, 36),
518
+ Connection(57, 212),
519
+ Connection(212, 202),
520
+ Connection(202, 57),
521
+ Connection(20, 60),
522
+ Connection(60, 99),
523
+ Connection(99, 20),
524
+ Connection(28, 158),
525
+ Connection(158, 157),
526
+ Connection(157, 28),
527
+ Connection(35, 226),
528
+ Connection(226, 113),
529
+ Connection(113, 35),
530
+ Connection(160, 159),
531
+ Connection(159, 27),
532
+ Connection(27, 160),
533
+ Connection(204, 202),
534
+ Connection(202, 210),
535
+ Connection(210, 204),
536
+ Connection(113, 225),
537
+ Connection(225, 46),
538
+ Connection(46, 113),
539
+ Connection(43, 202),
540
+ Connection(202, 204),
541
+ Connection(204, 43),
542
+ Connection(62, 76),
543
+ Connection(76, 77),
544
+ Connection(77, 62),
545
+ Connection(137, 123),
546
+ Connection(123, 116),
547
+ Connection(116, 137),
548
+ Connection(41, 38),
549
+ Connection(38, 72),
550
+ Connection(72, 41),
551
+ Connection(203, 129),
552
+ Connection(129, 142),
553
+ Connection(142, 203),
554
+ Connection(64, 98),
555
+ Connection(98, 240),
556
+ Connection(240, 64),
557
+ Connection(49, 102),
558
+ Connection(102, 64),
559
+ Connection(64, 49),
560
+ Connection(41, 73),
561
+ Connection(73, 74),
562
+ Connection(74, 41),
563
+ Connection(212, 216),
564
+ Connection(216, 207),
565
+ Connection(207, 212),
566
+ Connection(42, 74),
567
+ Connection(74, 184),
568
+ Connection(184, 42),
569
+ Connection(169, 170),
570
+ Connection(170, 211),
571
+ Connection(211, 169),
572
+ Connection(170, 149),
573
+ Connection(149, 176),
574
+ Connection(176, 170),
575
+ Connection(105, 66),
576
+ Connection(66, 69),
577
+ Connection(69, 105),
578
+ Connection(122, 6),
579
+ Connection(6, 168),
580
+ Connection(168, 122),
581
+ Connection(123, 147),
582
+ Connection(147, 187),
583
+ Connection(187, 123),
584
+ Connection(96, 77),
585
+ Connection(77, 90),
586
+ Connection(90, 96),
587
+ Connection(65, 55),
588
+ Connection(55, 107),
589
+ Connection(107, 65),
590
+ Connection(89, 90),
591
+ Connection(90, 180),
592
+ Connection(180, 89),
593
+ Connection(101, 100),
594
+ Connection(100, 120),
595
+ Connection(120, 101),
596
+ Connection(63, 105),
597
+ Connection(105, 104),
598
+ Connection(104, 63),
599
+ Connection(93, 137),
600
+ Connection(137, 227),
601
+ Connection(227, 93),
602
+ Connection(15, 86),
603
+ Connection(86, 85),
604
+ Connection(85, 15),
605
+ Connection(129, 102),
606
+ Connection(102, 49),
607
+ Connection(49, 129),
608
+ Connection(14, 87),
609
+ Connection(87, 86),
610
+ Connection(86, 14),
611
+ Connection(55, 8),
612
+ Connection(8, 9),
613
+ Connection(9, 55),
614
+ Connection(100, 47),
615
+ Connection(47, 121),
616
+ Connection(121, 100),
617
+ Connection(145, 23),
618
+ Connection(23, 22),
619
+ Connection(22, 145),
620
+ Connection(88, 89),
621
+ Connection(89, 179),
622
+ Connection(179, 88),
623
+ Connection(6, 122),
624
+ Connection(122, 196),
625
+ Connection(196, 6),
626
+ Connection(88, 95),
627
+ Connection(95, 96),
628
+ Connection(96, 88),
629
+ Connection(138, 172),
630
+ Connection(172, 136),
631
+ Connection(136, 138),
632
+ Connection(215, 58),
633
+ Connection(58, 172),
634
+ Connection(172, 215),
635
+ Connection(115, 48),
636
+ Connection(48, 219),
637
+ Connection(219, 115),
638
+ Connection(42, 80),
639
+ Connection(80, 81),
640
+ Connection(81, 42),
641
+ Connection(195, 3),
642
+ Connection(3, 51),
643
+ Connection(51, 195),
644
+ Connection(43, 146),
645
+ Connection(146, 61),
646
+ Connection(61, 43),
647
+ Connection(171, 175),
648
+ Connection(175, 199),
649
+ Connection(199, 171),
650
+ Connection(81, 82),
651
+ Connection(82, 38),
652
+ Connection(38, 81),
653
+ Connection(53, 46),
654
+ Connection(46, 225),
655
+ Connection(225, 53),
656
+ Connection(144, 163),
657
+ Connection(163, 110),
658
+ Connection(110, 144),
659
+ Connection(52, 65),
660
+ Connection(65, 66),
661
+ Connection(66, 52),
662
+ Connection(229, 228),
663
+ Connection(228, 117),
664
+ Connection(117, 229),
665
+ Connection(34, 127),
666
+ Connection(127, 234),
667
+ Connection(234, 34),
668
+ Connection(107, 108),
669
+ Connection(108, 69),
670
+ Connection(69, 107),
671
+ Connection(109, 108),
672
+ Connection(108, 151),
673
+ Connection(151, 109),
674
+ Connection(48, 64),
675
+ Connection(64, 235),
676
+ Connection(235, 48),
677
+ Connection(62, 78),
678
+ Connection(78, 191),
679
+ Connection(191, 62),
680
+ Connection(129, 209),
681
+ Connection(209, 126),
682
+ Connection(126, 129),
683
+ Connection(111, 35),
684
+ Connection(35, 143),
685
+ Connection(143, 111),
686
+ Connection(117, 123),
687
+ Connection(123, 50),
688
+ Connection(50, 117),
689
+ Connection(222, 65),
690
+ Connection(65, 52),
691
+ Connection(52, 222),
692
+ Connection(19, 125),
693
+ Connection(125, 141),
694
+ Connection(141, 19),
695
+ Connection(221, 55),
696
+ Connection(55, 65),
697
+ Connection(65, 221),
698
+ Connection(3, 195),
699
+ Connection(195, 197),
700
+ Connection(197, 3),
701
+ Connection(25, 7),
702
+ Connection(7, 33),
703
+ Connection(33, 25),
704
+ Connection(220, 237),
705
+ Connection(237, 44),
706
+ Connection(44, 220),
707
+ Connection(70, 71),
708
+ Connection(71, 139),
709
+ Connection(139, 70),
710
+ Connection(122, 193),
711
+ Connection(193, 245),
712
+ Connection(245, 122),
713
+ Connection(247, 130),
714
+ Connection(130, 33),
715
+ Connection(33, 247),
716
+ Connection(71, 21),
717
+ Connection(21, 162),
718
+ Connection(162, 71),
719
+ Connection(170, 169),
720
+ Connection(169, 150),
721
+ Connection(150, 170),
722
+ Connection(188, 174),
723
+ Connection(174, 196),
724
+ Connection(196, 188),
725
+ Connection(216, 186),
726
+ Connection(186, 92),
727
+ Connection(92, 216),
728
+ Connection(2, 97),
729
+ Connection(97, 167),
730
+ Connection(167, 2),
731
+ Connection(141, 125),
732
+ Connection(125, 241),
733
+ Connection(241, 141),
734
+ Connection(164, 167),
735
+ Connection(167, 37),
736
+ Connection(37, 164),
737
+ Connection(72, 38),
738
+ Connection(38, 12),
739
+ Connection(12, 72),
740
+ Connection(38, 82),
741
+ Connection(82, 13),
742
+ Connection(13, 38),
743
+ Connection(63, 68),
744
+ Connection(68, 71),
745
+ Connection(71, 63),
746
+ Connection(226, 35),
747
+ Connection(35, 111),
748
+ Connection(111, 226),
749
+ Connection(101, 50),
750
+ Connection(50, 205),
751
+ Connection(205, 101),
752
+ Connection(206, 92),
753
+ Connection(92, 165),
754
+ Connection(165, 206),
755
+ Connection(209, 198),
756
+ Connection(198, 217),
757
+ Connection(217, 209),
758
+ Connection(165, 167),
759
+ Connection(167, 97),
760
+ Connection(97, 165),
761
+ Connection(220, 115),
762
+ Connection(115, 218),
763
+ Connection(218, 220),
764
+ Connection(133, 112),
765
+ Connection(112, 243),
766
+ Connection(243, 133),
767
+ Connection(239, 238),
768
+ Connection(238, 241),
769
+ Connection(241, 239),
770
+ Connection(214, 135),
771
+ Connection(135, 169),
772
+ Connection(169, 214),
773
+ Connection(190, 173),
774
+ Connection(173, 133),
775
+ Connection(133, 190),
776
+ Connection(171, 208),
777
+ Connection(208, 32),
778
+ Connection(32, 171),
779
+ Connection(125, 44),
780
+ Connection(44, 237),
781
+ Connection(237, 125),
782
+ Connection(86, 87),
783
+ Connection(87, 178),
784
+ Connection(178, 86),
785
+ Connection(85, 86),
786
+ Connection(86, 179),
787
+ Connection(179, 85),
788
+ Connection(84, 85),
789
+ Connection(85, 180),
790
+ Connection(180, 84),
791
+ Connection(83, 84),
792
+ Connection(84, 181),
793
+ Connection(181, 83),
794
+ Connection(201, 83),
795
+ Connection(83, 182),
796
+ Connection(182, 201),
797
+ Connection(137, 93),
798
+ Connection(93, 132),
799
+ Connection(132, 137),
800
+ Connection(76, 62),
801
+ Connection(62, 183),
802
+ Connection(183, 76),
803
+ Connection(61, 76),
804
+ Connection(76, 184),
805
+ Connection(184, 61),
806
+ Connection(57, 61),
807
+ Connection(61, 185),
808
+ Connection(185, 57),
809
+ Connection(212, 57),
810
+ Connection(57, 186),
811
+ Connection(186, 212),
812
+ Connection(214, 207),
813
+ Connection(207, 187),
814
+ Connection(187, 214),
815
+ Connection(34, 143),
816
+ Connection(143, 156),
817
+ Connection(156, 34),
818
+ Connection(79, 239),
819
+ Connection(239, 237),
820
+ Connection(237, 79),
821
+ Connection(123, 137),
822
+ Connection(137, 177),
823
+ Connection(177, 123),
824
+ Connection(44, 1),
825
+ Connection(1, 4),
826
+ Connection(4, 44),
827
+ Connection(201, 194),
828
+ Connection(194, 32),
829
+ Connection(32, 201),
830
+ Connection(64, 102),
831
+ Connection(102, 129),
832
+ Connection(129, 64),
833
+ Connection(213, 215),
834
+ Connection(215, 138),
835
+ Connection(138, 213),
836
+ Connection(59, 166),
837
+ Connection(166, 219),
838
+ Connection(219, 59),
839
+ Connection(242, 99),
840
+ Connection(99, 97),
841
+ Connection(97, 242),
842
+ Connection(2, 94),
843
+ Connection(94, 141),
844
+ Connection(141, 2),
845
+ Connection(75, 59),
846
+ Connection(59, 235),
847
+ Connection(235, 75),
848
+ Connection(24, 110),
849
+ Connection(110, 228),
850
+ Connection(228, 24),
851
+ Connection(25, 130),
852
+ Connection(130, 226),
853
+ Connection(226, 25),
854
+ Connection(23, 24),
855
+ Connection(24, 229),
856
+ Connection(229, 23),
857
+ Connection(22, 23),
858
+ Connection(23, 230),
859
+ Connection(230, 22),
860
+ Connection(26, 22),
861
+ Connection(22, 231),
862
+ Connection(231, 26),
863
+ Connection(112, 26),
864
+ Connection(26, 232),
865
+ Connection(232, 112),
866
+ Connection(189, 190),
867
+ Connection(190, 243),
868
+ Connection(243, 189),
869
+ Connection(221, 56),
870
+ Connection(56, 190),
871
+ Connection(190, 221),
872
+ Connection(28, 56),
873
+ Connection(56, 221),
874
+ Connection(221, 28),
875
+ Connection(27, 28),
876
+ Connection(28, 222),
877
+ Connection(222, 27),
878
+ Connection(29, 27),
879
+ Connection(27, 223),
880
+ Connection(223, 29),
881
+ Connection(30, 29),
882
+ Connection(29, 224),
883
+ Connection(224, 30),
884
+ Connection(247, 30),
885
+ Connection(30, 225),
886
+ Connection(225, 247),
887
+ Connection(238, 79),
888
+ Connection(79, 20),
889
+ Connection(20, 238),
890
+ Connection(166, 59),
891
+ Connection(59, 75),
892
+ Connection(75, 166),
893
+ Connection(60, 75),
894
+ Connection(75, 240),
895
+ Connection(240, 60),
896
+ Connection(147, 177),
897
+ Connection(177, 215),
898
+ Connection(215, 147),
899
+ Connection(20, 79),
900
+ Connection(79, 166),
901
+ Connection(166, 20),
902
+ Connection(187, 147),
903
+ Connection(147, 213),
904
+ Connection(213, 187),
905
+ Connection(112, 233),
906
+ Connection(233, 244),
907
+ Connection(244, 112),
908
+ Connection(233, 128),
909
+ Connection(128, 245),
910
+ Connection(245, 233),
911
+ Connection(128, 114),
912
+ Connection(114, 188),
913
+ Connection(188, 128),
914
+ Connection(114, 217),
915
+ Connection(217, 174),
916
+ Connection(174, 114),
917
+ Connection(131, 115),
918
+ Connection(115, 220),
919
+ Connection(220, 131),
920
+ Connection(217, 198),
921
+ Connection(198, 236),
922
+ Connection(236, 217),
923
+ Connection(198, 131),
924
+ Connection(131, 134),
925
+ Connection(134, 198),
926
+ Connection(177, 132),
927
+ Connection(132, 58),
928
+ Connection(58, 177),
929
+ Connection(143, 35),
930
+ Connection(35, 124),
931
+ Connection(124, 143),
932
+ Connection(110, 163),
933
+ Connection(163, 7),
934
+ Connection(7, 110),
935
+ Connection(228, 110),
936
+ Connection(110, 25),
937
+ Connection(25, 228),
938
+ Connection(356, 389),
939
+ Connection(389, 368),
940
+ Connection(368, 356),
941
+ Connection(11, 302),
942
+ Connection(302, 267),
943
+ Connection(267, 11),
944
+ Connection(452, 350),
945
+ Connection(350, 349),
946
+ Connection(349, 452),
947
+ Connection(302, 303),
948
+ Connection(303, 269),
949
+ Connection(269, 302),
950
+ Connection(357, 343),
951
+ Connection(343, 277),
952
+ Connection(277, 357),
953
+ Connection(452, 453),
954
+ Connection(453, 357),
955
+ Connection(357, 452),
956
+ Connection(333, 332),
957
+ Connection(332, 297),
958
+ Connection(297, 333),
959
+ Connection(175, 152),
960
+ Connection(152, 377),
961
+ Connection(377, 175),
962
+ Connection(347, 348),
963
+ Connection(348, 330),
964
+ Connection(330, 347),
965
+ Connection(303, 304),
966
+ Connection(304, 270),
967
+ Connection(270, 303),
968
+ Connection(9, 336),
969
+ Connection(336, 337),
970
+ Connection(337, 9),
971
+ Connection(278, 279),
972
+ Connection(279, 360),
973
+ Connection(360, 278),
974
+ Connection(418, 262),
975
+ Connection(262, 431),
976
+ Connection(431, 418),
977
+ Connection(304, 408),
978
+ Connection(408, 409),
979
+ Connection(409, 304),
980
+ Connection(310, 415),
981
+ Connection(415, 407),
982
+ Connection(407, 310),
983
+ Connection(270, 409),
984
+ Connection(409, 410),
985
+ Connection(410, 270),
986
+ Connection(450, 348),
987
+ Connection(348, 347),
988
+ Connection(347, 450),
989
+ Connection(422, 430),
990
+ Connection(430, 434),
991
+ Connection(434, 422),
992
+ Connection(313, 314),
993
+ Connection(314, 17),
994
+ Connection(17, 313),
995
+ Connection(306, 307),
996
+ Connection(307, 375),
997
+ Connection(375, 306),
998
+ Connection(387, 388),
999
+ Connection(388, 260),
1000
+ Connection(260, 387),
1001
+ Connection(286, 414),
1002
+ Connection(414, 398),
1003
+ Connection(398, 286),
1004
+ Connection(335, 406),
1005
+ Connection(406, 418),
1006
+ Connection(418, 335),
1007
+ Connection(364, 367),
1008
+ Connection(367, 416),
1009
+ Connection(416, 364),
1010
+ Connection(423, 358),
1011
+ Connection(358, 327),
1012
+ Connection(327, 423),
1013
+ Connection(251, 284),
1014
+ Connection(284, 298),
1015
+ Connection(298, 251),
1016
+ Connection(281, 5),
1017
+ Connection(5, 4),
1018
+ Connection(4, 281),
1019
+ Connection(373, 374),
1020
+ Connection(374, 253),
1021
+ Connection(253, 373),
1022
+ Connection(307, 320),
1023
+ Connection(320, 321),
1024
+ Connection(321, 307),
1025
+ Connection(425, 427),
1026
+ Connection(427, 411),
1027
+ Connection(411, 425),
1028
+ Connection(421, 313),
1029
+ Connection(313, 18),
1030
+ Connection(18, 421),
1031
+ Connection(321, 405),
1032
+ Connection(405, 406),
1033
+ Connection(406, 321),
1034
+ Connection(320, 404),
1035
+ Connection(404, 405),
1036
+ Connection(405, 320),
1037
+ Connection(315, 16),
1038
+ Connection(16, 17),
1039
+ Connection(17, 315),
1040
+ Connection(426, 425),
1041
+ Connection(425, 266),
1042
+ Connection(266, 426),
1043
+ Connection(377, 400),
1044
+ Connection(400, 369),
1045
+ Connection(369, 377),
1046
+ Connection(322, 391),
1047
+ Connection(391, 269),
1048
+ Connection(269, 322),
1049
+ Connection(417, 465),
1050
+ Connection(465, 464),
1051
+ Connection(464, 417),
1052
+ Connection(386, 257),
1053
+ Connection(257, 258),
1054
+ Connection(258, 386),
1055
+ Connection(466, 260),
1056
+ Connection(260, 388),
1057
+ Connection(388, 466),
1058
+ Connection(456, 399),
1059
+ Connection(399, 419),
1060
+ Connection(419, 456),
1061
+ Connection(284, 332),
1062
+ Connection(332, 333),
1063
+ Connection(333, 284),
1064
+ Connection(417, 285),
1065
+ Connection(285, 8),
1066
+ Connection(8, 417),
1067
+ Connection(346, 340),
1068
+ Connection(340, 261),
1069
+ Connection(261, 346),
1070
+ Connection(413, 441),
1071
+ Connection(441, 285),
1072
+ Connection(285, 413),
1073
+ Connection(327, 460),
1074
+ Connection(460, 328),
1075
+ Connection(328, 327),
1076
+ Connection(355, 371),
1077
+ Connection(371, 329),
1078
+ Connection(329, 355),
1079
+ Connection(392, 439),
1080
+ Connection(439, 438),
1081
+ Connection(438, 392),
1082
+ Connection(382, 341),
1083
+ Connection(341, 256),
1084
+ Connection(256, 382),
1085
+ Connection(429, 420),
1086
+ Connection(420, 360),
1087
+ Connection(360, 429),
1088
+ Connection(364, 394),
1089
+ Connection(394, 379),
1090
+ Connection(379, 364),
1091
+ Connection(277, 343),
1092
+ Connection(343, 437),
1093
+ Connection(437, 277),
1094
+ Connection(443, 444),
1095
+ Connection(444, 283),
1096
+ Connection(283, 443),
1097
+ Connection(275, 440),
1098
+ Connection(440, 363),
1099
+ Connection(363, 275),
1100
+ Connection(431, 262),
1101
+ Connection(262, 369),
1102
+ Connection(369, 431),
1103
+ Connection(297, 338),
1104
+ Connection(338, 337),
1105
+ Connection(337, 297),
1106
+ Connection(273, 375),
1107
+ Connection(375, 321),
1108
+ Connection(321, 273),
1109
+ Connection(450, 451),
1110
+ Connection(451, 349),
1111
+ Connection(349, 450),
1112
+ Connection(446, 342),
1113
+ Connection(342, 467),
1114
+ Connection(467, 446),
1115
+ Connection(293, 334),
1116
+ Connection(334, 282),
1117
+ Connection(282, 293),
1118
+ Connection(458, 461),
1119
+ Connection(461, 462),
1120
+ Connection(462, 458),
1121
+ Connection(276, 353),
1122
+ Connection(353, 383),
1123
+ Connection(383, 276),
1124
+ Connection(308, 324),
1125
+ Connection(324, 325),
1126
+ Connection(325, 308),
1127
+ Connection(276, 300),
1128
+ Connection(300, 293),
1129
+ Connection(293, 276),
1130
+ Connection(372, 345),
1131
+ Connection(345, 447),
1132
+ Connection(447, 372),
1133
+ Connection(352, 345),
1134
+ Connection(345, 340),
1135
+ Connection(340, 352),
1136
+ Connection(274, 1),
1137
+ Connection(1, 19),
1138
+ Connection(19, 274),
1139
+ Connection(456, 248),
1140
+ Connection(248, 281),
1141
+ Connection(281, 456),
1142
+ Connection(436, 427),
1143
+ Connection(427, 425),
1144
+ Connection(425, 436),
1145
+ Connection(381, 256),
1146
+ Connection(256, 252),
1147
+ Connection(252, 381),
1148
+ Connection(269, 391),
1149
+ Connection(391, 393),
1150
+ Connection(393, 269),
1151
+ Connection(200, 199),
1152
+ Connection(199, 428),
1153
+ Connection(428, 200),
1154
+ Connection(266, 330),
1155
+ Connection(330, 329),
1156
+ Connection(329, 266),
1157
+ Connection(287, 273),
1158
+ Connection(273, 422),
1159
+ Connection(422, 287),
1160
+ Connection(250, 462),
1161
+ Connection(462, 328),
1162
+ Connection(328, 250),
1163
+ Connection(258, 286),
1164
+ Connection(286, 384),
1165
+ Connection(384, 258),
1166
+ Connection(265, 353),
1167
+ Connection(353, 342),
1168
+ Connection(342, 265),
1169
+ Connection(387, 259),
1170
+ Connection(259, 257),
1171
+ Connection(257, 387),
1172
+ Connection(424, 431),
1173
+ Connection(431, 430),
1174
+ Connection(430, 424),
1175
+ Connection(342, 353),
1176
+ Connection(353, 276),
1177
+ Connection(276, 342),
1178
+ Connection(273, 335),
1179
+ Connection(335, 424),
1180
+ Connection(424, 273),
1181
+ Connection(292, 325),
1182
+ Connection(325, 307),
1183
+ Connection(307, 292),
1184
+ Connection(366, 447),
1185
+ Connection(447, 345),
1186
+ Connection(345, 366),
1187
+ Connection(271, 303),
1188
+ Connection(303, 302),
1189
+ Connection(302, 271),
1190
+ Connection(423, 266),
1191
+ Connection(266, 371),
1192
+ Connection(371, 423),
1193
+ Connection(294, 455),
1194
+ Connection(455, 460),
1195
+ Connection(460, 294),
1196
+ Connection(279, 278),
1197
+ Connection(278, 294),
1198
+ Connection(294, 279),
1199
+ Connection(271, 272),
1200
+ Connection(272, 304),
1201
+ Connection(304, 271),
1202
+ Connection(432, 434),
1203
+ Connection(434, 427),
1204
+ Connection(427, 432),
1205
+ Connection(272, 407),
1206
+ Connection(407, 408),
1207
+ Connection(408, 272),
1208
+ Connection(394, 430),
1209
+ Connection(430, 431),
1210
+ Connection(431, 394),
1211
+ Connection(395, 369),
1212
+ Connection(369, 400),
1213
+ Connection(400, 395),
1214
+ Connection(334, 333),
1215
+ Connection(333, 299),
1216
+ Connection(299, 334),
1217
+ Connection(351, 417),
1218
+ Connection(417, 168),
1219
+ Connection(168, 351),
1220
+ Connection(352, 280),
1221
+ Connection(280, 411),
1222
+ Connection(411, 352),
1223
+ Connection(325, 319),
1224
+ Connection(319, 320),
1225
+ Connection(320, 325),
1226
+ Connection(295, 296),
1227
+ Connection(296, 336),
1228
+ Connection(336, 295),
1229
+ Connection(319, 403),
1230
+ Connection(403, 404),
1231
+ Connection(404, 319),
1232
+ Connection(330, 348),
1233
+ Connection(348, 349),
1234
+ Connection(349, 330),
1235
+ Connection(293, 298),
1236
+ Connection(298, 333),
1237
+ Connection(333, 293),
1238
+ Connection(323, 454),
1239
+ Connection(454, 447),
1240
+ Connection(447, 323),
1241
+ Connection(15, 16),
1242
+ Connection(16, 315),
1243
+ Connection(315, 15),
1244
+ Connection(358, 429),
1245
+ Connection(429, 279),
1246
+ Connection(279, 358),
1247
+ Connection(14, 15),
1248
+ Connection(15, 316),
1249
+ Connection(316, 14),
1250
+ Connection(285, 336),
1251
+ Connection(336, 9),
1252
+ Connection(9, 285),
1253
+ Connection(329, 349),
1254
+ Connection(349, 350),
1255
+ Connection(350, 329),
1256
+ Connection(374, 380),
1257
+ Connection(380, 252),
1258
+ Connection(252, 374),
1259
+ Connection(318, 402),
1260
+ Connection(402, 403),
1261
+ Connection(403, 318),
1262
+ Connection(6, 197),
1263
+ Connection(197, 419),
1264
+ Connection(419, 6),
1265
+ Connection(318, 319),
1266
+ Connection(319, 325),
1267
+ Connection(325, 318),
1268
+ Connection(367, 364),
1269
+ Connection(364, 365),
1270
+ Connection(365, 367),
1271
+ Connection(435, 367),
1272
+ Connection(367, 397),
1273
+ Connection(397, 435),
1274
+ Connection(344, 438),
1275
+ Connection(438, 439),
1276
+ Connection(439, 344),
1277
+ Connection(272, 271),
1278
+ Connection(271, 311),
1279
+ Connection(311, 272),
1280
+ Connection(195, 5),
1281
+ Connection(5, 281),
1282
+ Connection(281, 195),
1283
+ Connection(273, 287),
1284
+ Connection(287, 291),
1285
+ Connection(291, 273),
1286
+ Connection(396, 428),
1287
+ Connection(428, 199),
1288
+ Connection(199, 396),
1289
+ Connection(311, 271),
1290
+ Connection(271, 268),
1291
+ Connection(268, 311),
1292
+ Connection(283, 444),
1293
+ Connection(444, 445),
1294
+ Connection(445, 283),
1295
+ Connection(373, 254),
1296
+ Connection(254, 339),
1297
+ Connection(339, 373),
1298
+ Connection(282, 334),
1299
+ Connection(334, 296),
1300
+ Connection(296, 282),
1301
+ Connection(449, 347),
1302
+ Connection(347, 346),
1303
+ Connection(346, 449),
1304
+ Connection(264, 447),
1305
+ Connection(447, 454),
1306
+ Connection(454, 264),
1307
+ Connection(336, 296),
1308
+ Connection(296, 299),
1309
+ Connection(299, 336),
1310
+ Connection(338, 10),
1311
+ Connection(10, 151),
1312
+ Connection(151, 338),
1313
+ Connection(278, 439),
1314
+ Connection(439, 455),
1315
+ Connection(455, 278),
1316
+ Connection(292, 407),
1317
+ Connection(407, 415),
1318
+ Connection(415, 292),
1319
+ Connection(358, 371),
1320
+ Connection(371, 355),
1321
+ Connection(355, 358),
1322
+ Connection(340, 345),
1323
+ Connection(345, 372),
1324
+ Connection(372, 340),
1325
+ Connection(346, 347),
1326
+ Connection(347, 280),
1327
+ Connection(280, 346),
1328
+ Connection(442, 443),
1329
+ Connection(443, 282),
1330
+ Connection(282, 442),
1331
+ Connection(19, 94),
1332
+ Connection(94, 370),
1333
+ Connection(370, 19),
1334
+ Connection(441, 442),
1335
+ Connection(442, 295),
1336
+ Connection(295, 441),
1337
+ Connection(248, 419),
1338
+ Connection(419, 197),
1339
+ Connection(197, 248),
1340
+ Connection(263, 255),
1341
+ Connection(255, 359),
1342
+ Connection(359, 263),
1343
+ Connection(440, 275),
1344
+ Connection(275, 274),
1345
+ Connection(274, 440),
1346
+ Connection(300, 383),
1347
+ Connection(383, 368),
1348
+ Connection(368, 300),
1349
+ Connection(351, 412),
1350
+ Connection(412, 465),
1351
+ Connection(465, 351),
1352
+ Connection(263, 467),
1353
+ Connection(467, 466),
1354
+ Connection(466, 263),
1355
+ Connection(301, 368),
1356
+ Connection(368, 389),
1357
+ Connection(389, 301),
1358
+ Connection(395, 378),
1359
+ Connection(378, 379),
1360
+ Connection(379, 395),
1361
+ Connection(412, 351),
1362
+ Connection(351, 419),
1363
+ Connection(419, 412),
1364
+ Connection(436, 426),
1365
+ Connection(426, 322),
1366
+ Connection(322, 436),
1367
+ Connection(2, 164),
1368
+ Connection(164, 393),
1369
+ Connection(393, 2),
1370
+ Connection(370, 462),
1371
+ Connection(462, 461),
1372
+ Connection(461, 370),
1373
+ Connection(164, 0),
1374
+ Connection(0, 267),
1375
+ Connection(267, 164),
1376
+ Connection(302, 11),
1377
+ Connection(11, 12),
1378
+ Connection(12, 302),
1379
+ Connection(268, 12),
1380
+ Connection(12, 13),
1381
+ Connection(13, 268),
1382
+ Connection(293, 300),
1383
+ Connection(300, 301),
1384
+ Connection(301, 293),
1385
+ Connection(446, 261),
1386
+ Connection(261, 340),
1387
+ Connection(340, 446),
1388
+ Connection(330, 266),
1389
+ Connection(266, 425),
1390
+ Connection(425, 330),
1391
+ Connection(426, 423),
1392
+ Connection(423, 391),
1393
+ Connection(391, 426),
1394
+ Connection(429, 355),
1395
+ Connection(355, 437),
1396
+ Connection(437, 429),
1397
+ Connection(391, 327),
1398
+ Connection(327, 326),
1399
+ Connection(326, 391),
1400
+ Connection(440, 457),
1401
+ Connection(457, 438),
1402
+ Connection(438, 440),
1403
+ Connection(341, 382),
1404
+ Connection(382, 362),
1405
+ Connection(362, 341),
1406
+ Connection(459, 457),
1407
+ Connection(457, 461),
1408
+ Connection(461, 459),
1409
+ Connection(434, 430),
1410
+ Connection(430, 394),
1411
+ Connection(394, 434),
1412
+ Connection(414, 463),
1413
+ Connection(463, 362),
1414
+ Connection(362, 414),
1415
+ Connection(396, 369),
1416
+ Connection(369, 262),
1417
+ Connection(262, 396),
1418
+ Connection(354, 461),
1419
+ Connection(461, 457),
1420
+ Connection(457, 354),
1421
+ Connection(316, 403),
1422
+ Connection(403, 402),
1423
+ Connection(402, 316),
1424
+ Connection(315, 404),
1425
+ Connection(404, 403),
1426
+ Connection(403, 315),
1427
+ Connection(314, 405),
1428
+ Connection(405, 404),
1429
+ Connection(404, 314),
1430
+ Connection(313, 406),
1431
+ Connection(406, 405),
1432
+ Connection(405, 313),
1433
+ Connection(421, 418),
1434
+ Connection(418, 406),
1435
+ Connection(406, 421),
1436
+ Connection(366, 401),
1437
+ Connection(401, 361),
1438
+ Connection(361, 366),
1439
+ Connection(306, 408),
1440
+ Connection(408, 407),
1441
+ Connection(407, 306),
1442
+ Connection(291, 409),
1443
+ Connection(409, 408),
1444
+ Connection(408, 291),
1445
+ Connection(287, 410),
1446
+ Connection(410, 409),
1447
+ Connection(409, 287),
1448
+ Connection(432, 436),
1449
+ Connection(436, 410),
1450
+ Connection(410, 432),
1451
+ Connection(434, 416),
1452
+ Connection(416, 411),
1453
+ Connection(411, 434),
1454
+ Connection(264, 368),
1455
+ Connection(368, 383),
1456
+ Connection(383, 264),
1457
+ Connection(309, 438),
1458
+ Connection(438, 457),
1459
+ Connection(457, 309),
1460
+ Connection(352, 376),
1461
+ Connection(376, 401),
1462
+ Connection(401, 352),
1463
+ Connection(274, 275),
1464
+ Connection(275, 4),
1465
+ Connection(4, 274),
1466
+ Connection(421, 428),
1467
+ Connection(428, 262),
1468
+ Connection(262, 421),
1469
+ Connection(294, 327),
1470
+ Connection(327, 358),
1471
+ Connection(358, 294),
1472
+ Connection(433, 416),
1473
+ Connection(416, 367),
1474
+ Connection(367, 433),
1475
+ Connection(289, 455),
1476
+ Connection(455, 439),
1477
+ Connection(439, 289),
1478
+ Connection(462, 370),
1479
+ Connection(370, 326),
1480
+ Connection(326, 462),
1481
+ Connection(2, 326),
1482
+ Connection(326, 370),
1483
+ Connection(370, 2),
1484
+ Connection(305, 460),
1485
+ Connection(460, 455),
1486
+ Connection(455, 305),
1487
+ Connection(254, 449),
1488
+ Connection(449, 448),
1489
+ Connection(448, 254),
1490
+ Connection(255, 261),
1491
+ Connection(261, 446),
1492
+ Connection(446, 255),
1493
+ Connection(253, 450),
1494
+ Connection(450, 449),
1495
+ Connection(449, 253),
1496
+ Connection(252, 451),
1497
+ Connection(451, 450),
1498
+ Connection(450, 252),
1499
+ Connection(256, 452),
1500
+ Connection(452, 451),
1501
+ Connection(451, 256),
1502
+ Connection(341, 453),
1503
+ Connection(453, 452),
1504
+ Connection(452, 341),
1505
+ Connection(413, 464),
1506
+ Connection(464, 463),
1507
+ Connection(463, 413),
1508
+ Connection(441, 413),
1509
+ Connection(413, 414),
1510
+ Connection(414, 441),
1511
+ Connection(258, 442),
1512
+ Connection(442, 441),
1513
+ Connection(441, 258),
1514
+ Connection(257, 443),
1515
+ Connection(443, 442),
1516
+ Connection(442, 257),
1517
+ Connection(259, 444),
1518
+ Connection(444, 443),
1519
+ Connection(443, 259),
1520
+ Connection(260, 445),
1521
+ Connection(445, 444),
1522
+ Connection(444, 260),
1523
+ Connection(467, 342),
1524
+ Connection(342, 445),
1525
+ Connection(445, 467),
1526
+ Connection(459, 458),
1527
+ Connection(458, 250),
1528
+ Connection(250, 459),
1529
+ Connection(289, 392),
1530
+ Connection(392, 290),
1531
+ Connection(290, 289),
1532
+ Connection(290, 328),
1533
+ Connection(328, 460),
1534
+ Connection(460, 290),
1535
+ Connection(376, 433),
1536
+ Connection(433, 435),
1537
+ Connection(435, 376),
1538
+ Connection(250, 290),
1539
+ Connection(290, 392),
1540
+ Connection(392, 250),
1541
+ Connection(411, 416),
1542
+ Connection(416, 433),
1543
+ Connection(433, 411),
1544
+ Connection(341, 463),
1545
+ Connection(463, 464),
1546
+ Connection(464, 341),
1547
+ Connection(453, 464),
1548
+ Connection(464, 465),
1549
+ Connection(465, 453),
1550
+ Connection(357, 465),
1551
+ Connection(465, 412),
1552
+ Connection(412, 357),
1553
+ Connection(343, 412),
1554
+ Connection(412, 399),
1555
+ Connection(399, 343),
1556
+ Connection(360, 363),
1557
+ Connection(363, 440),
1558
+ Connection(440, 360),
1559
+ Connection(437, 399),
1560
+ Connection(399, 456),
1561
+ Connection(456, 437),
1562
+ Connection(420, 456),
1563
+ Connection(456, 363),
1564
+ Connection(363, 420),
1565
+ Connection(401, 435),
1566
+ Connection(435, 288),
1567
+ Connection(288, 401),
1568
+ Connection(372, 383),
1569
+ Connection(383, 353),
1570
+ Connection(353, 372),
1571
+ Connection(339, 255),
1572
+ Connection(255, 249),
1573
+ Connection(249, 339),
1574
+ Connection(448, 261),
1575
+ Connection(261, 255),
1576
+ Connection(255, 448),
1577
+ Connection(133, 243),
1578
+ Connection(243, 190),
1579
+ Connection(190, 133),
1580
+ Connection(133, 155),
1581
+ Connection(155, 112),
1582
+ Connection(112, 133),
1583
+ Connection(33, 246),
1584
+ Connection(246, 247),
1585
+ Connection(247, 33),
1586
+ Connection(33, 130),
1587
+ Connection(130, 25),
1588
+ Connection(25, 33),
1589
+ Connection(398, 384),
1590
+ Connection(384, 286),
1591
+ Connection(286, 398),
1592
+ Connection(362, 398),
1593
+ Connection(398, 414),
1594
+ Connection(414, 362),
1595
+ Connection(362, 463),
1596
+ Connection(463, 341),
1597
+ Connection(341, 362),
1598
+ Connection(263, 359),
1599
+ Connection(359, 467),
1600
+ Connection(467, 263),
1601
+ Connection(263, 249),
1602
+ Connection(249, 255),
1603
+ Connection(255, 263),
1604
+ Connection(466, 467),
1605
+ Connection(467, 260),
1606
+ Connection(260, 466),
1607
+ Connection(75, 60),
1608
+ Connection(60, 166),
1609
+ Connection(166, 75),
1610
+ Connection(238, 239),
1611
+ Connection(239, 79),
1612
+ Connection(79, 238),
1613
+ Connection(162, 127),
1614
+ Connection(127, 139),
1615
+ Connection(139, 162),
1616
+ Connection(72, 11),
1617
+ Connection(11, 37),
1618
+ Connection(37, 72),
1619
+ Connection(121, 232),
1620
+ Connection(232, 120),
1621
+ Connection(120, 121),
1622
+ Connection(73, 72),
1623
+ Connection(72, 39),
1624
+ Connection(39, 73),
1625
+ Connection(114, 128),
1626
+ Connection(128, 47),
1627
+ Connection(47, 114),
1628
+ Connection(233, 232),
1629
+ Connection(232, 128),
1630
+ Connection(128, 233),
1631
+ Connection(103, 104),
1632
+ Connection(104, 67),
1633
+ Connection(67, 103),
1634
+ Connection(152, 175),
1635
+ Connection(175, 148),
1636
+ Connection(148, 152),
1637
+ Connection(119, 118),
1638
+ Connection(118, 101),
1639
+ Connection(101, 119),
1640
+ Connection(74, 73),
1641
+ Connection(73, 40),
1642
+ Connection(40, 74),
1643
+ Connection(107, 9),
1644
+ Connection(9, 108),
1645
+ Connection(108, 107),
1646
+ Connection(49, 48),
1647
+ Connection(48, 131),
1648
+ Connection(131, 49),
1649
+ Connection(32, 194),
1650
+ Connection(194, 211),
1651
+ Connection(211, 32),
1652
+ Connection(184, 74),
1653
+ Connection(74, 185),
1654
+ Connection(185, 184),
1655
+ Connection(191, 80),
1656
+ Connection(80, 183),
1657
+ Connection(183, 191),
1658
+ Connection(185, 40),
1659
+ Connection(40, 186),
1660
+ Connection(186, 185),
1661
+ Connection(119, 230),
1662
+ Connection(230, 118),
1663
+ Connection(118, 119),
1664
+ Connection(210, 202),
1665
+ Connection(202, 214),
1666
+ Connection(214, 210),
1667
+ Connection(84, 83),
1668
+ Connection(83, 17),
1669
+ Connection(17, 84),
1670
+ Connection(77, 76),
1671
+ Connection(76, 146),
1672
+ Connection(146, 77),
1673
+ Connection(161, 160),
1674
+ Connection(160, 30),
1675
+ Connection(30, 161),
1676
+ Connection(190, 56),
1677
+ Connection(56, 173),
1678
+ Connection(173, 190),
1679
+ Connection(182, 106),
1680
+ Connection(106, 194),
1681
+ Connection(194, 182),
1682
+ Connection(138, 135),
1683
+ Connection(135, 192),
1684
+ Connection(192, 138),
1685
+ Connection(129, 203),
1686
+ Connection(203, 98),
1687
+ Connection(98, 129),
1688
+ Connection(54, 21),
1689
+ Connection(21, 68),
1690
+ Connection(68, 54),
1691
+ Connection(5, 51),
1692
+ Connection(51, 4),
1693
+ Connection(4, 5),
1694
+ Connection(145, 144),
1695
+ Connection(144, 23),
1696
+ Connection(23, 145),
1697
+ Connection(90, 77),
1698
+ Connection(77, 91),
1699
+ Connection(91, 90),
1700
+ Connection(207, 205),
1701
+ Connection(205, 187),
1702
+ Connection(187, 207),
1703
+ Connection(83, 201),
1704
+ Connection(201, 18),
1705
+ Connection(18, 83),
1706
+ Connection(181, 91),
1707
+ Connection(91, 182),
1708
+ Connection(182, 181),
1709
+ Connection(180, 90),
1710
+ Connection(90, 181),
1711
+ Connection(181, 180),
1712
+ Connection(16, 85),
1713
+ Connection(85, 17),
1714
+ Connection(17, 16),
1715
+ Connection(205, 206),
1716
+ Connection(206, 36),
1717
+ Connection(36, 205),
1718
+ Connection(176, 148),
1719
+ Connection(148, 140),
1720
+ Connection(140, 176),
1721
+ Connection(165, 92),
1722
+ Connection(92, 39),
1723
+ Connection(39, 165),
1724
+ Connection(245, 193),
1725
+ Connection(193, 244),
1726
+ Connection(244, 245),
1727
+ Connection(27, 159),
1728
+ Connection(159, 28),
1729
+ Connection(28, 27),
1730
+ Connection(30, 247),
1731
+ Connection(247, 161),
1732
+ Connection(161, 30),
1733
+ Connection(174, 236),
1734
+ Connection(236, 196),
1735
+ Connection(196, 174),
1736
+ Connection(103, 54),
1737
+ Connection(54, 104),
1738
+ Connection(104, 103),
1739
+ Connection(55, 193),
1740
+ Connection(193, 8),
1741
+ Connection(8, 55),
1742
+ Connection(111, 117),
1743
+ Connection(117, 31),
1744
+ Connection(31, 111),
1745
+ Connection(221, 189),
1746
+ Connection(189, 55),
1747
+ Connection(55, 221),
1748
+ Connection(240, 98),
1749
+ Connection(98, 99),
1750
+ Connection(99, 240),
1751
+ Connection(142, 126),
1752
+ Connection(126, 100),
1753
+ Connection(100, 142),
1754
+ Connection(219, 166),
1755
+ Connection(166, 218),
1756
+ Connection(218, 219),
1757
+ Connection(112, 155),
1758
+ Connection(155, 26),
1759
+ Connection(26, 112),
1760
+ Connection(198, 209),
1761
+ Connection(209, 131),
1762
+ Connection(131, 198),
1763
+ Connection(169, 135),
1764
+ Connection(135, 150),
1765
+ Connection(150, 169),
1766
+ Connection(114, 47),
1767
+ Connection(47, 217),
1768
+ Connection(217, 114),
1769
+ Connection(224, 223),
1770
+ Connection(223, 53),
1771
+ Connection(53, 224),
1772
+ Connection(220, 45),
1773
+ Connection(45, 134),
1774
+ Connection(134, 220),
1775
+ Connection(32, 211),
1776
+ Connection(211, 140),
1777
+ Connection(140, 32),
1778
+ Connection(109, 67),
1779
+ Connection(67, 108),
1780
+ Connection(108, 109),
1781
+ Connection(146, 43),
1782
+ Connection(43, 91),
1783
+ Connection(91, 146),
1784
+ Connection(231, 230),
1785
+ Connection(230, 120),
1786
+ Connection(120, 231),
1787
+ Connection(113, 226),
1788
+ Connection(226, 247),
1789
+ Connection(247, 113),
1790
+ Connection(105, 63),
1791
+ Connection(63, 52),
1792
+ Connection(52, 105),
1793
+ Connection(241, 238),
1794
+ Connection(238, 242),
1795
+ Connection(242, 241),
1796
+ Connection(124, 46),
1797
+ Connection(46, 156),
1798
+ Connection(156, 124),
1799
+ Connection(95, 78),
1800
+ Connection(78, 96),
1801
+ Connection(96, 95),
1802
+ Connection(70, 46),
1803
+ Connection(46, 63),
1804
+ Connection(63, 70),
1805
+ Connection(116, 143),
1806
+ Connection(143, 227),
1807
+ Connection(227, 116),
1808
+ Connection(116, 123),
1809
+ Connection(123, 111),
1810
+ Connection(111, 116),
1811
+ Connection(1, 44),
1812
+ Connection(44, 19),
1813
+ Connection(19, 1),
1814
+ Connection(3, 236),
1815
+ Connection(236, 51),
1816
+ Connection(51, 3),
1817
+ Connection(207, 216),
1818
+ Connection(216, 205),
1819
+ Connection(205, 207),
1820
+ Connection(26, 154),
1821
+ Connection(154, 22),
1822
+ Connection(22, 26),
1823
+ Connection(165, 39),
1824
+ Connection(39, 167),
1825
+ Connection(167, 165),
1826
+ Connection(199, 200),
1827
+ Connection(200, 208),
1828
+ Connection(208, 199),
1829
+ Connection(101, 36),
1830
+ Connection(36, 100),
1831
+ Connection(100, 101),
1832
+ Connection(43, 57),
1833
+ Connection(57, 202),
1834
+ Connection(202, 43),
1835
+ Connection(242, 20),
1836
+ Connection(20, 99),
1837
+ Connection(99, 242),
1838
+ Connection(56, 28),
1839
+ Connection(28, 157),
1840
+ Connection(157, 56),
1841
+ Connection(124, 35),
1842
+ Connection(35, 113),
1843
+ Connection(113, 124),
1844
+ Connection(29, 160),
1845
+ Connection(160, 27),
1846
+ Connection(27, 29),
1847
+ Connection(211, 204),
1848
+ Connection(204, 210),
1849
+ Connection(210, 211),
1850
+ Connection(124, 113),
1851
+ Connection(113, 46),
1852
+ Connection(46, 124),
1853
+ Connection(106, 43),
1854
+ Connection(43, 204),
1855
+ Connection(204, 106),
1856
+ Connection(96, 62),
1857
+ Connection(62, 77),
1858
+ Connection(77, 96),
1859
+ Connection(227, 137),
1860
+ Connection(137, 116),
1861
+ Connection(116, 227),
1862
+ Connection(73, 41),
1863
+ Connection(41, 72),
1864
+ Connection(72, 73),
1865
+ Connection(36, 203),
1866
+ Connection(203, 142),
1867
+ Connection(142, 36),
1868
+ Connection(235, 64),
1869
+ Connection(64, 240),
1870
+ Connection(240, 235),
1871
+ Connection(48, 49),
1872
+ Connection(49, 64),
1873
+ Connection(64, 48),
1874
+ Connection(42, 41),
1875
+ Connection(41, 74),
1876
+ Connection(74, 42),
1877
+ Connection(214, 212),
1878
+ Connection(212, 207),
1879
+ Connection(207, 214),
1880
+ Connection(183, 42),
1881
+ Connection(42, 184),
1882
+ Connection(184, 183),
1883
+ Connection(210, 169),
1884
+ Connection(169, 211),
1885
+ Connection(211, 210),
1886
+ Connection(140, 170),
1887
+ Connection(170, 176),
1888
+ Connection(176, 140),
1889
+ Connection(104, 105),
1890
+ Connection(105, 69),
1891
+ Connection(69, 104),
1892
+ Connection(193, 122),
1893
+ Connection(122, 168),
1894
+ Connection(168, 193),
1895
+ Connection(50, 123),
1896
+ Connection(123, 187),
1897
+ Connection(187, 50),
1898
+ Connection(89, 96),
1899
+ Connection(96, 90),
1900
+ Connection(90, 89),
1901
+ Connection(66, 65),
1902
+ Connection(65, 107),
1903
+ Connection(107, 66),
1904
+ Connection(179, 89),
1905
+ Connection(89, 180),
1906
+ Connection(180, 179),
1907
+ Connection(119, 101),
1908
+ Connection(101, 120),
1909
+ Connection(120, 119),
1910
+ Connection(68, 63),
1911
+ Connection(63, 104),
1912
+ Connection(104, 68),
1913
+ Connection(234, 93),
1914
+ Connection(93, 227),
1915
+ Connection(227, 234),
1916
+ Connection(16, 15),
1917
+ Connection(15, 85),
1918
+ Connection(85, 16),
1919
+ Connection(209, 129),
1920
+ Connection(129, 49),
1921
+ Connection(49, 209),
1922
+ Connection(15, 14),
1923
+ Connection(14, 86),
1924
+ Connection(86, 15),
1925
+ Connection(107, 55),
1926
+ Connection(55, 9),
1927
+ Connection(9, 107),
1928
+ Connection(120, 100),
1929
+ Connection(100, 121),
1930
+ Connection(121, 120),
1931
+ Connection(153, 145),
1932
+ Connection(145, 22),
1933
+ Connection(22, 153),
1934
+ Connection(178, 88),
1935
+ Connection(88, 179),
1936
+ Connection(179, 178),
1937
+ Connection(197, 6),
1938
+ Connection(6, 196),
1939
+ Connection(196, 197),
1940
+ Connection(89, 88),
1941
+ Connection(88, 96),
1942
+ Connection(96, 89),
1943
+ Connection(135, 138),
1944
+ Connection(138, 136),
1945
+ Connection(136, 135),
1946
+ Connection(138, 215),
1947
+ Connection(215, 172),
1948
+ Connection(172, 138),
1949
+ Connection(218, 115),
1950
+ Connection(115, 219),
1951
+ Connection(219, 218),
1952
+ Connection(41, 42),
1953
+ Connection(42, 81),
1954
+ Connection(81, 41),
1955
+ Connection(5, 195),
1956
+ Connection(195, 51),
1957
+ Connection(51, 5),
1958
+ Connection(57, 43),
1959
+ Connection(43, 61),
1960
+ Connection(61, 57),
1961
+ Connection(208, 171),
1962
+ Connection(171, 199),
1963
+ Connection(199, 208),
1964
+ Connection(41, 81),
1965
+ Connection(81, 38),
1966
+ Connection(38, 41),
1967
+ Connection(224, 53),
1968
+ Connection(53, 225),
1969
+ Connection(225, 224),
1970
+ Connection(24, 144),
1971
+ Connection(144, 110),
1972
+ Connection(110, 24),
1973
+ Connection(105, 52),
1974
+ Connection(52, 66),
1975
+ Connection(66, 105),
1976
+ Connection(118, 229),
1977
+ Connection(229, 117),
1978
+ Connection(117, 118),
1979
+ Connection(227, 34),
1980
+ Connection(34, 234),
1981
+ Connection(234, 227),
1982
+ Connection(66, 107),
1983
+ Connection(107, 69),
1984
+ Connection(69, 66),
1985
+ Connection(10, 109),
1986
+ Connection(109, 151),
1987
+ Connection(151, 10),
1988
+ Connection(219, 48),
1989
+ Connection(48, 235),
1990
+ Connection(235, 219),
1991
+ Connection(183, 62),
1992
+ Connection(62, 191),
1993
+ Connection(191, 183),
1994
+ Connection(142, 129),
1995
+ Connection(129, 126),
1996
+ Connection(126, 142),
1997
+ Connection(116, 111),
1998
+ Connection(111, 143),
1999
+ Connection(143, 116),
2000
+ Connection(118, 117),
2001
+ Connection(117, 50),
2002
+ Connection(50, 118),
2003
+ Connection(223, 222),
2004
+ Connection(222, 52),
2005
+ Connection(52, 223),
2006
+ Connection(94, 19),
2007
+ Connection(19, 141),
2008
+ Connection(141, 94),
2009
+ Connection(222, 221),
2010
+ Connection(221, 65),
2011
+ Connection(65, 222),
2012
+ Connection(196, 3),
2013
+ Connection(3, 197),
2014
+ Connection(197, 196),
2015
+ Connection(45, 220),
2016
+ Connection(220, 44),
2017
+ Connection(44, 45),
2018
+ Connection(156, 70),
2019
+ Connection(70, 139),
2020
+ Connection(139, 156),
2021
+ Connection(188, 122),
2022
+ Connection(122, 245),
2023
+ Connection(245, 188),
2024
+ Connection(139, 71),
2025
+ Connection(71, 162),
2026
+ Connection(162, 139),
2027
+ Connection(149, 170),
2028
+ Connection(170, 150),
2029
+ Connection(150, 149),
2030
+ Connection(122, 188),
2031
+ Connection(188, 196),
2032
+ Connection(196, 122),
2033
+ Connection(206, 216),
2034
+ Connection(216, 92),
2035
+ Connection(92, 206),
2036
+ Connection(164, 2),
2037
+ Connection(2, 167),
2038
+ Connection(167, 164),
2039
+ Connection(242, 141),
2040
+ Connection(141, 241),
2041
+ Connection(241, 242),
2042
+ Connection(0, 164),
2043
+ Connection(164, 37),
2044
+ Connection(37, 0),
2045
+ Connection(11, 72),
2046
+ Connection(72, 12),
2047
+ Connection(12, 11),
2048
+ Connection(12, 38),
2049
+ Connection(38, 13),
2050
+ Connection(13, 12),
2051
+ Connection(70, 63),
2052
+ Connection(63, 71),
2053
+ Connection(71, 70),
2054
+ Connection(31, 226),
2055
+ Connection(226, 111),
2056
+ Connection(111, 31),
2057
+ Connection(36, 101),
2058
+ Connection(101, 205),
2059
+ Connection(205, 36),
2060
+ Connection(203, 206),
2061
+ Connection(206, 165),
2062
+ Connection(165, 203),
2063
+ Connection(126, 209),
2064
+ Connection(209, 217),
2065
+ Connection(217, 126),
2066
+ Connection(98, 165),
2067
+ Connection(165, 97),
2068
+ Connection(97, 98),
2069
+ Connection(237, 220),
2070
+ Connection(220, 218),
2071
+ Connection(218, 237),
2072
+ Connection(237, 239),
2073
+ Connection(239, 241),
2074
+ Connection(241, 237),
2075
+ Connection(210, 214),
2076
+ Connection(214, 169),
2077
+ Connection(169, 210),
2078
+ Connection(140, 171),
2079
+ Connection(171, 32),
2080
+ Connection(32, 140),
2081
+ Connection(241, 125),
2082
+ Connection(125, 237),
2083
+ Connection(237, 241),
2084
+ Connection(179, 86),
2085
+ Connection(86, 178),
2086
+ Connection(178, 179),
2087
+ Connection(180, 85),
2088
+ Connection(85, 179),
2089
+ Connection(179, 180),
2090
+ Connection(181, 84),
2091
+ Connection(84, 180),
2092
+ Connection(180, 181),
2093
+ Connection(182, 83),
2094
+ Connection(83, 181),
2095
+ Connection(181, 182),
2096
+ Connection(194, 201),
2097
+ Connection(201, 182),
2098
+ Connection(182, 194),
2099
+ Connection(177, 137),
2100
+ Connection(137, 132),
2101
+ Connection(132, 177),
2102
+ Connection(184, 76),
2103
+ Connection(76, 183),
2104
+ Connection(183, 184),
2105
+ Connection(185, 61),
2106
+ Connection(61, 184),
2107
+ Connection(184, 185),
2108
+ Connection(186, 57),
2109
+ Connection(57, 185),
2110
+ Connection(185, 186),
2111
+ Connection(216, 212),
2112
+ Connection(212, 186),
2113
+ Connection(186, 216),
2114
+ Connection(192, 214),
2115
+ Connection(214, 187),
2116
+ Connection(187, 192),
2117
+ Connection(139, 34),
2118
+ Connection(34, 156),
2119
+ Connection(156, 139),
2120
+ Connection(218, 79),
2121
+ Connection(79, 237),
2122
+ Connection(237, 218),
2123
+ Connection(147, 123),
2124
+ Connection(123, 177),
2125
+ Connection(177, 147),
2126
+ Connection(45, 44),
2127
+ Connection(44, 4),
2128
+ Connection(4, 45),
2129
+ Connection(208, 201),
2130
+ Connection(201, 32),
2131
+ Connection(32, 208),
2132
+ Connection(98, 64),
2133
+ Connection(64, 129),
2134
+ Connection(129, 98),
2135
+ Connection(192, 213),
2136
+ Connection(213, 138),
2137
+ Connection(138, 192),
2138
+ Connection(235, 59),
2139
+ Connection(59, 219),
2140
+ Connection(219, 235),
2141
+ Connection(141, 242),
2142
+ Connection(242, 97),
2143
+ Connection(97, 141),
2144
+ Connection(97, 2),
2145
+ Connection(2, 141),
2146
+ Connection(141, 97),
2147
+ Connection(240, 75),
2148
+ Connection(75, 235),
2149
+ Connection(235, 240),
2150
+ Connection(229, 24),
2151
+ Connection(24, 228),
2152
+ Connection(228, 229),
2153
+ Connection(31, 25),
2154
+ Connection(25, 226),
2155
+ Connection(226, 31),
2156
+ Connection(230, 23),
2157
+ Connection(23, 229),
2158
+ Connection(229, 230),
2159
+ Connection(231, 22),
2160
+ Connection(22, 230),
2161
+ Connection(230, 231),
2162
+ Connection(232, 26),
2163
+ Connection(26, 231),
2164
+ Connection(231, 232),
2165
+ Connection(233, 112),
2166
+ Connection(112, 232),
2167
+ Connection(232, 233),
2168
+ Connection(244, 189),
2169
+ Connection(189, 243),
2170
+ Connection(243, 244),
2171
+ Connection(189, 221),
2172
+ Connection(221, 190),
2173
+ Connection(190, 189),
2174
+ Connection(222, 28),
2175
+ Connection(28, 221),
2176
+ Connection(221, 222),
2177
+ Connection(223, 27),
2178
+ Connection(27, 222),
2179
+ Connection(222, 223),
2180
+ Connection(224, 29),
2181
+ Connection(29, 223),
2182
+ Connection(223, 224),
2183
+ Connection(225, 30),
2184
+ Connection(30, 224),
2185
+ Connection(224, 225),
2186
+ Connection(113, 247),
2187
+ Connection(247, 225),
2188
+ Connection(225, 113),
2189
+ Connection(99, 60),
2190
+ Connection(60, 240),
2191
+ Connection(240, 99),
2192
+ Connection(213, 147),
2193
+ Connection(147, 215),
2194
+ Connection(215, 213),
2195
+ Connection(60, 20),
2196
+ Connection(20, 166),
2197
+ Connection(166, 60),
2198
+ Connection(192, 187),
2199
+ Connection(187, 213),
2200
+ Connection(213, 192),
2201
+ Connection(243, 112),
2202
+ Connection(112, 244),
2203
+ Connection(244, 243),
2204
+ Connection(244, 233),
2205
+ Connection(233, 245),
2206
+ Connection(245, 244),
2207
+ Connection(245, 128),
2208
+ Connection(128, 188),
2209
+ Connection(188, 245),
2210
+ Connection(188, 114),
2211
+ Connection(114, 174),
2212
+ Connection(174, 188),
2213
+ Connection(134, 131),
2214
+ Connection(131, 220),
2215
+ Connection(220, 134),
2216
+ Connection(174, 217),
2217
+ Connection(217, 236),
2218
+ Connection(236, 174),
2219
+ Connection(236, 198),
2220
+ Connection(198, 134),
2221
+ Connection(134, 236),
2222
+ Connection(215, 177),
2223
+ Connection(177, 58),
2224
+ Connection(58, 215),
2225
+ Connection(156, 143),
2226
+ Connection(143, 124),
2227
+ Connection(124, 156),
2228
+ Connection(25, 110),
2229
+ Connection(110, 7),
2230
+ Connection(7, 25),
2231
+ Connection(31, 228),
2232
+ Connection(228, 25),
2233
+ Connection(25, 31),
2234
+ Connection(264, 356),
2235
+ Connection(356, 368),
2236
+ Connection(368, 264),
2237
+ Connection(0, 11),
2238
+ Connection(11, 267),
2239
+ Connection(267, 0),
2240
+ Connection(451, 452),
2241
+ Connection(452, 349),
2242
+ Connection(349, 451),
2243
+ Connection(267, 302),
2244
+ Connection(302, 269),
2245
+ Connection(269, 267),
2246
+ Connection(350, 357),
2247
+ Connection(357, 277),
2248
+ Connection(277, 350),
2249
+ Connection(350, 452),
2250
+ Connection(452, 357),
2251
+ Connection(357, 350),
2252
+ Connection(299, 333),
2253
+ Connection(333, 297),
2254
+ Connection(297, 299),
2255
+ Connection(396, 175),
2256
+ Connection(175, 377),
2257
+ Connection(377, 396),
2258
+ Connection(280, 347),
2259
+ Connection(347, 330),
2260
+ Connection(330, 280),
2261
+ Connection(269, 303),
2262
+ Connection(303, 270),
2263
+ Connection(270, 269),
2264
+ Connection(151, 9),
2265
+ Connection(9, 337),
2266
+ Connection(337, 151),
2267
+ Connection(344, 278),
2268
+ Connection(278, 360),
2269
+ Connection(360, 344),
2270
+ Connection(424, 418),
2271
+ Connection(418, 431),
2272
+ Connection(431, 424),
2273
+ Connection(270, 304),
2274
+ Connection(304, 409),
2275
+ Connection(409, 270),
2276
+ Connection(272, 310),
2277
+ Connection(310, 407),
2278
+ Connection(407, 272),
2279
+ Connection(322, 270),
2280
+ Connection(270, 410),
2281
+ Connection(410, 322),
2282
+ Connection(449, 450),
2283
+ Connection(450, 347),
2284
+ Connection(347, 449),
2285
+ Connection(432, 422),
2286
+ Connection(422, 434),
2287
+ Connection(434, 432),
2288
+ Connection(18, 313),
2289
+ Connection(313, 17),
2290
+ Connection(17, 18),
2291
+ Connection(291, 306),
2292
+ Connection(306, 375),
2293
+ Connection(375, 291),
2294
+ Connection(259, 387),
2295
+ Connection(387, 260),
2296
+ Connection(260, 259),
2297
+ Connection(424, 335),
2298
+ Connection(335, 418),
2299
+ Connection(418, 424),
2300
+ Connection(434, 364),
2301
+ Connection(364, 416),
2302
+ Connection(416, 434),
2303
+ Connection(391, 423),
2304
+ Connection(423, 327),
2305
+ Connection(327, 391),
2306
+ Connection(301, 251),
2307
+ Connection(251, 298),
2308
+ Connection(298, 301),
2309
+ Connection(275, 281),
2310
+ Connection(281, 4),
2311
+ Connection(4, 275),
2312
+ Connection(254, 373),
2313
+ Connection(373, 253),
2314
+ Connection(253, 254),
2315
+ Connection(375, 307),
2316
+ Connection(307, 321),
2317
+ Connection(321, 375),
2318
+ Connection(280, 425),
2319
+ Connection(425, 411),
2320
+ Connection(411, 280),
2321
+ Connection(200, 421),
2322
+ Connection(421, 18),
2323
+ Connection(18, 200),
2324
+ Connection(335, 321),
2325
+ Connection(321, 406),
2326
+ Connection(406, 335),
2327
+ Connection(321, 320),
2328
+ Connection(320, 405),
2329
+ Connection(405, 321),
2330
+ Connection(314, 315),
2331
+ Connection(315, 17),
2332
+ Connection(17, 314),
2333
+ Connection(423, 426),
2334
+ Connection(426, 266),
2335
+ Connection(266, 423),
2336
+ Connection(396, 377),
2337
+ Connection(377, 369),
2338
+ Connection(369, 396),
2339
+ Connection(270, 322),
2340
+ Connection(322, 269),
2341
+ Connection(269, 270),
2342
+ Connection(413, 417),
2343
+ Connection(417, 464),
2344
+ Connection(464, 413),
2345
+ Connection(385, 386),
2346
+ Connection(386, 258),
2347
+ Connection(258, 385),
2348
+ Connection(248, 456),
2349
+ Connection(456, 419),
2350
+ Connection(419, 248),
2351
+ Connection(298, 284),
2352
+ Connection(284, 333),
2353
+ Connection(333, 298),
2354
+ Connection(168, 417),
2355
+ Connection(417, 8),
2356
+ Connection(8, 168),
2357
+ Connection(448, 346),
2358
+ Connection(346, 261),
2359
+ Connection(261, 448),
2360
+ Connection(417, 413),
2361
+ Connection(413, 285),
2362
+ Connection(285, 417),
2363
+ Connection(326, 327),
2364
+ Connection(327, 328),
2365
+ Connection(328, 326),
2366
+ Connection(277, 355),
2367
+ Connection(355, 329),
2368
+ Connection(329, 277),
2369
+ Connection(309, 392),
2370
+ Connection(392, 438),
2371
+ Connection(438, 309),
2372
+ Connection(381, 382),
2373
+ Connection(382, 256),
2374
+ Connection(256, 381),
2375
+ Connection(279, 429),
2376
+ Connection(429, 360),
2377
+ Connection(360, 279),
2378
+ Connection(365, 364),
2379
+ Connection(364, 379),
2380
+ Connection(379, 365),
2381
+ Connection(355, 277),
2382
+ Connection(277, 437),
2383
+ Connection(437, 355),
2384
+ Connection(282, 443),
2385
+ Connection(443, 283),
2386
+ Connection(283, 282),
2387
+ Connection(281, 275),
2388
+ Connection(275, 363),
2389
+ Connection(363, 281),
2390
+ Connection(395, 431),
2391
+ Connection(431, 369),
2392
+ Connection(369, 395),
2393
+ Connection(299, 297),
2394
+ Connection(297, 337),
2395
+ Connection(337, 299),
2396
+ Connection(335, 273),
2397
+ Connection(273, 321),
2398
+ Connection(321, 335),
2399
+ Connection(348, 450),
2400
+ Connection(450, 349),
2401
+ Connection(349, 348),
2402
+ Connection(359, 446),
2403
+ Connection(446, 467),
2404
+ Connection(467, 359),
2405
+ Connection(283, 293),
2406
+ Connection(293, 282),
2407
+ Connection(282, 283),
2408
+ Connection(250, 458),
2409
+ Connection(458, 462),
2410
+ Connection(462, 250),
2411
+ Connection(300, 276),
2412
+ Connection(276, 383),
2413
+ Connection(383, 300),
2414
+ Connection(292, 308),
2415
+ Connection(308, 325),
2416
+ Connection(325, 292),
2417
+ Connection(283, 276),
2418
+ Connection(276, 293),
2419
+ Connection(293, 283),
2420
+ Connection(264, 372),
2421
+ Connection(372, 447),
2422
+ Connection(447, 264),
2423
+ Connection(346, 352),
2424
+ Connection(352, 340),
2425
+ Connection(340, 346),
2426
+ Connection(354, 274),
2427
+ Connection(274, 19),
2428
+ Connection(19, 354),
2429
+ Connection(363, 456),
2430
+ Connection(456, 281),
2431
+ Connection(281, 363),
2432
+ Connection(426, 436),
2433
+ Connection(436, 425),
2434
+ Connection(425, 426),
2435
+ Connection(380, 381),
2436
+ Connection(381, 252),
2437
+ Connection(252, 380),
2438
+ Connection(267, 269),
2439
+ Connection(269, 393),
2440
+ Connection(393, 267),
2441
+ Connection(421, 200),
2442
+ Connection(200, 428),
2443
+ Connection(428, 421),
2444
+ Connection(371, 266),
2445
+ Connection(266, 329),
2446
+ Connection(329, 371),
2447
+ Connection(432, 287),
2448
+ Connection(287, 422),
2449
+ Connection(422, 432),
2450
+ Connection(290, 250),
2451
+ Connection(250, 328),
2452
+ Connection(328, 290),
2453
+ Connection(385, 258),
2454
+ Connection(258, 384),
2455
+ Connection(384, 385),
2456
+ Connection(446, 265),
2457
+ Connection(265, 342),
2458
+ Connection(342, 446),
2459
+ Connection(386, 387),
2460
+ Connection(387, 257),
2461
+ Connection(257, 386),
2462
+ Connection(422, 424),
2463
+ Connection(424, 430),
2464
+ Connection(430, 422),
2465
+ Connection(445, 342),
2466
+ Connection(342, 276),
2467
+ Connection(276, 445),
2468
+ Connection(422, 273),
2469
+ Connection(273, 424),
2470
+ Connection(424, 422),
2471
+ Connection(306, 292),
2472
+ Connection(292, 307),
2473
+ Connection(307, 306),
2474
+ Connection(352, 366),
2475
+ Connection(366, 345),
2476
+ Connection(345, 352),
2477
+ Connection(268, 271),
2478
+ Connection(271, 302),
2479
+ Connection(302, 268),
2480
+ Connection(358, 423),
2481
+ Connection(423, 371),
2482
+ Connection(371, 358),
2483
+ Connection(327, 294),
2484
+ Connection(294, 460),
2485
+ Connection(460, 327),
2486
+ Connection(331, 279),
2487
+ Connection(279, 294),
2488
+ Connection(294, 331),
2489
+ Connection(303, 271),
2490
+ Connection(271, 304),
2491
+ Connection(304, 303),
2492
+ Connection(436, 432),
2493
+ Connection(432, 427),
2494
+ Connection(427, 436),
2495
+ Connection(304, 272),
2496
+ Connection(272, 408),
2497
+ Connection(408, 304),
2498
+ Connection(395, 394),
2499
+ Connection(394, 431),
2500
+ Connection(431, 395),
2501
+ Connection(378, 395),
2502
+ Connection(395, 400),
2503
+ Connection(400, 378),
2504
+ Connection(296, 334),
2505
+ Connection(334, 299),
2506
+ Connection(299, 296),
2507
+ Connection(6, 351),
2508
+ Connection(351, 168),
2509
+ Connection(168, 6),
2510
+ Connection(376, 352),
2511
+ Connection(352, 411),
2512
+ Connection(411, 376),
2513
+ Connection(307, 325),
2514
+ Connection(325, 320),
2515
+ Connection(320, 307),
2516
+ Connection(285, 295),
2517
+ Connection(295, 336),
2518
+ Connection(336, 285),
2519
+ Connection(320, 319),
2520
+ Connection(319, 404),
2521
+ Connection(404, 320),
2522
+ Connection(329, 330),
2523
+ Connection(330, 349),
2524
+ Connection(349, 329),
2525
+ Connection(334, 293),
2526
+ Connection(293, 333),
2527
+ Connection(333, 334),
2528
+ Connection(366, 323),
2529
+ Connection(323, 447),
2530
+ Connection(447, 366),
2531
+ Connection(316, 15),
2532
+ Connection(15, 315),
2533
+ Connection(315, 316),
2534
+ Connection(331, 358),
2535
+ Connection(358, 279),
2536
+ Connection(279, 331),
2537
+ Connection(317, 14),
2538
+ Connection(14, 316),
2539
+ Connection(316, 317),
2540
+ Connection(8, 285),
2541
+ Connection(285, 9),
2542
+ Connection(9, 8),
2543
+ Connection(277, 329),
2544
+ Connection(329, 350),
2545
+ Connection(350, 277),
2546
+ Connection(253, 374),
2547
+ Connection(374, 252),
2548
+ Connection(252, 253),
2549
+ Connection(319, 318),
2550
+ Connection(318, 403),
2551
+ Connection(403, 319),
2552
+ Connection(351, 6),
2553
+ Connection(6, 419),
2554
+ Connection(419, 351),
2555
+ Connection(324, 318),
2556
+ Connection(318, 325),
2557
+ Connection(325, 324),
2558
+ Connection(397, 367),
2559
+ Connection(367, 365),
2560
+ Connection(365, 397),
2561
+ Connection(288, 435),
2562
+ Connection(435, 397),
2563
+ Connection(397, 288),
2564
+ Connection(278, 344),
2565
+ Connection(344, 439),
2566
+ Connection(439, 278),
2567
+ Connection(310, 272),
2568
+ Connection(272, 311),
2569
+ Connection(311, 310),
2570
+ Connection(248, 195),
2571
+ Connection(195, 281),
2572
+ Connection(281, 248),
2573
+ Connection(375, 273),
2574
+ Connection(273, 291),
2575
+ Connection(291, 375),
2576
+ Connection(175, 396),
2577
+ Connection(396, 199),
2578
+ Connection(199, 175),
2579
+ Connection(312, 311),
2580
+ Connection(311, 268),
2581
+ Connection(268, 312),
2582
+ Connection(276, 283),
2583
+ Connection(283, 445),
2584
+ Connection(445, 276),
2585
+ Connection(390, 373),
2586
+ Connection(373, 339),
2587
+ Connection(339, 390),
2588
+ Connection(295, 282),
2589
+ Connection(282, 296),
2590
+ Connection(296, 295),
2591
+ Connection(448, 449),
2592
+ Connection(449, 346),
2593
+ Connection(346, 448),
2594
+ Connection(356, 264),
2595
+ Connection(264, 454),
2596
+ Connection(454, 356),
2597
+ Connection(337, 336),
2598
+ Connection(336, 299),
2599
+ Connection(299, 337),
2600
+ Connection(337, 338),
2601
+ Connection(338, 151),
2602
+ Connection(151, 337),
2603
+ Connection(294, 278),
2604
+ Connection(278, 455),
2605
+ Connection(455, 294),
2606
+ Connection(308, 292),
2607
+ Connection(292, 415),
2608
+ Connection(415, 308),
2609
+ Connection(429, 358),
2610
+ Connection(358, 355),
2611
+ Connection(355, 429),
2612
+ Connection(265, 340),
2613
+ Connection(340, 372),
2614
+ Connection(372, 265),
2615
+ Connection(352, 346),
2616
+ Connection(346, 280),
2617
+ Connection(280, 352),
2618
+ Connection(295, 442),
2619
+ Connection(442, 282),
2620
+ Connection(282, 295),
2621
+ Connection(354, 19),
2622
+ Connection(19, 370),
2623
+ Connection(370, 354),
2624
+ Connection(285, 441),
2625
+ Connection(441, 295),
2626
+ Connection(295, 285),
2627
+ Connection(195, 248),
2628
+ Connection(248, 197),
2629
+ Connection(197, 195),
2630
+ Connection(457, 440),
2631
+ Connection(440, 274),
2632
+ Connection(274, 457),
2633
+ Connection(301, 300),
2634
+ Connection(300, 368),
2635
+ Connection(368, 301),
2636
+ Connection(417, 351),
2637
+ Connection(351, 465),
2638
+ Connection(465, 417),
2639
+ Connection(251, 301),
2640
+ Connection(301, 389),
2641
+ Connection(389, 251),
2642
+ Connection(394, 395),
2643
+ Connection(395, 379),
2644
+ Connection(379, 394),
2645
+ Connection(399, 412),
2646
+ Connection(412, 419),
2647
+ Connection(419, 399),
2648
+ Connection(410, 436),
2649
+ Connection(436, 322),
2650
+ Connection(322, 410),
2651
+ Connection(326, 2),
2652
+ Connection(2, 393),
2653
+ Connection(393, 326),
2654
+ Connection(354, 370),
2655
+ Connection(370, 461),
2656
+ Connection(461, 354),
2657
+ Connection(393, 164),
2658
+ Connection(164, 267),
2659
+ Connection(267, 393),
2660
+ Connection(268, 302),
2661
+ Connection(302, 12),
2662
+ Connection(12, 268),
2663
+ Connection(312, 268),
2664
+ Connection(268, 13),
2665
+ Connection(13, 312),
2666
+ Connection(298, 293),
2667
+ Connection(293, 301),
2668
+ Connection(301, 298),
2669
+ Connection(265, 446),
2670
+ Connection(446, 340),
2671
+ Connection(340, 265),
2672
+ Connection(280, 330),
2673
+ Connection(330, 425),
2674
+ Connection(425, 280),
2675
+ Connection(322, 426),
2676
+ Connection(426, 391),
2677
+ Connection(391, 322),
2678
+ Connection(420, 429),
2679
+ Connection(429, 437),
2680
+ Connection(437, 420),
2681
+ Connection(393, 391),
2682
+ Connection(391, 326),
2683
+ Connection(326, 393),
2684
+ Connection(344, 440),
2685
+ Connection(440, 438),
2686
+ Connection(438, 344),
2687
+ Connection(458, 459),
2688
+ Connection(459, 461),
2689
+ Connection(461, 458),
2690
+ Connection(364, 434),
2691
+ Connection(434, 394),
2692
+ Connection(394, 364),
2693
+ Connection(428, 396),
2694
+ Connection(396, 262),
2695
+ Connection(262, 428),
2696
+ Connection(274, 354),
2697
+ Connection(354, 457),
2698
+ Connection(457, 274),
2699
+ Connection(317, 316),
2700
+ Connection(316, 402),
2701
+ Connection(402, 317),
2702
+ Connection(316, 315),
2703
+ Connection(315, 403),
2704
+ Connection(403, 316),
2705
+ Connection(315, 314),
2706
+ Connection(314, 404),
2707
+ Connection(404, 315),
2708
+ Connection(314, 313),
2709
+ Connection(313, 405),
2710
+ Connection(405, 314),
2711
+ Connection(313, 421),
2712
+ Connection(421, 406),
2713
+ Connection(406, 313),
2714
+ Connection(323, 366),
2715
+ Connection(366, 361),
2716
+ Connection(361, 323),
2717
+ Connection(292, 306),
2718
+ Connection(306, 407),
2719
+ Connection(407, 292),
2720
+ Connection(306, 291),
2721
+ Connection(291, 408),
2722
+ Connection(408, 306),
2723
+ Connection(291, 287),
2724
+ Connection(287, 409),
2725
+ Connection(409, 291),
2726
+ Connection(287, 432),
2727
+ Connection(432, 410),
2728
+ Connection(410, 287),
2729
+ Connection(427, 434),
2730
+ Connection(434, 411),
2731
+ Connection(411, 427),
2732
+ Connection(372, 264),
2733
+ Connection(264, 383),
2734
+ Connection(383, 372),
2735
+ Connection(459, 309),
2736
+ Connection(309, 457),
2737
+ Connection(457, 459),
2738
+ Connection(366, 352),
2739
+ Connection(352, 401),
2740
+ Connection(401, 366),
2741
+ Connection(1, 274),
2742
+ Connection(274, 4),
2743
+ Connection(4, 1),
2744
+ Connection(418, 421),
2745
+ Connection(421, 262),
2746
+ Connection(262, 418),
2747
+ Connection(331, 294),
2748
+ Connection(294, 358),
2749
+ Connection(358, 331),
2750
+ Connection(435, 433),
2751
+ Connection(433, 367),
2752
+ Connection(367, 435),
2753
+ Connection(392, 289),
2754
+ Connection(289, 439),
2755
+ Connection(439, 392),
2756
+ Connection(328, 462),
2757
+ Connection(462, 326),
2758
+ Connection(326, 328),
2759
+ Connection(94, 2),
2760
+ Connection(2, 370),
2761
+ Connection(370, 94),
2762
+ Connection(289, 305),
2763
+ Connection(305, 455),
2764
+ Connection(455, 289),
2765
+ Connection(339, 254),
2766
+ Connection(254, 448),
2767
+ Connection(448, 339),
2768
+ Connection(359, 255),
2769
+ Connection(255, 446),
2770
+ Connection(446, 359),
2771
+ Connection(254, 253),
2772
+ Connection(253, 449),
2773
+ Connection(449, 254),
2774
+ Connection(253, 252),
2775
+ Connection(252, 450),
2776
+ Connection(450, 253),
2777
+ Connection(252, 256),
2778
+ Connection(256, 451),
2779
+ Connection(451, 252),
2780
+ Connection(256, 341),
2781
+ Connection(341, 452),
2782
+ Connection(452, 256),
2783
+ Connection(414, 413),
2784
+ Connection(413, 463),
2785
+ Connection(463, 414),
2786
+ Connection(286, 441),
2787
+ Connection(441, 414),
2788
+ Connection(414, 286),
2789
+ Connection(286, 258),
2790
+ Connection(258, 441),
2791
+ Connection(441, 286),
2792
+ Connection(258, 257),
2793
+ Connection(257, 442),
2794
+ Connection(442, 258),
2795
+ Connection(257, 259),
2796
+ Connection(259, 443),
2797
+ Connection(443, 257),
2798
+ Connection(259, 260),
2799
+ Connection(260, 444),
2800
+ Connection(444, 259),
2801
+ Connection(260, 467),
2802
+ Connection(467, 445),
2803
+ Connection(445, 260),
2804
+ Connection(309, 459),
2805
+ Connection(459, 250),
2806
+ Connection(250, 309),
2807
+ Connection(305, 289),
2808
+ Connection(289, 290),
2809
+ Connection(290, 305),
2810
+ Connection(305, 290),
2811
+ Connection(290, 460),
2812
+ Connection(460, 305),
2813
+ Connection(401, 376),
2814
+ Connection(376, 435),
2815
+ Connection(435, 401),
2816
+ Connection(309, 250),
2817
+ Connection(250, 392),
2818
+ Connection(392, 309),
2819
+ Connection(376, 411),
2820
+ Connection(411, 433),
2821
+ Connection(433, 376),
2822
+ Connection(453, 341),
2823
+ Connection(341, 464),
2824
+ Connection(464, 453),
2825
+ Connection(357, 453),
2826
+ Connection(453, 465),
2827
+ Connection(465, 357),
2828
+ Connection(343, 357),
2829
+ Connection(357, 412),
2830
+ Connection(412, 343),
2831
+ Connection(437, 343),
2832
+ Connection(343, 399),
2833
+ Connection(399, 437),
2834
+ Connection(344, 360),
2835
+ Connection(360, 440),
2836
+ Connection(440, 344),
2837
+ Connection(420, 437),
2838
+ Connection(437, 456),
2839
+ Connection(456, 420),
2840
+ Connection(360, 420),
2841
+ Connection(420, 363),
2842
+ Connection(363, 360),
2843
+ Connection(361, 401),
2844
+ Connection(401, 288),
2845
+ Connection(288, 361),
2846
+ Connection(265, 372),
2847
+ Connection(372, 353),
2848
+ Connection(353, 265),
2849
+ Connection(390, 339),
2850
+ Connection(339, 249),
2851
+ Connection(249, 390),
2852
+ Connection(339, 448),
2853
+ Connection(448, 255),
2854
+ Connection(255, 339),
2855
+ ]
2856
+
2857
+
2858
+ @dataclasses.dataclass
2859
+ class FaceLandmarkerResult:
2860
+ """The face landmarks detection result from FaceLandmarker, where each vector element represents a single face detected in the image.
2861
+
2862
+ Attributes:
2863
+ face_landmarks: Detected face landmarks in normalized image coordinates.
2864
+ face_blendshapes: Optional face blendshapes results.
2865
+ facial_transformation_matrixes: Optional facial transformation matrix.
2866
+ """
2867
+
2868
+ face_landmarks: List[List[landmark_module.NormalizedLandmark]]
2869
+ face_blendshapes: List[List[category_module.Category]]
2870
+ facial_transformation_matrixes: List[np.ndarray]
2871
+
2872
+
2873
+ def _build_landmarker_result(
2874
+ output_packets: Mapping[str, packet_module.Packet]
2875
+ ) -> FaceLandmarkerResult:
2876
+ """Constructs a `FaceLandmarkerResult` from output packets."""
2877
+ face_landmarks_proto_list = packet_getter.get_proto_list(
2878
+ output_packets[_NORM_LANDMARKS_STREAM_NAME]
2879
+ )
2880
+
2881
+ face_landmarks_results = []
2882
+ for proto in face_landmarks_proto_list:
2883
+ face_landmarks = landmark_pb2.NormalizedLandmarkList()
2884
+ face_landmarks.MergeFrom(proto)
2885
+ face_landmarks_list = []
2886
+ for face_landmark in face_landmarks.landmark:
2887
+ face_landmarks_list.append(
2888
+ landmark_module.NormalizedLandmark.create_from_pb2(face_landmark)
2889
+ )
2890
+ face_landmarks_results.append(face_landmarks_list)
2891
+
2892
+ face_blendshapes_results = []
2893
+ if _BLENDSHAPES_STREAM_NAME in output_packets:
2894
+ face_blendshapes_proto_list = packet_getter.get_proto_list(
2895
+ output_packets[_BLENDSHAPES_STREAM_NAME]
2896
+ )
2897
+ for proto in face_blendshapes_proto_list:
2898
+ face_blendshapes_categories = []
2899
+ face_blendshapes_classifications = classification_pb2.ClassificationList()
2900
+ face_blendshapes_classifications.MergeFrom(proto)
2901
+ for face_blendshapes in face_blendshapes_classifications.classification:
2902
+ face_blendshapes_categories.append(
2903
+ category_module.Category(
2904
+ index=face_blendshapes.index,
2905
+ score=face_blendshapes.score,
2906
+ display_name=face_blendshapes.display_name,
2907
+ category_name=face_blendshapes.label,
2908
+ )
2909
+ )
2910
+ face_blendshapes_results.append(face_blendshapes_categories)
2911
+
2912
+ facial_transformation_matrixes_results = []
2913
+ if _FACE_GEOMETRY_STREAM_NAME in output_packets:
2914
+ facial_transformation_matrixes_proto_list = packet_getter.get_proto_list(
2915
+ output_packets[_FACE_GEOMETRY_STREAM_NAME]
2916
+ )
2917
+ for proto in facial_transformation_matrixes_proto_list:
2918
+ if hasattr(proto, 'pose_transform_matrix'):
2919
+ matrix_data = matrix_data_pb2.MatrixData()
2920
+ matrix_data.MergeFrom(proto.pose_transform_matrix)
2921
+ matrix = np.array(matrix_data.packed_data)
2922
+ matrix = matrix.reshape((matrix_data.rows, matrix_data.cols))
2923
+ matrix = (
2924
+ matrix if matrix_data.layout == _LayoutEnum.ROW_MAJOR else matrix.T
2925
+ )
2926
+ facial_transformation_matrixes_results.append(matrix)
2927
+
2928
+ return FaceLandmarkerResult(
2929
+ face_landmarks_results,
2930
+ face_blendshapes_results,
2931
+ facial_transformation_matrixes_results,
2932
+ )
2933
+
2934
+ def _build_landmarker_result2(
2935
+ output_packets: Mapping[str, packet_module.Packet]
2936
+ ) -> FaceLandmarkerResult:
2937
+ """Constructs a `FaceLandmarkerResult` from output packets."""
2938
+ face_landmarks_proto_list = packet_getter.get_proto_list(
2939
+ output_packets[_NORM_LANDMARKS_STREAM_NAME]
2940
+ )
2941
+
2942
+ face_landmarks_results = []
2943
+ for proto in face_landmarks_proto_list:
2944
+ face_landmarks = landmark_pb2.NormalizedLandmarkList()
2945
+ face_landmarks.MergeFrom(proto)
2946
+ face_landmarks_list = []
2947
+ for face_landmark in face_landmarks.landmark:
2948
+ face_landmarks_list.append(
2949
+ landmark_module.NormalizedLandmark.create_from_pb2(face_landmark)
2950
+ )
2951
+ face_landmarks_results.append(face_landmarks_list)
2952
+
2953
+ face_blendshapes_results = []
2954
+ if _BLENDSHAPES_STREAM_NAME in output_packets:
2955
+ face_blendshapes_proto_list = packet_getter.get_proto_list(
2956
+ output_packets[_BLENDSHAPES_STREAM_NAME]
2957
+ )
2958
+ for proto in face_blendshapes_proto_list:
2959
+ face_blendshapes_categories = []
2960
+ face_blendshapes_classifications = classification_pb2.ClassificationList()
2961
+ face_blendshapes_classifications.MergeFrom(proto)
2962
+ for face_blendshapes in face_blendshapes_classifications.classification:
2963
+ face_blendshapes_categories.append(
2964
+ category_module.Category(
2965
+ index=face_blendshapes.index,
2966
+ score=face_blendshapes.score,
2967
+ display_name=face_blendshapes.display_name,
2968
+ category_name=face_blendshapes.label,
2969
+ )
2970
+ )
2971
+ face_blendshapes_results.append(face_blendshapes_categories)
2972
+
2973
+ facial_transformation_matrixes_results = []
2974
+ if _FACE_GEOMETRY_STREAM_NAME in output_packets:
2975
+ facial_transformation_matrixes_proto_list = packet_getter.get_proto_list(
2976
+ output_packets[_FACE_GEOMETRY_STREAM_NAME]
2977
+ )
2978
+ for proto in facial_transformation_matrixes_proto_list:
2979
+ if hasattr(proto, 'pose_transform_matrix'):
2980
+ matrix_data = matrix_data_pb2.MatrixData()
2981
+ matrix_data.MergeFrom(proto.pose_transform_matrix)
2982
+ matrix = np.array(matrix_data.packed_data)
2983
+ matrix = matrix.reshape((matrix_data.rows, matrix_data.cols))
2984
+ matrix = (
2985
+ matrix if matrix_data.layout == _LayoutEnum.ROW_MAJOR else matrix.T
2986
+ )
2987
+ facial_transformation_matrixes_results.append(matrix)
2988
+
2989
+ return FaceLandmarkerResult(
2990
+ face_landmarks_results,
2991
+ face_blendshapes_results,
2992
+ facial_transformation_matrixes_results,
2993
+ ), facial_transformation_matrixes_proto_list[0].mesh
2994
+
2995
+ @dataclasses.dataclass
2996
+ class FaceLandmarkerOptions:
2997
+ """Options for the face landmarker task.
2998
+
2999
+ Attributes:
3000
+ base_options: Base options for the face landmarker task.
3001
+ running_mode: The running mode of the task. Default to the image mode.
3002
+ FaceLandmarker has three running modes: 1) The image mode for detecting
3003
+ face landmarks on single image inputs. 2) The video mode for detecting
3004
+ face landmarks on the decoded frames of a video. 3) The live stream mode
3005
+ for detecting face landmarks on the live stream of input data, such as
3006
+ from camera. In this mode, the "result_callback" below must be specified
3007
+ to receive the detection results asynchronously.
3008
+ num_faces: The maximum number of faces that can be detected by the
3009
+ FaceLandmarker.
3010
+ min_face_detection_confidence: The minimum confidence score for the face
3011
+ detection to be considered successful.
3012
+ min_face_presence_confidence: The minimum confidence score of face presence
3013
+ score in the face landmark detection.
3014
+ min_tracking_confidence: The minimum confidence score for the face tracking
3015
+ to be considered successful.
3016
+ output_face_blendshapes: Whether FaceLandmarker outputs face blendshapes
3017
+ classification. Face blendshapes are used for rendering the 3D face model.
3018
+ output_facial_transformation_matrixes: Whether FaceLandmarker outputs facial
3019
+ transformation_matrix. Facial transformation matrix is used to transform
3020
+ the face landmarks in canonical face to the detected face, so that users
3021
+ can apply face effects on the detected landmarks.
3022
+ result_callback: The user-defined result callback for processing live stream
3023
+ data. The result callback should only be specified when the running mode
3024
+ is set to the live stream mode.
3025
+ """
3026
+
3027
+ base_options: _BaseOptions
3028
+ running_mode: _RunningMode = _RunningMode.IMAGE
3029
+ num_faces: int = 1
3030
+ min_face_detection_confidence: float = 0.5
3031
+ min_face_presence_confidence: float = 0.5
3032
+ min_tracking_confidence: float = 0.5
3033
+ output_face_blendshapes: bool = False
3034
+ output_facial_transformation_matrixes: bool = False
3035
+ result_callback: Optional[
3036
+ Callable[[FaceLandmarkerResult, image_module.Image, int], None]
3037
+ ] = None
3038
+
3039
+ @doc_controls.do_not_generate_docs
3040
+ def to_pb2(self) -> _FaceLandmarkerGraphOptionsProto:
3041
+ """Generates an FaceLandmarkerGraphOptions protobuf object."""
3042
+ base_options_proto = self.base_options.to_pb2()
3043
+ base_options_proto.use_stream_mode = (
3044
+ False if self.running_mode == _RunningMode.IMAGE else True
3045
+ )
3046
+
3047
+ # Initialize the face landmarker options from base options.
3048
+ face_landmarker_options_proto = _FaceLandmarkerGraphOptionsProto(
3049
+ base_options=base_options_proto
3050
+ )
3051
+
3052
+ # Configure face detector options.
3053
+ face_landmarker_options_proto.face_detector_graph_options.num_faces = (
3054
+ self.num_faces
3055
+ )
3056
+ face_landmarker_options_proto.face_detector_graph_options.min_detection_confidence = (
3057
+ self.min_face_detection_confidence
3058
+ )
3059
+
3060
+ # Configure face landmark detector options.
3061
+ face_landmarker_options_proto.min_tracking_confidence = (
3062
+ self.min_tracking_confidence
3063
+ )
3064
+ face_landmarker_options_proto.face_landmarks_detector_graph_options.min_detection_confidence = (
3065
+ self.min_face_detection_confidence
3066
+ )
3067
+ return face_landmarker_options_proto
3068
+
3069
+
3070
+ class FaceLandmarker(base_vision_task_api.BaseVisionTaskApi):
3071
+ """Class that performs face landmarks detection on images."""
3072
+
3073
+ @classmethod
3074
+ def create_from_model_path(cls, model_path: str) -> 'FaceLandmarker':
3075
+ """Creates an `FaceLandmarker` object from a TensorFlow Lite model and the default `FaceLandmarkerOptions`.
3076
+
3077
+ Note that the created `FaceLandmarker` instance is in image mode, for
3078
+ detecting face landmarks on single image inputs.
3079
+
3080
+ Args:
3081
+ model_path: Path to the model.
3082
+
3083
+ Returns:
3084
+ `FaceLandmarker` object that's created from the model file and the
3085
+ default `FaceLandmarkerOptions`.
3086
+
3087
+ Raises:
3088
+ ValueError: If failed to create `FaceLandmarker` object from the
3089
+ provided file such as invalid file path.
3090
+ RuntimeError: If other types of error occurred.
3091
+ """
3092
+ base_options = _BaseOptions(model_asset_path=model_path)
3093
+ options = FaceLandmarkerOptions(
3094
+ base_options=base_options, running_mode=_RunningMode.IMAGE
3095
+ )
3096
+ return cls.create_from_options(options)
3097
+
3098
+ @classmethod
3099
+ def create_from_options(
3100
+ cls, options: FaceLandmarkerOptions
3101
+ ) -> 'FaceLandmarker':
3102
+ """Creates the `FaceLandmarker` object from face landmarker options.
3103
+
3104
+ Args:
3105
+ options: Options for the face landmarker task.
3106
+
3107
+ Returns:
3108
+ `FaceLandmarker` object that's created from `options`.
3109
+
3110
+ Raises:
3111
+ ValueError: If failed to create `FaceLandmarker` object from
3112
+ `FaceLandmarkerOptions` such as missing the model.
3113
+ RuntimeError: If other types of error occurred.
3114
+ """
3115
+
3116
+ def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
3117
+ if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
3118
+ return
3119
+
3120
+ image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
3121
+ if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
3122
+ return
3123
+
3124
+ if output_packets[_NORM_LANDMARKS_STREAM_NAME].is_empty():
3125
+ empty_packet = output_packets[_NORM_LANDMARKS_STREAM_NAME]
3126
+ options.result_callback(
3127
+ FaceLandmarkerResult([], [], []),
3128
+ image,
3129
+ empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
3130
+ )
3131
+ return
3132
+
3133
+ face_landmarks_result = _build_landmarker_result(output_packets)
3134
+ timestamp = output_packets[_NORM_LANDMARKS_STREAM_NAME].timestamp
3135
+ options.result_callback(
3136
+ face_landmarks_result,
3137
+ image,
3138
+ timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
3139
+ )
3140
+
3141
+ output_streams = [
3142
+ ':'.join([_NORM_LANDMARKS_TAG, _NORM_LANDMARKS_STREAM_NAME]),
3143
+ ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
3144
+ ]
3145
+
3146
+ if options.output_face_blendshapes:
3147
+ output_streams.append(
3148
+ ':'.join([_BLENDSHAPES_TAG, _BLENDSHAPES_STREAM_NAME])
3149
+ )
3150
+ if options.output_facial_transformation_matrixes:
3151
+ output_streams.append(
3152
+ ':'.join([_FACE_GEOMETRY_TAG, _FACE_GEOMETRY_STREAM_NAME])
3153
+ )
3154
+
3155
+ task_info = _TaskInfo(
3156
+ task_graph=_TASK_GRAPH_NAME,
3157
+ input_streams=[
3158
+ ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
3159
+ ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
3160
+ ],
3161
+ output_streams=output_streams,
3162
+ task_options=options,
3163
+ )
3164
+ return cls(
3165
+ task_info.generate_graph_config(
3166
+ enable_flow_limiting=options.running_mode
3167
+ == _RunningMode.LIVE_STREAM
3168
+ ),
3169
+ options.running_mode,
3170
+ packets_callback if options.result_callback else None,
3171
+ )
3172
+
3173
+ def detect(
3174
+ self,
3175
+ image: image_module.Image,
3176
+ image_processing_options: Optional[_ImageProcessingOptions] = None,
3177
+ ) -> FaceLandmarkerResult:
3178
+ """Performs face landmarks detection on the given image.
3179
+
3180
+ Only use this method when the FaceLandmarker is created with the image
3181
+ running mode.
3182
+
3183
+ The image can be of any size with format RGB or RGBA.
3184
+ TODO: Describes how the input image will be preprocessed after the yuv
3185
+ support is implemented.
3186
+
3187
+ Args:
3188
+ image: MediaPipe Image.
3189
+ image_processing_options: Options for image processing.
3190
+
3191
+ Returns:
3192
+ The face landmarks detection results.
3193
+
3194
+ Raises:
3195
+ ValueError: If any of the input arguments is invalid.
3196
+ RuntimeError: If face landmarker detection failed to run.
3197
+ """
3198
+
3199
+ normalized_rect = self.convert_to_normalized_rect(
3200
+ image_processing_options, image, roi_allowed=False
3201
+ )
3202
+ output_packets = self._process_image_data({
3203
+ _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
3204
+ _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
3205
+ normalized_rect.to_pb2()
3206
+ ),
3207
+ })
3208
+
3209
+ if output_packets[_NORM_LANDMARKS_STREAM_NAME].is_empty():
3210
+ return FaceLandmarkerResult([], [], [])
3211
+
3212
+ return _build_landmarker_result2(output_packets)
3213
+
3214
+ def detect_for_video(
3215
+ self,
3216
+ image: image_module.Image,
3217
+ timestamp_ms: int,
3218
+ image_processing_options: Optional[_ImageProcessingOptions] = None,
3219
+ ):
3220
+ """Performs face landmarks detection on the provided video frame.
3221
+
3222
+ Only use this method when the FaceLandmarker is created with the video
3223
+ running mode.
3224
+
3225
+ Only use this method when the FaceLandmarker is created with the video
3226
+ running mode. It's required to provide the video frame's timestamp (in
3227
+ milliseconds) along with the video frame. The input timestamps should be
3228
+ monotonically increasing for adjacent calls of this method.
3229
+
3230
+ Args:
3231
+ image: MediaPipe Image.
3232
+ timestamp_ms: The timestamp of the input video frame in milliseconds.
3233
+ image_processing_options: Options for image processing.
3234
+
3235
+ Returns:
3236
+ The face landmarks detection results.
3237
+
3238
+ Raises:
3239
+ ValueError: If any of the input arguments is invalid.
3240
+ RuntimeError: If face landmarker detection failed to run.
3241
+ """
3242
+ normalized_rect = self.convert_to_normalized_rect(
3243
+ image_processing_options, image, roi_allowed=False
3244
+ )
3245
+ output_packets = self._process_video_data({
3246
+ _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
3247
+ timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
3248
+ ),
3249
+ _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
3250
+ normalized_rect.to_pb2()
3251
+ ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
3252
+ })
3253
+
3254
+ if output_packets[_NORM_LANDMARKS_STREAM_NAME].is_empty():
3255
+ return FaceLandmarkerResult([], [], [])
3256
+
3257
+ return _build_landmarker_result2(output_packets)
3258
+
3259
+ def detect_async(
3260
+ self,
3261
+ image: image_module.Image,
3262
+ timestamp_ms: int,
3263
+ image_processing_options: Optional[_ImageProcessingOptions] = None,
3264
+ ) -> None:
3265
+ """Sends live image data to perform face landmarks detection.
3266
+
3267
+ The results will be available via the "result_callback" provided in the
3268
+ FaceLandmarkerOptions. Only use this method when the FaceLandmarker is
3269
+ created with the live stream running mode.
3270
+
3271
+ Only use this method when the FaceLandmarker is created with the live
3272
+ stream running mode. The input timestamps should be monotonically increasing
3273
+ for adjacent calls of this method. This method will return immediately after
3274
+ the input image is accepted. The results will be available via the
3275
+ `result_callback` provided in the `FaceLandmarkerOptions`. The
3276
+ `detect_async` method is designed to process live stream data such as
3277
+ camera input. To lower the overall latency, face landmarker may drop the
3278
+ input images if needed. In other words, it's not guaranteed to have output
3279
+ per input image.
3280
+
3281
+ The `result_callback` provides:
3282
+ - The face landmarks detection results.
3283
+ - The input image that the face landmarker runs on.
3284
+ - The input timestamp in milliseconds.
3285
+
3286
+ Args:
3287
+ image: MediaPipe Image.
3288
+ timestamp_ms: The timestamp of the input image in milliseconds.
3289
+ image_processing_options: Options for image processing.
3290
+
3291
+ Raises:
3292
+ ValueError: If the current input timestamp is smaller than what the
3293
+ face landmarker has already processed.
3294
+ """
3295
+ normalized_rect = self.convert_to_normalized_rect(
3296
+ image_processing_options, image, roi_allowed=False
3297
+ )
3298
+ self._send_live_stream_data({
3299
+ _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
3300
+ timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
3301
+ ),
3302
+ _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
3303
+ normalized_rect.to_pb2()
3304
+ ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
3305
+ })
src/utils/mp_utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import cv2
4
+ import time
5
+ from tqdm import tqdm
6
+ import multiprocessing
7
+ import glob
8
+
9
+ import mediapipe as mp
10
+ from mediapipe import solutions
11
+ from mediapipe.framework.formats import landmark_pb2
12
+ from mediapipe.tasks import python
13
+ from mediapipe.tasks.python import vision
14
+ from . import face_landmark
15
+
16
+ CUR_DIR = os.path.dirname(__file__)
17
+
18
+
19
+ class LMKExtractor():
20
+ def __init__(self, FPS=25):
21
+ # Create an FaceLandmarker object.
22
+ self.mode = mp.tasks.vision.FaceDetectorOptions.running_mode.IMAGE
23
+ base_options = python.BaseOptions(model_asset_path=os.path.join(CUR_DIR, 'mp_models/face_landmarker_v2_with_blendshapes.task'))
24
+ base_options.delegate = mp.tasks.BaseOptions.Delegate.CPU
25
+ options = vision.FaceLandmarkerOptions(base_options=base_options,
26
+ running_mode=self.mode,
27
+ output_face_blendshapes=True,
28
+ output_facial_transformation_matrixes=True,
29
+ num_faces=1)
30
+ self.detector = face_landmark.FaceLandmarker.create_from_options(options)
31
+ self.last_ts = 0
32
+ self.frame_ms = int(1000 / FPS)
33
+
34
+ det_base_options = python.BaseOptions(model_asset_path=os.path.join(CUR_DIR, 'mp_models/blaze_face_short_range.tflite'))
35
+ det_options = vision.FaceDetectorOptions(base_options=det_base_options)
36
+ self.det_detector = vision.FaceDetector.create_from_options(det_options)
37
+
38
+
39
+ def __call__(self, img):
40
+ frame = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
41
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame)
42
+ t0 = time.time()
43
+ if self.mode == mp.tasks.vision.FaceDetectorOptions.running_mode.VIDEO:
44
+ det_result = self.det_detector.detect(image)
45
+ if len(det_result.detections) != 1:
46
+ return None
47
+ self.last_ts += self.frame_ms
48
+ try:
49
+ detection_result, mesh3d = self.detector.detect_for_video(image, timestamp_ms=self.last_ts)
50
+ except:
51
+ return None
52
+ elif self.mode == mp.tasks.vision.FaceDetectorOptions.running_mode.IMAGE:
53
+ # det_result = self.det_detector.detect(image)
54
+
55
+ # if len(det_result.detections) != 1:
56
+ # return None
57
+ try:
58
+ detection_result, mesh3d = self.detector.detect(image)
59
+ except:
60
+ return None
61
+
62
+
63
+ bs_list = detection_result.face_blendshapes
64
+ if len(bs_list) == 1:
65
+ bs = bs_list[0]
66
+ bs_values = []
67
+ for index in range(len(bs)):
68
+ bs_values.append(bs[index].score)
69
+ bs_values = bs_values[1:] # remove neutral
70
+ trans_mat = detection_result.facial_transformation_matrixes[0]
71
+ face_landmarks_list = detection_result.face_landmarks
72
+ face_landmarks = face_landmarks_list[0]
73
+ lmks = []
74
+ for index in range(len(face_landmarks)):
75
+ x = face_landmarks[index].x
76
+ y = face_landmarks[index].y
77
+ z = face_landmarks[index].z
78
+ lmks.append([x, y, z])
79
+ lmks = np.array(lmks)
80
+
81
+ lmks3d = np.array(mesh3d.vertex_buffer)
82
+ lmks3d = lmks3d.reshape(-1, 5)[:, :3]
83
+ mp_tris = np.array(mesh3d.index_buffer).reshape(-1, 3) + 1
84
+
85
+ return {
86
+ "lmks": lmks,
87
+ 'lmks3d': lmks3d,
88
+ "trans_mat": trans_mat,
89
+ 'faces': mp_tris,
90
+ "bs": bs_values
91
+ }
92
+ else:
93
+ # print('multiple faces in the image: {}'.format(img_path))
94
+ return None
95
+
src/utils/pose_util.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ from scipy.spatial.transform import Rotation as R
5
+
6
+
7
+ def create_perspective_matrix(aspect_ratio):
8
+ kDegreesToRadians = np.pi / 180.
9
+ near = 1
10
+ far = 10000
11
+ perspective_matrix = np.zeros(16, dtype=np.float32)
12
+
13
+ # Standard perspective projection matrix calculations.
14
+ f = 1.0 / np.tan(kDegreesToRadians * 63 / 2.)
15
+
16
+ denom = 1.0 / (near - far)
17
+ perspective_matrix[0] = f / aspect_ratio
18
+ perspective_matrix[5] = f
19
+ perspective_matrix[10] = (near + far) * denom
20
+ perspective_matrix[11] = -1.
21
+ perspective_matrix[14] = 1. * far * near * denom
22
+
23
+ # If the environment's origin point location is in the top left corner,
24
+ # then skip additional flip along Y-axis is required to render correctly.
25
+
26
+ perspective_matrix[5] *= -1.
27
+ return perspective_matrix
28
+
29
+
30
+ def project_points(points_3d, transformation_matrix, pose_vectors, image_shape):
31
+ P = create_perspective_matrix(image_shape[1] / image_shape[0]).reshape(4, 4).T
32
+ L, N, _ = points_3d.shape
33
+ projected_points = np.zeros((L, N, 2))
34
+ for i in range(L):
35
+ points_3d_frame = points_3d[i]
36
+ ones = np.ones((points_3d_frame.shape[0], 1))
37
+ points_3d_homogeneous = np.hstack([points_3d_frame, ones])
38
+ transformed_points = points_3d_homogeneous @ (transformation_matrix @ euler_and_translation_to_matrix(pose_vectors[i][:3], pose_vectors[i][3:])).T @ P
39
+ projected_points_frame = transformed_points[:, :2] / transformed_points[:, 3, np.newaxis] # -1 ~ 1
40
+ projected_points_frame[:, 0] = (projected_points_frame[:, 0] + 1) * 0.5 * image_shape[1]
41
+ projected_points_frame[:, 1] = (projected_points_frame[:, 1] + 1) * 0.5 * image_shape[0]
42
+ projected_points[i] = projected_points_frame
43
+ return projected_points
44
+
45
+
46
+ def project_points_with_trans(points_3d, transformation_matrix, image_shape):
47
+ P = create_perspective_matrix(image_shape[1] / image_shape[0]).reshape(4, 4).T
48
+ L, N, _ = points_3d.shape
49
+ projected_points = np.zeros((L, N, 2))
50
+ for i in range(L):
51
+ points_3d_frame = points_3d[i]
52
+ ones = np.ones((points_3d_frame.shape[0], 1))
53
+ points_3d_homogeneous = np.hstack([points_3d_frame, ones])
54
+ transformed_points = points_3d_homogeneous @ transformation_matrix[i].T @ P
55
+ projected_points_frame = transformed_points[:, :2] / transformed_points[:, 3, np.newaxis] # -1 ~ 1
56
+ projected_points_frame[:, 0] = (projected_points_frame[:, 0] + 1) * 0.5 * image_shape[1]
57
+ projected_points_frame[:, 1] = (projected_points_frame[:, 1] + 1) * 0.5 * image_shape[0]
58
+ projected_points[i] = projected_points_frame
59
+ return projected_points
60
+
61
+
62
+ def euler_and_translation_to_matrix(euler_angles, translation_vector):
63
+ rotation = R.from_euler('xyz', euler_angles, degrees=True)
64
+ rotation_matrix = rotation.as_matrix()
65
+
66
+ matrix = np.eye(4)
67
+ matrix[:3, :3] = rotation_matrix
68
+ matrix[:3, 3] = translation_vector
69
+
70
+ return matrix
71
+
72
+
73
+ def matrix_to_euler_and_translation(matrix):
74
+ rotation_matrix = matrix[:3, :3]
75
+ translation_vector = matrix[:3, 3]
76
+ rotation = R.from_matrix(rotation_matrix)
77
+ euler_angles = rotation.as_euler('xyz', degrees=True)
78
+ return euler_angles, translation_vector
src/utils/util.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+ import os.path as osp
4
+ import shutil
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ import av
9
+ import numpy as np
10
+ import torch
11
+ import torchvision
12
+ from einops import rearrange
13
+ from PIL import Image
14
+
15
+
16
+ def seed_everything(seed):
17
+ import random
18
+
19
+ import numpy as np
20
+
21
+ torch.manual_seed(seed)
22
+ torch.cuda.manual_seed_all(seed)
23
+ np.random.seed(seed % (2**32))
24
+ random.seed(seed)
25
+
26
+
27
+ def import_filename(filename):
28
+ spec = importlib.util.spec_from_file_location("mymodule", filename)
29
+ module = importlib.util.module_from_spec(spec)
30
+ sys.modules[spec.name] = module
31
+ spec.loader.exec_module(module)
32
+ return module
33
+
34
+
35
+ def delete_additional_ckpt(base_path, num_keep):
36
+ dirs = []
37
+ for d in os.listdir(base_path):
38
+ if d.startswith("checkpoint-"):
39
+ dirs.append(d)
40
+ num_tot = len(dirs)
41
+ if num_tot <= num_keep:
42
+ return
43
+ # ensure ckpt is sorted and delete the ealier!
44
+ del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
45
+ for d in del_dirs:
46
+ path_to_dir = osp.join(base_path, d)
47
+ if osp.exists(path_to_dir):
48
+ shutil.rmtree(path_to_dir)
49
+
50
+
51
+ def save_videos_from_pil(pil_images, path, fps=8):
52
+ import av
53
+
54
+ save_fmt = Path(path).suffix
55
+ os.makedirs(os.path.dirname(path), exist_ok=True)
56
+ width, height = pil_images[0].size
57
+
58
+ if save_fmt == ".mp4":
59
+ codec = "libx264"
60
+ container = av.open(path, "w")
61
+ stream = container.add_stream(codec, rate=fps)
62
+
63
+ stream.width = width
64
+ stream.height = height
65
+
66
+ for pil_image in pil_images:
67
+ # pil_image = Image.fromarray(image_arr).convert("RGB")
68
+ av_frame = av.VideoFrame.from_image(pil_image)
69
+ container.mux(stream.encode(av_frame))
70
+ container.mux(stream.encode())
71
+ container.close()
72
+
73
+ elif save_fmt == ".gif":
74
+ pil_images[0].save(
75
+ fp=path,
76
+ format="GIF",
77
+ append_images=pil_images[1:],
78
+ save_all=True,
79
+ duration=(1 / fps * 1000),
80
+ loop=0,
81
+ )
82
+ else:
83
+ raise ValueError("Unsupported file type. Use .mp4 or .gif.")
84
+
85
+
86
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
87
+ videos = rearrange(videos, "b c t h w -> t b c h w")
88
+ height, width = videos.shape[-2:]
89
+ outputs = []
90
+
91
+ for x in videos:
92
+ x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w)
93
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c)
94
+ if rescale:
95
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
96
+ x = (x * 255).numpy().astype(np.uint8)
97
+ x = Image.fromarray(x)
98
+
99
+ outputs.append(x)
100
+
101
+ os.makedirs(os.path.dirname(path), exist_ok=True)
102
+
103
+ save_videos_from_pil(outputs, path, fps)
104
+
105
+
106
+ def read_frames(video_path):
107
+ container = av.open(video_path)
108
+
109
+ video_stream = next(s for s in container.streams if s.type == "video")
110
+ frames = []
111
+ for packet in container.demux(video_stream):
112
+ for frame in packet.decode():
113
+ image = Image.frombytes(
114
+ "RGB",
115
+ (frame.width, frame.height),
116
+ frame.to_rgb().to_ndarray(),
117
+ )
118
+ frames.append(image)
119
+
120
+ return frames
121
+
122
+
123
+ def get_fps(video_path):
124
+ container = av.open(video_path)
125
+ video_stream = next(s for s in container.streams if s.type == "video")
126
+ fps = video_stream.average_rate
127
+ container.close()
128
+ return fps
src/vid2vid.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import shutil
4
+ import ffmpeg
5
+ from datetime import datetime
6
+ from pathlib import Path
7
+ import numpy as np
8
+ import cv2
9
+ import torch
10
+ # import spaces
11
+ from diffusers import AutoencoderKL, DDIMScheduler
12
+ from einops import repeat
13
+ from omegaconf import OmegaConf
14
+ from PIL import Image
15
+ from torchvision import transforms
16
+ from transformers import CLIPVisionModelWithProjection
17
+
18
+ from src.models.pose_guider import PoseGuider
19
+ from src.models.unet_2d_condition import UNet2DConditionModel
20
+ from src.models.unet_3d import UNet3DConditionModel
21
+ from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
22
+ from src.utils.util import get_fps, read_frames, save_videos_grid
23
+
24
+ from src.utils.mp_utils import LMKExtractor
25
+ from src.utils.draw_util import FaceMeshVisualizer
26
+ from src.utils.pose_util import project_points_with_trans, matrix_to_euler_and_translation
27
+ from src.audio2vid import smooth_pose_seq
28
+
29
+
30
+ def parse_args():
31
+ parser = argparse.ArgumentParser()
32
+ parser.add_argument("--config", type=str, default='./configs/prompts/animation_facereenac.yaml')
33
+ parser.add_argument("-W", type=int, default=512)
34
+ parser.add_argument("-H", type=int, default=512)
35
+ parser.add_argument("-L", type=int)
36
+ parser.add_argument("--seed", type=int, default=42)
37
+ parser.add_argument("--cfg", type=float, default=3.5)
38
+ parser.add_argument("--steps", type=int, default=25)
39
+ parser.add_argument("--fps", type=int)
40
+ args = parser.parse_args()
41
+
42
+ return args
43
+
44
+ # @spaces.GPU
45
+ def video2video(ref_img, source_video, size=512, steps=25, length=150, seed=42):
46
+ cfg = 3.5
47
+
48
+ config = OmegaConf.load('./configs/prompts/animation_facereenac.yaml')
49
+
50
+ if config.weight_dtype == "fp16":
51
+ weight_dtype = torch.float16
52
+ else:
53
+ weight_dtype = torch.float32
54
+
55
+ vae = AutoencoderKL.from_pretrained(
56
+ config.pretrained_vae_path,
57
+ ).to("cuda", dtype=weight_dtype)
58
+
59
+ reference_unet = UNet2DConditionModel.from_pretrained(
60
+ config.pretrained_base_model_path,
61
+ subfolder="unet",
62
+ ).to(dtype=weight_dtype, device="cuda")
63
+
64
+ inference_config_path = config.inference_config
65
+ infer_config = OmegaConf.load(inference_config_path)
66
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
67
+ config.pretrained_base_model_path,
68
+ config.motion_module_path,
69
+ subfolder="unet",
70
+ unet_additional_kwargs=infer_config.unet_additional_kwargs,
71
+ ).to(dtype=weight_dtype, device="cuda")
72
+
73
+ pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
74
+
75
+ image_enc = CLIPVisionModelWithProjection.from_pretrained(
76
+ config.image_encoder_path
77
+ ).to(dtype=weight_dtype, device="cuda")
78
+
79
+ sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
80
+ scheduler = DDIMScheduler(**sched_kwargs)
81
+
82
+ generator = torch.manual_seed(seed)
83
+
84
+ width, height = size, size
85
+
86
+ # load pretrained weights
87
+ denoising_unet.load_state_dict(
88
+ torch.load(config.denoising_unet_path, map_location="cpu"),
89
+ strict=False,
90
+ )
91
+ reference_unet.load_state_dict(
92
+ torch.load(config.reference_unet_path, map_location="cpu"),
93
+ )
94
+ pose_guider.load_state_dict(
95
+ torch.load(config.pose_guider_path, map_location="cpu"),
96
+ )
97
+
98
+ pipe = Pose2VideoPipeline(
99
+ vae=vae,
100
+ image_encoder=image_enc,
101
+ reference_unet=reference_unet,
102
+ denoising_unet=denoising_unet,
103
+ pose_guider=pose_guider,
104
+ scheduler=scheduler,
105
+ )
106
+ pipe = pipe.to("cuda", dtype=weight_dtype)
107
+
108
+ date_str = datetime.now().strftime("%Y%m%d")
109
+ time_str = datetime.now().strftime("%H%M")
110
+ save_dir_name = f"{time_str}--seed_{seed}-{size}x{size}"
111
+
112
+ save_dir = Path(f"output/{date_str}/{save_dir_name}")
113
+ save_dir.mkdir(exist_ok=True, parents=True)
114
+
115
+
116
+ lmk_extractor = LMKExtractor()
117
+ vis = FaceMeshVisualizer(forehead_edge=False)
118
+
119
+
120
+
121
+ ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
122
+ # TODO: 人脸检测+裁剪
123
+ ref_image_np = cv2.resize(ref_image_np, (size, size))
124
+ ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
125
+
126
+ face_result = lmk_extractor(ref_image_np)
127
+ if face_result is None:
128
+ return None
129
+
130
+ lmks = face_result['lmks'].astype(np.float32)
131
+ ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
132
+
133
+
134
+
135
+ source_images = read_frames(source_video)
136
+ src_fps = get_fps(source_video)
137
+ pose_transform = transforms.Compose(
138
+ [transforms.Resize((height, width)), transforms.ToTensor()]
139
+ )
140
+
141
+ step = 1
142
+ if src_fps == 60:
143
+ src_fps = 30
144
+ step = 2
145
+
146
+ pose_trans_list = []
147
+ verts_list = []
148
+ bs_list = []
149
+ src_tensor_list = []
150
+ args_L = len(source_images) if length==0 or length*step > len(source_images) else length*step
151
+ for src_image_pil in source_images[: args_L: step]:
152
+ src_tensor_list.append(pose_transform(src_image_pil))
153
+ src_img_np = cv2.cvtColor(np.array(src_image_pil), cv2.COLOR_RGB2BGR)
154
+ frame_height, frame_width, _ = src_img_np.shape
155
+ src_img_result = lmk_extractor(src_img_np)
156
+ if src_img_result is None:
157
+ break
158
+ pose_trans_list.append(src_img_result['trans_mat'])
159
+ verts_list.append(src_img_result['lmks3d'])
160
+ bs_list.append(src_img_result['bs'])
161
+
162
+
163
+ # pose_arr = np.array(pose_trans_list)
164
+ trans_mat_arr = np.array(pose_trans_list)
165
+ verts_arr = np.array(verts_list)
166
+ bs_arr = np.array(bs_list)
167
+ min_bs_idx = np.argmin(bs_arr.sum(1))
168
+
169
+ # compute delta pose
170
+ trans_mat_inv_frame_0 = np.linalg.inv(trans_mat_arr[0])
171
+ pose_arr = np.zeros([trans_mat_arr.shape[0], 6])
172
+
173
+ for i in range(pose_arr.shape[0]):
174
+ pose_mat = trans_mat_inv_frame_0 @ trans_mat_arr[i]
175
+ euler_angles, translation_vector = matrix_to_euler_and_translation(pose_mat)
176
+ pose_arr[i, :3] = euler_angles
177
+ pose_arr[i, 3:6] = translation_vector
178
+
179
+ pose_arr = smooth_pose_seq(pose_arr)
180
+
181
+ # face retarget
182
+ verts_arr = verts_arr - verts_arr[min_bs_idx] + face_result['lmks3d']
183
+ # project 3D mesh to 2D landmark
184
+ projected_vertices = project_points_with_trans(verts_arr, pose_arr, [frame_height, frame_width])
185
+
186
+ pose_list = []
187
+ for i, verts in enumerate(projected_vertices):
188
+ lmk_img = vis.draw_landmarks((frame_width, frame_height), verts, normed=False)
189
+ pose_image_np = cv2.resize(lmk_img, (width, height))
190
+ pose_list.append(pose_image_np)
191
+
192
+ pose_list = np.array(pose_list)
193
+
194
+ video_length = len(pose_list)
195
+
196
+ video = pipe(
197
+ ref_image_pil,
198
+ pose_list,
199
+ ref_pose,
200
+ width,
201
+ height,
202
+ video_length,
203
+ steps,
204
+ cfg,
205
+ generator=generator,
206
+ ).videos
207
+
208
+ save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4"
209
+ save_videos_grid(
210
+ video,
211
+ save_path,
212
+ n_rows=1,
213
+ fps=src_fps,
214
+ )
215
+
216
+ audio_output = f'{save_dir}/audio_from_video.aac'
217
+ # extract audio
218
+ try:
219
+ ffmpeg.input(source_video).output(audio_output, acodec='copy').run()
220
+ # merge audio and video
221
+ stream = ffmpeg.input(save_path)
222
+ audio = ffmpeg.input(audio_output)
223
+ ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac').run()
224
+
225
+ os.remove(save_path)
226
+ os.remove(audio_output)
227
+ except:
228
+ shutil.move(
229
+ save_path,
230
+ save_path.replace('_noaudio.mp4', '.mp4')
231
+ )
232
+
233
+ return save_path.replace('_noaudio.mp4', '.mp4')