#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse import os from project_settings import project_path hf_hub_cache = (project_path / "cache/huggingface/hub").as_posix() os.environ["HUGGINGFACE_HUB_CACHE"] = hf_hub_cache import gradio as gr from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--pretrained_model_name_or_path", default="facebook/m2m100_418M", type=str ) args = parser.parse_args() return args def main(): args = get_args() model = M2M100ForConditionalGeneration.from_pretrained(args.pretrained_model_name_or_path) tokenizer = M2M100Tokenizer.from_pretrained(args.pretrained_model_name_or_path) def multilingual_translate(src_text: str, src_lang: str, tgt_lang: str, ): tokenizer.src_lang = src_lang encoded_src = tokenizer(src_text, return_tensors="pt") generated_tokens = model.generate(**encoded_src, forced_bos_token_id=tokenizer.get_lang_id(tgt_lang), ) result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) return result title = "Multilingual Machine Translation" description = "M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation. It was introduced in this [paper](https://arxiv.org/abs/2010.11125) and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository." examples = [ [ "Hello world!", "en", "zh", ], ] inputs = [ gr.Textbox(lines=4, value="", label="Input Text"), gr.Textbox(lines=1, value="", label="Source Language"), gr.Textbox(lines=1, value="", label="Target Language"), ] output = gr.Textbox(lines=4, label="Output Text") app = gr.Interface( fn=multilingual_translate, inputs=inputs, outputs=output, examples=examples, title=title, description=description, cache_examples=True ) app.launch(debug=True, enable_queue=True) return if __name__ == '__main__': main()