File size: 5,088 Bytes
f5b4ff2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# @ [email protected]

import argparse
def parse_args():
    parser = argparse.ArgumentParser(description="encode the librilight dataset using encodec model")
    parser.add_argument("--dataset_name", type=str, default='English', help='name of dataset')
    parser.add_argument('--dataset_dir', type=str, default=None, help="dataset path")
    parser.add_argument('--save_dir', type=str, default=None, help="path to the manifest, phonemes, and encodec codes dirs")
    return parser.parse_args()

if __name__ == "__main__":
    import logging
    formatter = (
        "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
    )
    logging.basicConfig(format=formatter, level=logging.INFO)
    args = parse_args()

    import os
    import numpy as np
    import torch
    import tqdm
    import time
    import pandas as pd
    import multiprocessing
    from tokenizer import TextTokenizer, tokenize_text
    
    # get the path
    phn_save_root = os.path.join(args.save_dir, args.dataset_name, "phonemes")
    os.makedirs(phn_save_root, exist_ok=True)

    def sort_by_audio_len(lens):
        inds = np.argsort(lens).tolist()
        logging.info(f"longest: {lens[inds[-1]]*args.model_code_sr} encodec codes, {lens[inds[-1]]:.2f} sec.")
        logging.info(f"shortest: {lens[inds[0]]*args.model_code_sr} encodec codes, {lens[inds[0]]:.2f} sec.")
        logging.info(f"median: {lens[inds[len(inds)//2]]*args.model_code_sr} encodec codes, {lens[inds[len(inds)//2]]:.2f} sec.")
        logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]*args.model_code_sr} encodec codes, {lens[inds[int(len(inds)*0.95)]]:.2f} sec.")
        return inds[::-1]
    
    def write_array_to_txt_file(array, filename):
        with open(filename, 'w') as f:
            for a in array[:-1]:
                f.write(' '.join(map(str, a))+'\n')
            f.write(' '.join(map(str, array[-1])))
    

    ### phonemization
    # load tokenizer
    text_tokenizer = TextTokenizer(backend="espeak") # add language='cmn' when you process mandarin

    punc2sym = {" <COMMA>": ",", " <PERIOD>": ".", " <QUESTIONMARK>": "?", " <EXCLAMATIONPOINT>": "!"} # note the space in front of each punc name
    gar2sym = {"<SIL>": "#%#", "<MUSIC>": "##%", "<NOISE>": "%%#", "<OTHER>":"%#%"} # so that they are savely keep as the original sym when using tokenize_text
    punc2sym.update(gar2sym)

    word2sym = { "h æ ʃ h ɐ ʃ p ɚ s ɛ n t": "<MUSIC>", "h æ ʃ p ɚ s ɛ n t h æ ʃ": "<SIL>", "p ɚ s ɛ n t h ɐ ʃ p ɚ s ɛ n t": "<OTHER>", "p ɚ s ɛ n t p ɚ s ɛ n t h æ ʃ": "<NOISE>"}
    forbidden_words = set(['#%#', '##%', '%%#', '%#%'])

    stime = time.time()
    logging.info("loading the dataset...")

    splits = ['validation', 'test', 'train']
    
    logging.info(f"phonemizing...")
    # you will see a ton of [WARNING] words_mismatch.py:88......, it's not a issue
    for split in tqdm.tqdm(splits):
        skip = 0
        logging.info(f"now processing split {split}...")
        jsondata = pd.read_json(path_or_buf=os.path.join(args.dataset_dir, 'trans', split+'.json'), lines=True)
        N=88
        df_split = np.array_split(jsondata, N)
        print(len(jsondata))
        # Optional: Save each part to a separate JSON file
        cmds = []
        for idx, part in enumerate(df_split):
            # if idx >80 and idx <=100:
            part.reset_index(drop=True, inplace=True)
            cmds.append((idx, part))

        def process_one(indx, splitdata):
            vocab_fn = os.path.join(args.save_dir, args.dataset_name, f"vocab_{split}_{str(indx)}.txt")
            phn_vocab = set()
            all_lens = []
            for key in tqdm.tqdm(range(len(splitdata))):
                save_fn = os.path.join(phn_save_root, splitdata['segment_id'][key]+".txt")
                if not os.path.exists(save_fn):
                    text = splitdata['trans'][key]
                    if sum(word in forbidden_words for word in text.split(" ")):
                        logging.info(f"skip {splitdata['segment_id'][key]}, because it contains forbiden words. It's transcript: {text}")
                        skip += 1
                        continue
                    for k, v in punc2sym.items():
                        text = text.replace(k, v)
                    phn = tokenize_text(text_tokenizer, text)
                    phn_seq = " ".join(phn)
                    for k, v in word2sym.items():
                        phn_seq = phn_seq.replace(k, v)
                    phn_vocab.update(phn_seq.split(" "))
                    all_lens.append(len(phn_seq.split(" ")))
                    with open(save_fn, "w") as f:
                        f.write(phn_seq)
                else:
                    print('exists')
                    with open(save_fn, "r") as f:
                        phn_seq = f.read()
                    phn_vocab.update(phn_seq.split(" "))
                    all_lens.append(len(phn_seq.split(" ")))

        with multiprocessing.Pool(processes=88) as pool:
            pool.starmap(process_one, cmds)