Canberk Baykal commited on
Commit
b5ed368
·
1 Parent(s): 06eb18b
adapter/adapter_decoder.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from adapter import clipadapter
4
+ from models.stylegan2.model_remapper import Generator
5
+
6
+
7
+ def get_keys(d, name):
8
+ if 'state_dict' in d:
9
+ d = d['state_dict']
10
+ d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
11
+ return d_filt
12
+
13
+
14
+ class CLIPAdapterWithDecoder(nn.Module):
15
+
16
+ def __init__(self, opts):
17
+ super(CLIPAdapterWithDecoder, self).__init__()
18
+ self.opts = opts
19
+ # Define architecture
20
+ self.adapter = clipadapter.CLIPAdapter(self.opts)
21
+ self.decoder = Generator(self.opts.stylegan_size, 512, 8)
22
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
23
+ # Load weights if needed
24
+ self.load_weights()
25
+
26
+
27
+
28
+ def load_weights(self):
29
+ if self.opts.checkpoint_path is not None:
30
+ print('Loading from checkpoint: {}'.format(self.opts.checkpoint_path))
31
+ ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
32
+ self.adapter.load_state_dict(get_keys(ckpt, 'mapper'), strict=False)
33
+ self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
34
+ else:
35
+ print('Loading decoder weights from pretrained!')
36
+ ckpt = torch.load(self.opts.stylegan_weights)
37
+ self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
38
+
39
+ def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
40
+ inject_latent=None, return_latents=False, alpha=None):
41
+ if input_code:
42
+ codes = x
43
+ else:
44
+ codes = self.adapter(x)
45
+
46
+ if latent_mask is not None:
47
+ for i in latent_mask:
48
+ if inject_latent is not None:
49
+ if alpha is not None:
50
+ codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
51
+ else:
52
+ codes[:, i] = inject_latent[:, i]
53
+ else:
54
+ codes[:, i] = 0
55
+
56
+ input_is_latent = not input_code
57
+ images, result_latent = self.decoder([codes],
58
+ input_is_latent=input_is_latent,
59
+ randomize_noise=randomize_noise,
60
+ return_latents=return_latents)
61
+
62
+ if resize:
63
+ images = self.face_pool(images)
64
+
65
+ if return_latents:
66
+ return images, result_latent
67
+ else:
68
+ return images
adapter/clipadapter.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from models.stylegan2.model import PixelNorm
3
+ from torch.nn import Linear, LayerNorm, LeakyReLU, Sequential, Module, Conv2d, GroupNorm
4
+
5
+ class TextModulationModule(Module):
6
+ def __init__(self, in_channels):
7
+ super(TextModulationModule, self).__init__()
8
+ self.conv = Conv2d(in_channels, in_channels, 3, stride=1, padding=1, bias=False)
9
+ self.norm = GroupNorm(32, in_channels)
10
+ self.gamma_function = Sequential(Linear(512, 512), LayerNorm([512]), LeakyReLU(), Linear(512, in_channels))
11
+ self.beta_function = Sequential(Linear(512, 512), LayerNorm([512]), LeakyReLU(), Linear(512, in_channels))
12
+ self.leakyrelu = LeakyReLU()
13
+
14
+ def forward(self, x, embedding):
15
+ x = self.conv(x)
16
+ x = self.norm(x)
17
+ log_gamma = self.gamma_function(embedding.float())
18
+ gamma = log_gamma.exp().unsqueeze(2).unsqueeze(3)
19
+ beta = self.beta_function(embedding.float()).unsqueeze(2).unsqueeze(3)
20
+ out = x * (1 + gamma) + beta
21
+ out = self.leakyrelu(out)
22
+ return out
23
+
24
+ class SubTextMapper(Module):
25
+ def __init__(self, opts, in_channels):
26
+ super(SubTextMapper, self).__init__()
27
+ self.opts = opts
28
+ self.pixelnorm = PixelNorm()
29
+ self.modulation_module_list = nn.ModuleList([TextModulationModule(in_channels) for _ in range(1)])
30
+
31
+ def forward(self, x, embedding):
32
+ x = self.pixelnorm(x)
33
+ for modulation_module in self.modulation_module_list:
34
+ x = modulation_module(x, embedding)
35
+ return x
36
+
37
+ class CLIPAdapter(Module):
38
+ def __init__(self, opts):
39
+ super(CLIPAdapter, self).__init__()
40
+ self.opts = opts
41
+
42
+ if not opts.no_coarse_mapper:
43
+ self.coarse_mapping = SubTextMapper(opts, 512)
44
+ if not opts.no_medium_mapper:
45
+ self.medium_mapping = SubTextMapper(opts, 256)
46
+ if not opts.no_fine_mapper:
47
+ self.fine_mapping = SubTextMapper(opts, 128)
48
+
49
+
50
+ def forward(self, features, txt_embed):
51
+ txt_embed = txt_embed.detach()
52
+ c1, c2, c3 = features
53
+
54
+ if not self.opts.no_coarse_mapper:
55
+ c3 = self.coarse_mapping(c3, txt_embed)
56
+ if not self.opts.no_medium_mapper:
57
+ c2 = self.medium_mapping(c2, txt_embed)
58
+ if not self.opts.no_fine_mapper:
59
+ c1 = self.fine_mapping(c1, txt_embed)
60
+ return (c1,c2,c3)
align_faces_parallel.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset)
3
+ author: lzhbrian (https://lzhbrian.me)
4
+ date: 2020.1.5
5
+ note: code is heavily borrowed from
6
+ https://github.com/NVlabs/ffhq-dataset
7
+ http://dlib.net/face_landmark_detection.py.html
8
+
9
+ requirements:
10
+ apt install cmake
11
+ conda install Pillow numpy scipy
12
+ pip install dlib
13
+ # download face landmark model from:
14
+ # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
15
+ """
16
+ from argparse import ArgumentParser
17
+ import time
18
+ import numpy as np
19
+ import PIL
20
+ import PIL.Image
21
+ import os
22
+ import scipy
23
+ import scipy.ndimage
24
+ import dlib
25
+ import multiprocessing as mp
26
+ import math
27
+
28
+
29
+ SHAPE_PREDICTOR_PATH = "shape_predictor_68_face_landmarks.dat"
30
+
31
+
32
+ def get_landmark(img, predictor):
33
+ """get landmark with dlib
34
+ :return: np.array shape=(68, 2)
35
+ """
36
+ detector = dlib.get_frontal_face_detector()
37
+
38
+ # img = dlib.load_rgb_image(filepath)
39
+ img = img
40
+ img = np.uint8(np.array(img))
41
+ dets = detector(img, 1)
42
+
43
+ shape = None
44
+ for k, d in enumerate(dets):
45
+ shape = predictor(img, d)
46
+
47
+ if not shape:
48
+ # raise Exception("Could not find face in image! Please try another image!")
49
+ return None
50
+
51
+ t = list(shape.parts())
52
+ a = []
53
+ for tt in t:
54
+ a.append([tt.x, tt.y])
55
+ lm = np.array(a)
56
+ return lm
57
+
58
+
59
+ def align_face(img, predictor, output_size=256, transform_size=256):
60
+ """
61
+ :param filepath: str
62
+ :return: PIL Image
63
+ """
64
+
65
+ lm = get_landmark(img, predictor)
66
+ if lm is None:
67
+ return None
68
+
69
+ lm_chin = lm[0: 17] # left-right
70
+ lm_eyebrow_left = lm[17: 22] # left-right
71
+ lm_eyebrow_right = lm[22: 27] # left-right
72
+ lm_nose = lm[27: 31] # top-down
73
+ lm_nostrils = lm[31: 36] # top-down
74
+ lm_eye_left = lm[36: 42] # left-clockwise
75
+ lm_eye_right = lm[42: 48] # left-clockwise
76
+ lm_mouth_outer = lm[48: 60] # left-clockwise
77
+ lm_mouth_inner = lm[60: 68] # left-clockwise
78
+
79
+ # Calculate auxiliary vectors.
80
+ eye_left = np.mean(lm_eye_left, axis=0)
81
+ eye_right = np.mean(lm_eye_right, axis=0)
82
+ eye_avg = (eye_left + eye_right) * 0.5
83
+ eye_to_eye = eye_right - eye_left
84
+ mouth_left = lm_mouth_outer[0]
85
+ mouth_right = lm_mouth_outer[6]
86
+ mouth_avg = (mouth_left + mouth_right) * 0.5
87
+ eye_to_mouth = mouth_avg - eye_avg
88
+
89
+ # Choose oriented crop rectangle.
90
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
91
+ x /= np.hypot(*x)
92
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
93
+ y = np.flipud(x) * [-1, 1]
94
+ c = eye_avg + eye_to_mouth * 0.1
95
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
96
+ qsize = np.hypot(*x) * 2
97
+
98
+ # read image
99
+ img = img
100
+ enable_padding = True
101
+
102
+ # Shrink.
103
+ shrink = int(np.floor(qsize / output_size * 0.5))
104
+ if shrink > 1:
105
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
106
+ img = img.resize(rsize, PIL.Image.ANTIALIAS)
107
+ quad /= shrink
108
+ qsize /= shrink
109
+
110
+ # Crop.
111
+ border = max(int(np.rint(qsize * 0.1)), 3)
112
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
113
+ int(np.ceil(max(quad[:, 1]))))
114
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
115
+ min(crop[3] + border, img.size[1]))
116
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
117
+ img = img.crop(crop)
118
+ quad -= crop[0:2]
119
+
120
+ # Pad.
121
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
122
+ int(np.ceil(max(quad[:, 1]))))
123
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
124
+ max(pad[3] - img.size[1] + border, 0))
125
+ if enable_padding and max(pad) > border - 4:
126
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
127
+ img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
128
+ h, w, _ = img.shape
129
+ y, x, _ = np.ogrid[:h, :w, :1]
130
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
131
+ 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
132
+ blur = qsize * 0.02
133
+ img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
134
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
135
+ img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
136
+ quad += pad[:2]
137
+
138
+ # Transform.
139
+ img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
140
+ if output_size < transform_size:
141
+ img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
142
+
143
+ # Save aligned image.
144
+ return img
145
+
146
+
147
+ def chunks(lst, n):
148
+ """Yield successive n-sized chunks from lst."""
149
+ for i in range(0, len(lst), n):
150
+ yield lst[i:i + n]
151
+
152
+
153
+ def extract_on_paths(file_paths):
154
+ predictor = dlib.shape_predictor(SHAPE_PREDICTOR_PATH)
155
+ pid = mp.current_process().name
156
+ print('\t{} is starting to extract on #{} images'.format(pid, len(file_paths)))
157
+ tot_count = len(file_paths)
158
+ count = 0
159
+ for file_path, res_path in file_paths:
160
+ count += 1
161
+ if count % 100 == 0:
162
+ print('{} done with {}/{}'.format(pid, count, tot_count))
163
+ try:
164
+ res = align_face(file_path, predictor)
165
+ res = res.convert('RGB')
166
+ os.makedirs(os.path.dirname(res_path), exist_ok=True)
167
+ res.save(res_path)
168
+ except Exception:
169
+ continue
170
+ print('\tDone!')
171
+
172
+
173
+ def parse_args():
174
+ parser = ArgumentParser(add_help=False)
175
+ parser.add_argument('--num_threads', type=int, default=1)
176
+ parser.add_argument('--root_path', type=str, default='')
177
+ args = parser.parse_args()
178
+ return args
179
+
180
+
181
+ def run(args):
182
+ root_path = args.root_path
183
+ out_crops_path = root_path + '_crops'
184
+ if not os.path.exists(out_crops_path):
185
+ os.makedirs(out_crops_path, exist_ok=True)
186
+
187
+ file_paths = []
188
+ for root, dirs, files in os.walk(root_path):
189
+ for file in files:
190
+ file_path = os.path.join(root, file)
191
+ fname = os.path.join(out_crops_path, os.path.relpath(file_path, root_path))
192
+ res_path = '{}.jpg'.format(os.path.splitext(fname)[0])
193
+ if os.path.splitext(file_path)[1] == '.txt' or os.path.exists(res_path):
194
+ continue
195
+ file_paths.append((file_path, res_path))
196
+
197
+ file_chunks = list(chunks(file_paths, int(math.ceil(len(file_paths) / args.num_threads))))
198
+ print(len(file_chunks))
199
+ pool = mp.Pool(args.num_threads)
200
+ print('Running on {} paths\nHere we goooo'.format(len(file_paths)))
201
+ tic = time.time()
202
+ pool.map(extract_on_paths, file_chunks)
203
+ toc = time.time()
204
+ print('Mischief managed in {}s'.format(toc - tic))
205
+
206
+
207
+ if __name__ == '__main__':
208
+ args = parse_args()
209
+ run(args)
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from argparse import Namespace
3
+ import torchvision.transforms as transforms
4
+ import clip
5
+ import numpy as np
6
+ import sys
7
+ sys.path.append(".")
8
+ sys.path.append("..")
9
+ from models.e4e_features import pSp
10
+ from adapter.adapter_decoder import CLIPAdapterWithDecoder
11
+
12
+ import gradio as gr
13
+
14
+ def tensor2im(var):
15
+ var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy()
16
+ var = ((var + 1) / 2)
17
+ var[var < 0] = 0
18
+ var[var > 1] = 1
19
+ var = var * 255
20
+ return var.astype('uint8')
21
+
22
+ def run_alignment(image_path):
23
+ import dlib
24
+ from align_faces_parallel import align_face
25
+ predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
26
+ aligned_image = align_face(image_path, predictor=predictor)
27
+ # print("Aligned image has shape: {}".format(aligned_image.size))
28
+ return aligned_image
29
+
30
+ input_transforms = transforms.Compose([
31
+ transforms.Resize((256, 256)),
32
+ transforms.ToTensor(),
33
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
34
+
35
+ model_path = "/scratch/users/aanees20/hpc_run/light_textmod/textmodulation/mapper/exp_ada5/checkpoints/best_model.pt"
36
+ e4e_path = 'e4e_ffhq_encode.pt' # "/scratch/users/abaykal20/hpc_run/sam/SAM/pretrained_models/e4e_ffhq_encode.pt"
37
+
38
+ ckpt = torch.load(model_path, map_location='cpu')
39
+ opts = ckpt['opts']
40
+ opts['checkpoint_path'] = model_path
41
+ opts['pretrained_e4e_path'] = e4e_path
42
+ opts['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
43
+ opts = Namespace(**opts)
44
+ encoder = pSp(opts)
45
+ encoder.eval()
46
+ encoder.cuda()
47
+
48
+ adapter = CLIPAdapterWithDecoder(opts)
49
+ adapter.eval()
50
+ adapter.cuda()
51
+
52
+ clip_model, _ = clip.load("ViT-B/32", device='cuda')
53
+
54
+ def manipulate(input_image, caption):
55
+ aligned_image = run_alignment(input_image)
56
+ input_image = input_transforms(aligned_image)
57
+ input_image = input_image.unsqueeze(0)
58
+ text_input = clip.tokenize(caption)
59
+ text_input = text_input.cuda()
60
+ input_image = input_image.cuda().float()
61
+
62
+ with torch.no_grad():
63
+ text_features = clip_model.encode_text(text_input).float()
64
+
65
+ w, features = encoder.forward(input_image, return_latents=True)
66
+ features = adapter.adapter(features, text_features)
67
+ w_hat = w + 0.1 * encoder.forward_features(features)
68
+
69
+ result_tensor, _ = adapter.decoder([w_hat], input_is_latent=True, return_latents=False, randomize_noise=False, truncation=1, txt_embed=text_features)
70
+ result_tensor = result_tensor.squeeze(0)
71
+ result_image = tensor2im(result_tensor)
72
+
73
+ return result_image
74
+
75
+ gr.Interface(fn=manipulate,
76
+ inputs=[gr.Image(type="pil"), "text"],
77
+ outputs="image",).launch(share=True)
models/e4e_features.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+
3
+ matplotlib.use('Agg')
4
+ import torch
5
+ from torch import nn
6
+ from models.encoders import psp_encoders_features
7
+
8
+
9
+ def get_keys(d, name):
10
+ if 'state_dict' in d:
11
+ d = d['state_dict']
12
+ d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
13
+ return d_filt
14
+
15
+
16
+ class pSp(nn.Module):
17
+
18
+ def __init__(self, opts):
19
+ super(pSp, self).__init__()
20
+ self.opts = opts
21
+ # Define architecture
22
+ self.encoder = self.set_encoder().eval()
23
+ # Load weights if needed
24
+ self.load_weights()
25
+
26
+ def set_encoder(self):
27
+ encoder = psp_encoders_features.Encoder4Editing(50, 'ir_se', self.opts)
28
+ return encoder
29
+
30
+ def load_weights(self):
31
+ # We only load the encoder weights
32
+ print('Loading e4e over the pSp framework from checkpoint: {}'.format(self.opts.pretrained_e4e_path))
33
+ ckpt = torch.load(self.opts.pretrained_e4e_path, map_location='cpu')
34
+ self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
35
+ self.__load_latent_avg(ckpt)
36
+
37
+ def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
38
+ inject_latent=None, return_latents=False, alpha=None):
39
+ if input_code:
40
+ codes = x
41
+ else:
42
+ codes, features = self.encoder(x)
43
+ # normalize with respect to the center of an average face
44
+ if self.opts.start_from_latent_avg:
45
+ if codes.ndim == 2:
46
+ codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
47
+ else:
48
+ codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
49
+
50
+ if latent_mask is not None:
51
+ for i in latent_mask:
52
+ if inject_latent is not None:
53
+ if alpha is not None:
54
+ codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
55
+ else:
56
+ codes[:, i] = inject_latent[:, i]
57
+ else:
58
+ codes[:, i] = 0
59
+
60
+ return codes, features
61
+
62
+ # Forward the modulated feature maps
63
+ def forward_features(self, features):
64
+ return self.encoder.forward_features(features)
65
+
66
+ def __load_latent_avg(self, ckpt, repeat=None):
67
+ if 'latent_avg' in ckpt:
68
+ self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
69
+ if repeat is not None:
70
+ self.latent_avg = self.latent_avg.repeat(repeat, 1)
71
+ else:
72
+ self.latent_avg = None
models/encoders/__init__.py ADDED
File without changes
models/encoders/helpers.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
5
+
6
+ """
7
+ ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
8
+ """
9
+
10
+
11
+ class Flatten(Module):
12
+ def forward(self, input):
13
+ return input.view(input.size(0), -1)
14
+
15
+
16
+ def l2_norm(input, axis=1):
17
+ norm = torch.norm(input, 2, axis, True)
18
+ output = torch.div(input, norm)
19
+ return output
20
+
21
+
22
+ class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
23
+ """ A named tuple describing a ResNet block. """
24
+
25
+
26
+ def get_block(in_channel, depth, num_units, stride=2):
27
+ return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
28
+
29
+
30
+ def get_blocks(num_layers):
31
+ if num_layers == 50:
32
+ blocks = [
33
+ get_block(in_channel=64, depth=64, num_units=3),
34
+ get_block(in_channel=64, depth=128, num_units=4),
35
+ get_block(in_channel=128, depth=256, num_units=14),
36
+ get_block(in_channel=256, depth=512, num_units=3)
37
+ ]
38
+ elif num_layers == 100:
39
+ blocks = [
40
+ get_block(in_channel=64, depth=64, num_units=3),
41
+ get_block(in_channel=64, depth=128, num_units=13),
42
+ get_block(in_channel=128, depth=256, num_units=30),
43
+ get_block(in_channel=256, depth=512, num_units=3)
44
+ ]
45
+ elif num_layers == 152:
46
+ blocks = [
47
+ get_block(in_channel=64, depth=64, num_units=3),
48
+ get_block(in_channel=64, depth=128, num_units=8),
49
+ get_block(in_channel=128, depth=256, num_units=36),
50
+ get_block(in_channel=256, depth=512, num_units=3)
51
+ ]
52
+ else:
53
+ raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
54
+ return blocks
55
+
56
+
57
+ class SEModule(Module):
58
+ def __init__(self, channels, reduction):
59
+ super(SEModule, self).__init__()
60
+ self.avg_pool = AdaptiveAvgPool2d(1)
61
+ self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
62
+ self.relu = ReLU(inplace=True)
63
+ self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
64
+ self.sigmoid = Sigmoid()
65
+
66
+ def forward(self, x):
67
+ module_input = x
68
+ x = self.avg_pool(x)
69
+ x = self.fc1(x)
70
+ x = self.relu(x)
71
+ x = self.fc2(x)
72
+ x = self.sigmoid(x)
73
+ return module_input * x
74
+
75
+
76
+ class bottleneck_IR(Module):
77
+ def __init__(self, in_channel, depth, stride):
78
+ super(bottleneck_IR, self).__init__()
79
+ if in_channel == depth:
80
+ self.shortcut_layer = MaxPool2d(1, stride)
81
+ else:
82
+ self.shortcut_layer = Sequential(
83
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
84
+ BatchNorm2d(depth)
85
+ )
86
+ self.res_layer = Sequential(
87
+ BatchNorm2d(in_channel),
88
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
89
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
90
+ )
91
+
92
+ def forward(self, x):
93
+ shortcut = self.shortcut_layer(x)
94
+ res = self.res_layer(x)
95
+ return res + shortcut
96
+
97
+
98
+ class bottleneck_IR_SE(Module):
99
+ def __init__(self, in_channel, depth, stride):
100
+ super(bottleneck_IR_SE, self).__init__()
101
+ if in_channel == depth:
102
+ self.shortcut_layer = MaxPool2d(1, stride)
103
+ else:
104
+ self.shortcut_layer = Sequential(
105
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
106
+ BatchNorm2d(depth)
107
+ )
108
+ self.res_layer = Sequential(
109
+ BatchNorm2d(in_channel),
110
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
111
+ PReLU(depth),
112
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
113
+ BatchNorm2d(depth),
114
+ SEModule(depth, 16)
115
+ )
116
+
117
+ def forward(self, x):
118
+ shortcut = self.shortcut_layer(x)
119
+ res = self.res_layer(x)
120
+ return res + shortcut
121
+
122
+
123
+ def _upsample_add(x, y):
124
+ """Upsample and add two feature maps.
125
+ Args:
126
+ x: (Variable) top feature map to be upsampled.
127
+ y: (Variable) lateral feature map.
128
+ Returns:
129
+ (Variable) added feature map.
130
+ Note in PyTorch, when input size is odd, the upsampled feature map
131
+ with `F.upsample(..., scale_factor=2, mode='nearest')`
132
+ maybe not equal to the lateral feature map size.
133
+ e.g.
134
+ original input size: [N,_,15,15] ->
135
+ conv2d feature map size: [N,_,8,8] ->
136
+ upsampled feature map size: [N,_,16,16]
137
+ So we choose bilinear upsample which supports arbitrary output sizes.
138
+ """
139
+ _, _, H, W = y.size()
140
+ return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y
models/encoders/map2style.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from torch import nn
3
+ from torch.nn import Conv2d, Module
4
+
5
+ from models.stylegan2.model import EqualLinear
6
+
7
+
8
+ class GradualStyleBlock(Module):
9
+ def __init__(self, in_c, out_c, spatial):
10
+ super(GradualStyleBlock, self).__init__()
11
+ self.out_c = out_c
12
+ self.spatial = spatial
13
+ num_pools = int(np.log2(spatial))
14
+ modules = []
15
+ modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
16
+ nn.LeakyReLU()]
17
+ for i in range(num_pools - 1):
18
+ modules += [
19
+ Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
20
+ nn.LeakyReLU()
21
+ ]
22
+ self.convs = nn.Sequential(*modules)
23
+ self.linear = EqualLinear(out_c, out_c, lr_mul=1)
24
+
25
+ def forward(self, x):
26
+ x = self.convs(x)
27
+ x = x.view(-1, self.out_c)
28
+ x = self.linear(x)
29
+ return x
models/encoders/model_irse.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
2
+ from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
3
+
4
+ """
5
+ Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
6
+ """
7
+
8
+
9
+ class Backbone(Module):
10
+ def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
11
+ super(Backbone, self).__init__()
12
+ assert input_size in [112, 224], "input_size should be 112 or 224"
13
+ assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
14
+ assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
15
+ blocks = get_blocks(num_layers)
16
+ if mode == 'ir':
17
+ unit_module = bottleneck_IR
18
+ elif mode == 'ir_se':
19
+ unit_module = bottleneck_IR_SE
20
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
21
+ BatchNorm2d(64),
22
+ PReLU(64))
23
+ if input_size == 112:
24
+ self.output_layer = Sequential(BatchNorm2d(512),
25
+ Dropout(drop_ratio),
26
+ Flatten(),
27
+ Linear(512 * 7 * 7, 512),
28
+ BatchNorm1d(512, affine=affine))
29
+ else:
30
+ self.output_layer = Sequential(BatchNorm2d(512),
31
+ Dropout(drop_ratio),
32
+ Flatten(),
33
+ Linear(512 * 14 * 14, 512),
34
+ BatchNorm1d(512, affine=affine))
35
+
36
+ modules = []
37
+ for block in blocks:
38
+ for bottleneck in block:
39
+ modules.append(unit_module(bottleneck.in_channel,
40
+ bottleneck.depth,
41
+ bottleneck.stride))
42
+ self.body = Sequential(*modules)
43
+
44
+ def forward(self, x):
45
+ x = self.input_layer(x)
46
+ x = self.body(x)
47
+ x = self.output_layer(x)
48
+ return l2_norm(x)
49
+
50
+
51
+ def IR_50(input_size):
52
+ """Constructs a ir-50 model."""
53
+ model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
54
+ return model
55
+
56
+
57
+ def IR_101(input_size):
58
+ """Constructs a ir-101 model."""
59
+ model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
60
+ return model
61
+
62
+
63
+ def IR_152(input_size):
64
+ """Constructs a ir-152 model."""
65
+ model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
66
+ return model
67
+
68
+
69
+ def IR_SE_50(input_size):
70
+ """Constructs a ir_se-50 model."""
71
+ model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
72
+ return model
73
+
74
+
75
+ def IR_SE_101(input_size):
76
+ """Constructs a ir_se-101 model."""
77
+ model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
78
+ return model
79
+
80
+
81
+ def IR_SE_152(input_size):
82
+ """Constructs a ir_se-152 model."""
83
+ model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
84
+ return model
models/encoders/psp_encoders.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module
7
+
8
+ from models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE, _upsample_add
9
+ from models.stylegan2.model import EqualLinear
10
+
11
+
12
+ class ProgressiveStage(Enum):
13
+ WTraining = 0
14
+ Delta1Training = 1
15
+ Delta2Training = 2
16
+ Delta3Training = 3
17
+ Delta4Training = 4
18
+ Delta5Training = 5
19
+ Delta6Training = 6
20
+ Delta7Training = 7
21
+ Delta8Training = 8
22
+ Delta9Training = 9
23
+ Delta10Training = 10
24
+ Delta11Training = 11
25
+ Delta12Training = 12
26
+ Delta13Training = 13
27
+ Delta14Training = 14
28
+ Delta15Training = 15
29
+ Delta16Training = 16
30
+ Delta17Training = 17
31
+ Inference = 18
32
+
33
+
34
+ class GradualStyleBlock(Module):
35
+ def __init__(self, in_c, out_c, spatial):
36
+ super(GradualStyleBlock, self).__init__()
37
+ self.out_c = out_c
38
+ self.spatial = spatial
39
+ num_pools = int(np.log2(spatial))
40
+ modules = []
41
+ modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
42
+ nn.LeakyReLU()]
43
+ for i in range(num_pools - 1):
44
+ modules += [
45
+ Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
46
+ nn.LeakyReLU()
47
+ ]
48
+ self.convs = nn.Sequential(*modules)
49
+ self.linear = EqualLinear(out_c, out_c, lr_mul=1)
50
+
51
+ def forward(self, x):
52
+ x = self.convs(x)
53
+ x = x.view(-1, self.out_c)
54
+ x = self.linear(x)
55
+ return x
56
+
57
+
58
+ class GradualStyleEncoder(Module):
59
+ def __init__(self, num_layers, mode='ir', opts=None):
60
+ super(GradualStyleEncoder, self).__init__()
61
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
62
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
63
+ blocks = get_blocks(num_layers)
64
+ if mode == 'ir':
65
+ unit_module = bottleneck_IR
66
+ elif mode == 'ir_se':
67
+ unit_module = bottleneck_IR_SE
68
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
69
+ BatchNorm2d(64),
70
+ PReLU(64))
71
+ modules = []
72
+ for block in blocks:
73
+ for bottleneck in block:
74
+ modules.append(unit_module(bottleneck.in_channel,
75
+ bottleneck.depth,
76
+ bottleneck.stride))
77
+ self.body = Sequential(*modules)
78
+
79
+ self.styles = nn.ModuleList()
80
+ log_size = int(math.log(opts.stylegan_size, 2))
81
+ self.style_count = 2 * log_size - 2
82
+ self.coarse_ind = 3
83
+ self.middle_ind = 7
84
+ for i in range(self.style_count):
85
+ if i < self.coarse_ind:
86
+ style = GradualStyleBlock(512, 512, 16)
87
+ elif i < self.middle_ind:
88
+ style = GradualStyleBlock(512, 512, 32)
89
+ else:
90
+ style = GradualStyleBlock(512, 512, 64)
91
+ self.styles.append(style)
92
+ self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
93
+ self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
94
+
95
+ def forward(self, x):
96
+ x = self.input_layer(x)
97
+
98
+ latents = []
99
+ modulelist = list(self.body._modules.values())
100
+ for i, l in enumerate(modulelist):
101
+ x = l(x)
102
+ if i == 6:
103
+ c1 = x
104
+ elif i == 20:
105
+ c2 = x
106
+ elif i == 23:
107
+ c3 = x
108
+
109
+ for j in range(self.coarse_ind):
110
+ latents.append(self.styles[j](c3))
111
+
112
+ p2 = _upsample_add(c3, self.latlayer1(c2))
113
+ for j in range(self.coarse_ind, self.middle_ind):
114
+ latents.append(self.styles[j](p2))
115
+
116
+ p1 = _upsample_add(p2, self.latlayer2(c1))
117
+ for j in range(self.middle_ind, self.style_count):
118
+ latents.append(self.styles[j](p1))
119
+
120
+ out = torch.stack(latents, dim=1)
121
+ return out
122
+
123
+
124
+ class Encoder4Editing(Module):
125
+ def __init__(self, num_layers, mode='ir', opts=None):
126
+ super(Encoder4Editing, self).__init__()
127
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
128
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
129
+ blocks = get_blocks(num_layers)
130
+ if mode == 'ir':
131
+ unit_module = bottleneck_IR
132
+ elif mode == 'ir_se':
133
+ unit_module = bottleneck_IR_SE
134
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
135
+ BatchNorm2d(64),
136
+ PReLU(64))
137
+ modules = []
138
+ for block in blocks:
139
+ for bottleneck in block:
140
+ modules.append(unit_module(bottleneck.in_channel,
141
+ bottleneck.depth,
142
+ bottleneck.stride))
143
+ self.body = Sequential(*modules)
144
+
145
+ self.styles = nn.ModuleList()
146
+ log_size = int(math.log(opts.stylegan_size, 2))
147
+ self.style_count = 2 * log_size - 2
148
+ self.coarse_ind = 3
149
+ self.middle_ind = 7
150
+
151
+ for i in range(self.style_count):
152
+ if i < self.coarse_ind:
153
+ style = GradualStyleBlock(512, 512, 16)
154
+ elif i < self.middle_ind:
155
+ style = GradualStyleBlock(512, 512, 32)
156
+ else:
157
+ style = GradualStyleBlock(512, 512, 64)
158
+ self.styles.append(style)
159
+
160
+ self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
161
+ self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
162
+
163
+ self.progressive_stage = ProgressiveStage.Inference
164
+
165
+ def get_deltas_starting_dimensions(self):
166
+ ''' Get a list of the initial dimension of every delta from which it is applied '''
167
+ return list(range(self.style_count)) # Each dimension has a delta applied to it
168
+
169
+ def set_progressive_stage(self, new_stage: ProgressiveStage):
170
+ self.progressive_stage = new_stage
171
+ print('Changed progressive stage to: ', new_stage)
172
+
173
+ def forward(self, x):
174
+ x = self.input_layer(x)
175
+
176
+ modulelist = list(self.body._modules.values())
177
+ for i, l in enumerate(modulelist):
178
+ x = l(x)
179
+ if i == 6:
180
+ c1 = x
181
+ elif i == 20:
182
+ c2 = x
183
+ elif i == 23:
184
+ c3 = x
185
+
186
+ # Infer main W and duplicate it
187
+ w0 = self.styles[0](c3)
188
+ w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2)
189
+ stage = self.progressive_stage.value
190
+ features = c3
191
+ for i in range(1, min(stage + 1, self.style_count)): # Infer additional deltas
192
+ if i == self.coarse_ind:
193
+ p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features
194
+ features = p2
195
+ elif i == self.middle_ind:
196
+ p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features
197
+ features = p1
198
+ delta_i = self.styles[i](features)
199
+ w[:, i] += delta_i
200
+ return w
201
+
202
+
203
+ class BackboneEncoderUsingLastLayerIntoW(Module):
204
+ def __init__(self, num_layers, mode='ir', opts=None):
205
+ super(BackboneEncoderUsingLastLayerIntoW, self).__init__()
206
+ print('Using BackboneEncoderUsingLastLayerIntoW')
207
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
208
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
209
+ blocks = get_blocks(num_layers)
210
+ if mode == 'ir':
211
+ unit_module = bottleneck_IR
212
+ elif mode == 'ir_se':
213
+ unit_module = bottleneck_IR_SE
214
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
215
+ BatchNorm2d(64),
216
+ PReLU(64))
217
+ self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
218
+ self.linear = EqualLinear(512, 512, lr_mul=1)
219
+ modules = []
220
+ for block in blocks:
221
+ for bottleneck in block:
222
+ modules.append(unit_module(bottleneck.in_channel,
223
+ bottleneck.depth,
224
+ bottleneck.stride))
225
+ self.body = Sequential(*modules)
226
+ log_size = int(math.log(opts.stylegan_size, 2))
227
+ self.style_count = 2 * log_size - 2
228
+
229
+ def forward(self, x):
230
+ x = self.input_layer(x)
231
+ x = self.body(x)
232
+ x = self.output_pool(x)
233
+ x = x.view(-1, 512)
234
+ x = self.linear(x)
235
+ return x.repeat(self.style_count, 1, 1).permute(1, 0, 2)
models/encoders/psp_encoders_features.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module
7
+
8
+ from models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE, _upsample_add
9
+ from models.stylegan2.model import EqualLinear
10
+
11
+
12
+ class ProgressiveStage(Enum):
13
+ WTraining = 0
14
+ Delta1Training = 1
15
+ Delta2Training = 2
16
+ Delta3Training = 3
17
+ Delta4Training = 4
18
+ Delta5Training = 5
19
+ Delta6Training = 6
20
+ Delta7Training = 7
21
+ Delta8Training = 8
22
+ Delta9Training = 9
23
+ Delta10Training = 10
24
+ Delta11Training = 11
25
+ Delta12Training = 12
26
+ Delta13Training = 13
27
+ Delta14Training = 14
28
+ Delta15Training = 15
29
+ Delta16Training = 16
30
+ Delta17Training = 17
31
+ Inference = 18
32
+
33
+
34
+ class GradualStyleBlock(Module):
35
+ def __init__(self, in_c, out_c, spatial):
36
+ super(GradualStyleBlock, self).__init__()
37
+ self.out_c = out_c
38
+ self.spatial = spatial
39
+ num_pools = int(np.log2(spatial))
40
+ modules = []
41
+ modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
42
+ nn.LeakyReLU()]
43
+ for i in range(num_pools - 1):
44
+ modules += [
45
+ Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
46
+ nn.LeakyReLU()
47
+ ]
48
+ self.convs = nn.Sequential(*modules)
49
+ self.linear = EqualLinear(out_c, out_c, lr_mul=1)
50
+
51
+ def forward(self, x):
52
+ x = self.convs(x)
53
+ x = x.view(-1, self.out_c)
54
+ x = self.linear(x)
55
+ return x
56
+
57
+
58
+ class GradualStyleEncoder(Module):
59
+ def __init__(self, num_layers, mode='ir', opts=None):
60
+ super(GradualStyleEncoder, self).__init__()
61
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
62
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
63
+ blocks = get_blocks(num_layers)
64
+ if mode == 'ir':
65
+ unit_module = bottleneck_IR
66
+ elif mode == 'ir_se':
67
+ unit_module = bottleneck_IR_SE
68
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
69
+ BatchNorm2d(64),
70
+ PReLU(64))
71
+ modules = []
72
+ for block in blocks:
73
+ for bottleneck in block:
74
+ modules.append(unit_module(bottleneck.in_channel,
75
+ bottleneck.depth,
76
+ bottleneck.stride))
77
+ self.body = Sequential(*modules)
78
+
79
+ self.styles = nn.ModuleList()
80
+ log_size = int(math.log(opts.stylegan_size, 2))
81
+ self.style_count = 2 * log_size - 2
82
+ self.coarse_ind = 3
83
+ self.middle_ind = 7
84
+ for i in range(self.style_count):
85
+ if i < self.coarse_ind:
86
+ style = GradualStyleBlock(512, 512, 16)
87
+ elif i < self.middle_ind:
88
+ style = GradualStyleBlock(512, 512, 32)
89
+ else:
90
+ style = GradualStyleBlock(512, 512, 64)
91
+ self.styles.append(style)
92
+ self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
93
+ self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
94
+
95
+ def forward(self, x):
96
+ x = self.input_layer(x)
97
+
98
+ latents = []
99
+ modulelist = list(self.body._modules.values())
100
+ for i, l in enumerate(modulelist):
101
+ x = l(x)
102
+ if i == 6:
103
+ c1 = x
104
+ elif i == 20:
105
+ c2 = x
106
+ elif i == 23:
107
+ c3 = x
108
+
109
+ for j in range(self.coarse_ind):
110
+ latents.append(self.styles[j](c3))
111
+
112
+ p2 = _upsample_add(c3, self.latlayer1(c2))
113
+ for j in range(self.coarse_ind, self.middle_ind):
114
+ latents.append(self.styles[j](p2))
115
+
116
+ p1 = _upsample_add(p2, self.latlayer2(c1))
117
+ for j in range(self.middle_ind, self.style_count):
118
+ latents.append(self.styles[j](p1))
119
+
120
+ out = torch.stack(latents, dim=1)
121
+ return out
122
+
123
+
124
+ class Encoder4Editing(Module):
125
+ def __init__(self, num_layers, mode='ir', opts=None):
126
+ super(Encoder4Editing, self).__init__()
127
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
128
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
129
+ blocks = get_blocks(num_layers)
130
+ if mode == 'ir':
131
+ unit_module = bottleneck_IR
132
+ elif mode == 'ir_se':
133
+ unit_module = bottleneck_IR_SE
134
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
135
+ BatchNorm2d(64),
136
+ PReLU(64))
137
+ modules = []
138
+ for block in blocks:
139
+ for bottleneck in block:
140
+ modules.append(unit_module(bottleneck.in_channel,
141
+ bottleneck.depth,
142
+ bottleneck.stride))
143
+ self.body = Sequential(*modules)
144
+
145
+ self.styles = nn.ModuleList()
146
+ log_size = int(math.log(opts.stylegan_size, 2))
147
+ self.style_count = 2 * log_size - 2
148
+ self.coarse_ind = 3
149
+ self.middle_ind = 7
150
+
151
+ for i in range(self.style_count):
152
+ if i < self.coarse_ind:
153
+ style = GradualStyleBlock(512, 512, 16)
154
+ elif i < self.middle_ind:
155
+ style = GradualStyleBlock(512, 512, 32)
156
+ else:
157
+ style = GradualStyleBlock(512, 512, 64)
158
+ self.styles.append(style)
159
+
160
+ self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
161
+ self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
162
+
163
+ self.progressive_stage = ProgressiveStage.Inference
164
+
165
+ def get_deltas_starting_dimensions(self):
166
+ ''' Get a list of the initial dimension of every delta from which it is applied '''
167
+ return list(range(self.style_count)) # Each dimension has a delta applied to it
168
+
169
+ def set_progressive_stage(self, new_stage: ProgressiveStage):
170
+ self.progressive_stage = new_stage
171
+ print('Changed progressive stage to: ', new_stage)
172
+
173
+ # Forward pass also returns the intermediate feature maps
174
+ def forward(self, x):
175
+ x = self.input_layer(x)
176
+
177
+ modulelist = list(self.body._modules.values())
178
+ for i, l in enumerate(modulelist):
179
+ x = l(x)
180
+ if i == 6:
181
+ c1 = x
182
+ elif i == 20:
183
+ c2 = x
184
+ elif i == 23:
185
+ c3 = x
186
+
187
+ # Infer main W and duplicate it
188
+ w0 = self.styles[0](c3)
189
+ w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2)
190
+ stage = self.progressive_stage.value
191
+ features = c3
192
+ for i in range(1, min(stage + 1, self.style_count)): # Infer additional deltas
193
+ if i == self.coarse_ind:
194
+ p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features
195
+ features = p2
196
+ elif i == self.middle_ind:
197
+ p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features
198
+ features = p1
199
+ delta_i = self.styles[i](features)
200
+ w[:, i] += delta_i
201
+ return w, (c1,c2,c3)
202
+
203
+ # Forward the modulated features through the map2style layers to obtain the residual latents
204
+ def forward_features(self, features):
205
+ c1, c2, c3 = features
206
+
207
+ # Infer main W and duplicate it
208
+ w0 = self.styles[0](c3)
209
+ w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2)
210
+ stage = self.progressive_stage.value
211
+ features = c3
212
+ for i in range(1, min(stage + 1, self.style_count)): # Infer additional deltas
213
+ if i == self.coarse_ind:
214
+ p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features
215
+ features = p2
216
+ elif i == self.middle_ind:
217
+ p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features
218
+ features = p1
219
+ delta_i = self.styles[i](features)
220
+ w[:, i] += delta_i
221
+ return w
222
+
223
+
224
+ class BackboneEncoderUsingLastLayerIntoW(Module):
225
+ def __init__(self, num_layers, mode='ir', opts=None):
226
+ super(BackboneEncoderUsingLastLayerIntoW, self).__init__()
227
+ print('Using BackboneEncoderUsingLastLayerIntoW')
228
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
229
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
230
+ blocks = get_blocks(num_layers)
231
+ if mode == 'ir':
232
+ unit_module = bottleneck_IR
233
+ elif mode == 'ir_se':
234
+ unit_module = bottleneck_IR_SE
235
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
236
+ BatchNorm2d(64),
237
+ PReLU(64))
238
+ self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
239
+ self.linear = EqualLinear(512, 512, lr_mul=1)
240
+ modules = []
241
+ for block in blocks:
242
+ for bottleneck in block:
243
+ modules.append(unit_module(bottleneck.in_channel,
244
+ bottleneck.depth,
245
+ bottleneck.stride))
246
+ self.body = Sequential(*modules)
247
+ log_size = int(math.log(opts.stylegan_size, 2))
248
+ self.style_count = 2 * log_size - 2
249
+
250
+ def forward(self, x):
251
+ x = self.input_layer(x)
252
+ x = self.body(x)
253
+ x = self.output_pool(x)
254
+ x = x.view(-1, 512)
255
+ x = self.linear(x)
256
+ return x.repeat(self.style_count, 1, 1).permute(1, 0, 2)
models/stylegan2/__init__.py ADDED
File without changes
models/stylegan2/model.py ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
9
+
10
+
11
+ class PixelNorm(nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ def forward(self, input):
16
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
17
+
18
+
19
+ def make_kernel(k):
20
+ k = torch.tensor(k, dtype=torch.float32)
21
+
22
+ if k.ndim == 1:
23
+ k = k[None, :] * k[:, None]
24
+
25
+ k /= k.sum()
26
+
27
+ return k
28
+
29
+
30
+ class Upsample(nn.Module):
31
+ def __init__(self, kernel, factor=2):
32
+ super().__init__()
33
+
34
+ self.factor = factor
35
+ kernel = make_kernel(kernel) * (factor ** 2)
36
+ self.register_buffer('kernel', kernel)
37
+
38
+ p = kernel.shape[0] - factor
39
+
40
+ pad0 = (p + 1) // 2 + factor - 1
41
+ pad1 = p // 2
42
+
43
+ self.pad = (pad0, pad1)
44
+
45
+ def forward(self, input):
46
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
47
+
48
+ return out
49
+
50
+
51
+ class Downsample(nn.Module):
52
+ def __init__(self, kernel, factor=2):
53
+ super().__init__()
54
+
55
+ self.factor = factor
56
+ kernel = make_kernel(kernel)
57
+ self.register_buffer('kernel', kernel)
58
+
59
+ p = kernel.shape[0] - factor
60
+
61
+ pad0 = (p + 1) // 2
62
+ pad1 = p // 2
63
+
64
+ self.pad = (pad0, pad1)
65
+
66
+ def forward(self, input):
67
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
68
+
69
+ return out
70
+
71
+
72
+ class Blur(nn.Module):
73
+ def __init__(self, kernel, pad, upsample_factor=1):
74
+ super().__init__()
75
+
76
+ kernel = make_kernel(kernel)
77
+
78
+ if upsample_factor > 1:
79
+ kernel = kernel * (upsample_factor ** 2)
80
+
81
+ self.register_buffer('kernel', kernel)
82
+
83
+ self.pad = pad
84
+
85
+ def forward(self, input):
86
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
87
+
88
+ return out
89
+
90
+
91
+ class EqualConv2d(nn.Module):
92
+ def __init__(
93
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
94
+ ):
95
+ super().__init__()
96
+
97
+ self.weight = nn.Parameter(
98
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
99
+ )
100
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
101
+
102
+ self.stride = stride
103
+ self.padding = padding
104
+
105
+ if bias:
106
+ self.bias = nn.Parameter(torch.zeros(out_channel))
107
+
108
+ else:
109
+ self.bias = None
110
+
111
+ def forward(self, input):
112
+ out = F.conv2d(
113
+ input,
114
+ self.weight * self.scale,
115
+ bias=self.bias,
116
+ stride=self.stride,
117
+ padding=self.padding,
118
+ )
119
+
120
+ return out
121
+
122
+ def __repr__(self):
123
+ return (
124
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
125
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
126
+ )
127
+
128
+
129
+ class EqualLinear(nn.Module):
130
+ def __init__(
131
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
132
+ ):
133
+ super().__init__()
134
+
135
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
136
+
137
+ if bias:
138
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
139
+
140
+ else:
141
+ self.bias = None
142
+
143
+ self.activation = activation
144
+
145
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
146
+ self.lr_mul = lr_mul
147
+
148
+ def forward(self, input):
149
+ if self.activation:
150
+ out = F.linear(input, self.weight * self.scale)
151
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
152
+
153
+ else:
154
+ out = F.linear(
155
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
156
+ )
157
+
158
+ return out
159
+
160
+ def __repr__(self):
161
+ return (
162
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
163
+ )
164
+
165
+
166
+ class ScaledLeakyReLU(nn.Module):
167
+ def __init__(self, negative_slope=0.2):
168
+ super().__init__()
169
+
170
+ self.negative_slope = negative_slope
171
+
172
+ def forward(self, input):
173
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
174
+
175
+ return out * math.sqrt(2)
176
+
177
+
178
+ class ModulatedConv2d(nn.Module):
179
+ def __init__(
180
+ self,
181
+ in_channel,
182
+ out_channel,
183
+ kernel_size,
184
+ style_dim,
185
+ demodulate=True,
186
+ upsample=False,
187
+ downsample=False,
188
+ blur_kernel=[1, 3, 3, 1],
189
+ ):
190
+ super().__init__()
191
+
192
+ self.eps = 1e-8
193
+ self.kernel_size = kernel_size
194
+ self.in_channel = in_channel
195
+ self.out_channel = out_channel
196
+ self.upsample = upsample
197
+ self.downsample = downsample
198
+
199
+ if upsample:
200
+ factor = 2
201
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
202
+ pad0 = (p + 1) // 2 + factor - 1
203
+ pad1 = p // 2 + 1
204
+
205
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
206
+
207
+ if downsample:
208
+ factor = 2
209
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
210
+ pad0 = (p + 1) // 2
211
+ pad1 = p // 2
212
+
213
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
214
+
215
+ fan_in = in_channel * kernel_size ** 2
216
+ self.scale = 1 / math.sqrt(fan_in)
217
+ self.padding = kernel_size // 2
218
+
219
+ self.weight = nn.Parameter(
220
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
221
+ )
222
+
223
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
224
+
225
+ self.demodulate = demodulate
226
+
227
+ def __repr__(self):
228
+ return (
229
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
230
+ f'upsample={self.upsample}, downsample={self.downsample})'
231
+ )
232
+
233
+ def forward(self, input, style):
234
+ batch, in_channel, height, width = input.shape
235
+
236
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
237
+ weight = self.scale * self.weight * style
238
+
239
+ if self.demodulate:
240
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
241
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
242
+
243
+ weight = weight.view(
244
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
245
+ )
246
+
247
+ if self.upsample:
248
+ input = input.view(1, batch * in_channel, height, width)
249
+ weight = weight.view(
250
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
251
+ )
252
+ weight = weight.transpose(1, 2).reshape(
253
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
254
+ )
255
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
256
+ _, _, height, width = out.shape
257
+ out = out.view(batch, self.out_channel, height, width)
258
+ out = self.blur(out)
259
+
260
+ elif self.downsample:
261
+ input = self.blur(input)
262
+ _, _, height, width = input.shape
263
+ input = input.view(1, batch * in_channel, height, width)
264
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
265
+ _, _, height, width = out.shape
266
+ out = out.view(batch, self.out_channel, height, width)
267
+
268
+ else:
269
+ input = input.view(1, batch * in_channel, height, width)
270
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
271
+ _, _, height, width = out.shape
272
+ out = out.view(batch, self.out_channel, height, width)
273
+
274
+ return out
275
+
276
+
277
+ class NoiseInjection(nn.Module):
278
+ def __init__(self):
279
+ super().__init__()
280
+
281
+ self.weight = nn.Parameter(torch.zeros(1))
282
+
283
+ def forward(self, image, noise=None):
284
+ if noise is None:
285
+ batch, _, height, width = image.shape
286
+ noise = image.new_empty(batch, 1, height, width).normal_()
287
+
288
+ return image + self.weight * noise
289
+
290
+
291
+ class ConstantInput(nn.Module):
292
+ def __init__(self, channel, size=4):
293
+ super().__init__()
294
+
295
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
296
+
297
+ def forward(self, input):
298
+ batch = input.shape[0]
299
+ out = self.input.repeat(batch, 1, 1, 1)
300
+
301
+ return out
302
+
303
+
304
+ class StyledConv(nn.Module):
305
+ def __init__(
306
+ self,
307
+ in_channel,
308
+ out_channel,
309
+ kernel_size,
310
+ style_dim,
311
+ upsample=False,
312
+ blur_kernel=[1, 3, 3, 1],
313
+ demodulate=True,
314
+ ):
315
+ super().__init__()
316
+
317
+ self.conv = ModulatedConv2d(
318
+ in_channel,
319
+ out_channel,
320
+ kernel_size,
321
+ style_dim,
322
+ upsample=upsample,
323
+ blur_kernel=blur_kernel,
324
+ demodulate=demodulate,
325
+ )
326
+
327
+ self.noise = NoiseInjection()
328
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
329
+ # self.activate = ScaledLeakyReLU(0.2)
330
+ self.activate = FusedLeakyReLU(out_channel)
331
+
332
+ def forward(self, input, style, noise=None):
333
+ out = self.conv(input, style)
334
+ out = self.noise(out, noise=noise)
335
+ # out = out + self.bias
336
+ out = self.activate(out)
337
+
338
+ return out
339
+
340
+
341
+ class ToRGB(nn.Module):
342
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
343
+ super().__init__()
344
+
345
+ if upsample:
346
+ self.upsample = Upsample(blur_kernel)
347
+
348
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
349
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
350
+
351
+ def forward(self, input, style, skip=None):
352
+ out = self.conv(input, style)
353
+ out = out + self.bias
354
+
355
+ if skip is not None:
356
+ skip = self.upsample(skip)
357
+
358
+ out = out + skip
359
+
360
+ return out
361
+
362
+
363
+ class Generator(nn.Module):
364
+ def __init__(
365
+ self,
366
+ size,
367
+ style_dim,
368
+ n_mlp,
369
+ channel_multiplier=2,
370
+ blur_kernel=[1, 3, 3, 1],
371
+ lr_mlp=0.01,
372
+ ):
373
+ super().__init__()
374
+
375
+ self.size = size
376
+
377
+ self.style_dim = style_dim
378
+
379
+ layers = [PixelNorm()]
380
+
381
+ for i in range(n_mlp):
382
+ layers.append(
383
+ EqualLinear(
384
+ style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
385
+ )
386
+ )
387
+
388
+ self.style = nn.Sequential(*layers)
389
+
390
+ self.channels = {
391
+ 4: 512,
392
+ 8: 512,
393
+ 16: 512,
394
+ 32: 512,
395
+ 64: 256 * channel_multiplier,
396
+ 128: 128 * channel_multiplier,
397
+ 256: 64 * channel_multiplier,
398
+ 512: 32 * channel_multiplier,
399
+ 1024: 16 * channel_multiplier,
400
+ }
401
+
402
+ self.input = ConstantInput(self.channels[4])
403
+ self.conv1 = StyledConv(
404
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
405
+ )
406
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
407
+
408
+ self.log_size = int(math.log(size, 2))
409
+ self.num_layers = (self.log_size - 2) * 2 + 1
410
+
411
+ self.convs = nn.ModuleList()
412
+ self.upsamples = nn.ModuleList()
413
+ self.to_rgbs = nn.ModuleList()
414
+ self.noises = nn.Module()
415
+
416
+ in_channel = self.channels[4]
417
+
418
+ for layer_idx in range(self.num_layers):
419
+ res = (layer_idx + 5) // 2
420
+ shape = [1, 1, 2 ** res, 2 ** res]
421
+ self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
422
+
423
+ for i in range(3, self.log_size + 1):
424
+ out_channel = self.channels[2 ** i]
425
+
426
+ self.convs.append(
427
+ StyledConv(
428
+ in_channel,
429
+ out_channel,
430
+ 3,
431
+ style_dim,
432
+ upsample=True,
433
+ blur_kernel=blur_kernel,
434
+ )
435
+ )
436
+
437
+ self.convs.append(
438
+ StyledConv(
439
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
440
+ )
441
+ )
442
+
443
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
444
+
445
+ in_channel = out_channel
446
+
447
+ self.n_latent = self.log_size * 2 - 2
448
+
449
+ def make_noise(self):
450
+ device = self.input.input.device
451
+
452
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
453
+
454
+ for i in range(3, self.log_size + 1):
455
+ for _ in range(2):
456
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
457
+
458
+ return noises
459
+
460
+ def mean_latent(self, n_latent):
461
+ latent_in = torch.randn(
462
+ n_latent, self.style_dim, device=self.input.input.device
463
+ )
464
+ latent = self.style(latent_in).mean(0, keepdim=True)
465
+
466
+ return latent
467
+
468
+ def get_latent(self, input):
469
+ return self.style(input)
470
+
471
+ def forward(
472
+ self,
473
+ styles,
474
+ return_latents=False,
475
+ inject_index=None,
476
+ truncation=1,
477
+ truncation_latent=None,
478
+ input_is_latent=False,
479
+ noise=None,
480
+ randomize_noise=True,
481
+ ):
482
+ if not input_is_latent:
483
+ styles = [self.style(s) for s in styles]
484
+
485
+ if noise is None:
486
+ if randomize_noise:
487
+ noise = [None] * self.num_layers
488
+ else:
489
+ noise = [
490
+ getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
491
+ ]
492
+
493
+ if truncation < 1:
494
+ style_t = []
495
+
496
+ for style in styles:
497
+ style_t.append(
498
+ truncation_latent + truncation * (style - truncation_latent)
499
+ )
500
+
501
+ styles = style_t
502
+
503
+ if len(styles) < 2:
504
+ inject_index = self.n_latent
505
+
506
+ if styles[0].ndim < 3:
507
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
508
+
509
+ else:
510
+ latent = styles[0]
511
+
512
+ else:
513
+ if inject_index is None:
514
+ inject_index = random.randint(1, self.n_latent - 1)
515
+
516
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
517
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
518
+
519
+ latent = torch.cat([latent, latent2], 1)
520
+
521
+ out = self.input(latent)
522
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
523
+
524
+ skip = self.to_rgb1(out, latent[:, 1])
525
+
526
+ i = 1
527
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
528
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
529
+ ):
530
+ out = conv1(out, latent[:, i], noise=noise1)
531
+ out = conv2(out, latent[:, i + 1], noise=noise2)
532
+ skip = to_rgb(out, latent[:, i + 2], skip)
533
+
534
+ i += 2
535
+
536
+ image = skip
537
+
538
+ if return_latents:
539
+ return image, latent
540
+
541
+ else:
542
+ return image, None
543
+
544
+
545
+ class ConvLayer(nn.Sequential):
546
+ def __init__(
547
+ self,
548
+ in_channel,
549
+ out_channel,
550
+ kernel_size,
551
+ downsample=False,
552
+ blur_kernel=[1, 3, 3, 1],
553
+ bias=True,
554
+ activate=True,
555
+ ):
556
+ layers = []
557
+
558
+ if downsample:
559
+ factor = 2
560
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
561
+ pad0 = (p + 1) // 2
562
+ pad1 = p // 2
563
+
564
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
565
+
566
+ stride = 2
567
+ self.padding = 0
568
+
569
+ else:
570
+ stride = 1
571
+ self.padding = kernel_size // 2
572
+
573
+ layers.append(
574
+ EqualConv2d(
575
+ in_channel,
576
+ out_channel,
577
+ kernel_size,
578
+ padding=self.padding,
579
+ stride=stride,
580
+ bias=bias and not activate,
581
+ )
582
+ )
583
+
584
+ if activate:
585
+ if bias:
586
+ layers.append(FusedLeakyReLU(out_channel))
587
+
588
+ else:
589
+ layers.append(ScaledLeakyReLU(0.2))
590
+
591
+ super().__init__(*layers)
592
+
593
+
594
+ class ResBlock(nn.Module):
595
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
596
+ super().__init__()
597
+
598
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
599
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
600
+
601
+ self.skip = ConvLayer(
602
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
603
+ )
604
+
605
+ def forward(self, input):
606
+ out = self.conv1(input)
607
+ out = self.conv2(out)
608
+
609
+ skip = self.skip(input)
610
+ out = (out + skip) / math.sqrt(2)
611
+
612
+ return out
613
+
614
+
615
+ class Discriminator(nn.Module):
616
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
617
+ super().__init__()
618
+
619
+ channels = {
620
+ 4: 512,
621
+ 8: 512,
622
+ 16: 512,
623
+ 32: 512,
624
+ 64: 256 * channel_multiplier,
625
+ 128: 128 * channel_multiplier,
626
+ 256: 64 * channel_multiplier,
627
+ 512: 32 * channel_multiplier,
628
+ 1024: 16 * channel_multiplier,
629
+ }
630
+
631
+ convs = [ConvLayer(3, channels[size], 1)]
632
+
633
+ log_size = int(math.log(size, 2))
634
+
635
+ in_channel = channels[size]
636
+
637
+ for i in range(log_size, 2, -1):
638
+ out_channel = channels[2 ** (i - 1)]
639
+
640
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
641
+
642
+ in_channel = out_channel
643
+
644
+ self.convs = nn.Sequential(*convs)
645
+
646
+ self.stddev_group = 4
647
+ self.stddev_feat = 1
648
+
649
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
650
+ self.final_linear = nn.Sequential(
651
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
652
+ EqualLinear(channels[4], 1),
653
+ )
654
+
655
+ def forward(self, input):
656
+ out = self.convs(input)
657
+
658
+ batch, channel, height, width = out.shape
659
+ group = min(batch, self.stddev_group)
660
+ stddev = out.view(
661
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
662
+ )
663
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
664
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
665
+ stddev = stddev.repeat(group, 1, height, width)
666
+ out = torch.cat([out, stddev], 1)
667
+
668
+ out = self.final_conv(out)
669
+
670
+ out = out.view(batch, -1)
671
+ out = self.final_linear(out)
672
+
673
+ return out
674
+
models/stylegan2/model_remapper.py ADDED
@@ -0,0 +1,694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
8
+
9
+
10
+ class PixelNorm(nn.Module):
11
+ def __init__(self):
12
+ super().__init__()
13
+
14
+ def forward(self, input):
15
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
16
+
17
+
18
+ def make_kernel(k):
19
+ k = torch.tensor(k, dtype=torch.float32)
20
+
21
+ if k.ndim == 1:
22
+ k = k[None, :] * k[:, None]
23
+
24
+ k /= k.sum()
25
+
26
+ return k
27
+
28
+
29
+ class Upsample(nn.Module):
30
+ def __init__(self, kernel, factor=2):
31
+ super().__init__()
32
+
33
+ self.factor = factor
34
+ kernel = make_kernel(kernel) * (factor ** 2)
35
+ self.register_buffer('kernel', kernel)
36
+
37
+ p = kernel.shape[0] - factor
38
+
39
+ pad0 = (p + 1) // 2 + factor - 1
40
+ pad1 = p // 2
41
+
42
+ self.pad = (pad0, pad1)
43
+
44
+ def forward(self, input):
45
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
46
+
47
+ return out
48
+
49
+
50
+ class Downsample(nn.Module):
51
+ def __init__(self, kernel, factor=2):
52
+ super().__init__()
53
+
54
+ self.factor = factor
55
+ kernel = make_kernel(kernel)
56
+ self.register_buffer('kernel', kernel)
57
+
58
+ p = kernel.shape[0] - factor
59
+
60
+ pad0 = (p + 1) // 2
61
+ pad1 = p // 2
62
+
63
+ self.pad = (pad0, pad1)
64
+
65
+ def forward(self, input):
66
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
67
+
68
+ return out
69
+
70
+
71
+ class Blur(nn.Module):
72
+ def __init__(self, kernel, pad, upsample_factor=1):
73
+ super().__init__()
74
+
75
+ kernel = make_kernel(kernel)
76
+
77
+ if upsample_factor > 1:
78
+ kernel = kernel * (upsample_factor ** 2)
79
+
80
+ self.register_buffer('kernel', kernel)
81
+
82
+ self.pad = pad
83
+
84
+ def forward(self, input):
85
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
86
+
87
+ return out
88
+
89
+
90
+ class EqualConv2d(nn.Module):
91
+ def __init__(
92
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
93
+ ):
94
+ super().__init__()
95
+
96
+ self.weight = nn.Parameter(
97
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
98
+ )
99
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
100
+
101
+ self.stride = stride
102
+ self.padding = padding
103
+
104
+ if bias:
105
+ self.bias = nn.Parameter(torch.zeros(out_channel))
106
+
107
+ else:
108
+ self.bias = None
109
+
110
+ def forward(self, input):
111
+ out = F.conv2d(
112
+ input,
113
+ self.weight * self.scale,
114
+ bias=self.bias,
115
+ stride=self.stride,
116
+ padding=self.padding,
117
+ )
118
+
119
+ return out
120
+
121
+ def __repr__(self):
122
+ return (
123
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
124
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
125
+ )
126
+
127
+
128
+ class EqualLinear(nn.Module):
129
+ def __init__(
130
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
131
+ ):
132
+ super().__init__()
133
+
134
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
135
+
136
+ if bias:
137
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
138
+
139
+ else:
140
+ self.bias = None
141
+
142
+ self.activation = activation
143
+
144
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
145
+ self.lr_mul = lr_mul
146
+
147
+ def forward(self, input):
148
+ if self.activation:
149
+ out = F.linear(input, self.weight * self.scale)
150
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
151
+
152
+ else:
153
+ out = F.linear(
154
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
155
+ )
156
+
157
+ return out
158
+
159
+ def __repr__(self):
160
+ return (
161
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
162
+ )
163
+
164
+
165
+ class ScaledLeakyReLU(nn.Module):
166
+ def __init__(self, negative_slope=0.2):
167
+ super().__init__()
168
+
169
+ self.negative_slope = negative_slope
170
+
171
+ def forward(self, input):
172
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
173
+
174
+ return out * math.sqrt(2)
175
+
176
+
177
+ class ModulatedConv2d(nn.Module):
178
+ def __init__(
179
+ self,
180
+ in_channel,
181
+ out_channel,
182
+ kernel_size,
183
+ style_dim,
184
+ demodulate=True,
185
+ upsample=False,
186
+ downsample=False,
187
+ blur_kernel=[1, 3, 3, 1],
188
+ ):
189
+ super().__init__()
190
+
191
+ self.eps = 1e-8
192
+ self.kernel_size = kernel_size
193
+ self.in_channel = in_channel
194
+ self.out_channel = out_channel
195
+ self.upsample = upsample
196
+ self.downsample = downsample
197
+
198
+ if upsample:
199
+ factor = 2
200
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
201
+ pad0 = (p + 1) // 2 + factor - 1
202
+ pad1 = p // 2 + 1
203
+
204
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
205
+
206
+ if downsample:
207
+ factor = 2
208
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
209
+ pad0 = (p + 1) // 2
210
+ pad1 = p // 2
211
+
212
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
213
+
214
+ fan_in = in_channel * kernel_size ** 2
215
+ self.scale = 1 / math.sqrt(fan_in)
216
+ self.padding = kernel_size // 2
217
+
218
+ self.weight = nn.Parameter(
219
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
220
+ )
221
+
222
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
223
+
224
+ self.demodulate = demodulate
225
+
226
+ def __repr__(self):
227
+ return (
228
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
229
+ f'upsample={self.upsample}, downsample={self.downsample})'
230
+ )
231
+
232
+ def forward(self, input, style):
233
+ batch, in_channel, height, width = input.shape
234
+
235
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
236
+ weight = self.scale * self.weight * style
237
+
238
+ if self.demodulate:
239
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
240
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
241
+
242
+ weight = weight.view(
243
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
244
+ )
245
+
246
+ if self.upsample:
247
+ input = input.view(1, batch * in_channel, height, width)
248
+ weight = weight.view(
249
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
250
+ )
251
+ weight = weight.transpose(1, 2).reshape(
252
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
253
+ )
254
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
255
+ _, _, height, width = out.shape
256
+ out = out.view(batch, self.out_channel, height, width)
257
+ out = self.blur(out)
258
+
259
+ elif self.downsample:
260
+ input = self.blur(input)
261
+ _, _, height, width = input.shape
262
+ input = input.view(1, batch * in_channel, height, width)
263
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
264
+ _, _, height, width = out.shape
265
+ out = out.view(batch, self.out_channel, height, width)
266
+
267
+ else:
268
+ input = input.view(1, batch * in_channel, height, width)
269
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
270
+ _, _, height, width = out.shape
271
+ out = out.view(batch, self.out_channel, height, width)
272
+
273
+ return out
274
+
275
+
276
+ class NoiseInjection(nn.Module):
277
+ def __init__(self):
278
+ super().__init__()
279
+
280
+ self.weight = nn.Parameter(torch.zeros(1))
281
+
282
+ def forward(self, image, noise=None):
283
+ if noise is None:
284
+ batch, _, height, width = image.shape
285
+ noise = image.new_empty(batch, 1, height, width).normal_()
286
+
287
+ return image + self.weight * noise
288
+
289
+
290
+ class ConstantInput(nn.Module):
291
+ def __init__(self, channel, size=4):
292
+ super().__init__()
293
+
294
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
295
+
296
+ def forward(self, input):
297
+ batch = input.shape[0]
298
+ out = self.input.repeat(batch, 1, 1, 1)
299
+
300
+ return out
301
+
302
+
303
+ class StyledConv(nn.Module):
304
+ def __init__(
305
+ self,
306
+ in_channel,
307
+ out_channel,
308
+ kernel_size,
309
+ style_dim,
310
+ upsample=False,
311
+ blur_kernel=[1, 3, 3, 1],
312
+ demodulate=True,
313
+ ):
314
+ super().__init__()
315
+
316
+ self.conv = ModulatedConv2d(
317
+ in_channel,
318
+ out_channel,
319
+ kernel_size,
320
+ style_dim,
321
+ upsample=upsample,
322
+ blur_kernel=blur_kernel,
323
+ demodulate=demodulate,
324
+ )
325
+
326
+ self.noise = NoiseInjection()
327
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
328
+ # self.activate = ScaledLeakyReLU(0.2)
329
+ self.activate = FusedLeakyReLU(out_channel)
330
+
331
+ def forward(self, input, style, noise=None):
332
+ out = self.conv(input, style)
333
+ out = self.noise(out, noise=noise)
334
+ # out = out + self.bias
335
+ out = self.activate(out)
336
+
337
+ return out
338
+
339
+
340
+ class ToRGB(nn.Module):
341
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
342
+ super().__init__()
343
+
344
+ if upsample:
345
+ self.upsample = Upsample(blur_kernel)
346
+
347
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
348
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
349
+
350
+ def forward(self, input, style, skip=None):
351
+ out = self.conv(input, style)
352
+ out = out + self.bias
353
+
354
+ if skip is not None:
355
+ skip = self.upsample(skip)
356
+
357
+ out = out + skip
358
+
359
+ return out
360
+
361
+ # Generator with the remapper layers
362
+ class Generator(nn.Module):
363
+ def __init__(
364
+ self,
365
+ size,
366
+ style_dim,
367
+ n_mlp,
368
+ channel_multiplier=2,
369
+ blur_kernel=[1, 3, 3, 1],
370
+ lr_mlp=0.01,
371
+ ):
372
+ super().__init__()
373
+
374
+ self.size = size
375
+
376
+ self.style_dim = style_dim
377
+
378
+ layers = [PixelNorm()]
379
+
380
+ for i in range(n_mlp):
381
+ layers.append(
382
+ EqualLinear(
383
+ style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
384
+ )
385
+ )
386
+
387
+ self.style = nn.Sequential(*layers)
388
+
389
+ self.channels = {
390
+ 4: 512,
391
+ 8: 512,
392
+ 16: 512,
393
+ 32: 512,
394
+ 64: 256 * channel_multiplier,
395
+ 128: 128 * channel_multiplier,
396
+ 256: 64 * channel_multiplier,
397
+ 512: 32 * channel_multiplier,
398
+ 1024: 16 * channel_multiplier,
399
+ }
400
+
401
+ self.input = ConstantInput(self.channels[4])
402
+ self.conv1 = StyledConv(
403
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
404
+ )
405
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
406
+
407
+ self.log_size = int(math.log(size, 2))
408
+ self.num_layers = (self.log_size - 2) * 2 + 1
409
+
410
+ self.convs = nn.ModuleList()
411
+ self.upsamples = nn.ModuleList()
412
+ self.to_rgbs = nn.ModuleList()
413
+ self.noises = nn.Module()
414
+
415
+ in_channel = self.channels[4]
416
+
417
+ for layer_idx in range(self.num_layers):
418
+ res = (layer_idx + 5) // 2
419
+ shape = [1, 1, 2 ** res, 2 ** res]
420
+ self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
421
+
422
+ for i in range(3, self.log_size + 1):
423
+ out_channel = self.channels[2 ** i]
424
+
425
+ self.convs.append(
426
+ StyledConv(
427
+ in_channel,
428
+ out_channel,
429
+ 3,
430
+ style_dim,
431
+ upsample=True,
432
+ blur_kernel=blur_kernel,
433
+ )
434
+ )
435
+
436
+ self.convs.append(
437
+ StyledConv(
438
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
439
+ )
440
+ )
441
+
442
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
443
+
444
+ in_channel = out_channel
445
+
446
+ self.n_latent = self.log_size * 2 - 2
447
+
448
+ # CLIPREMAPPER
449
+ cond_size = style_dim
450
+ self.affine_mlps = nn.ModuleList()
451
+ self.lambdas = nn.ParameterList()
452
+ for i in range(1,9):
453
+ self.lambdas.append(nn.Parameter(torch.tensor(0.05), requires_grad=True))
454
+ self.lambdas.append(nn.Parameter(torch.tensor(0.05), requires_grad=True))
455
+ self.affine_mlps.append(nn.Sequential(nn.Linear(cond_size, cond_size), nn.ReLU(), nn.Linear(cond_size, style_dim), nn.LayerNorm(style_dim), nn.ReLU(), nn.Linear(style_dim, style_dim)))
456
+ self.affine_mlps.append(nn.Sequential(nn.Linear(cond_size, cond_size), nn.ReLU(), nn.Linear(cond_size, style_dim), nn.LayerNorm(style_dim), nn.ReLU(), nn.Linear(style_dim, style_dim)))
457
+
458
+ def make_noise(self):
459
+ device = self.input.input.device
460
+
461
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
462
+
463
+ for i in range(3, self.log_size + 1):
464
+ for _ in range(2):
465
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
466
+
467
+ return noises
468
+
469
+ def mean_latent(self, n_latent):
470
+ latent_in = torch.randn(
471
+ n_latent, self.style_dim, device=self.input.input.device
472
+ )
473
+ latent = self.style(latent_in).mean(0, keepdim=True)
474
+
475
+ return latent
476
+
477
+ def get_latent(self, input):
478
+ return self.style(input)
479
+
480
+ def forward(
481
+ self,
482
+ styles,
483
+ return_latents=False,
484
+ return_features=False,
485
+ inject_index=None,
486
+ truncation=1,
487
+ truncation_latent=None,
488
+ input_is_latent=False,
489
+ noise=None,
490
+ randomize_noise=True,
491
+ txt_embed=None
492
+ ):
493
+ if not input_is_latent:
494
+ styles = [self.style(s) for s in styles]
495
+
496
+ if noise is None:
497
+ if randomize_noise:
498
+ noise = [None] * self.num_layers
499
+ else:
500
+ noise = [
501
+ getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
502
+ ]
503
+
504
+ if truncation < 1:
505
+ style_t = []
506
+
507
+ for style in styles:
508
+ style_t.append(
509
+ truncation_latent + truncation * (style - truncation_latent)
510
+ )
511
+
512
+ styles = style_t
513
+
514
+ if len(styles) < 2:
515
+ inject_index = self.n_latent
516
+
517
+ if styles[0].ndim < 3:
518
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
519
+ else:
520
+ latent = styles[0]
521
+
522
+ else:
523
+ if inject_index is None:
524
+ inject_index = random.randint(1, self.n_latent - 1)
525
+
526
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
527
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
528
+
529
+ latent = torch.cat([latent, latent2], 1)
530
+
531
+ out = self.input(latent)
532
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
533
+
534
+ skip = self.to_rgb1(out, latent[:, 1])
535
+
536
+ i = 1
537
+ for affine1, affine2, conv1, conv2, noise1, noise2, to_rgb in zip(
538
+ self.affine_mlps[::2], self.affine_mlps[1::2], self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
539
+ ):
540
+ if txt_embed is None:
541
+ out = conv1(out, latent[:, i], noise=noise1)
542
+ out = conv2(out, latent[:, i + 1], noise=noise2)
543
+ else:
544
+ latent_pr_1 = affine1(txt_embed)
545
+ lambda1 = self.lambdas[i-1] * 0.9
546
+ inject_latent1 = latent[:, i] * (1-lambda1) + latent_pr_1 * (lambda1)
547
+ out = conv1(out, inject_latent1, noise=noise1)
548
+ latent_pr_2 = affine2(txt_embed)
549
+ lambda2 = self.lambdas[i] * 0.9
550
+ inject_latent2 = latent[:, i + 1] * (1-lambda2) + latent_pr_2 * (lambda2)
551
+ out = conv2(out, inject_latent2, noise=noise2)
552
+ skip = to_rgb(out, latent[:, i + 2], skip)
553
+
554
+ i += 2
555
+
556
+ image = skip
557
+
558
+ if return_latents:
559
+ return image, latent
560
+ elif return_features:
561
+ return image, out
562
+ else:
563
+ return image, None
564
+
565
+
566
+ class ConvLayer(nn.Sequential):
567
+ def __init__(
568
+ self,
569
+ in_channel,
570
+ out_channel,
571
+ kernel_size,
572
+ downsample=False,
573
+ blur_kernel=[1, 3, 3, 1],
574
+ bias=True,
575
+ activate=True,
576
+ ):
577
+ layers = []
578
+
579
+ if downsample:
580
+ factor = 2
581
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
582
+ pad0 = (p + 1) // 2
583
+ pad1 = p // 2
584
+
585
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
586
+
587
+ stride = 2
588
+ self.padding = 0
589
+
590
+ else:
591
+ stride = 1
592
+ self.padding = kernel_size // 2
593
+
594
+ layers.append(
595
+ EqualConv2d(
596
+ in_channel,
597
+ out_channel,
598
+ kernel_size,
599
+ padding=self.padding,
600
+ stride=stride,
601
+ bias=bias and not activate,
602
+ )
603
+ )
604
+
605
+ if activate:
606
+ if bias:
607
+ layers.append(FusedLeakyReLU(out_channel))
608
+
609
+ else:
610
+ layers.append(ScaledLeakyReLU(0.2))
611
+
612
+ super().__init__(*layers)
613
+
614
+
615
+ class ResBlock(nn.Module):
616
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
617
+ super().__init__()
618
+
619
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
620
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
621
+
622
+ self.skip = ConvLayer(
623
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
624
+ )
625
+
626
+ def forward(self, input):
627
+ out = self.conv1(input)
628
+ out = self.conv2(out)
629
+
630
+ skip = self.skip(input)
631
+ out = (out + skip) / math.sqrt(2)
632
+
633
+ return out
634
+
635
+
636
+ class Discriminator(nn.Module):
637
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
638
+ super().__init__()
639
+
640
+ channels = {
641
+ 4: 512,
642
+ 8: 512,
643
+ 16: 512,
644
+ 32: 512,
645
+ 64: 256 * channel_multiplier,
646
+ 128: 128 * channel_multiplier,
647
+ 256: 64 * channel_multiplier,
648
+ 512: 32 * channel_multiplier,
649
+ 1024: 16 * channel_multiplier,
650
+ }
651
+
652
+ convs = [ConvLayer(3, channels[size], 1)]
653
+
654
+ log_size = int(math.log(size, 2))
655
+
656
+ in_channel = channels[size]
657
+
658
+ for i in range(log_size, 2, -1):
659
+ out_channel = channels[2 ** (i - 1)]
660
+
661
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
662
+
663
+ in_channel = out_channel
664
+
665
+ self.convs = nn.Sequential(*convs)
666
+
667
+ self.stddev_group = 4
668
+ self.stddev_feat = 1
669
+
670
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
671
+ self.final_linear = nn.Sequential(
672
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
673
+ EqualLinear(channels[4], 1),
674
+ )
675
+
676
+ def forward(self, input):
677
+ out = self.convs(input)
678
+
679
+ batch, channel, height, width = out.shape
680
+ group = min(batch, self.stddev_group)
681
+ stddev = out.view(
682
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
683
+ )
684
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
685
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
686
+ stddev = stddev.repeat(group, 1, height, width)
687
+ out = torch.cat([out, stddev], 1)
688
+
689
+ out = self.final_conv(out)
690
+
691
+ out = out.view(batch, -1)
692
+ out = self.final_linear(out)
693
+
694
+ return out
models/stylegan2/op/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu
2
+ from .upfirdn2d import upfirdn2d
models/stylegan2/op/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (253 Bytes). View file
 
models/stylegan2/op/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (257 Bytes). View file
 
models/stylegan2/op/__pycache__/fused_act.cpython-36.pyc ADDED
Binary file (1.33 kB). View file
 
models/stylegan2/op/__pycache__/fused_act.cpython-37.pyc ADDED
Binary file (1.31 kB). View file
 
models/stylegan2/op/__pycache__/upfirdn2d.cpython-36.pyc ADDED
Binary file (1.45 kB). View file
 
models/stylegan2/op/__pycache__/upfirdn2d.cpython-37.pyc ADDED
Binary file (1.43 kB). View file
 
models/stylegan2/op/fused_act.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ module_path = os.path.dirname(__file__)
8
+
9
+
10
+
11
+ class FusedLeakyReLU(nn.Module):
12
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
13
+ super().__init__()
14
+
15
+ self.bias = nn.Parameter(torch.zeros(channel))
16
+ self.negative_slope = negative_slope
17
+ self.scale = scale
18
+
19
+ def forward(self, input):
20
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
21
+
22
+
23
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
24
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
25
+ # input = input.cuda()
26
+ if input.ndim == 3:
27
+ return (
28
+ F.leaky_relu(
29
+ input + bias.view(1, *rest_dim, bias.shape[0]), negative_slope=negative_slope
30
+ )
31
+ * scale
32
+ )
33
+ else:
34
+ return (
35
+ F.leaky_relu(
36
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope
37
+ )
38
+ * scale
39
+ )
40
+
models/stylegan2/op/upfirdn2d.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+
7
+ module_path = os.path.dirname(__file__)
8
+
9
+
10
+
11
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
12
+ out = upfirdn2d_native(
13
+ input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
14
+ )
15
+
16
+ return out
17
+
18
+
19
+ def upfirdn2d_native(
20
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
21
+ ):
22
+ _, channel, in_h, in_w = input.shape
23
+ input = input.reshape(-1, in_h, in_w, 1)
24
+
25
+ _, in_h, in_w, minor = input.shape
26
+ kernel_h, kernel_w = kernel.shape
27
+
28
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
29
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
30
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
31
+
32
+ out = F.pad(
33
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
34
+ )
35
+ out = out[
36
+ :,
37
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
38
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
39
+ :,
40
+ ]
41
+
42
+ out = out.permute(0, 3, 1, 2)
43
+ out = out.reshape(
44
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
45
+ )
46
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
47
+ out = F.conv2d(out, w)
48
+ out = out.reshape(
49
+ -1,
50
+ minor,
51
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
52
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
53
+ )
54
+ out = out.permute(0, 2, 3, 1)
55
+ out = out[:, ::down_y, ::down_x, :]
56
+
57
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
58
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
59
+
60
+ return out.view(-1, channel, out_h, out_w)