Spaces:
Sleeping
Sleeping
File size: 6,617 Bytes
8b092c8 37d9263 f6ee9aa 312ff3c 8b092c8 0ab6a79 8b092c8 2ed25e9 f690e46 68ea804 cec6612 8b092c8 37d9263 68ea804 37d9263 8b092c8 37d9263 8b092c8 0ab6a79 312ff3c 0ab6a79 312ff3c 0ab6a79 312ff3c f17c002 312ff3c 9831833 f17c002 312ff3c f17c002 312ff3c 0ab6a79 50f7072 312ff3c 50f7072 312ff3c 50f7072 312ff3c 0ab6a79 37d9263 0ab6a79 37d9263 0ab6a79 f690e46 37d9263 0ab6a79 37d9263 2a4bf26 447af8d cec6612 8b092c8 cec6612 8b092c8 cec6612 8b092c8 cec6612 8b092c8 cec6612 8b092c8 312ff3c cec6612 d44d8ad cec6612 312ff3c f17c002 312ff3c f17c002 312ff3c 49f070a 0ab6a79 312ff3c 50f7072 312ff3c 0ab6a79 312ff3c 49f070a cec6612 312ff3c cec6612 f17c002 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
import streamlit as st
import pandas as pd
from pathlib import Path
#from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
from transformers import M2M100ForConditionalGeneration
from tokenization_small100 import SMALL100Tokenizer
import io
st.set_page_config(page_title="Translation Demo", page_icon=":milky_way:", layout="wide")
@st.cache
def load_model():
model = M2M100ForConditionalGeneration.from_pretrained("alirezamsh/small100")
return model
def get_translation(src_code, trg_code, src):
#tokenizer.src_lang = src_code
#encoded = tokenizer(src, return_tensors="pt")
#generated_tokens = model.generate(
#**encoded,
#forced_bos_token_id=tokenizer.lang_code_to_id[trg_code]
#)
#trg = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
model = load_model()
tokenizer.tgt_lang = trg_code
encoded = tokenizer(src, return_tensors="pt")
generated_tokens = model.generate(**encoded)
trg = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
return trg
def open_input(the_file):
sheets = []
if the_file.name.endswith('.tsv'):
parsed = pd.read_csv(the_file, sep="\t")
elif the_file.name.endswith('.xlsx'):
xlsx = pd.ExcelFile(the_file)
if len(xlsx.sheet_names) > 1:
sheets = [sheet for sheet in xlsx.sheet_names]
parsed = [pd.read_excel(xlsx, sheet) for sheet in sheets]
else:
parsed = pd.read_excel(the_file)
return parsed, sheets
def translate_data(df, s_lang, t_lang, col_for_translation, languages):
translated_data = []
new_df = df
for text in df[col_for_translation]:
if len(text) > 0 and s_lang in languages and t_lang in languages:
with st.spinner("Translating..."):
try:
target_text = get_translation(s_lang, t_lang, text)[0]
translated_data.append(target_text)
except:
st.subheader("Translation failed :sad:")
break
else:
st.write("Please enter the source text, source language and target language.")
new_df["SMALL-100 translation"] = translated_data
return new_df
def select_column(data, valid_source, valid_target, is_excel=False):
if is_excel:
columns = (col for col in data[0].columns)
else:
columns = (col for col in data.columns)
src_col = st.selectbox(
'Select the column to translate (WARNING: You can only select a single column - please make sure all columns are named accordingly):',
columns,
)
if src_col:
col_src_lang = st.selectbox(
'Source language:',
valid_source,
)
col_trg_lang = st.selectbox(
'Target language:',
valid_target,
)
submitted_cols = st.button("Translate column")
return submitted_cols, src_col, col_src_lang, col_trg_lang
st.subheader("SMALL-100 Translator")
source = "In the beginning the Universe was created. This has made a lot of people very angry and been widely regarded as a bad move."
target = ""
#model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
#tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer = SMALL100Tokenizer.from_pretrained("alirezamsh/small100")
#valid_languages = ['de_DE', 'en_XX', 'it_IT']
valid_languages = ['de', 'it', 'en', 'fr', 'sw', 'wo']
valid_languages_tuple = (lang for lang in valid_languages)
valid_languages_tuple_trg = (lang for lang in valid_languages)
with st.form("my_form"):
left_c, right_c = st.columns(2)
#with left_c:
src_lang = st.selectbox(
'Source language',
valid_languages_tuple,
)
#with right_c:
trg_lang = st.selectbox(
'Target language',
valid_languages_tuple_trg,
)
source = st.text_area("Source", value=source, height=130, placeholder="Enter the source text...")
submitted = st.form_submit_button("Translate")
if submitted:
if len(source) > 0 and src_lang in valid_languages and trg_lang in valid_languages:
with st.spinner("Translating..."):
try:
target = get_translation(src_lang, trg_lang, source)[0]
st.subheader("Translation done!")
target = st.text_area("Target", value=target, height=130)
except:
st.subheader("Translation failed :sad:")
else:
st.write("Please enter the source text, source language and target language.")
st.subheader('Input XLSX/TSV')
uploaded_file = st.file_uploader("Choose a file")
done = False
if uploaded_file is not None:
valid_col = (lang for lang in valid_languages)
valid_col_trg = (lang for lang in valid_languages)
data, sheets = open_input(uploaded_file)
if len(sheets) > 0:
translated_sheets = []
submitted_cols, src_col, src_code, trg_code = select_column(data, valid_col, valid_col_trg, is_excel=True)
if submitted_cols:
for sheet in data:
translated_sheets.append(translate_data(sheet, src_code, trg_code, src_col, valid_languages))
done = True
else:
submitted_cols, src_col, valid_col, valid_col_trg = select_column(data, valid_col, valid_col_trg)
st.subheader("DataFrame")
st.write(data)
st.write(data.describe())
if submitted_cols:
new_df = translate_data(data, valid_col, valid_col_trg, src_col, valid_languages)
done = True
if done:
st.subheader("Translated DataFrame")
if len(sheets) > 0:
pass
buffer = io.BytesIO()
with pd.ExcelWriter(buffer) as writer:
for idx, sheet in enumerate(translated_sheets):
sheet.to_excel(writer, sheet_name=sheets[idx])
st.download_button('Download XLSX', buffer, 'translated_file.xlsx', 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', key='download-xlsx')
else:
st.write(new_df)
st.write(new_df.describe())
to_dl = new_df.to_csv(index=False, sep='\t').encode('utf-8')
st.download_button('Download TSV', to_dl, 'translated_file.tsv', 'text/tsv', key='download-tsv')
else:
st.info("☝️ Upload a TSV file") |