SSR-Speech / data /gigaspeech.py
OpenSound's picture
11
f5b4ff2
raw
history blame
14.3 kB
# @ [email protected]
import os
import torch
import random
import copy
import logging
import shutil
import typing as tp
class dataset(torch.utils.data.Dataset):
def __init__(self, args, split):
super().__init__()
self.args = args
self.split = split
assert self.split in ['train', 'validation', 'test']
manifest_fn = os.path.join(self.args.dataset_dir, self.args.manifest_name, self.split+".txt")
with open(manifest_fn, "r") as rf:
data = [l.strip().split("\t") for l in rf.readlines()]
lengths_list = [int(item[-1]) for item in data]
self.data = []
self.lengths_list = []
for d, l in zip(data, lengths_list):
if l >= self.args.encodec_sr*self.args.audio_min_length:
if self.args.drop_long and l > self.args.encodec_sr*self.args.audio_max_length:
continue
self.data.append(d)
self.lengths_list.append(l)
logging.info(f"number of data points for {self.split} split: {len(self.lengths_list)}")
# phoneme vocabulary
vocab_fn = os.path.join(self.args.dataset_dir,"vocab.txt")
shutil.copy(vocab_fn, os.path.join(self.args.exp_dir, "vocab.txt"))
with open(vocab_fn, "r") as f:
temp = [l.strip().split(" ") for l in f.readlines() if len(l) != 0]
self.phn2num = {item[1]:int(item[0]) for item in temp}
self.symbol_set = set(["<SIL>", "<MUSIC>", "<NOISE>", "<OTHER>"])
def __len__(self):
return len(self.lengths_list)
def _load_phn_enc(self, index):
item = self.data[index]
pf = os.path.join(self.args.dataset_dir, self.args.phn_folder_name, item[1]+".txt")
ef = os.path.join(self.args.dataset_dir, self.args.encodec_folder_name, item[1]+".txt")
try:
with open(pf, "r") as p, open(ef, "r") as e:
phns = [l.strip() for l in p.readlines()]
assert len(phns) == 1, phns
x = [self.phn2num[item] for item in phns[0].split(" ") if item not in self.symbol_set] # drop ["<SIL>", "<MUSIC>", "<NOISE>", "<OTHER>"], as they are not in training set annotation
encos = [l.strip().split() for k, l in enumerate(e.readlines()) if k < self.args.n_codebooks]
assert len(encos) == self.args.n_codebooks, ef
if self.args.special_first:
y = [[int(n)+self.args.n_special for n in l] for l in encos]
else:
y = [[int(n) for n in l] for l in encos]
except Exception as e:
logging.info(f"loading failed for {pf} and {ef}, maybe files don't exist or are corrupted")
logging.info(f"error message: {e}")
return [], [[]]
return x, y
def prepare_mask_intervals(self, y_len):
# random generate mask intervals
# Mask Intervals: [(5, 9), (19, 29)]
# Non-Mask Intervals: [[(0, 5), (9, 19), (29, 30)]]
if self.args.mask_sample_dist == "uniform":
n_spans = random.choice(range(1, self.args.max_n_spans + 1))
elif "poisson" in self.args.mask_sample_dist.lower():
param = float(self.args.mask_sample_dist[len("poisson"):])
poisson_sample = torch.poisson(torch.tensor([param]))
n_spans = int(poisson_sample.clamp(1, self.args.max_n_spans).item())
starts = random.sample(range(0, y_len - self.args.mask_len_min), n_spans)
starts = sorted(starts)
for j in range(len(starts) - 1, 0, -1):
if starts[j] - starts[j - 1] < self.args.min_gap:
del starts[j]
assert len(starts) > 0, f"there is no masked span left, y_len: {y_len}, sampled n_spans: {n_spans}"
tmp_mask_len_max = int(self.args.max_mask_portion * y_len / len(starts))
ends = []
for j, start in enumerate(starts):
if j < len(starts) - 1:
mask_len = random.randint(self.args.mask_len_min, min(tmp_mask_len_max, starts[j+1]-starts[j]-self.args.min_gap+1))
else:
mask_len = random.randint(self.args.mask_len_min, min(tmp_mask_len_max, y_len-starts[j]))
ends.append(start + mask_len)
if self.args.tts_enhanced > 0 and random.random() < 0.5:
starts[-1] = max(starts[-1], y_len - tmp_mask_len_max)
ends[-1] = y_len
mask_intervals = [(s, e) for s, e in zip(starts, ends)]
non_mask_intervals = [(ns, ne) for ns, ne in zip([0] + ends, starts + [y_len])]
return mask_intervals, non_mask_intervals
def rearrange(self, y, non_mask_intervals, mask_intervals):
assert self.args.eos > 0, f"eos={self.args.eos} should > 0"
rearranged_y = []
sos_tensor = torch.tensor([self.args.sos] * self.args.n_codebooks).unsqueeze(-1)
eos_tensor = torch.tensor([self.args.eos] * self.args.n_codebooks).unsqueeze(-1)
eog_tensor = torch.tensor([self.args.eog] * self.args.n_codebooks).unsqueeze(-1)
for i, item in enumerate(non_mask_intervals):
if i == 0:
if item[0] == item[1]: # case: (0,0)
rearranged_y.append(sos_tensor)
else:
rearranged_y.append(torch.cat([sos_tensor, y[:, item[0]: item[1]]], dim=-1))
elif i == len(non_mask_intervals)-1:
if item[0] == item[1]: # case: (N,N)
rearranged_y.append(eos_tensor)
else:
rearranged_y.append(torch.cat([y[:, item[0]: item[1]], eos_tensor], dim=-1))
else:
rearranged_y.append(y[:, item[0]: item[1]])
for i, item in enumerate(mask_intervals):
rearranged_y.append(torch.cat([y[:, item[0]: item[1]], eog_tensor], dim=-1))
return rearranged_y
def get_pattern_sequence(self, tokens: torch.Tensor, n_q: int, special_token: int, delays: tp.Optional[tp.List[int]] = None,
empty_initial: int = 0) -> torch.Tensor:
"""Generate a pattern sequence for delayed codebooks without batch dimension.
Args:
tokens (torch.Tensor): Input tensor of shape [K, T].
n_q (int): Number of codebooks.
delays (Optional[List[int]]): Delay for each codebook. Defaults to increasing delays.
empty_initial (int): Number of initial empty steps. Defaults to 0.
special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
Returns:
torch.Tensor: Modified tokens based on the pattern.
"""
K, T = tokens.shape
assert K == n_q, "Number of codebooks (K) must match n_q"
if delays is None:
delays = list(range(n_q))
max_delay = max(delays)
pattern_length = T + max_delay + empty_initial
pattern_tokens = torch.full((K, pattern_length), fill_value=special_token, dtype=tokens.dtype, device=tokens.device)
for t in range(T):
for q in range(n_q):
delayed_t = t + delays[q] + empty_initial
if delayed_t < pattern_length:
pattern_tokens[q, delayed_t] = tokens[q, t]
return pattern_tokens
def revert_pattern_sequence(self, pattern_tokens: torch.Tensor, n_q: int,
delays: tp.Optional[tp.List[int]] = None, special_token: int = -1) -> torch.Tensor:
"""Revert the pattern sequence back to the original multi-codebook sequence without batch dimension.
Args:
pattern_tokens (torch.Tensor): Pattern tensor of shape [K, S].
n_q (int): Number of codebooks.
delays (Optional[List[int]]): Delay for each codebook. Defaults to increasing delays.
special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
Returns:
torch.Tensor: Reverted tokens of shape [K, T].
"""
K, S = pattern_tokens.shape
assert K == n_q, "Number of codebooks (K) must match n_q"
if delays is None:
delays = list(range(n_q))
T = S - max(delays)
reverted_tokens = torch.full((K, T), fill_value=special_token, dtype=pattern_tokens.dtype, device=pattern_tokens.device)
for t in range(T):
for q in range(n_q):
delayed_t = t + delays[q]
if delayed_t < S:
reverted_tokens[q, t] = pattern_tokens[q, delayed_t]
return reverted_tokens
def shift(self, rearranged_y):
shifted_y = [self.get_pattern_sequence(tokens=cur_y, n_q=self.args.n_codebooks, special_token=self.args.empty_token) for cur_y in rearranged_y] # the first item is values, later two are indexes and mask
return shifted_y
def insert_mask(self, shifted_y):
num_masks = (len(shifted_y) - 1) // 2
assert num_masks == (len(shifted_y) - 1) / 2, len(shifted_y)
emb_inds = list(range(self.args.mts, self.args.mts+ self.args.max_n_spans))
if self.args.shuffle_mask_embedding:
random.shuffle(emb_inds)
emb_inds_use = emb_inds[:num_masks]
mask_value = emb_inds_use + emb_inds_use
assert len(shifted_y) == len(mask_value) + 1, len(mask_value)
inserted_y = []
mask_position = [-1] * (self.args.max_n_spans*2)
for j in range(len(shifted_y)-1):
inserted_y.append(shifted_y[j])
mask_position[j] = sum([item.shape[1] for item in inserted_y]) # each item is of shape [K S], so take shape[1]
tmp = torch.tensor([mask_value[j]] * self.args.n_codebooks).unsqueeze(-1)
inserted_y.append(tmp)
inserted_y.append(shifted_y[-1])
mask_position = [item for item in mask_position if item != -1]
return inserted_y, mask_position
def cat_y(self, inserted_y):
cated_y = torch.cat(inserted_y, dim=1)
assert cated_y.shape[0] == self.args.n_codebooks, cated_y.shape
new_y_lens = cated_y.shape[1]
return cated_y, new_y_lens
def __getitem__(self, index):
x, y = self._load_phn_enc(index)
x_len, y_len = len(x), len(y[0])
if x_len == 0 or y_len == 0: # load failure
item = self.data[index]
pf = os.path.join(self.args.dataset_dir, self.args.phn_folder_name, item[1]+".txt")
logging.info(f"loading failed for {pf}, length is 0")
return {
"x": None,
"x_len": None,
"y": None,
"y_len": None,
}
if y_len < self.args.encodec_sr * self.args.audio_min_length or x_len < self.args.text_min_length: # too short
item = self.data[index]
pf = os.path.join(self.args.dataset_dir, self.args.phn_folder_name, item[1]+".txt")
logging.info(f"loading failed for {pf}, too short")
return {
"x": None,
"x_len": None,
"y": None,
"y_len": None,
}
if self.args.drop_long:
if x_len > self.args.text_max_length or y_len > self.args.encodec_sr * self.args.audio_max_length: # too long
item = self.data[index]
pf = os.path.join(self.args.dataset_dir, self.args.phn_folder_name, item[1]+".txt")
logging.info(f"loading failed for {pf}, too long")
return {
"x": None,
"x_len": None,
"y": None,
"y_len": None,
}
if self.args.cfg_enhanced and random.random() < 0.1: # We use the last unused token for cfg training
x = torch.tensor([self.args.text_vocab_size-1], dtype=torch.long)
x_len = len(x)
mask_intervals, non_mask_intervals = self.prepare_mask_intervals(y_len)
rearranged_y = self.rearrange(torch.LongTensor(y), non_mask_intervals, mask_intervals)
shifted_y = self.shift(rearranged_y)
inserted_y, mask_position = self.insert_mask(shifted_y)
y, y_len = self.cat_y(inserted_y)
x = torch.LongTensor(x)
y = torch.LongTensor(y)
if not (y < int(self.args.audio_vocab_size) + self.args.n_special + self.args.max_n_spans).all():
item = self.data[index]
pf = os.path.join(self.args.dataset_dir, self.args.phn_folder_name, item[1]+".txt")
logging.info(f"loading failed for {pf}, index out of range")
return {
"x": None,
"x_len": None,
"y": None,
"y_len": None,
}
return {
"x": x,
"x_len": x_len,
"y": y,
"y_len": y_len
}
def collate(self, batch):
out = {key:[] for key in batch[0]}
for item in batch:
if item['x'] == None: # deal with load failure
continue
for key, val in item.items():
out[key].append(val)
res = {}
if self.args.pad_x:
res["x"] = torch.stack(out["x"], dim=0)
else:
res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=self.args.text_pad_token)
res["x_lens"] = torch.LongTensor(out["x_len"])
if self.args.dynamic_batching:
if out['y'][0].ndim==2:
res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=self.args.audio_pad_token)
res['y'] = res['y'].permute(1,2,0) # T B K -> B K T
else:
assert out['y'][0].ndim==1, out['y'][0].shape
res['y'] = torch.nn.utils.rnn.pad_sequence(out['y'], batch_first=True, padding_value=self.args.audio_pad_token)
else:
res['y'] = torch.stack(out['y'], dim=0)
res["y_lens"] = torch.LongTensor(out["y_len"])
return res
if __name__ == "__main__":
# debug
pass