import pyopenjtalk
import re
import sys
import os

# Temporarily redirect stdout and stderr
sys.stdout = open(os.devnull, 'w')
sys.stderr = open(os.devnull, 'w')

# Call the function that produces the warning
# e.g., pyopenjtalk.some_function()

# Restore stdout and stderr
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__

# 定义平假名到片假名的转换表
hiragana_to_katakana = str.maketrans(
    "ぁあぃいぅうぇえぉおかがきぎくぐけげこご"
    "さざしじすずせぜそぞただちぢっつづてでとど"
    "なにぬねのはばぱひびぴふぶぷへべぺほぼぽ"
    "まみむめもゃやゅゆょよらりるれろゎわゐゑをんゔゕゖ",
    "ァアィイゥウェエォオカガキギクグケゲコゴ"
    "サザシジスズセゼソゾタダチヂッツヅテデトド"
    "ナニヌネノハバパヒビピフブプヘベペホボポ"
    "マミムメモャヤュユョヨラリルレロヮワヰヱヲンヴヵヶ"
)

# 定义一个函数,将平假名转换为片假名
def hiragana_to_katakana_func(text):
    return text.translate(hiragana_to_katakana)

# 定义一个函数,准确地分割假名为音拍(mora)
def split_into_moras(kana):
    # 正则表达式匹配日语音拍,包括拗音、小写片假名和长音符号
    mora_pattern = re.compile(
        r"(?:[ァ-ヴー]|[ぁ-ゖ]|ー)[ァィゥェォャュョ]?|ー"
    )
    moras = mora_pattern.findall(kana)
    return moras

# 定义一个函数,根据 acc 值标注升降调
def annotate_kana_with_accent(moras, acc):
    annotated_moras = []
    for i, mora in enumerate(moras):
        annotated_moras.append(mora)
        # 当 acc == 0 时,在第一个假名后添加上升符号
        if acc == 0 and i == 0:
            annotated_moras.append('↑')
        # 当 acc > 1 时,在第一个假名后添加上升符号
        elif acc > 1 and i == 0:
            annotated_moras.append('↑')
        # 当 acc > 0 时,在第 n 个假名后添加下降符号
        elif acc > 0 and i + 1 == acc:
            annotated_moras.append('↓')
    return ''.join(annotated_moras)

# 主函数,获取带音调符号的片假名序列
def get_katakana_with_accent(text):
    current_accent = 0
    # 对于0形,其结束时current_accent为1,对于其他,其结束时current_accent为0
    # 
    tokens = pyopenjtalk.run_frontend(text)
    result = ''
    for token in tokens:
        #print(token)
        mora_size = token['mora_size']
        if mora_size > 1:
            pron = token['pron']
            acc = token['acc']
            # 将发音转换为平假名
            kana = pyopenjtalk.g2p(pron, kana=True)
            # 转换为片假名
            kana = hiragana_to_katakana_func(kana)
            # 分割为音拍(mora)
            moras = split_into_moras(kana)
            # 标注音调符号
            annotated_kana = annotate_kana_with_accent(moras, acc)
            result += annotated_kana
        elif mora_size == 0 or token['pron'] == '’':
            # 对于标点符号等,直接添加原始字符串
            result += token['string'] 
        else:
            result += token['pron']
    result.replace('’', '↑')
    return result

