|
|
|
|
|
import argparse |
|
import json |
|
import os |
|
import platform |
|
import re |
|
from typing import List |
|
|
|
from project_settings import project_path |
|
|
|
os.environ["HUGGINGFACE_HUB_CACHE"] = (project_path / "cache/huggingface/hub").as_posix() |
|
os.environ['NLTK_DATA'] = (project_path / "thirdparty_data/nltk_data").as_posix() |
|
|
|
import gradio as gr |
|
import nltk |
|
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer |
|
|
|
|
|
language_map = { |
|
"Arabic": "ar", |
|
"Chinese": "zh", |
|
"Czech": "cs", |
|
"Danish": "da", |
|
"Dutch": "nl", |
|
"Flemish": "nl", |
|
"English": "en", |
|
"Estonian": "et", |
|
"Finnish": "fi", |
|
"French": "fr", |
|
"German": "de", |
|
"Italian": "it", |
|
"Norwegian": "no", |
|
"Polish": "pl", |
|
"Portuguese": "pt", |
|
"Russian": "ru", |
|
"Spanish": "es", |
|
"Swedish": "sv", |
|
"Turkish": "tr", |
|
|
|
} |
|
|
|
|
|
nltk_sent_tokenize_languages = [ |
|
"czech", "danish", "dutch", "flemish", "english", "estonian", |
|
"finnish", "french", "german", "italian", "norwegian", |
|
"polish", "portuguese", "russian", "spanish", "swedish", "turkish" |
|
] |
|
|
|
|
|
def chinese_sent_tokenize(text: str): |
|
|
|
text = re.sub(r"([。!?\?])([^”’])", r"\1\n\2", text) |
|
|
|
text = re.sub(r"(\.{6})([^”’])", r"\1\n\2", text) |
|
|
|
text = re.sub(r"(\…{2})([^”’])", r"\1\n\2", text) |
|
|
|
text = re.sub(r"([。!?\?][”’])([^,。!?\?])", r"\1\n\2", text) |
|
|
|
|
|
text = text.rstrip() |
|
|
|
return text.split("\n") |
|
|
|
|
|
def sent_tokenize(text: str, language: str) -> List[str]: |
|
if language in ["chinese"]: |
|
sent_list = chinese_sent_tokenize(text) |
|
elif language in nltk_sent_tokenize_languages: |
|
sent_list = nltk.sent_tokenize(text, language) |
|
else: |
|
sent_list = [text] |
|
return sent_list |
|
|
|
|
|
def main(): |
|
model_dict = { |
|
"facebook/m2m100_418M": { |
|
"model": M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M"), |
|
"tokenizer": M2M100Tokenizer.from_pretrained("facebook/m2m100_418M") |
|
} |
|
} |
|
|
|
def multilingual_translation(src_text: str, |
|
src_lang: str, |
|
tgt_lang: str, |
|
model_name: str, |
|
): |
|
|
|
model_group = model_dict.get(model_name) |
|
if model_group is None: |
|
for k in list(model_dict.keys()): |
|
del model_dict[k] |
|
|
|
model_dict[model_name] = { |
|
"model": M2M100ForConditionalGeneration.from_pretrained(model_name), |
|
"tokenizer": M2M100Tokenizer.from_pretrained(model_name) |
|
} |
|
model_group = model_dict[model_name] |
|
|
|
model = model_group["model"] |
|
tokenizer = model_group["tokenizer"] |
|
|
|
|
|
tokenizer.src_lang = language_map[src_lang] |
|
|
|
if src_lang.lower() in nltk_sent_tokenize_languages: |
|
src_t_list = sent_tokenize(src_text, language=src_lang.lower()) |
|
else: |
|
src_t_list = [src_text] |
|
|
|
|
|
result = "" |
|
for src_t in src_t_list: |
|
encoded_src = tokenizer(src_t, return_tensors="pt") |
|
generated_tokens = model.generate(**encoded_src, |
|
forced_bos_token_id=tokenizer.get_lang_id(language_map[tgt_lang]), |
|
) |
|
text_decoded = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) |
|
result += text_decoded[0] |
|
|
|
output.value = result |
|
|
|
return result |
|
|
|
title = "Multilingual Machine Translation" |
|
|
|
description = """ |
|
M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation. |
|
It was introduced in this [paper](https://arxiv.org/abs/2010.11125) and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository. |
|
|
|
[Languages covered](https://huggingface.co/facebook/m2m100_418M#languages-covered) |
|
|
|
""" |
|
|
|
examples = [ |
|
[ |
|
"Hello world!", |
|
"English", |
|
"Chinese", |
|
"facebook/m2m100_418M", |
|
], |
|
[ |
|
"我是一个句子。我是另一个句子。", |
|
"Chinese", |
|
"English", |
|
"facebook/m2m100_418M", |
|
], |
|
[ |
|
"M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation. It was introduced in this paper and first released in this repository.", |
|
"English", |
|
"Chinese", |
|
"facebook/m2m100_418M", |
|
] |
|
] |
|
|
|
model_choices = [ |
|
"facebook/m2m100_418M", |
|
"facebook/m2m100_1.2B" |
|
] |
|
inputs = [ |
|
gr.Textbox(lines=4, placeholder="text", label="Input Text"), |
|
gr.Dropdown(choices=list(language_map.keys()), value="English", label="Source Language"), |
|
gr.Dropdown(choices=list(language_map.keys()), value="Chinese", label="Target Language"), |
|
gr.Dropdown(choices=model_choices, value="facebook/m2m100_418M", label="model_name") |
|
] |
|
|
|
output = gr.Textbox(lines=4, label="Output Text") |
|
|
|
demo = gr.Interface( |
|
fn=multilingual_translation, |
|
inputs=inputs, |
|
outputs=output, |
|
examples=examples, |
|
title=title, |
|
description=description, |
|
cache_examples=False |
|
) |
|
demo.queue().launch( |
|
|
|
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", |
|
server_port=7860 |
|
) |
|
|
|
return |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|