#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse import os from project_settings import project_path hf_hub_cache = (project_path / "cache/huggingface/hub").as_posix() os.environ["HUGGINGFACE_HUB_CACHE"] = hf_hub_cache import gradio as gr from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer def main(): model_dict = { "facebook/m2m100_418M": { "model": M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M"), "tokenizer": M2M100Tokenizer.from_pretrained("facebook/m2m100_418M") } } def multilingual_translate(src_text: str, src_lang: str, tgt_lang: str, model_name: str, ): model_group = model_dict.get(model_name) if model_group is None: for k, mg in model_dict.items(): del mg["model"] model_dict[model_name] = { "model": M2M100ForConditionalGeneration.from_pretrained(model_name), "tokenizer": M2M100Tokenizer.from_pretrained(model_name) } model_group = model_dict[model_name] model = model_group["model"] tokenizer = model_group["tokenizer"] tokenizer.src_lang = src_lang encoded_src = tokenizer(src_text, return_tensors="pt") generated_tokens = model.generate(**encoded_src, forced_bos_token_id=tokenizer.get_lang_id(tgt_lang), ) result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) return result[0] title = "Multilingual Machine Translation" description = """ M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation. It was introduced in this [paper](https://arxiv.org/abs/2010.11125) and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository. ### Languages covered Afrikaans (af), Amharic (am), Arabic (ar), Asturian (ast), Azerbaijani (az), Bashkir (ba), Belarusian (be), Bulgarian (bg), Bengali (bn), Breton (br), Bosnian (bs), Catalan; Valencian (ca), Cebuano (ceb), Czech (cs), Welsh (cy), Danish (da), German (de), Greeek (el), English (en), Spanish (es), Estonian (et), Persian (fa), Fulah (ff), Finnish (fi), French (fr), Western Frisian (fy), Irish (ga), Gaelic; Scottish Gaelic (gd), Galician (gl), Gujarati (gu), Hausa (ha), Hebrew (he), Hindi (hi), Croatian (hr), Haitian; Haitian Creole (ht), Hungarian (hu), Armenian (hy), Indonesian (id), Igbo (ig), Iloko (ilo), Icelandic (is), Italian (it), Japanese (ja), Javanese (jv), Georgian (ka), Kazakh (kk), Central Khmer (km), Kannada (kn), Korean (ko), Luxembourgish; Letzeburgesch (lb), Ganda (lg), Lingala (ln), Lao (lo), Lithuanian (lt), Latvian (lv), Malagasy (mg), Macedonian (mk), Malayalam (ml), Mongolian (mn), Marathi (mr), Malay (ms), Burmese (my), Nepali (ne), Dutch; Flemish (nl), Norwegian (no), Northern Sotho (ns), Occitan (post 1500) (oc), Oriya (or), Panjabi; Punjabi (pa), Polish (pl), Pushto; Pashto (ps), Portuguese (pt), Romanian; Moldavian; Moldovan (ro), Russian (ru), Sindhi (sd), Sinhala; Sinhalese (si), Slovak (sk), Slovenian (sl), Somali (so), Albanian (sq), Serbian (sr), Swati (ss), Sundanese (su), Swedish (sv), Swahili (sw), Tamil (ta), Thai (th), Tagalog (tl), Tswana (tn), Turkish (tr), Ukrainian (uk), Urdu (ur), Uzbek (uz), Vietnamese (vi), Wolof (wo), Xhosa (xh), Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu) """ examples = [ [ "Hello world!", "en", "zh", "facebook/m2m100_418M", ], ] model_choices = [ "facebook/m2m100_418M", "facebook/m2m100_1.2B" ] inputs = [ gr.Textbox(lines=4, value="", label="Input Text"), gr.Textbox(lines=1, value="", label="Source Language"), gr.Textbox(lines=1, value="", label="Target Language"), gr.Dropdown(choices=model_choices, label="model_name") ] output = gr.Textbox(lines=4, label="Output Text") app = gr.Interface( fn=multilingual_translate, inputs=inputs, outputs=output, examples=examples, title=title, description=description, cache_examples=True ) app.launch(debug=True, enable_queue=True) return if __name__ == '__main__': main()