qgyd2021's picture
[update]add main
9b27f4a
raw
history blame
4.55 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import os
from project_settings import project_path
hf_hub_cache = (project_path / "cache/huggingface/hub").as_posix()
os.environ["HUGGINGFACE_HUB_CACHE"] = hf_hub_cache
import gradio as gr
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
def main():
model_dict = {
"facebook/m2m100_418M": {
"model": M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M"),
"tokenizer": M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
}
}
def multilingual_translate(src_text: str,
src_lang: str,
tgt_lang: str,
model_name: str,
):
model_group = model_dict.get(model_name)
if model_group is None:
for k, mg in model_dict.items():
del mg["model"]
model_dict[model_name] = {
"model": M2M100ForConditionalGeneration.from_pretrained(model_name),
"tokenizer": M2M100Tokenizer.from_pretrained(model_name)
}
model_group = model_dict[model_name]
model = model_group["model"]
tokenizer = model_group["tokenizer"]
tokenizer.src_lang = src_lang
encoded_src = tokenizer(src_text, return_tensors="pt")
generated_tokens = model.generate(**encoded_src,
forced_bos_token_id=tokenizer.get_lang_id(tgt_lang),
)
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
return result[0]
title = "Multilingual Machine Translation"
description = """
M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation.
It was introduced in this [paper](https://arxiv.org/abs/2010.11125) and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository.
### Languages covered
Afrikaans (af), Amharic (am), Arabic (ar), Asturian (ast), Azerbaijani (az),
Bashkir (ba), Belarusian (be), Bulgarian (bg), Bengali (bn), Breton (br), Bosnian (bs),
Catalan; Valencian (ca), Cebuano (ceb), Czech (cs), Welsh (cy), Danish (da), German (de), Greeek (el),
English (en), Spanish (es), Estonian (et), Persian (fa), Fulah (ff), Finnish (fi), French (fr),
Western Frisian (fy), Irish (ga), Gaelic; Scottish Gaelic (gd), Galician (gl), Gujarati (gu),
Hausa (ha), Hebrew (he), Hindi (hi), Croatian (hr), Haitian; Haitian Creole (ht), Hungarian (hu), Armenian (hy),
Indonesian (id), Igbo (ig), Iloko (ilo), Icelandic (is), Italian (it),
Japanese (ja), Javanese (jv), Georgian (ka), Kazakh (kk), Central Khmer (km),
Kannada (kn), Korean (ko), Luxembourgish; Letzeburgesch (lb), Ganda (lg),
Lingala (ln), Lao (lo), Lithuanian (lt), Latvian (lv),
Malagasy (mg), Macedonian (mk), Malayalam (ml), Mongolian (mn), Marathi (mr), Malay (ms),
Burmese (my), Nepali (ne), Dutch; Flemish (nl), Norwegian (no), Northern Sotho (ns),
Occitan (post 1500) (oc), Oriya (or), Panjabi; Punjabi (pa), Polish (pl), Pushto; Pashto (ps), Portuguese (pt),
Romanian; Moldavian; Moldovan (ro), Russian (ru),
Sindhi (sd), Sinhala; Sinhalese (si), Slovak (sk), Slovenian (sl), Somali (so),
Albanian (sq), Serbian (sr), Swati (ss), Sundanese (su), Swedish (sv), Swahili (sw),
Tamil (ta), Thai (th), Tagalog (tl), Tswana (tn), Turkish (tr),
Ukrainian (uk), Urdu (ur), Uzbek (uz), Vietnamese (vi), Wolof (wo), Xhosa (xh),
Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu)
"""
examples = [
[
"Hello world!",
"en",
"zh",
"facebook/m2m100_418M",
],
]
model_choices = [
"facebook/m2m100_418M",
"facebook/m2m100_1.2B"
]
inputs = [
gr.Textbox(lines=4, value="", label="Input Text"),
gr.Textbox(lines=1, value="", label="Source Language"),
gr.Textbox(lines=1, value="", label="Target Language"),
gr.Dropdown(choices=model_choices, label="model_name")
]
output = gr.Textbox(lines=4, label="Output Text")
app = gr.Interface(
fn=multilingual_translate,
inputs=inputs,
outputs=output,
examples=examples,
title=title,
description=description,
cache_examples=True
)
app.launch(debug=True, enable_queue=True)
return
if __name__ == '__main__':
main()