Spaces:
Sleeping
Sleeping
File size: 4,891 Bytes
8b092c8 37d9263 f6ee9aa 8b092c8 0ab6a79 8b092c8 cec6612 8b092c8 37d9263 8b092c8 37d9263 8b092c8 0ab6a79 37d9263 0ab6a79 37d9263 0ab6a79 37d9263 0ab6a79 37d9263 447af8d cec6612 8b092c8 cec6612 8b092c8 cec6612 8b092c8 cec6612 8b092c8 cec6612 8b092c8 0ab6a79 cec6612 d44d8ad 8b092c8 cec6612 0ab6a79 cec6612 0ab6a79 cec6612 0ab6a79 f2a145b 0ab6a79 cec6612 0ab6a79 cec6612 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import streamlit as st
import pandas as pd
from pathlib import Path
#from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
from transformers import M2M100ForConditionalGeneration
from tokenization_small100 import SMALL100Tokenizer
st.set_page_config(page_title="Translation Demo", page_icon=":milky_way:", layout="wide")
def get_translation(src_code, trg_code, src):
#tokenizer.src_lang = src_code
#encoded = tokenizer(src, return_tensors="pt")
#generated_tokens = model.generate(
#**encoded,
#forced_bos_token_id=tokenizer.lang_code_to_id[trg_code]
#)
#trg = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
tokenizer.tgt_lang = trg_code
encoded = tokenizer(src, return_tensors="pt")
generated_tokens = model.generate(**encoded)
trg = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
return trg
def open_input(the_file):
if the_file.name.endswith('.tsv'):
parsed = pd.read_csv(the_file, sep="\t")
elif the_file.name.endswith('.xlsx'):
parsed = pd.read_excel(the_file)
return parsed
st.subheader("SMALL-100 Translator")
source = "In the beginning the Universe was created. This has made a lot of people very angry and been widely regarded as a bad move."
target = ""
#model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
#tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
model = M2M100ForConditionalGeneration.from_pretrained("alirezamsh/small100")
tokenizer = SMALL100Tokenizer.from_pretrained("alirezamsh/small100")
#valid_languages = ['de_DE', 'en_XX', 'it_IT']
valid_languages = ['de', 'it', 'en']
valid_languages_tuple = (lang for lang in valid_languages)
valid_languages_tuple_trg = (lang for lang in valid_languages)
with st.form("my_form"):
left_c, right_c = st.columns(2)
#with left_c:
src_lang = st.selectbox(
'Source language',
valid_languages_tuple,
)
#with right_c:
trg_lang = st.selectbox(
'Target language',
valid_languages_tuple_trg,
)
source = st.text_area("Source", value=source, height=130, placeholder="Enter the source text...")
submitted = st.form_submit_button("Translate")
if submitted:
if len(source) > 0 and src_lang in valid_languages and trg_lang in valid_languages:
with st.spinner("Translating..."):
try:
target = get_translation(src_lang, trg_lang, source)[0]
st.subheader("Translation done!")
target = st.text_area("Target", value=target, height=130)
except:
st.subheader("Translation failed :sad:")
else:
st.write("Please enter the source text, source language and target language.")
st.subheader('Input Excel/TSV')
uploaded_file = st.file_uploader("Choose a file")
done = False
if uploaded_file is not None:
valid_languages_col = (lang for lang in valid_languages)
valid_languages_col_trg = (lang for lang in valid_languages)
data = open_input(uploaded_file)
st.subheader("DataFrame")
st.write(data)
st.write(data.describe())
columns = (col for col in data.columns)
src_col = st.selectbox(
'Select the column to translate:',
columns,
)
if src_col:
col_src_lang = st.selectbox(
'Source language:',
valid_languages_col,
)
col_trg_lang = st.selectbox(
'Target language:',
valid_languages_col_trg,
)
submitted_cols = st.button("Translate column")
if submitted_cols:
translated_data = []
new_df = data
for text in data[src_col]:
if len(text) > 0 and col_src_lang in valid_languages and col_trg_lang in valid_languages:
with st.spinner("Translating..."):
try:
target_text = get_translation(col_src_lang, col_trg_lang, text)[0]
translated_data.append(target_text)
except:
st.subheader("Translation failed :sad:")
break
else:
st.write("Please enter the source text, source language and target language.")
new_df[src_col] = translated_data
done = True
if done:
st.subheader("Translated DataFrame")
st.write(new_df)
st.write(new_df.describe())
to_dl = new_df.to_csv(index=False, sep='\t').encode('utf-8')
st.download_button('Download TSV', to_dl, 'translated_file.tsv', 'text/tsv', key='download-tsv')
else:
st.info("☝️ Upload a TSV file")
|