del voice-blind embedding
Browse files- Modules/diffusion/modules.py +12 -41
Modules/diffusion/modules.py
CHANGED
@@ -117,10 +117,9 @@ class StyleTransformer1d(nn.Module):
|
|
117 |
nn.GELU(),
|
118 |
)
|
119 |
|
120 |
-
self.fixed_embedding = FixedEmbedding(
|
121 |
-
|
122 |
-
)
|
123 |
-
|
124 |
|
125 |
def get_mapping(
|
126 |
self,
|
@@ -144,40 +143,26 @@ class StyleTransformer1d(nn.Module):
|
|
144 |
mapping = self.to_mapping(mapping)
|
145 |
|
146 |
return mapping
|
147 |
-
|
148 |
-
def
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
mapping = self.get_mapping(time, features)
|
152 |
x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
|
153 |
mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
|
154 |
-
|
155 |
for block in self.blocks:
|
156 |
x = x + mapping
|
157 |
x = block(x, features)
|
158 |
-
|
159 |
x = x.mean(axis=1).unsqueeze(1)
|
160 |
x = self.to_out(x)
|
161 |
x = x.transpose(-1, -2)
|
162 |
-
|
163 |
return x
|
164 |
-
|
165 |
-
def forward(self,
|
166 |
-
x,
|
167 |
-
time,
|
168 |
-
embedding= None,
|
169 |
-
features = None):
|
170 |
-
|
171 |
-
b, device = embedding.shape[0], embedding.device
|
172 |
-
# if
|
173 |
-
# embedding_mask_proba: float = 0.0, > 0
|
174 |
-
# fixed_embedding = self.fixed_embedding(embedding)
|
175 |
-
# embedding = torch.where(batch_mask, fixed_embedding, embedding)
|
176 |
-
return self.run(x,
|
177 |
-
time,
|
178 |
-
embedding=embedding,
|
179 |
-
# embedding=self.fixed_embedding(embedding), # fixedemb has noisy beginnings on chapters.wav
|
180 |
-
features=features)
|
181 |
|
182 |
|
183 |
class StyleTransformerBlock(nn.Module):
|
@@ -379,17 +364,3 @@ def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
|
|
379 |
nn.Linear(in_features=dim + 1, out_features=out_features),
|
380 |
)
|
381 |
|
382 |
-
class FixedEmbedding(nn.Module):
|
383 |
-
def __init__(self, max_length: int, features: int):
|
384 |
-
super().__init__()
|
385 |
-
self.max_length = max_length
|
386 |
-
self.embedding = nn.Embedding(max_length, features)
|
387 |
-
|
388 |
-
def forward(self, x: Tensor) -> Tensor:
|
389 |
-
batch_size, length, device = *x.shape[0:2], x.device
|
390 |
-
assert_message = "Input sequence length must be <= max_length"
|
391 |
-
assert length <= self.max_length, assert_message
|
392 |
-
position = torch.arange(length, device=device)
|
393 |
-
fixed_embedding = self.embedding(position)
|
394 |
-
fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
|
395 |
-
return fixed_embedding
|
|
|
117 |
nn.GELU(),
|
118 |
)
|
119 |
|
120 |
+
# self.fixed_embedding = FixedEmbedding(
|
121 |
+
# max_length=embedding_max_length, features=context_embedding_features
|
122 |
+
# ) # Non speker-aware LookUp: EMbedding looks just the time-frame-index [0,1,2...,num-asr-time-frames]
|
|
|
123 |
|
124 |
def get_mapping(
|
125 |
self,
|
|
|
143 |
mapping = self.to_mapping(mapping)
|
144 |
|
145 |
return mapping
|
146 |
+
|
147 |
+
def forward(self,
|
148 |
+
x,
|
149 |
+
time,
|
150 |
+
embedding= None,
|
151 |
+
features = None):
|
152 |
+
|
153 |
+
# --
|
154 |
+
# called by forward()
|
155 |
|
156 |
mapping = self.get_mapping(time, features)
|
157 |
x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
|
158 |
mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
|
|
|
159 |
for block in self.blocks:
|
160 |
x = x + mapping
|
161 |
x = block(x, features)
|
|
|
162 |
x = x.mean(axis=1).unsqueeze(1)
|
163 |
x = self.to_out(x)
|
164 |
x = x.transpose(-1, -2)
|
|
|
165 |
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
|
168 |
class StyleTransformerBlock(nn.Module):
|
|
|
364 |
nn.Linear(in_features=dim + 1, out_features=out_features),
|
365 |
)
|
366 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|