|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import random |
|
import json |
|
import tarfile |
|
import json |
|
import io |
|
import pyarrow.parquet as pq |
|
from io import BytesIO |
|
import torch |
|
import torchaudio |
|
from torch.nn.utils.rnn import pad_sequence |
|
import torch.nn.functional as F |
|
import tarfile |
|
import json |
|
import io |
|
import wave |
|
import numpy as np |
|
import torchaudio |
|
import os |
|
import sys |
|
import json |
|
import random |
|
import pickle |
|
import argparse |
|
import itertools |
|
import mmap |
|
import struct |
|
import collections |
|
|
|
|
|
|
|
import shutil |
|
import multiprocessing as mp |
|
from pathlib import Path |
|
|
|
from tqdm import tqdm |
|
from collections import defaultdict |
|
from copy import deepcopy |
|
from datetime import datetime |
|
import pickle |
|
|
|
from wids import wids |
|
import math |
|
|
|
torchaudio.set_audio_backend('soundfile') |
|
|
|
AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) |
|
|
|
try: |
|
MAIN_SPK_EMBEDDING=torch.load("/workspace/audio_checkpoints/flow_model/spk_embedding/0909/mean_embedding.pt") |
|
GPT_SPK_EMBEDDING=torch.load("/workspace/audio_checkpoints/flow_model/spk_embedding/0909/spk_mean_embeddings.pt") |
|
except: |
|
MAIN_SPK_EMBEDDING=torch.zeros(1,192) |
|
GPT_SPK_EMBEDDING=torch.zeros(1,192) |
|
|
|
def parquet_opener(data, mode='train', tts_data={}): |
|
""" Give url or local file, return file descriptor |
|
Inplace operation. |
|
|
|
Args: |
|
data(Iterable[str]): url or local file list |
|
|
|
Returns: |
|
Iterable[{src, stream}] |
|
""" |
|
for sample in data: |
|
assert 'src' in sample |
|
url = sample['src'] |
|
try: |
|
df = pq.read_table(url).to_pandas() |
|
for i in range(len(df)): |
|
if mode == 'inference' and df.loc[i, 'utt'] not in tts_data: |
|
continue |
|
sample.update(dict(df.loc[i])) |
|
if mode == 'train': |
|
|
|
yield {**sample} |
|
else: |
|
for index, text in enumerate(tts_data[df.loc[i, 'utt']]): |
|
yield {**sample, 'tts_index': index, 'tts_text': text} |
|
except Exception as ex: |
|
logging.warning('Failed to open {}, ex info {}'.format(url, ex)) |
|
|
|
|
|
|
|
|
|
def parse_tar_header(header_bytes): |
|
header = struct.unpack("!100s8s8s8s12s12s8s1s100s6s2s32s32s8s8s155s", header_bytes) |
|
return TarHeader(*header) |
|
|
|
TarHeader = collections.namedtuple( |
|
"TarHeader", |
|
[ |
|
"name", |
|
"mode", |
|
"uid", |
|
"gid", |
|
"size", |
|
"mtime", |
|
"chksum", |
|
"typeflag", |
|
"linkname", |
|
"magic", |
|
"version", |
|
"uname", |
|
"gname", |
|
"devmajor", |
|
"devminor", |
|
"prefix", |
|
], |
|
) |
|
|
|
class MMTar: |
|
def __init__(self, file_path: Path | str): |
|
self.stream = open(file_path, "rb") |
|
self.mmap = mmap.mmap(self.stream.fileno(), 0, access=mmap.ACCESS_READ) |
|
|
|
def __del__(self): |
|
try: |
|
self.mmap.close() |
|
self.stream.close() |
|
except: |
|
pass |
|
|
|
def get_at_offset(self, offset) -> tuple[str, bytes]: |
|
header = parse_tar_header(self.mmap[offset : offset + 500]) |
|
name = header.name.decode("utf-8").strip("\x00") |
|
start = offset + 512 |
|
end = start + int(header.size.decode("utf-8")[:-1], 8) |
|
return name, self.mmap[start:end] |
|
|
|
|
|
class Tar: |
|
def __init__(self, path: Path): |
|
self.tar = MMTar(path) |
|
indices_path = path.with_suffix(".index") |
|
self.index = pickle.loads(indices_path.read_bytes()) |
|
self.name_mapping = {} |
|
for name, offset, _ in self.index: |
|
self.name_mapping[name] = offset |
|
|
|
def read(self, name: str) -> bytes: |
|
return self.tar.get_at_offset(self.name_mapping[name])[1] |
|
|
|
def cosy_jsonl_opener(data, mode='train', tts_data={}): |
|
""" Give url or local file, return file descriptor |
|
Inplace operation. |
|
|
|
Args: |
|
data(Iterable[str]): url or local file list |
|
|
|
Returns: |
|
Iterable[{src, stream}] |
|
""" |
|
for sample in data: |
|
assert 'src' in sample |
|
cosy_jsonl_path = sample['src'] |
|
tar_file_path=cosy_jsonl_path.replace(".vq0907.jsonl",".tar") |
|
try: |
|
tar_data=Tar(Path(tar_file_path)) |
|
with open(cosy_jsonl_path, 'r') as f: |
|
for line in f: |
|
item=json.loads(line) |
|
cosy_token = item['cosy_token'] |
|
sample['speech_token']=torch.tensor(cosy_token) |
|
sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename']))) |
|
|
|
yield {**sample} |
|
|
|
except Exception as ex: |
|
logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex)) |
|
|
|
|
|
def cosy_jsonl_opener_vq0918_nopool(data, mode='train', tts_data={}): |
|
""" Give url or local file, return file descriptor |
|
Inplace operation. |
|
|
|
Args: |
|
data(Iterable[str]): url or local file list |
|
|
|
Returns: |
|
Iterable[{src, stream}] |
|
""" |
|
for sample in data: |
|
assert 'src' in sample |
|
cosy_jsonl_path = sample['src'] |
|
tar_file_path=cosy_jsonl_path.replace(".vq0918-nopool.jsonl",".tar") |
|
|
|
|
|
try: |
|
tar_data=Tar(Path(tar_file_path)) |
|
with open(cosy_jsonl_path, 'r') as f: |
|
|
|
for line in f: |
|
item=json.loads(line) |
|
cosy_token = item['cosy_token'] |
|
sample['speech_token']=torch.tensor(cosy_token) |
|
sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename']))) |
|
|
|
yield {**sample} |
|
|
|
except Exception as ex: |
|
logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex)) |
|
|
|
|
|
|
|
def cosy_jsonl_opener_vq0918_pool2(data, mode='train', tts_data={}): |
|
""" Give url or local file, return file descriptor |
|
Inplace operation. |
|
|
|
Args: |
|
data(Iterable[str]): url or local file list |
|
|
|
Returns: |
|
Iterable[{src, stream}] |
|
""" |
|
for sample in data: |
|
assert 'src' in sample |
|
cosy_jsonl_path = sample['src'] |
|
tar_file_path=cosy_jsonl_path.replace(".vq0918-pool2.jsonl",".tar") |
|
|
|
try: |
|
tar_data=Tar(Path(tar_file_path)) |
|
with open(cosy_jsonl_path, 'r') as f: |
|
for line in f: |
|
item=json.loads(line) |
|
cosy_token = item['cosy_token'] |
|
sample['speech_token']=torch.tensor(cosy_token) |
|
sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename']))) |
|
|
|
yield {**sample} |
|
|
|
except Exception as ex: |
|
logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex)) |
|
|
|
|
|
def cosy_jsonl_opener_vq0918_pool4(data, mode='train', tts_data={}): |
|
""" Give url or local file, return file descriptor |
|
Inplace operation. |
|
|
|
Args: |
|
data(Iterable[str]): url or local file list |
|
|
|
Returns: |
|
Iterable[{src, stream}] |
|
""" |
|
for sample in data: |
|
assert 'src' in sample |
|
cosy_jsonl_path = sample['src'] |
|
tar_file_path=cosy_jsonl_path.replace(".vq0918-pool4.jsonl",".tar") |
|
try: |
|
tar_data=Tar(Path(tar_file_path)) |
|
with open(cosy_jsonl_path, 'r') as f: |
|
|
|
for line in f: |
|
item=json.loads(line) |
|
cosy_token = item['cosy_token'] |
|
sample['speech_token']=torch.tensor(cosy_token) |
|
sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename']))) |
|
|
|
yield {**sample} |
|
|
|
except Exception as ex: |
|
logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex)) |
|
|
|
|
|
def cosy_jsonl_opener_vq0918_pool8(data, mode='train', tts_data={}): |
|
""" Give url or local file, return file descriptor |
|
Inplace operation. |
|
|
|
Args: |
|
data(Iterable[str]): url or local file list |
|
|
|
Returns: |
|
Iterable[{src, stream}] |
|
""" |
|
for sample in data: |
|
assert 'src' in sample |
|
cosy_jsonl_path = sample['src'] |
|
tar_file_path=cosy_jsonl_path.replace(".vq0918-pool8.jsonl",".tar") |
|
|
|
try: |
|
tar_data=Tar(Path(tar_file_path)) |
|
with open(cosy_jsonl_path, 'r') as f: |
|
|
|
for line in f: |
|
item=json.loads(line) |
|
cosy_token = item['cosy_token'] |
|
sample['speech_token']=torch.tensor(cosy_token) |
|
sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename']))) |
|
|
|
yield {**sample} |
|
|
|
except Exception as ex: |
|
logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex)) |
|
|
|
|
|
|
|
def process_sft_vq0918_pool4(data, mode='train', tts_data={}): |
|
for sample in data: |
|
assert 'src' in sample |
|
|
|
token_npy_path = sample['src'] |
|
wav_path=token_npy_path.replace(".vq0918-pool4.npy","") |
|
|
|
|
|
try: |
|
sample['speech_token']=torch.tensor(np.load(token_npy_path)) |
|
sample['speech'], sample['sample_rate']= torchaudio.load(wav_path) |
|
if sample['speech'].shape[0] > 1: |
|
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) |
|
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) |
|
yield {**sample} |
|
except Exception as ex: |
|
logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex)) |
|
logging.warning('Failed to open {}'.format(wav_path)) |
|
|
|
|
|
def process_sft_vq0918_pool4_split(data, mode='train',split_token=25, tts_data={}): |
|
for sample in data: |
|
assert 'src' in sample |
|
|
|
token_npy_path = sample['src'] |
|
wav_path=token_npy_path.replace(".vq0918-pool4.npy","") |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
speech_token=torch.tensor(np.load(token_npy_path)) |
|
speech,sample_rate= torchaudio.load(wav_path) |
|
|
|
if speech.shape[0] > 1: |
|
speech = speech.mean(dim=0, keepdim=True) |
|
|
|
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) |
|
sample['sample_rate']=sample_rate |
|
|
|
num_splits = (speech_token.size(0) + split_token - 1) // split_token |
|
|
|
for split_id in range(num_splits): |
|
end_token_idx = min((split_id + 1) * split_token, speech_token.size(0)) |
|
end_speech_idx=int(np.ceil(end_token_idx / 12.5 * sample_rate)) |
|
sample['speech_token']=speech_token[:end_token_idx] |
|
sample['speech']=speech[:,:end_speech_idx] |
|
print(sample['speech_token'].size(),sample['speech'].size()) |
|
yield {**sample} |
|
except Exception as ex: |
|
logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex)) |
|
logging.warning('Failed to open {}'.format(wav_path)) |
|
|
|
|
|
def process_sft_vq0918_pool2(data, mode='train', tts_data={}): |
|
for sample in data: |
|
assert 'src' in sample |
|
|
|
token_npy_path = sample['src'].replace(".vq0918-pool4.npy",".vq0918-pool2.npy") |
|
wav_path=token_npy_path.replace(".vq0918-pool2.npy","") |
|
|
|
|
|
try: |
|
sample['speech_token']=torch.tensor(np.load(token_npy_path)) |
|
sample['speech'], sample['sample_rate']= torchaudio.load(wav_path) |
|
if sample['speech'].shape[0] > 1: |
|
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) |
|
|
|
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) |
|
yield {**sample} |
|
except Exception as ex: |
|
logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex)) |
|
logging.warning('Failed to open {}'.format(wav_path)) |
|
|
|
|
|
def process_sft_vq0918_pool2_split(data, mode='train',split_token=50, tts_data={}): |
|
for sample in data: |
|
assert 'src' in sample |
|
|
|
token_npy_path = sample['src'] |
|
wav_path=token_npy_path.replace(".vq0918-pool2.npy","") |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
speech_token=torch.tensor(np.load(token_npy_path)) |
|
speech,sample_rate= torchaudio.load(wav_path) |
|
|
|
if speech.shape[0] > 1: |
|
speech = speech.mean(dim=0, keepdim=True) |
|
|
|
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) |
|
sample['sample_rate']=sample_rate |
|
|
|
num_splits = (speech_token.size(0) + split_token - 1) // split_token |
|
|
|
for split_id in range(num_splits): |
|
end_token_idx = min((split_id + 1) * split_token, speech_token.size(0)) |
|
end_speech_idx=int(np.ceil(end_token_idx / 25 * sample_rate)) |
|
sample['speech_token']=speech_token[:end_token_idx] |
|
sample['speech']=speech[:,:end_speech_idx] |
|
print(sample['speech_token'].size(),sample['speech'].size()) |
|
yield {**sample} |
|
except Exception as ex: |
|
logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex)) |
|
logging.warning('Failed to open {}'.format(wav_path)) |
|
|
|
def process_sft_vq0918_pool4_gpt(data, mode='train', tts_data={}): |
|
for sample in data: |
|
assert 'src' in sample |
|
try: |
|
entry=json.loads(sample['src']) |
|
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) |
|
|
|
for conv in entry["conversations"]: |
|
if "response_wav" in conv: |
|
wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}" |
|
token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy") |
|
sample['speech_token']=torch.tensor(np.load(token_npy_path)) |
|
sample['speech'], sample['sample_rate']= torchaudio.load(wav_path) |
|
if sample['speech'].shape[0] > 1: |
|
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) |
|
sample['spk_embedding']=spk_embedding |
|
yield {**sample} |
|
except Exception as ex: |
|
|
|
logging.warning('Failed to open {}'.format(wav_path)) |
|
|
|
|
|
def process_sft_vq0918_pool4_gpt_1010(data, mode='train', tts_data={}): |
|
for sample in data: |
|
assert 'src' in sample |
|
try: |
|
entry=json.loads(sample['src']) |
|
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) |
|
|
|
for conv in entry["conversations"]: |
|
if "response_wav" in conv: |
|
wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}" |
|
token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy") |
|
sample['speech_token']=torch.tensor(np.load(token_npy_path)) |
|
sample['speech'], sample['sample_rate']= torchaudio.load(wav_path) |
|
if sample['speech'].shape[0] > 1: |
|
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) |
|
sample['spk_embedding']=spk_embedding |
|
yield {**sample} |
|
if "prompt_wav" in conv: |
|
wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}" |
|
token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy") |
|
sample['speech_token']=torch.tensor(np.load(token_npy_path)) |
|
sample['speech'], sample['sample_rate']= torchaudio.load(wav_path) |
|
if sample['speech'].shape[0] > 1: |
|
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) |
|
sample['spk_embedding']=spk_embedding |
|
yield {**sample} |
|
except Exception as ex: |
|
|
|
logging.warning('Failed to open {}'.format(wav_path)) |
|
|
|
|
|
def filter(data, |
|
max_length=10240, |
|
min_length=10, |
|
token_max_length=200, |
|
token_min_length=1, |
|
min_output_input_ratio=0.0005, |
|
max_output_input_ratio=1, |
|
mode='train'): |
|
""" Filter sample according to feature and label length |
|
Inplace operation. |
|
|
|
Args:: |
|
data: Iterable[{key, wav, label, sample_rate}] |
|
max_length: drop utterance which is greater than max_length(10ms) |
|
min_length: drop utterance which is less than min_length(10ms) |
|
token_max_length: drop utterance which is greater than |
|
token_max_length, especially when use char unit for |
|
english modeling |
|
token_min_length: drop utterance which is |
|
less than token_max_length |
|
min_output_input_ratio: minimal ration of |
|
token_length / feats_length(10ms) |
|
max_output_input_ratio: maximum ration of |
|
token_length / feats_length(10ms) |
|
|
|
Returns: |
|
Iterable[{key, wav, label, sample_rate}] |
|
""" |
|
for sample in data: |
|
|
|
|
|
|
|
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100 |
|
if num_frames < min_length: |
|
continue |
|
if num_frames > max_length: |
|
continue |
|
if len(sample['text_token']) < token_min_length: |
|
continue |
|
if len(sample['text_token']) > token_max_length: |
|
continue |
|
if len(sample['speech_token']) == 0: |
|
continue |
|
if num_frames != 0: |
|
if len(sample['text_token']) / num_frames < min_output_input_ratio: |
|
continue |
|
if len(sample['text_token']) / num_frames > max_output_input_ratio: |
|
continue |
|
yield sample |
|
|
|
|
|
def filter_speech_token(data, |
|
max_length=10240, |
|
min_length=10, |
|
token_max_length=5000, |
|
token_min_length=1, |
|
min_output_input_ratio=0.0005, |
|
max_output_input_ratio=30, |
|
mode='train'): |
|
""" Filter sample according to feature and label length |
|
Inplace operation. |
|
|
|
Args:: |
|
data: Iterable[{key, wav, label, sample_rate}] |
|
max_length: drop utterance which is greater than max_length(10ms) |
|
min_length: drop utterance which is less than min_length(10ms) |
|
token_max_length: drop utterance which is greater than |
|
token_max_length, especially when use char unit for |
|
english modeling |
|
token_min_length: drop utterance which is |
|
less than token_max_length |
|
min_output_input_ratio: minimal ration of |
|
token_length / feats_length(10ms) |
|
max_output_input_ratio: maximum ration of |
|
token_length / feats_length(10ms) |
|
|
|
Returns: |
|
Iterable[{key, wav, label, sample_rate}] |
|
""" |
|
for sample in data: |
|
|
|
|
|
|
|
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100 |
|
if num_frames < min_length: |
|
continue |
|
if num_frames > max_length: |
|
continue |
|
if len(sample['speech_token']) < token_min_length: |
|
continue |
|
if len(sample['speech_token']) > token_max_length: |
|
continue |
|
if len(sample['speech_token']) == 0: |
|
continue |
|
if num_frames != 0: |
|
if len(sample['speech_token']) / num_frames < min_output_input_ratio: |
|
continue |
|
if len(sample['speech_token']) / num_frames > max_output_input_ratio: |
|
continue |
|
yield sample |
|
|
|
|
|
def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'): |
|
""" Resample data. |
|
Inplace operation. |
|
|
|
Args: |
|
data: Iterable[{key, wav, label, sample_rate}] |
|
resample_rate: target resample rate |
|
|
|
Returns: |
|
Iterable[{key, wav, label, sample_rate}] |
|
""" |
|
for sample in data: |
|
assert 'sample_rate' in sample |
|
assert 'speech' in sample |
|
sample_rate = sample['sample_rate'] |
|
waveform = sample['speech'] |
|
if sample_rate != resample_rate: |
|
if sample_rate < min_sample_rate: |
|
continue |
|
sample['sample_rate'] = resample_rate |
|
sample['speech'] = torchaudio.transforms.Resample( |
|
orig_freq=sample_rate, new_freq=resample_rate)(waveform) |
|
max_val = sample['speech'].abs().max() |
|
if max_val > 1: |
|
sample['speech'] /= max_val |
|
yield sample |
|
|
|
|
|
def compute_fbank(data, |
|
feat_extractor, |
|
mode='train'): |
|
""" Extract fbank |
|
|
|
Args: |
|
data: Iterable[{key, wav, label, sample_rate}] |
|
|
|
Returns: |
|
Iterable[{key, feat, label}] |
|
""" |
|
for sample in data: |
|
assert 'sample_rate' in sample |
|
assert 'speech' in sample |
|
|
|
|
|
waveform = sample['speech'] |
|
mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1) |
|
sample['speech_feat'] = mat |
|
del sample['speech'] |
|
yield sample |
|
|
|
|
|
def parse_embedding(data, normalize, mode='train'): |
|
""" Parse utt_embedding/spk_embedding |
|
|
|
Args: |
|
data: Iterable[{key, wav, label, sample_rate}] |
|
|
|
Returns: |
|
Iterable[{key, feat, label}] |
|
""" |
|
for sample in data: |
|
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32) |
|
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32) |
|
if normalize: |
|
sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0) |
|
sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0) |
|
yield sample |
|
|
|
|
|
def tokenize(data, get_tokenizer, allowed_special, mode='train'): |
|
""" Decode text to chars or BPE |
|
Inplace operation |
|
|
|
Args: |
|
data: Iterable[{key, wav, txt, sample_rate}] |
|
|
|
Returns: |
|
Iterable[{key, wav, txt, tokens, label, sample_rate}] |
|
""" |
|
tokenizer = get_tokenizer() |
|
for sample in data: |
|
assert 'text' in sample |
|
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special) |
|
if mode == 'inference': |
|
sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special) |
|
yield sample |
|
|
|
|
|
def shuffle(data, shuffle_size=10000, mode='train'): |
|
""" Local shuffle the data |
|
|
|
Args: |
|
data: Iterable[{key, feat, label}] |
|
shuffle_size: buffer size for shuffle |
|
|
|
Returns: |
|
Iterable[{key, feat, label}] |
|
""" |
|
buf = [] |
|
for sample in data: |
|
buf.append(sample) |
|
if len(buf) >= shuffle_size: |
|
random.shuffle(buf) |
|
for x in buf: |
|
yield x |
|
buf = [] |
|
|
|
random.shuffle(buf) |
|
for x in buf: |
|
yield x |
|
|
|
|
|
def sort(data, sort_size=500, mode='train'): |
|
""" Sort the data by feature length. |
|
Sort is used after shuffle and before batch, so we can group |
|
utts with similar lengths into a batch, and `sort_size` should |
|
be less than `shuffle_size` |
|
|
|
Args: |
|
data: Iterable[{key, feat, label}] |
|
sort_size: buffer size for sort |
|
|
|
Returns: |
|
Iterable[{key, feat, label}] |
|
""" |
|
|
|
buf = [] |
|
for sample in data: |
|
buf.append(sample) |
|
if len(buf) >= sort_size: |
|
buf.sort(key=lambda x: x['speech_feat'].size(0)) |
|
for x in buf: |
|
yield x |
|
buf = [] |
|
|
|
buf.sort(key=lambda x: x['speech_feat'].size(0)) |
|
for x in buf: |
|
yield x |
|
|
|
|
|
def static_batch(data, batch_size=16): |
|
""" Static batch the data by `batch_size` |
|
|
|
Args: |
|
data: Iterable[{key, feat, label}] |
|
batch_size: batch size |
|
|
|
Returns: |
|
Iterable[List[{key, feat, label}]] |
|
""" |
|
buf = [] |
|
for sample in data: |
|
buf.append(sample) |
|
if len(buf) >= batch_size: |
|
yield buf |
|
buf = [] |
|
if len(buf) > 0: |
|
yield buf |
|
|
|
|
|
def dynamic_batch(data, max_frames_in_batch=12000, mode='train'): |
|
""" Dynamic batch the data until the total frames in batch |
|
reach `max_frames_in_batch` |
|
|
|
Args: |
|
data: Iterable[{key, feat, label}] |
|
max_frames_in_batch: max_frames in one batch |
|
|
|
Returns: |
|
Iterable[List[{key, feat, label}]] |
|
""" |
|
buf = [] |
|
longest_frames = 0 |
|
for sample in data: |
|
assert 'speech_feat' in sample |
|
assert isinstance(sample['speech_feat'], torch.Tensor) |
|
new_sample_frames = sample['speech_feat'].size(0) |
|
longest_frames = max(longest_frames, new_sample_frames) |
|
frames_after_padding = longest_frames * (len(buf) + 1) |
|
if frames_after_padding > max_frames_in_batch: |
|
yield buf |
|
buf = [sample] |
|
longest_frames = new_sample_frames |
|
else: |
|
buf.append(sample) |
|
if len(buf) > 0: |
|
yield buf |
|
|
|
|
|
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'): |
|
""" Wrapper for static/dynamic batch |
|
""" |
|
if mode == 'inference': |
|
return static_batch(data, 1) |
|
else: |
|
if batch_type == 'static': |
|
return static_batch(data, batch_size) |
|
elif batch_type == 'dynamic': |
|
return dynamic_batch(data, max_frames_in_batch) |
|
else: |
|
logging.fatal('Unsupported batch type {}'.format(batch_type)) |
|
|
|
|
|
def padding(data, use_spk_embedding, mode='train'): |
|
""" Padding the data into training data |
|
|
|
Args: |
|
data: Iterable[List[{key, feat, label}]] |
|
|
|
Returns: |
|
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] |
|
""" |
|
for sample in data: |
|
assert isinstance(sample, list) |
|
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample], |
|
dtype=torch.int32) |
|
order = torch.argsort(speech_feat_len, descending=True) |
|
|
|
utts = [sample[i]['utt'] for i in order] |
|
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order] |
|
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32) |
|
speech_token = pad_sequence(speech_token, |
|
batch_first=True, |
|
padding_value=0) |
|
speech_feat = [sample[i]['speech_feat'] for i in order] |
|
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32) |
|
speech_feat = pad_sequence(speech_feat, |
|
batch_first=True, |
|
padding_value=0) |
|
text = [sample[i]['text'] for i in order] |
|
text_token = [torch.tensor(sample[i]['text_token']) for i in order] |
|
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32) |
|
text_token = pad_sequence(text_token, batch_first=True, padding_value=0) |
|
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0) |
|
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0) |
|
batch = { |
|
"utts": utts, |
|
"speech_token": speech_token, |
|
"speech_token_len": speech_token_len, |
|
"speech_feat": speech_feat, |
|
"speech_feat_len": speech_feat_len, |
|
"text": text, |
|
"text_token": text_token, |
|
"text_token_len": text_token_len, |
|
"utt_embedding": utt_embedding, |
|
"spk_embedding": spk_embedding, |
|
} |
|
if mode == 'inference': |
|
tts_text = [sample[i]['tts_text'] for i in order] |
|
tts_index = [sample[i]['tts_index'] for i in order] |
|
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order] |
|
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32) |
|
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1) |
|
batch.update({'tts_text': tts_text, |
|
'tts_index': tts_index, |
|
'tts_text_token': tts_text_token, |
|
'tts_text_token_len': tts_text_token_len}) |
|
if use_spk_embedding is True: |
|
batch["embedding"] = batch["spk_embedding"] |
|
else: |
|
batch["embedding"] = batch["utt_embedding"] |
|
yield batch |
|
|
|
|
|
|
|
def padding_speech_token(data, use_spk_embedding, mode='train'): |
|
""" Padding the data into training data |
|
|
|
Args: |
|
data: Iterable[List[{key, feat, label}]] |
|
|
|
Returns: |
|
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] |
|
""" |
|
for sample in data: |
|
assert isinstance(sample, list) |
|
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample], |
|
dtype=torch.int32) |
|
order = torch.argsort(speech_feat_len, descending=True) |
|
|
|
|
|
|
|
try: |
|
speech_token = [sample[i]['speech_token'].clone().detach() for i in order] |
|
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32) |
|
speech_token = pad_sequence(speech_token, |
|
batch_first=True, |
|
padding_value=0) |
|
speech_feat = [sample[i]['speech_feat'] for i in order] |
|
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32) |
|
speech_feat = pad_sequence(speech_feat, |
|
batch_first=True, |
|
padding_value=0) |
|
batch = { |
|
"speech_token": speech_token, |
|
"speech_token_len": speech_token_len, |
|
"speech_feat": speech_feat, |
|
"speech_feat_len": speech_feat_len, |
|
} |
|
if mode == 'inference': |
|
tts_text = [sample[i]['tts_text'] for i in order] |
|
tts_index = [sample[i]['tts_index'] for i in order] |
|
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order] |
|
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32) |
|
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1) |
|
batch.update({'tts_text': tts_text, |
|
'tts_index': tts_index, |
|
'tts_text_token': tts_text_token, |
|
'tts_text_token_len': tts_text_token_len}) |
|
|
|
|
|
|
|
|
|
batch["embedding"]=torch.zeros((batch["speech_feat"].size(0),192),device=batch["speech_feat"].device) |
|
yield batch |
|
except Exception as ex: |
|
logging.warning(' ex info {}'.format(ex)) |
|
|
|
|
|
|
|
|
|
def padding_speech_token_spk(data, use_spk_embedding, mode='train'): |
|
""" Padding the data into training data |
|
|
|
Args: |
|
data: Iterable[List[{key, feat, label}]] |
|
|
|
Returns: |
|
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] |
|
""" |
|
for sample in data: |
|
assert isinstance(sample, list) |
|
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample], |
|
dtype=torch.int32) |
|
order = torch.argsort(speech_feat_len, descending=True) |
|
|
|
|
|
|
|
try: |
|
speech_token = [sample[i]['speech_token'].clone().detach() for i in order] |
|
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32) |
|
speech_token = pad_sequence(speech_token, |
|
batch_first=True, |
|
padding_value=0) |
|
speech_feat = [sample[i]['speech_feat'] for i in order] |
|
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32) |
|
speech_feat = pad_sequence(speech_feat, |
|
batch_first=True, |
|
padding_value=0) |
|
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0) |
|
batch = { |
|
"speech_token": speech_token, |
|
"speech_token_len": speech_token_len, |
|
"speech_feat": speech_feat, |
|
"speech_feat_len": speech_feat_len, |
|
"spk_embedding": spk_embedding, |
|
} |
|
if mode == 'inference': |
|
tts_text = [sample[i]['tts_text'] for i in order] |
|
tts_index = [sample[i]['tts_index'] for i in order] |
|
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order] |
|
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32) |
|
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1) |
|
batch.update({'tts_text': tts_text, |
|
'tts_index': tts_index, |
|
'tts_text_token': tts_text_token, |
|
'tts_text_token_len': tts_text_token_len}) |
|
|
|
|
|
|
|
|
|
|
|
batch["embedding"] = batch["spk_embedding"] |
|
yield batch |
|
except Exception as ex: |
|
logging.warning(' ex info {}'.format(ex)) |
|
|