# @ hwang258@jh.edu import random import numpy as np import logging import argparse, copy from typing import Dict, Optional import torch import torch.nn as nn import torch.nn.functional as F from torchmetrics.classification import MulticlassAccuracy from .modules.utils import make_pad_mask from .modules.embedding import SinePositionalEmbedding, TokenEmbedding from .modules.transformer import ( LayerNorm, TransformerEncoder, TransformerEncoderLayer, ) from huggingface_hub import PyTorchModelHubMixin from argparse import Namespace import typing as tp def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 ): """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (batch size, vocabulary size) if top_k > 0: keep only top k tokens with highest probability (top-k filtering). if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) Make sure we keep at least min_tokens_to_keep per batch example in the output From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ if top_k > 0: top_k = min( max(top_k, min_tokens_to_keep), logits.size(-1) ) # Safety check # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum( F.softmax(sorted_logits, dim=-1), dim=-1 ) # Remove tokens with cumulative probability above the threshold (token with 0 are kept) sorted_indices_to_remove = cumulative_probs > top_p if min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ ..., :-1 ].clone() sorted_indices_to_remove[..., 0] = 0 # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter( 1, sorted_indices, sorted_indices_to_remove ) logits[indices_to_remove] = filter_value return logits def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0): # temperature: (`optional`) float # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. # top_k: (`optional`) int # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50. # top_p: (`optional`) float # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1. # Temperature (higher temperature => more likely to sample low probability tokens) if temperature != 1.0: logits = logits / temperature # Top-p/top-k filtering logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) # Sample token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) return token class SSR_Speech( nn.Module, PyTorchModelHubMixin, library_name="ssr_speech", repo_url=None, tags=None, ): def __new__(cls, args: Optional[Namespace] = None, config: Optional[Dict] = None, **kwargs) -> "SSR_Speech": # If initialized from Namespace args => convert to dict config for 'PyTorchModelHubMixin' to serialize it as config.json # Won't affect instance initialization if args is not None: if config is not None: raise ValueError("Cannot provide both `args` and `config`.") config = vars(args) return super().__new__(cls, args=args, config=config, **kwargs) def __init__(self, args: Optional[Namespace] = None, config: Optional[Dict] = None): super().__init__() # If loaded from HF Hub => convert config.json to Namespace args before initializing if args is None: if config is None: raise ValueError("Either `args` or `config` must be provided.") args = Namespace(**config) self.args = copy.copy(args) if not getattr(self.args, "n_special", False): self.args.n_special = 3 self.args.eos = getattr(self.args, "eos", -1) if isinstance(self.args.audio_vocab_size, str): self.args.audio_vocab_size = eval(self.args.audio_vocab_size) self.n_text_tokens = self.args.text_vocab_size + 1 assert self.args.text_pad_token == self.args.text_vocab_size, f"self.args.text_vocab_size: {self.args.text_vocab_size}, self.args.text_pad_token: {self.args.text_pad_token}" self.n_audio_tokens = [int(self.args.audio_vocab_size) + self.args.n_special + self.args.max_n_spans] * self.args.n_codebooks # special tokens: empty token, EOG token, audio pad token, mask tokens assert self.args.audio_vocab_size == self.args.empty_token, self.args.empty_token assert self.args.eog == self.args.audio_vocab_size + 1, self.args.eog assert self.args.audio_pad_token == self.args.audio_vocab_size + 2, self.args.audio_pad_token assert self.args.eos == self.args.audio_vocab_size + 3, self.args.eos assert self.args.sos == self.args.audio_vocab_size + 4, self.args.sos assert self.args.mts == self.args.audio_vocab_size + 5, self.args.mts self.text_embedding = TokenEmbedding( dim_model=self.args.d_model, vocab_size=self.n_text_tokens, dropout=self.args.text_embedding_dropout ) self.audio_embedding = nn.ModuleList( [ TokenEmbedding( dim_model=self.args.audio_embedding_dim, vocab_size=self.n_audio_tokens[k], dropout=self.args.audio_embedding_dropout ) for k in range(self.args.n_codebooks) ] ) self.text_positional_embedding = SinePositionalEmbedding( self.args.d_model, dropout=self.args.text_positional_embedding_dropout, scale=False, alpha=True, # learnable scaler, scale the volume of positional embedding ) self.audio_positional_embedding = SinePositionalEmbedding( self.args.d_model, dropout=self.args.audio_positional_embedding_dropout, scale=False, alpha=True, # learnable scaler, scale the volume of positional embedding ) dec_layer = TransformerEncoderLayer( self.args.d_model, self.args.nhead, dim_feedforward=self.args.d_model * 4, dropout=self.args.trm_dropout, batch_first=True, norm_first=True, layer_norm_cls=LayerNorm ) self.decoder = TransformerEncoder( dec_layer, num_layers=self.args.num_decoder_layers, norm=LayerNorm(self.args.d_model), ) self.predict_layer = nn.ModuleList( [ nn.Sequential(nn.Linear(self.args.d_model, self.args.audio_vocab_size//2), nn.GELU(), nn.Linear(self.args.audio_vocab_size//2, self.n_audio_tokens[k])) for k in range(self.args.n_codebooks) ] ) self.accuracy_metrics = nn.ModuleList( [MulticlassAccuracy( self.n_audio_tokens[k], top_k=10, average="micro", multidim_average="global", ignore_index=None, ) for k in range(self.args.n_codebooks)] ) def embed_y(self, cated_y): # [K,T,B] embedded_y = torch.stack([self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)], dim=0) # [K, T, B, D] assert embedded_y.shape[0] == self.args.n_codebooks, embedded_y.shape assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape embedded_y = embedded_y.sum(dim=0) # [K,T,B,D]->[T,B,D] embedded_y = embedded_y.transpose(1,0) # [T,B,D]->[B,T,D] return embedded_y def prepare_input_target(self, cated_y, y_lens): embedded_y = self.embed_y(cated_y) # [B,T,D] # positional embedding y_input = self.audio_positional_embedding(embedded_y) # make attention mask and padding mask y_padding_mask = make_pad_mask(y_lens).to(cated_y.device) y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y_padding_mask.device) return y_input, y_padding_mask, y_attention_mask def dec_forward( self, x_input, x_lens, x_attention_mask, x_padding_mask, y_input, new_y_lens, y_attention_mask, y_padding_mask, past=None, last_3_tokens=False ): x_attn_mask = F.pad( x_attention_mask, (0, new_y_lens.max()), value=True, ) # x attn to all x, doesn't attn to any y, this follow figure 3 of the valle paper y_attn_mask = F.pad( y_attention_mask, (x_lens.max(), 0), # y is padded at the front value=False, ) # y attn to all x, for y itself use lower triangle mask to ensure autoregressive xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) # merge key padding and attention masks bsz, src_len = x_input.shape[0], x_lens.max() + new_y_lens.max() xy_padding_mask = torch.concat([x_padding_mask, y_padding_mask], dim=1) _xy_padding_mask = ( xy_padding_mask.view(bsz, 1, 1, src_len) .expand(-1, self.args.nhead, -1, -1) .reshape(bsz * self.args.nhead, 1, src_len) ) # Check shapes and resize+broadcast as necessary if xy_attn_mask.shape != _xy_padding_mask.shape: assert xy_attn_mask.ndim + 1 == _xy_padding_mask.ndim, f"xy_attn_mask.shape: {xy_attn_mask.shape}, _xy_padding_mask: {_xy_padding_mask.shape}" xy_attn_mask = xy_attn_mask.unsqueeze(0).repeat(_xy_padding_mask.shape[0], 1, 1) # Example approach xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) new_attn_mask = torch.zeros_like(xy_attn_mask) new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) xy_attn_mask = new_attn_mask xy_input = torch.cat([x_input, y_input], dim=1) if past == None: # do not use kvcache out, _ = self.decoder((xy_input, None), mask=xy_attn_mask) return out[:, x_lens.max():], None else: # use kvcache if past.ndim > 3: # uses kvcache, only need to pass the last tokens, this doesn't work with multi-span speech editing yet if last_3_tokens: xy_input = xy_input[:, -3:] xy_attn_mask = xy_attn_mask[:, -3:] else: xy_input = xy_input[:, -1:] xy_attn_mask = xy_attn_mask[:, -1:] out, present = self.decoder((xy_input, None), mask=xy_attn_mask, past=past) if isinstance(out, tuple): # get rid of stage_embedding out = out[0] if out.shape[1] > x_lens.max(): # the first pass, not kvcache yet return out[:, x_lens.max():], present else: # used kvcache return out, present def forward(self, batch): """ Args: x: A 2-D tensor of shape (N, S). x_lens: A 1-D tensor of shape (N,). It contains the number of tokens in `x` before padding. y: A 3-D tensor of shape (N, K, T). where K is the number of codebooks y_lens: A 1-D tensor of shape (N,). It contains the number of tokens in `x` before padding. """ x, x_lens, y, y_lens = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"] if len(x) == 0: return None x = x[:, :x_lens.max()] # this deal with gradient accumulation, where x_lens.max() might not be longer than the length of the current slice of x y = y[:, :, :y_lens.max()] assert x.ndim == 2, x.shape assert x_lens.ndim == 1, x_lens.shape assert y.ndim == 3 and y.shape[1] == self.args.n_codebooks, y.shape assert y_lens.ndim == 1, y_lens.shape targets = y.clone() y = y.permute(1,2,0) # [B,K,T]->[K,T,B] # makes attention mask and padding mask for x x_padding_mask = make_pad_mask(x_lens).to(x.device) x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x_padding_mask.device) x_input = self.text_embedding(x) x_input = self.text_positional_embedding(x_input) y_input, y_padding_mask, y_attention_mask = self.prepare_input_target(y, y_lens) y_out = self.dec_forward( x_input, x_lens, x_attention_mask, x_padding_mask, y_input, y_lens, y_attention_mask, y_padding_mask ) y_out = y_out[0] # no kv-caching during training assert y_out.shape == y_input.shape, f"y_out.shape: {y_out.shape}, y_input.shape: {y_input.shape}" # [B S D] logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card] assert logits.shape[1] == self.args.n_codebooks and logits.shape[3] == self.n_audio_tokens[0], logits.shape targets = targets.permute(1,0,2) # [K B T] logits = logits.permute(1,0,2,3) # [K B S card] logits = logits[:, :, :-1] targets = targets[:, :, 1:] if self.args.predict_mask_token: masks = (targets != self.args.audio_pad_token) & (targets != self.args.empty_token) else: masks = (targets != self.args.audio_pad_token) & (targets != self.args.empty_token) & (targets < self.args.mts) tmp_masks = masks.clone() if not self.args.predict_all: eos_pos = (targets == self.args.mts).nonzero(as_tuple=False) for k, b, t in eos_pos: tmp_masks[k, b, :t] = False assert masks.shape[0] == self.args.n_codebooks, masks.shape loss = [] ntokens = [] top10acc = [] for k, (logit, target, mask, tmp_mask) in enumerate(zip(logits, targets, masks, tmp_masks)): logit = logit.reshape(-1, logit.size(-1)) # B*S card target = target.reshape(-1) # B*T mask = mask.reshape(-1).bool() tmp_mask = tmp_mask.reshape(-1).bool() loss.append(F.cross_entropy(logit[tmp_mask], target[tmp_mask], reduction='mean')) top10acc.append(self.accuracy_metrics[k](logit[tmp_mask].detach(), target[tmp_mask])) ntokens.append(len(target[mask])) all_ntokens = sum(ntokens) if self.args.codebook_weight != None: codebook_weight = eval(self.args.codebook_weight) else: codebook_weight = [1.] * self.args.n_codebooks loss = sum([l*nt*cw for l, nt, cw in zip(loss, ntokens, codebook_weight)]) top10acc_by_codebook = [t10a*nt for t10a, nt in zip(top10acc, ntokens)] top10acc = sum(top10acc_by_codebook) ntokens = torch.tensor(all_ntokens).to(logits.device) return { "loss": loss, "top10acc": top10acc, "top10acc_by_codebook": top10acc_by_codebook, "effective_ntoken": ntokens, } def rearrange(self, y, non_mask_intervals, mask_intervals): assert self.args.eos > 0, f"eos={self.args.eos} should > 0" rearranged_y = [] sos_tensor = torch.LongTensor([self.args.sos] * self.args.n_codebooks).unsqueeze(-1).to(y.device) eos_tensor = torch.LongTensor([self.args.eos] * self.args.n_codebooks).unsqueeze(-1).to(y.device) eog_tensor = torch.LongTensor([self.args.eog] * self.args.n_codebooks).unsqueeze(-1).to(y.device) for i, item in enumerate(non_mask_intervals): if i == 0: if item[0] == item[1]: # case: (0,0) rearranged_y.append(sos_tensor) else: rearranged_y.append(torch.cat([sos_tensor, y[:, item[0]: item[1]]], dim=-1)) elif i == len(non_mask_intervals)-1: if item[0] == item[1]: # case: (N,N) rearranged_y.append(eos_tensor) else: rearranged_y.append(torch.cat([y[:, item[0]: item[1]], eos_tensor], dim=-1)) else: rearranged_y.append(y[:, item[0]: item[1]]) for i, item in enumerate(mask_intervals): rearranged_y.append(torch.cat([y[:, item[0]: item[1]], eog_tensor], dim=-1)) return rearranged_y def get_pattern_sequence(self, tokens: torch.Tensor, n_q: int, special_token: int, delays: tp.Optional[tp.List[int]] = None, empty_initial: int = 0) -> torch.Tensor: """Generate a pattern sequence for delayed codebooks without batch dimension. Args: tokens (torch.Tensor): Input tensor of shape [K, T]. n_q (int): Number of codebooks. delays (Optional[List[int]]): Delay for each codebook. Defaults to increasing delays. empty_initial (int): Number of initial empty steps. Defaults to 0. special_token (int): Special token used to fill non-pattern coordinates in the new sequence. Returns: torch.Tensor: Modified tokens based on the pattern. """ K, T = tokens.shape assert K == n_q, "Number of codebooks (K) must match n_q" if delays is None: delays = list(range(n_q)) max_delay = max(delays) pattern_length = T + max_delay + empty_initial pattern_tokens = torch.full((K, pattern_length), fill_value=special_token, dtype=tokens.dtype).to(tokens.device) for t in range(T): for q in range(n_q): delayed_t = t + delays[q] + empty_initial if delayed_t < pattern_length: pattern_tokens[q, delayed_t] = tokens[q, t] return pattern_tokens def revert_pattern_sequence(self, pattern_tokens: torch.Tensor, n_q: int, delays: tp.Optional[tp.List[int]] = None, special_token: int = -1) -> torch.Tensor: """Revert the pattern sequence back to the original multi-codebook sequence without batch dimension. Args: pattern_tokens (torch.Tensor): Pattern tensor of shape [K, S]. n_q (int): Number of codebooks. delays (Optional[List[int]]): Delay for each codebook. Defaults to increasing delays. special_token (int): Special token used to fill non-pattern coordinates in the new sequence. Returns: torch.Tensor: Reverted tokens of shape [K, T]. """ K, S = pattern_tokens.shape assert K == n_q, "Number of codebooks (K) must match n_q" if delays is None: delays = list(range(n_q)) T = S - max(delays) reverted_tokens = torch.full((K, T), fill_value=special_token, dtype=pattern_tokens.dtype).to(pattern_tokens.device) for t in range(T): for q in range(n_q): delayed_t = t + delays[q] if delayed_t < S: reverted_tokens[q, t] = pattern_tokens[q, delayed_t] return reverted_tokens def shift(self, rearranged_y): shifted_y = [self.get_pattern_sequence(tokens=cur_y, n_q=self.args.n_codebooks, special_token=self.args.empty_token) for cur_y in rearranged_y] # the first item is values, later two are indexes and mask return shifted_y def insert_mask(self, shifted_y): num_masks = (len(shifted_y) - 1) // 2 assert num_masks == (len(shifted_y) - 1) / 2, len(shifted_y) emb_inds = list(range(self.args.mts, self.args.mts+ self.args.max_n_spans)) if self.args.shuffle_mask_embedding: random.shuffle(emb_inds) emb_inds_use = emb_inds[:num_masks] mask_value = emb_inds_use + emb_inds_use assert len(shifted_y) == len(mask_value) + 1, len(mask_value) inserted_y = [] mask_position = [-1] * (self.args.max_n_spans*2) for j in range(len(shifted_y)-1): inserted_y.append(shifted_y[j]) mask_position[j] = sum([item.shape[1] for item in inserted_y]) # each item is of shape [K S], so take shape[1] tmp = torch.LongTensor([mask_value[j]] * self.args.n_codebooks).unsqueeze(-1).to(shifted_y[0].device) inserted_y.append(tmp) inserted_y.append(shifted_y[-1]) mask_position = [item for item in mask_position if item != -1] return inserted_y, mask_position def cat_y(self, inserted_y): cated_y = torch.cat(inserted_y, dim=1) assert cated_y.shape[0] == self.args.n_codebooks, cated_y.shape new_y_lens = cated_y.shape[1] return cated_y, new_y_lens def inference( self, x: torch.Tensor, x_lens: torch.Tensor, prompt_x: torch.Tensor, prompt_x_lens: torch.Tensor, y: torch.Tensor, prompt: torch.Tensor, mask_interval: list[torch.Tensor], top_k: int=-100, top_p: float=1.0, temperature: float=1.0, stop_repetition: int=-1, kvcache: int=1, silence_tokens: list[int]=[1388,1898,131], cfg_coef: float=1.5, aug_text: bool=False, aug_context: bool=False, cfg_pretrained: bool=False, ) -> torch.Tensor: """ Args: x: A 2-D tensor of shape (1, L). x_lens: A 1-D tensor of shape (1,). It contains the number of tokens in `x` before padding. y: A 3-D tensor of shape (1, T, K). mask_interval: a list of tensors of shape (M, 2). contains M mask_start and mask_end. list length is actually 1, because we only support single sample inference for now top_k: (`optional`) int The number of highest probability tokens to keep for top-k-filtering. Default to -100. top_p: (`optional`) float For Neucleus sampling temperature: (`optional`) float The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. stop_repetition (`optional`) int if not -1, will set the logits of a token that repeated this many times to be -100000, to avoid generating it again. This only apply to tokens from the first codebook kvcache (`optional`) int if 1, use kvcache to speed up sampling cfg_coef: float (>= 1.0) aug_text: whether use cfg to improve the text input aug_context: whether improve the context by combining original audio and text cfg_pretrained: whether use cfg in training """ assert cfg_coef >= 1.0, cfg_coef assert x.ndim == 2, x.shape assert x_lens.ndim == 1, x_lens.shape assert y.ndim == 3, y.shape y = y.transpose(2,1) # [1,T,K] -> [1,K,T] assert prompt.ndim == 3, prompt.shape prompt = prompt.transpose(2,1) assert y.shape[0] == 1 and y.shape[1] == self.args.n_codebooks, y.shape # there is no padding assert prompt.shape[0] == 1 and prompt.shape[1] == self.args.n_codebooks, prompt.shape # there is no padding assert mask_interval.shape == torch.Size((1, mask_interval.shape[1], 2)), mask_interval # whether to use context context_len = sum([item[1] - item[0] for item in mask_interval[0]]) if aug_context and context_len < 2 * 50: aug_context = True else: aug_context = False # augment if aug_text and not aug_context: # [t, ab, m] [t', ab, m] y = y.repeat(2, 1, 1) if not cfg_pretrained: uncond_x = torch.randint(0, self.n_text_tokens, (1, x.shape[1])).to(x.device) else: uncond_x = torch.tensor([self.args.text_vocab_size-1], dtype=torch.long).unsqueeze(0).repeat(1, x.shape[1]).to(x.device) x = torch.cat([x, uncond_x], dim=0) if aug_text and aug_context: # [tc, t, c, ab, m] [tc, t', c, ab, m] out_len = prompt.shape[2] gt_y = torch.cat([prompt, y], dim=-1) y = gt_y.repeat(2, 1, 1) gt_x = torch.cat([prompt_x, x], dim=1) if not cfg_pretrained: uncond_x = torch.randint(0, self.n_text_tokens, (1, gt_x.shape[1])).to(gt_x.device) else: uncond_x = torch.tensor([self.args.text_vocab_size-1], dtype=torch.long).unsqueeze(0).repeat(1, gt_x.shape[1]).to(gt_x.device) x = torch.cat([gt_x, uncond_x], dim=0) if not aug_text and aug_context: # [tc, t, c, ab, m] out_len = prompt.shape[2] y = torch.cat([prompt, y], dim=-1) x = torch.cat([prompt_x, x], dim=1) # make x attention mask and x_input x_lens = torch.LongTensor([x.shape[-1]]).to(x_lens.device) x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x.device) x_input = self.text_embedding(x) x_input = self.text_positional_embedding(x_input) # make initial y_input # make mask_interval and non_mask_interval y_len = y.shape[2] y_lens = torch.LongTensor([y_len]).to(y.device) mask_interval = mask_interval[0] if aug_context: mask_interval = [[item[0]+out_len, item[1]+out_len] for item in mask_interval] starts = [item[0].item() for item in mask_interval] + [y_len] ends = [0] + [item[1].item() for item in mask_interval] mask_intervals = [ (item[0].item(), item[1].item()) for item in mask_interval ] # a werid name change, mask_interval is input, now is mask_intervals, with one more dimension non_mask_intervals = [ (ns, ne) for ns, ne in zip(ends, starts) ] # prepare input sequences rearranged_y = self.rearrange(y[0], non_mask_intervals, mask_intervals) shifted_y = self.shift(rearranged_y) # each element [K S], patterns is not used, as we directly use the original input y inserted_y, mask_position = self.insert_mask(shifted_y) cated_y, new_y_lens = self.cat_y(inserted_y) # KT num_task = len(mask_position)//2 cated_y = cated_y[:, :mask_position[num_task]] # of shape [K,T] input of the network new_y_lens = torch.LongTensor([mask_position[num_task]]).to(cated_y.device) cated_y = cated_y.unsqueeze(0).permute(1,2,0) # B,K,T -> K,T,B if aug_text: cated_y = cated_y.repeat(1, 1, 2) embedded_y = self.embed_y(cated_y) #BTD if aug_text: x_padding_mask = torch.full((2, x_lens[0]), False).to(x.device) if cfg_pretrained: x_padding_mask[1:, 1:] = True past = torch.ones([self.args.num_decoder_layers, 2, 2], device=x.device, dtype=torch.float32) if kvcache else None else: x_padding_mask = torch.full((1, x_lens[0]), False).to(x.device) past = torch.ones([self.args.num_decoder_layers, 2, 1], device=x.device, dtype=torch.float32) if kvcache else None emb_inds = list(range(self.args.mts, self.args.mts+ self.args.max_n_spans)) generated = [] logging.info(f"silence tokens: {silence_tokens}, note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default") for idx in range(num_task): cur_generated = [] prev_token = None consec_silence_count = 0 num_gen = 0 num_eog = 0 # add mask token mts = torch.LongTensor([emb_inds[idx]] * self.args.n_codebooks).unsqueeze(-1).to(embedded_y.device) # K, 1 mts_emb = torch.stack([self.audio_embedding[k](mts[k]) for k in range(self.args.n_codebooks)], dim=0) # [K,1,D] mts_emb = mts_emb.sum(dim=0,keepdim=True) # [1,1,D] if aug_text: mts_emb = mts_emb.repeat(2,1,1) embedded_y = torch.cat([embedded_y, mts_emb], dim=1) # positional embedding y_input = self.audio_positional_embedding(embedded_y) # [B T D] # make attention mask and padding mask y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device) new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device) if aug_text: y_padding_mask = torch.full((2,new_y_lens[0]), False).to(y.device) else: y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device) while True: # get model output y_out, present = self.dec_forward( x_input, x_lens, x_attention_mask, x_padding_mask, y_input, new_y_lens, y_attention_mask, y_padding_mask, past=past, last_3_tokens=False ) if past != None: past = torch.cat([past, present.to(past.dtype)], dim=-2) if past.ndim > 3 else present.to(past.dtype) y_out = y_out[:, -1:] # only take the last one logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card], B==S==1, so [1 K 1 card] logits = logits.squeeze() # [K card] if aug_text: logits = cfg_coef * logits[0] + (1 - cfg_coef) * logits[1] assert logits.shape == torch.Size((self.args.n_codebooks, self.n_audio_tokens[0])), f"{logits.shape}" # filter out mts, sos and eos for jj in range(self.args.n_codebooks): logits[jj][self.args.eos] = -10000. logits[jj][self.args.sos] = -10000. for mts in range(self.args.mts, self.args.mts+ self.args.max_n_spans): logits[jj][mts] = -10000. # add first empty tokens if num_gen < self.args.n_codebooks - 1: for jj in range(num_gen + 1, self.args.n_codebooks): logits[jj][self.args.empty_token] = 10000. # deal with eog token if num_eog > 0: # codebook 1 has produced eog token for jj in range(num_eog+1,self.args.n_codebooks): logits[jj][self.args.eog] = -10000 logits[jj][self.args.empty_token] = -10000 samples = topk_sampling( logits, top_k=top_k, top_p=top_p, temperature=temperature ) # [K, 1] for jj in range(num_eog): samples[jj, 0] = self.args.empty_token samples[num_eog, 0] = self.args.eog num_eog += 1 else: # codebook 1 did not produce eog token # filter out eog for codebook 2-4 for jj in range(1,self.args.n_codebooks): logits[jj][self.args.eog] = -10000 # silence repetition handling if stop_repetition > 0 and prev_token in silence_tokens and consec_silence_count > stop_repetition: if logits[0, prev_token] < 0: logits[0, prev_token] = logits[0, prev_token] * (consec_silence_count - (stop_repetition-1)) else: logits[0, prev_token] = logits[0, prev_token] / (consec_silence_count - (stop_repetition-1)) samples = topk_sampling( logits, top_k=top_k, top_p=top_p, temperature=temperature ) # [K, 1] assert samples.shape == torch.Size((self.args.n_codebooks, 1)), f"samples.shape: {samples.shape}" if ( samples[0,0] == self.args.eog or torch.argmax(logits[0], dim=-1) == self.args.eog or y_input.shape[1] > x_lens[0] * 10 ): # last one means y is already too long, shouldn't happen, but put it here samples[0,0] = self.args.eog num_eog += 1 if samples[0,0] in silence_tokens and samples[0,0] == prev_token: consec_silence_count += 1 else: consec_silence_count = 0 prev_token = samples[0,0] num_gen += 1 cur_generated.append(samples.squeeze(-1)) if num_eog == self.args.n_codebooks: # current span is done break # prepare input for next token prediction samples_emb = torch.stack([self.audio_embedding[k](samples[k]) for k in range(self.args.n_codebooks)], dim=0) # [K,1,D] samples_emb = samples_emb.sum(dim=0,keepdim=True) # [1,1,D] if aug_text: samples_emb = samples_emb.repeat(2, 1, 1) embedded_y = torch.cat([embedded_y, samples_emb], dim=1) # positional embedding y_input = self.audio_positional_embedding(embedded_y) # [B T D] # make attention mask and padding mask y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device) new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device) if aug_text: y_padding_mask = torch.full((2,new_y_lens[0]), False).to(y.device) else: y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device) generated.append(cur_generated) assert len(generated) == num_task, f"len(generated): {len(generated)}, num_task: {num_task}" # # combine non_masked_span with generated spans # first need to shift the generated part back flatten_gen = [] for i, orig_span in enumerate(generated): span = torch.stack(orig_span, dim=0) # [T K] span = span.transpose(1,0) # [K, T] assert span.shape[0] == self.args.n_codebooks, span.shape unshifted_span = self.revert_pattern_sequence(pattern_tokens=span, n_q=self.args.n_codebooks, special_token=self.args.empty_token) assert unshifted_span.shape[1] == span.shape[1]-self.args.n_codebooks+1, f"unshifted_span:{unshifted_span.shape}, orig_span:{span.shape}" unshifted_span = unshifted_span[:,:-1] # remove eog token flatten_gen.append(unshifted_span) res = [] marks = [] masks = [] tmp = 0 for orig_interval, gen in zip(non_mask_intervals, flatten_gen): res.append(y[0, :, orig_interval[0]:orig_interval[1]]) masks.append((tmp, tmp+orig_interval[1]-orig_interval[0])) tmp_mark = [0] * (orig_interval[1] - orig_interval[0]) marks = [*marks, *tmp_mark] res.append(gen) tmp += orig_interval[1]-orig_interval[0] + gen.shape[-1] tmp_mark = [1] * gen.shape[-1] marks = [*marks, *tmp_mark] if y.shape[-1] != non_mask_intervals[-1][1] + 1: # edit last tokens or tts res.append(y[0, :, non_mask_intervals[-1][0]:non_mask_intervals[-1][1]]) masks.append((tmp, tmp+non_mask_intervals[-1][1]-non_mask_intervals[-1][0])) tmp_mark = [0] * (non_mask_intervals[-1][1] - non_mask_intervals[-1][0]) marks = [*marks, *tmp_mark] res = torch.cat(res, dim=1).unsqueeze(0) # [K,new_T] -> [1, K, new_T] marks = torch.LongTensor(marks).unsqueeze(0) if aug_context: res = res[:, :, out_len:] marks = marks[:, out_len:] masks = [(item[0]-out_len, item[1]-out_len) for item in masks] non_mask_intervals = [(item[0]-out_len, item[1]-out_len) for item in non_mask_intervals] return res, marks, masks, non_mask_intervals if __name__ == "__main__": # debug pass