File size: 6,068 Bytes
68eb545
 
 
e6a62de
68eb545
b21dd8d
b06652e
 
68eb545
 
 
9499ec0
 
68eb545
 
22ed4b7
68eb545
 
 
e6a62de
3b5b01c
e6a62de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b06652e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b5b01c
b06652e
3b5b01c
 
b06652e
 
 
68eb545
9b27f4a
 
 
 
 
 
68eb545
ec314ff
 
 
 
 
e6a62de
9b27f4a
 
27d3705
 
 
9b27f4a
 
 
 
 
 
 
 
 
e6a62de
 
68eb545
e6a62de
b06652e
e6a62de
 
22ed4b7
e6a62de
22ed4b7
 
 
 
e6a62de
22ed4b7
 
 
 
57b5dd0
 
22ed4b7
68eb545
2c51a4b
68eb545
6d06dfd
 
 
 
b72e90e
 
6d06dfd
68eb545
 
 
 
c82bbd4
 
9b27f4a
68eb545
aa4e278
 
 
 
 
 
 
 
 
 
 
 
68eb545
 
9b27f4a
 
 
 
68eb545
6757676
e6a62de
 
6757676
68eb545
 
2c51a4b
68eb545
57b5dd0
ec314ff
68eb545
 
 
 
 
57b5dd0
68eb545
49d142d
 
 
 
 
68eb545
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import json
import os
import platform
import re
from typing import List

from project_settings import project_path

os.environ["HUGGINGFACE_HUB_CACHE"] = (project_path / "cache/huggingface/hub").as_posix()
os.environ['NLTK_DATA'] = (project_path / "thirdparty_data/nltk_data").as_posix()

import gradio as gr
import nltk
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer


language_map = {
    "Arabic": "ar",
    "Chinese": "zh",
    "Czech": "cs",
    "Danish": "da",
    "Dutch": "nl",
    "Flemish": "nl",
    "English": "en",
    "Estonian": "et",
    "Finnish": "fi",
    "French": "fr",
    "German": "de",
    "Italian": "it",
    "Norwegian": "no",
    "Polish": "pl",
    "Portuguese": "pt",
    "Russian": "ru",
    "Spanish": "es",
    "Swedish": "sv",
    "Turkish": "tr",

}


nltk_sent_tokenize_languages = [
    "czech", "danish", "dutch", "flemish", "english", "estonian",
    "finnish", "french", "german", "italian", "norwegian",
    "polish", "portuguese", "russian", "spanish", "swedish", "turkish"
]


def chinese_sent_tokenize(text: str):
    # 单字符断句符
    text = re.sub(r"([。!?\?])([^”’])", r"\1\n\2", text)
    # 英文省略号
    text = re.sub(r"(\.{6})([^”’])", r"\1\n\2", text)
    # 中文省略号
    text = re.sub(r"(\…{2})([^”’])", r"\1\n\2", text)
    # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号
    text = re.sub(r"([。!?\?][”’])([^,。!?\?])", r"\1\n\2", text)
    # 段尾如果有多余的\n就去掉它
    # 很多规则中会考虑分号; ,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。
    text = text.rstrip()

    return text.split("\n")


def sent_tokenize(text: str, language: str) -> List[str]:
    if language in ["chinese"]:
        sent_list = chinese_sent_tokenize(text)
    elif language in nltk_sent_tokenize_languages:
        sent_list = nltk.sent_tokenize(text, language)
    else:
        sent_list = [text]
    return sent_list


def main():
    model_dict = {
        "facebook/m2m100_418M": {
            "model": M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M"),
            "tokenizer": M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
        }
    }

    def multilingual_translation(src_text: str,
                                 src_lang: str,
                                 tgt_lang: str,
                                 model_name: str,
                                 ):
        # model
        model_group = model_dict.get(model_name)
        if model_group is None:
            for k in list(model_dict.keys()):
                del model_dict[k]

            model_dict[model_name] = {
                "model": M2M100ForConditionalGeneration.from_pretrained(model_name),
                "tokenizer": M2M100Tokenizer.from_pretrained(model_name)
            }
            model_group = model_dict[model_name]

        model = model_group["model"]
        tokenizer = model_group["tokenizer"]

        # tokenize
        tokenizer.src_lang = language_map[src_lang]

        if src_lang.lower() in nltk_sent_tokenize_languages:
            src_t_list = sent_tokenize(src_text, language=src_lang.lower())
        else:
            src_t_list = [src_text]

        # infer
        result = ""
        for src_t in src_t_list:
            encoded_src = tokenizer(src_t, return_tensors="pt")
            generated_tokens = model.generate(**encoded_src,
                                              forced_bos_token_id=tokenizer.get_lang_id(language_map[tgt_lang]),
                                              )
            text_decoded = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
            result += text_decoded[0]

            output.value = result

        return result

    title = "Multilingual Machine Translation"

    description = """
M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation. 
It was introduced in this [paper](https://arxiv.org/abs/2010.11125) and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository.

[Languages covered](https://huggingface.co/facebook/m2m100_418M#languages-covered)

"""

    examples = [
        [
            "Hello world!",
            "English",
            "Chinese",
            "facebook/m2m100_418M",
        ],
        [
            "我是一个句子。我是另一个句子。",
            "Chinese",
            "English",
            "facebook/m2m100_418M",
        ],
        [
            "M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation. It was introduced in this paper and first released in this repository.",
            "English",
            "Chinese",
            "facebook/m2m100_418M",
        ]
    ]

    model_choices = [
        "facebook/m2m100_418M",
        "facebook/m2m100_1.2B"
    ]
    inputs = [
        gr.Textbox(lines=4, placeholder="text", label="Input Text"),
        gr.Dropdown(choices=list(language_map.keys()), value="English", label="Source Language"),
        gr.Dropdown(choices=list(language_map.keys()), value="Chinese", label="Target Language"),
        gr.Dropdown(choices=model_choices, value="facebook/m2m100_418M", label="model_name")
    ]

    output = gr.Textbox(lines=4, label="Output Text")

    demo = gr.Interface(
        fn=multilingual_translation,
        inputs=inputs,
        outputs=output,
        examples=examples,
        title=title,
        description=description,
        cache_examples=False
    )
    demo.queue().launch(
        # debug=True, enable_queue=True,
        server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
        server_port=7860
    )

    return


if __name__ == '__main__':
    main()