DEL DIFFUSION
Browse files- Modules/diffusion/sampler.py +0 -181
- models.py +5 -23
- msinference.py +3 -12
Modules/diffusion/sampler.py
DELETED
@@ -1,181 +0,0 @@
|
|
1 |
-
from math import sqrt
|
2 |
-
import torch.nn as nn
|
3 |
-
from einops import rearrange
|
4 |
-
from torch import Tensor
|
5 |
-
from functools import reduce
|
6 |
-
# from inspect import isfunction
|
7 |
-
import torch
|
8 |
-
import torch.nn.functional as F
|
9 |
-
|
10 |
-
def default(val, d):
|
11 |
-
if val is not None: #exists(val):
|
12 |
-
return val
|
13 |
-
return d #d() if isfunction(d) else d
|
14 |
-
|
15 |
-
class LogNormalDistribution():
|
16 |
-
def __init__(self, mean: float, std: float):
|
17 |
-
self.mean = mean
|
18 |
-
self.std = std
|
19 |
-
|
20 |
-
def __call__(
|
21 |
-
self, num_samples: int, device: torch.device = torch.device("cpu")
|
22 |
-
) -> Tensor:
|
23 |
-
normal = self.mean + self.std * torch.randn((num_samples,), device=device)
|
24 |
-
return normal.exp()
|
25 |
-
|
26 |
-
|
27 |
-
class UniformDistribution():
|
28 |
-
def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
|
29 |
-
return torch.rand(num_samples, device=device)
|
30 |
-
|
31 |
-
def to_batch(
|
32 |
-
batch_size: int,
|
33 |
-
device: torch.device,
|
34 |
-
x = None,
|
35 |
-
xs = None):
|
36 |
-
# assert exists(x) ^ exists(xs), "Either x or xs must be provided"
|
37 |
-
# If x provided use the same for all batch items
|
38 |
-
if x is not None: #exists(x):
|
39 |
-
xs = torch.full(size=(batch_size,), fill_value=x).to(device)
|
40 |
-
# assert exists(xs)
|
41 |
-
return xs
|
42 |
-
|
43 |
-
class KDiffusion(nn.Module):
|
44 |
-
"""Elucidated Diffusion (Karras et al. 2022): https://arxiv.org/abs/2206.00364"""
|
45 |
-
|
46 |
-
alias = "k"
|
47 |
-
|
48 |
-
def __init__(
|
49 |
-
self,
|
50 |
-
net: nn.Module,
|
51 |
-
*,
|
52 |
-
sigma_distribution,
|
53 |
-
sigma_data: float, # data distribution standard deviation
|
54 |
-
dynamic_threshold: float = 0.0,
|
55 |
-
):
|
56 |
-
super().__init__()
|
57 |
-
self.net = net
|
58 |
-
self.sigma_data = sigma_data
|
59 |
-
|
60 |
-
def get_scale_weights(self, sigmas):
|
61 |
-
sigma_data = self.sigma_data
|
62 |
-
c_noise = torch.log(sigmas) * 0.25
|
63 |
-
sigmas = rearrange(sigmas, "b -> b 1 1")
|
64 |
-
c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
|
65 |
-
c_out = sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
|
66 |
-
c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
|
67 |
-
return c_skip, c_out, c_in, c_noise
|
68 |
-
|
69 |
-
def denoise_fn(
|
70 |
-
self,
|
71 |
-
x_noisy,
|
72 |
-
sigmas = None,
|
73 |
-
sigma = None,
|
74 |
-
**kwargs,
|
75 |
-
):
|
76 |
-
# raise ValueError
|
77 |
-
batch_size, device = x_noisy.shape[0], x_noisy.device
|
78 |
-
sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
|
79 |
-
|
80 |
-
# Predict network output and add skip connection
|
81 |
-
# print('\n\n\n\n', kwargs, '\nKWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWAr\n\n\n\n') 'embedding tensor'
|
82 |
-
c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas)
|
83 |
-
x_pred = self.net(c_in * x_noisy, c_noise, **kwargs)
|
84 |
-
x_denoised = c_skip * x_noisy + c_out * x_pred
|
85 |
-
|
86 |
-
return x_denoised
|
87 |
-
|
88 |
-
|
89 |
-
class KarrasSchedule(nn.Module):
|
90 |
-
"""https://arxiv.org/abs/2206.00364 equation 5"""
|
91 |
-
|
92 |
-
def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0):
|
93 |
-
super().__init__()
|
94 |
-
self.sigma_min = sigma_min
|
95 |
-
self.sigma_max = sigma_max
|
96 |
-
self.rho = rho
|
97 |
-
|
98 |
-
def forward(self, num_steps: int, device):
|
99 |
-
rho_inv = 1.0 / self.rho
|
100 |
-
steps = torch.arange(num_steps, device=device, dtype=torch.float32)
|
101 |
-
sigmas = (
|
102 |
-
self.sigma_max ** rho_inv
|
103 |
-
+ (steps / (num_steps - 1))
|
104 |
-
* (self.sigma_min ** rho_inv - self.sigma_max ** rho_inv)
|
105 |
-
) ** self.rho
|
106 |
-
sigmas = F.pad(sigmas, pad=(0, 1), value=0.0)
|
107 |
-
return sigmas
|
108 |
-
|
109 |
-
class ADPM2Sampler(nn.Module):
|
110 |
-
"""https://www.desmos.com/calculator/jbxjlqd9mb"""
|
111 |
-
|
112 |
-
diffusion_types = [KDiffusion,] # VKDiffusion]
|
113 |
-
|
114 |
-
def __init__(self, rho: float = 1.0):
|
115 |
-
super().__init__()
|
116 |
-
self.rho = rho
|
117 |
-
|
118 |
-
def get_sigmas(self,
|
119 |
-
sigma,
|
120 |
-
sigma_next):
|
121 |
-
r = self.rho
|
122 |
-
sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
|
123 |
-
sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
|
124 |
-
sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
|
125 |
-
return sigma_up, sigma_down, sigma_mid
|
126 |
-
|
127 |
-
def step(self, x, fn, sigma, sigma_next):
|
128 |
-
|
129 |
-
sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next)
|
130 |
-
# Derivative at sigma (∂x/∂sigma)
|
131 |
-
d = (x - fn(x, sigma=sigma)) / sigma
|
132 |
-
# Denoise to midpoint
|
133 |
-
x_mid = x + d * (sigma_mid - sigma)
|
134 |
-
# Derivative at sigma_mid (∂x_mid/∂sigma_mid)
|
135 |
-
d_mid = (x_mid - fn(x_mid, sigma=sigma_mid)) / sigma_mid
|
136 |
-
# Denoise to next
|
137 |
-
x = x + d_mid * (sigma_down - sigma)
|
138 |
-
# Add randomness
|
139 |
-
x_next = x + torch.randn_like(x) * sigma_up
|
140 |
-
return x_next
|
141 |
-
|
142 |
-
def forward(
|
143 |
-
self, noise, fn, sigmas, num_steps):
|
144 |
-
# raise ValueError
|
145 |
-
x = sigmas[0] * noise
|
146 |
-
# Denoise to sample
|
147 |
-
for i in range(num_steps - 1):
|
148 |
-
x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
|
149 |
-
return x
|
150 |
-
|
151 |
-
class DiffusionSampler(nn.Module):
|
152 |
-
|
153 |
-
def __init__(
|
154 |
-
self,
|
155 |
-
diffusion=None,
|
156 |
-
num_steps=None,
|
157 |
-
clamp=True, # default=False
|
158 |
-
):
|
159 |
-
super().__init__()
|
160 |
-
self.denoise_fn = diffusion.denoise_fn
|
161 |
-
self.sampler = ADPM2Sampler()
|
162 |
-
self.sigma_schedule = KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0)
|
163 |
-
self.num_steps = num_steps
|
164 |
-
self.clamp = clamp
|
165 |
-
|
166 |
-
|
167 |
-
def forward(
|
168 |
-
self, noise, num_steps=None, **kwargs):
|
169 |
-
# raise ValueError
|
170 |
-
device = noise.device
|
171 |
-
num_steps = default(num_steps, self.num_steps) # type: ignore
|
172 |
-
|
173 |
-
# Compute sigmas using schedule
|
174 |
-
sigmas = self.sigma_schedule(num_steps, device)
|
175 |
-
|
176 |
-
# L242 KWARGS dict_keys(['embedding', 'features'])
|
177 |
-
fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa
|
178 |
-
# Sample using sampler
|
179 |
-
x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
|
180 |
-
x = x.clamp(-1.0, 1.0) if self.clamp else x
|
181 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models.py
CHANGED
@@ -11,20 +11,19 @@ import torch.nn.functional as F
|
|
11 |
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
12 |
from Utils.ASR.models import ASRCNN
|
13 |
from Utils.JDC.model import JDCNet
|
14 |
-
|
15 |
from Modules.diffusion.modules import StyleTransformer1d
|
16 |
-
|
17 |
from munch import Munch
|
18 |
import yaml
|
19 |
from math import pi
|
20 |
from random import randint
|
21 |
-
|
22 |
import torch
|
23 |
from einops import rearrange
|
24 |
from torch import Tensor, nn
|
25 |
from tqdm import tqdm
|
26 |
-
|
27 |
-
# from Modules.diffusion.sampler import *
|
28 |
|
29 |
|
30 |
|
@@ -623,23 +622,7 @@ def build_model(args, text_aligner, pitch_extractor, bert):
|
|
623 |
else:
|
624 |
raise NotImplementedError
|
625 |
|
626 |
-
|
627 |
-
in_channels=1,
|
628 |
-
embedding_max_length=bert.config.max_position_embeddings,
|
629 |
-
embedding_features=bert.config.hidden_size,
|
630 |
-
embedding_mask_proba=args.diffusion.embedding_mask_proba, # Conditional dropout of batch elements,
|
631 |
-
channels=args.style_dim*2,
|
632 |
-
context_features=args.style_dim*2,
|
633 |
-
)
|
634 |
-
# this initialises self.diffusion for AudioDiffusionConditional
|
635 |
-
diffusion.diffusion = KDiffusion(
|
636 |
-
net=diffusion.unet,
|
637 |
-
sigma_distribution=LogNormalDistribution(mean = args.diffusion.dist.mean, std = args.diffusion.dist.std),
|
638 |
-
sigma_data=args.diffusion.dist.sigma_data, # a placeholder, will be changed dynamically when start training diffusion model
|
639 |
-
dynamic_threshold=0.0
|
640 |
-
)
|
641 |
-
diffusion.diffusion.net = transformer
|
642 |
-
diffusion.unet = transformer
|
643 |
|
644 |
|
645 |
nets = Munch(
|
@@ -652,7 +635,6 @@ def build_model(args, text_aligner, pitch_extractor, bert):
|
|
652 |
|
653 |
predictor_encoder=predictor_encoder,
|
654 |
style_encoder=style_encoder,
|
655 |
-
diffusion=diffusion,
|
656 |
|
657 |
text_aligner = text_aligner,
|
658 |
pitch_extractor=pitch_extractor
|
|
|
11 |
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
12 |
from Utils.ASR.models import ASRCNN
|
13 |
from Utils.JDC.model import JDCNet
|
14 |
+
|
15 |
from Modules.diffusion.modules import StyleTransformer1d
|
16 |
+
|
17 |
from munch import Munch
|
18 |
import yaml
|
19 |
from math import pi
|
20 |
from random import randint
|
21 |
+
|
22 |
import torch
|
23 |
from einops import rearrange
|
24 |
from torch import Tensor, nn
|
25 |
from tqdm import tqdm
|
26 |
+
|
|
|
27 |
|
28 |
|
29 |
|
|
|
622 |
else:
|
623 |
raise NotImplementedError
|
624 |
|
625 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
626 |
|
627 |
|
628 |
nets = Munch(
|
|
|
635 |
|
636 |
predictor_encoder=predictor_encoder,
|
637 |
style_encoder=style_encoder,
|
|
|
638 |
|
639 |
text_aligner = text_aligner,
|
640 |
pitch_extractor=pitch_extractor
|
msinference.py
CHANGED
@@ -160,9 +160,7 @@ for key in model:
|
|
160 |
# _load(params[key], model[key])
|
161 |
_ = [model[key].eval() for key in model]
|
162 |
|
163 |
-
from Modules.diffusion.sampler import DiffusionSampler
|
164 |
|
165 |
-
sampler = DiffusionSampler(diffusion=model.diffusion.diffusion)
|
166 |
|
167 |
def inference(text,
|
168 |
ref_s,
|
@@ -205,17 +203,10 @@ def inference(text,
|
|
205 |
# print('BERTdu', bert_dur.shape, tokens.shape, '\n') # bert what is the 768 per token -> IS USED in sampler
|
206 |
# BERTdu torch.Size([1, 11, 768]) torch.Size([1, 11])
|
207 |
|
208 |
-
|
209 |
-
embedding=bert_dur,
|
210 |
-
features=ref_s, # reference from the same speaker as the embedding
|
211 |
-
num_steps=diffusion_steps).squeeze(1)
|
212 |
-
|
213 |
-
|
214 |
-
s = s_pred[:, 128:]
|
215 |
-
ref = s_pred[:, :128]
|
216 |
|
217 |
-
ref =
|
218 |
-
s =
|
219 |
|
220 |
d = model.predictor.text_encoder(d_en,
|
221 |
s, input_lengths, text_mask)
|
|
|
160 |
# _load(params[key], model[key])
|
161 |
_ = [model[key].eval() for key in model]
|
162 |
|
|
|
163 |
|
|
|
164 |
|
165 |
def inference(text,
|
166 |
ref_s,
|
|
|
203 |
# print('BERTdu', bert_dur.shape, tokens.shape, '\n') # bert what is the 768 per token -> IS USED in sampler
|
204 |
# BERTdu torch.Size([1, 11, 768]) torch.Size([1, 11])
|
205 |
|
206 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
|
208 |
+
ref = ref_s[:, :128]
|
209 |
+
s = ref_s[:, 128:]
|
210 |
|
211 |
d = model.predictor.text_encoder(d_en,
|
212 |
s, input_lengths, text_mask)
|