soumickmj commited on
Commit
ad947b4
·
verified ·
1 Parent(s): b5dede8

Upload DiffAE

Browse files
DiffAE.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ import numpy as np
4
+ import torch
5
+ from pytorch_lightning.callbacks import *
6
+ from torch.optim.optimizer import Optimizer
7
+
8
+ from transformers import PreTrainedModel
9
+
10
+ from .DiffAEConfig import DiffAEConfig
11
+ from .DiffAE_support import *
12
+
13
+ class DiffAE(PreTrainedModel):
14
+ config_class = DiffAEConfig
15
+ def __init__(self, config):
16
+ super().__init__(config)
17
+
18
+ conf = ukbb_autoenc(n_latents=config.latent_dim)
19
+ conf.__dict__.update(**vars(config)) #update the supplied DiffAE params
20
+
21
+ if config.test_with_TEval:
22
+ conf.T_inv = conf.T_eval
23
+ conf.T_step = conf.T_eval
24
+
25
+ conf.fp16 = config.ampmode not in ["32", "32-true"]
26
+
27
+ conf.refresh_values()
28
+ conf.make_model_conf()
29
+
30
+ self.config = config
31
+ self.conf = conf
32
+
33
+ self.net = conf.make_model_conf().make_model()
34
+ self.ema_net = copy.deepcopy(self.net)
35
+ self.ema_net.requires_grad_(False)
36
+ self.ema_net.eval()
37
+
38
+ model_size = sum(param.data.nelement() for param in self.net.parameters())
39
+ print('Model params: %.2f M' % (model_size / 1024 / 1024))
40
+
41
+ self.sampler = conf.make_diffusion_conf().make_sampler()
42
+ self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler()
43
+
44
+ # this is shared for both model and latent
45
+ self.T_sampler = conf.make_T_sampler()
46
+
47
+ if conf.train_mode.use_latent_net():
48
+ self.latent_sampler = conf.make_latent_diffusion_conf(
49
+ ).make_sampler()
50
+ self.eval_latent_sampler = conf.make_latent_eval_diffusion_conf(
51
+ ).make_sampler()
52
+ else:
53
+ self.latent_sampler = None
54
+ self.eval_latent_sampler = None
55
+
56
+ # initial variables for consistent sampling
57
+ self.register_buffer('x_T', torch.randn(conf.sample_size, conf.in_channels, *conf.input_shape))
58
+
59
+ if conf.pretrain is not None:
60
+ print(f'loading pretrain ... {conf.pretrain.name}')
61
+ state = torch.load(conf.pretrain.path, map_location='cpu')
62
+ print('step:', state['global_step'])
63
+ self.load_state_dict(state['state_dict'], strict=False)
64
+
65
+ if conf.latent_infer_path is not None:
66
+ print('loading latent stats ...')
67
+ state = torch.load(conf.latent_infer_path)
68
+ self.conds = state['conds']
69
+ self.register_buffer('conds_mean', state['conds_mean'][None, :])
70
+ self.register_buffer('conds_std', state['conds_std'][None, :])
71
+ else:
72
+ self.conds_mean = None
73
+ self.conds_std = None
74
+
75
+ def normalise(self, cond):
76
+ cond = (cond - self.conds_mean.to(self.device)) / self.conds_std.to(
77
+ self.device)
78
+ return cond
79
+
80
+ def denormalise(self, cond):
81
+ cond = (cond * self.conds_std.to(self.device)) + self.conds_mean.to(
82
+ self.device)
83
+ return cond
84
+
85
+ def sample(self, N, device, T=None, T_latent=None):
86
+ if T is None:
87
+ sampler = self.eval_sampler
88
+ latent_sampler = self.latent_sampler
89
+ else:
90
+ sampler = self.conf._make_diffusion_conf(T).make_sampler()
91
+ latent_sampler = self.conf._make_latent_diffusion_conf(T_latent).make_sampler()
92
+
93
+ noise = torch.randn(N,
94
+ self.conf.in_channels,
95
+ *self.conf.input_shape,
96
+ device=device)
97
+ pred_img = render_uncondition(
98
+ self.conf,
99
+ self.ema_net,
100
+ noise,
101
+ sampler=sampler,
102
+ latent_sampler=latent_sampler,
103
+ conds_mean=self.conds_mean,
104
+ conds_std=self.conds_std,
105
+ )
106
+ pred_img = (pred_img + 1) / 2
107
+ return pred_img
108
+
109
+ def render(self, noise, cond=None, T=None, use_ema=True):
110
+ if T is None:
111
+ sampler = self.eval_sampler
112
+ else:
113
+ sampler = self.conf._make_diffusion_conf(T).make_sampler()
114
+
115
+ if cond is not None:
116
+ pred_img = render_condition(self.conf,
117
+ self.ema_net if use_ema else self.net,
118
+ noise,
119
+ sampler=sampler,
120
+ cond=cond)
121
+ else:
122
+ pred_img = render_uncondition(self.conf,
123
+ self.ema_net if use_ema else self.net,
124
+ noise,
125
+ sampler=sampler,
126
+ latent_sampler=None)
127
+ pred_img = (pred_img + 1) / 2
128
+ return pred_img
129
+
130
+ def encode(self, x, use_ema=True):
131
+ assert self.conf.model_type.has_autoenc()
132
+ return self.ema_net.encoder.forward(x) if use_ema else self.net.encoder.forward(x)
133
+
134
+ def encode_stochastic(self, x, cond, T=None, use_ema=True):
135
+ if T is None:
136
+ sampler = self.eval_sampler
137
+ else:
138
+ sampler = self.conf._make_diffusion_conf(T).make_sampler()
139
+ out = sampler.ddim_reverse_sample_loop(self.ema_net if use_ema else self.net,
140
+ x,
141
+ model_kwargs={'cond': cond})
142
+ return out['sample']
143
+
144
+ def forward(self, x_start=None, noise=None, ema_model: bool = False):
145
+ with amp.autocast(False):
146
+ model = self.ema_net if ema_model else self.net
147
+ return self.eval_sampler.sample(
148
+ model=model,
149
+ noise=noise,
150
+ x_start=x_start,
151
+ shape=noise.shape if noise is not None else x_start.shape,
152
+ )
153
+
154
+ def is_last_accum(self, batch_idx):
155
+ """
156
+ is it the last gradient accumulation loop?
157
+ used with gradient_accum > 1 and to see if the optimizer will perform "step" in this iteration or not
158
+ """
159
+ return (batch_idx + 1) % self.conf.accum_batches == 0
160
+
161
+ def training_step(self, batch, batch_idx):
162
+ """
163
+ given an input, calculate the loss function
164
+ no optimization at this stage.
165
+ """
166
+ with amp.autocast(False):
167
+ # forward
168
+ if self.conf.train_mode.require_dataset_infer():
169
+ # this mode as pre-calculated cond
170
+ cond = batch[0]
171
+ if self.conf.latent_znormalize:
172
+ cond = (cond - self.conds_mean.to(
173
+ self.device)) / self.conds_std.to(self.device)
174
+ else:
175
+ imgs, idxs = batch['inp']['data'], batch_idx
176
+ # print(f'(rank {self.global_rank}) batch size:', len(imgs))
177
+ x_start = imgs
178
+
179
+ if self.conf.train_mode == TrainMode.diffusion:
180
+ """
181
+ main training mode!!!
182
+ """
183
+ # with numpy seed we have the problem that the sample t's are related!
184
+ t, weight = self.T_sampler.sample(len(x_start), x_start.device)
185
+ losses = self.sampler.training_losses(model=self.net,
186
+ x_start=x_start,
187
+ t=t)
188
+ elif self.conf.train_mode.is_latent_diffusion():
189
+ """
190
+ training the latent variables!
191
+ """
192
+ # diffusion on the latent
193
+ t, weight = self.T_sampler.sample(len(cond), cond.device)
194
+ latent_losses = self.latent_sampler.training_losses(
195
+ model=self.net.latent_net, x_start=cond, t=t)
196
+ # train only do the latent diffusion
197
+ losses = {
198
+ 'latent': latent_losses['loss'],
199
+ 'loss': latent_losses['loss']
200
+ }
201
+ else:
202
+ raise NotImplementedError()
203
+
204
+ loss = losses['loss'].mean()
205
+ loss_dict = {"train_loss": loss}
206
+ for key in ['vae', 'latent', 'mmd', 'chamfer', 'arg_cnt']:
207
+ if key in losses:
208
+ loss_dict[f'train_{key}'] = losses[key].mean()
209
+ self.log_dict(loss_dict, on_step=True, on_epoch=True, reduce_fx="mean", sync_dist=True, batch_size=batch['inp']['data'].shape[0])
210
+
211
+ return loss
212
+
213
+ def on_train_batch_end(self, outputs, batch, batch_idx: int) -> None:
214
+ """
215
+ after each training step ...
216
+ """
217
+ if self.is_last_accum(batch_idx):
218
+ # only apply ema on the last gradient accumulation step,
219
+ # if it is the iteration that has optimizer.step()
220
+ if self.conf.train_mode == TrainMode.latent_diffusion:
221
+ # it trains only the latent hence change only the latent
222
+ ema(self.net.latent_net, self.ema_net.latent_net,
223
+ self.conf.ema_decay)
224
+ else:
225
+ ema(self.net, self.ema_net, self.conf.ema_decay)
226
+
227
+ def on_before_optimizer_step(self, optimizer: Optimizer) -> None:
228
+ # fix the fp16 + clip grad norm problem with pytorch lightinng
229
+ # this is the currently correct way to do it
230
+ if self.conf.grad_clip > 0:
231
+ # from trainer.params_grads import grads_norm, iter_opt_params
232
+ params = [
233
+ p for group in optimizer.param_groups for p in group['params']
234
+ ]
235
+ # print('before:', grads_norm(iter_opt_params(optimizer)))
236
+ torch.nn.utils.clip_grad_norm_(params,
237
+ max_norm=self.conf.grad_clip)
238
+ # print('after:', grads_norm(iter_opt_params(optimizer)))
239
+
240
+ #Validation
241
+ def validation_step(self, batch, batch_idx):
242
+ _, prediction_ema = self.inference_pass(batch['inp']['data'], T_inv=self.conf.T_eval, T_step=self.conf.T_eval, use_ema=True)
243
+ _, prediction_base = self.inference_pass(batch['inp']['data'], T_inv=self.conf.T_eval, T_step=self.conf.T_eval, use_ema=False)
244
+
245
+ inp = batch['inp']['data'].cpu()
246
+ inp = (inp + 1) / 2
247
+
248
+ _, val_ssim_ema = self._eval_prediction(inp, prediction_ema)
249
+ _, val_ssim_base = self._eval_prediction(inp, prediction_base)
250
+
251
+ self.log_dict({"val_ssim_ema": val_ssim_ema, "val_ssim_base": val_ssim_base, "val_loss": -val_ssim_ema}, on_step=True, on_epoch=True, reduce_fx="mean", sync_dist=True, batch_size=batch['inp']['data'].shape[0])
252
+ self.img_logger("val_ema", batch_idx, inp, prediction_ema)
253
+ self.img_logger("val_base", batch_idx, inp, prediction_base)
254
+
255
+ def _eval_prediction(self, inp, prediction):
256
+ prediction = prediction.detach().cpu()
257
+ prediction = prediction.numpy() if prediction.dtype not in {torch.bfloat16, torch.float16} else prediction.to(dtype=torch.float32).numpy()
258
+ if self.config.grey2RGB in [0, 2]:
259
+ inp = inp[:, 1, ...].unsqueeze(1)
260
+ prediction = np.expand_dims(prediction[:, 1, ...], axis=1)
261
+ val_ssim = getSSIM(inp.numpy(), prediction, data_range=1)
262
+ return prediction, val_ssim
263
+
264
+ def inference_pass(self, inp, T_inv, T_step, use_ema=True):
265
+ semantic_latent = self.encode(inp, use_ema=use_ema)
266
+ if self.config.test_emb_only:
267
+ return semantic_latent, None
268
+ stochastic_latent = self.encode_stochastic(inp, semantic_latent, T=T_inv)
269
+ prediction = self.render(stochastic_latent, semantic_latent, T=T_step, use_ema=use_ema)
270
+ return semantic_latent, prediction
271
+
272
+ # Testing
273
+ def test_step(self, batch, batch_idx):
274
+ emb, recon = self.inference_pass(batch['inp']['data'], T_inv=self.conf.T_inv, T_step=self.conf.T_step, use_ema=self.config.test_ema)
275
+
276
+ emb = emb.detach().cpu()
277
+ emb = emb.numpy() if emb.dtype not in {torch.bfloat16, torch.float16} else emb.to(dtype=torch.float32).numpy()
278
+
279
+ return emb, recon
280
+
281
+ #Prediction
282
+ def predict_step(self, batch, batch_idx):
283
+ emb = self.encode(batch['inp']['data']).detach().cpu()
284
+ return emb.numpy() if emb.dtype not in {torch.bfloat16, torch.float16} else emb.to(dtype=torch.float32).numpy()
285
+
286
+ def configure_optimizers(self):
287
+ if self.conf.optimizer == OptimizerType.adam:
288
+ optim = torch.optim.Adam(self.net.parameters(),
289
+ lr=self.conf.lr,
290
+ weight_decay=self.conf.weight_decay)
291
+ elif self.conf.optimizer == OptimizerType.adamw:
292
+ optim = torch.optim.AdamW(self.net.parameters(),
293
+ lr=self.conf.lr,
294
+ weight_decay=self.conf.weight_decay)
295
+ else:
296
+ raise NotImplementedError()
297
+ out = {'optimizer': optim}
298
+ if self.conf.warmup > 0:
299
+ sched = torch.optim.lr_scheduler.LambdaLR(optim,
300
+ lr_lambda=WarmupLR(
301
+ self.conf.warmup))
302
+ out['lr_scheduler'] = {
303
+ 'scheduler': sched,
304
+ 'interval': 'step',
305
+ }
306
+ return out
307
+
308
+ def split_tensor(self, x):
309
+ """
310
+ extract the tensor for a corresponding "worker" in the batch dimension
311
+
312
+ Args:
313
+ x: (n, c)
314
+
315
+ Returns: x: (n_local, c)
316
+ """
317
+ n = len(x)
318
+ rank = self.global_rank
319
+ world_size = get_world_size()
320
+ # print(f'rank: {rank}/{world_size}')
321
+ per_rank = n // world_size
322
+ return x[rank * per_rank:(rank + 1) * per_rank]
DiffAEConfig.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class DiffAEConfig(PretrainedConfig):
4
+ model_type = "DiffAE"
5
+ def __init__(self,
6
+ is3D=True,
7
+ in_channels=1,
8
+ out_channels=1,
9
+ latent_dim=128,
10
+ net_ch=32,
11
+ sample_every_batches=1000, #log samples during training. Set it to 0 to disable
12
+ sample_size=4, #Number of samples in the buffer for consistent sampling (batch size of x_T)
13
+ test_with_TEval=True,
14
+ ampmode="16-mixed",
15
+ grey2RGB=-1,
16
+ test_emb_only=True,
17
+ test_ema=True,
18
+ batch_size=9,
19
+ # beta_scheduler='linear',
20
+ # latent_beta_scheduler='linear',
21
+ data_name="ukbb",
22
+ diffusion_type = 'beatgans',
23
+ # eval_ema_every_samples = 200_000,
24
+ # eval_every_samples = 200_000,
25
+ lr=0.0001,
26
+ # net_beatgans_attn_head = 1,
27
+ # net_beatgans_embed_channels = 128,
28
+ # net_ch_mult = (1, 1, 2, 3, 4),
29
+ # T_eval = 20,
30
+ # latent_T_eval=1000,
31
+ # group_norm_limit=32,
32
+ seed=1701,
33
+ input_shape=(50, 128, 128),
34
+ # dropout=0.1,
35
+ **kwargs):
36
+ self.is3D = is3D
37
+ self.in_channels = in_channels
38
+ self.out_channels = out_channels
39
+ self.latent_dim = latent_dim
40
+ self.net_ch = net_ch
41
+ self.sample_every_batches = sample_every_batches
42
+ self.sample_size = sample_size
43
+ self.test_with_TEval = test_with_TEval
44
+ self.ampmode = ampmode
45
+ self.grey2RGB = grey2RGB
46
+ self.test_emb_only = test_emb_only
47
+ self.test_ema = test_ema
48
+ self.batch_size = batch_size
49
+ self.data_name = data_name
50
+ self.diffusion_type = diffusion_type
51
+ self.lr = lr
52
+ self.seed = seed
53
+ self.input_shape = input_shape
54
+ super().__init__(**kwargs)
DiffAE_diffusion_base.py ADDED
@@ -0,0 +1,1109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code started out as a PyTorch port of Ho et al's diffusion models:
3
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
4
+
5
+ Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
6
+ """
7
+
8
+ from .DiffAE_model_unet_autoenc import AutoencReturn
9
+ from .DiffAE_support_config_base import BaseConfig
10
+ import enum
11
+ import math
12
+
13
+ import numpy as np
14
+ import torch as th
15
+ from .DiffAE_model import *
16
+ from .DiffAE_model_nn import mean_flat
17
+ from typing import NamedTuple, Tuple
18
+ from .DiffAE_support_choices import *
19
+ from torch.cuda.amp import autocast
20
+ import torch.nn.functional as F
21
+
22
+ from dataclasses import dataclass
23
+
24
+
25
+ @dataclass
26
+ class GaussianDiffusionBeatGansConfig(BaseConfig):
27
+ gen_type: GenerativeType
28
+ betas: Tuple[float]
29
+ model_type: ModelType
30
+ model_mean_type: ModelMeanType
31
+ model_var_type: ModelVarType
32
+ loss_type: LossType
33
+ rescale_timesteps: bool
34
+ fp16: bool
35
+ train_pred_xstart_detach: bool = True
36
+
37
+ def make_sampler(self):
38
+ return GaussianDiffusionBeatGans(self)
39
+
40
+
41
+ class GaussianDiffusionBeatGans:
42
+ """
43
+ Utilities for training and sampling diffusion models.
44
+
45
+ Ported directly from here, and then adapted over time to further experimentation.
46
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
47
+
48
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
49
+ starting at T and going to 1.
50
+ :param model_mean_type: a ModelMeanType determining what the model outputs.
51
+ :param model_var_type: a ModelVarType determining how variance is output.
52
+ :param loss_type: a LossType determining the loss function to use.
53
+ :param rescale_timesteps: if True, pass floating point timesteps into the
54
+ model so that they are always scaled like in the
55
+ original paper (0 to 1000).
56
+ """
57
+ def __init__(self, conf: GaussianDiffusionBeatGansConfig):
58
+ self.conf = conf
59
+ self.model_mean_type = conf.model_mean_type
60
+ self.model_var_type = conf.model_var_type
61
+ self.loss_type = conf.loss_type
62
+ self.rescale_timesteps = conf.rescale_timesteps
63
+
64
+ # Use float64 for accuracy.
65
+ betas = np.array(conf.betas, dtype=np.float64)
66
+ self.betas = betas
67
+ assert len(betas.shape) == 1, "betas must be 1-D"
68
+ assert (betas > 0).all() and (betas <= 1).all()
69
+
70
+ self.num_timesteps = int(betas.shape[0])
71
+
72
+ alphas = 1.0 - betas
73
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
74
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
75
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
76
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps, )
77
+
78
+ # calculations for diffusion q(x_t | x_{t-1}) and others
79
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
80
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
81
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
82
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
83
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod -
84
+ 1)
85
+
86
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
87
+ self.posterior_variance = (betas * (1.0 - self.alphas_cumprod_prev) /
88
+ (1.0 - self.alphas_cumprod))
89
+ # log calculation clipped because the posterior variance is 0 at the
90
+ # beginning of the diffusion chain.
91
+ self.posterior_log_variance_clipped = np.log(
92
+ np.append(self.posterior_variance[1], self.posterior_variance[1:]))
93
+ self.posterior_mean_coef1 = (betas *
94
+ np.sqrt(self.alphas_cumprod_prev) /
95
+ (1.0 - self.alphas_cumprod))
96
+ self.posterior_mean_coef2 = ((1.0 - self.alphas_cumprod_prev) *
97
+ np.sqrt(alphas) /
98
+ (1.0 - self.alphas_cumprod))
99
+
100
+ def training_losses(self,
101
+ model: Model,
102
+ x_start: th.Tensor,
103
+ t: th.Tensor,
104
+ model_kwargs=None,
105
+ noise: th.Tensor = None):
106
+ """
107
+ Compute training losses for a single timestep.
108
+
109
+ :param model: the model to evaluate loss on.
110
+ :param x_start: the [N x C x ...] tensor of inputs.
111
+ :param t: a batch of timestep indices.
112
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
113
+ pass to the model. This can be used for conditioning.
114
+ :param noise: if specified, the specific Gaussian noise to try to remove.
115
+ :return: a dict with the key "loss" containing a tensor of shape [N].
116
+ Some mean or variance settings may also have other keys.
117
+ """
118
+ if model_kwargs is None:
119
+ model_kwargs = {}
120
+ if noise is None:
121
+ noise = th.randn_like(x_start)
122
+
123
+ x_t = self.q_sample(x_start, t, noise=noise)
124
+
125
+ terms = {'x_t': x_t}
126
+
127
+ if self.loss_type in [
128
+ LossType.mse,
129
+ LossType.l1,
130
+ ]:
131
+ with autocast(self.conf.fp16):
132
+ # x_t is static wrt. to the diffusion process
133
+ model_forward = model.forward(x=x_t.detach(),
134
+ t=self._scale_timesteps(t),
135
+ x_start=x_start.detach(),
136
+ **model_kwargs)
137
+ model_output = model_forward.pred
138
+
139
+ _model_output = model_output
140
+ if self.conf.train_pred_xstart_detach:
141
+ _model_output = _model_output.detach()
142
+ # get the pred xstart
143
+ p_mean_var = self.p_mean_variance(
144
+ model=DummyModel(pred=_model_output),
145
+ # gradient goes through x_t
146
+ x=x_t,
147
+ t=t,
148
+ clip_denoised=False)
149
+ terms['pred_xstart'] = p_mean_var['pred_xstart']
150
+
151
+ # model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
152
+
153
+ target_types = {
154
+ ModelMeanType.eps: noise,
155
+ }
156
+ target = target_types[self.model_mean_type]
157
+ assert model_output.shape == target.shape == x_start.shape
158
+
159
+ if self.loss_type == LossType.mse:
160
+ if self.model_mean_type == ModelMeanType.eps:
161
+ # (n, c, h, w) => (n, )
162
+ terms["mse"] = mean_flat((target - model_output)**2)
163
+ else:
164
+ raise NotImplementedError()
165
+ elif self.loss_type == LossType.l1:
166
+ # (n, c, h, w) => (n, )
167
+ terms["mse"] = mean_flat((target - model_output).abs())
168
+ else:
169
+ raise NotImplementedError()
170
+
171
+ if "vb" in terms:
172
+ # if learning the variance also use the vlb loss
173
+ terms["loss"] = terms["mse"] + terms["vb"]
174
+ else:
175
+ terms["loss"] = terms["mse"]
176
+ else:
177
+ raise NotImplementedError(self.loss_type)
178
+
179
+ return terms
180
+
181
+ def sample(self,
182
+ model: Model,
183
+ shape=None,
184
+ noise=None,
185
+ cond=None,
186
+ x_start=None,
187
+ clip_denoised=True,
188
+ model_kwargs=None,
189
+ progress=False):
190
+ """
191
+ Args:
192
+ x_start: given for the autoencoder
193
+ """
194
+ if model_kwargs is None:
195
+ model_kwargs = {}
196
+ if self.conf.model_type.has_autoenc():
197
+ model_kwargs['x_start'] = x_start
198
+ model_kwargs['cond'] = cond
199
+
200
+ if self.conf.gen_type == GenerativeType.ddpm:
201
+ return self.p_sample_loop(model,
202
+ shape=shape,
203
+ noise=noise,
204
+ clip_denoised=clip_denoised,
205
+ model_kwargs=model_kwargs,
206
+ progress=progress)
207
+ elif self.conf.gen_type == GenerativeType.ddim:
208
+ return self.ddim_sample_loop(model,
209
+ shape=shape,
210
+ noise=noise,
211
+ clip_denoised=clip_denoised,
212
+ model_kwargs=model_kwargs,
213
+ progress=progress)
214
+ else:
215
+ raise NotImplementedError()
216
+
217
+ def q_mean_variance(self, x_start, t):
218
+ """
219
+ Get the distribution q(x_t | x_0).
220
+
221
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
222
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
223
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
224
+ """
225
+ mean = (
226
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) *
227
+ x_start)
228
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t,
229
+ x_start.shape)
230
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod,
231
+ t, x_start.shape)
232
+ return mean, variance, log_variance
233
+
234
+ def q_sample(self, x_start, t, noise=None):
235
+ """
236
+ Diffuse the data for a given number of diffusion steps.
237
+
238
+ In other words, sample from q(x_t | x_0).
239
+
240
+ :param x_start: the initial data batch.
241
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
242
+ :param noise: if specified, the split-out normal noise.
243
+ :return: A noisy version of x_start.
244
+ """
245
+ if noise is None:
246
+ noise = th.randn_like(x_start)
247
+ assert noise.shape == x_start.shape
248
+ return (
249
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) *
250
+ x_start + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod,
251
+ t, x_start.shape) * noise)
252
+
253
+ def q_posterior_mean_variance(self, x_start, x_t, t):
254
+ """
255
+ Compute the mean and variance of the diffusion posterior:
256
+
257
+ q(x_{t-1} | x_t, x_0)
258
+
259
+ """
260
+ assert x_start.shape == x_t.shape
261
+ posterior_mean = (
262
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) *
263
+ x_start +
264
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) *
265
+ x_t)
266
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t,
267
+ x_t.shape)
268
+ posterior_log_variance_clipped = _extract_into_tensor(
269
+ self.posterior_log_variance_clipped, t, x_t.shape)
270
+ assert (posterior_mean.shape[0] == posterior_variance.shape[0] ==
271
+ posterior_log_variance_clipped.shape[0] == x_start.shape[0])
272
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
273
+
274
+ def p_mean_variance(self,
275
+ model: Model,
276
+ x,
277
+ t,
278
+ clip_denoised=True,
279
+ denoised_fn=None,
280
+ model_kwargs=None):
281
+ """
282
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
283
+ the initial x, x_0.
284
+
285
+ :param model: the model, which takes a signal and a batch of timesteps
286
+ as input.
287
+ :param x: the [N x C x ...] tensor at time t.
288
+ :param t: a 1-D Tensor of timesteps.
289
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
290
+ :param denoised_fn: if not None, a function which applies to the
291
+ x_start prediction before it is used to sample. Applies before
292
+ clip_denoised.
293
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
294
+ pass to the model. This can be used for conditioning.
295
+ :return: a dict with the following keys:
296
+ - 'mean': the model mean output.
297
+ - 'variance': the model variance output.
298
+ - 'log_variance': the log of 'variance'.
299
+ - 'pred_xstart': the prediction for x_0.
300
+ """
301
+ if model_kwargs is None:
302
+ model_kwargs = {}
303
+
304
+ B, C = x.shape[:2]
305
+ assert t.shape == (B, )
306
+ with autocast(self.conf.fp16):
307
+ model_forward = model.forward(x=x,
308
+ t=self._scale_timesteps(t),
309
+ **model_kwargs)
310
+ model_output = model_forward.pred
311
+
312
+ if self.model_var_type in [
313
+ ModelVarType.fixed_large, ModelVarType.fixed_small
314
+ ]:
315
+ model_variance, model_log_variance = {
316
+ # for fixedlarge, we set the initial (log-)variance like so
317
+ # to get a better decoder log likelihood.
318
+ ModelVarType.fixed_large: (
319
+ np.append(self.posterior_variance[1], self.betas[1:]),
320
+ np.log(
321
+ np.append(self.posterior_variance[1], self.betas[1:])),
322
+ ),
323
+ ModelVarType.fixed_small: (
324
+ self.posterior_variance,
325
+ self.posterior_log_variance_clipped,
326
+ ),
327
+ }[self.model_var_type]
328
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
329
+ model_log_variance = _extract_into_tensor(model_log_variance, t,
330
+ x.shape)
331
+
332
+ def process_xstart(x):
333
+ if denoised_fn is not None:
334
+ x = denoised_fn(x)
335
+ if clip_denoised:
336
+ return x.clamp(-1, 1)
337
+ return x
338
+
339
+ if self.model_mean_type in [
340
+ ModelMeanType.eps,
341
+ ]:
342
+ if self.model_mean_type == ModelMeanType.eps:
343
+ pred_xstart = process_xstart(
344
+ self._predict_xstart_from_eps(x_t=x, t=t,
345
+ eps=model_output))
346
+ else:
347
+ raise NotImplementedError()
348
+ model_mean, _, _ = self.q_posterior_mean_variance(
349
+ x_start=pred_xstart, x_t=x, t=t)
350
+ else:
351
+ raise NotImplementedError(self.model_mean_type)
352
+
353
+ assert (model_mean.shape == model_log_variance.shape ==
354
+ pred_xstart.shape == x.shape)
355
+ return {
356
+ "mean": model_mean,
357
+ "variance": model_variance,
358
+ "log_variance": model_log_variance,
359
+ "pred_xstart": pred_xstart,
360
+ 'model_forward': model_forward,
361
+ }
362
+
363
+ def _predict_xstart_from_eps(self, x_t, t, eps):
364
+ assert x_t.shape == eps.shape
365
+ return (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t,
366
+ x_t.shape) * x_t -
367
+ _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t,
368
+ x_t.shape) * eps)
369
+
370
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
371
+ assert x_t.shape == xprev.shape
372
+ return ( # (xprev - coef2*x_t) / coef1
373
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape)
374
+ * xprev - _extract_into_tensor(
375
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t,
376
+ x_t.shape) * x_t)
377
+
378
+ def _predict_xstart_from_scaled_xstart(self, t, scaled_xstart):
379
+ return scaled_xstart * _extract_into_tensor(
380
+ self.sqrt_recip_alphas_cumprod, t, scaled_xstart.shape)
381
+
382
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
383
+ return (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t,
384
+ x_t.shape) * x_t -
385
+ pred_xstart) / _extract_into_tensor(
386
+ self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
387
+
388
+ def _predict_eps_from_scaled_xstart(self, x_t, t, scaled_xstart):
389
+ """
390
+ Args:
391
+ scaled_xstart: is supposed to be sqrt(alphacum) * x_0
392
+ """
393
+ # 1 / sqrt(1-alphabar) * (x_t - scaled xstart)
394
+ return (x_t - scaled_xstart) / _extract_into_tensor(
395
+ self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
396
+
397
+ def _scale_timesteps(self, t):
398
+ if self.rescale_timesteps:
399
+ # scale t to be maxed out at 1000 steps
400
+ return t.float() * (1000.0 / self.num_timesteps)
401
+ return t
402
+
403
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
404
+ """
405
+ Compute the mean for the previous step, given a function cond_fn that
406
+ computes the gradient of a conditional log probability with respect to
407
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
408
+ condition on y.
409
+
410
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
411
+ """
412
+ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
413
+ new_mean = (p_mean_var["mean"].float() +
414
+ p_mean_var["variance"] * gradient.float())
415
+ return new_mean
416
+
417
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
418
+ """
419
+ Compute what the p_mean_variance output would have been, should the
420
+ model's score function be conditioned by cond_fn.
421
+
422
+ See condition_mean() for details on cond_fn.
423
+
424
+ Unlike condition_mean(), this instead uses the conditioning strategy
425
+ from Song et al (2020).
426
+ """
427
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
428
+
429
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
430
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
431
+ x, self._scale_timesteps(t), **model_kwargs)
432
+
433
+ out = p_mean_var.copy()
434
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
435
+ out["mean"], _, _ = self.q_posterior_mean_variance(
436
+ x_start=out["pred_xstart"], x_t=x, t=t)
437
+ return out
438
+
439
+ def p_sample(
440
+ self,
441
+ model: Model,
442
+ x,
443
+ t,
444
+ clip_denoised=True,
445
+ denoised_fn=None,
446
+ cond_fn=None,
447
+ model_kwargs=None,
448
+ ):
449
+ """
450
+ Sample x_{t-1} from the model at the given timestep.
451
+
452
+ :param model: the model to sample from.
453
+ :param x: the current tensor at x_{t-1}.
454
+ :param t: the value of t, starting at 0 for the first diffusion step.
455
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
456
+ :param denoised_fn: if not None, a function which applies to the
457
+ x_start prediction before it is used to sample.
458
+ :param cond_fn: if not None, this is a gradient function that acts
459
+ similarly to the model.
460
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
461
+ pass to the model. This can be used for conditioning.
462
+ :return: a dict containing the following keys:
463
+ - 'sample': a random sample from the model.
464
+ - 'pred_xstart': a prediction of x_0.
465
+ """
466
+ out = self.p_mean_variance(
467
+ model,
468
+ x,
469
+ t,
470
+ clip_denoised=clip_denoised,
471
+ denoised_fn=denoised_fn,
472
+ model_kwargs=model_kwargs,
473
+ )
474
+ noise = th.randn_like(x)
475
+ nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
476
+ ) # no noise when t == 0
477
+ if cond_fn is not None:
478
+ out["mean"] = self.condition_mean(cond_fn,
479
+ out,
480
+ x,
481
+ t,
482
+ model_kwargs=model_kwargs)
483
+ sample = out["mean"] + nonzero_mask * th.exp(
484
+ 0.5 * out["log_variance"]) * noise
485
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
486
+
487
+ def p_sample_loop(
488
+ self,
489
+ model: Model,
490
+ shape=None,
491
+ noise=None,
492
+ clip_denoised=True,
493
+ denoised_fn=None,
494
+ cond_fn=None,
495
+ model_kwargs=None,
496
+ device=None,
497
+ progress=False,
498
+ ):
499
+ """
500
+ Generate samples from the model.
501
+
502
+ :param model: the model module.
503
+ :param shape: the shape of the samples, (N, C, H, W).
504
+ :param noise: if specified, the noise from the encoder to sample.
505
+ Should be of the same shape as `shape`.
506
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
507
+ :param denoised_fn: if not None, a function which applies to the
508
+ x_start prediction before it is used to sample.
509
+ :param cond_fn: if not None, this is a gradient function that acts
510
+ similarly to the model.
511
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
512
+ pass to the model. This can be used for conditioning.
513
+ :param device: if specified, the device to create the samples on.
514
+ If not specified, use a model parameter's device.
515
+ :param progress: if True, show a tqdm progress bar.
516
+ :return: a non-differentiable batch of samples.
517
+ """
518
+ final = None
519
+ for sample in self.p_sample_loop_progressive(
520
+ model,
521
+ shape,
522
+ noise=noise,
523
+ clip_denoised=clip_denoised,
524
+ denoised_fn=denoised_fn,
525
+ cond_fn=cond_fn,
526
+ model_kwargs=model_kwargs,
527
+ device=device,
528
+ progress=progress,
529
+ ):
530
+ final = sample
531
+ return final["sample"]
532
+
533
+ def p_sample_loop_progressive(
534
+ self,
535
+ model: Model,
536
+ shape=None,
537
+ noise=None,
538
+ clip_denoised=True,
539
+ denoised_fn=None,
540
+ cond_fn=None,
541
+ model_kwargs=None,
542
+ device=None,
543
+ progress=False,
544
+ ):
545
+ """
546
+ Generate samples from the model and yield intermediate samples from
547
+ each timestep of diffusion.
548
+
549
+ Arguments are the same as p_sample_loop().
550
+ Returns a generator over dicts, where each dict is the return value of
551
+ p_sample().
552
+ """
553
+ if device is None:
554
+ device = next(model.parameters()).device
555
+ if noise is not None:
556
+ img = noise
557
+ else:
558
+ assert isinstance(shape, (tuple, list))
559
+ img = th.randn(*shape, device=device)
560
+ indices = list(range(self.num_timesteps))[::-1]
561
+
562
+ if progress:
563
+ # Lazy import so that we don't depend on tqdm.
564
+ from tqdm.auto import tqdm
565
+
566
+ indices = tqdm(indices)
567
+
568
+ for i in indices:
569
+ # t = th.tensor([i] * shape[0], device=device)
570
+ t = th.tensor([i] * len(img), device=device)
571
+ with th.no_grad():
572
+ out = self.p_sample(
573
+ model,
574
+ img,
575
+ t,
576
+ clip_denoised=clip_denoised,
577
+ denoised_fn=denoised_fn,
578
+ cond_fn=cond_fn,
579
+ model_kwargs=model_kwargs,
580
+ )
581
+ yield out
582
+ img = out["sample"]
583
+
584
+ def ddim_sample(
585
+ self,
586
+ model: Model,
587
+ x,
588
+ t,
589
+ clip_denoised=True,
590
+ denoised_fn=None,
591
+ cond_fn=None,
592
+ model_kwargs=None,
593
+ eta=0.0,
594
+ ):
595
+ """
596
+ Sample x_{t-1} from the model using DDIM.
597
+
598
+ Same usage as p_sample().
599
+ """
600
+ out = self.p_mean_variance(
601
+ model,
602
+ x,
603
+ t,
604
+ clip_denoised=clip_denoised,
605
+ denoised_fn=denoised_fn,
606
+ model_kwargs=model_kwargs,
607
+ )
608
+ if cond_fn is not None:
609
+ out = self.condition_score(cond_fn,
610
+ out,
611
+ x,
612
+ t,
613
+ model_kwargs=model_kwargs)
614
+
615
+ # Usually our model outputs epsilon, but we re-derive it
616
+ # in case we used x_start or x_prev prediction.
617
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
618
+
619
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
620
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t,
621
+ x.shape)
622
+ sigma = (eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) *
623
+ th.sqrt(1 - alpha_bar / alpha_bar_prev))
624
+ # Equation 12.
625
+ noise = th.randn_like(x)
626
+ mean_pred = (out["pred_xstart"] * th.sqrt(alpha_bar_prev) +
627
+ th.sqrt(1 - alpha_bar_prev - sigma**2) * eps)
628
+ nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
629
+ ) # no noise when t == 0
630
+ sample = mean_pred + nonzero_mask * sigma * noise
631
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
632
+
633
+ def ddim_reverse_sample(
634
+ self,
635
+ model: Model,
636
+ x,
637
+ t,
638
+ clip_denoised=True,
639
+ denoised_fn=None,
640
+ model_kwargs=None,
641
+ eta=0.0,
642
+ ):
643
+ """
644
+ Sample x_{t+1} from the model using DDIM reverse ODE.
645
+ NOTE: never used ?
646
+ """
647
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
648
+ out = self.p_mean_variance(
649
+ model,
650
+ x,
651
+ t,
652
+ clip_denoised=clip_denoised,
653
+ denoised_fn=denoised_fn,
654
+ model_kwargs=model_kwargs,
655
+ )
656
+ # Usually our model outputs epsilon, but we re-derive it
657
+ # in case we used x_start or x_prev prediction.
658
+ eps = (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape)
659
+ * x - out["pred_xstart"]) / _extract_into_tensor(
660
+ self.sqrt_recipm1_alphas_cumprod, t, x.shape)
661
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t,
662
+ x.shape)
663
+
664
+ # Equation 12. reversed (DDIM paper) (th.sqrt == torch.sqrt)
665
+ mean_pred = (out["pred_xstart"] * th.sqrt(alpha_bar_next) +
666
+ th.sqrt(1 - alpha_bar_next) * eps)
667
+
668
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
669
+
670
+ def ddim_reverse_sample_loop(
671
+ self,
672
+ model: Model,
673
+ x,
674
+ clip_denoised=True,
675
+ denoised_fn=None,
676
+ model_kwargs=None,
677
+ eta=0.0,
678
+ device=None,
679
+ ):
680
+ if device is None:
681
+ device = next(model.parameters()).device
682
+ sample_t = []
683
+ xstart_t = []
684
+ T = []
685
+ indices = list(range(self.num_timesteps))
686
+ sample = x
687
+ for i in indices:
688
+ t = th.tensor([i] * len(sample), device=device)
689
+ with th.no_grad():
690
+ out = self.ddim_reverse_sample(model,
691
+ sample,
692
+ t=t,
693
+ clip_denoised=clip_denoised,
694
+ denoised_fn=denoised_fn,
695
+ model_kwargs=model_kwargs,
696
+ eta=eta)
697
+ sample = out['sample']
698
+ # [1, ..., T]
699
+ sample_t.append(sample)
700
+ # [0, ...., T-1]
701
+ xstart_t.append(out['pred_xstart'])
702
+ # [0, ..., T-1] ready to use
703
+ T.append(t)
704
+
705
+ return {
706
+ # xT "
707
+ 'sample': sample,
708
+ # (1, ..., T)
709
+ 'sample_t': sample_t,
710
+ # xstart here is a bit different from sampling from T = T-1 to T = 0
711
+ # may not be exact
712
+ 'xstart_t': xstart_t,
713
+ 'T': T,
714
+ }
715
+
716
+ def ddim_sample_loop(
717
+ self,
718
+ model: Model,
719
+ shape=None,
720
+ noise=None,
721
+ clip_denoised=True,
722
+ denoised_fn=None,
723
+ cond_fn=None,
724
+ model_kwargs=None,
725
+ device=None,
726
+ progress=False,
727
+ eta=0.0,
728
+ ):
729
+ """
730
+ Generate samples from the model using DDIM.
731
+
732
+ Same usage as p_sample_loop().
733
+ """
734
+ final = None
735
+ for sample in self.ddim_sample_loop_progressive(
736
+ model,
737
+ shape,
738
+ noise=noise,
739
+ clip_denoised=clip_denoised,
740
+ denoised_fn=denoised_fn,
741
+ cond_fn=cond_fn,
742
+ model_kwargs=model_kwargs,
743
+ device=device,
744
+ progress=progress,
745
+ eta=eta,
746
+ ):
747
+ final = sample
748
+ return final["sample"]
749
+
750
+ def ddim_sample_loop_progressive(
751
+ self,
752
+ model: Model,
753
+ shape=None,
754
+ noise=None,
755
+ clip_denoised=True,
756
+ denoised_fn=None,
757
+ cond_fn=None,
758
+ model_kwargs=None,
759
+ device=None,
760
+ progress=False,
761
+ eta=0.0,
762
+ ):
763
+ """
764
+ Use DDIM to sample from the model and yield intermediate samples from
765
+ each timestep of DDIM.
766
+
767
+ Same usage as p_sample_loop_progressive().
768
+ """
769
+ if device is None:
770
+ device = next(model.parameters()).device
771
+ if noise is not None:
772
+ img = noise
773
+ else:
774
+ assert isinstance(shape, (tuple, list))
775
+ img = th.randn(*shape, device=device)
776
+ indices = list(range(self.num_timesteps))[::-1]
777
+
778
+ if progress:
779
+ # Lazy import so that we don't depend on tqdm.
780
+ from tqdm.auto import tqdm
781
+
782
+ indices = tqdm(indices)
783
+
784
+ for i in indices:
785
+
786
+ if isinstance(model_kwargs, list):
787
+ # index dependent model kwargs
788
+ # (T-1, ..., 0)
789
+ _kwargs = model_kwargs[i]
790
+ else:
791
+ _kwargs = model_kwargs
792
+
793
+ t = th.tensor([i] * len(img), device=device)
794
+ with th.no_grad():
795
+ out = self.ddim_sample(
796
+ model,
797
+ img,
798
+ t,
799
+ clip_denoised=clip_denoised,
800
+ denoised_fn=denoised_fn,
801
+ cond_fn=cond_fn,
802
+ model_kwargs=_kwargs,
803
+ eta=eta,
804
+ )
805
+ out['t'] = t
806
+ yield out
807
+ img = out["sample"]
808
+
809
+ def _vb_terms_bpd(self,
810
+ model: Model,
811
+ x_start,
812
+ x_t,
813
+ t,
814
+ clip_denoised=True,
815
+ model_kwargs=None):
816
+ """
817
+ Get a term for the variational lower-bound.
818
+
819
+ The resulting units are bits (rather than nats, as one might expect).
820
+ This allows for comparison to other papers.
821
+
822
+ :return: a dict with the following keys:
823
+ - 'output': a shape [N] tensor of NLLs or KLs.
824
+ - 'pred_xstart': the x_0 predictions.
825
+ """
826
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
827
+ x_start=x_start, x_t=x_t, t=t)
828
+ out = self.p_mean_variance(model,
829
+ x_t,
830
+ t,
831
+ clip_denoised=clip_denoised,
832
+ model_kwargs=model_kwargs)
833
+ kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"],
834
+ out["log_variance"])
835
+ kl = mean_flat(kl) / np.log(2.0)
836
+
837
+ decoder_nll = -discretized_gaussian_log_likelihood(
838
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"])
839
+ assert decoder_nll.shape == x_start.shape
840
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
841
+
842
+ # At the first timestep return the decoder NLL,
843
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
844
+ output = th.where((t == 0), decoder_nll, kl)
845
+ return {
846
+ "output": output,
847
+ "pred_xstart": out["pred_xstart"],
848
+ 'model_forward': out['model_forward'],
849
+ }
850
+
851
+ def _prior_bpd(self, x_start):
852
+ """
853
+ Get the prior KL term for the variational lower-bound, measured in
854
+ bits-per-dim.
855
+
856
+ This term can't be optimized, as it only depends on the encoder.
857
+
858
+ :param x_start: the [N x C x ...] tensor of inputs.
859
+ :return: a batch of [N] KL values (in bits), one per batch element.
860
+ """
861
+ batch_size = x_start.shape[0]
862
+ t = th.tensor([self.num_timesteps - 1] * batch_size,
863
+ device=x_start.device)
864
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
865
+ kl_prior = normal_kl(mean1=qt_mean,
866
+ logvar1=qt_log_variance,
867
+ mean2=0.0,
868
+ logvar2=0.0)
869
+ return mean_flat(kl_prior) / np.log(2.0)
870
+
871
+ def calc_bpd_loop(self,
872
+ model: Model,
873
+ x_start,
874
+ clip_denoised=True,
875
+ model_kwargs=None):
876
+ """
877
+ Compute the entire variational lower-bound, measured in bits-per-dim,
878
+ as well as other related quantities.
879
+
880
+ :param model: the model to evaluate loss on.
881
+ :param x_start: the [N x C x ...] tensor of inputs.
882
+ :param clip_denoised: if True, clip denoised samples.
883
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
884
+ pass to the model. This can be used for conditioning.
885
+
886
+ :return: a dict containing the following keys:
887
+ - total_bpd: the total variational lower-bound, per batch element.
888
+ - prior_bpd: the prior term in the lower-bound.
889
+ - vb: an [N x T] tensor of terms in the lower-bound.
890
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
891
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
892
+ """
893
+ device = x_start.device
894
+ batch_size = x_start.shape[0]
895
+
896
+ vb = []
897
+ xstart_mse = []
898
+ mse = []
899
+ for t in list(range(self.num_timesteps))[::-1]:
900
+ t_batch = th.tensor([t] * batch_size, device=device)
901
+ noise = th.randn_like(x_start)
902
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
903
+ # Calculate VLB term at the current timestep
904
+ with th.no_grad():
905
+ out = self._vb_terms_bpd(
906
+ model,
907
+ x_start=x_start,
908
+ x_t=x_t,
909
+ t=t_batch,
910
+ clip_denoised=clip_denoised,
911
+ model_kwargs=model_kwargs,
912
+ )
913
+ vb.append(out["output"])
914
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start)**2))
915
+ eps = self._predict_eps_from_xstart(x_t, t_batch,
916
+ out["pred_xstart"])
917
+ mse.append(mean_flat((eps - noise)**2))
918
+
919
+ vb = th.stack(vb, dim=1)
920
+ xstart_mse = th.stack(xstart_mse, dim=1)
921
+ mse = th.stack(mse, dim=1)
922
+
923
+ prior_bpd = self._prior_bpd(x_start)
924
+ total_bpd = vb.sum(dim=1) + prior_bpd
925
+ return {
926
+ "total_bpd": total_bpd,
927
+ "prior_bpd": prior_bpd,
928
+ "vb": vb,
929
+ "xstart_mse": xstart_mse,
930
+ "mse": mse,
931
+ }
932
+
933
+
934
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
935
+ """
936
+ Extract values from a 1-D numpy array for a batch of indices.
937
+
938
+ :param arr: the 1-D numpy array.
939
+ :param timesteps: a tensor of indices into the array to extract.
940
+ :param broadcast_shape: a larger shape of K dimensions with the batch
941
+ dimension equal to the length of timesteps.
942
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
943
+ """
944
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
945
+ while len(res.shape) < len(broadcast_shape):
946
+ res = res[..., None]
947
+ return res.expand(broadcast_shape)
948
+
949
+
950
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
951
+ """
952
+ Get a pre-defined beta schedule for the given name.
953
+
954
+ The beta schedule library consists of beta schedules which remain similar
955
+ in the limit of num_diffusion_timesteps.
956
+ Beta schedules may be added, but should not be removed or changed once
957
+ they are committed to maintain backwards compatibility.
958
+ """
959
+ if schedule_name == "linear":
960
+ # Linear schedule from Ho et al, extended to work for any number of
961
+ # diffusion steps.
962
+ scale = 1000 / num_diffusion_timesteps
963
+ beta_start = scale * 0.0001
964
+ beta_end = scale * 0.02
965
+ return np.linspace(beta_start,
966
+ beta_end,
967
+ num_diffusion_timesteps,
968
+ dtype=np.float64)
969
+ elif schedule_name == "cosine":
970
+ return betas_for_alpha_bar(
971
+ num_diffusion_timesteps,
972
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2)**2,
973
+ )
974
+ elif schedule_name == "const0.01":
975
+ scale = 1000 / num_diffusion_timesteps
976
+ return np.array([scale * 0.01] * num_diffusion_timesteps,
977
+ dtype=np.float64)
978
+ elif schedule_name == "const0.015":
979
+ scale = 1000 / num_diffusion_timesteps
980
+ return np.array([scale * 0.015] * num_diffusion_timesteps,
981
+ dtype=np.float64)
982
+ elif schedule_name == "const0.008":
983
+ scale = 1000 / num_diffusion_timesteps
984
+ return np.array([scale * 0.008] * num_diffusion_timesteps,
985
+ dtype=np.float64)
986
+ elif schedule_name == "const0.0065":
987
+ scale = 1000 / num_diffusion_timesteps
988
+ return np.array([scale * 0.0065] * num_diffusion_timesteps,
989
+ dtype=np.float64)
990
+ elif schedule_name == "const0.0055":
991
+ scale = 1000 / num_diffusion_timesteps
992
+ return np.array([scale * 0.0055] * num_diffusion_timesteps,
993
+ dtype=np.float64)
994
+ elif schedule_name == "const0.0045":
995
+ scale = 1000 / num_diffusion_timesteps
996
+ return np.array([scale * 0.0045] * num_diffusion_timesteps,
997
+ dtype=np.float64)
998
+ elif schedule_name == "const0.0035":
999
+ scale = 1000 / num_diffusion_timesteps
1000
+ return np.array([scale * 0.0035] * num_diffusion_timesteps,
1001
+ dtype=np.float64)
1002
+ elif schedule_name == "const0.0025":
1003
+ scale = 1000 / num_diffusion_timesteps
1004
+ return np.array([scale * 0.0025] * num_diffusion_timesteps,
1005
+ dtype=np.float64)
1006
+ elif schedule_name == "const0.0015":
1007
+ scale = 1000 / num_diffusion_timesteps
1008
+ return np.array([scale * 0.0015] * num_diffusion_timesteps,
1009
+ dtype=np.float64)
1010
+ else:
1011
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
1012
+
1013
+
1014
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
1015
+ """
1016
+ Create a beta schedule that discretizes the given alpha_t_bar function,
1017
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
1018
+
1019
+ :param num_diffusion_timesteps: the number of betas to produce.
1020
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
1021
+ produces the cumulative product of (1-beta) up to that
1022
+ part of the diffusion process.
1023
+ :param max_beta: the maximum beta to use; use values lower than 1 to
1024
+ prevent singularities.
1025
+ """
1026
+ betas = []
1027
+ for i in range(num_diffusion_timesteps):
1028
+ t1 = i / num_diffusion_timesteps
1029
+ t2 = (i + 1) / num_diffusion_timesteps
1030
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
1031
+ return np.array(betas)
1032
+
1033
+
1034
+ def normal_kl(mean1, logvar1, mean2, logvar2):
1035
+ """
1036
+ Compute the KL divergence between two gaussians.
1037
+
1038
+ Shapes are automatically broadcasted, so batches can be compared to
1039
+ scalars, among other use cases.
1040
+ """
1041
+ tensor = None
1042
+ for obj in (mean1, logvar1, mean2, logvar2):
1043
+ if isinstance(obj, th.Tensor):
1044
+ tensor = obj
1045
+ break
1046
+ assert tensor is not None, "at least one argument must be a Tensor"
1047
+
1048
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
1049
+ # Tensors, but it does not work for th.exp().
1050
+ logvar1, logvar2 = [
1051
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
1052
+ for x in (logvar1, logvar2)
1053
+ ]
1054
+
1055
+ return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) +
1056
+ ((mean1 - mean2)**2) * th.exp(-logvar2))
1057
+
1058
+
1059
+ def approx_standard_normal_cdf(x):
1060
+ """
1061
+ A fast approximation of the cumulative distribution function of the
1062
+ standard normal.
1063
+ """
1064
+ return 0.5 * (
1065
+ 1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
1066
+
1067
+
1068
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
1069
+ """
1070
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
1071
+ given image.
1072
+
1073
+ :param x: the target images. It is assumed that this was uint8 values,
1074
+ rescaled to the range [-1, 1].
1075
+ :param means: the Gaussian mean Tensor.
1076
+ :param log_scales: the Gaussian log stddev Tensor.
1077
+ :return: a tensor like x of log probabilities (in nats).
1078
+ """
1079
+ assert x.shape == means.shape == log_scales.shape
1080
+ centered_x = x - means
1081
+ inv_stdv = th.exp(-log_scales)
1082
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
1083
+ cdf_plus = approx_standard_normal_cdf(plus_in)
1084
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
1085
+ cdf_min = approx_standard_normal_cdf(min_in)
1086
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
1087
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
1088
+ cdf_delta = cdf_plus - cdf_min
1089
+ log_probs = th.where(
1090
+ x < -0.999,
1091
+ log_cdf_plus,
1092
+ th.where(x > 0.999, log_one_minus_cdf_min,
1093
+ th.log(cdf_delta.clamp(min=1e-12))),
1094
+ )
1095
+ assert log_probs.shape == x.shape
1096
+ return log_probs
1097
+
1098
+
1099
+ class DummyModel(th.nn.Module):
1100
+ def __init__(self, pred):
1101
+ super().__init__()
1102
+ self.pred = pred
1103
+
1104
+ def forward(self, *args, **kwargs):
1105
+ return DummyReturn(pred=self.pred)
1106
+
1107
+
1108
+ class DummyReturn(NamedTuple):
1109
+ pred: th.Tensor
DiffAE_diffusion_diffusion.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .DiffAE_diffusion_base import *
2
+ from dataclasses import dataclass
3
+
4
+
5
+ def space_timesteps(num_timesteps, section_counts):
6
+ """
7
+ Create a list of timesteps to use from an original diffusion process,
8
+ given the number of timesteps we want to take from equally-sized portions
9
+ of the original process.
10
+
11
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
12
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
13
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
14
+
15
+ If the stride is a string starting with "ddim", then the fixed striding
16
+ from the DDIM paper is used, and only one section is allowed.
17
+
18
+ :param num_timesteps: the number of diffusion steps in the original
19
+ process to divide up.
20
+ :param section_counts: either a list of numbers, or a string containing
21
+ comma-separated numbers, indicating the step count
22
+ per section. As a special case, use "ddimN" where N
23
+ is a number of steps to use the striding from the
24
+ DDIM paper.
25
+ :return: a set of diffusion steps from the original process to use.
26
+ """
27
+ if isinstance(section_counts, str):
28
+ if section_counts.startswith("ddim"):
29
+ desired_count = int(section_counts[len("ddim"):])
30
+ for i in range(1, num_timesteps):
31
+ if len(range(0, num_timesteps, i)) == desired_count:
32
+ return set(range(0, num_timesteps, i))
33
+ raise ValueError(
34
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
35
+ )
36
+ section_counts = [int(x) for x in section_counts.split(",")]
37
+ size_per = num_timesteps // len(section_counts)
38
+ extra = num_timesteps % len(section_counts)
39
+ start_idx = 0
40
+ all_steps = []
41
+ for i, section_count in enumerate(section_counts):
42
+ size = size_per + (1 if i < extra else 0)
43
+ if size < section_count:
44
+ raise ValueError(
45
+ f"cannot divide section of {size} steps into {section_count}")
46
+ if section_count <= 1:
47
+ frac_stride = 1
48
+ else:
49
+ frac_stride = (size - 1) / (section_count - 1)
50
+ cur_idx = 0.0
51
+ taken_steps = []
52
+ for _ in range(section_count):
53
+ taken_steps.append(start_idx + round(cur_idx))
54
+ cur_idx += frac_stride
55
+ all_steps += taken_steps
56
+ start_idx += size
57
+ return set(all_steps)
58
+
59
+
60
+ @dataclass
61
+ class SpacedDiffusionBeatGansConfig(GaussianDiffusionBeatGansConfig):
62
+ use_timesteps: Tuple[int] = None
63
+
64
+ def make_sampler(self):
65
+ return SpacedDiffusionBeatGans(self)
66
+
67
+
68
+ class SpacedDiffusionBeatGans(GaussianDiffusionBeatGans):
69
+ """
70
+ A diffusion process which can skip steps in a base diffusion process.
71
+
72
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
73
+ original diffusion process to retain.
74
+ :param kwargs: the kwargs to create the base diffusion process.
75
+ """
76
+ def __init__(self, conf: SpacedDiffusionBeatGansConfig):
77
+ self.conf = conf
78
+ self.use_timesteps = set(conf.use_timesteps)
79
+ # how the new t's mapped to the old t's
80
+ self.timestep_map = []
81
+ self.original_num_steps = len(conf.betas)
82
+
83
+ base_diffusion = GaussianDiffusionBeatGans(conf) # pylint: disable=missing-kwoa
84
+ last_alpha_cumprod = 1.0
85
+ new_betas = []
86
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
87
+ if i in self.use_timesteps:
88
+ # getting the new betas of the new timesteps
89
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
90
+ last_alpha_cumprod = alpha_cumprod
91
+ self.timestep_map.append(i)
92
+ conf.betas = np.array(new_betas)
93
+ super().__init__(conf)
94
+
95
+ def p_mean_variance(self, model: Model, *args, **kwargs): # pylint: disable=signature-differs
96
+ return super().p_mean_variance(self._wrap_model(model), *args,
97
+ **kwargs)
98
+
99
+ def training_losses(self, model: Model, *args, **kwargs): # pylint: disable=signature-differs
100
+ return super().training_losses(self._wrap_model(model), *args,
101
+ **kwargs)
102
+
103
+ def condition_mean(self, cond_fn, *args, **kwargs):
104
+ return super().condition_mean(self._wrap_model(cond_fn), *args,
105
+ **kwargs)
106
+
107
+ def condition_score(self, cond_fn, *args, **kwargs):
108
+ return super().condition_score(self._wrap_model(cond_fn), *args,
109
+ **kwargs)
110
+
111
+ def _wrap_model(self, model: Model):
112
+ if isinstance(model, _WrappedModel):
113
+ return model
114
+ return _WrappedModel(model, self.timestep_map, self.rescale_timesteps,
115
+ self.original_num_steps)
116
+
117
+ def _scale_timesteps(self, t):
118
+ # Scaling is done by the wrapped model.
119
+ return t
120
+
121
+
122
+ class _WrappedModel:
123
+ """
124
+ converting the supplied t's to the old t's scales.
125
+ """
126
+ def __init__(self, model, timestep_map, rescale_timesteps,
127
+ original_num_steps):
128
+ self.model = model
129
+ self.timestep_map = timestep_map
130
+ self.rescale_timesteps = rescale_timesteps
131
+ self.original_num_steps = original_num_steps
132
+
133
+ def forward(self, x, t, t_cond=None, **kwargs):
134
+ """
135
+ Args:
136
+ t: t's with differrent ranges (can be << T due to smaller eval T) need to be converted to the original t's
137
+ t_cond: the same as t but can be of different values
138
+ """
139
+ map_tensor = th.tensor(self.timestep_map,
140
+ device=t.device,
141
+ dtype=t.dtype)
142
+
143
+ def do(t):
144
+ new_ts = map_tensor[t]
145
+ if self.rescale_timesteps:
146
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
147
+ return new_ts
148
+
149
+ if t_cond is not None:
150
+ # support t_cond
151
+ t_cond = do(t_cond)
152
+
153
+ return self.model(x=x, t=do(t), t_cond=t_cond, **kwargs)
154
+
155
+ def __getattr__(self, name):
156
+ # allow for calling the model's methods
157
+ if hasattr(self.model, name):
158
+ func = getattr(self.model, name)
159
+ return func
160
+ raise AttributeError(name)
DiffAE_diffusion_resample.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import numpy as np
4
+ import torch as th
5
+ import torch.distributed as dist
6
+
7
+
8
+ def create_named_schedule_sampler(name, diffusion):
9
+ """
10
+ Create a ScheduleSampler from a library of pre-defined samplers.
11
+
12
+ :param name: the name of the sampler.
13
+ :param diffusion: the diffusion object to sample for.
14
+ """
15
+ if name == "uniform":
16
+ return UniformSampler(diffusion)
17
+ else:
18
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
19
+
20
+
21
+ class ScheduleSampler(ABC):
22
+ """
23
+ A distribution over timesteps in the diffusion process, intended to reduce
24
+ variance of the objective.
25
+
26
+ By default, samplers perform unbiased importance sampling, in which the
27
+ objective's mean is unchanged.
28
+ However, subclasses may override sample() to change how the resampled
29
+ terms are reweighted, allowing for actual changes in the objective.
30
+ """
31
+ @abstractmethod
32
+ def weights(self):
33
+ """
34
+ Get a numpy array of weights, one per diffusion step.
35
+
36
+ The weights needn't be normalized, but must be positive.
37
+ """
38
+
39
+ def sample(self, batch_size, device):
40
+ """
41
+ Importance-sample timesteps for a batch.
42
+
43
+ :param batch_size: the number of timesteps.
44
+ :param device: the torch device to save to.
45
+ :return: a tuple (timesteps, weights):
46
+ - timesteps: a tensor of timestep indices.
47
+ - weights: a tensor of weights to scale the resulting losses.
48
+ """
49
+ w = self.weights()
50
+ p = w / np.sum(w)
51
+ indices_np = np.random.choice(len(p), size=(batch_size, ), p=p)
52
+ indices = th.from_numpy(indices_np).long().to(device)
53
+ weights_np = 1 / (len(p) * p[indices_np])
54
+ weights = th.from_numpy(weights_np).float().to(device)
55
+ return indices, weights
56
+
57
+
58
+ class UniformSampler(ScheduleSampler):
59
+ def __init__(self, num_timesteps):
60
+ self._weights = np.ones([num_timesteps])
61
+
62
+ def weights(self):
63
+ return self._weights
DiffAE_model.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+ from .DiffAE_model_unet import BeatGANsUNetModel, BeatGANsUNetConfig
3
+ from .DiffAE_model_unet_autoenc import BeatGANsAutoencConfig, BeatGANsAutoencModel
4
+ from .DiffAE_model_latentnet import *
5
+
6
+ Model = Union[BeatGANsUNetModel, BeatGANsAutoencModel]
7
+ ModelConfig = Union[BeatGANsUNetConfig, BeatGANsAutoencConfig]
DiffAE_model_blocks.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ from abc import abstractmethod
4
+ from dataclasses import dataclass
5
+ from numbers import Number
6
+
7
+ import torch as th
8
+ import torch.nn.functional as F
9
+ from .DiffAE_support_choices import *
10
+ from .DiffAE_support_config_base import BaseConfig
11
+ from torch import nn
12
+
13
+ from .DiffAE_model_nn import (avg_pool_nd, conv_nd, linear, normalization,
14
+ timestep_embedding, torch_checkpoint, zero_module)
15
+
16
+ class ScaleAt(Enum):
17
+ after_norm = 'afternorm'
18
+
19
+
20
+ class TimestepBlock(nn.Module):
21
+ """
22
+ Any module where forward() takes timestep embeddings as a second argument.
23
+ """
24
+ @abstractmethod
25
+ def forward(self, x, emb=None, cond=None, lateral=None):
26
+ """
27
+ Apply the module to `x` given `emb` timestep embeddings.
28
+ """
29
+
30
+
31
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
32
+ """
33
+ A sequential module that passes timestep embeddings to the children that
34
+ support it as an extra input.
35
+ """
36
+ def forward(self, x, emb=None, cond=None, lateral=None):
37
+ for layer in self:
38
+ if isinstance(layer, TimestepBlock):
39
+ x = layer(x, emb=emb, cond=cond, lateral=lateral)
40
+ else:
41
+ x = layer(x)
42
+ return x
43
+
44
+
45
+ @dataclass
46
+ class ResBlockConfig(BaseConfig):
47
+ channels: int
48
+ emb_channels: int
49
+ dropout: float
50
+ out_channels: int = None
51
+ # condition the resblock with time (and encoder's output)
52
+ use_condition: bool = True
53
+ # whether to use 3x3 conv for skip path when the channels aren't matched
54
+ use_conv: bool = False
55
+ group_norm_limit: int = 32
56
+ # dimension of conv (always 2 = 2d)
57
+ dims: int = 2
58
+ # gradient checkpoint
59
+ use_checkpoint: bool = False
60
+ up: bool = False
61
+ down: bool = False
62
+ # whether to condition with both time & encoder's output
63
+ two_cond: bool = False
64
+ # number of encoders' output channels
65
+ cond_emb_channels: int = None
66
+ # suggest: False
67
+ has_lateral: bool = False
68
+ lateral_channels: int = None
69
+ # whether to init the convolution with zero weights
70
+ # this is default from BeatGANs and seems to help learning
71
+ use_zero_module: bool = True
72
+
73
+ def __post_init__(self):
74
+ self.out_channels = self.out_channels or self.channels
75
+ self.cond_emb_channels = self.cond_emb_channels or self.emb_channels
76
+
77
+ def make_model(self):
78
+ return ResBlock(self)
79
+
80
+
81
+ class ResBlock(TimestepBlock):
82
+ """
83
+ A residual block that can optionally change the number of channels.
84
+
85
+ total layers:
86
+ in_layers
87
+ - norm
88
+ - act
89
+ - conv
90
+ out_layers
91
+ - norm
92
+ - (modulation)
93
+ - act
94
+ - conv
95
+ """
96
+ def __init__(self, conf: ResBlockConfig):
97
+ super().__init__()
98
+ self.conf = conf
99
+
100
+ #############################
101
+ # IN LAYERS
102
+ #############################
103
+ assert conf.lateral_channels is None
104
+ layers = [
105
+ normalization(conf.channels, limit=conf.group_norm_limit if "group_norm_limit" in conf.__dict__ else 32),
106
+ nn.SiLU(),
107
+ conv_nd(conf.dims, conf.channels, conf.out_channels, 3, padding=1)
108
+ ]
109
+ self.in_layers = nn.Sequential(*layers)
110
+
111
+ self.updown = conf.up or conf.down
112
+
113
+ if conf.up:
114
+ self.h_upd = Upsample(conf.channels, False, conf.dims)
115
+ self.x_upd = Upsample(conf.channels, False, conf.dims)
116
+ elif conf.down:
117
+ self.h_upd = Downsample(conf.channels, False, conf.dims)
118
+ self.x_upd = Downsample(conf.channels, False, conf.dims)
119
+ else:
120
+ self.h_upd = self.x_upd = nn.Identity()
121
+
122
+ #############################
123
+ # OUT LAYERS CONDITIONS
124
+ #############################
125
+ if conf.use_condition:
126
+ # condition layers for the out_layers
127
+ self.emb_layers = nn.Sequential(
128
+ nn.SiLU(),
129
+ linear(conf.emb_channels, 2 * conf.out_channels),
130
+ )
131
+
132
+ if conf.two_cond:
133
+ self.cond_emb_layers = nn.Sequential(
134
+ nn.SiLU(),
135
+ linear(conf.cond_emb_channels, conf.out_channels),
136
+ )
137
+ #############################
138
+ # OUT LAYERS (ignored when there is no condition)
139
+ #############################
140
+ # original version
141
+ conv = conv_nd(conf.dims,
142
+ conf.out_channels,
143
+ conf.out_channels,
144
+ 3,
145
+ padding=1)
146
+ if conf.use_zero_module:
147
+ # zere out the weights
148
+ # it seems to help training
149
+ conv = zero_module(conv)
150
+
151
+ # construct the layers
152
+ # - norm
153
+ # - (modulation)
154
+ # - act
155
+ # - dropout
156
+ # - conv
157
+ layers = []
158
+ layers += [
159
+ normalization(conf.out_channels, limit=conf.group_norm_limit if "group_norm_limit" in conf.__dict__ else 32),
160
+ nn.SiLU(),
161
+ nn.Dropout(p=conf.dropout),
162
+ conv,
163
+ ]
164
+ self.out_layers = nn.Sequential(*layers)
165
+
166
+ #############################
167
+ # SKIP LAYERS
168
+ #############################
169
+ if conf.out_channels == conf.channels:
170
+ # cannot be used with gatedconv, also gatedconv is alsways used as the first block
171
+ self.skip_connection = nn.Identity()
172
+ else:
173
+ if conf.use_conv:
174
+ kernel_size = 3
175
+ padding = 1
176
+ else:
177
+ kernel_size = 1
178
+ padding = 0
179
+
180
+ self.skip_connection = conv_nd(conf.dims,
181
+ conf.channels,
182
+ conf.out_channels,
183
+ kernel_size,
184
+ padding=padding)
185
+
186
+ def forward(self, x, emb=None, cond=None, lateral=None):
187
+ """
188
+ Apply the block to a Tensor, conditioned on a timestep embedding.
189
+
190
+ Args:
191
+ x: input
192
+ lateral: lateral connection from the encoder
193
+ """
194
+ return torch_checkpoint(self._forward, (x, emb, cond, lateral),
195
+ self.conf.use_checkpoint)
196
+
197
+ def _forward(
198
+ self,
199
+ x,
200
+ emb=None,
201
+ cond=None,
202
+ lateral=None,
203
+ ):
204
+ """
205
+ Args:
206
+ lateral: required if "has_lateral" and non-gated, with gated, it can be supplied optionally
207
+ """
208
+ if self.conf.has_lateral:
209
+ # lateral may be supplied even if it doesn't require
210
+ # the model will take the lateral only if "has_lateral"
211
+ assert lateral is not None
212
+ x = th.cat([x, lateral], dim=1)
213
+
214
+ if self.updown:
215
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
216
+ h = in_rest(x)
217
+ h = self.h_upd(h)
218
+ x = self.x_upd(x)
219
+ h = in_conv(h)
220
+ else:
221
+ h = self.in_layers(x)
222
+
223
+ if self.conf.use_condition:
224
+ # it's possible that the network may not receieve the time emb
225
+ # this happens with autoenc and setting the time_at
226
+ if emb is not None:
227
+ emb_out = self.emb_layers(emb).type(h.dtype)
228
+ else:
229
+ emb_out = None
230
+
231
+ if self.conf.two_cond:
232
+ # it's possible that the network is two_cond
233
+ # but it doesn't get the second condition
234
+ # in which case, we ignore the second condition
235
+ # and treat as if the network has one condition
236
+ if cond is None:
237
+ cond_out = None
238
+ else:
239
+ cond_out = self.cond_emb_layers(cond).type(h.dtype)
240
+
241
+ if cond_out is not None:
242
+ while len(cond_out.shape) < len(h.shape):
243
+ cond_out = cond_out[..., None]
244
+ else:
245
+ cond_out = None
246
+
247
+ # this is the new refactored code
248
+ h = apply_conditions(
249
+ h=h,
250
+ emb=emb_out,
251
+ cond=cond_out,
252
+ layers=self.out_layers,
253
+ scale_bias=1,
254
+ in_channels=self.conf.out_channels,
255
+ up_down_layer=None,
256
+ )
257
+
258
+ return self.skip_connection(x) + h
259
+
260
+
261
+ def apply_conditions(
262
+ h,
263
+ emb=None,
264
+ cond=None,
265
+ layers: nn.Sequential = None,
266
+ scale_bias: float = 1,
267
+ in_channels: int = 512,
268
+ up_down_layer: nn.Module = None,
269
+ ):
270
+ """
271
+ apply conditions on the feature maps
272
+
273
+ Args:
274
+ emb: time conditional (ready to scale + shift)
275
+ cond: encoder's conditional (read to scale + shift)
276
+ """
277
+ two_cond = emb is not None and cond is not None
278
+
279
+ if emb is not None:
280
+ # adjusting shapes
281
+ while len(emb.shape) < len(h.shape):
282
+ emb = emb[..., None]
283
+
284
+ if two_cond:
285
+ # adjusting shapes
286
+ while len(cond.shape) < len(h.shape):
287
+ cond = cond[..., None]
288
+ # time first
289
+ scale_shifts = [emb, cond]
290
+ else:
291
+ # "cond" is not used with single cond mode
292
+ scale_shifts = [emb]
293
+
294
+ # support scale, shift or shift only
295
+ for i, each in enumerate(scale_shifts):
296
+ if each is None:
297
+ # special case: the condition is not provided
298
+ a = None
299
+ b = None
300
+ else:
301
+ if each.shape[1] == in_channels * 2:
302
+ a, b = th.chunk(each, 2, dim=1)
303
+ else:
304
+ a = each
305
+ b = None
306
+ scale_shifts[i] = (a, b)
307
+
308
+ # condition scale bias could be a list
309
+ if isinstance(scale_bias, Number):
310
+ biases = [scale_bias] * len(scale_shifts)
311
+ else:
312
+ # a list
313
+ biases = scale_bias
314
+
315
+ # default, the scale & shift are applied after the group norm but BEFORE SiLU
316
+ pre_layers, post_layers = layers[0], layers[1:]
317
+
318
+ # spilt the post layer to be able to scale up or down before conv
319
+ # post layers will contain only the conv
320
+ mid_layers, post_layers = post_layers[:-2], post_layers[-2:]
321
+
322
+ h = pre_layers(h)
323
+ # scale and shift for each condition
324
+ for i, (scale, shift) in enumerate(scale_shifts):
325
+ # if scale is None, it indicates that the condition is not provided
326
+ if scale is not None:
327
+ h = h * (biases[i] + scale)
328
+ if shift is not None:
329
+ h = h + shift
330
+ h = mid_layers(h)
331
+
332
+ # upscale or downscale if any just before the last conv
333
+ if up_down_layer is not None:
334
+ h = up_down_layer(h)
335
+ h = post_layers(h)
336
+ return h
337
+
338
+
339
+ class Upsample(nn.Module):
340
+ """
341
+ An upsampling layer with an optional convolution.
342
+
343
+ :param channels: channels in the inputs and outputs.
344
+ :param use_conv: a bool determining if a convolution is applied.
345
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
346
+ upsampling occurs in the inner-two dimensions.
347
+ """
348
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
349
+ super().__init__()
350
+ self.channels = channels
351
+ self.out_channels = out_channels or channels
352
+ self.use_conv = use_conv
353
+ self.dims = dims
354
+ if use_conv:
355
+ self.conv = conv_nd(dims,
356
+ self.channels,
357
+ self.out_channels,
358
+ 3,
359
+ padding=1)
360
+
361
+ def forward(self, x):
362
+ assert x.shape[1] == self.channels
363
+ if self.dims == 3:
364
+ x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
365
+ mode="nearest")
366
+ else:
367
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
368
+ if self.use_conv:
369
+ x = self.conv(x)
370
+ return x
371
+
372
+
373
+ class Downsample(nn.Module):
374
+ """
375
+ A downsampling layer with an optional convolution.
376
+
377
+ :param channels: channels in the inputs and outputs.
378
+ :param use_conv: a bool determining if a convolution is applied.
379
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
380
+ downsampling occurs in the inner-two dimensions.
381
+ """
382
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
383
+ super().__init__()
384
+ self.channels = channels
385
+ self.out_channels = out_channels or channels
386
+ self.use_conv = use_conv
387
+ self.dims = dims
388
+ stride = 2 if dims != 3 else (1, 2, 2)
389
+ if use_conv:
390
+ self.op = conv_nd(dims,
391
+ self.channels,
392
+ self.out_channels,
393
+ 3,
394
+ stride=stride,
395
+ padding=1)
396
+ else:
397
+ assert self.channels == self.out_channels
398
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
399
+
400
+ def forward(self, x):
401
+ assert x.shape[1] == self.channels
402
+ return self.op(x)
403
+
404
+
405
+ class AttentionBlock(nn.Module):
406
+ """
407
+ An attention block that allows spatial positions to attend to each other.
408
+
409
+ Originally ported from here, but adapted to the N-d case.
410
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
411
+ """
412
+ def __init__(
413
+ self,
414
+ channels,
415
+ num_heads=1,
416
+ num_head_channels=-1,
417
+ group_norm_limit=32,
418
+ use_checkpoint=False,
419
+ use_new_attention_order=False,
420
+ ):
421
+ super().__init__()
422
+ self.channels = channels
423
+ if num_head_channels == -1:
424
+ self.num_heads = num_heads
425
+ else:
426
+ assert (
427
+ channels % num_head_channels == 0
428
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
429
+ self.num_heads = channels // num_head_channels
430
+ self.use_checkpoint = use_checkpoint
431
+ self.norm = normalization(channels, limit=group_norm_limit)
432
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
433
+ if use_new_attention_order:
434
+ # split qkv before split heads
435
+ self.attention = QKVAttention(self.num_heads)
436
+ else:
437
+ # split heads before split qkv
438
+ self.attention = QKVAttentionLegacy(self.num_heads)
439
+
440
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
441
+
442
+ def forward(self, x):
443
+ return torch_checkpoint(self._forward, (x, ), self.use_checkpoint)
444
+
445
+ def _forward(self, x):
446
+ b, c, *spatial = x.shape
447
+ x = x.reshape(b, c, -1)
448
+ qkv = self.qkv(self.norm(x))
449
+ h = self.attention(qkv)
450
+ h = self.proj_out(h)
451
+ return (x + h).reshape(b, c, *spatial)
452
+
453
+
454
+ def count_flops_attn(model, _x, y):
455
+ """
456
+ A counter for the `thop` package to count the operations in an
457
+ attention operation.
458
+ Meant to be used like:
459
+ macs, params = thop.profile(
460
+ model,
461
+ inputs=(inputs, timestamps),
462
+ custom_ops={QKVAttention: QKVAttention.count_flops},
463
+ )
464
+ """
465
+ b, c, *spatial = y[0].shape
466
+ num_spatial = int(np.prod(spatial))
467
+ # We perform two matmuls with the same number of ops.
468
+ # The first computes the weight matrix, the second computes
469
+ # the combination of the value vectors.
470
+ matmul_ops = 2 * b * (num_spatial**2) * c
471
+ model.total_ops += th.DoubleTensor([matmul_ops])
472
+
473
+
474
+ class QKVAttentionLegacy(nn.Module):
475
+ """
476
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
477
+ """
478
+ def __init__(self, n_heads):
479
+ super().__init__()
480
+ self.n_heads = n_heads
481
+
482
+ def forward(self, qkv):
483
+ """
484
+ Apply QKV attention.
485
+
486
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
487
+ :return: an [N x (H * C) x T] tensor after attention.
488
+ """
489
+ bs, width, length = qkv.shape
490
+ assert width % (3 * self.n_heads) == 0
491
+ ch = width // (3 * self.n_heads)
492
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch,
493
+ dim=1)
494
+ scale = 1 / math.sqrt(math.sqrt(ch))
495
+ weight = th.einsum(
496
+ "bct,bcs->bts", q * scale,
497
+ k * scale) # More stable with f16 than dividing afterwards
498
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
499
+ a = th.einsum("bts,bcs->bct", weight, v)
500
+ return a.reshape(bs, -1, length)
501
+
502
+ @staticmethod
503
+ def count_flops(model, _x, y):
504
+ return count_flops_attn(model, _x, y)
505
+
506
+
507
+ class QKVAttention(nn.Module):
508
+ """
509
+ A module which performs QKV attention and splits in a different order.
510
+ """
511
+ def __init__(self, n_heads):
512
+ super().__init__()
513
+ self.n_heads = n_heads
514
+
515
+ def forward(self, qkv):
516
+ """
517
+ Apply QKV attention.
518
+
519
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
520
+ :return: an [N x (H * C) x T] tensor after attention.
521
+ """
522
+ bs, width, length = qkv.shape
523
+ assert width % (3 * self.n_heads) == 0
524
+ ch = width // (3 * self.n_heads)
525
+ q, k, v = qkv.chunk(3, dim=1)
526
+ scale = 1 / math.sqrt(math.sqrt(ch))
527
+ weight = th.einsum(
528
+ "bct,bcs->bts",
529
+ (q * scale).view(bs * self.n_heads, ch, length),
530
+ (k * scale).view(bs * self.n_heads, ch, length),
531
+ ) # More stable with f16 than dividing afterwards
532
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
533
+ a = th.einsum("bts,bcs->bct", weight,
534
+ v.reshape(bs * self.n_heads, ch, length))
535
+ return a.reshape(bs, -1, length)
536
+
537
+ @staticmethod
538
+ def count_flops(model, _x, y):
539
+ return count_flops_attn(model, _x, y)
540
+
541
+
542
+ class AttentionPool2d(nn.Module):
543
+ """
544
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
545
+ """
546
+ def __init__(
547
+ self,
548
+ spacial_dim: int,
549
+ embed_dim: int,
550
+ num_heads_channels: int,
551
+ output_dim: int = None,
552
+ ):
553
+ super().__init__()
554
+ self.positional_embedding = nn.Parameter(
555
+ th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
556
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
557
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
558
+ self.num_heads = embed_dim // num_heads_channels
559
+ self.attention = QKVAttention(self.num_heads)
560
+
561
+ def forward(self, x):
562
+ b, c, *_spatial = x.shape
563
+ x = x.reshape(b, c, -1) # NC(HW)
564
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
565
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
566
+ x = self.qkv_proj(x)
567
+ x = self.attention(x)
568
+ x = self.c_proj(x)
569
+ return x[:, :, 0]
DiffAE_model_latentnet.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from enum import Enum
4
+ from typing import NamedTuple, Tuple
5
+
6
+ import torch
7
+ from .DiffAE_support_choices import *
8
+ from .DiffAE_support_config_base import BaseConfig
9
+ from torch import nn
10
+ from torch.nn import init
11
+
12
+ from .DiffAE_model_blocks import *
13
+ from .DiffAE_model_nn import timestep_embedding
14
+ from .DiffAE_model_unet import *
15
+
16
+
17
+ class LatentNetType(Enum):
18
+ none = 'none'
19
+ # injecting inputs into the hidden layers
20
+ skip = 'skip'
21
+
22
+
23
+ class LatentNetReturn(NamedTuple):
24
+ pred: torch.Tensor = None
25
+
26
+
27
+ @dataclass
28
+ class MLPSkipNetConfig(BaseConfig):
29
+ """
30
+ default MLP for the latent DPM in the paper!
31
+ """
32
+ num_channels: int
33
+ skip_layers: Tuple[int]
34
+ num_hid_channels: int
35
+ num_layers: int
36
+ num_time_emb_channels: int = 64
37
+ activation: Activation = Activation.silu
38
+ use_norm: bool = True
39
+ condition_bias: float = 1
40
+ dropout: float = 0
41
+ last_act: Activation = Activation.none
42
+ num_time_layers: int = 2
43
+ time_last_act: bool = False
44
+
45
+ def make_model(self):
46
+ return MLPSkipNet(self)
47
+
48
+
49
+ class MLPSkipNet(nn.Module):
50
+ """
51
+ concat x to hidden layers
52
+
53
+ default MLP for the latent DPM in the paper!
54
+ """
55
+ def __init__(self, conf: MLPSkipNetConfig):
56
+ super().__init__()
57
+ self.conf = conf
58
+
59
+ layers = []
60
+ for i in range(conf.num_time_layers):
61
+ if i == 0:
62
+ a = conf.num_time_emb_channels
63
+ b = conf.num_channels
64
+ else:
65
+ a = conf.num_channels
66
+ b = conf.num_channels
67
+ layers.append(nn.Linear(a, b))
68
+ if i < conf.num_time_layers - 1 or conf.time_last_act:
69
+ layers.append(conf.activation.get_act())
70
+ self.time_embed = nn.Sequential(*layers)
71
+
72
+ self.layers = nn.ModuleList([])
73
+ for i in range(conf.num_layers):
74
+ if i == 0:
75
+ act = conf.activation
76
+ norm = conf.use_norm
77
+ cond = True
78
+ a, b = conf.num_channels, conf.num_hid_channels
79
+ dropout = conf.dropout
80
+ elif i == conf.num_layers - 1:
81
+ act = Activation.none
82
+ norm = False
83
+ cond = False
84
+ a, b = conf.num_hid_channels, conf.num_channels
85
+ dropout = 0
86
+ else:
87
+ act = conf.activation
88
+ norm = conf.use_norm
89
+ cond = True
90
+ a, b = conf.num_hid_channels, conf.num_hid_channels
91
+ dropout = conf.dropout
92
+
93
+ if i in conf.skip_layers:
94
+ a += conf.num_channels
95
+
96
+ self.layers.append(
97
+ MLPLNAct(
98
+ a,
99
+ b,
100
+ norm=norm,
101
+ activation=act,
102
+ cond_channels=conf.num_channels,
103
+ use_cond=cond,
104
+ condition_bias=conf.condition_bias,
105
+ dropout=dropout,
106
+ ))
107
+ self.last_act = conf.last_act.get_act()
108
+
109
+ def forward(self, x, t, **kwargs):
110
+ t = timestep_embedding(t, self.conf.num_time_emb_channels)
111
+ cond = self.time_embed(t)
112
+ h = x
113
+ for i in range(len(self.layers)):
114
+ if i in self.conf.skip_layers:
115
+ # injecting input into the hidden layers
116
+ h = torch.cat([h, x], dim=1)
117
+ h = self.layers[i].forward(x=h, cond=cond)
118
+ h = self.last_act(h)
119
+ return LatentNetReturn(h)
120
+
121
+
122
+ class MLPLNAct(nn.Module):
123
+ def __init__(
124
+ self,
125
+ in_channels: int,
126
+ out_channels: int,
127
+ norm: bool,
128
+ use_cond: bool,
129
+ activation: Activation,
130
+ cond_channels: int,
131
+ condition_bias: float = 0,
132
+ dropout: float = 0,
133
+ ):
134
+ super().__init__()
135
+ self.activation = activation
136
+ self.condition_bias = condition_bias
137
+ self.use_cond = use_cond
138
+
139
+ self.linear = nn.Linear(in_channels, out_channels)
140
+ self.act = activation.get_act()
141
+ if self.use_cond:
142
+ self.linear_emb = nn.Linear(cond_channels, out_channels)
143
+ self.cond_layers = nn.Sequential(self.act, self.linear_emb)
144
+ if norm:
145
+ self.norm = nn.LayerNorm(out_channels)
146
+ else:
147
+ self.norm = nn.Identity()
148
+
149
+ if dropout > 0:
150
+ self.dropout = nn.Dropout(p=dropout)
151
+ else:
152
+ self.dropout = nn.Identity()
153
+
154
+ self.init_weights()
155
+
156
+ def init_weights(self):
157
+ for module in self.modules():
158
+ if isinstance(module, nn.Linear):
159
+ if self.activation == Activation.relu:
160
+ init.kaiming_normal_(module.weight,
161
+ a=0,
162
+ nonlinearity='relu')
163
+ elif self.activation == Activation.lrelu:
164
+ init.kaiming_normal_(module.weight,
165
+ a=0.2,
166
+ nonlinearity='leaky_relu')
167
+ elif self.activation == Activation.silu:
168
+ init.kaiming_normal_(module.weight,
169
+ a=0,
170
+ nonlinearity='relu')
171
+ else:
172
+ # leave it as default
173
+ pass
174
+
175
+ def forward(self, x, cond=None):
176
+ x = self.linear(x)
177
+ if self.use_cond:
178
+ # (n, c) or (n, c * 2)
179
+ cond = self.cond_layers(cond)
180
+ cond = (cond, None)
181
+
182
+ # scale shift first
183
+ x = x * (self.condition_bias + cond[0])
184
+ if cond[1] is not None:
185
+ x = x + cond[1]
186
+ # then norm
187
+ x = self.norm(x)
188
+ else:
189
+ # no condition
190
+ x = self.norm(x)
191
+ x = self.act(x)
192
+ x = self.dropout(x)
193
+ return x
DiffAE_model_nn.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Various utilities for neural networks.
3
+ """
4
+
5
+ from enum import Enum
6
+ import math
7
+ from typing import Optional
8
+
9
+ import torch as th
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+
13
+ import torch.nn.functional as F
14
+
15
+
16
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
17
+ class SiLU(nn.Module):
18
+ # @th.jit.script
19
+ def forward(self, x):
20
+ return x * th.sigmoid(x)
21
+
22
+
23
+ class GroupNorm32(nn.GroupNorm):
24
+ def forward(self, x):
25
+ return super().forward(x.float()).type(x.dtype)
26
+
27
+
28
+ def conv_nd(dims, *args, **kwargs):
29
+ """
30
+ Create a 1D, 2D, or 3D convolution module.
31
+ """
32
+ if dims == 1:
33
+ return nn.Conv1d(*args, **kwargs)
34
+ elif dims == 2:
35
+ return nn.Conv2d(*args, **kwargs)
36
+ elif dims == 3:
37
+ return nn.Conv3d(*args, **kwargs)
38
+ raise ValueError(f"unsupported dimensions: {dims}")
39
+
40
+
41
+ def linear(*args, **kwargs):
42
+ """
43
+ Create a linear module.
44
+ """
45
+ return nn.Linear(*args, **kwargs)
46
+
47
+
48
+ def avg_pool_nd(dims, *args, **kwargs):
49
+ """
50
+ Create a 1D, 2D, or 3D average pooling module.
51
+ """
52
+ if dims == 1:
53
+ return nn.AvgPool1d(*args, **kwargs)
54
+ elif dims == 2:
55
+ return nn.AvgPool2d(*args, **kwargs)
56
+ elif dims == 3:
57
+ return nn.AvgPool3d(*args, **kwargs)
58
+ raise ValueError(f"unsupported dimensions: {dims}")
59
+
60
+
61
+ def update_ema(target_params, source_params, rate=0.99):
62
+ """
63
+ Update target parameters to be closer to those of source parameters using
64
+ an exponential moving average.
65
+
66
+ :param target_params: the target parameter sequence.
67
+ :param source_params: the source parameter sequence.
68
+ :param rate: the EMA rate (closer to 1 means slower).
69
+ """
70
+ for targ, src in zip(target_params, source_params):
71
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
72
+
73
+
74
+ def zero_module(module):
75
+ """
76
+ Zero out the parameters of a module and return it.
77
+ """
78
+ for p in module.parameters():
79
+ p.detach().zero_()
80
+ return module
81
+
82
+
83
+ def scale_module(module, scale):
84
+ """
85
+ Scale the parameters of a module and return it.
86
+ """
87
+ for p in module.parameters():
88
+ p.detach().mul_(scale)
89
+ return module
90
+
91
+
92
+ def mean_flat(tensor):
93
+ """
94
+ Take the mean over all non-batch dimensions.
95
+ """
96
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
97
+
98
+
99
+ def normalization(channels, limit=32):
100
+ """
101
+ Make a standard normalization layer.
102
+
103
+ :param channels: number of input channels.
104
+ :param limit: the maximum number of groups. It's required if the number of net_channel is too small. Default: 32 (Added by Soumick, default from original)
105
+ :return: an nn.Module for normalization.
106
+ """
107
+ return GroupNorm32(min(limit, channels), channels)
108
+
109
+
110
+ def timestep_embedding(timesteps, dim, max_period=10000):
111
+ """
112
+ Create sinusoidal timestep embeddings.
113
+
114
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
115
+ These may be fractional.
116
+ :param dim: the dimension of the output.
117
+ :param max_period: controls the minimum frequency of the embeddings.
118
+ :return: an [N x dim] Tensor of positional embeddings.
119
+ """
120
+ half = dim // 2
121
+ freqs = th.exp(-math.log(max_period) *
122
+ th.arange(start=0, end=half, dtype=th.float32) /
123
+ half).to(device=timesteps.device)
124
+ args = timesteps[:, None].float() * freqs[None]
125
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
126
+ if dim % 2:
127
+ embedding = th.cat(
128
+ [embedding, th.zeros_like(embedding[:, :1])], dim=-1)
129
+ return embedding
130
+
131
+
132
+ def torch_checkpoint(func, args, flag, preserve_rng_state=False):
133
+ # torch's gradient checkpoint works with automatic mixed precision, given torch >= 1.8
134
+ if flag:
135
+ return torch.utils.checkpoint.checkpoint(
136
+ func, *args, preserve_rng_state=preserve_rng_state)
137
+ else:
138
+ return func(*args)
DiffAE_model_unet.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from numbers import Number
4
+ from typing import NamedTuple, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch as th
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ from .DiffAE_support_choices import *
11
+ from .DiffAE_support_config_base import BaseConfig
12
+ from .DiffAE_model_blocks import *
13
+
14
+ from .DiffAE_model_nn import (conv_nd, linear, normalization, timestep_embedding,
15
+ torch_checkpoint, zero_module)
16
+
17
+
18
+ @dataclass
19
+ class BeatGANsUNetConfig(BaseConfig):
20
+ image_size: int = 64
21
+ in_channels: int = 3
22
+ # base channels, will be multiplied
23
+ model_channels: int = 64
24
+ # output of the unet
25
+ # suggest: 3
26
+ # you only need 6 if you also model the variance of the noise prediction (usually we use an analytical variance hence 3)
27
+ out_channels: int = 3
28
+ # how many repeating resblocks per resolution
29
+ # the decoding side would have "one more" resblock
30
+ # default: 2
31
+ num_res_blocks: int = 2
32
+ # you can also set the number of resblocks specifically for the input blocks
33
+ # default: None = above
34
+ num_input_res_blocks: int = None
35
+ # number of time embed channels and style channels
36
+ embed_channels: int = 512
37
+ # at what resolutions you want to do self-attention of the feature maps
38
+ # attentions generally improve performance
39
+ # default: [16]
40
+ # beatgans: [32, 16, 8]
41
+ attention_resolutions: Tuple[int] = (16, )
42
+ # number of time embed channels
43
+ time_embed_channels: int = None
44
+ # dropout applies to the resblocks (on feature maps)
45
+ dropout: float = 0.1
46
+ channel_mult: Tuple[int] = (1, 2, 4, 8)
47
+ input_channel_mult: Tuple[int] = None
48
+ conv_resample: bool = True
49
+ group_norm_limit: int = 32
50
+ # always 2 = 2d conv
51
+ dims: int = 2
52
+ # don't use this, legacy from BeatGANs
53
+ num_classes: int = None
54
+ use_checkpoint: bool = False
55
+ # number of attention heads
56
+ num_heads: int = 1
57
+ # or specify the number of channels per attention head
58
+ num_head_channels: int = -1
59
+ # what's this?
60
+ num_heads_upsample: int = -1
61
+ # use resblock for upscale/downscale blocks (expensive)
62
+ # default: True (BeatGANs)
63
+ resblock_updown: bool = True
64
+ # never tried
65
+ use_new_attention_order: bool = False
66
+ resnet_two_cond: bool = False
67
+ resnet_cond_channels: int = None
68
+ # init the decoding conv layers with zero weights, this speeds up training
69
+ # default: True (BeattGANs)
70
+ resnet_use_zero_module: bool = True
71
+ # gradient checkpoint the attention operation
72
+ attn_checkpoint: bool = False
73
+
74
+ def make_model(self):
75
+ return BeatGANsUNetModel(self)
76
+
77
+
78
+ class BeatGANsUNetModel(nn.Module):
79
+ def __init__(self, conf: BeatGANsUNetConfig):
80
+ super().__init__()
81
+ self.conf = conf
82
+
83
+ if conf.num_heads_upsample == -1:
84
+ self.num_heads_upsample = conf.num_heads
85
+
86
+ self.dtype = th.float32
87
+
88
+ self.time_emb_channels = conf.time_embed_channels or conf.model_channels
89
+ self.time_embed = nn.Sequential(
90
+ linear(self.time_emb_channels, conf.embed_channels),
91
+ nn.SiLU(),
92
+ linear(conf.embed_channels, conf.embed_channels),
93
+ )
94
+
95
+ if conf.num_classes is not None:
96
+ self.label_emb = nn.Embedding(conf.num_classes,
97
+ conf.embed_channels)
98
+
99
+ ch = input_ch = int(conf.channel_mult[0] * conf.model_channels)
100
+ self.input_blocks = nn.ModuleList([
101
+ TimestepEmbedSequential(
102
+ conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1))
103
+ ])
104
+
105
+ kwargs = dict(
106
+ use_condition=True,
107
+ two_cond=conf.resnet_two_cond,
108
+ use_zero_module=conf.resnet_use_zero_module,
109
+ # style channels for the resnet block
110
+ cond_emb_channels=conf.resnet_cond_channels,
111
+ )
112
+
113
+ self._feature_size = ch
114
+
115
+ # input_block_chans = [ch]
116
+ input_block_chans = [[] for _ in range(len(conf.channel_mult))]
117
+ input_block_chans[0].append(ch)
118
+
119
+ # number of blocks at each resolution
120
+ self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))]
121
+ self.input_num_blocks[0] = 1
122
+ self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))]
123
+
124
+ ds = 1
125
+ resolution = conf.image_size
126
+ for level, mult in enumerate(conf.input_channel_mult
127
+ or conf.channel_mult):
128
+ for _ in range(conf.num_input_res_blocks or conf.num_res_blocks):
129
+ layers = [
130
+ ResBlockConfig(
131
+ ch,
132
+ conf.embed_channels,
133
+ conf.dropout,
134
+ out_channels=int(mult * conf.model_channels),
135
+ group_norm_limit=conf.group_norm_limit,
136
+ dims=conf.dims,
137
+ use_checkpoint=conf.use_checkpoint,
138
+ **kwargs,
139
+ ).make_model()
140
+ ]
141
+ ch = int(mult * conf.model_channels)
142
+ if resolution in conf.attention_resolutions:
143
+ layers.append(
144
+ AttentionBlock(
145
+ ch,
146
+ use_checkpoint=conf.use_checkpoint
147
+ or conf.attn_checkpoint,
148
+ num_heads=conf.num_heads,
149
+ num_head_channels=conf.num_head_channels,
150
+ group_norm_limit=conf.group_norm_limit,
151
+ use_new_attention_order=conf.
152
+ use_new_attention_order,
153
+ ))
154
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
155
+ self._feature_size += ch
156
+ # input_block_chans.append(ch)
157
+ input_block_chans[level].append(ch)
158
+ self.input_num_blocks[level] += 1
159
+ # print(input_block_chans)
160
+ if level != len(conf.channel_mult) - 1:
161
+ resolution //= 2
162
+ out_ch = ch
163
+ self.input_blocks.append(
164
+ TimestepEmbedSequential(
165
+ ResBlockConfig(
166
+ ch,
167
+ conf.embed_channels,
168
+ conf.dropout,
169
+ out_channels=out_ch,
170
+ group_norm_limit=conf.group_norm_limit,
171
+ dims=conf.dims,
172
+ use_checkpoint=conf.use_checkpoint,
173
+ down=True,
174
+ **kwargs,
175
+ ).make_model() if conf.
176
+ resblock_updown else Downsample(ch,
177
+ conf.conv_resample,
178
+ dims=conf.dims,
179
+ out_channels=out_ch)))
180
+ ch = out_ch
181
+ # input_block_chans.append(ch)
182
+ input_block_chans[level + 1].append(ch)
183
+ self.input_num_blocks[level + 1] += 1
184
+ ds *= 2
185
+ self._feature_size += ch
186
+
187
+ self.middle_block = TimestepEmbedSequential(
188
+ ResBlockConfig(
189
+ ch,
190
+ conf.embed_channels,
191
+ conf.dropout,
192
+ group_norm_limit=conf.group_norm_limit,
193
+ dims=conf.dims,
194
+ use_checkpoint=conf.use_checkpoint,
195
+ **kwargs,
196
+ ).make_model(),
197
+ AttentionBlock(
198
+ ch,
199
+ use_checkpoint=conf.use_checkpoint or conf.attn_checkpoint,
200
+ num_heads=conf.num_heads,
201
+ num_head_channels=conf.num_head_channels,
202
+ group_norm_limit=conf.group_norm_limit,
203
+ use_new_attention_order=conf.use_new_attention_order,
204
+ ),
205
+ ResBlockConfig(
206
+ ch,
207
+ conf.embed_channels,
208
+ conf.dropout,
209
+ group_norm_limit=conf.group_norm_limit,
210
+ dims=conf.dims,
211
+ use_checkpoint=conf.use_checkpoint,
212
+ **kwargs,
213
+ ).make_model(),
214
+ )
215
+ self._feature_size += ch
216
+
217
+ self.output_blocks = nn.ModuleList([])
218
+ for level, mult in list(enumerate(conf.channel_mult))[::-1]:
219
+ for i in range(conf.num_res_blocks + 1):
220
+ # print(input_block_chans)
221
+ # ich = input_block_chans.pop()
222
+ try:
223
+ ich = input_block_chans[level].pop()
224
+ except IndexError:
225
+ # this happens only when num_res_block > num_enc_res_block
226
+ # we will not have enough lateral (skip) connecions for all decoder blocks
227
+ ich = 0
228
+ # print('pop:', ich)
229
+ layers = [
230
+ ResBlockConfig(
231
+ # only direct channels when gated
232
+ channels=ch + ich,
233
+ emb_channels=conf.embed_channels,
234
+ dropout=conf.dropout,
235
+ out_channels=int(conf.model_channels * mult),
236
+ group_norm_limit=conf.group_norm_limit,
237
+ dims=conf.dims,
238
+ use_checkpoint=conf.use_checkpoint,
239
+ # lateral channels are described here when gated
240
+ has_lateral=True if ich > 0 else False,
241
+ lateral_channels=None,
242
+ **kwargs,
243
+ ).make_model()
244
+ ]
245
+ ch = int(conf.model_channels * mult)
246
+ if resolution in conf.attention_resolutions:
247
+ layers.append(
248
+ AttentionBlock(
249
+ ch,
250
+ use_checkpoint=conf.use_checkpoint
251
+ or conf.attn_checkpoint,
252
+ num_heads=self.num_heads_upsample,
253
+ num_head_channels=conf.num_head_channels,
254
+ group_norm_limit=conf.group_norm_limit,
255
+ use_new_attention_order=conf.
256
+ use_new_attention_order,
257
+ ))
258
+ if level and i == conf.num_res_blocks:
259
+ resolution *= 2
260
+ out_ch = ch
261
+ layers.append(
262
+ ResBlockConfig(
263
+ ch,
264
+ conf.embed_channels,
265
+ conf.dropout,
266
+ out_channels=out_ch,
267
+ group_norm_limit=conf.group_norm_limit,
268
+ dims=conf.dims,
269
+ use_checkpoint=conf.use_checkpoint,
270
+ up=True,
271
+ **kwargs,
272
+ ).make_model() if (
273
+ conf.resblock_updown
274
+ ) else Upsample(ch,
275
+ conf.conv_resample,
276
+ dims=conf.dims,
277
+ out_channels=out_ch))
278
+ ds //= 2
279
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
280
+ self.output_num_blocks[level] += 1
281
+ self._feature_size += ch
282
+
283
+ # print(input_block_chans)
284
+ # print('inputs:', self.input_num_blocks)
285
+ # print('outputs:', self.output_num_blocks)
286
+
287
+ if conf.resnet_use_zero_module:
288
+ self.out = nn.Sequential(
289
+ normalization(ch, limit=conf.group_norm_limit if "group_norm_limit" in conf.__dict__ else 32),
290
+ nn.SiLU(),
291
+ zero_module(
292
+ conv_nd(conf.dims,
293
+ input_ch,
294
+ conf.out_channels,
295
+ 3,
296
+ padding=1)),
297
+ )
298
+ else:
299
+ self.out = nn.Sequential(
300
+ normalization(ch, limit=conf.group_norm_limit if "group_norm_limit" in conf.__dict__ else 32),
301
+ nn.SiLU(),
302
+ conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1),
303
+ )
304
+
305
+ def forward(self, x, t, y=None, **kwargs):
306
+ """
307
+ Apply the model to an input batch.
308
+
309
+ :param x: an [N x C x ...] Tensor of inputs.
310
+ :param timesteps: a 1-D batch of timesteps.
311
+ :param y: an [N] Tensor of labels, if class-conditional.
312
+ :return: an [N x C x ...] Tensor of outputs.
313
+ """
314
+ assert (y is not None) == (
315
+ self.conf.num_classes is not None
316
+ ), "must specify y if and only if the model is class-conditional"
317
+
318
+ # hs = []
319
+ hs = [[] for _ in range(len(self.conf.channel_mult))]
320
+ emb = self.time_embed(timestep_embedding(t, self.time_emb_channels))
321
+
322
+ if self.conf.num_classes is not None:
323
+ raise NotImplementedError()
324
+ # assert y.shape == (x.shape[0], )
325
+ # emb = emb + self.label_emb(y)
326
+
327
+ # new code supports input_num_blocks != output_num_blocks
328
+ h = x.type(self.dtype)
329
+ k = 0
330
+ for i in range(len(self.input_num_blocks)):
331
+ for j in range(self.input_num_blocks[i]):
332
+ h = self.input_blocks[k](h, emb=emb)
333
+ # print(i, j, h.shape)
334
+ hs[i].append(h)
335
+ k += 1
336
+ assert k == len(self.input_blocks)
337
+
338
+ h = self.middle_block(h, emb=emb)
339
+ k = 0
340
+ for i in range(len(self.output_num_blocks)):
341
+ for j in range(self.output_num_blocks[i]):
342
+ # take the lateral connection from the same layer (in reserve)
343
+ # until there is no more, use None
344
+ try:
345
+ lateral = hs[-i - 1].pop()
346
+ # print(i, j, lateral.shape)
347
+ except IndexError:
348
+ lateral = None
349
+ # print(i, j, lateral)
350
+ h = self.output_blocks[k](h, emb=emb, lateral=lateral)
351
+ k += 1
352
+
353
+ h = h.type(x.dtype)
354
+ pred = self.out(h)
355
+ return Return(pred=pred)
356
+
357
+
358
+ class Return(NamedTuple):
359
+ pred: th.Tensor
360
+
361
+
362
+ @dataclass
363
+ class BeatGANsEncoderConfig(BaseConfig):
364
+ image_size: int
365
+ in_channels: int
366
+ model_channels: int
367
+ out_hid_channels: int
368
+ out_channels: int
369
+ num_res_blocks: int
370
+ attention_resolutions: Tuple[int]
371
+ dropout: float = 0
372
+ channel_mult: Tuple[int] = (1, 2, 4, 8)
373
+ use_time_condition: bool = True
374
+ conv_resample: bool = True
375
+ group_norm_limit: int = 32
376
+ dims: int = 2
377
+ use_checkpoint: bool = False
378
+ num_heads: int = 1
379
+ num_head_channels: int = -1
380
+ resblock_updown: bool = False
381
+ use_new_attention_order: bool = False
382
+ pool: str = 'adaptivenonzero'
383
+
384
+ def make_model(self):
385
+ return BeatGANsEncoderModel(self)
386
+
387
+
388
+ class BeatGANsEncoderModel(nn.Module):
389
+ """
390
+ The half UNet model with attention and timestep embedding.
391
+
392
+ For usage, see UNet.
393
+ """
394
+ def __init__(self, conf: BeatGANsEncoderConfig):
395
+ super().__init__()
396
+ self.conf = conf
397
+ self.dtype = th.float32
398
+
399
+ if conf.use_time_condition:
400
+ time_embed_dim = conf.model_channels * 4
401
+ self.time_embed = nn.Sequential(
402
+ linear(conf.model_channels, time_embed_dim),
403
+ nn.SiLU(),
404
+ linear(time_embed_dim, time_embed_dim),
405
+ )
406
+ else:
407
+ time_embed_dim = None
408
+
409
+ ch = int(conf.channel_mult[0] * conf.model_channels)
410
+ self.input_blocks = nn.ModuleList([
411
+ TimestepEmbedSequential(
412
+ conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1))
413
+ ])
414
+ self._feature_size = ch
415
+ input_block_chans = [ch]
416
+ ds = 1
417
+ resolution = conf.image_size
418
+ for level, mult in enumerate(conf.channel_mult):
419
+ for _ in range(conf.num_res_blocks):
420
+ layers = [
421
+ ResBlockConfig(
422
+ ch,
423
+ time_embed_dim,
424
+ conf.dropout,
425
+ out_channels=int(mult * conf.model_channels),
426
+ group_norm_limit=conf.group_norm_limit,
427
+ dims=conf.dims,
428
+ use_condition=conf.use_time_condition,
429
+ use_checkpoint=conf.use_checkpoint,
430
+ ).make_model()
431
+ ]
432
+ ch = int(mult * conf.model_channels)
433
+ if resolution in conf.attention_resolutions:
434
+ layers.append(
435
+ AttentionBlock(
436
+ ch,
437
+ use_checkpoint=conf.use_checkpoint,
438
+ num_heads=conf.num_heads,
439
+ num_head_channels=conf.num_head_channels,
440
+ group_norm_limit=conf.group_norm_limit,
441
+ use_new_attention_order=conf.
442
+ use_new_attention_order,
443
+ ))
444
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
445
+ self._feature_size += ch
446
+ input_block_chans.append(ch)
447
+ if level != len(conf.channel_mult) - 1:
448
+ resolution //= 2
449
+ out_ch = ch
450
+ self.input_blocks.append(
451
+ TimestepEmbedSequential(
452
+ ResBlockConfig(
453
+ ch,
454
+ time_embed_dim,
455
+ conf.dropout,
456
+ out_channels=out_ch,
457
+ group_norm_limit=conf.group_norm_limit,
458
+ dims=conf.dims,
459
+ use_condition=conf.use_time_condition,
460
+ use_checkpoint=conf.use_checkpoint,
461
+ down=True,
462
+ ).make_model() if (
463
+ conf.resblock_updown
464
+ ) else Downsample(ch,
465
+ conf.conv_resample,
466
+ dims=conf.dims,
467
+ out_channels=out_ch)))
468
+ ch = out_ch
469
+ input_block_chans.append(ch)
470
+ ds *= 2
471
+ self._feature_size += ch
472
+
473
+ self.middle_block = TimestepEmbedSequential(
474
+ ResBlockConfig(
475
+ ch,
476
+ time_embed_dim,
477
+ conf.dropout,
478
+ group_norm_limit=conf.group_norm_limit,
479
+ dims=conf.dims,
480
+ use_condition=conf.use_time_condition,
481
+ use_checkpoint=conf.use_checkpoint,
482
+ ).make_model(),
483
+ AttentionBlock(
484
+ ch,
485
+ use_checkpoint=conf.use_checkpoint,
486
+ num_heads=conf.num_heads,
487
+ num_head_channels=conf.num_head_channels,
488
+ group_norm_limit=conf.group_norm_limit,
489
+ use_new_attention_order=conf.use_new_attention_order,
490
+ ),
491
+ ResBlockConfig(
492
+ ch,
493
+ time_embed_dim,
494
+ conf.dropout,
495
+ group_norm_limit=conf.group_norm_limit,
496
+ dims=conf.dims,
497
+ use_condition=conf.use_time_condition,
498
+ use_checkpoint=conf.use_checkpoint,
499
+ ).make_model(),
500
+ )
501
+ self._feature_size += ch
502
+ if conf.pool == "adaptivenonzero":
503
+ self.out = nn.Sequential(
504
+ normalization(ch, limit=conf.group_norm_limit if "group_norm_limit" in conf.__dict__ else 32),
505
+ nn.SiLU(),
506
+ nn.AdaptiveAvgPool2d((1, 1)) if conf.dims == 2 else nn.AdaptiveAvgPool3d((1, 1, 1)),
507
+ conv_nd(conf.dims, ch, conf.out_channels, 1),
508
+ nn.Flatten(),
509
+ )
510
+ else:
511
+ raise NotImplementedError(f"Unexpected {conf.pool} pooling")
512
+
513
+ def forward(self, x, t=None, return_Nd_feature=False):
514
+ """
515
+ Apply the model to an input batch.
516
+
517
+ :param x: an [N x C x ...] Tensor of inputs.
518
+ :param timesteps: a 1-D batch of timesteps.
519
+ :return: an [N x K] Tensor of outputs.
520
+ """
521
+ if self.conf.use_time_condition:
522
+ emb = self.time_embed(timestep_embedding(t, self.model_channels))
523
+ else:
524
+ emb = None
525
+
526
+ results = []
527
+ h = x.type(self.dtype)
528
+ for module in self.input_blocks:
529
+ h = module(h, emb=emb)
530
+ if self.conf.pool.startswith("spatial"):
531
+ results.append(h.type(x.dtype).mean(dim=(2, 3) if self.conf.dims == 2 else (2, 3, 4)))
532
+ h = self.middle_block(h, emb=emb)
533
+ if self.conf.pool.startswith("spatial"):
534
+ results.append(h.type(x.dtype).mean(dim=(2, 3) if self.conf.dims == 2 else (2, 3, 4)))
535
+ h = th.cat(results, axis=-1)
536
+ else:
537
+ h = h.type(x.dtype)
538
+
539
+ h_Nd = h
540
+ h = self.out(h)
541
+
542
+ if return_Nd_feature:
543
+ return h, h_Nd
544
+ else:
545
+ return h
546
+
547
+ def forward_flatten(self, x):
548
+ """
549
+ transform the last Nd feature into a flatten vector
550
+ """
551
+ h = self.out(x)
552
+ return h
553
+
554
+
555
+ class SuperResModel(BeatGANsUNetModel):
556
+ """
557
+ A UNetModel that performs super-resolution.
558
+
559
+ Expects an extra kwarg `low_res` to condition on a low-resolution image.
560
+ """
561
+ def __init__(self, image_size, in_channels, *args, **kwargs):
562
+ super().__init__(image_size, in_channels * 2, *args, **kwargs)
563
+
564
+ def forward(self, x, timesteps, low_res=None, **kwargs):
565
+ _, _, new_height, new_width = x.shape
566
+ upsampled = F.interpolate(low_res, (new_height, new_width),
567
+ mode="bilinear")
568
+ x = th.cat([x, upsampled], dim=1)
569
+ return super().forward(x, timesteps, **kwargs)
DiffAE_model_unet_autoenc.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch.nn.functional import silu
6
+
7
+ from .DiffAE_model_latentnet import *
8
+ from .DiffAE_model_unet import *
9
+ from .DiffAE_support_choices import *
10
+
11
+
12
+ @dataclass
13
+ class BeatGANsAutoencConfig(BeatGANsUNetConfig):
14
+ # number of style channels
15
+ enc_out_channels: int = 512
16
+ enc_attn_resolutions: Tuple[int] = None
17
+ enc_pool: str = 'depthconv'
18
+ enc_num_res_block: int = 2
19
+ enc_channel_mult: Tuple[int] = None
20
+ enc_grad_checkpoint: bool = False
21
+ latent_net_conf: MLPSkipNetConfig = None
22
+
23
+ def make_model(self):
24
+ return BeatGANsAutoencModel(self)
25
+
26
+
27
+ class BeatGANsAutoencModel(BeatGANsUNetModel):
28
+ def __init__(self, conf: BeatGANsAutoencConfig):
29
+ super().__init__(conf)
30
+ self.conf = conf
31
+
32
+ # having only time, cond
33
+ self.time_embed = TimeStyleSeperateEmbed(
34
+ time_channels=conf.model_channels,
35
+ time_out_channels=conf.embed_channels,
36
+ )
37
+
38
+ self.encoder = BeatGANsEncoderConfig(
39
+ image_size=conf.image_size,
40
+ in_channels=conf.in_channels,
41
+ model_channels=conf.model_channels,
42
+ out_hid_channels=conf.enc_out_channels,
43
+ out_channels=conf.enc_out_channels,
44
+ num_res_blocks=conf.enc_num_res_block,
45
+ attention_resolutions=(conf.enc_attn_resolutions
46
+ or conf.attention_resolutions),
47
+ dropout=conf.dropout,
48
+ channel_mult=conf.enc_channel_mult or conf.channel_mult,
49
+ use_time_condition=False,
50
+ conv_resample=conf.conv_resample,
51
+ group_norm_limit=conf.group_norm_limit,
52
+ dims=conf.dims,
53
+ use_checkpoint=conf.use_checkpoint or conf.enc_grad_checkpoint,
54
+ num_heads=conf.num_heads,
55
+ num_head_channels=conf.num_head_channels,
56
+ resblock_updown=conf.resblock_updown,
57
+ use_new_attention_order=conf.use_new_attention_order,
58
+ pool=conf.enc_pool,
59
+ ).make_model()
60
+
61
+ if conf.latent_net_conf is not None:
62
+ self.latent_net = conf.latent_net_conf.make_model()
63
+
64
+ def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
65
+ """
66
+ Reparameterization trick to sample from N(mu, var) from
67
+ N(0,1).
68
+ :param mu: (Tensor) Mean of the latent Gaussian [B x D]
69
+ :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
70
+ :return: (Tensor) [B x D]
71
+ """
72
+ assert self.conf.is_stochastic
73
+ std = torch.exp(0.5 * logvar)
74
+ eps = torch.randn_like(std)
75
+ return eps * std + mu
76
+
77
+ def sample_z(self, n: int, device):
78
+ assert self.conf.is_stochastic
79
+ return torch.randn(n, self.conf.enc_out_channels, device=device)
80
+
81
+ def noise_to_cond(self, noise: Tensor):
82
+ raise NotImplementedError()
83
+ assert self.conf.noise_net_conf is not None
84
+ return self.noise_net.forward(noise)
85
+
86
+ def encode(self, x):
87
+ cond = self.encoder.forward(x)
88
+ return {'cond': cond}
89
+
90
+ @property
91
+ def stylespace_sizes(self):
92
+ modules = list(self.input_blocks.modules()) + list(
93
+ self.middle_block.modules()) + list(self.output_blocks.modules())
94
+ sizes = []
95
+ for module in modules:
96
+ if isinstance(module, ResBlock):
97
+ linear = module.cond_emb_layers[-1]
98
+ sizes.append(linear.weight.shape[0])
99
+ return sizes
100
+
101
+ def encode_stylespace(self, x, return_vector: bool = True):
102
+ """
103
+ encode to style space
104
+ """
105
+ modules = list(self.input_blocks.modules()) + list(
106
+ self.middle_block.modules()) + list(self.output_blocks.modules())
107
+ # (n, c)
108
+ cond = self.encoder.forward(x)
109
+ S = []
110
+ for module in modules:
111
+ if isinstance(module, ResBlock):
112
+ # (n, c')
113
+ s = module.cond_emb_layers.forward(cond)
114
+ S.append(s)
115
+
116
+ if return_vector:
117
+ # (n, sum_c)
118
+ return torch.cat(S, dim=1)
119
+ else:
120
+ return S
121
+
122
+ def forward(self,
123
+ x,
124
+ t,
125
+ y=None,
126
+ x_start=None,
127
+ cond=None,
128
+ style=None,
129
+ noise=None,
130
+ t_cond=None,
131
+ **kwargs):
132
+ """
133
+ Apply the model to an input batch.
134
+
135
+ Args:
136
+ x_start: the original image to encode
137
+ cond: output of the encoder
138
+ noise: random noise (to predict the cond)
139
+ """
140
+
141
+ if t_cond is None:
142
+ t_cond = t
143
+
144
+ if noise is not None:
145
+ # if the noise is given, we predict the cond from noise
146
+ cond = self.noise_to_cond(noise)
147
+
148
+ if cond is None:
149
+ if x is not None:
150
+ assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}'
151
+
152
+ tmp = self.encode(x_start)
153
+ cond = tmp['cond']
154
+
155
+ if t is not None:
156
+ _t_emb = timestep_embedding(t, self.conf.model_channels)
157
+ _t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels)
158
+ else:
159
+ # this happens when training only autoenc
160
+ _t_emb = None
161
+ _t_cond_emb = None
162
+
163
+ if self.conf.resnet_two_cond:
164
+ res = self.time_embed.forward(
165
+ time_emb=_t_emb,
166
+ cond=cond,
167
+ time_cond_emb=_t_cond_emb,
168
+ )
169
+ else:
170
+ raise NotImplementedError()
171
+
172
+ if self.conf.resnet_two_cond:
173
+ # two cond: first = time emb, second = cond_emb
174
+ emb = res.time_emb
175
+ cond_emb = res.emb
176
+ else:
177
+ # one cond = combined of both time and cond
178
+ emb = res.emb
179
+ cond_emb = None
180
+
181
+ # override the style if given
182
+ style = style or res.style
183
+
184
+ assert (y is not None) == (
185
+ self.conf.num_classes is not None
186
+ ), "must specify y if and only if the model is class-conditional"
187
+
188
+ if self.conf.num_classes is not None:
189
+ raise NotImplementedError()
190
+ # assert y.shape == (x.shape[0], )
191
+ # emb = emb + self.label_emb(y)
192
+
193
+ # where in the model to supply time conditions
194
+ enc_time_emb = emb
195
+ mid_time_emb = emb
196
+ dec_time_emb = emb
197
+ # where in the model to supply style conditions
198
+ enc_cond_emb = cond_emb
199
+ mid_cond_emb = cond_emb
200
+ dec_cond_emb = cond_emb
201
+
202
+ # hs = []
203
+ hs = [[] for _ in range(len(self.conf.channel_mult))]
204
+
205
+ if x is not None:
206
+ h = x.type(self.dtype)
207
+
208
+ # input blocks
209
+ k = 0
210
+ for i in range(len(self.input_num_blocks)):
211
+ for j in range(self.input_num_blocks[i]):
212
+ h = self.input_blocks[k](h,
213
+ emb=enc_time_emb,
214
+ cond=enc_cond_emb)
215
+
216
+ # print(i, j, h.shape)
217
+ hs[i].append(h)
218
+ k += 1
219
+ assert k == len(self.input_blocks)
220
+
221
+ # middle blocks
222
+ h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb)
223
+ else:
224
+ # no lateral connections
225
+ # happens when training only the autonecoder
226
+ h = None
227
+ hs = [[] for _ in range(len(self.conf.channel_mult))]
228
+
229
+ # output blocks
230
+ k = 0
231
+ for i in range(len(self.output_num_blocks)):
232
+ for j in range(self.output_num_blocks[i]):
233
+ # take the lateral connection from the same layer (in reserve)
234
+ # until there is no more, use None
235
+ try:
236
+ lateral = hs[-i - 1].pop()
237
+ # print(i, j, lateral.shape)
238
+ except IndexError:
239
+ lateral = None
240
+ # print(i, j, lateral)
241
+
242
+ h = self.output_blocks[k](h,
243
+ emb=dec_time_emb,
244
+ cond=dec_cond_emb,
245
+ lateral=lateral)
246
+ k += 1
247
+
248
+ pred = self.out(h)
249
+ return AutoencReturn(pred=pred, cond=cond)
250
+
251
+
252
+ class AutoencReturn(NamedTuple):
253
+ pred: Tensor
254
+ cond: Tensor = None
255
+
256
+
257
+ class EmbedReturn(NamedTuple):
258
+ # style and time
259
+ emb: Tensor = None
260
+ # time only
261
+ time_emb: Tensor = None
262
+ # style only (but could depend on time)
263
+ style: Tensor = None
264
+
265
+
266
+ class TimeStyleSeperateEmbed(nn.Module):
267
+ # embed only style
268
+ def __init__(self, time_channels, time_out_channels):
269
+ super().__init__()
270
+ self.time_embed = nn.Sequential(
271
+ linear(time_channels, time_out_channels),
272
+ nn.SiLU(),
273
+ linear(time_out_channels, time_out_channels),
274
+ )
275
+ self.style = nn.Identity()
276
+
277
+ def forward(self, time_emb=None, cond=None, **kwargs):
278
+ if time_emb is None:
279
+ # happens with autoenc training mode
280
+ time_emb = None
281
+ else:
282
+ time_emb = self.time_embed(time_emb)
283
+ style = self.style(cond)
284
+ return EmbedReturn(emb=style, time_emb=time_emb, style=style)
DiffAE_support.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .DiffAE_support_choices import *
2
+ from .DiffAE_support_config_base import *
3
+ from .DiffAE_support_config import *
4
+ from .DiffAE_support_dist_utils import *
5
+ from .DiffAE_support_metrics import *
6
+ from .DiffAE_support_renderer import *
7
+ from .DiffAE_support_templates_latent import *
8
+ from .DiffAE_support_templates import *
9
+ from .DiffAE_support_utils import *
DiffAE_support_choices.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from torch import nn
3
+
4
+
5
+ class TrainMode(Enum):
6
+ # manipulate mode = training the classifier
7
+ manipulate = 'manipulate'
8
+ # default trainin mode!
9
+ diffusion = 'diffusion'
10
+ # default latent training mode!
11
+ # fitting the a DDPM to a given latent
12
+ latent_diffusion = 'latentdiffusion'
13
+
14
+ def is_manipulate(self):
15
+ return self in [
16
+ TrainMode.manipulate,
17
+ ]
18
+
19
+ def is_diffusion(self):
20
+ return self in [
21
+ TrainMode.diffusion,
22
+ TrainMode.latent_diffusion,
23
+ ]
24
+
25
+ def is_autoenc(self):
26
+ # the network possibly does autoencoding
27
+ return self in [
28
+ TrainMode.diffusion,
29
+ ]
30
+
31
+ def is_latent_diffusion(self):
32
+ return self in [
33
+ TrainMode.latent_diffusion,
34
+ ]
35
+
36
+ def use_latent_net(self):
37
+ return self.is_latent_diffusion()
38
+
39
+ def require_dataset_infer(self):
40
+ """
41
+ whether training in this mode requires the latent variables to be available?
42
+ """
43
+ # this will precalculate all the latents before hand
44
+ # and the dataset will be all the predicted latents
45
+ return self in [
46
+ TrainMode.latent_diffusion,
47
+ TrainMode.manipulate,
48
+ ]
49
+
50
+
51
+ class ManipulateMode(Enum):
52
+ """
53
+ how to train the classifier to manipulate
54
+ """
55
+ # train on whole celeba attr dataset
56
+ celebahq_all = 'celebahq_all'
57
+ # celeba with D2C's crop
58
+ d2c_fewshot = 'd2cfewshot'
59
+ d2c_fewshot_allneg = 'd2cfewshotallneg'
60
+
61
+ def is_celeba_attr(self):
62
+ return self in [
63
+ ManipulateMode.d2c_fewshot,
64
+ ManipulateMode.d2c_fewshot_allneg,
65
+ ManipulateMode.celebahq_all,
66
+ ]
67
+
68
+ def is_single_class(self):
69
+ return self in [
70
+ ManipulateMode.d2c_fewshot,
71
+ ManipulateMode.d2c_fewshot_allneg,
72
+ ]
73
+
74
+ def is_fewshot(self):
75
+ return self in [
76
+ ManipulateMode.d2c_fewshot,
77
+ ManipulateMode.d2c_fewshot_allneg,
78
+ ]
79
+
80
+ def is_fewshot_allneg(self):
81
+ return self in [
82
+ ManipulateMode.d2c_fewshot_allneg,
83
+ ]
84
+
85
+
86
+ class ModelType(Enum):
87
+ """
88
+ Kinds of the backbone models
89
+ """
90
+
91
+ # unconditional ddpm
92
+ ddpm = 'ddpm'
93
+ # autoencoding ddpm cannot do unconditional generation
94
+ autoencoder = 'autoencoder'
95
+
96
+ def has_autoenc(self):
97
+ return self in [
98
+ ModelType.autoencoder,
99
+ ]
100
+
101
+ def can_sample(self):
102
+ return self in [ModelType.ddpm]
103
+
104
+
105
+ class ModelName(Enum):
106
+ """
107
+ List of all supported model classes
108
+ """
109
+
110
+ beatgans_ddpm = 'beatgans_ddpm'
111
+ beatgans_autoenc = 'beatgans_autoenc'
112
+
113
+
114
+ class ModelMeanType(Enum):
115
+ """
116
+ Which type of output the model predicts.
117
+ """
118
+
119
+ eps = 'eps' # the model predicts epsilon
120
+
121
+
122
+ class ModelVarType(Enum):
123
+ """
124
+ What is used as the model's output variance.
125
+
126
+ The LEARNED_RANGE option has been added to allow the model to predict
127
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
128
+ """
129
+
130
+ # posterior beta_t
131
+ fixed_small = 'fixed_small'
132
+ # beta_t
133
+ fixed_large = 'fixed_large'
134
+
135
+
136
+ class LossType(Enum):
137
+ mse = 'mse' # use raw MSE loss (and KL when learning variances)
138
+ l1 = 'l1'
139
+
140
+
141
+ class GenerativeType(Enum):
142
+ """
143
+ How's a sample generated
144
+ """
145
+
146
+ ddpm = 'ddpm'
147
+ ddim = 'ddim'
148
+
149
+
150
+ class OptimizerType(Enum):
151
+ adam = 'adam'
152
+ adamw = 'adamw'
153
+
154
+
155
+ class Activation(Enum):
156
+ none = 'none'
157
+ relu = 'relu'
158
+ lrelu = 'lrelu'
159
+ silu = 'silu'
160
+ tanh = 'tanh'
161
+
162
+ def get_act(self):
163
+ if self == Activation.none:
164
+ return nn.Identity()
165
+ elif self == Activation.relu:
166
+ return nn.ReLU()
167
+ elif self == Activation.lrelu:
168
+ return nn.LeakyReLU(negative_slope=0.2)
169
+ elif self == Activation.silu:
170
+ return nn.SiLU()
171
+ elif self == Activation.tanh:
172
+ return nn.Tanh()
173
+ else:
174
+ raise NotImplementedError()
175
+
176
+
177
+ class ManipulateLossType(Enum):
178
+ bce = 'bce'
179
+ mse = 'mse'
DiffAE_support_config.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .DiffAE_model_blocks import ScaleAt
2
+ from .DiffAE_model import *
3
+ from .DiffAE_diffusion_resample import UniformSampler
4
+ from .DiffAE_diffusion_diffusion import space_timesteps
5
+ from typing import Tuple
6
+
7
+ from torch.utils.data import DataLoader
8
+
9
+ from .DiffAE_support_config_base import BaseConfig
10
+ from .DiffAE_support_choices import GenerativeType, LossType, ModelMeanType, ModelVarType
11
+ from .DiffAE_diffusion_base import get_named_beta_schedule
12
+ from .DiffAE_support_choices import *
13
+ from .DiffAE_diffusion_diffusion import SpacedDiffusionBeatGansConfig
14
+ from multiprocessing import get_context
15
+ import os
16
+ from torch.utils.data.distributed import DistributedSampler
17
+
18
+ from dataclasses import dataclass
19
+
20
+ data_paths = {
21
+ 'ffhqlmdb256':
22
+ os.path.expanduser('datasets/ffhq256.lmdb'),
23
+ # used for training a classifier
24
+ 'celeba':
25
+ os.path.expanduser('datasets/celeba'),
26
+ # used for training DPM models
27
+ 'celebalmdb':
28
+ os.path.expanduser('datasets/celeba.lmdb'),
29
+ 'celebahq':
30
+ os.path.expanduser('datasets/celebahq256.lmdb'),
31
+ 'horse256':
32
+ os.path.expanduser('datasets/horse256.lmdb'),
33
+ 'bedroom256':
34
+ os.path.expanduser('datasets/bedroom256.lmdb'),
35
+ 'celeba_anno':
36
+ os.path.expanduser('datasets/celeba_anno/list_attr_celeba.txt'),
37
+ 'celebahq_anno':
38
+ os.path.expanduser(
39
+ 'datasets/celeba_anno/CelebAMask-HQ-attribute-anno.txt'),
40
+ 'celeba_relight':
41
+ os.path.expanduser('datasets/celeba_hq_light/celeba_light.txt'),
42
+ }
43
+
44
+
45
+ @dataclass
46
+ class PretrainConfig(BaseConfig):
47
+ name: str
48
+ path: str
49
+
50
+
51
+ @dataclass
52
+ class TrainConfig(BaseConfig):
53
+ #new params added (Soumick)
54
+ n_dims: int = 2
55
+ in_channels: int = 3
56
+ out_channels: int = 3
57
+ group_norm_limit: int = 32
58
+
59
+ # random seed
60
+ seed: int = 0
61
+ train_mode: TrainMode = TrainMode.diffusion
62
+ train_cond0_prob: float = 0
63
+ train_pred_xstart_detach: bool = True
64
+ train_interpolate_prob: float = 0
65
+ train_interpolate_img: bool = False
66
+ manipulate_mode: ManipulateMode = ManipulateMode.celebahq_all
67
+ manipulate_cls: str = None
68
+ manipulate_shots: int = None
69
+ manipulate_loss: ManipulateLossType = ManipulateLossType.bce
70
+ manipulate_znormalize: bool = False
71
+ manipulate_seed: int = 0
72
+ accum_batches: int = 1
73
+ autoenc_mid_attn: bool = True
74
+ batch_size: int = 16
75
+ batch_size_eval: int = None
76
+ beatgans_gen_type: GenerativeType = GenerativeType.ddim
77
+ beatgans_loss_type: LossType = LossType.mse
78
+ beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps
79
+ beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large
80
+ beatgans_rescale_timesteps: bool = False
81
+ latent_infer_path: str = None
82
+ latent_znormalize: bool = False
83
+ latent_gen_type: GenerativeType = GenerativeType.ddim
84
+ latent_loss_type: LossType = LossType.mse
85
+ latent_model_mean_type: ModelMeanType = ModelMeanType.eps
86
+ latent_model_var_type: ModelVarType = ModelVarType.fixed_large
87
+ latent_rescale_timesteps: bool = False
88
+ latent_T_eval: int = 1_000
89
+ latent_clip_sample: bool = False
90
+ latent_beta_scheduler: str = 'linear'
91
+ beta_scheduler: str = 'linear'
92
+ data_name: str = ''
93
+ data_val_name: str = None
94
+ diffusion_type: str = None
95
+ dropout: float = 0.1
96
+ ema_decay: float = 0.9999
97
+ eval_num_images: int = 5_000
98
+ eval_every_samples: int = 200_000
99
+ eval_ema_every_samples: int = 200_000
100
+ fid_use_torch: bool = True
101
+ fp16: bool = False
102
+ grad_clip: float = 1
103
+ img_size: int = 64
104
+ lr: float = 0.0001
105
+ optimizer: OptimizerType = OptimizerType.adam
106
+ weight_decay: float = 0
107
+ model_conf: ModelConfig = None
108
+ model_name: ModelName = None
109
+ model_type: ModelType = None
110
+ net_attn: Tuple[int] = None
111
+ net_beatgans_attn_head: int = 1
112
+ # not necessarily the same as the the number of style channels
113
+ net_beatgans_embed_channels: int = 512
114
+ net_resblock_updown: bool = True
115
+ net_enc_use_time: bool = False
116
+ net_enc_pool: str = 'adaptivenonzero'
117
+ net_beatgans_gradient_checkpoint: bool = False
118
+ net_beatgans_resnet_two_cond: bool = False
119
+ net_beatgans_resnet_use_zero_module: bool = True
120
+ net_beatgans_resnet_scale_at: ScaleAt = ScaleAt.after_norm
121
+ net_beatgans_resnet_cond_channels: int = None
122
+ net_ch_mult: Tuple[int] = None
123
+ net_ch: int = 64
124
+ net_enc_attn: Tuple[int] = None
125
+ net_enc_k: int = None
126
+ # number of resblocks for the encoder (half-unet)
127
+ net_enc_num_res_blocks: int = 2
128
+ net_enc_channel_mult: Tuple[int] = None
129
+ net_enc_grad_checkpoint: bool = False
130
+ net_autoenc_stochastic: bool = False
131
+ net_latent_activation: Activation = Activation.silu
132
+ net_latent_channel_mult: Tuple[int] = (1, 2, 4)
133
+ net_latent_condition_bias: float = 0
134
+ net_latent_dropout: float = 0
135
+ net_latent_layers: int = None
136
+ net_latent_net_last_act: Activation = Activation.none
137
+ net_latent_net_type: LatentNetType = LatentNetType.none
138
+ net_latent_num_hid_channels: int = 1024
139
+ net_latent_num_time_layers: int = 2
140
+ net_latent_skip_layers: Tuple[int] = None
141
+ net_latent_time_emb_channels: int = 64
142
+ net_latent_use_norm: bool = False
143
+ net_latent_time_last_act: bool = False
144
+ net_num_res_blocks: int = 2
145
+ # number of resblocks for the UNET
146
+ net_num_input_res_blocks: int = None
147
+ net_enc_num_cls: int = None
148
+ num_workers: int = 4
149
+ parallel: bool = False
150
+ postfix: str = ''
151
+ sample_size: int = 64
152
+ sample_every_samples: int = 20_000
153
+ save_every_samples: int = 100_000
154
+ style_ch: int = 512
155
+ T_eval: int = 1_000
156
+ T_sampler: str = 'uniform'
157
+ T: int = 1_000
158
+ total_samples: int = 10_000_000
159
+ warmup: int = 0
160
+ pretrain: PretrainConfig = None
161
+ continue_from: PretrainConfig = None
162
+ eval_programs: Tuple[str] = None
163
+ # if present load the checkpoint from this path instead
164
+ eval_path: str = None
165
+ base_dir: str = 'checkpoints'
166
+ use_cache_dataset: bool = False
167
+ data_cache_dir: str = os.path.expanduser('~/cache')
168
+ work_cache_dir: str = os.path.expanduser('~/mycache')
169
+ # to be overridden
170
+ name: str = ''
171
+
172
+ def refresh_values(self):
173
+ self.img_size = max(self.input_shape)
174
+ self.n_dims = 3 if self.is3D else 2
175
+ self.group_norm_limit = min(32, self.net_ch)
176
+
177
+ def __post_init__(self):
178
+ self.batch_size_eval = self.batch_size_eval or self.batch_size
179
+ self.data_val_name = self.data_val_name or self.data_name
180
+
181
+ def scale_up_gpus(self, num_gpus, num_nodes=1):
182
+ self.eval_ema_every_samples *= num_gpus * num_nodes
183
+ self.eval_every_samples *= num_gpus * num_nodes
184
+ self.sample_every_samples *= num_gpus * num_nodes
185
+ self.batch_size *= num_gpus * num_nodes
186
+ self.batch_size_eval *= num_gpus * num_nodes
187
+ return self
188
+
189
+ @property
190
+ def batch_size_effective(self):
191
+ return self.batch_size * self.accum_batches
192
+
193
+ @property
194
+ def fid_cache(self):
195
+ # we try to use the local dirs to reduce the load over network drives
196
+ # hopefully, this would reduce the disconnection problems with sshfs
197
+ return f'{self.work_cache_dir}/eval_images/{self.data_name}_size{self.img_size}_{self.eval_num_images}'
198
+
199
+ @property
200
+ def data_path(self):
201
+ # may use the cache dir
202
+ path = data_paths[self.data_name]
203
+ if self.use_cache_dataset and path is not None:
204
+ path = use_cached_dataset_path(
205
+ path, f'{self.data_cache_dir}/{self.data_name}')
206
+ return path
207
+
208
+ @property
209
+ def logdir(self):
210
+ return f'{self.base_dir}/{self.name}'
211
+
212
+ @property
213
+ def generate_dir(self):
214
+ # we try to use the local dirs to reduce the load over network drives
215
+ # hopefully, this would reduce the disconnection problems with sshfs
216
+ return f'{self.work_cache_dir}/gen_images/{self.name}'
217
+
218
+ def _make_diffusion_conf(self, T=None):
219
+ if self.diffusion_type == 'beatgans':
220
+ # can use T < self.T for evaluation
221
+ # follows the guided-diffusion repo conventions
222
+ # t's are evenly spaced
223
+ if self.beatgans_gen_type == GenerativeType.ddpm:
224
+ section_counts = [T]
225
+ elif self.beatgans_gen_type == GenerativeType.ddim:
226
+ section_counts = f'ddim{T}'
227
+ else:
228
+ raise NotImplementedError()
229
+
230
+ return SpacedDiffusionBeatGansConfig(
231
+ gen_type=self.beatgans_gen_type,
232
+ model_type=self.model_type,
233
+ betas=get_named_beta_schedule(self.beta_scheduler, self.T),
234
+ model_mean_type=self.beatgans_model_mean_type,
235
+ model_var_type=self.beatgans_model_var_type,
236
+ loss_type=self.beatgans_loss_type,
237
+ rescale_timesteps=self.beatgans_rescale_timesteps,
238
+ use_timesteps=space_timesteps(num_timesteps=self.T,
239
+ section_counts=section_counts),
240
+ fp16=self.fp16,
241
+ )
242
+ else:
243
+ raise NotImplementedError()
244
+
245
+ def _make_latent_diffusion_conf(self, T=None):
246
+ # can use T < self.T for evaluation
247
+ # follows the guided-diffusion repo conventions
248
+ # t's are evenly spaced
249
+ if self.latent_gen_type == GenerativeType.ddpm:
250
+ section_counts = [T]
251
+ elif self.latent_gen_type == GenerativeType.ddim:
252
+ section_counts = f'ddim{T}'
253
+ else:
254
+ raise NotImplementedError()
255
+
256
+ return SpacedDiffusionBeatGansConfig(
257
+ train_pred_xstart_detach=self.train_pred_xstart_detach,
258
+ gen_type=self.latent_gen_type,
259
+ # latent's model is always ddpm
260
+ model_type=ModelType.ddpm,
261
+ # latent shares the beta scheduler and full T
262
+ betas=get_named_beta_schedule(self.latent_beta_scheduler, self.T),
263
+ model_mean_type=self.latent_model_mean_type,
264
+ model_var_type=self.latent_model_var_type,
265
+ loss_type=self.latent_loss_type,
266
+ rescale_timesteps=self.latent_rescale_timesteps,
267
+ use_timesteps=space_timesteps(num_timesteps=self.T,
268
+ section_counts=section_counts),
269
+ fp16=self.fp16,
270
+ )
271
+
272
+ @property
273
+ def model_out_channels(self):
274
+ return self.out_channels
275
+
276
+ def make_T_sampler(self):
277
+ if self.T_sampler == 'uniform':
278
+ return UniformSampler(self.T)
279
+ else:
280
+ raise NotImplementedError()
281
+
282
+ def make_diffusion_conf(self):
283
+ return self._make_diffusion_conf(self.T)
284
+
285
+ def make_eval_diffusion_conf(self):
286
+ return self._make_diffusion_conf(T=self.T_eval)
287
+
288
+ def make_latent_diffusion_conf(self):
289
+ return self._make_latent_diffusion_conf(T=self.T)
290
+
291
+ def make_latent_eval_diffusion_conf(self):
292
+ # latent can have different eval T
293
+ return self._make_latent_diffusion_conf(T=self.latent_T_eval)
294
+
295
+ def make_dataset(self, path=None, **kwargs):
296
+ if self.data_name == 'ffhqlmdb256':
297
+ return FFHQlmdb(path=path or self.data_path,
298
+ image_size=self.img_size,
299
+ **kwargs)
300
+ elif self.data_name == 'horse256':
301
+ return Horse_lmdb(path=path or self.data_path,
302
+ image_size=self.img_size,
303
+ **kwargs)
304
+ elif self.data_name == 'bedroom256':
305
+ return Horse_lmdb(path=path or self.data_path,
306
+ image_size=self.img_size,
307
+ **kwargs)
308
+ elif self.data_name == 'celebalmdb':
309
+ # always use d2c crop
310
+ return CelebAlmdb(path=path or self.data_path,
311
+ image_size=self.img_size,
312
+ original_resolution=None,
313
+ crop_d2c=True,
314
+ **kwargs)
315
+ else:
316
+ raise NotImplementedError()
317
+
318
+ def make_loader(self,
319
+ dataset,
320
+ shuffle: bool,
321
+ num_worker: bool = None,
322
+ drop_last: bool = True,
323
+ batch_size: int = None,
324
+ parallel: bool = False):
325
+ if parallel and distributed.is_initialized():
326
+ # drop last to make sure that there is no added special indexes
327
+ sampler = DistributedSampler(dataset,
328
+ shuffle=shuffle,
329
+ drop_last=True)
330
+ else:
331
+ sampler = None
332
+ return DataLoader(
333
+ dataset,
334
+ batch_size=batch_size or self.batch_size,
335
+ sampler=sampler,
336
+ # with sampler, use the sample instead of this option
337
+ shuffle=False if sampler else shuffle,
338
+ num_workers=num_worker or self.num_workers,
339
+ pin_memory=True,
340
+ drop_last=drop_last,
341
+ multiprocessing_context=get_context('fork'),
342
+ )
343
+
344
+ def make_model_conf(self):
345
+ if self.model_name == ModelName.beatgans_ddpm:
346
+ self.model_type = ModelType.ddpm
347
+ self.model_conf = BeatGANsUNetConfig(
348
+ attention_resolutions=self.net_attn,
349
+ channel_mult=self.net_ch_mult,
350
+ conv_resample=True,
351
+ group_norm_limit=self.group_norm_limit,
352
+ dims=self.n_dims,
353
+ dropout=self.dropout,
354
+ embed_channels=self.net_beatgans_embed_channels,
355
+ image_size=self.img_size,
356
+ in_channels=self.in_channels,
357
+ model_channels=self.net_ch,
358
+ num_classes=None,
359
+ num_head_channels=-1,
360
+ num_heads_upsample=-1,
361
+ num_heads=self.net_beatgans_attn_head,
362
+ num_res_blocks=self.net_num_res_blocks,
363
+ num_input_res_blocks=self.net_num_input_res_blocks,
364
+ out_channels=self.model_out_channels,
365
+ resblock_updown=self.net_resblock_updown,
366
+ use_checkpoint=self.net_beatgans_gradient_checkpoint,
367
+ use_new_attention_order=False,
368
+ resnet_two_cond=self.net_beatgans_resnet_two_cond,
369
+ resnet_use_zero_module=self.
370
+ net_beatgans_resnet_use_zero_module,
371
+ )
372
+ elif self.model_name in [
373
+ ModelName.beatgans_autoenc,
374
+ ]:
375
+ cls = BeatGANsAutoencConfig
376
+ # supports both autoenc and vaeddpm
377
+ if self.model_name == ModelName.beatgans_autoenc:
378
+ self.model_type = ModelType.autoencoder
379
+ else:
380
+ raise NotImplementedError()
381
+
382
+ if self.net_latent_net_type == LatentNetType.none:
383
+ latent_net_conf = None
384
+ elif self.net_latent_net_type == LatentNetType.skip:
385
+ latent_net_conf = MLPSkipNetConfig(
386
+ num_channels=self.style_ch,
387
+ skip_layers=self.net_latent_skip_layers,
388
+ num_hid_channels=self.net_latent_num_hid_channels,
389
+ num_layers=self.net_latent_layers,
390
+ num_time_emb_channels=self.net_latent_time_emb_channels,
391
+ activation=self.net_latent_activation,
392
+ use_norm=self.net_latent_use_norm,
393
+ condition_bias=self.net_latent_condition_bias,
394
+ dropout=self.net_latent_dropout,
395
+ last_act=self.net_latent_net_last_act,
396
+ num_time_layers=self.net_latent_num_time_layers,
397
+ time_last_act=self.net_latent_time_last_act,
398
+ )
399
+ else:
400
+ raise NotImplementedError()
401
+
402
+ self.model_conf = cls(
403
+ attention_resolutions=self.net_attn,
404
+ channel_mult=self.net_ch_mult,
405
+ conv_resample=True,
406
+ group_norm_limit=self.group_norm_limit,
407
+ dims=self.n_dims,
408
+ dropout=self.dropout,
409
+ embed_channels=self.net_beatgans_embed_channels,
410
+ enc_out_channels=self.style_ch,
411
+ enc_pool=self.net_enc_pool,
412
+ enc_num_res_block=self.net_enc_num_res_blocks,
413
+ enc_channel_mult=self.net_enc_channel_mult,
414
+ enc_grad_checkpoint=self.net_enc_grad_checkpoint,
415
+ enc_attn_resolutions=self.net_enc_attn,
416
+ image_size=self.img_size,
417
+ in_channels=self.in_channels,
418
+ model_channels=self.net_ch,
419
+ num_classes=None,
420
+ num_head_channels=-1,
421
+ num_heads_upsample=-1,
422
+ num_heads=self.net_beatgans_attn_head,
423
+ num_res_blocks=self.net_num_res_blocks,
424
+ num_input_res_blocks=self.net_num_input_res_blocks,
425
+ out_channels=self.model_out_channels,
426
+ resblock_updown=self.net_resblock_updown,
427
+ use_checkpoint=self.net_beatgans_gradient_checkpoint,
428
+ use_new_attention_order=False,
429
+ resnet_two_cond=self.net_beatgans_resnet_two_cond,
430
+ resnet_use_zero_module=self.
431
+ net_beatgans_resnet_use_zero_module,
432
+ latent_net_conf=latent_net_conf,
433
+ resnet_cond_channels=self.net_beatgans_resnet_cond_channels,
434
+ )
435
+ else:
436
+ raise NotImplementedError(self.model_name)
437
+
438
+ return self.model_conf
DiffAE_support_config_base.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from copy import deepcopy
4
+ from dataclasses import dataclass
5
+
6
+
7
+ @dataclass
8
+ class BaseConfig:
9
+ def clone(self):
10
+ return deepcopy(self)
11
+
12
+ def inherit(self, another):
13
+ """inherit common keys from a given config"""
14
+ common_keys = set(self.__dict__.keys()) & set(another.__dict__.keys())
15
+ for k in common_keys:
16
+ setattr(self, k, getattr(another, k))
17
+
18
+ def propagate(self):
19
+ """push down the configuration to all members"""
20
+ for k, v in self.__dict__.items():
21
+ if isinstance(v, BaseConfig):
22
+ v.inherit(self)
23
+ v.propagate()
24
+
25
+ def save(self, save_path):
26
+ """save config to json file"""
27
+ dirname = os.path.dirname(save_path)
28
+ if not os.path.exists(dirname):
29
+ os.makedirs(dirname)
30
+ conf = self.as_dict_jsonable()
31
+ with open(save_path, 'w') as f:
32
+ json.dump(conf, f)
33
+
34
+ def load(self, load_path):
35
+ """load json config"""
36
+ with open(load_path) as f:
37
+ conf = json.load(f)
38
+ self.from_dict(conf)
39
+
40
+ def from_dict(self, dict, strict=False):
41
+ for k, v in dict.items():
42
+ if not hasattr(self, k):
43
+ if strict:
44
+ raise ValueError(f"loading extra '{k}'")
45
+ else:
46
+ print(f"loading extra '{k}'")
47
+ continue
48
+ if isinstance(self.__dict__[k], BaseConfig):
49
+ self.__dict__[k].from_dict(v)
50
+ else:
51
+ self.__dict__[k] = v
52
+
53
+ def as_dict_jsonable(self):
54
+ conf = {}
55
+ for k, v in self.__dict__.items():
56
+ if isinstance(v, BaseConfig):
57
+ conf[k] = v.as_dict_jsonable()
58
+ else:
59
+ if jsonable(v):
60
+ conf[k] = v
61
+ else:
62
+ # ignore not jsonable
63
+ pass
64
+ return conf
65
+
66
+
67
+ def jsonable(x):
68
+ try:
69
+ json.dumps(x)
70
+ return True
71
+ except TypeError:
72
+ return False
DiffAE_support_dist_utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from torch import distributed
3
+
4
+
5
+ def barrier():
6
+ if distributed.is_initialized():
7
+ distributed.barrier()
8
+ else:
9
+ pass
10
+
11
+
12
+ def broadcast(data, src):
13
+ if distributed.is_initialized():
14
+ distributed.broadcast(data, src)
15
+ else:
16
+ pass
17
+
18
+
19
+ def all_gather(data: List, src):
20
+ if distributed.is_initialized():
21
+ distributed.all_gather(data, src)
22
+ else:
23
+ data[0] = src
24
+
25
+
26
+ def get_rank():
27
+ if distributed.is_initialized():
28
+ return distributed.get_rank()
29
+ else:
30
+ return 0
31
+
32
+
33
+ def get_world_size():
34
+ if distributed.is_initialized():
35
+ return distributed.get_world_size()
36
+ else:
37
+ return 1
38
+
39
+
40
+ def chunk_size(size, rank, world_size):
41
+ extra = rank < size % world_size
42
+ return size // world_size + extra
DiffAE_support_metrics.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+ import torch
5
+ import torchvision
6
+ from pytorch_fid import fid_score
7
+ from torch import distributed
8
+ from torch.utils.data import DataLoader
9
+ from torch.utils.data.distributed import DistributedSampler
10
+ from tqdm.autonotebook import tqdm, trange
11
+
12
+ from .DiffAE_support_renderer import *
13
+ from .DiffAE_support_config import *
14
+ from .DiffAE_diffusion_diffusion import SpacedDiffusionBeatGans as Sampler
15
+ import lpips
16
+ from ssim import compute_ssim as ssim
17
+
18
+
19
+ def make_subset_loader(conf: TrainConfig,
20
+ dataset,
21
+ batch_size: int,
22
+ shuffle: bool,
23
+ parallel: bool,
24
+ drop_last=True):
25
+ dataset = SubsetDataset(dataset, size=conf.eval_num_images)
26
+ if parallel and distributed.is_initialized():
27
+ sampler = DistributedSampler(dataset, shuffle=shuffle)
28
+ else:
29
+ sampler = None
30
+ return DataLoader(
31
+ dataset,
32
+ batch_size=batch_size,
33
+ sampler=sampler,
34
+ # with sampler, use the sample instead of this option
35
+ shuffle=False if sampler else shuffle,
36
+ num_workers=conf.num_workers,
37
+ pin_memory=True,
38
+ drop_last=drop_last,
39
+ multiprocessing_context=get_context('fork'),
40
+ )
41
+
42
+
43
+ def evaluate_lpips(
44
+ sampler: Sampler,
45
+ model: Model,
46
+ conf: TrainConfig,
47
+ device,
48
+ val_data,
49
+ latent_sampler: Sampler = None,
50
+ use_inverted_noise: bool = False,
51
+ ):
52
+ """
53
+ compare the generated images from autoencoder on validation dataset
54
+
55
+ Args:
56
+ use_inversed_noise: the noise is also inverted from DDIM
57
+ """
58
+ lpips_fn = lpips.LPIPS(net='alex').to(device)
59
+ val_loader = make_subset_loader(conf,
60
+ dataset=val_data,
61
+ batch_size=conf.batch_size_eval,
62
+ shuffle=False,
63
+ parallel=True)
64
+
65
+ model.eval()
66
+ with torch.no_grad():
67
+ scores = {
68
+ 'lpips': [],
69
+ 'mse': [],
70
+ 'ssim': [],
71
+ 'psnr': [],
72
+ }
73
+ for batch in tqdm(val_loader, desc='lpips'):
74
+ imgs = batch['img'].to(device)
75
+
76
+ if use_inverted_noise:
77
+ # inverse the noise
78
+ # with condition from the encoder
79
+ model_kwargs = {}
80
+ if conf.model_type.has_autoenc():
81
+ with torch.no_grad():
82
+ model_kwargs = model.encode(imgs)
83
+ x_T = sampler.ddim_reverse_sample_loop(
84
+ model=model,
85
+ x=imgs,
86
+ clip_denoised=True,
87
+ model_kwargs=model_kwargs)
88
+ x_T = x_T['sample']
89
+ else:
90
+ x_T = torch.randn((len(imgs), 3, conf.img_size, conf.img_size),
91
+ device=device)
92
+
93
+ if conf.model_type == ModelType.ddpm:
94
+ # the case where you want to calculate the inversion capability of the DDIM model
95
+ assert use_inverted_noise
96
+ pred_imgs = render_uncondition(
97
+ conf=conf,
98
+ model=model,
99
+ x_T=x_T,
100
+ sampler=sampler,
101
+ latent_sampler=latent_sampler,
102
+ )
103
+ else:
104
+ pred_imgs = render_condition(conf=conf,
105
+ model=model,
106
+ x_T=x_T,
107
+ x_start=imgs,
108
+ cond=None,
109
+ sampler=sampler)
110
+ # # returns {'cond', 'cond2'}
111
+ # conds = model.encode(imgs)
112
+ # pred_imgs = sampler.sample(model=model,
113
+ # noise=x_T,
114
+ # model_kwargs=conds)
115
+
116
+ # (n, 1, 1, 1) => (n, )
117
+ scores['lpips'].append(lpips_fn.forward(imgs, pred_imgs).view(-1))
118
+
119
+ # need to normalize into [0, 1]
120
+ norm_imgs = (imgs + 1) / 2
121
+ norm_pred_imgs = (pred_imgs + 1) / 2
122
+ # (n, )
123
+ scores['ssim'].append(
124
+ ssim(norm_imgs, norm_pred_imgs, size_average=False))
125
+ # (n, )
126
+ scores['mse'].append(
127
+ (norm_imgs - norm_pred_imgs).pow(2).mean(dim=[1, 2, 3]))
128
+ # (n, )
129
+ scores['psnr'].append(psnr(norm_imgs, norm_pred_imgs))
130
+ # (N, )
131
+ for key in scores.keys():
132
+ scores[key] = torch.cat(scores[key]).float()
133
+ model.train()
134
+
135
+ barrier()
136
+
137
+ # support multi-gpu
138
+ outs = {
139
+ key: [
140
+ torch.zeros(len(scores[key]), device=device)
141
+ for i in range(get_world_size())
142
+ ]
143
+ for key in scores.keys()
144
+ }
145
+ for key in scores.keys():
146
+ all_gather(outs[key], scores[key])
147
+
148
+ # final scores
149
+ for key in scores.keys():
150
+ scores[key] = torch.cat(outs[key]).mean().item()
151
+
152
+ # {'lpips', 'mse', 'ssim'}
153
+ return scores
154
+
155
+
156
+ def psnr(img1, img2):
157
+ """
158
+ Args:
159
+ img1: (n, c, h, w)
160
+ """
161
+ v_max = 1.
162
+ # (n,)
163
+ mse = torch.mean((img1 - img2)**2, dim=[1, 2, 3])
164
+ return 20 * torch.log10(v_max / torch.sqrt(mse))
165
+
166
+
167
+ def evaluate_fid(
168
+ sampler: Sampler,
169
+ model: Model,
170
+ conf: TrainConfig,
171
+ device,
172
+ train_data,
173
+ val_data,
174
+ latent_sampler: Sampler = None,
175
+ conds_mean=None,
176
+ conds_std=None,
177
+ remove_cache: bool = True,
178
+ clip_latent_noise: bool = False,
179
+ ):
180
+ assert conf.fid_cache is not None
181
+ if get_rank() == 0:
182
+ # no parallel
183
+ # validation data for a comparing FID
184
+ val_loader = make_subset_loader(conf,
185
+ dataset=val_data,
186
+ batch_size=conf.batch_size_eval,
187
+ shuffle=False,
188
+ parallel=False)
189
+
190
+ # put the val images to a directory
191
+ cache_dir = f'{conf.fid_cache}_{conf.eval_num_images}'
192
+ if (os.path.exists(cache_dir)
193
+ and len(os.listdir(cache_dir)) < conf.eval_num_images):
194
+ shutil.rmtree(cache_dir)
195
+
196
+ if not os.path.exists(cache_dir):
197
+ # write files to the cache
198
+ # the images are normalized, hence need to denormalize first
199
+ loader_to_path(val_loader, cache_dir, denormalize=True)
200
+
201
+ # create the generate dir
202
+ if os.path.exists(conf.generate_dir):
203
+ shutil.rmtree(conf.generate_dir)
204
+ os.makedirs(conf.generate_dir)
205
+
206
+ barrier()
207
+
208
+ world_size = get_world_size()
209
+ rank = get_rank()
210
+ batch_size = chunk_size(conf.batch_size_eval, rank, world_size)
211
+
212
+ def filename(idx):
213
+ return world_size * idx + rank
214
+
215
+ model.eval()
216
+ with torch.no_grad():
217
+ if conf.model_type.can_sample():
218
+ eval_num_images = chunk_size(conf.eval_num_images, rank,
219
+ world_size)
220
+ desc = "generating images"
221
+ for i in trange(0, eval_num_images, batch_size, desc=desc):
222
+ batch_size = min(batch_size, eval_num_images - i)
223
+ x_T = torch.randn(
224
+ (batch_size, 3, conf.img_size, conf.img_size),
225
+ device=device)
226
+ batch_images = render_uncondition(
227
+ conf=conf,
228
+ model=model,
229
+ x_T=x_T,
230
+ sampler=sampler,
231
+ latent_sampler=latent_sampler,
232
+ conds_mean=conds_mean,
233
+ conds_std=conds_std).cpu()
234
+
235
+ batch_images = (batch_images + 1) / 2
236
+ # keep the generated images
237
+ for j in range(len(batch_images)):
238
+ img_name = filename(i + j)
239
+ torchvision.utils.save_image(
240
+ batch_images[j],
241
+ os.path.join(conf.generate_dir, f'{img_name}.png'))
242
+ elif conf.model_type == ModelType.autoencoder:
243
+ if conf.train_mode.is_latent_diffusion():
244
+ # evaluate autoencoder + latent diffusion (doesn't give the images)
245
+ model: BeatGANsAutoencModel
246
+ eval_num_images = chunk_size(conf.eval_num_images, rank,
247
+ world_size)
248
+ desc = "generating images"
249
+ for i in trange(0, eval_num_images, batch_size, desc=desc):
250
+ batch_size = min(batch_size, eval_num_images - i)
251
+ x_T = torch.randn(
252
+ (batch_size, 3, conf.img_size, conf.img_size),
253
+ device=device)
254
+ batch_images = render_uncondition(
255
+ conf=conf,
256
+ model=model,
257
+ x_T=x_T,
258
+ sampler=sampler,
259
+ latent_sampler=latent_sampler,
260
+ conds_mean=conds_mean,
261
+ conds_std=conds_std,
262
+ clip_latent_noise=clip_latent_noise,
263
+ ).cpu()
264
+ batch_images = (batch_images + 1) / 2
265
+ # keep the generated images
266
+ for j in range(len(batch_images)):
267
+ img_name = filename(i + j)
268
+ torchvision.utils.save_image(
269
+ batch_images[j],
270
+ os.path.join(conf.generate_dir, f'{img_name}.png'))
271
+ else:
272
+ # evaulate autoencoder (given the images)
273
+ # to make the FID fair, autoencoder must not see the validation dataset
274
+ # also shuffle to make it closer to unconditional generation
275
+ train_loader = make_subset_loader(conf,
276
+ dataset=train_data,
277
+ batch_size=batch_size,
278
+ shuffle=True,
279
+ parallel=True)
280
+
281
+ i = 0
282
+ for batch in tqdm(train_loader, desc='generating images'):
283
+ imgs = batch['img'].to(device)
284
+ x_T = torch.randn(
285
+ (len(imgs), 3, conf.img_size, conf.img_size),
286
+ device=device)
287
+ batch_images = render_condition(
288
+ conf=conf,
289
+ model=model,
290
+ x_T=x_T,
291
+ x_start=imgs,
292
+ cond=None,
293
+ sampler=sampler,
294
+ latent_sampler=latent_sampler).cpu()
295
+ # model: BeatGANsAutoencModel
296
+ # # returns {'cond', 'cond2'}
297
+ # conds = model.encode(imgs)
298
+ # batch_images = sampler.sample(model=model,
299
+ # noise=x_T,
300
+ # model_kwargs=conds).cpu()
301
+ # denormalize the images
302
+ batch_images = (batch_images + 1) / 2
303
+ # keep the generated images
304
+ for j in range(len(batch_images)):
305
+ img_name = filename(i + j)
306
+ torchvision.utils.save_image(
307
+ batch_images[j],
308
+ os.path.join(conf.generate_dir, f'{img_name}.png'))
309
+ i += len(imgs)
310
+ else:
311
+ raise NotImplementedError()
312
+ model.train()
313
+
314
+ barrier()
315
+
316
+ if get_rank() == 0:
317
+ fid = fid_score.calculate_fid_given_paths(
318
+ [cache_dir, conf.generate_dir],
319
+ batch_size,
320
+ device=device,
321
+ dims=2048)
322
+
323
+ # remove the cache
324
+ if remove_cache and os.path.exists(conf.generate_dir):
325
+ shutil.rmtree(conf.generate_dir)
326
+
327
+ barrier()
328
+
329
+ if get_rank() == 0:
330
+ # need to float it! unless the broadcasted value is wrong
331
+ fid = torch.tensor(float(fid), device=device)
332
+ broadcast(fid, 0)
333
+ else:
334
+ fid = torch.tensor(0., device=device)
335
+ broadcast(fid, 0)
336
+ fid = fid.item()
337
+ print(f'fid ({get_rank()}):', fid)
338
+
339
+ return fid
340
+
341
+
342
+ def loader_to_path(loader: DataLoader, path: str, denormalize: bool):
343
+ # not process safe!
344
+
345
+ if not os.path.exists(path):
346
+ os.makedirs(path)
347
+
348
+ # write the loader to files
349
+ i = 0
350
+ for batch in tqdm(loader, desc='copy images'):
351
+ imgs = batch['img']
352
+ if denormalize:
353
+ imgs = (imgs + 1) / 2
354
+ for j in range(len(imgs)):
355
+ torchvision.utils.save_image(imgs[j],
356
+ os.path.join(path, f'{i+j}.png'))
357
+ i += len(imgs)
DiffAE_support_renderer.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .DiffAE_support_config import *
2
+ from .DiffAE_diffusion_diffusion import SpacedDiffusionBeatGans as Sampler
3
+ from .DiffAE_model_unet_autoenc import BeatGANsAutoencModel
4
+
5
+ from torch.cuda import amp
6
+
7
+
8
+ def render_uncondition(conf: TrainConfig,
9
+ model: BeatGANsAutoencModel,
10
+ x_T,
11
+ sampler: Sampler,
12
+ latent_sampler: Sampler,
13
+ conds_mean=None,
14
+ conds_std=None,
15
+ clip_latent_noise: bool = False):
16
+ device = x_T.device
17
+ if conf.train_mode == TrainMode.diffusion:
18
+ assert conf.model_type.can_sample()
19
+ return sampler.sample(model=model, noise=x_T)
20
+ elif conf.train_mode.is_latent_diffusion():
21
+ model: BeatGANsAutoencModel
22
+ if conf.train_mode == TrainMode.latent_diffusion:
23
+ latent_noise = torch.randn(len(x_T), conf.style_ch, device=device)
24
+ else:
25
+ raise NotImplementedError()
26
+
27
+ if clip_latent_noise:
28
+ latent_noise = latent_noise.clip(-1, 1)
29
+
30
+ cond = latent_sampler.sample(
31
+ model=model.latent_net,
32
+ noise=latent_noise,
33
+ clip_denoised=conf.latent_clip_sample,
34
+ )
35
+
36
+ if conf.latent_znormalize:
37
+ cond = cond * conds_std.to(device) + conds_mean.to(device)
38
+
39
+ # the diffusion on the model
40
+ return sampler.sample(model=model, noise=x_T, cond=cond)
41
+ else:
42
+ raise NotImplementedError()
43
+
44
+
45
+ def render_condition(
46
+ conf: TrainConfig,
47
+ model: BeatGANsAutoencModel,
48
+ x_T,
49
+ sampler: Sampler,
50
+ x_start=None,
51
+ cond=None,
52
+ ):
53
+ if conf.train_mode == TrainMode.diffusion:
54
+ assert conf.model_type.has_autoenc()
55
+ # returns {'cond', 'cond2'}
56
+ if cond is None:
57
+ cond = model.encode(x_start)
58
+ return sampler.sample(model=model,
59
+ noise=x_T,
60
+ model_kwargs={'cond': cond})
61
+ else:
62
+ raise NotImplementedError()
DiffAE_support_templates.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .DiffAE_support_config import *
2
+
3
+
4
+ def ddpm():
5
+ """
6
+ base configuration for all DDIM-based models.
7
+ """
8
+ conf = TrainConfig()
9
+ conf.batch_size = 32
10
+ conf.beatgans_gen_type = GenerativeType.ddim
11
+ conf.beta_scheduler = 'linear'
12
+ conf.data_name = 'ffhq'
13
+ conf.diffusion_type = 'beatgans'
14
+ conf.eval_ema_every_samples = 200_000
15
+ conf.eval_every_samples = 200_000
16
+ conf.fp16 = True
17
+ conf.lr = 1e-4
18
+ conf.model_name = ModelName.beatgans_ddpm
19
+ conf.net_attn = (16, )
20
+ conf.net_beatgans_attn_head = 1
21
+ conf.net_beatgans_embed_channels = 512
22
+ conf.net_ch_mult = (1, 2, 4, 8)
23
+ conf.net_ch = 64
24
+ conf.sample_size = 32
25
+ conf.T_eval = 20
26
+ conf.T = 1000
27
+ conf.make_model_conf()
28
+ return conf
29
+
30
+
31
+ def autoenc_base():
32
+ """
33
+ base configuration for all Diff-AE models.
34
+ """
35
+ conf = TrainConfig()
36
+ conf.batch_size = 32
37
+ conf.beatgans_gen_type = GenerativeType.ddim
38
+ conf.beta_scheduler = 'linear'
39
+ conf.data_name = 'ffhq'
40
+ conf.diffusion_type = 'beatgans'
41
+ conf.eval_ema_every_samples = 200_000
42
+ conf.eval_every_samples = 200_000
43
+ conf.fp16 = True
44
+ conf.lr = 1e-4
45
+ conf.model_name = ModelName.beatgans_autoenc
46
+ conf.net_attn = (16, )
47
+ conf.net_beatgans_attn_head = 1
48
+ conf.net_beatgans_embed_channels = 512
49
+ conf.net_beatgans_resnet_two_cond = True
50
+ conf.net_ch_mult = (1, 2, 4, 8)
51
+ conf.net_ch = 64
52
+ conf.net_enc_channel_mult = (1, 2, 4, 8, 8)
53
+ conf.net_enc_pool = 'adaptivenonzero'
54
+ conf.sample_size = 32
55
+ conf.T_eval = 20
56
+ conf.T = 1000
57
+ conf.make_model_conf()
58
+ return conf
59
+
60
+ def ffhq64_ddpm():
61
+ conf = ddpm()
62
+ conf.data_name = 'ffhqlmdb256'
63
+ conf.warmup = 0
64
+ conf.total_samples = 72_000_000
65
+ conf.scale_up_gpus(4)
66
+ return conf
67
+
68
+
69
+ def ffhq64_autoenc():
70
+ conf = autoenc_base()
71
+ conf.data_name = 'ffhqlmdb256'
72
+ conf.warmup = 0
73
+ conf.total_samples = 72_000_000
74
+ conf.net_ch_mult = (1, 2, 4, 8)
75
+ conf.net_enc_channel_mult = (1, 2, 4, 8, 8)
76
+ conf.eval_every_samples = 1_000_000
77
+ conf.eval_ema_every_samples = 1_000_000
78
+ conf.scale_up_gpus(4)
79
+ conf.make_model_conf()
80
+ return conf
81
+
82
+
83
+ def celeba64d2c_ddpm():
84
+ conf = ffhq128_ddpm()
85
+ conf.data_name = 'celebalmdb'
86
+ conf.eval_every_samples = 10_000_000
87
+ conf.eval_ema_every_samples = 10_000_000
88
+ conf.total_samples = 72_000_000
89
+ conf.name = 'celeba64d2c_ddpm'
90
+ return conf
91
+
92
+
93
+ def celeba64d2c_autoenc():
94
+ conf = ffhq64_autoenc()
95
+ conf.data_name = 'celebalmdb'
96
+ conf.eval_every_samples = 10_000_000
97
+ conf.eval_ema_every_samples = 10_000_000
98
+ conf.total_samples = 72_000_000
99
+ conf.name = 'celeba64d2c_autoenc'
100
+ return conf
101
+
102
+
103
+ def ffhq128_ddpm():
104
+ conf = ddpm()
105
+ conf.data_name = 'ffhqlmdb256'
106
+ conf.warmup = 0
107
+ conf.total_samples = 48_000_000
108
+ conf.img_size = 128
109
+ conf.net_ch = 128
110
+ # channels:
111
+ # 3 => 128 * 1 => 128 * 1 => 128 * 2 => 128 * 3 => 128 * 4
112
+ # sizes:
113
+ # 128 => 128 => 64 => 32 => 16 => 8
114
+ conf.net_ch_mult = (1, 1, 2, 3, 4)
115
+ conf.eval_every_samples = 1_000_000
116
+ conf.eval_ema_every_samples = 1_000_000
117
+ conf.scale_up_gpus(4)
118
+ conf.eval_ema_every_samples = 10_000_000
119
+ conf.eval_every_samples = 10_000_000
120
+ conf.make_model_conf()
121
+ return conf
122
+
123
+
124
+ def ffhq128_autoenc_base():
125
+ conf = autoenc_base()
126
+ conf.data_name = 'ffhqlmdb256'
127
+ conf.scale_up_gpus(4)
128
+ conf.img_size = 128
129
+ conf.net_ch = 128
130
+ # final resolution = 8x8
131
+ conf.net_ch_mult = (1, 1, 2, 3, 4)
132
+ # final resolution = 4x4
133
+ conf.net_enc_channel_mult = (1, 1, 2, 3, 4, 4)
134
+ conf.eval_ema_every_samples = 10_000_000
135
+ conf.eval_every_samples = 10_000_000
136
+ conf.make_model_conf()
137
+ return conf
138
+
139
+ def ffhq256_autoenc():
140
+ conf = ffhq128_autoenc_base()
141
+ conf.img_size = 256
142
+ conf.net_ch = 128
143
+ conf.net_ch_mult = (1, 1, 2, 2, 4, 4)
144
+ conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4)
145
+ conf.eval_every_samples = 10_000_000
146
+ conf.eval_ema_every_samples = 10_000_000
147
+ conf.total_samples = 200_000_000
148
+ conf.batch_size = 64
149
+ conf.make_model_conf()
150
+ conf.name = 'ffhq256_autoenc'
151
+ return conf
152
+
153
+
154
+ def ffhq256_autoenc_eco():
155
+ conf = ffhq128_autoenc_base()
156
+ conf.img_size = 256
157
+ conf.net_ch = 128
158
+ conf.net_ch_mult = (1, 1, 2, 2, 4, 4)
159
+ conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4)
160
+ conf.eval_every_samples = 10_000_000
161
+ conf.eval_ema_every_samples = 10_000_000
162
+ conf.total_samples = 200_000_000
163
+ conf.batch_size = 64
164
+ conf.make_model_conf()
165
+ conf.name = 'ffhq256_autoenc_eco'
166
+ return conf
167
+
168
+
169
+ def ffhq128_ddpm_72M():
170
+ conf = ffhq128_ddpm()
171
+ conf.total_samples = 72_000_000
172
+ conf.name = 'ffhq128_ddpm_72M'
173
+ return conf
174
+
175
+
176
+ def ffhq128_autoenc_72M():
177
+ conf = ffhq128_autoenc_base()
178
+ conf.total_samples = 72_000_000
179
+ conf.name = 'ffhq128_autoenc_72M'
180
+ return conf
181
+
182
+
183
+ def ffhq128_ddpm_130M():
184
+ conf = ffhq128_ddpm()
185
+ conf.total_samples = 130_000_000
186
+ conf.eval_ema_every_samples = 10_000_000
187
+ conf.eval_every_samples = 10_000_000
188
+ conf.name = 'ffhq128_ddpm_130M'
189
+ return conf
190
+
191
+
192
+ def ffhq128_autoenc_130M():
193
+ conf = ffhq128_autoenc_base()
194
+ conf.total_samples = 130_000_000
195
+ conf.eval_ema_every_samples = 10_000_000
196
+ conf.eval_every_samples = 10_000_000
197
+ conf.name = 'ffhq128_autoenc_130M'
198
+ return conf
199
+
200
+ #created from ffhq128_autoenc_130M
201
+ def ukbb_autoenc(ds_name="ukbb", n_latents=128):
202
+ conf = TrainConfig()
203
+ conf.beatgans_gen_type = GenerativeType.ddim
204
+ conf.beta_scheduler = 'linear'
205
+ conf.diffusion_type = 'beatgans'
206
+ conf.fp16 = True
207
+ conf.model_name = ModelName.beatgans_autoenc
208
+ conf.net_attn = (16, )
209
+ conf.net_beatgans_attn_head = 1
210
+ conf.net_beatgans_embed_channels = n_latents
211
+ conf.style_ch = n_latents
212
+ conf.net_beatgans_resnet_two_cond = True
213
+ conf.net_enc_pool = 'adaptivenonzero'
214
+ conf.sample_size = 32
215
+ conf.T_eval = 20
216
+ conf.T = 1000
217
+
218
+ conf.T_inv = 200
219
+ conf.T_step = 100
220
+
221
+ conf.data_name = ds_name
222
+ conf.net_ch_mult = (1, 1, 2, 3, 4)
223
+ conf.net_enc_channel_mult = (1, 1, 2, 3, 4, 4)
224
+
225
+ conf.name = 'ukbb_ffhq128_autoenc'
226
+ return conf
227
+
228
+
229
+ def horse128_ddpm():
230
+ conf = ffhq128_ddpm()
231
+ conf.data_name = 'horse256'
232
+ conf.total_samples = 130_000_000
233
+ conf.eval_ema_every_samples = 10_000_000
234
+ conf.eval_every_samples = 10_000_000
235
+ conf.name = 'horse128_ddpm'
236
+ return conf
237
+
238
+
239
+ def horse128_autoenc():
240
+ conf = ffhq128_autoenc_base()
241
+ conf.data_name = 'horse256'
242
+ conf.total_samples = 130_000_000
243
+ conf.eval_ema_every_samples = 10_000_000
244
+ conf.eval_every_samples = 10_000_000
245
+ conf.name = 'horse128_autoenc'
246
+ return conf
247
+
248
+
249
+ def bedroom128_ddpm():
250
+ conf = ffhq128_ddpm()
251
+ conf.data_name = 'bedroom256'
252
+ conf.eval_ema_every_samples = 10_000_000
253
+ conf.eval_every_samples = 10_000_000
254
+ conf.total_samples = 120_000_000
255
+ conf.name = 'bedroom128_ddpm'
256
+ return conf
257
+
258
+
259
+ def bedroom128_autoenc():
260
+ conf = ffhq128_autoenc_base()
261
+ conf.data_name = 'bedroom256'
262
+ conf.eval_ema_every_samples = 10_000_000
263
+ conf.eval_every_samples = 10_000_000
264
+ conf.total_samples = 120_000_000
265
+ conf.name = 'bedroom128_autoenc'
266
+ return conf
267
+
268
+
269
+ def pretrain_celeba64d2c_72M():
270
+ conf = celeba64d2c_autoenc()
271
+ conf.pretrain = PretrainConfig(
272
+ name='72M',
273
+ path=f'checkpoints/{celeba64d2c_autoenc().name}/last.ckpt',
274
+ )
275
+ conf.latent_infer_path = f'checkpoints/{celeba64d2c_autoenc().name}/latent.pkl'
276
+ return conf
277
+
278
+
279
+ def pretrain_ffhq128_autoenc72M():
280
+ conf = ffhq128_autoenc_base()
281
+ conf.postfix = ''
282
+ conf.pretrain = PretrainConfig(
283
+ name='72M',
284
+ path=f'checkpoints/{ffhq128_autoenc_72M().name}/last.ckpt',
285
+ )
286
+ conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_72M().name}/latent.pkl'
287
+ return conf
288
+
289
+
290
+ def pretrain_ffhq128_autoenc130M():
291
+ conf = ffhq128_autoenc_base()
292
+ conf.pretrain = PretrainConfig(
293
+ name='130M',
294
+ path=f'checkpoints/{ffhq128_autoenc_130M().name}/last.ckpt',
295
+ )
296
+ conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_130M().name}/latent.pkl'
297
+ return conf
298
+
299
+
300
+ def pretrain_ffhq256_autoenc():
301
+ conf = ffhq256_autoenc()
302
+ conf.pretrain = PretrainConfig(
303
+ name='90M',
304
+ path=f'checkpoints/{ffhq256_autoenc().name}/last.ckpt',
305
+ )
306
+ conf.latent_infer_path = f'checkpoints/{ffhq256_autoenc().name}/latent.pkl'
307
+ return conf
308
+
309
+
310
+ def pretrain_horse128():
311
+ conf = horse128_autoenc()
312
+ conf.pretrain = PretrainConfig(
313
+ name='82M',
314
+ path=f'checkpoints/{horse128_autoenc().name}/last.ckpt',
315
+ )
316
+ conf.latent_infer_path = f'checkpoints/{horse128_autoenc().name}/latent.pkl'
317
+ return conf
318
+
319
+
320
+ def pretrain_bedroom128():
321
+ conf = bedroom128_autoenc()
322
+ conf.pretrain = PretrainConfig(
323
+ name='120M',
324
+ path=f'checkpoints/{bedroom128_autoenc().name}/last.ckpt',
325
+ )
326
+ conf.latent_infer_path = f'checkpoints/{bedroom128_autoenc().name}/latent.pkl'
327
+ return conf
DiffAE_support_templates_latent.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .DiffAE_support_templates import *
2
+
3
+
4
+ def latent_diffusion_config(conf: TrainConfig):
5
+ conf.batch_size = 128
6
+ conf.train_mode = TrainMode.latent_diffusion
7
+ conf.latent_gen_type = GenerativeType.ddim
8
+ conf.latent_loss_type = LossType.mse
9
+ conf.latent_model_mean_type = ModelMeanType.eps
10
+ conf.latent_model_var_type = ModelVarType.fixed_large
11
+ conf.latent_rescale_timesteps = False
12
+ conf.latent_clip_sample = False
13
+ conf.latent_T_eval = 20
14
+ conf.latent_znormalize = True
15
+ conf.total_samples = 96_000_000
16
+ conf.sample_every_samples = 400_000
17
+ conf.eval_every_samples = 20_000_000
18
+ conf.eval_ema_every_samples = 20_000_000
19
+ conf.save_every_samples = 2_000_000
20
+ return conf
21
+
22
+
23
+ def latent_diffusion128_config(conf: TrainConfig):
24
+ conf = latent_diffusion_config(conf)
25
+ conf.batch_size_eval = 32
26
+ return conf
27
+
28
+
29
+ def latent_mlp_2048_norm_10layers(conf: TrainConfig):
30
+ conf.net_latent_net_type = LatentNetType.skip
31
+ conf.net_latent_layers = 10
32
+ conf.net_latent_skip_layers = list(range(1, conf.net_latent_layers))
33
+ conf.net_latent_activation = Activation.silu
34
+ conf.net_latent_num_hid_channels = 2048
35
+ conf.net_latent_use_norm = True
36
+ conf.net_latent_condition_bias = 1
37
+ return conf
38
+
39
+
40
+ def latent_mlp_2048_norm_20layers(conf: TrainConfig):
41
+ conf = latent_mlp_2048_norm_10layers(conf)
42
+ conf.net_latent_layers = 20
43
+ conf.net_latent_skip_layers = list(range(1, conf.net_latent_layers))
44
+ return conf
45
+
46
+
47
+ def latent_256_batch_size(conf: TrainConfig):
48
+ conf.batch_size = 256
49
+ conf.eval_ema_every_samples = 100_000_000
50
+ conf.eval_every_samples = 100_000_000
51
+ conf.sample_every_samples = 1_000_000
52
+ conf.save_every_samples = 2_000_000
53
+ conf.total_samples = 301_000_000
54
+ return conf
55
+
56
+
57
+ def latent_512_batch_size(conf: TrainConfig):
58
+ conf.batch_size = 512
59
+ conf.eval_ema_every_samples = 100_000_000
60
+ conf.eval_every_samples = 100_000_000
61
+ conf.sample_every_samples = 1_000_000
62
+ conf.save_every_samples = 5_000_000
63
+ conf.total_samples = 501_000_000
64
+ return conf
65
+
66
+
67
+ def latent_2048_batch_size(conf: TrainConfig):
68
+ conf.batch_size = 2048
69
+ conf.eval_ema_every_samples = 200_000_000
70
+ conf.eval_every_samples = 200_000_000
71
+ conf.sample_every_samples = 4_000_000
72
+ conf.save_every_samples = 20_000_000
73
+ conf.total_samples = 1_501_000_000
74
+ return conf
75
+
76
+
77
+ def adamw_weight_decay(conf: TrainConfig):
78
+ conf.optimizer = OptimizerType.adamw
79
+ conf.weight_decay = 0.01
80
+ return conf
81
+
82
+
83
+ def ffhq128_autoenc_latent():
84
+ conf = pretrain_ffhq128_autoenc130M()
85
+ conf = latent_diffusion128_config(conf)
86
+ conf = latent_mlp_2048_norm_10layers(conf)
87
+ conf = latent_256_batch_size(conf)
88
+ conf = adamw_weight_decay(conf)
89
+ conf.total_samples = 101_000_000
90
+ conf.latent_loss_type = LossType.l1
91
+ conf.latent_beta_scheduler = 'const0.008'
92
+ conf.name = 'ffhq128_autoenc_latent'
93
+ return conf
94
+
95
+
96
+ def ffhq256_autoenc_latent():
97
+ conf = pretrain_ffhq256_autoenc()
98
+ conf = latent_diffusion128_config(conf)
99
+ conf = latent_mlp_2048_norm_10layers(conf)
100
+ conf = latent_256_batch_size(conf)
101
+ conf = adamw_weight_decay(conf)
102
+ conf.total_samples = 101_000_000
103
+ conf.latent_loss_type = LossType.l1
104
+ conf.latent_beta_scheduler = 'const0.008'
105
+ conf.eval_ema_every_samples = 200_000_000
106
+ conf.eval_every_samples = 200_000_000
107
+ conf.sample_every_samples = 4_000_000
108
+ conf.name = 'ffhq256_autoenc_latent'
109
+ return conf
110
+
111
+
112
+ def horse128_autoenc_latent():
113
+ conf = pretrain_horse128()
114
+ conf = latent_diffusion128_config(conf)
115
+ conf = latent_2048_batch_size(conf)
116
+ conf = latent_mlp_2048_norm_20layers(conf)
117
+ conf.total_samples = 2_001_000_000
118
+ conf.latent_beta_scheduler = 'const0.008'
119
+ conf.latent_loss_type = LossType.l1
120
+ conf.name = 'horse128_autoenc_latent'
121
+ return conf
122
+
123
+
124
+ def bedroom128_autoenc_latent():
125
+ conf = pretrain_bedroom128()
126
+ conf = latent_diffusion128_config(conf)
127
+ conf = latent_2048_batch_size(conf)
128
+ conf = latent_mlp_2048_norm_20layers(conf)
129
+ conf.total_samples = 2_001_000_000
130
+ conf.latent_beta_scheduler = 'const0.008'
131
+ conf.latent_loss_type = LossType.l1
132
+ conf.name = 'bedroom128_autoenc_latent'
133
+ return conf
134
+
135
+
136
+ def celeba64d2c_autoenc_latent():
137
+ conf = pretrain_celeba64d2c_72M()
138
+ conf = latent_diffusion_config(conf)
139
+ conf = latent_512_batch_size(conf)
140
+ conf = latent_mlp_2048_norm_10layers(conf)
141
+ conf = adamw_weight_decay(conf)
142
+ # just for the name
143
+ conf.continue_from = PretrainConfig('200M',
144
+ f'log-latent/{conf.name}/last.ckpt')
145
+ conf.postfix = '_300M'
146
+ conf.total_samples = 301_000_000
147
+ conf.latent_beta_scheduler = 'const0.008'
148
+ conf.latent_loss_type = LossType.l1
149
+ conf.name = 'celeba64d2c_autoenc_latent'
150
+ return conf
DiffAE_support_utils.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from statistics import median
2
+ from skimage.metrics import structural_similarity
3
+
4
+ def getSSIM(gt, out, gt_flag=None, data_range=1):
5
+ if gt_flag is None: # all of the samples have GTs
6
+ gt_flag = [True]*gt.shape[0]
7
+
8
+ vals = []
9
+ for i in range(gt.shape[0]):
10
+ if not gt_flag[i]:
11
+ continue
12
+ vals.extend(
13
+ structural_similarity(
14
+ gt[i, j, ...], out[i, j, ...], data_range=data_range
15
+ )
16
+ for j in range(gt.shape[1])
17
+ )
18
+ return median(vals)
19
+
20
+ def ema(source, target, decay):
21
+ source_dict = source.state_dict()
22
+ target_dict = target.state_dict()
23
+ for key in source_dict.keys():
24
+ target_dict[key].data.copy_(target_dict[key].data * decay +
25
+ source_dict[key].data * (1 - decay))
26
+
27
+
28
+ class WarmupLR:
29
+ def __init__(self, warmup) -> None:
30
+ self.warmup = warmup
31
+
32
+ def __call__(self, step):
33
+ return min(step, self.warmup) / self.warmup
config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ampmode": "16-mixed",
3
+ "architectures": [
4
+ "DiffAE"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "DiffAEConfig.DiffAEConfig",
8
+ "AutoModel": "DiffAE.DiffAE"
9
+ },
10
+ "batch_size": 9,
11
+ "data_name": "ukbb",
12
+ "diffusion_type": "beatgans",
13
+ "grey2RGB": -1,
14
+ "in_channels": 1,
15
+ "input_shape": [
16
+ 50,
17
+ 128,
18
+ 128
19
+ ],
20
+ "is3D": true,
21
+ "latent_dim": 128,
22
+ "lr": 0.0001,
23
+ "model_type": "DiffAE",
24
+ "net_ch": 32,
25
+ "out_channels": 1,
26
+ "sample_every_batches": 1000,
27
+ "sample_size": 4,
28
+ "seed": 1701,
29
+ "test_ema": true,
30
+ "test_emb_only": true,
31
+ "test_with_TEval": true,
32
+ "torch_dtype": "float32",
33
+ "transformers_version": "4.44.2"
34
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c810229d67650b6b73aa6b30722bb25e62c01a0e805d3964b7789051a46194ab
3
+ size 179944264