|
|
|
|
|
import argparse |
|
import os |
|
|
|
from project_settings import project_path |
|
|
|
hf_hub_cache = (project_path / "cache/huggingface/hub").as_posix() |
|
|
|
os.environ["HUGGINGFACE_HUB_CACHE"] = hf_hub_cache |
|
|
|
import gradio as gr |
|
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer |
|
|
|
|
|
def main(): |
|
model_dict = { |
|
"facebook/m2m100_418M": { |
|
"model": M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M"), |
|
"tokenizer": M2M100Tokenizer.from_pretrained("facebook/m2m100_418M") |
|
} |
|
} |
|
|
|
def multilingual_translate(src_text: str, |
|
src_lang: str, |
|
tgt_lang: str, |
|
model_name: str, |
|
): |
|
model_group = model_dict.get(model_name) |
|
if model_group is None: |
|
for k, mg in model_dict.items(): |
|
del mg["model"] |
|
model_dict[model_name] = { |
|
"model": M2M100ForConditionalGeneration.from_pretrained(model_name), |
|
"tokenizer": M2M100Tokenizer.from_pretrained(model_name) |
|
} |
|
model_group = model_dict[model_name] |
|
|
|
model = model_group["model"] |
|
tokenizer = model_group["tokenizer"] |
|
|
|
tokenizer.src_lang = src_lang |
|
encoded_src = tokenizer(src_text, return_tensors="pt") |
|
generated_tokens = model.generate(**encoded_src, |
|
forced_bos_token_id=tokenizer.get_lang_id(tgt_lang), |
|
) |
|
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) |
|
|
|
return result[0] |
|
|
|
title = "Multilingual Machine Translation" |
|
|
|
description = """ |
|
M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation. |
|
It was introduced in this [paper](https://arxiv.org/abs/2010.11125) and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository. |
|
|
|
### Languages covered |
|
|
|
Afrikaans (af), Amharic (am), Arabic (ar), Asturian (ast), Azerbaijani (az), |
|
Bashkir (ba), Belarusian (be), Bulgarian (bg), Bengali (bn), Breton (br), Bosnian (bs), |
|
Catalan; Valencian (ca), Cebuano (ceb), Czech (cs), Welsh (cy), Danish (da), German (de), Greeek (el), |
|
English (en), Spanish (es), Estonian (et), Persian (fa), Fulah (ff), Finnish (fi), French (fr), |
|
Western Frisian (fy), Irish (ga), Gaelic; Scottish Gaelic (gd), Galician (gl), Gujarati (gu), |
|
Hausa (ha), Hebrew (he), Hindi (hi), Croatian (hr), Haitian; Haitian Creole (ht), Hungarian (hu), Armenian (hy), |
|
Indonesian (id), Igbo (ig), Iloko (ilo), Icelandic (is), Italian (it), |
|
Japanese (ja), Javanese (jv), Georgian (ka), Kazakh (kk), Central Khmer (km), |
|
Kannada (kn), Korean (ko), Luxembourgish; Letzeburgesch (lb), Ganda (lg), |
|
Lingala (ln), Lao (lo), Lithuanian (lt), Latvian (lv), |
|
Malagasy (mg), Macedonian (mk), Malayalam (ml), Mongolian (mn), Marathi (mr), Malay (ms), |
|
Burmese (my), Nepali (ne), Dutch; Flemish (nl), Norwegian (no), Northern Sotho (ns), |
|
Occitan (post 1500) (oc), Oriya (or), Panjabi; Punjabi (pa), Polish (pl), Pushto; Pashto (ps), Portuguese (pt), |
|
Romanian; Moldavian; Moldovan (ro), Russian (ru), |
|
Sindhi (sd), Sinhala; Sinhalese (si), Slovak (sk), Slovenian (sl), Somali (so), |
|
Albanian (sq), Serbian (sr), Swati (ss), Sundanese (su), Swedish (sv), Swahili (sw), |
|
Tamil (ta), Thai (th), Tagalog (tl), Tswana (tn), Turkish (tr), |
|
Ukrainian (uk), Urdu (ur), Uzbek (uz), Vietnamese (vi), Wolof (wo), Xhosa (xh), |
|
Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu) |
|
""" |
|
|
|
examples = [ |
|
[ |
|
"Hello world!", |
|
"en", |
|
"zh", |
|
"facebook/m2m100_418M", |
|
], |
|
] |
|
|
|
model_choices = [ |
|
"facebook/m2m100_418M", |
|
"facebook/m2m100_1.2B" |
|
] |
|
inputs = [ |
|
gr.Textbox(lines=4, value="", label="Input Text"), |
|
gr.Textbox(lines=1, value="", label="Source Language"), |
|
gr.Textbox(lines=1, value="", label="Target Language"), |
|
gr.Dropdown(choices=model_choices, label="model_name") |
|
] |
|
|
|
output = gr.Textbox(lines=4, label="Output Text") |
|
|
|
app = gr.Interface( |
|
fn=multilingual_translate, |
|
inputs=inputs, |
|
outputs=output, |
|
examples=examples, |
|
title=title, |
|
description=description, |
|
cache_examples=True |
|
) |
|
app.launch(debug=True, enable_queue=True) |
|
|
|
return |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|