File size: 3,324 Bytes
68eb545 22ed4b7 68eb545 57b5dd0 68eb545 9b27f4a 68eb545 57b5dd0 9b27f4a 27d3705 9b27f4a 68eb545 22ed4b7 57b5dd0 22ed4b7 68eb545 2c51a4b 68eb545 6d06dfd b72e90e 6d06dfd 68eb545 9b27f4a 68eb545 9b27f4a 68eb545 6757676 68eb545 2c51a4b 68eb545 57b5dd0 68eb545 57b5dd0 68eb545 57b5dd0 68eb545 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import os
from project_settings import project_path
hf_hub_cache = (project_path / "cache/huggingface/hub").as_posix()
os.environ["HUGGINGFACE_HUB_CACHE"] = hf_hub_cache
import gradio as gr
import nltk
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
from transformers.generation.streamers import TextIteratorStreamer
def main():
model_dict = {
"facebook/m2m100_418M": {
"model": M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M"),
"tokenizer": M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
}
}
def fn_non_stream(src_text: str,
src_lang: str,
tgt_lang: str,
model_name: str,
):
model_group = model_dict.get(model_name)
if model_group is None:
for k in list(model_dict.keys()):
del model_dict[k]
model_dict[model_name] = {
"model": M2M100ForConditionalGeneration.from_pretrained(model_name),
"tokenizer": M2M100Tokenizer.from_pretrained(model_name)
}
model_group = model_dict[model_name]
model = model_group["model"]
tokenizer = model_group["tokenizer"]
tokenizer.src_lang = src_lang
src_t_list = nltk.sent_tokenize(src_text)
result = ""
for src_t in src_t_list:
encoded_src = tokenizer(src_t, return_tensors="pt")
generated_tokens = model.generate(**encoded_src,
forced_bos_token_id=tokenizer.get_lang_id(tgt_lang),
)
text_decoded = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
result += text_decoded[0]
output.value = result
return result
title = "Multilingual Machine Translation"
description = """
M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation.
It was introduced in this [paper](https://arxiv.org/abs/2010.11125) and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository.
[Languages covered](https://huggingface.co/facebook/m2m100_418M#languages-covered)
"""
examples = [
[
"Hello world!",
"en",
"zh",
"facebook/m2m100_418M",
],
]
model_choices = [
"facebook/m2m100_418M",
"facebook/m2m100_1.2B"
]
inputs = [
gr.Textbox(lines=4, placeholder="text", label="Input Text"),
gr.Textbox(lines=1, value="en", label="Source Language"),
gr.Textbox(lines=1, value="zh", label="Target Language"),
gr.Dropdown(choices=model_choices, value="facebook/m2m100_418M", label="model_name")
]
output = gr.Textbox(lines=4, label="Output Text")
demo = gr.Interface(
fn=fn_non_stream,
inputs=inputs,
outputs=output,
examples=examples,
title=title,
description=description,
cache_examples=False
)
demo.queue().launch(debug=True, enable_queue=True)
return
if __name__ == '__main__':
main()
|