seyoungsong's picture
r
4ec5914 verified
from pathlib import Path
import gradio as gr
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
lang_to_code = {
"Akrikaans": "af",
"Albanian": "sq",
"Amharic": "am",
"Arabic": "ar",
"Armenian": "hy",
"Assamese": "as",
"Asturian": "ast",
"Aymara": "ay",
"Azerbaijani": "az",
"Bashkir": "ba",
"Belarusian": "be",
"Bengali": "bn",
"Bosnian": "bs",
"Breton": "br",
"Bulgarian": "bg",
"Burmese": "my",
"Catalan": "ca",
"Cebuano": "ceb",
"Central Khmer": "km",
"Chinese": "zh",
"Chokwe": "cjk",
"Croatian": "hr",
"Czech": "cs",
"Danish": "da",
"Dutch": "nl",
"Dyula": "dyu",
"English": "en",
"Estonian": "et",
"Finnish": "fi",
"French": "fr",
"Fulah": "ff",
"Galician": "gl",
"Ganda": "lg",
"Georgian": "ka",
"German": "de",
"Greek": "el",
"Gujarati": "gu",
"Haitian Creole": "ht",
"Hausa": "ha",
"Hebrew": "he",
"Hindi": "hi",
"Hungarian": "hu",
"Icelandic": "is",
"Igbo": "ig",
"Iloko": "ilo",
"Indonesian": "id",
"Irish": "ga",
"Italian": "it",
"Japanese": "ja",
"Javanese": "jv",
"Kabuverdianu": "kea",
"Kachin": "kac",
"Kamba": "kam",
"Kannada": "kn",
"Kazakh": "kk",
"Kimbundu": "kmb",
"Kongo": "kg",
"Korean": "ko",
"Kurdish": "ku",
"Kyrgyz": "ky",
"Lao": "lo",
"Latvian": "lv",
"Lingala": "ln",
"Lithuanian": "lt",
"Luo": "luo",
"Luxembourgish": "lb",
"Macedonian": "mk",
"Malagasy": "mg",
"Malay": "ms",
"Malayalam": "ml",
"Maltese": "mt",
"Maori": "mi",
"Marathi": "mr",
"Mongolian": "mn",
"Nepali": "ne",
"Northern Kurdish": "kmr",
"Northern Sotho": "ns",
"Norwegian": "no",
"Nyanja": "ny",
"Occitan": "oc",
"Oriya": "or",
"Oromo": "om",
"Pashto": "ps",
"Persian": "fa",
"Polish": "pl",
"Portuguese": "pt",
"Punjabi": "pa",
"Quechua": "qu",
"Romanian": "ro",
"Russian": "ru",
"Scottish Gaelic": "gd",
"Serbian": "sr",
"Shan": "shn",
"Shona": "sn",
"Sindhi": "sd",
"Sinhala": "si",
"Slovak": "sk",
"Slovenian": "sl",
"Somali": "so",
"Spanish": "es",
"Sundanese": "su",
"Swahili": "sw",
"Swati": "ss",
"Swedish": "sv",
"Tagalog": "tl",
"Tajik": "tg",
"Tamil": "ta",
"Telugu": "te",
"Thai": "th",
"Tigrinya": "ti",
"Tswana": "tn",
"Turkish": "tr",
"Ukrainian": "uk",
"Umbundu": "umb",
"Urdu": "ur",
"Uzbek": "uz",
"Vietnamese": "vi",
"Welsh": "cy",
"Western Frisian": "fy",
"Wolof": "wo",
"Xhosa": "xh",
"Yiddish": "yi",
"Yoruba": "yo",
"Zulu": "zu",
}
lang_names = list(lang_to_code.keys())
# load model
model_path = Path("./model_files").resolve()
print(f"model_path: {model_path}")
tokenizer: M2M100Tokenizer = M2M100Tokenizer.from_pretrained(
pretrained_model_name_or_path=str(model_path), local_files_only=True
)
model = M2M100ForConditionalGeneration.from_pretrained(
pretrained_model_name_or_path=str(model_path), local_files_only=True
)
# fix tokenizer
tokenizer.lang_token_to_id = {
t: i
for t, i in zip(tokenizer.all_special_tokens, tokenizer.all_special_ids)
if i > 5
}
tokenizer.lang_code_to_token = {s.strip("_"): s for s in tokenizer.lang_token_to_id}
tokenizer.lang_code_to_id = {
s.strip("_"): i for s, i in tokenizer.lang_token_to_id.items()
}
tokenizer.id_to_lang_token = {i: s for s, i in tokenizer.lang_token_to_id.items()}
def translate(src_text: str, source_lang: str, target_lang: str) -> str:
# get lang code
src_lang = lang_to_code[source_lang]
tgt_lang = lang_to_code[target_lang]
# encode
tokenizer.src_lang = src_lang
encoded_input = tokenizer(src_text, return_tensors="pt")
# inference
generated_tokens = model.generate(
**encoded_input,
forced_bos_token_id=tokenizer.get_lang_id(tgt_lang),
max_length=1024,
)
# decode
pred_texts = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
pred_text = pred_texts[0]
assert isinstance(pred_text, str)
return pred_text
inputs = [
gr.Textbox(lines=4, value="Hello world!", label="Input Text"),
gr.Dropdown(lang_names, value="English", label="Source Language"),
gr.Dropdown(lang_names, value="Korean", label="Target Language"),
]
outputs = gr.Textbox(lines=4, label="Output Text")
demo = gr.Interface(
fn=translate,
inputs=inputs,
outputs=outputs,
title="Flores101: Large-Scale Multilingual Machine Translation",
description="[`seyoungsong/flores101_mm100_175M`](https://huggingface.co/seyoungsong/flores101_mm100_175M)",
)
if __name__ == "__main__":
# https://huggingface.co/seyoungsong/flores101_mm100_175M
# https://huggingface.co/spaces/seyoungsong/flores101_mm100_175M
# gradio src/pretrained/gradio/app.py
# http://127.0.0.1:7860
demo.launch()