Dionyssos commited on
Commit
8cc530a
·
1 Parent(s): 52c4e0a

fx diffusion

Browse files
Modules/diffusion/sampler.py CHANGED
@@ -4,7 +4,6 @@ from einops import rearrange
4
  from torch import Tensor
5
  from functools import reduce
6
  # from inspect import isfunction
7
- # from math import ceil, floor, log2, pi
8
  import torch
9
  import torch.nn.functional as F
10
 
@@ -29,8 +28,6 @@ class UniformDistribution():
29
  def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
30
  return torch.rand(num_samples, device=device)
31
 
32
-
33
-
34
  def to_batch(
35
  batch_size: int,
36
  device: torch.device,
@@ -59,8 +56,6 @@ class KDiffusion(nn.Module):
59
  super().__init__()
60
  self.net = net
61
  self.sigma_data = sigma_data
62
- self.sigma_distribution = sigma_distribution
63
- self.dynamic_threshold = dynamic_threshold
64
 
65
  def get_scale_weights(self, sigmas):
66
  sigma_data = self.sigma_data
@@ -91,17 +86,6 @@ class KDiffusion(nn.Module):
91
  return x_denoised
92
 
93
 
94
-
95
-
96
-
97
-
98
-
99
-
100
-
101
-
102
-
103
-
104
-
105
  class KarrasSchedule(nn.Module):
106
  """https://arxiv.org/abs/2206.00364 equation 5"""
107
 
@@ -165,27 +149,20 @@ class ADPM2Sampler(nn.Module):
165
  return x
166
 
167
  class DiffusionSampler(nn.Module):
 
168
  def __init__(
169
  self,
170
- diffusion,
171
- *,
172
- sampler,
173
- sigma_schedule,
174
  num_steps=None,
175
- clamp=True,
176
  ):
177
  super().__init__()
178
  self.denoise_fn = diffusion.denoise_fn
179
- self.sampler = sampler
180
- self.sigma_schedule = sigma_schedule
181
  self.num_steps = num_steps
182
  self.clamp = clamp
183
 
184
- # Check sampler is compatible with diffusion type
185
- sampler_class = sampler.__class__.__name__
186
- diffusion_class = diffusion.__class__.__name__
187
- message = f"{sampler_class} incompatible with {diffusion_class}"
188
- assert diffusion.alias in [t.alias for t in sampler.diffusion_types], message
189
 
190
  def forward(
191
  self, noise, num_steps=None, **kwargs):
 
4
  from torch import Tensor
5
  from functools import reduce
6
  # from inspect import isfunction
 
7
  import torch
8
  import torch.nn.functional as F
9
 
 
28
  def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
29
  return torch.rand(num_samples, device=device)
30
 
 
 
31
  def to_batch(
32
  batch_size: int,
33
  device: torch.device,
 
56
  super().__init__()
57
  self.net = net
58
  self.sigma_data = sigma_data
 
 
59
 
60
  def get_scale_weights(self, sigmas):
61
  sigma_data = self.sigma_data
 
86
  return x_denoised
87
 
88
 
 
 
 
 
 
 
 
 
 
 
 
89
  class KarrasSchedule(nn.Module):
90
  """https://arxiv.org/abs/2206.00364 equation 5"""
91
 
 
149
  return x
150
 
151
  class DiffusionSampler(nn.Module):
152
+
153
  def __init__(
154
  self,
155
+ diffusion=None,
 
 
 
156
  num_steps=None,
157
+ clamp=True, # default=False
158
  ):
159
  super().__init__()
160
  self.denoise_fn = diffusion.denoise_fn
161
+ self.sampler = ADPM2Sampler()
162
+ self.sigma_schedule = KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0)
163
  self.num_steps = num_steps
164
  self.clamp = clamp
165
 
 
 
 
 
 
166
 
167
  def forward(
168
  self, noise, num_steps=None, **kwargs):
Utils/PLBERT/util.py CHANGED
@@ -37,6 +37,6 @@ def load_plbert(log_dir):
37
  name = name[8:] # remove `encoder.`
38
  new_state_dict[name] = v
39
  del new_state_dict["embeddings.position_ids"]
40
- bert.load_state_dict(new_state_dict, strict=False)
41
 
42
  return bert
 
37
  name = name[8:] # remove `encoder.`
38
  new_state_dict[name] = v
39
  del new_state_dict["embeddings.position_ids"]
40
+ bert.load_state_dict(new_state_dict, strict=True)
41
 
42
  return bert
msinference.py CHANGED
@@ -17,8 +17,8 @@ from torch import nn
17
  from nltk.tokenize import word_tokenize
18
 
19
  torch.manual_seed(0)
20
- torch.backends.cudnn.benchmark = False
21
- torch.backends.cudnn.deterministic = True
22
 
23
 
24
  # IPA Phonemizer: https://github.com/bootphon/phonemizer
@@ -160,14 +160,9 @@ for key in model:
160
  # _load(params[key], model[key])
161
  _ = [model[key].eval() for key in model]
162
 
163
- from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
164
 
165
- sampler = DiffusionSampler(
166
- model.diffusion.diffusion,
167
- sampler=ADPM2Sampler(),
168
- sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
169
- clamp=False
170
- )
171
 
172
  def inference(text,
173
  ref_s,
 
17
  from nltk.tokenize import word_tokenize
18
 
19
  torch.manual_seed(0)
20
+ # torch.backends.cudnn.benchmark = False
21
+ # torch.backends.cudnn.deterministic = True
22
 
23
 
24
  # IPA Phonemizer: https://github.com/bootphon/phonemizer
 
160
  # _load(params[key], model[key])
161
  _ = [model[key].eval() for key in model]
162
 
163
+ from Modules.diffusion.sampler import DiffusionSampler
164
 
165
+ sampler = DiffusionSampler(diffusion=model.diffusion.diffusion)
 
 
 
 
 
166
 
167
  def inference(text,
168
  ref_s,