#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse import os from project_settings import project_path hf_hub_cache = (project_path / "cache/huggingface/hub").as_posix() os.environ["HUGGINGFACE_HUB_CACHE"] = hf_hub_cache import gradio as gr import nltk from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer from transformers.generation.streamers import TextIteratorStreamer def main(): model_dict = { "facebook/m2m100_418M": { "model": M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M"), "tokenizer": M2M100Tokenizer.from_pretrained("facebook/m2m100_418M") } } def multilingual_translation(src_text: str, src_lang: str, tgt_lang: str, model_name: str, ): model_group = model_dict.get(model_name) if model_group is None: for k in list(model_dict.keys()): del model_dict[k] model_dict[model_name] = { "model": M2M100ForConditionalGeneration.from_pretrained(model_name), "tokenizer": M2M100Tokenizer.from_pretrained(model_name) } model_group = model_dict[model_name] model = model_group["model"] tokenizer = model_group["tokenizer"] tokenizer.src_lang = src_lang src_t_list = nltk.sent_tokenize(src_text) result = "" for src_t in src_t_list: encoded_src = tokenizer(src_t, return_tensors="pt") generated_tokens = model.generate(**encoded_src, forced_bos_token_id=tokenizer.get_lang_id(tgt_lang), ) text_decoded = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) result += text_decoded[0] output.value = result return result title = "Multilingual Machine Translation" description = """ M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation. It was introduced in this [paper](https://arxiv.org/abs/2010.11125) and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository. [Languages covered](https://huggingface.co/facebook/m2m100_418M#languages-covered) """ examples = [ [ "Hello world!", "en", "zh", "facebook/m2m100_418M", ], ] model_choices = [ "facebook/m2m100_418M", "facebook/m2m100_1.2B" ] inputs = [ gr.Textbox(lines=4, placeholder="text", label="Input Text"), gr.Textbox(lines=1, value="en", label="Source Language"), gr.Textbox(lines=1, value="zh", label="Target Language"), gr.Dropdown(choices=model_choices, value="facebook/m2m100_418M", label="model_name") ] output = gr.Textbox(lines=4, label="Output Text") demo = gr.Interface( fn=multilingual_translation, inputs=inputs, outputs=output, examples=examples, title=title, description=description, cache_examples=False ) demo.queue().launch(debug=True, enable_queue=True) return if __name__ == '__main__': main()