Dionyssos commited on
Commit
21b15e5
·
1 Parent(s): a212d92

DEL DIFFUSION

Browse files
Files changed (3) hide show
  1. Modules/diffusion/sampler.py +0 -181
  2. models.py +5 -23
  3. msinference.py +3 -12
Modules/diffusion/sampler.py DELETED
@@ -1,181 +0,0 @@
1
- from math import sqrt
2
- import torch.nn as nn
3
- from einops import rearrange
4
- from torch import Tensor
5
- from functools import reduce
6
- # from inspect import isfunction
7
- import torch
8
- import torch.nn.functional as F
9
-
10
- def default(val, d):
11
- if val is not None: #exists(val):
12
- return val
13
- return d #d() if isfunction(d) else d
14
-
15
- class LogNormalDistribution():
16
- def __init__(self, mean: float, std: float):
17
- self.mean = mean
18
- self.std = std
19
-
20
- def __call__(
21
- self, num_samples: int, device: torch.device = torch.device("cpu")
22
- ) -> Tensor:
23
- normal = self.mean + self.std * torch.randn((num_samples,), device=device)
24
- return normal.exp()
25
-
26
-
27
- class UniformDistribution():
28
- def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
29
- return torch.rand(num_samples, device=device)
30
-
31
- def to_batch(
32
- batch_size: int,
33
- device: torch.device,
34
- x = None,
35
- xs = None):
36
- # assert exists(x) ^ exists(xs), "Either x or xs must be provided"
37
- # If x provided use the same for all batch items
38
- if x is not None: #exists(x):
39
- xs = torch.full(size=(batch_size,), fill_value=x).to(device)
40
- # assert exists(xs)
41
- return xs
42
-
43
- class KDiffusion(nn.Module):
44
- """Elucidated Diffusion (Karras et al. 2022): https://arxiv.org/abs/2206.00364"""
45
-
46
- alias = "k"
47
-
48
- def __init__(
49
- self,
50
- net: nn.Module,
51
- *,
52
- sigma_distribution,
53
- sigma_data: float, # data distribution standard deviation
54
- dynamic_threshold: float = 0.0,
55
- ):
56
- super().__init__()
57
- self.net = net
58
- self.sigma_data = sigma_data
59
-
60
- def get_scale_weights(self, sigmas):
61
- sigma_data = self.sigma_data
62
- c_noise = torch.log(sigmas) * 0.25
63
- sigmas = rearrange(sigmas, "b -> b 1 1")
64
- c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
65
- c_out = sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
66
- c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
67
- return c_skip, c_out, c_in, c_noise
68
-
69
- def denoise_fn(
70
- self,
71
- x_noisy,
72
- sigmas = None,
73
- sigma = None,
74
- **kwargs,
75
- ):
76
- # raise ValueError
77
- batch_size, device = x_noisy.shape[0], x_noisy.device
78
- sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
79
-
80
- # Predict network output and add skip connection
81
- # print('\n\n\n\n', kwargs, '\nKWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWAr\n\n\n\n') 'embedding tensor'
82
- c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas)
83
- x_pred = self.net(c_in * x_noisy, c_noise, **kwargs)
84
- x_denoised = c_skip * x_noisy + c_out * x_pred
85
-
86
- return x_denoised
87
-
88
-
89
- class KarrasSchedule(nn.Module):
90
- """https://arxiv.org/abs/2206.00364 equation 5"""
91
-
92
- def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0):
93
- super().__init__()
94
- self.sigma_min = sigma_min
95
- self.sigma_max = sigma_max
96
- self.rho = rho
97
-
98
- def forward(self, num_steps: int, device):
99
- rho_inv = 1.0 / self.rho
100
- steps = torch.arange(num_steps, device=device, dtype=torch.float32)
101
- sigmas = (
102
- self.sigma_max ** rho_inv
103
- + (steps / (num_steps - 1))
104
- * (self.sigma_min ** rho_inv - self.sigma_max ** rho_inv)
105
- ) ** self.rho
106
- sigmas = F.pad(sigmas, pad=(0, 1), value=0.0)
107
- return sigmas
108
-
109
- class ADPM2Sampler(nn.Module):
110
- """https://www.desmos.com/calculator/jbxjlqd9mb"""
111
-
112
- diffusion_types = [KDiffusion,] # VKDiffusion]
113
-
114
- def __init__(self, rho: float = 1.0):
115
- super().__init__()
116
- self.rho = rho
117
-
118
- def get_sigmas(self,
119
- sigma,
120
- sigma_next):
121
- r = self.rho
122
- sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
123
- sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
124
- sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
125
- return sigma_up, sigma_down, sigma_mid
126
-
127
- def step(self, x, fn, sigma, sigma_next):
128
-
129
- sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next)
130
- # Derivative at sigma (∂x/∂sigma)
131
- d = (x - fn(x, sigma=sigma)) / sigma
132
- # Denoise to midpoint
133
- x_mid = x + d * (sigma_mid - sigma)
134
- # Derivative at sigma_mid (∂x_mid/∂sigma_mid)
135
- d_mid = (x_mid - fn(x_mid, sigma=sigma_mid)) / sigma_mid
136
- # Denoise to next
137
- x = x + d_mid * (sigma_down - sigma)
138
- # Add randomness
139
- x_next = x + torch.randn_like(x) * sigma_up
140
- return x_next
141
-
142
- def forward(
143
- self, noise, fn, sigmas, num_steps):
144
- # raise ValueError
145
- x = sigmas[0] * noise
146
- # Denoise to sample
147
- for i in range(num_steps - 1):
148
- x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
149
- return x
150
-
151
- class DiffusionSampler(nn.Module):
152
-
153
- def __init__(
154
- self,
155
- diffusion=None,
156
- num_steps=None,
157
- clamp=True, # default=False
158
- ):
159
- super().__init__()
160
- self.denoise_fn = diffusion.denoise_fn
161
- self.sampler = ADPM2Sampler()
162
- self.sigma_schedule = KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0)
163
- self.num_steps = num_steps
164
- self.clamp = clamp
165
-
166
-
167
- def forward(
168
- self, noise, num_steps=None, **kwargs):
169
- # raise ValueError
170
- device = noise.device
171
- num_steps = default(num_steps, self.num_steps) # type: ignore
172
-
173
- # Compute sigmas using schedule
174
- sigmas = self.sigma_schedule(num_steps, device)
175
-
176
- # L242 KWARGS dict_keys(['embedding', 'features'])
177
- fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa
178
- # Sample using sampler
179
- x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
180
- x = x.clamp(-1.0, 1.0) if self.clamp else x
181
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models.py CHANGED
@@ -11,20 +11,19 @@ import torch.nn.functional as F
11
  from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
  from Utils.ASR.models import ASRCNN
