# Emilia Dataset: https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07
# if use updated new version, i.e. WebDataset, feel free to modify / draft your own script

# generate audio text map for Emilia ZH & EN
# evaluate for vocab size

import sys
import os


from pathlib import Path
import json
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor

from datasets.arrow_writer import ArrowWriter

from model.utils import (

out_zh = {
zh_filters = ["い", "て"]
# seems synthesized audios, or heavily code-switched
out_en = {
en_filters = ["ا", "い", "て"]

def deal_with_audio_dir(audio_dir):
    audio_jsonl = audio_dir.with_suffix(".jsonl")
    sub_result, durations = [], []
    vocab_set = set()
    bad_case_zh = 0
    bad_case_en = 0
    with open(audio_jsonl, "r") as f:
        lines = f.readlines()
        for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
            obj = json.loads(line)
            text = obj["text"]
            if obj["language"] == "zh":
                if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
                    bad_case_zh += 1
                    text = text.translate(
                        str.maketrans({",": ",", "!": "!", "?": "?"})
                    )  # not "。" cuz much code-switched
            if obj["language"] == "en":
                if (
                    obj["wav"].split("/")[1] in out_en
                    or any(f in text for f in en_filters)
                    or repetition_found(text, length=4)
                    bad_case_en += 1
            if tokenizer == "pinyin":
                text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
            duration = obj["duration"]
            sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
    return sub_result, durations, vocab_set, bad_case_zh, bad_case_en

def main():
    assert tokenizer in ["pinyin", "char"]
    result = []
    duration_list = []
    text_vocab_set = set()
    total_bad_case_zh = 0
    total_bad_case_en = 0

    # process raw data
    executor = ProcessPoolExecutor(max_workers=max_workers)
    futures = []
    for lang in langs:
        dataset_path = Path(os.path.join(dataset_dir, lang))
            futures.append(executor.submit(deal_with_audio_dir, audio_dir))
            for audio_dir in dataset_path.iterdir()
            if audio_dir.is_dir()
    for futures in tqdm(futures, total=len(futures)):
        sub_result, durations, vocab_set, bad_case_zh, bad_case_en = futures.result()
        total_bad_case_zh += bad_case_zh
        total_bad_case_en += bad_case_en

    # save preprocessed dataset to disk
    if not os.path.exists(f"data/{dataset_name}"):
    print(f"\nSaving to data/{dataset_name} ...")
    # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})  # oom
    # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
    with ArrowWriter(path=f"data/{dataset_name}/raw.arrow") as writer:
        for line in tqdm(result, desc="Writing to raw.arrow ..."):

    # dup a json separately saving duration in case for DynamicBatchSampler ease
    with open(f"data/{dataset_name}/duration.json", "w", encoding="utf-8") as f:
        json.dump({"duration": duration_list}, f, ensure_ascii=False)

    # vocab map, i.e. tokenizer
    # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
    # if tokenizer == "pinyin":
    #     text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
    with open(f"data/{dataset_name}/vocab.txt", "w") as f:
        for vocab in sorted(text_vocab_set):
            f.write(vocab + "\n")

    print(f"\nFor {dataset_name}, sample count: {len(result)}")
    print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
    print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
    if "ZH" in langs:
        print(f"Bad zh transcription case: {total_bad_case_zh}")
    if "EN" in langs:
        print(f"Bad en transcription case: {total_bad_case_en}\n")

if __name__ == "__main__":
    max_workers = 32

    tokenizer = "pinyin"  # "pinyin" | "char"
    polyphone = True

    langs = ["ZH", "EN"]
    dataset_dir = "<SOME_PATH>/Emilia_Dataset/raw"
    dataset_name = f"Emilia_{'_'.join(langs)}_{tokenizer}"
    print(f"\nPrepare for {dataset_name}\n")


    # Emilia               ZH & EN
    # samples count       37837916   (after removal)
    # pinyin vocab size       2543   (polyphone)
    # total duration      95281.87   (hours)
    # bad zh asr cnt        230435   (samples)
    # bad eh asr cnt         37217   (samples)

    # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
    # please be careful if using pretrained model, make sure the vocab.txt is same