qgyd2021's picture
Update main.py
b21dd8d verified
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import json
import os
import platform
import re
from typing import List
from project_settings import project_path
os.environ["HUGGINGFACE_HUB_CACHE"] = (project_path / "cache/huggingface/hub").as_posix()
os.environ['NLTK_DATA'] = (project_path / "thirdparty_data/nltk_data").as_posix()
import gradio as gr
import nltk
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
language_map = {
"Arabic": "ar",
"Chinese": "zh",
"Czech": "cs",
"Danish": "da",
"Dutch": "nl",
"Flemish": "nl",
"English": "en",
"Estonian": "et",
"Finnish": "fi",
"French": "fr",
"German": "de",
"Italian": "it",
"Norwegian": "no",
"Polish": "pl",
"Portuguese": "pt",
"Russian": "ru",
"Spanish": "es",
"Swedish": "sv",
"Turkish": "tr",
}
nltk_sent_tokenize_languages = [
"czech", "danish", "dutch", "flemish", "english", "estonian",
"finnish", "french", "german", "italian", "norwegian",
"polish", "portuguese", "russian", "spanish", "swedish", "turkish"
]
def chinese_sent_tokenize(text: str):
# 单字符断句符
text = re.sub(r"([。!?\?])([^”’])", r"\1\n\2", text)
# 英文省略号
text = re.sub(r"(\.{6})([^”’])", r"\1\n\2", text)
# 中文省略号
text = re.sub(r"(\…{2})([^”’])", r"\1\n\2", text)
# 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号
text = re.sub(r"([。!?\?][”’])([^,。!?\?])", r"\1\n\2", text)
# 段尾如果有多余的\n就去掉它
# 很多规则中会考虑分号; ,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。
text = text.rstrip()
return text.split("\n")
def sent_tokenize(text: str, language: str) -> List[str]:
if language in ["chinese"]:
sent_list = chinese_sent_tokenize(text)
elif language in nltk_sent_tokenize_languages:
sent_list = nltk.sent_tokenize(text, language)
else:
sent_list = [text]
return sent_list
def main():
model_dict = {
"facebook/m2m100_418M": {
"model": M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M"),
"tokenizer": M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
}
}
def multilingual_translation(src_text: str,
src_lang: str,
tgt_lang: str,
model_name: str,
):
# model
model_group = model_dict.get(model_name)
if model_group is None:
for k in list(model_dict.keys()):
del model_dict[k]
model_dict[model_name] = {
"model": M2M100ForConditionalGeneration.from_pretrained(model_name),
"tokenizer": M2M100Tokenizer.from_pretrained(model_name)
}
model_group = model_dict[model_name]
model = model_group["model"]
tokenizer = model_group["tokenizer"]
# tokenize
tokenizer.src_lang = language_map[src_lang]
if src_lang.lower() in nltk_sent_tokenize_languages:
src_t_list = sent_tokenize(src_text, language=src_lang.lower())
else:
src_t_list = [src_text]
# infer
result = ""
for src_t in src_t_list:
encoded_src = tokenizer(src_t, return_tensors="pt")
generated_tokens = model.generate(**encoded_src,
forced_bos_token_id=tokenizer.get_lang_id(language_map[tgt_lang]),
)
text_decoded = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
result += text_decoded[0]
output.value = result
return result
title = "Multilingual Machine Translation"
description = """
M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation.
It was introduced in this [paper](https://arxiv.org/abs/2010.11125) and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository.
[Languages covered](https://huggingface.co/facebook/m2m100_418M#languages-covered)
"""
examples = [
[
"Hello world!",
"English",
"Chinese",
"facebook/m2m100_418M",
],
[
"我是一个句子。我是另一个句子。",
"Chinese",
"English",
"facebook/m2m100_418M",
],
[
"M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation. It was introduced in this paper and first released in this repository.",
"English",
"Chinese",
"facebook/m2m100_418M",
]
]
model_choices = [
"facebook/m2m100_418M",
"facebook/m2m100_1.2B"
]
inputs = [
gr.Textbox(lines=4, placeholder="text", label="Input Text"),
gr.Dropdown(choices=list(language_map.keys()), value="English", label="Source Language"),
gr.Dropdown(choices=list(language_map.keys()), value="Chinese", label="Target Language"),
gr.Dropdown(choices=model_choices, value="facebook/m2m100_418M", label="model_name")
]
output = gr.Textbox(lines=4, label="Output Text")
demo = gr.Interface(
fn=multilingual_translation,
inputs=inputs,
outputs=output,
examples=examples,
title=title,
description=description,
cache_examples=False
)
demo.queue().launch(
# debug=True, enable_queue=True,
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
server_port=7860
)
return
if __name__ == '__main__':
main()