13
  from Utils.JDC.model import JDCNet
14
- from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
15
  from Modules.diffusion.modules import StyleTransformer1d
16
- # from Modules.diffusion.diffusion import AudioDiffusionConditional
17
  from munch import Munch
18
  import yaml
19
  from math import pi
20
  from random import randint
21
- # from typing import Any, Optional, Sequence, Tuple, Union
22
  import torch
23
  from einops import rearrange
24
  from torch import Tensor, nn
25
  from tqdm import tqdm
26
- # from Modules.diffusion.utils import *
27
- # from Modules.diffusion.sampler import *
28
 
29
 
30
 
@@ -623,23 +622,7 @@ def build_model(args, text_aligner, pitch_extractor, bert):
623
  else:
624
  raise NotImplementedError
625
 
626
- diffusion = AudioDiffusionConditional(
627
- in_channels=1,
628
- embedding_max_length=bert.config.max_position_embeddings,
629
- embedding_features=bert.config.hidden_size,
630
- embedding_mask_proba=args.diffusion.embedding_mask_proba, # Conditional dropout of batch elements,
631
- channels=args.style_dim*2,
632
- context_features=args.style_dim*2,
633
- )
634
- # this initialises self.diffusion for AudioDiffusionConditional
635
- diffusion.diffusion = KDiffusion(
636
- net=diffusion.unet,
637
- sigma_distribution=LogNormalDistribution(mean = args.diffusion.dist.mean, std = args.diffusion.dist.std),
638
- sigma_data=args.diffusion.dist.sigma_data, # a placeholder, will be changed dynamically when start training diffusion model
639
- dynamic_threshold=0.0
640
- )
641
- diffusion.diffusion.net = transformer
642
- diffusion.unet = transformer
643
 
