valhalla commited on
Commit
4996763
·
1 Parent(s): 8eaba8f

Update modeling_latent_diffusion.py

Browse files
Files changed (1) hide show
  1. modeling_latent_diffusion.py +26 -893
modeling_latent_diffusion.py CHANGED
@@ -7,862 +7,17 @@ import torch
7
  import torch.nn as nn
8
 
9
  from diffusers import DiffusionPipeline
10
- from diffusers.configuration_utils import ConfigMixin
11
- from diffusers.modeling_utils import ModelMixin
12
-
13
-
14
- def get_timestep_embedding(timesteps, embedding_dim):
15
- """
16
- This matches the implementation in Denoising Diffusion Probabilistic Models:
17
- From Fairseq.
18
- Build sinusoidal embeddings.
19
- This matches the implementation in tensor2tensor, but differs slightly
20
- from the description in Section 3.5 of "Attention Is All You Need".
21
- """
22
- assert len(timesteps.shape) == 1
23
-
24
- half_dim = embedding_dim // 2
25
- emb = math.log(10000) / (half_dim - 1)
26
- emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
27
- emb = emb.to(device=timesteps.device)
28
- emb = timesteps.float()[:, None] * emb[None, :]
29
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
30
- if embedding_dim % 2 == 1: # zero pad
31
- emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
32
- return emb
33
-
34
-
35
- def nonlinearity(x):
36
- # swish
37
- return x * torch.sigmoid(x)
38
-
39
-
40
- def Normalize(in_channels):
41
- return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
42
-
43
-
44
- class Upsample(nn.Module):
45
- def __init__(self, in_channels, with_conv):
46
- super().__init__()
47
- self.with_conv = with_conv
48
- if self.with_conv:
49
- self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
50
-
51
- def forward(self, x):
52
- x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
53
- if self.with_conv:
54
- x = self.conv(x)
55
- return x
56
-
57
-
58
- class Downsample(nn.Module):
59
- def __init__(self, in_channels, with_conv):
60
- super().__init__()
61
- self.with_conv = with_conv
62
- if self.with_conv:
63
- # no asymmetric padding in torch conv, must do it ourselves
64
- self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
65
-
66
- def forward(self, x):
67
- if self.with_conv:
68
- pad = (0, 1, 0, 1)
69
- x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
70
- x = self.conv(x)
71
- else:
72
- x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
73
- return x
74
-
75
-
76
- class ResnetBlock(nn.Module):
77
- def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
78
- super().__init__()
79
- self.in_channels = in_channels
80
- out_channels = in_channels if out_channels is None else out_channels
81
- self.out_channels = out_channels
82
- self.use_conv_shortcut = conv_shortcut
83
-
84
- self.norm1 = Normalize(in_channels)
85
- self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
86
- if temb_channels > 0:
87
- self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
88
- self.norm2 = Normalize(out_channels)
89
- self.dropout = torch.nn.Dropout(dropout)
90
- self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
91
- if self.in_channels != self.out_channels:
92
- if self.use_conv_shortcut:
93
- self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
94
- else:
95
- self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
96
-
97
- def forward(self, x, temb):
98
- h = x
99
- h = self.norm1(h)
100
- h = nonlinearity(h)
101
- h = self.conv1(h)
102
-
103
- if temb is not None:
104
- h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
105
-
106
- h = self.norm2(h)
107
- h = nonlinearity(h)
108
- h = self.dropout(h)
109
- h = self.conv2(h)
110
-
111
- if self.in_channels != self.out_channels:
112
- if self.use_conv_shortcut:
113
- x = self.conv_shortcut(x)
114
- else:
115
- x = self.nin_shortcut(x)
116
-
117
- return x + h
118
-
119
-
120
- class AttnBlock(nn.Module):
121
- def __init__(self, in_channels):
122
- super().__init__()
123
- self.in_channels = in_channels
124
-
125
- self.norm = Normalize(in_channels)
126
- self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
127
- self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
128
- self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
129
- self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
130
-
131
- def forward(self, x):
132
- h_ = x
133
- h_ = self.norm(h_)
134
- q = self.q(h_)
135
- k = self.k(h_)
136
- v = self.v(h_)
137
-
138
- # compute attention
139
- b, c, h, w = q.shape
140
- q = q.reshape(b, c, h * w)
141
- q = q.permute(0, 2, 1) # b,hw,c
142
- k = k.reshape(b, c, h * w) # b,c,hw
143
- w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
144
- w_ = w_ * (int(c) ** (-0.5))
145
- w_ = torch.nn.functional.softmax(w_, dim=2)
146
-
147
- # attend to values
148
- v = v.reshape(b, c, h * w)
149
- w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
150
- h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
151
- h_ = h_.reshape(b, c, h, w)
152
-
153
- h_ = self.proj_out(h_)
154
-
155
- return x + h_
156
-
157
-
158
- class Model(nn.Module):
159
- def __init__(
160
- self,
161
- *,
162
- ch,
163
- out_ch,
164
- ch_mult=(1, 2, 4, 8),
165
- num_res_blocks,
166
- attn_resolutions,
167
- dropout=0.0,
168
- resamp_with_conv=True,
169
- in_channels,
170
- resolution,
171
- use_timestep=True,
172
- ):
173
- super().__init__()
174
- self.ch = ch
175
- self.temb_ch = self.ch * 4
176
- self.num_resolutions = len(ch_mult)
177
- self.num_res_blocks = num_res_blocks
178
- self.resolution = resolution
179
- self.in_channels = in_channels
180
-
181
- self.use_timestep = use_timestep
182
- if self.use_timestep:
183
- # timestep embedding
184
- self.temb = nn.Module()
185
- self.temb.dense = nn.ModuleList(
186
- [
187
- torch.nn.Linear(self.ch, self.temb_ch),
188
- torch.nn.Linear(self.temb_ch, self.temb_ch),
189
- ]
190
- )
191
-
192
- # downsampling
193
- self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
194
-
195
- curr_res = resolution
196
- in_ch_mult = (1,) + tuple(ch_mult)
197
- self.down = nn.ModuleList()
198
- for i_level in range(self.num_resolutions):
199
- block = nn.ModuleList()
200
- attn = nn.ModuleList()
201
- block_in = ch * in_ch_mult[i_level]
202
- block_out = ch * ch_mult[i_level]
203
- for i_block in range(self.num_res_blocks):
204
- block.append(
205
- ResnetBlock(
206
- in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
207
- )
208
- )
209
- block_in = block_out
210
- if curr_res in attn_resolutions:
211
- attn.append(AttnBlock(block_in))
212
- down = nn.Module()
213
- down.block = block
214
- down.attn = attn
215
- if i_level != self.num_resolutions - 1:
216
- down.downsample = Downsample(block_in, resamp_with_conv)
217
- curr_res = curr_res // 2
218
- self.down.append(down)
219
-
220
- # middle
221
- self.mid = nn.Module()
222
- self.mid.block_1 = ResnetBlock(
223
- in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
224
- )
225
- self.mid.attn_1 = AttnBlock(block_in)
226
- self.mid.block_2 = ResnetBlock(
227
- in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
228
- )
229
-
230
- # upsampling
231
- self.up = nn.ModuleList()
232
- for i_level in reversed(range(self.num_resolutions)):
233
- block = nn.ModuleList()
234
- attn = nn.ModuleList()
235
- block_out = ch * ch_mult[i_level]
236
- skip_in = ch * ch_mult[i_level]
237
- for i_block in range(self.num_res_blocks + 1):
238
- if i_block == self.num_res_blocks:
239
- skip_in = ch * in_ch_mult[i_level]
240
- block.append(
241
- ResnetBlock(
242
- in_channels=block_in + skip_in,
243
- out_channels=block_out,
244
- temb_channels=self.temb_ch,
245
- dropout=dropout,
246
- )
247
- )
248
- block_in = block_out
249
- if curr_res in attn_resolutions:
250
- attn.append(AttnBlock(block_in))
251
- up = nn.Module()
252
- up.block = block
253
- up.attn = attn
254
- if i_level != 0:
255
- up.upsample = Upsample(block_in, resamp_with_conv)
256
- curr_res = curr_res * 2
257
- self.up.insert(0, up) # prepend to get consistent order
258
-
259
- # end
260
- self.norm_out = Normalize(block_in)
261
- self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
262
-
263
- def forward(self, x, t=None):
264
- # assert x.shape[2] == x.shape[3] == self.resolution
265
-
266
- if self.use_timestep:
267
- # timestep embedding
268
- assert t is not None
269
- temb = get_timestep_embedding(t, self.ch)
270
- temb = self.temb.dense[0](temb)
271
- temb = nonlinearity(temb)
272
- temb = self.temb.dense[1](temb)
273
- else:
274
- temb = None
275
-
276
- # downsampling
277
- hs = [self.conv_in(x)]
278
- for i_level in range(self.num_resolutions):
279
- for i_block in range(self.num_res_blocks):
280
- h = self.down[i_level].block[i_block](hs[-1], temb)
281
- if len(self.down[i_level].attn) > 0:
282
- h = self.down[i_level].attn[i_block](h)
283
- hs.append(h)
284
- if i_level != self.num_resolutions - 1:
285
- hs.append(self.down[i_level].downsample(hs[-1]))
286
-
287
- # middle
288
- h = hs[-1]
289
- h = self.mid.block_1(h, temb)
290
- h = self.mid.attn_1(h)
291
- h = self.mid.block_2(h, temb)
292
-
293
- # upsampling
294
- for i_level in reversed(range(self.num_resolutions)):
295
- for i_block in range(self.num_res_blocks + 1):
296
- h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
297
- if len(self.up[i_level].attn) > 0:
298
- h = self.up[i_level].attn[i_block](h)
299
- if i_level != 0:
300
- h = self.up[i_level].upsample(h)
301
-
302
- # end
303
- h = self.norm_out(h)
304
- h = nonlinearity(h)
305
- h = self.conv_out(h)
306
- return h
307
-
308
-
309
- class Encoder(nn.Module):
310
- def __init__(
311
- self,
312
- *,
313
- ch,
314
- out_ch,
315
- ch_mult=(1, 2, 4, 8),
316
- num_res_blocks,
317
- attn_resolutions,
318
- dropout=0.0,
319
- resamp_with_conv=True,
320
- in_channels,
321
- resolution,
322
- z_channels,
323
- double_z=True,
324
- **ignore_kwargs,
325
- ):
326
- super().__init__()
327
- self.ch = ch
328
- self.temb_ch = 0
329
- self.num_resolutions = len(ch_mult)
330
- self.num_res_blocks = num_res_blocks
331
- self.resolution = resolution
332
- self.in_channels = in_channels
333
-
334
- # downsampling
335
- self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
336
-
337
- curr_res = resolution
338
- in_ch_mult = (1,) + tuple(ch_mult)
339
- self.down = nn.ModuleList()
340
- for i_level in range(self.num_resolutions):
341
- block = nn.ModuleList()
342
- attn = nn.ModuleList()
343
- block_in = ch * in_ch_mult[i_level]
344
- block_out = ch * ch_mult[i_level]
345
- for i_block in range(self.num_res_blocks):
346
- block.append(
347
- ResnetBlock(
348
- in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
349
- )
350
- )
351
- block_in = block_out
352
- if curr_res in attn_resolutions:
353
- attn.append(AttnBlock(block_in))
354
- down = nn.Module()
355
- down.block = block
356
- down.attn = attn
357
- if i_level != self.num_resolutions - 1:
358
- down.downsample = Downsample(block_in, resamp_with_conv)
359
- curr_res = curr_res // 2
360
- self.down.append(down)
361
-
362
- # middle
363
- self.mid = nn.Module()
364
- self.mid.block_1 = ResnetBlock(
365
- in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
366
- )
367
- self.mid.attn_1 = AttnBlock(block_in)
368
- self.mid.block_2 = ResnetBlock(
369
- in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
370
- )
371
-
372
- # end
373
- self.norm_out = Normalize(block_in)
374
- self.conv_out = torch.nn.Conv2d(
375
- block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1
376
- )
377
-
378
- def forward(self, x):
379
- # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
380
-
381
- # timestep embedding
382
- temb = None
383
-
384
- # downsampling
385
- hs = [self.conv_in(x)]
386
- for i_level in range(self.num_resolutions):
387
- for i_block in range(self.num_res_blocks):
388
- h = self.down[i_level].block[i_block](hs[-1], temb)
389
- if len(self.down[i_level].attn) > 0:
390
- h = self.down[i_level].attn[i_block](h)
391
- hs.append(h)
392
- if i_level != self.num_resolutions - 1:
393
- hs.append(self.down[i_level].downsample(hs[-1]))
394
-
395
- # middle
396
- h = hs[-1]
397
- h = self.mid.block_1(h, temb)
398
- h = self.mid.attn_1(h)
399
- h = self.mid.block_2(h, temb)
400
-
401
- # end
402
- h = self.norm_out(h)
403
- h = nonlinearity(h)
404
- h = self.conv_out(h)
405
- return h
406
-
407
-
408
- class Decoder(nn.Module):
409
- def __init__(
410
- self,
411
- *,
412
- ch,
413
- out_ch,
414
- ch_mult=(1, 2, 4, 8),
415
- num_res_blocks,
416
- attn_resolutions,
417
- dropout=0.0,
418
- resamp_with_conv=True,
419
- in_channels,
420
- resolution,
421
- z_channels,
422
- give_pre_end=False,
423
- **ignorekwargs,
424
- ):
425
- super().__init__()
426
- self.ch = ch
427
- self.temb_ch = 0
428
- self.num_resolutions = len(ch_mult)
429
- self.num_res_blocks = num_res_blocks
430
- self.resolution = resolution
431
- self.in_channels = in_channels
432
- self.give_pre_end = give_pre_end
433
-
434
- # compute in_ch_mult, block_in and curr_res at lowest res
435
- in_ch_mult = (1,) + tuple(ch_mult)
436
- block_in = ch * ch_mult[self.num_resolutions - 1]
437
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
438
- self.z_shape = (1, z_channels, curr_res, curr_res)
439
- print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
440
-
441
- # z to block_in
442
- self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
443
-
444
- # middle
445
- self.mid = nn.Module()
446
- self.mid.block_1 = ResnetBlock(
447
- in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
448
- )
449
- self.mid.attn_1 = AttnBlock(block_in)
450
- self.mid.block_2 = ResnetBlock(
451
- in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
452
- )
453
-
454
- # upsampling
455
- self.up = nn.ModuleList()
456
- for i_level in reversed(range(self.num_resolutions)):
457
- block = nn.ModuleList()
458
- attn = nn.ModuleList()
459
- block_out = ch * ch_mult[i_level]
460
- for i_block in range(self.num_res_blocks + 1):
461
- block.append(
462
- ResnetBlock(
463
- in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
464
- )
465
- )
466
- block_in = block_out
467
- if curr_res in attn_resolutions:
468
- attn.append(AttnBlock(block_in))
469
- up = nn.Module()
470
- up.block = block
471
- up.attn = attn
472
- if i_level != 0:
473
- up.upsample = Upsample(block_in, resamp_with_conv)
474
- curr_res = curr_res * 2
475
- self.up.insert(0, up) # prepend to get consistent order
476
-
477
- # end
478
- self.norm_out = Normalize(block_in)
479
- self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
480
-
481
- def forward(self, z):
482
- # assert z.shape[1:] == self.z_shape[1:]
483
- self.last_z_shape = z.shape
484
-
485
- # timestep embedding
486
- temb = None
487
-
488
- # z to block_in
489
- h = self.conv_in(z)
490
-
491
- # middle
492
- h = self.mid.block_1(h, temb)
493
- h = self.mid.attn_1(h)
494
- h = self.mid.block_2(h, temb)
495
-
496
- # upsampling
497
- for i_level in reversed(range(self.num_resolutions)):
498
- for i_block in range(self.num_res_blocks + 1):
499
- h = self.up[i_level].block[i_block](h, temb)
500
- if len(self.up[i_level].attn) > 0:
501
- h = self.up[i_level].attn[i_block](h)
502
- if i_level != 0:
503
- h = self.up[i_level].upsample(h)
504
-
505
- # end
506
- if self.give_pre_end:
507
- return h
508
-
509
- h = self.norm_out(h)
510
- h = nonlinearity(h)
511
- h = self.conv_out(h)
512
- return h
513
-
514
-
515
- class VectorQuantizer(nn.Module):
516
- """
517
- Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
518
- avoids costly matrix multiplications and allows for post-hoc remapping of indices.
519
- """
520
-
521
- # NOTE: due to a bug the beta term was applied to the wrong term. for
522
- # backwards compatibility we use the buggy version by default, but you can
523
- # specify legacy=False to fix it.
524
- def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
525
- super().__init__()
526
- self.n_e = n_e
527
- self.e_dim = e_dim
528
- self.beta = beta
529
- self.legacy = legacy
530
-
531
- self.embedding = nn.Embedding(self.n_e, self.e_dim)
532
- self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
533
-
534
- self.remap = remap
535
- if self.remap is not None:
536
- self.register_buffer("used", torch.tensor(np.load(self.remap)))
537
- self.re_embed = self.used.shape[0]
538
- self.unknown_index = unknown_index # "random" or "extra" or integer
539
- if self.unknown_index == "extra":
540
- self.unknown_index = self.re_embed
541
- self.re_embed = self.re_embed + 1
542
- print(
543
- f"Remapping {self.n_e} indices to {self.re_embed} indices. "
544
- f"Using {self.unknown_index} for unknown indices."
545
- )
546
- else:
547
- self.re_embed = n_e
548
-
549
- self.sane_index_shape = sane_index_shape
550
-
551
- def remap_to_used(self, inds):
552
- ishape = inds.shape
553
- assert len(ishape) > 1
554
- inds = inds.reshape(ishape[0], -1)
555
- used = self.used.to(inds)
556
- match = (inds[:, :, None] == used[None, None, ...]).long()
557
- new = match.argmax(-1)
558
- unknown = match.sum(2) < 1
559
- if self.unknown_index == "random":
560
- new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
561
- else:
562
- new[unknown] = self.unknown_index
563
- return new.reshape(ishape)
564
-
565
- def unmap_to_all(self, inds):
566
- ishape = inds.shape
567
- assert len(ishape) > 1
568
- inds = inds.reshape(ishape[0], -1)
569
- used = self.used.to(inds)
570
- if self.re_embed > self.used.shape[0]: # extra token
571
- inds[inds >= self.used.shape[0]] = 0 # simply set to zero
572
- back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
573
- return back.reshape(ishape)
574
-
575
- def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
576
- assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
577
- assert rescale_logits == False, "Only for interface compatible with Gumbel"
578
- assert return_logits == False, "Only for interface compatible with Gumbel"
579
- # reshape z -> (batch, height, width, channel) and flatten
580
- z = rearrange(z, "b c h w -> b h w c").contiguous()
581
- z_flattened = z.view(-1, self.e_dim)
582
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
583
-
584
- d = (
585
- torch.sum(z_flattened**2, dim=1, keepdim=True)
586
- + torch.sum(self.embedding.weight**2, dim=1)
587
- - 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
588
- )
589
-
590
- min_encoding_indices = torch.argmin(d, dim=1)
591
- z_q = self.embedding(min_encoding_indices).view(z.shape)
592
- perplexity = None
593
- min_encodings = None
594
-
595
- # compute loss for embedding
596
- if not self.legacy:
597
- loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
598
- else:
599
- loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
600
-
601
- # preserve gradients
602
- z_q = z + (z_q - z).detach()
603
-
604
- # reshape back to match original input shape
605
- z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
606
-
607
- if self.remap is not None:
608
- min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
609
- min_encoding_indices = self.remap_to_used(min_encoding_indices)
610
- min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
611
-
612
- if self.sane_index_shape:
613
- min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
614
-
615
- return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
616
-
617
- def get_codebook_entry(self, indices, shape):
618
- # shape specifying (batch, height, width, channel)
619
- if self.remap is not None:
620
- indices = indices.reshape(shape[0], -1) # add batch axis
621
- indices = self.unmap_to_all(indices)
622
- indices = indices.reshape(-1) # flatten again
623
-
624
- # get quantized latent vectors
625
- z_q = self.embedding(indices)
626
-
627
- if shape is not None:
628
- z_q = z_q.view(shape)
629
- # reshape back to match original input shape
630
- z_q = z_q.permute(0, 3, 1, 2).contiguous()
631
-
632
- return z_q
633
-
634
-
635
- class VQModel(ModelMixin, ConfigMixin):
636
- def __init__(
637
- self,
638
- ch,
639
- out_ch,
640
- num_res_blocks,
641
- attn_resolutions,
642
- in_channels,
643
- resolution,
644
- z_channels,
645
- n_embed,
646
- embed_dim,
647
- remap=None,
648
- sane_index_shape=False, # tell vector quantizer to return indices as bhw
649
- ch_mult=(1, 2, 4, 8),
650
- dropout=0.0,
651
- double_z=True,
652
- resamp_with_conv=True,
653
- give_pre_end=False,
654
- ):
655
- super().__init__()
656
-
657
- # register all __init__ params with self.register
658
- self.register(
659
- ch=ch,
660
- out_ch=out_ch,
661
- num_res_blocks=num_res_blocks,
662
- attn_resolutions=attn_resolutions,
663
- in_channels=in_channels,
664
- resolution=resolution,
665
- z_channels=z_channels,
666
- n_embed=n_embed,
667
- embed_dim=embed_dim,
668
- remap=remap,
669
- sane_index_shape=sane_index_shape,
670
- ch_mult=ch_mult,
671
- dropout=dropout,
672
- double_z=double_z,
673
- resamp_with_conv=resamp_with_conv,
674
- give_pre_end=give_pre_end,
675
- )
676
-
677
- # pass init params to Encoder
678
- self.encoder = Encoder(
679
- ch=ch,
680
- out_ch=out_ch,
681
- num_res_blocks=num_res_blocks,
682
- attn_resolutions=attn_resolutions,
683
- in_channels=in_channels,
684
- resolution=resolution,
685
- z_channels=z_channels,
686
- ch_mult=ch_mult,
687
- dropout=dropout,
688
- resamp_with_conv=resamp_with_conv,
689
- double_z=double_z,
690
- give_pre_end=give_pre_end,
691
- )
692
-
693
- self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape)
694
-
695
- # pass init params to Decoder
696
- self.decoder = Decoder(
697
- ch=ch,
698
- out_ch=out_ch,
699
- num_res_blocks=num_res_blocks,
700
- attn_resolutions=attn_resolutions,
701
- in_channels=in_channels,
702
- resolution=resolution,
703
- z_channels=z_channels,
704
- ch_mult=ch_mult,
705
- dropout=dropout,
706
- resamp_with_conv=resamp_with_conv,
707
- give_pre_end=give_pre_end,
708
- )
709
-
710
- def encode(self, x):
711
- h = self.encoder(x)
712
- h = self.quant_conv(h)
713
- return h
714
-
715
- def decode(self, h, force_not_quantize=False):
716
- # also go through quantization layer
717
- if not force_not_quantize:
718
- quant, emb_loss, info = self.quantize(h)
719
- else:
720
- quant = h
721
- quant = self.post_quant_conv(quant)
722
- dec = self.decoder(quant)
723
- return dec
724
-
725
-
726
- class DiagonalGaussianDistribution(object):
727
- def __init__(self, parameters, deterministic=False):
728
- self.parameters = parameters
729
- self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
730
- self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
731
- self.deterministic = deterministic
732
- self.std = torch.exp(0.5 * self.logvar)
733
- self.var = torch.exp(self.logvar)
734
- if self.deterministic:
735
- self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
736
-
737
- def sample(self):
738
- x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
739
- return x
740
-
741
- def kl(self, other=None):
742
- if self.deterministic:
743
- return torch.Tensor([0.])
744
- else:
745
- if other is None:
746
- return 0.5 * torch.sum(torch.pow(self.mean, 2)
747
- + self.var - 1.0 - self.logvar,
748
- dim=[1, 2, 3])
749
- else:
750
- return 0.5 * torch.sum(
751
- torch.pow(self.mean - other.mean, 2) / other.var
752
- + self.var / other.var - 1.0 - self.logvar + other.logvar,
753
- dim=[1, 2, 3])
754
-
755
- def nll(self, sample, dims=[1,2,3]):
756
- if self.deterministic:
757
- return torch.Tensor([0.])
758
- logtwopi = np.log(2.0 * np.pi)
759
- return 0.5 * torch.sum(
760
- logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
761
- dim=dims)
762
-
763
- def mode(self):
764
- return self.mean
765
-
766
- class AutoencoderKL(ModelMixin, ConfigMixin):
767
- def __init__(
768
- self,
769
- ch,
770
- out_ch,
771
- num_res_blocks,
772
- attn_resolutions,
773
- in_channels,
774
- resolution,
775
- z_channels,
776
- embed_dim,
777
- remap=None,
778
- sane_index_shape=False, # tell vector quantizer to return indices as bhw
779
- ch_mult=(1, 2, 4, 8),
780
- dropout=0.0,
781
- double_z=True,
782
- resamp_with_conv=True,
783
- give_pre_end=False,
784
- ):
785
- super().__init__()
786
-
787
- # register all __init__ params with self.register
788
- self.register(
789
- ch=ch,
790
- out_ch=out_ch,
791
- num_res_blocks=num_res_blocks,
792
- attn_resolutions=attn_resolutions,
793
- in_channels=in_channels,
794
- resolution=resolution,
795
- z_channels=z_channels,
796
- embed_dim=embed_dim,
797
- remap=remap,
798
- sane_index_shape=sane_index_shape,
799
- ch_mult=ch_mult,
800
- dropout=dropout,
801
- double_z=double_z,
802
- resamp_with_conv=resamp_with_conv,
803
- give_pre_end=give_pre_end,
804
- )
805
-
806
- # pass init params to Encoder
807
- self.encoder = Encoder(
808
- ch=ch,
809
- out_ch=out_ch,
810
- num_res_blocks=num_res_blocks,
811
- attn_resolutions=attn_resolutions,
812
- in_channels=in_channels,
813
- resolution=resolution,
814
- z_channels=z_channels,
815
- ch_mult=ch_mult,
816
- dropout=dropout,
817
- resamp_with_conv=resamp_with_conv,
818
- double_z=double_z,
819
- give_pre_end=give_pre_end,
820
- )
821
-
822
- # pass init params to Decoder
823
- self.decoder = Decoder(
824
- ch=ch,
825
- out_ch=out_ch,
826
- num_res_blocks=num_res_blocks,
827
- attn_resolutions=attn_resolutions,
828
- in_channels=in_channels,
829
- resolution=resolution,
830
- z_channels=z_channels,
831
- ch_mult=ch_mult,
832
- dropout=dropout,
833
- resamp_with_conv=resamp_with_conv,
834
- give_pre_end=give_pre_end,
835
- )
836
-
837
- self.quant_conv = torch.nn.Conv2d(2*z_channels, 2*embed_dim, 1)
838
- self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
839
-
840
- def encode(self, x):
841
- h = self.encoder(x)
842
- moments = self.quant_conv(h)
843
- posterior = DiagonalGaussianDistribution(moments)
844
- return posterior
845
-
846
- def decode(self, z):
847
- z = self.post_quant_conv(z)
848
- dec = self.decoder(z)
849
- return dec
850
-
851
- def forward(self, input, sample_posterior=True):
852
- posterior = self.encode(input)
853
- if sample_posterior:
854
- z = posterior.sample()
855
- else:
856
- z = posterior.mode()
857
- dec = self.decode(z)
858
- return dec, posterior
859
 
 
 
 
860
 
