OpenSound commited on
Commit
888ba34
·
1 Parent(s): 6d01598

Update models/ssr.py

Browse files
Files changed (1) hide show
  1. models/ssr.py +9 -2
models/ssr.py CHANGED
@@ -517,6 +517,7 @@ class SSR_Speech(
517
  kvcache: int=1,
518
  silence_tokens: list[int]=[1388,1898,131],
519
  cfg_coef: float=1.5,
 
520
  aug_text: bool=False,
521
  aug_context: bool=False,
522
  cfg_pretrained: bool=False,
@@ -648,6 +649,7 @@ class SSR_Speech(
648
  consec_silence_count = 0
649
  num_gen = 0
650
  num_eog = 0
 
651
 
652
  # add mask token
653
  mts = torch.LongTensor([emb_inds[idx]] * self.args.n_codebooks).unsqueeze(-1).to(embedded_y.device) # K, 1
@@ -686,7 +688,12 @@ class SSR_Speech(
686
  logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card], B==S==1, so [1 K 1 card]
687
  logits = logits.squeeze() # [K card]
688
  if aug_text:
689
- logits = cfg_coef * logits[0] + (1 - cfg_coef) * logits[1]
 
 
 
 
 
690
  assert logits.shape == torch.Size((self.args.n_codebooks, self.n_audio_tokens[0])), f"{logits.shape}"
691
  # filter out mts, sos and eos
692
  for jj in range(self.args.n_codebooks):
@@ -807,4 +814,4 @@ class SSR_Speech(
807
 
808
  if __name__ == "__main__":
809
  # debug
810
- pass
 
517
  kvcache: int=1,
518
  silence_tokens: list[int]=[1388,1898,131],
519
  cfg_coef: float=1.5,
520
+ cfg_stride: int=1,
521
  aug_text: bool=False,
522
  aug_context: bool=False,
523
  cfg_pretrained: bool=False,
 
649
  consec_silence_count = 0
650
  num_gen = 0
651
  num_eog = 0
652
+ num_cfg_tag = 1
653
 
654
  # add mask token
655
  mts = torch.LongTensor([emb_inds[idx]] * self.args.n_codebooks).unsqueeze(-1).to(embedded_y.device) # K, 1
 
688
  logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card], B==S==1, so [1 K 1 card]
689
  logits = logits.squeeze() # [K card]
690
  if aug_text:
691
+ if num_cfg_tag == cfg_stride:
692
+ logits = cfg_coef * logits[0] + (1 - cfg_coef) * logits[1]
693
+ num_cfg_tag = 1
694
+ else:
695
+ num_cfg_tag += 1
696
+ logits = logits[0]
697
  assert logits.shape == torch.Size((self.args.n_codebooks, self.n_audio_tokens[0])), f"{logits.shape}"
698
  # filter out mts, sos and eos
699
  for jj in range(self.args.n_codebooks):
 
814
 
815
  if __name__ == "__main__":
816
  # debug
817
+ pass