File size: 8,349 Bytes
f5b4ff2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# @ [email protected]

import argparse
def parse_args():
    parser = argparse.ArgumentParser(description="encode the librilight dataset using encodec model")
    parser.add_argument("--audiopath", type=str, default=None)
    parser.add_argument('--save_dir', type=str, default=None)
    parser.add_argument('--save_tag', type=str, default='encodec_16khz_4codebooks')
    parser.add_argument('--dataset_name', type=str, default=None)
    parser.add_argument('--encodec_model_path', type=str, default=None)
    parser.add_argument('--n_workers', type=int, default=4, help="Number of parallel worker processes")
    parser.add_argument('--mega_batch_size', type=int, default=120, help="Number of samples in each mega batch for multiprocess dataloading")
    parser.add_argument('--batch_size', type=int, default=32, help="batch size for encodec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus")
    parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate')
    parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate')
    parser.add_argument('--model_code_sr', type=int, default=50, help='encodec model code sample rate')
    parser.add_argument('--len_cap', type=float, default=20.0, help='will drop audios that are longer than this number')
    parser.add_argument('--start', type=int, default=0, help='start index for parallel processing')
    parser.add_argument('--end', type=int, default=500000, help='end index for parallel processing')
    parser.add_argument('--max_len', type=int, default=30000, help='max length of audio in samples, if exceed, will cut a batch into half to process, decrease this number if OOM on your machine')
    return parser.parse_args()
    
if __name__ == "__main__":
    import logging
    formatter = (
        "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
    )
    logging.basicConfig(format=formatter, level=logging.INFO)
    args = parse_args()

    import os
    os.environ["USER"] = "root"
    import numpy as np
    import torch
    import tqdm
    import time
    import torchaudio
    from datasets import load_dataset, DownloadConfig
    import pandas as pd
    from tokenizer import TextTokenizer, tokenize_text
    import torchaudio.transforms as transforms
    
    # get the path encodec_16khz_4codebooks
    codes_save_root = os.path.join(args.save_dir, args.dataset_name, args.save_tag)
    os.makedirs(codes_save_root, exist_ok=True)


    def sort_by_audio_len(lens):
        inds = np.argsort(lens).tolist()
        logging.info(f"longest: {lens[inds[-1]]*args.model_code_sr} encodec codes, {lens[inds[-1]]:.2f} sec.")
        logging.info(f"shortest: {lens[inds[0]]*args.model_code_sr} encodec codes, {lens[inds[0]]:.2f} sec.")
        logging.info(f"median: {lens[inds[len(inds)//2]]*args.model_code_sr} encodec codes, {lens[inds[len(inds)//2]]:.2f} sec.")
        logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]*args.model_code_sr} encodec codes, {lens[inds[int(len(inds)*0.95)]]:.2f} sec.")
        return inds[::-1]
    
    def write_array_to_txt_file(array, filename):
        with open(filename, 'w') as f:
            for a in array[:-1]:
                f.write(' '.join(map(str, a))+'\n')
            f.write(' '.join(map(str, array[-1])))
    

    ### phonemization
    # load tokenizer
    # load the encodec model
    from audiocraft.solvers import WMCompressionSolver
    model = WMCompressionSolver.model_from_checkpoint(args.encodec_model_path)
    model = model.cuda()
    model = model.eval()

        
    class mydataset(torch.utils.data.Dataset):
        def __init__(self, args, split):
            super().__init__()
            import glob
            self.data = glob.glob(os.path.join(args.audiopath, "*.wav"))
            self.data = self.data[args.start:args.end]

        def checkout(self, data):
            out = []
            for ind in range(len(data)):
                segment_id = data[ind].split('/')[-1].split(".wav")[0]
                save_fn = os.path.join(codes_save_root, segment_id+".txt")
                if not os.path.exists(save_fn):
                    out.append(data[ind])
            return out
            
        def __len__(self):
            return len(self.data)
        def __getitem__(self, ind):
            segment_id = self.data[ind].split('/')[-1].split(".wav")[0]
            if os.path.exists(self.data[ind]):
                audio, sr = torchaudio.load(self.data[ind])
            else:
                audio, sr = torchaudio.load(self.data[ind].replace('/apdcephfs_cq2', '/apdcephfs_cq2_1297902'))
            if sr != 16000:
                resampler = transforms.Resample(orig_freq=sr, new_freq=16000)
                audio = resampler(audio)
            duration = audio.shape[1] / sr
            return segment_id, audio.squeeze(), sr, duration
        def collate(self, batch):
            res = {'segment_id': [], "audio": [], "sr": [], "duration":[]}
            for item in batch:
                if item[0] != None:
                    res['segment_id'].append(item[0])
                    res['audio'].append(item[1])
                    res['sr'].append(item[2])
                    res['duration'].append(item[3])
            return res


    ## encodec codes extraction
    logging.info("encodec encoding...")
    train_dataset = mydataset(args, 'train')
    print(len(train_dataset))
    train_loader = torch.torch.utils.data.DataLoader(train_dataset, batch_size=args.mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=train_dataset.collate)
    splits = ['train']
    loaders = [train_loader]

    for split, loader in zip(splits, loaders):
        skip = 0
        logging.info(f"now processing split {split}...")
        for m, mega_batch in enumerate(loader):
            logging.info(f"====================================")
            logging.info(f"====================================")
            lengths = np.array(mega_batch['duration'])
            sorted_inds = sort_by_audio_len(lengths)
            for j in range(len(sorted_inds))[::-1]:
                if lengths[sorted_inds[j]] < 0.2 or lengths[sorted_inds[j]] > args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s)
                    skip += 1
                    del sorted_inds[j]
            
            n_steps = int(np.ceil(len(sorted_inds) / args.batch_size))
            for n in tqdm.tqdm(range(n_steps), disable=True):
                inds_used = sorted_inds[n*args.batch_size:(n+1)*args.batch_size]
                audio_batch = [mega_batch['audio'][id] for id in inds_used]
                sr_batch = [mega_batch['sr'][id] for id in inds_used]
                segment_id_batch = [mega_batch['segment_id'][id] for id in inds_used]
                padded_wav = torch.nn.utils.rnn.pad_sequence(audio_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T]
                all_lens = [lengths[id] for id in inds_used]
                with torch.no_grad():
                    if max(all_lens) > args.max_len and len(all_lens) > 1: # NOTE decrease args.max_len if OOM, or chunk it into more than 2 forward passes
                        codes = []
                        inwav = padded_wav.cuda()
                        codes.append(model.encode(inwav[:len(inwav)//2])[0].cpu())
                        codes.append(model.encode(inwav[len(inwav)//2:])[0].cpu())
                        codes = torch.cat(codes, dim=0)
                    else:
                        encoded_frames = model.encode(padded_wav.cuda())
                        # logging.info(f"encoded_frames: {encoded_frames[0].shape}")
                        codes = encoded_frames[0].cpu()

                for i, length in enumerate(all_lens):
                    save_fn = os.path.join(codes_save_root, segment_id_batch[i]+".txt")
                    if not os.path.exists(save_fn):
                        actual_len = round(length * args.model_code_sr) # 320 is downsample rate for this model
                        cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist()
                        write_array_to_txt_file(cur_code, save_fn)