861
  class LatentDiffusion(DiffusionPipeline):
862
  def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler):
863
  super().__init__()
864
  self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler)
865
 
 
866
  def __call__(self, prompt, batch_size=1, generator=None, torch_device=None, eta=0.0, guidance_scale=1.0, num_inference_steps=50):
867
  # eta corresponds to η in paper and should be between [0, 1]
868
 
@@ -873,6 +28,7 @@ class LatentDiffusion(DiffusionPipeline):
873
  self.vqvae.to(torch_device)
874
  self.bert.to(torch_device)
875
 
 
876
  if guidance_scale != 1.0:
877
  uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors='pt').to(torch_device)
878
  uncond_embeddings = self.bert(uncond_input.input_ids)[0]
@@ -901,65 +57,42 @@ class LatentDiffusion(DiffusionPipeline):
901
  # - pred_image_direction -> "direction pointingc to x_t"
902
  # - pred_prev_image -> "x_t-1"
903
  for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
904
- # 1. predict noise residual
905
  if guidance_scale == 1.0:
906
- timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
907
- context = text_embedding
908
  image_in = image
 
 
909
  else:
 
 
 
910
  image_in = torch.cat([image] * 2)
911
- timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
912
  context = torch.cat([uncond_embeddings, text_embedding])
 
 
 
 
