Spaces:
Sleeping
Sleeping
from pathlib import Path | |
import gradio as gr | |
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer | |
lang_to_code = { | |
"Akrikaans": "af", | |
"Albanian": "sq", | |
"Amharic": "am", | |
"Arabic": "ar", | |
"Armenian": "hy", | |
"Assamese": "as", | |
"Asturian": "ast", | |
"Aymara": "ay", | |
"Azerbaijani": "az", | |
"Bashkir": "ba", | |
"Belarusian": "be", | |
"Bengali": "bn", | |
"Bosnian": "bs", | |
"Breton": "br", | |
"Bulgarian": "bg", | |
"Burmese": "my", | |
"Catalan": "ca", | |
"Cebuano": "ceb", | |
"Central Khmer": "km", | |
"Chinese": "zh", | |
"Chokwe": "cjk", | |
"Croatian": "hr", | |
"Czech": "cs", | |
"Danish": "da", | |
"Dutch": "nl", | |
"Dyula": "dyu", | |
"English": "en", | |
"Estonian": "et", | |
"Finnish": "fi", | |
"French": "fr", | |
"Fulah": "ff", | |
"Galician": "gl", | |
"Ganda": "lg", | |
"Georgian": "ka", | |
"German": "de", | |
"Greek": "el", | |
"Gujarati": "gu", | |
"Haitian Creole": "ht", | |
"Hausa": "ha", | |
"Hebrew": "he", | |
"Hindi": "hi", | |
"Hungarian": "hu", | |
"Icelandic": "is", | |
"Igbo": "ig", | |
"Iloko": "ilo", | |
"Indonesian": "id", | |
"Irish": "ga", | |
"Italian": "it", | |
"Japanese": "ja", | |
"Javanese": "jv", | |
"Kabuverdianu": "kea", | |
"Kachin": "kac", | |
"Kamba": "kam", | |
"Kannada": "kn", | |
"Kazakh": "kk", | |
"Kimbundu": "kmb", | |
"Kongo": "kg", | |
"Korean": "ko", | |
"Kurdish": "ku", | |
"Kyrgyz": "ky", | |
"Lao": "lo", | |
"Latvian": "lv", | |
"Lingala": "ln", | |
"Lithuanian": "lt", | |
"Luo": "luo", | |
"Luxembourgish": "lb", | |
"Macedonian": "mk", | |
"Malagasy": "mg", | |
"Malay": "ms", | |
"Malayalam": "ml", | |
"Maltese": "mt", | |
"Maori": "mi", | |
"Marathi": "mr", | |
"Mongolian": "mn", | |
"Nepali": "ne", | |
"Northern Kurdish": "kmr", | |
"Northern Sotho": "ns", | |
"Norwegian": "no", | |
"Nyanja": "ny", | |
"Occitan": "oc", | |
"Oriya": "or", | |
"Oromo": "om", | |
"Pashto": "ps", | |
"Persian": "fa", | |
"Polish": "pl", | |
"Portuguese": "pt", | |
"Punjabi": "pa", | |
"Quechua": "qu", | |
"Romanian": "ro", | |
"Russian": "ru", | |
"Scottish Gaelic": "gd", | |
"Serbian": "sr", | |
"Shan": "shn", | |
"Shona": "sn", | |
"Sindhi": "sd", | |
"Sinhala": "si", | |
"Slovak": "sk", | |
"Slovenian": "sl", | |
"Somali": "so", | |
"Spanish": "es", | |
"Sundanese": "su", | |
"Swahili": "sw", | |
"Swati": "ss", | |
"Swedish": "sv", | |
"Tagalog": "tl", | |
"Tajik": "tg", | |
"Tamil": "ta", | |
"Telugu": "te", | |
"Thai": "th", | |
"Tigrinya": "ti", | |
"Tswana": "tn", | |
"Turkish": "tr", | |
"Ukrainian": "uk", | |
"Umbundu": "umb", | |
"Urdu": "ur", | |
"Uzbek": "uz", | |
"Vietnamese": "vi", | |
"Welsh": "cy", | |
"Western Frisian": "fy", | |
"Wolof": "wo", | |
"Xhosa": "xh", | |
"Yiddish": "yi", | |
"Yoruba": "yo", | |
"Zulu": "zu", | |
} | |
lang_names = list(lang_to_code.keys()) | |
# load model | |
model_path = Path("./model_files").resolve() | |
print(f"model_path: {model_path}") | |
tokenizer: M2M100Tokenizer = M2M100Tokenizer.from_pretrained( | |
pretrained_model_name_or_path=str(model_path), local_files_only=True | |
) | |
model = M2M100ForConditionalGeneration.from_pretrained( | |
pretrained_model_name_or_path=str(model_path), local_files_only=True | |
) | |
# fix tokenizer | |
tokenizer.lang_token_to_id = { | |
t: i | |
for t, i in zip(tokenizer.all_special_tokens, tokenizer.all_special_ids) | |
if i > 5 | |
} | |
tokenizer.lang_code_to_token = {s.strip("_"): s for s in tokenizer.lang_token_to_id} | |
tokenizer.lang_code_to_id = { | |
s.strip("_"): i for s, i in tokenizer.lang_token_to_id.items() | |
} | |
tokenizer.id_to_lang_token = {i: s for s, i in tokenizer.lang_token_to_id.items()} | |
def translate(src_text: str, source_lang: str, target_lang: str) -> str: | |
# get lang code | |
src_lang = lang_to_code[source_lang] | |
tgt_lang = lang_to_code[target_lang] | |
# encode | |
tokenizer.src_lang = src_lang | |
encoded_input = tokenizer(src_text, return_tensors="pt") | |
# inference | |
generated_tokens = model.generate( | |
**encoded_input, | |
forced_bos_token_id=tokenizer.get_lang_id(tgt_lang), | |
max_length=1024, | |
) | |
# decode | |
pred_texts = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) | |
pred_text = pred_texts[0] | |
assert isinstance(pred_text, str) | |
return pred_text | |
inputs = [ | |
gr.Textbox(lines=4, value="Hello world!", label="Input Text"), | |
gr.Dropdown(lang_names, value="English", label="Source Language"), | |
gr.Dropdown(lang_names, value="Korean", label="Target Language"), | |
] | |
outputs = gr.Textbox(lines=4, label="Output Text") | |
demo = gr.Interface( | |
fn=translate, | |
inputs=inputs, | |
outputs=outputs, | |
title="Flores101: Large-Scale Multilingual Machine Translation", | |
description="[`seyoungsong/flores101_mm100_175M`](https://huggingface.co/seyoungsong/flores101_mm100_175M)", | |
) | |
if __name__ == "__main__": | |
# https://huggingface.co/seyoungsong/flores101_mm100_175M | |
# https://huggingface.co/spaces/seyoungsong/flores101_mm100_175M | |
# gradio src/pretrained/gradio/app.py | |
# http://127.0.0.1:7860 | |
demo.launch() | |