|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline |
|
import torch |
|
from ui import title, description, examples |
|
from langs import LANGS |
|
|
|
|
|
TASK = "translation" |
|
|
|
CKPT = "facebook/nllb-200-distilled-600M" |
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(CKPT) |
|
tokenizer = AutoTokenizer.from_pretrained(CKPT) |
|
|
|
|
|
|
|
|
|
def translate(text, src_lang, tgt_lang, max_length=512): |
|
""" |
|
Translate the text from source lang to target lang |
|
""" |
|
translation_pipeline = pipeline(TASK, |
|
model=model, |
|
tokenizer=tokenizer, |
|
src_lang=src_lang, |
|
tgt_lang=tgt_lang, |
|
max_length=max_length) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result = translation_pipeline(text) |
|
return result[0]['translation_text'] |
|
|
|
|
|
gr.Interface( |
|
translate, |
|
[ |
|
gr.components.Textbox(label="Text"), |
|
gr.components.Dropdown(label="Source Language", choices=LANGS), |
|
gr.components.Dropdown(label="Target Language", choices=LANGS), |
|
gr.components.Slider(8, 512, value=512, step=8, label="Max Length") |
|
], |
|
["text"], |
|
examples=examples, |
|
|
|
cache_examples=False, |
|
title=title, |
|
description=description |
|
).launch() |
|
|