913
 
914
- with torch.no_grad():
915
- pred_noise_t = self.unet(image_in, timesteps, context=context)
916
-
917
  if guidance_scale != 1.0:
918
  pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2)
919
  pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond)
920
 
921
- # 2. get actual t and t-1
922
- train_step = inference_step_times[t]
923
- prev_train_step = inference_step_times[t - 1] if t > 0 else -1
924
-
925
- # 3. compute alphas, betas
926
- alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
927
- alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step)
928
- beta_prod_t = 1 - alpha_prod_t
929
- beta_prod_t_prev = 1 - alpha_prod_t_prev
930
-
931
- # 4. Compute predicted previous image from predicted noise
932
- # First: compute predicted original image from predicted noise also called
933
- # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
934
- pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt()
935
-
936
- # Second: Clip "predicted x_0"
937
- # pred_original_image = torch.clamp(pred_original_image, -1, 1)
938
 
939
- # Third: Compute variance: "sigma_t(η)" -> see formula (16)
940
- # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
941
- std_dev_t = (beta_prod_t_prev / beta_prod_t).sqrt() * (1 - alpha_prod_t / alpha_prod_t_prev).sqrt()
942
- std_dev_t = eta * std_dev_t
943
-
944
- # Fourth: Compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
945
- pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2).sqrt() * pred_noise_t
946
-
947
- # Fifth: Compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
948
- pred_prev_image = alpha_prod_t_prev.sqrt() * pred_original_image + pred_image_direction
949
-
950
- # 5. Sample x_t-1 image optionally if η > 0.0 by adding noise to pred_prev_image
951
- # Note: eta = 1.0 essentially corresponds to DDPM
952
- if eta > 0.0:
953
  noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
