qgyd2021's picture
[update]add sent_tokenize
ec314ff
raw
history blame
3.39 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import os
from project_settings import project_path
hf_hub_cache = (project_path / "cache/huggingface/hub").as_posix()
os.environ["HUGGINGFACE_HUB_CACHE"] = hf_hub_cache
import gradio as gr
import nltk
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
from transformers.generation.streamers import TextIteratorStreamer
def main():
model_dict = {
"facebook/m2m100_418M": {
"model": M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M"),
"tokenizer": M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
}
}
def multilingual_translation(src_text: str,
src_lang: str,
tgt_lang: str,
model_name: str,
):
model_group = model_dict.get(model_name)
if model_group is None:
for k in list(model_dict.keys()):
del model_dict[k]
model_dict[model_name] = {
"model": M2M100ForConditionalGeneration.from_pretrained(model_name),
"tokenizer": M2M100Tokenizer.from_pretrained(model_name)
}
model_group = model_dict[model_name]
model = model_group["model"]
tokenizer = model_group["tokenizer"]
tokenizer.src_lang = src_lang
src_t_list = nltk.sent_tokenize(src_text)
result = ""
for src_t in src_t_list:
encoded_src = tokenizer(src_t, return_tensors="pt")
generated_tokens = model.generate(**encoded_src,
forced_bos_token_id=tokenizer.get_lang_id(tgt_lang),
)
text_decoded = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
result += text_decoded[0]
output.value = result
return result
title = "Multilingual Machine Translation"
description = """
M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation.
It was introduced in this [paper](https://arxiv.org/abs/2010.11125) and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository.
[Languages covered](https://huggingface.co/facebook/m2m100_418M#languages-covered)
"""
examples = [
[
"Hello world!",
"en",
"zh",
"facebook/m2m100_418M",
],
]
model_choices = [
"facebook/m2m100_418M",
"facebook/m2m100_1.2B"
]
inputs = [
gr.Textbox(lines=4, placeholder="text", label="Input Text"),
gr.Textbox(lines=1, value="en", label="Source Language"),
gr.Textbox(lines=1, value="zh", label="Target Language"),
gr.Dropdown(choices=model_choices, value="facebook/m2m100_418M", label="model_name")
]
output = gr.Textbox(lines=4, label="Output Text")
demo = gr.Interface(
fn=multilingual_translation,
inputs=inputs,
outputs=output,
examples=examples,
title=title,
description=description,
cache_examples=False
)
demo.queue().launch(debug=True, enable_queue=True)
return
if __name__ == '__main__':
main()