Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
from pathlib import Path | |
#from transformers import MBartForConditionalGeneration, MBart50TokenizerFast | |
from transformers import M2M100ForConditionalGeneration | |
from tokenization_small100 import SMALL100Tokenizer | |
st.set_page_config(page_title="Translation Demo", page_icon=":milky_way:", layout="wide") | |
def get_translation(src_code, trg_code, src): | |
#tokenizer.src_lang = src_code | |
#encoded = tokenizer(src, return_tensors="pt") | |
#generated_tokens = model.generate( | |
#**encoded, | |
#forced_bos_token_id=tokenizer.lang_code_to_id[trg_code] | |
#) | |
#trg = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) | |
tokenizer.tgt_lang = trg_code | |
encoded = tokenizer(src, return_tensors="pt") | |
generated_tokens = model.generate(**encoded) | |
trg = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) | |
return trg | |
def open_input(the_file): | |
if the_file.name.endswith('.tsv'): | |
parsed = pd.read_csv(the_file, sep="\t") | |
elif the_file.name.endswith('.xlsx'): | |
parsed = pd.read_excel(the_file) | |
return parsed | |
st.subheader("SMALL-100 Translator") | |
source = "In the beginning the Universe was created. This has made a lot of people very angry and been widely regarded as a bad move." | |
target = "" | |
#model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") | |
#tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") | |
model = M2M100ForConditionalGeneration.from_pretrained("alirezamsh/small100") | |
tokenizer = SMALL100Tokenizer.from_pretrained("alirezamsh/small100") | |
#valid_languages = ['de_DE', 'en_XX', 'it_IT'] | |
valid_languages = ['de', 'it', 'en'] | |
valid_languages_tuple = (lang for lang in valid_languages) | |
valid_languages_tuple_trg = (lang for lang in valid_languages) | |
with st.form("my_form"): | |
left_c, right_c = st.columns(2) | |
#with left_c: | |
src_lang = st.selectbox( | |
'Source language', | |
valid_languages_tuple, | |
) | |
#with right_c: | |
trg_lang = st.selectbox( | |
'Target language', | |
valid_languages_tuple_trg, | |
) | |
source = st.text_area("Source", value=source, height=130, placeholder="Enter the source text...") | |
submitted = st.form_submit_button("Translate") | |
if submitted: | |
if len(source) > 0 and src_lang in valid_languages and trg_lang in valid_languages: | |
with st.spinner("Translating..."): | |
try: | |
target = get_translation(src_lang, trg_lang, source)[0] | |
st.subheader("Translation done!") | |
target = st.text_area("Target", value=target, height=130) | |
except: | |
st.subheader("Translation failed :sad:") | |
else: | |
st.write("Please enter the source text, source language and target language.") | |
st.subheader('Input Excel/TSV') | |
uploaded_file = st.file_uploader("Choose a file") | |
done = False | |
if uploaded_file is not None: | |
valid_languages_col = (lang for lang in valid_languages) | |
valid_languages_col_trg = (lang for lang in valid_languages) | |
data = open_input(uploaded_file) | |
st.subheader("DataFrame") | |
st.write(data) | |
st.write(data.describe()) | |
columns = (col for col in data.columns) | |
src_col = st.selectbox( | |
'Select the column to translate:', | |
columns, | |
) | |
if src_col: | |
col_src_lang = st.selectbox( | |
'Source language:', | |
valid_languages_col, | |
) | |
col_trg_lang = st.selectbox( | |
'Target language:', | |
valid_languages_col_trg, | |
) | |
submitted_cols = st.button("Translate column") | |
if submitted_cols: | |
translated_data = [] | |
new_df = data | |
for text in data[src_col]: | |
if len(text) > 0 and col_src_lang in valid_languages and col_trg_lang in valid_languages: | |
with st.spinner("Translating..."): | |
try: | |
target_text = get_translation(col_src_lang, col_trg_lang, text)[0] | |
translated_data.append(target_text) | |
except: | |
st.subheader("Translation failed :sad:") | |
break | |
else: | |
st.write("Please enter the source text, source language and target language.") | |
new_df[src_col] = translated_data | |
done = True | |
if done: | |
st.subheader("Translated DataFrame") | |
st.write(new_df) | |
st.write(new_df.describe()) | |
to_dl = new_df.to_csv(index=False, sep='\t').encode('utf-8') | |
st.download_button('Download TSV', to_dl, 'translated_file.tsv', 'text/tsv', key='download-tsv') | |
else: | |
st.info("☝️ Upload a TSV file") | |