import pyopenjtalk
import re
def pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True):
    """Extract phoneme + prosoody symbol sequence from input full-context labels.

    The algorithm is based on `Prosodic features control by symbols as input of
    sequence-to-sequence acoustic modeling for neural TTS`_ with some r9y9's tweaks.

    Args:
        text (str): Input text.
        drop_unvoiced_vowels (bool): whether to drop unvoiced vowels.

    Returns:
        List[str]: List of phoneme + prosody symbols.

    Examples:
        >>> from espnet2.text.phoneme_tokenizer import pyopenjtalk_g2p_prosody
        >>> pyopenjtalk_g2p_prosody("こんにちは。")
        ['^', 'k', 'o', '[', 'N', 'n', 'i', 'ch', 'i', 'w', 'a', '$']

    .. _`Prosodic features control by symbols as input of sequence-to-sequence acoustic
        modeling for neural TTS`: https://doi.org/10.1587/transinf.2020EDP7104

    """
    labels = pyopenjtalk.make_label(pyopenjtalk.run_frontend(text))
    #print(labels)
    N = len(labels)

    phones = []
    for n in range(N):
        lab_curr = labels[n]

        # current phoneme
        p3 = re.search(r"\-(.*?)\+", lab_curr).group(1)
        # deal unvoiced vowels as normal vowels
        if drop_unvoiced_vowels and p3 in "AEIOU":
            p3 = p3.lower()

        # deal with sil at the beginning and the end of text
        if p3 == "sil":
            assert n == 0 or n == N - 1
            if n == 0:
                phones.append("^")
            elif n == N - 1:
                # check question form or not
                e3 = _numeric_feature_by_regex(r"!(\d+)_", lab_curr)
                if e3 == 0:
                    phones.append("$")
                elif e3 == 1:
                    phones.append("?")
            continue
        elif p3 == "pau":
            phones.append("_")
            continue
        else:
            phones.append(p3)

        # accent type and position info (forward or backward)
        a1 = _numeric_feature_by_regex(r"/A:([0-9\-]+)\+", lab_curr)
        a2 = _numeric_feature_by_regex(r"\+(\d+)\+", lab_curr)
        a3 = _numeric_feature_by_regex(r"\+(\d+)/", lab_curr)

        # number of mora in accent phrase
        f1 = _numeric_feature_by_regex(r"/F:(\d+)_", lab_curr)

        a2_next = _numeric_feature_by_regex(r"\+(\d+)\+", labels[n + 1])
        # accent phrase border
        if a3 == 1 and a2_next == 1 and p3 in "aeiouAEIOUNcl":
            phones.append("#")
        # pitch falling
        elif a1 == 0 and a2_next == a2 + 1 and a2 != f1:
            phones.append("]")
        # pitch rising
        elif a2 == 1 and a2_next == 2:
            phones.append("[")

    return phones

def _numeric_feature_by_regex(regex, s):
    match = re.search(regex, s)
    if match is None:
        return -50
    return int(match.group(1))
import pyopenjtalk
def build_phone_to_katakana():
    # 所有基本的片假名音节
    basic_katakana = [
        'ア', 'イ', 'ウ', 'エ', 'オ',
        'カ', 'キ', 'ク', 'ケ', 'コ',
        'サ', 'シ', 'ス', 'セ', 'ソ',
        'タ', 'チ', 'ツ', 'テ', 'ト',
        'ナ', 'ニ', 'ヌ', 'ネ', 'ノ',
        'ハ', 'ヒ', 'フ', 'ヘ', 'ホ',
        'マ', 'ミ', 'ム', 'メ', 'モ',
        'ヤ', 'ユ', 'ヨ',
        'ラ', 'リ', 'ル', 'レ', 'ロ',
        'ワ', 'ヲ', 'ン',
        'ガ', 'ギ', 'グ', 'ゲ', 'ゴ',
        'ザ', 'ジ', 'ズ', 'ゼ', 'ゾ',
        'ダ', 'ヂ', 'ヅ', 'デ', 'ド',
        'バ', 'ビ', 'ブ', 'ベ', 'ボ',
        'パ', 'ピ', 'プ', 'ペ', 'ポ',
        'キャ', 'キュ', 'キョ',
        'シャ', 'シュ', 'ショ',
        'チャ', 'チュ', 'チョ',
        'ニャ', 'ニュ', 'ニョ',
        'ヒャ', 'ヒュ', 'ヒョ',
        'ミャ', 'ミュ', 'ミョ',
        'リャ', 'リュ', 'リョ',
        'ギャ', 'ギュ', 'ギョ',
        'ジャ', 'ジュ', 'ジョ',
        'ビャ', 'ビュ', 'ビョ',
        'ピャ', 'ピュ', 'ピョ',
        'ヴァ', 'ヴィ', 'ヴ', 'ヴェ', 'ヴォ',
        'ファ', 'フィ', 'フェ', 'フォ',
        'ウィ', 'ウェ', 'ウォ',
        'ティ', 'トゥ',
        'ディ', 'ドゥ',
        'ツァ', 'ツィ', 'ツェ', 'ツォ',
        'デュ', 'デョ',
        'ジェ', 'ジョ',
        'チェ', 'チョ',
        'シェ', 'ショ',
        'ヂェ', 'ヂョ',
        'ヒェ', 'ヒョ',
        'ビェ', 'ビョ',
        'ピェ', 'ピョ',
        'キェ', 'キョ',
        'ギェ', 'ギョ',
        'ミェ', 'ミョ',
        'リェ', 'リョ',
        'アァ', 'イィ', 'ウゥ', 'エェ', 'オォ',
        'ヴャ', 'ヴュ', 'ヴョ',
        'ッ', 'ー'
    ]


    katakana_to_phone = {}

    for kana in basic_katakana:
        # 将片假名转换为平假名
        # hiragana = pyopenjtalk.g2p(kana, kana=True)
        # 将平假名转换为音素表示
        phones = pyopenjtalk.g2p(kana)
        #print(phones)
        # 去除开头和结尾的静音标记(pau)
        phones = phones.strip('')
        # 存储映射关系
        katakana_to_phone[kana] = phones

    phone_to_katakana = {}

    for kana, phones in katakana_to_phone.items():
        # 检查是否已有相同的音素映射
        phone_to_katakana[phones] = kana
    return phone_to_katakana, katakana_to_phone
