Dionyssos commited on
Commit
52c4e0a
·
1 Parent(s): 6a0b5fd

del voice-blind embedding

Browse files
Files changed (1) hide show
  1. Modules/diffusion/modules.py +12 -41
Modules/diffusion/modules.py CHANGED
@@ -117,10 +117,9 @@ class StyleTransformer1d(nn.Module):
117
  nn.GELU(),
118
  )
119
 
120
- self.fixed_embedding = FixedEmbedding(
121
- max_length=embedding_max_length, features=context_embedding_features
122
- )
123
-
124
 
125
  def get_mapping(
126
  self,
@@ -144,40 +143,26 @@ class StyleTransformer1d(nn.Module):
144
  mapping = self.to_mapping(mapping)
145
 
146
  return mapping
147
-
148
- def run(self, x, time, embedding, features):
149
- # called by forward()
 
 
 
 
 
 
150
 
151
  mapping = self.get_mapping(time, features)
152
  x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
153
  mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
154
-
155
  for block in self.blocks:
156
  x = x + mapping
157
  x = block(x, features)
158
-
159
  x = x.mean(axis=1).unsqueeze(1)
160
  x = self.to_out(x)
161
  x = x.transpose(-1, -2)
162
-
163
  return x
164
-
165
- def forward(self,
166
- x,
167
- time,
168
- embedding= None,
169
- features = None):
170
-
171
- b, device = embedding.shape[0], embedding.device
172
- # if
173
- # embedding_mask_proba: float = 0.0, > 0
174
- # fixed_embedding = self.fixed_embedding(embedding)
175
- # embedding = torch.where(batch_mask, fixed_embedding, embedding)
176
- return self.run(x,
177
- time,
178
- embedding=embedding,
179
- # embedding=self.fixed_embedding(embedding), # fixedemb has noisy beginnings on chapters.wav
180
- features=features)
181
 
182
 
183
  class StyleTransformerBlock(nn.Module):
@@ -379,17 +364,3 @@ def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
379
  nn.Linear(in_features=dim + 1, out_features=out_features),
380
  )
381
 
382
- class FixedEmbedding(nn.Module):
383
- def __init__(self, max_length: int, features: int):
384
- super().__init__()
385
- self.max_length = max_length
386
- self.embedding = nn.Embedding(max_length, features)
387
-
388
- def forward(self, x: Tensor) -> Tensor:
389
- batch_size, length, device = *x.shape[0:2], x.device
390
- assert_message = "Input sequence length must be <= max_length"
391
- assert length <= self.max_length, assert_message
392
- position = torch.arange(length, device=device)
393
- fixed_embedding = self.embedding(position)
394
- fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
395
- return fixed_embedding
 
117
  nn.GELU(),
118
  )
119
 
120
+ # self.fixed_embedding = FixedEmbedding(
121
+ # max_length=embedding_max_length, features=context_embedding_features
122
+ # ) # Non speker-aware LookUp: EMbedding looks just the time-frame-index [0,1,2...,num-asr-time-frames]
 
123
 
124
  def get_mapping(
125
  self,
 
143
  mapping = self.to_mapping(mapping)
144
 
145
  return mapping
146
+
147
+ def forward(self,
148
+ x,
149
+ time,
150
+ embedding= None,
151
+ features = None):
152
+
153
+ # --
154
+ # called by forward()
155
 
156
  mapping = self.get_mapping(time, features)
157
  x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
158
  mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
 
159
  for block in self.blocks:
160
  x = x + mapping
161
  x = block(x, features)
 
162
  x = x.mean(axis=1).unsqueeze(1)
163
  x = self.to_out(x)
164
  x = x.transpose(-1, -2)
 
165
  return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
 
168
  class StyleTransformerBlock(nn.Module):
 
364
  nn.Linear(in_features=dim + 1, out_features=out_features),
365
  )
366