File size: 3,324 Bytes
68eb545
 
 
 
 
 
 
 
 
 
 
 
22ed4b7
68eb545
57b5dd0
68eb545
 
 
9b27f4a
 
 
 
 
 
68eb545
57b5dd0
 
 
 
 
9b27f4a
 
27d3705
 
 
9b27f4a
 
 
 
 
 
 
 
 
68eb545
 
22ed4b7
 
 
 
 
 
 
 
 
 
 
57b5dd0
 
22ed4b7
68eb545
2c51a4b
68eb545
6d06dfd
 
 
 
b72e90e
 
6d06dfd
68eb545
 
 
 
 
 
9b27f4a
68eb545
 
 
9b27f4a
 
 
 
68eb545
6757676
 
 
 
68eb545
 
2c51a4b
68eb545
57b5dd0
 
68eb545
 
 
 
 
57b5dd0
68eb545
57b5dd0
68eb545
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import os

from project_settings import project_path

hf_hub_cache = (project_path / "cache/huggingface/hub").as_posix()

os.environ["HUGGINGFACE_HUB_CACHE"] = hf_hub_cache

import gradio as gr
import nltk
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
from transformers.generation.streamers import TextIteratorStreamer


def main():
    model_dict = {
        "facebook/m2m100_418M": {
            "model": M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M"),
            "tokenizer": M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
        }
    }

    def fn_non_stream(src_text: str,
                      src_lang: str,
                      tgt_lang: str,
                      model_name: str,
                      ):
        model_group = model_dict.get(model_name)
        if model_group is None:
            for k in list(model_dict.keys()):
                del model_dict[k]

            model_dict[model_name] = {
                "model": M2M100ForConditionalGeneration.from_pretrained(model_name),
                "tokenizer": M2M100Tokenizer.from_pretrained(model_name)
            }
            model_group = model_dict[model_name]

        model = model_group["model"]
        tokenizer = model_group["tokenizer"]

        tokenizer.src_lang = src_lang

        src_t_list = nltk.sent_tokenize(src_text)

        result = ""
        for src_t in src_t_list:
            encoded_src = tokenizer(src_t, return_tensors="pt")
            generated_tokens = model.generate(**encoded_src,
                                              forced_bos_token_id=tokenizer.get_lang_id(tgt_lang),
                                              )
            text_decoded = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
            result += text_decoded[0]

            output.value = result

        return result

    title = "Multilingual Machine Translation"

    description = """
M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation. 
It was introduced in this [paper](https://arxiv.org/abs/2010.11125) and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository.

[Languages covered](https://huggingface.co/facebook/m2m100_418M#languages-covered)

"""

    examples = [
        [
            "Hello world!",
            "en",
            "zh",
            "facebook/m2m100_418M",
        ],
    ]

    model_choices = [
        "facebook/m2m100_418M",
        "facebook/m2m100_1.2B"
    ]
    inputs = [
        gr.Textbox(lines=4, placeholder="text", label="Input Text"),
        gr.Textbox(lines=1, value="en", label="Source Language"),
        gr.Textbox(lines=1, value="zh", label="Target Language"),
        gr.Dropdown(choices=model_choices, value="facebook/m2m100_418M", label="model_name")
    ]

    output = gr.Textbox(lines=4, label="Output Text")

    demo = gr.Interface(
        fn=fn_non_stream,
        inputs=inputs,
        outputs=output,
        examples=examples,
        title=title,
        description=description,
        cache_examples=False
    )
    demo.queue().launch(debug=True, enable_queue=True)

    return


if __name__ == '__main__':
    main()