# 定义转换函数
def phones_list_to_katakana(phone_list, phone_to_katakana):
    output = ''
    i = 0
    length = len(phone_list)
    special_symbols = {'^', '_', '[', ']', '#', '$', '?', '↑', '↓'}
    
    while i < length:
        phone = phone_list[i]
        if phone in special_symbols:
            output += phone
            i += 1
        else:
            max_match_length = 5
            match_found = False
            for l in range(max_match_length, 0, -1):
                if i + l <= length:
                    phones_seq = ' '.join(phone_list[i:i+l])
                    if phones_seq in phone_to_katakana:
                        output += phone_to_katakana[phones_seq]
                        i += l
                        match_found = True
                        break
            if not match_found:
                single_phone = phone_list[i]
                if single_phone in phone_to_katakana:
                    output += phone_to_katakana[single_phone]
                    i += 1
                else:
                    print(f"无法映射的音素: {single_phone}")
                    i += 1
    if len(output) == 0:
        return "…"
    return output.replace("[", "↑").replace("]", "↓")
def katakana_to_phones_list(katakana_list, katakana_to_phone):
    output = []
    i = 0
    length = len(katakana_list)
    special_symbols = {'^', '_', '[', ']', '#', '$', '?', '↑', '↓'}
    
    while i < length:
        katakana = katakana_list[i]
        if katakana in special_symbols:
            output.append(katakana)
            i += 1
        else:
            max_match_length = 5
            match_found = False
            for l in range(max_match_length, 0, -1):
                if i + l <= length:
                    katakana_seq = ''.join(katakana_list[i:i+l])
                    if katakana_seq in katakana_to_phone:
                        output.append(katakana_to_phone[katakana_seq])
                        i += l
                        match_found = True
                        break
            if not match_found:
                single_katakana = katakana_list[i]
                if single_katakana in katakana_to_phone:
                    output.append(katakana_to_phone[single_katakana])
                    i += 1
                else:
                    print(f"无法映射的片假名: {single_katakana}")
                    i += 1
    if len(output) == 0:
        return ["…"]
    return output

phone_to_katakana, katakana_to_phone = build_phone_to_katakana()

def surface_to_katakana_with_accent(text):
    text = text.replace("…", "")
    phones = pyopenjtalk_g2p_prosody(text)
    return phones_list_to_katakana(phones, phone_to_katakana)

def katakana_to_phones(katakana, katakana_to_phone = katakana_to_phone):
    katakana_list = list(katakana)
    phone_list = katakana_to_phones_list(katakana_list, katakana_to_phone)
    return ' '.join(phone_list).replace("^", "").replace("#", "").replace("$", "").replace("  "," ").strip()

# 处理文本中的标点符号和空格
# def preprocess_text(text):
#     # 定义日语字符的正则表达式
#     japanese_characters = re.compile(
#         r"[ぁ-ゟ゠-ヿ一-龯]"
#     )
#     # 定义非日语字符(包括标点符号、空格等)的正则表达式
#     non_japanese_characters = re.compile(
#         r"[^ぁ-ゟ゠-ヿ一-龯]+"
#     )
#     sentences = re.split(non_japanese_characters, text)
#     marks = re.findall(non_japanese_characters, text)
#     processed_text = []
#     for i, sentence in enumerate(sentences):
#         if sentence:
#             annotated_sentence = get_katakana_with_accent(sentence)
#             processed_text.append(annotated_sentence)
#         if i < len(marks):
#             mark = marks[i]
#             if mark.strip():
#                 processed_text.append(mark)
#     temp = ''.join(processed_text)
#     return_text = temp.replace("’", "↑")
#     return return_text
def preprocess_text(text):
    #print(text)
    return surface_to_katakana_with_accent(text)
# 示例用法
if __name__ == "__main__":
    text = "^キョ↓オワ#ワ↑タシノ#マ↑ホオ#エ↑ネル↓キイノ#ホ↑キュウノ#タ↑メ↓ギ$"
    annotated_text = katakana_to_phones(text)
    print(annotated_text)