954
- prev_image = pred_prev_image + std_dev_t * noise
955
- else:
956
- prev_image = pred_prev_image
957
 
958
- # 6. Set current image to prev_image: x_t -> x_t-1
959
- image = prev_image
960
 
 
961
  image = 1 / 0.18215 * image
962
  image = self.vqvae.decode(image)
963
  image = torch.clamp((image+1.0)/2.0, min=0.0, max=1.0)
964
 
965
- return image
 
7
  import torch.nn as nn
8
 
9
  from diffusers import DiffusionPipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ from .modeling_vae import AutoencoderKL
12
+ from .configuration_ldmbert import LDMBertConfig
13
+ from .modeling_ldmbert import LDMBertModel
14
 
15
  class LatentDiffusion(DiffusionPipeline):
16
  def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler):
17
  super().__init__()
18
  self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler)
19
 
20
+ @torch.no_grad()
21
  def __call__(self, prompt, batch_size=1, generator=None, torch_device=None, eta=0.0, guidance_scale=1.0, num_inference_steps=50):
22
  # eta corresponds to η in paper and should be between [0, 1]
23
 
 
28
  self.vqvae.to(torch_device)
29
  self.bert.to(torch_device)
30
 
31
+ # get unconditional embeddings for classifier free guidence
32
  if guidance_scale != 1.0:
