Spaces:
Running
on
Zero
Running
on
Zero
Update models/ssr.py
Browse files- models/ssr.py +9 -2
models/ssr.py
CHANGED
@@ -517,6 +517,7 @@ class SSR_Speech(
|
|
517 |
kvcache: int=1,
|
518 |
silence_tokens: list[int]=[1388,1898,131],
|
519 |
cfg_coef: float=1.5,
|
|
|
520 |
aug_text: bool=False,
|
521 |
aug_context: bool=False,
|
522 |
cfg_pretrained: bool=False,
|
@@ -648,6 +649,7 @@ class SSR_Speech(
|
|
648 |
consec_silence_count = 0
|
649 |
num_gen = 0
|
650 |
num_eog = 0
|
|
|
651 |
|
652 |
# add mask token
|
653 |
mts = torch.LongTensor([emb_inds[idx]] * self.args.n_codebooks).unsqueeze(-1).to(embedded_y.device) # K, 1
|
@@ -686,7 +688,12 @@ class SSR_Speech(
|
|
686 |
logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card], B==S==1, so [1 K 1 card]
|
687 |
logits = logits.squeeze() # [K card]
|
688 |
if aug_text:
|
689 |
-
|
|
|
|
|
|
|
|
|
|
|
690 |
assert logits.shape == torch.Size((self.args.n_codebooks, self.n_audio_tokens[0])), f"{logits.shape}"
|
691 |
# filter out mts, sos and eos
|
692 |
for jj in range(self.args.n_codebooks):
|
@@ -807,4 +814,4 @@ class SSR_Speech(
|
|
807 |
|
808 |
if __name__ == "__main__":
|
809 |
# debug
|
810 |
-
pass
|
|
|
517 |
kvcache: int=1,
|
518 |
silence_tokens: list[int]=[1388,1898,131],
|
519 |
cfg_coef: float=1.5,
|
520 |
+
cfg_stride: int=1,
|
521 |
aug_text: bool=False,
|
522 |
aug_context: bool=False,
|
523 |
cfg_pretrained: bool=False,
|
|
|
649 |
consec_silence_count = 0
|
650 |
num_gen = 0
|
651 |
num_eog = 0
|
652 |
+
num_cfg_tag = 1
|
653 |
|
654 |
# add mask token
|
655 |
mts = torch.LongTensor([emb_inds[idx]] * self.args.n_codebooks).unsqueeze(-1).to(embedded_y.device) # K, 1
|
|
|
688 |
logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card], B==S==1, so [1 K 1 card]
|
689 |
logits = logits.squeeze() # [K card]
|
690 |
if aug_text:
|
691 |
+
if num_cfg_tag == cfg_stride:
|
692 |
+
logits = cfg_coef * logits[0] + (1 - cfg_coef) * logits[1]
|
693 |
+
num_cfg_tag = 1
|
694 |
+
else:
|
695 |
+
num_cfg_tag += 1
|
696 |
+
logits = logits[0]
|
697 |
assert logits.shape == torch.Size((self.args.n_codebooks, self.n_audio_tokens[0])), f"{logits.shape}"
|
698 |
# filter out mts, sos and eos
|
699 |
for jj in range(self.args.n_codebooks):
|
|
|
814 |
|
815 |
if __name__ == "__main__":
|
816 |
# debug
|
817 |
+
pass
|