|
|
|
|
|
import argparse |
|
import os |
|
|
|
from project_settings import project_path |
|
|
|
hf_hub_cache = (project_path / "cache/huggingface/hub").as_posix() |
|
|
|
os.environ["HUGGINGFACE_HUB_CACHE"] = hf_hub_cache |
|
|
|
import gradio as gr |
|
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--pretrained_model_name_or_path", |
|
default="facebook/m2m100_418M", |
|
type=str |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
model = M2M100ForConditionalGeneration.from_pretrained(args.pretrained_model_name_or_path) |
|
tokenizer = M2M100Tokenizer.from_pretrained(args.pretrained_model_name_or_path) |
|
|
|
def multilingual_translate(src_text: str, |
|
src_lang: str, |
|
tgt_lang: str, |
|
): |
|
tokenizer.src_lang = src_lang |
|
encoded_src = tokenizer(src_text, return_tensors="pt") |
|
generated_tokens = model.generate(**encoded_src, |
|
forced_bos_token_id=tokenizer.get_lang_id(tgt_lang), |
|
) |
|
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) |
|
|
|
return result |
|
|
|
title = "Multilingual Machine Translation" |
|
|
|
description = "M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation. It was introduced in this [paper](https://arxiv.org/abs/2010.11125) and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository." |
|
|
|
examples = [ |
|
[ |
|
"Hello world!", |
|
"en", |
|
"zh", |
|
], |
|
] |
|
|
|
inputs = [ |
|
gr.Textbox(lines=4, value="", label="Input Text"), |
|
gr.Textbox(lines=1, value="", label="Source Language"), |
|
gr.Textbox(lines=1, value="", label="Target Language"), |
|
] |
|
|
|
output = gr.Textbox(lines=4, label="Output Text") |
|
|
|
app = gr.Interface( |
|
fn=multilingual_translate, |
|
inputs=inputs, |
|
outputs=output, |
|
examples=examples, |
|
title=title, |
|
description=description, |
|
cache_examples=True |
|
) |
|
app.launch(debug=True, enable_queue=True) |
|
|
|
return |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|