File size: 4,545 Bytes
68eb545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b27f4a
 
 
 
 
 
68eb545
 
 
 
9b27f4a
68eb545
9b27f4a
 
 
 
 
 
 
 
 
 
 
 
 
68eb545
 
473aa84
 
 
68eb545
 
24c189c
68eb545
2c51a4b
68eb545
6d06dfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68eb545
 
 
 
 
 
9b27f4a
68eb545
 
 
9b27f4a
 
 
 
68eb545
 
 
2c51a4b
9b27f4a
68eb545
 
2c51a4b
68eb545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import os

from project_settings import project_path

hf_hub_cache = (project_path / "cache/huggingface/hub").as_posix()

os.environ["HUGGINGFACE_HUB_CACHE"] = hf_hub_cache

import gradio as gr
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer


def main():
    model_dict = {
        "facebook/m2m100_418M": {
            "model": M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M"),
            "tokenizer": M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
        }
    }

    def multilingual_translate(src_text: str,
                               src_lang: str,
                               tgt_lang: str,
                               model_name: str,
                               ):
        model_group = model_dict.get(model_name)
        if model_group is None:
            for k, mg in model_dict.items():
                del mg["model"]
            model_dict[model_name] = {
                "model": M2M100ForConditionalGeneration.from_pretrained(model_name),
                "tokenizer": M2M100Tokenizer.from_pretrained(model_name)
            }
            model_group = model_dict[model_name]

        model = model_group["model"]
        tokenizer = model_group["tokenizer"]

        tokenizer.src_lang = src_lang
        encoded_src = tokenizer(src_text, return_tensors="pt")
        generated_tokens = model.generate(**encoded_src,
                                          forced_bos_token_id=tokenizer.get_lang_id(tgt_lang),
                                          )
        result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

        return result[0]

    title = "Multilingual Machine Translation"

    description = """
M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation. 
It was introduced in this [paper](https://arxiv.org/abs/2010.11125) and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository.

### Languages covered

Afrikaans (af), Amharic (am), Arabic (ar), Asturian (ast), Azerbaijani (az), 
Bashkir (ba), Belarusian (be), Bulgarian (bg), Bengali (bn), Breton (br), Bosnian (bs), 
Catalan; Valencian (ca), Cebuano (ceb), Czech (cs), Welsh (cy), Danish (da), German (de), Greeek (el), 
English (en), Spanish (es), Estonian (et), Persian (fa), Fulah (ff), Finnish (fi), French (fr), 
Western Frisian (fy), Irish (ga), Gaelic; Scottish Gaelic (gd), Galician (gl), Gujarati (gu), 
Hausa (ha), Hebrew (he), Hindi (hi), Croatian (hr), Haitian; Haitian Creole (ht), Hungarian (hu), Armenian (hy), 
Indonesian (id), Igbo (ig), Iloko (ilo), Icelandic (is), Italian (it), 
Japanese (ja), Javanese (jv), Georgian (ka), Kazakh (kk), Central Khmer (km), 
Kannada (kn), Korean (ko), Luxembourgish; Letzeburgesch (lb), Ganda (lg), 
Lingala (ln), Lao (lo), Lithuanian (lt), Latvian (lv), 
Malagasy (mg), Macedonian (mk), Malayalam (ml), Mongolian (mn), Marathi (mr), Malay (ms), 
Burmese (my), Nepali (ne), Dutch; Flemish (nl), Norwegian (no), Northern Sotho (ns), 
Occitan (post 1500) (oc), Oriya (or), Panjabi; Punjabi (pa), Polish (pl), Pushto; Pashto (ps), Portuguese (pt), 
Romanian; Moldavian; Moldovan (ro), Russian (ru), 
Sindhi (sd), Sinhala; Sinhalese (si), Slovak (sk), Slovenian (sl), Somali (so), 
Albanian (sq), Serbian (sr), Swati (ss), Sundanese (su), Swedish (sv), Swahili (sw), 
Tamil (ta), Thai (th), Tagalog (tl), Tswana (tn), Turkish (tr), 
Ukrainian (uk), Urdu (ur), Uzbek (uz), Vietnamese (vi), Wolof (wo), Xhosa (xh), 
Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu)
"""

    examples = [
        [
            "Hello world!",
            "en",
            "zh",
            "facebook/m2m100_418M",
        ],
    ]

    model_choices = [
        "facebook/m2m100_418M",
        "facebook/m2m100_1.2B"
    ]
    inputs = [
        gr.Textbox(lines=4, value="", label="Input Text"),
        gr.Textbox(lines=1, value="", label="Source Language"),
        gr.Textbox(lines=1, value="", label="Target Language"),
        gr.Dropdown(choices=model_choices, label="model_name")
    ]

    output = gr.Textbox(lines=4, label="Output Text")

    app = gr.Interface(
        fn=multilingual_translate,
        inputs=inputs,
        outputs=output,
        examples=examples,
        title=title,
        description=description,
        cache_examples=True
    )
    app.launch(debug=True, enable_queue=True)

    return


if __name__ == '__main__':
    main()