644
 
645
  nets = Munch(
@@ -652,7 +635,6 @@ def build_model(args, text_aligner, pitch_extractor, bert):
652
 
653
  predictor_encoder=predictor_encoder,
654
  style_encoder=style_encoder,
655
- diffusion=diffusion,
656
 
657
  text_aligner = text_aligner,
658
  pitch_extractor=pitch_extractor
 
11
  from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
  from Utils.ASR.models import ASRCNN
13
  from Utils.JDC.model import JDCNet
14
+
15
  from Modules.diffusion.modules import StyleTransformer1d
16
+
17
  from munch import Munch
18
  import yaml
19
  from math import pi
20
  from random import randint
21
+
22
  import torch
23
  from einops import rearrange
24
  from torch import Tensor, nn
25
  from tqdm import tqdm
26
+
 
27
 
28
 
29
 
 
622
  else:
623
  raise NotImplementedError
624
 
625
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
626
 
627
 
628
  nets = Munch(
 
635
 
636
  predictor_encoder=predictor_encoder,
637
  style_encoder=style_encoder,
 
638
 
639
  text_aligner = text_aligner,
640
  pitch_extractor=pitch_extractor
msinference.py CHANGED
@@ -160,9 +160,7 @@ for key in model:
160
  # _load(params[key], model[key])
161
  _ = [model[key].eval() for key in model]
162
 
163
- from Modules.diffusion.sampler import DiffusionSampler
164
 
165
- sampler = DiffusionSampler(diffusion=model.diffusion.diffusion)
166
 
167
  def inference(text,
168
  ref_s,
@@ -205,17 +203,10 @@ def inference(text,
205
  # print('BERTdu', bert_dur.shape, tokens.shape, '\n') # bert what is the 768 per token -> IS USED in sampler
206
  # BERTdu torch.Size([1, 11, 768]) torch.Size([1, 11])
207
 
208
- s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
209
- embedding=bert_dur,
210
- features=ref_s, # reference from the same speaker as the embedding
211
- num_steps=diffusion_steps).squeeze(1)
212
-
213
-
214
- s = s_pred[:, 128:]
215
- ref = s_pred[:, :128]
216
 
217
- ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
218
- s = beta * s + (1 - beta) * ref_s[:, 128:]
219
 
220
  d = model.predictor.text_encoder(d_en,
221
  s, input_lengths, text_mask)
 
160
  # _load(params[key], model[key])
161
  _ = [model[key].eval() for key in model]
162
 
 
163
 
 
164
 
165
  def inference(text,
166
  ref_s,
 
203
  # print('BERTdu', bert_dur.shape, tokens.shape, '\n') # bert what is the 768 per token -> IS USED in sampler
204
  # BERTdu torch.Size([1, 11, 768]) torch.Size([1, 11])
205
 
206
+
 
 
 
 
 
 
 
207
 
208
+ ref = ref_s[:, :128]
209
+ s = ref_s[:, 128:]
210
 
211
  d = model.predictor.text_encoder(d_en,
212
  s, input_lengths, text_mask)