qgyd2021's picture
[update]add main
473aa84
raw
history blame
2.4 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import os
from project_settings import project_path
hf_hub_cache = (project_path / "cache/huggingface/hub").as_posix()
os.environ["HUGGINGFACE_HUB_CACHE"] = hf_hub_cache
import gradio as gr
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--pretrained_model_name_or_path",
default="facebook/m2m100_418M",
type=str
)
args = parser.parse_args()
return args
def main():
args = get_args()
model = M2M100ForConditionalGeneration.from_pretrained(args.pretrained_model_name_or_path)
tokenizer = M2M100Tokenizer.from_pretrained(args.pretrained_model_name_or_path)
def multilingual_translate(src_text: str,
src_lang: str,
tgt_lang: str,
):
tokenizer.src_lang = src_lang
encoded_src = tokenizer(src_text, return_tensors="pt")
generated_tokens = model.generate(**encoded_src,
forced_bos_token_id=tokenizer.get_lang_id(tgt_lang),
)
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
return result
title = "Multilingual Machine Translation"
description = "M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation. It was introduced in this [paper](https://arxiv.org/abs/2010.11125) and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository."
examples = [
[
"Hello world!",
"en",
"zh",
],
]
inputs = [
gr.Textbox(lines=4, value="", label="Input Text"),
gr.Textbox(lines=1, value="", label="Source Language"),
gr.Textbox(lines=1, value="", label="Target Language"),
]
output = gr.Textbox(lines=4, label="Output Text")
app = gr.Interface(
fn=multilingual_translate,
inputs=inputs,
outputs=output,
examples=examples,
title=title,
description=description,
cache_examples=True
)
app.launch(debug=True, enable_queue=True)
return
if __name__ == '__main__':
main()