|
|
|
|
|
import argparse |
|
import os |
|
|
|
from project_settings import project_path |
|
|
|
hf_hub_cache = (project_path / "cache/huggingface/hub").as_posix() |
|
|
|
os.environ["HUGGINGFACE_HUB_CACHE"] = hf_hub_cache |
|
|
|
import gradio as gr |
|
import nltk |
|
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer |
|
from transformers.generation.streamers import TextIteratorStreamer |
|
|
|
|
|
def main(): |
|
model_dict = { |
|
"facebook/m2m100_418M": { |
|
"model": M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M"), |
|
"tokenizer": M2M100Tokenizer.from_pretrained("facebook/m2m100_418M") |
|
} |
|
} |
|
|
|
def multilingual_translation(src_text: str, |
|
src_lang: str, |
|
tgt_lang: str, |
|
model_name: str, |
|
): |
|
model_group = model_dict.get(model_name) |
|
if model_group is None: |
|
for k in list(model_dict.keys()): |
|
del model_dict[k] |
|
|
|
model_dict[model_name] = { |
|
"model": M2M100ForConditionalGeneration.from_pretrained(model_name), |
|
"tokenizer": M2M100Tokenizer.from_pretrained(model_name) |
|
} |
|
model_group = model_dict[model_name] |
|
|
|
model = model_group["model"] |
|
tokenizer = model_group["tokenizer"] |
|
|
|
tokenizer.src_lang = src_lang |
|
|
|
src_t_list = nltk.sent_tokenize(src_text) |
|
|
|
result = "" |
|
for src_t in src_t_list: |
|
encoded_src = tokenizer(src_t, return_tensors="pt") |
|
generated_tokens = model.generate(**encoded_src, |
|
forced_bos_token_id=tokenizer.get_lang_id(tgt_lang), |
|
) |
|
text_decoded = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) |
|
result += text_decoded[0] |
|
|
|
output.value = result |
|
|
|
return result |
|
|
|
title = "Multilingual Machine Translation" |
|
|
|
description = """ |
|
M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation. |
|
It was introduced in this [paper](https://arxiv.org/abs/2010.11125) and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository. |
|
|
|
[Languages covered](https://huggingface.co/facebook/m2m100_418M#languages-covered) |
|
|
|
""" |
|
|
|
examples = [ |
|
[ |
|
"Hello world!", |
|
"en", |
|
"zh", |
|
"facebook/m2m100_418M", |
|
], |
|
] |
|
|
|
model_choices = [ |
|
"facebook/m2m100_418M", |
|
"facebook/m2m100_1.2B" |
|
] |
|
inputs = [ |
|
gr.Textbox(lines=4, placeholder="text", label="Input Text"), |
|
gr.Textbox(lines=1, value="en", label="Source Language"), |
|
gr.Textbox(lines=1, value="zh", label="Target Language"), |
|
gr.Dropdown(choices=model_choices, value="facebook/m2m100_418M", label="model_name") |
|
] |
|
|
|
output = gr.Textbox(lines=4, label="Output Text") |
|
|
|
demo = gr.Interface( |
|
fn=multilingual_translation, |
|
inputs=inputs, |
|
outputs=output, |
|
examples=examples, |
|
title=title, |
|
description=description, |
|
cache_examples=False |
|
) |
|
demo.queue().launch(debug=True, enable_queue=True) |
|
|
|
return |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|