Spaces:
Sleeping
Sleeping
File size: 4,378 Bytes
8b092c8 377de90 8b092c8 0ab6a79 8b092c8 cec6612 8b092c8 0ab6a79 447af8d cec6612 8b092c8 cec6612 8b092c8 cec6612 8b092c8 cec6612 8b092c8 cec6612 8b092c8 0ab6a79 cec6612 d44d8ad 8b092c8 cec6612 0ab6a79 cec6612 0ab6a79 cec6612 0ab6a79 f2a145b 0ab6a79 cec6612 0ab6a79 cec6612 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import streamlit as st
import pandas as pd
from pathlib import Path
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
st.set_page_config(page_title="Translation Demo", page_icon=":milky_way:", layout="wide")
def get_translation(src_code, trg_code, src):
tokenizer.src_lang = src_code
encoded = tokenizer(src, return_tensors="pt")
generated_tokens = model.generate(
**encoded,
forced_bos_token_id=tokenizer.lang_code_to_id[trg_code]
)
trg = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
return trg
def open_input(the_file):
if the_file.name.endswith('.tsv'):
parsed = pd.read_csv(the_file, sep="\t")
elif the_file.name.endswith('.xlsx'):
parsed = pd.read_excel(the_file)
return parsed
st.subheader("MBART-50 Translator")
source = "In the beginning the Universe was created. This has made a lot of people very angry and been widely regarded as a bad move."
target = ""
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
valid_languages = ['de_DE', 'en_XX', 'it_IT']
valid_languages_tuple = (lang for lang in valid_languages)
valid_languages_tuple_trg = (lang for lang in valid_languages)
with st.form("my_form"):
left_c, right_c = st.columns(2)
#with left_c:
src_lang = st.selectbox(
'Source language',
valid_languages_tuple,
)
#with right_c:
trg_lang = st.selectbox(
'Target language',
valid_languages_tuple_trg,
)
source = st.text_area("Source", value=source, height=130, placeholder="Enter the source text...")
submitted = st.form_submit_button("Translate")
if submitted:
if len(source) > 0 and src_lang in valid_languages and trg_lang in valid_languages:
with st.spinner("Translating..."):
try:
target = get_translation(src_lang, trg_lang, source)[0]
st.subheader("Translation done!")
target = st.text_area("Target", value=target, height=130)
except:
st.subheader("Translation failed :sad:")
else:
st.write("Please enter the source text, source language and target language.")
st.subheader('Input Excel/TSV')
uploaded_file = st.file_uploader("Choose a file")
done = False
if uploaded_file is not None:
valid_languages_col = (lang for lang in valid_languages)
valid_languages_col_trg = (lang for lang in valid_languages)
data = open_input(uploaded_file)
st.subheader("DataFrame")
st.write(data)
st.write(data.describe())
columns = (col for col in data.columns)
src_col = st.selectbox(
'Select the column to translate:',
columns,
)
if src_col:
col_src_lang = st.selectbox(
'Source language:',
valid_languages_col,
)
col_trg_lang = st.selectbox(
'Target language:',
valid_languages_col_trg,
)
submitted_cols = st.button("Translate column")
if submitted_cols:
translated_data = []
new_df = data
for text in data[src_col]:
if len(text) > 0 and col_src_lang in valid_languages and col_trg_lang in valid_languages:
with st.spinner("Translating..."):
try:
target_text = get_translation(col_src_lang, col_trg_lang, text)[0]
translated_data.append(target_text)
except:
st.subheader("Translation failed :sad:")
break
else:
st.write("Please enter the source text, source language and target language.")
new_df[src_col] = translated_data
done = True
if done:
st.subheader("Translated DataFrame")
st.write(new_df)
st.write(new_df.describe())
to_dl = new_df.to_csv(index=False, sep='\t').encode('utf-8')
st.download_button('Download TSV', to_dl, 'translated_file.tsv', 'text/tsv', key='download-tsv')
else:
st.info("☝️ Upload a TSV file")
|