33
  uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors='pt').to(torch_device)
34
  uncond_embeddings = self.bert(uncond_input.input_ids)[0]
 
57
  # - pred_image_direction -> "direction pointingc to x_t"
58
  # - pred_prev_image -> "x_t-1"
59
  for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
60
+ # guidance_scale of 1 means no guidance
61
  if guidance_scale == 1.0:
 
 
62
  image_in = image
63
+ context = text_embedding
64
+ timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
65
  else:
66
+ # for classifier free guidance, we need to do two forward passes
67
+ # here we concanate embedding and unconditioned embedding in a single batch
68
+ # to avoid doing two forward passes
69
  image_in = torch.cat([image] * 2)
 
70
  context = torch.cat([uncond_embeddings, text_embedding])
71
+ timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
72
+
73
+ # 1. predict noise residual
74
+ pred_noise_t = self.unet(image_in, timesteps, context=context)
75
 
76
+ # perform guidance
 
 
77
  if guidance_scale != 1.0:
78
  pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2)
79
  pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond)
80
 
81
+ # 2. predict previous mean of image x_t-1
82
+ pred_prev_image = self.noise_scheduler.compute_prev_image_step(pred_noise_t, image, t, num_inference_steps, eta)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ # 3. optionally sample variance
85
+ variance = 0
86
+ if eta > 0:
 
 
 
 
 
 
 
 
 
 
 
87
  noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
88
+ variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
 
 
89
 
90
+ # 4. set current image to prev_image: x_t -> x_t-1
91
+ image = pred_prev_image + variance
92
 
93
+ # scale and decode image with vae
94
  image = 1 / 0.18215 * image
95
  image = self.vqvae.decode(image)
96
  image = torch.clamp((image+1.0)/2.0, min=0.0, max=1.0)
97
 
98
+ return image