fx diffusion
Browse files- Modules/diffusion/sampler.py +5 -28
- Utils/PLBERT/util.py +1 -1
- msinference.py +4 -9
Modules/diffusion/sampler.py
CHANGED
@@ -4,7 +4,6 @@ from einops import rearrange
|
|
4 |
from torch import Tensor
|
5 |
from functools import reduce
|
6 |
# from inspect import isfunction
|
7 |
-
# from math import ceil, floor, log2, pi
|
8 |
import torch
|
9 |
import torch.nn.functional as F
|
10 |
|
@@ -29,8 +28,6 @@ class UniformDistribution():
|
|
29 |
def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
|
30 |
return torch.rand(num_samples, device=device)
|
31 |
|
32 |
-
|
33 |
-
|
34 |
def to_batch(
|
35 |
batch_size: int,
|
36 |
device: torch.device,
|
@@ -59,8 +56,6 @@ class KDiffusion(nn.Module):
|
|
59 |
super().__init__()
|
60 |
self.net = net
|
61 |
self.sigma_data = sigma_data
|
62 |
-
self.sigma_distribution = sigma_distribution
|
63 |
-
self.dynamic_threshold = dynamic_threshold
|
64 |
|
65 |
def get_scale_weights(self, sigmas):
|
66 |
sigma_data = self.sigma_data
|
@@ -91,17 +86,6 @@ class KDiffusion(nn.Module):
|
|
91 |
return x_denoised
|
92 |
|
93 |
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
class KarrasSchedule(nn.Module):
|
106 |
"""https://arxiv.org/abs/2206.00364 equation 5"""
|
107 |
|
@@ -165,27 +149,20 @@ class ADPM2Sampler(nn.Module):
|
|
165 |
return x
|
166 |
|
167 |
class DiffusionSampler(nn.Module):
|
|
|
168 |
def __init__(
|
169 |
self,
|
170 |
-
diffusion,
|
171 |
-
*,
|
172 |
-
sampler,
|
173 |
-
sigma_schedule,
|
174 |
num_steps=None,
|
175 |
-
clamp=True,
|
176 |
):
|
177 |
super().__init__()
|
178 |
self.denoise_fn = diffusion.denoise_fn
|
179 |
-
self.sampler =
|
180 |
-
self.sigma_schedule =
|
181 |
self.num_steps = num_steps
|
182 |
self.clamp = clamp
|
183 |
|
184 |
-
# Check sampler is compatible with diffusion type
|
185 |
-
sampler_class = sampler.__class__.__name__
|
186 |
-
diffusion_class = diffusion.__class__.__name__
|
187 |
-
message = f"{sampler_class} incompatible with {diffusion_class}"
|
188 |
-
assert diffusion.alias in [t.alias for t in sampler.diffusion_types], message
|
189 |
|
190 |
def forward(
|
191 |
self, noise, num_steps=None, **kwargs):
|
|
|
4 |
from torch import Tensor
|
5 |
from functools import reduce
|
6 |
# from inspect import isfunction
|
|
|
7 |
import torch
|
8 |
import torch.nn.functional as F
|
9 |
|
|
|
28 |
def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
|
29 |
return torch.rand(num_samples, device=device)
|
30 |
|
|
|
|
|
31 |
def to_batch(
|
32 |
batch_size: int,
|
33 |
device: torch.device,
|
|
|
56 |
super().__init__()
|
57 |
self.net = net
|
58 |
self.sigma_data = sigma_data
|
|
|
|
|
59 |
|
60 |
def get_scale_weights(self, sigmas):
|
61 |
sigma_data = self.sigma_data
|
|
|
86 |
return x_denoised
|
87 |
|
88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
class KarrasSchedule(nn.Module):
|
90 |
"""https://arxiv.org/abs/2206.00364 equation 5"""
|
91 |
|
|
|
149 |
return x
|
150 |
|
151 |
class DiffusionSampler(nn.Module):
|
152 |
+
|
153 |
def __init__(
|
154 |
self,
|
155 |
+
diffusion=None,
|
|
|
|
|
|
|
156 |
num_steps=None,
|
157 |
+
clamp=True, # default=False
|
158 |
):
|
159 |
super().__init__()
|
160 |
self.denoise_fn = diffusion.denoise_fn
|
161 |
+
self.sampler = ADPM2Sampler()
|
162 |
+
self.sigma_schedule = KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0)
|
163 |
self.num_steps = num_steps
|
164 |
self.clamp = clamp
|
165 |
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
def forward(
|
168 |
self, noise, num_steps=None, **kwargs):
|
Utils/PLBERT/util.py
CHANGED
@@ -37,6 +37,6 @@ def load_plbert(log_dir):
|
|
37 |
name = name[8:] # remove `encoder.`
|
38 |
new_state_dict[name] = v
|
39 |
del new_state_dict["embeddings.position_ids"]
|
40 |
-
bert.load_state_dict(new_state_dict, strict=
|
41 |
|
42 |
return bert
|
|
|
37 |
name = name[8:] # remove `encoder.`
|
38 |
new_state_dict[name] = v
|
39 |
del new_state_dict["embeddings.position_ids"]
|
40 |
+
bert.load_state_dict(new_state_dict, strict=True)
|
41 |
|
42 |
return bert
|
msinference.py
CHANGED
@@ -17,8 +17,8 @@ from torch import nn
|
|
17 |
from nltk.tokenize import word_tokenize
|
18 |
|
19 |
torch.manual_seed(0)
|
20 |
-
torch.backends.cudnn.benchmark = False
|
21 |
-
torch.backends.cudnn.deterministic = True
|
22 |
|
23 |
|
24 |
# IPA Phonemizer: https://github.com/bootphon/phonemizer
|
@@ -160,14 +160,9 @@ for key in model:
|
|
160 |
# _load(params[key], model[key])
|
161 |
_ = [model[key].eval() for key in model]
|
162 |
|
163 |
-
from Modules.diffusion.sampler import DiffusionSampler
|
164 |
|
165 |
-
sampler = DiffusionSampler(
|
166 |
-
model.diffusion.diffusion,
|
167 |
-
sampler=ADPM2Sampler(),
|
168 |
-
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
|
169 |
-
clamp=False
|
170 |
-
)
|
171 |
|
172 |
def inference(text,
|
173 |
ref_s,
|
|
|
17 |
from nltk.tokenize import word_tokenize
|
18 |
|
19 |
torch.manual_seed(0)
|
20 |
+
# torch.backends.cudnn.benchmark = False
|
21 |
+
# torch.backends.cudnn.deterministic = True
|
22 |
|
23 |
|
24 |
# IPA Phonemizer: https://github.com/bootphon/phonemizer
|
|
|
160 |
# _load(params[key], model[key])
|
161 |
_ = [model[key].eval() for key in model]
|
162 |
|
163 |
+
from Modules.diffusion.sampler import DiffusionSampler
|
164 |
|
165 |
+
sampler = DiffusionSampler(diffusion=model.diffusion.diffusion)
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
def inference(text,
|
168 |
ref_s,
|