File size: 6,068 Bytes
68eb545 e6a62de 68eb545 b21dd8d b06652e 68eb545 9499ec0 68eb545 22ed4b7 68eb545 e6a62de 3b5b01c e6a62de b06652e 3b5b01c b06652e 3b5b01c b06652e 68eb545 9b27f4a 68eb545 ec314ff e6a62de 9b27f4a 27d3705 9b27f4a e6a62de 68eb545 e6a62de b06652e e6a62de 22ed4b7 e6a62de 22ed4b7 e6a62de 22ed4b7 57b5dd0 22ed4b7 68eb545 2c51a4b 68eb545 6d06dfd b72e90e 6d06dfd 68eb545 c82bbd4 9b27f4a 68eb545 aa4e278 68eb545 9b27f4a 68eb545 6757676 e6a62de 6757676 68eb545 2c51a4b 68eb545 57b5dd0 ec314ff 68eb545 57b5dd0 68eb545 49d142d 68eb545 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import json
import os
import platform
import re
from typing import List
from project_settings import project_path
os.environ["HUGGINGFACE_HUB_CACHE"] = (project_path / "cache/huggingface/hub").as_posix()
os.environ['NLTK_DATA'] = (project_path / "thirdparty_data/nltk_data").as_posix()
import gradio as gr
import nltk
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
language_map = {
"Arabic": "ar",
"Chinese": "zh",
"Czech": "cs",
"Danish": "da",
"Dutch": "nl",
"Flemish": "nl",
"English": "en",
"Estonian": "et",
"Finnish": "fi",
"French": "fr",
"German": "de",
"Italian": "it",
"Norwegian": "no",
"Polish": "pl",
"Portuguese": "pt",
"Russian": "ru",
"Spanish": "es",
"Swedish": "sv",
"Turkish": "tr",
}
nltk_sent_tokenize_languages = [
"czech", "danish", "dutch", "flemish", "english", "estonian",
"finnish", "french", "german", "italian", "norwegian",
"polish", "portuguese", "russian", "spanish", "swedish", "turkish"
]
def chinese_sent_tokenize(text: str):
# 单字符断句符
text = re.sub(r"([。!?\?])([^”’])", r"\1\n\2", text)
# 英文省略号
text = re.sub(r"(\.{6})([^”’])", r"\1\n\2", text)
# 中文省略号
text = re.sub(r"(\…{2})([^”’])", r"\1\n\2", text)
# 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号
text = re.sub(r"([。!?\?][”’])([^,。!?\?])", r"\1\n\2", text)
# 段尾如果有多余的\n就去掉它
# 很多规则中会考虑分号; ,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。
text = text.rstrip()
return text.split("\n")
def sent_tokenize(text: str, language: str) -> List[str]:
if language in ["chinese"]:
sent_list = chinese_sent_tokenize(text)
elif language in nltk_sent_tokenize_languages:
sent_list = nltk.sent_tokenize(text, language)
else:
sent_list = [text]
return sent_list
def main():
model_dict = {
"facebook/m2m100_418M": {
"model": M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M"),
"tokenizer": M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
}
}
def multilingual_translation(src_text: str,
src_lang: str,
tgt_lang: str,
model_name: str,
):
# model
model_group = model_dict.get(model_name)
if model_group is None:
for k in list(model_dict.keys()):
del model_dict[k]
model_dict[model_name] = {
"model": M2M100ForConditionalGeneration.from_pretrained(model_name),
"tokenizer": M2M100Tokenizer.from_pretrained(model_name)
}
model_group = model_dict[model_name]
model = model_group["model"]
tokenizer = model_group["tokenizer"]
# tokenize
tokenizer.src_lang = language_map[src_lang]
if src_lang.lower() in nltk_sent_tokenize_languages:
src_t_list = sent_tokenize(src_text, language=src_lang.lower())
else:
src_t_list = [src_text]
# infer
result = ""
for src_t in src_t_list:
encoded_src = tokenizer(src_t, return_tensors="pt")
generated_tokens = model.generate(**encoded_src,
forced_bos_token_id=tokenizer.get_lang_id(language_map[tgt_lang]),
)
text_decoded = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
result += text_decoded[0]
output.value = result
return result
title = "Multilingual Machine Translation"
description = """
M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation.
It was introduced in this [paper](https://arxiv.org/abs/2010.11125) and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository.
[Languages covered](https://huggingface.co/facebook/m2m100_418M#languages-covered)
"""
examples = [
[
"Hello world!",
"English",
"Chinese",
"facebook/m2m100_418M",
],
[
"我是一个句子。我是另一个句子。",
"Chinese",
"English",
"facebook/m2m100_418M",
],
[
"M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation. It was introduced in this paper and first released in this repository.",
"English",
"Chinese",
"facebook/m2m100_418M",
]
]
model_choices = [
"facebook/m2m100_418M",
"facebook/m2m100_1.2B"
]
inputs = [
gr.Textbox(lines=4, placeholder="text", label="Input Text"),
gr.Dropdown(choices=list(language_map.keys()), value="English", label="Source Language"),
gr.Dropdown(choices=list(language_map.keys()), value="Chinese", label="Target Language"),
gr.Dropdown(choices=model_choices, value="facebook/m2m100_418M", label="model_name")
]
output = gr.Textbox(lines=4, label="Output Text")
demo = gr.Interface(
fn=multilingual_translation,
inputs=inputs,
outputs=output,
examples=examples,
title=title,
description=description,
cache_examples=False
)
demo.queue().launch(
# debug=True, enable_queue=True,
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
server_port=7860
)
return
if __name__ == '__main__':
main()
|