File size: 4,378 Bytes
8b092c8
 
 
377de90
8b092c8
0ab6a79
8b092c8
cec6612
8b092c8
 
 
 
 
 
 
 
 
 
 
 
0ab6a79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447af8d
 
 
cec6612
 
 
8b092c8
 
 
cec6612
 
8b092c8
cec6612
8b092c8
cec6612
 
8b092c8
cec6612
8b092c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ab6a79
cec6612
 
d44d8ad
8b092c8
cec6612
 
 
0ab6a79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cec6612
0ab6a79
 
 
cec6612
0ab6a79
 
 
 
 
 
 
 
 
 
 
f2a145b
 
 
0ab6a79
 
cec6612
0ab6a79
 
cec6612
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import streamlit as st
import pandas as pd
from pathlib import Path
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast

st.set_page_config(page_title="Translation Demo", page_icon=":milky_way:", layout="wide")


def get_translation(src_code, trg_code, src):

    tokenizer.src_lang = src_code
    encoded = tokenizer(src, return_tensors="pt")
    generated_tokens = model.generate(
        **encoded,
        forced_bos_token_id=tokenizer.lang_code_to_id[trg_code]
    )
    trg = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
    
    return trg


def open_input(the_file):
    
    if the_file.name.endswith('.tsv'):
        parsed = pd.read_csv(the_file, sep="\t")
    elif the_file.name.endswith('.xlsx'):
        parsed = pd.read_excel(the_file)

    return parsed


st.subheader("MBART-50 Translator")

source = "In the beginning the Universe was created. This has made a lot of people very angry and been widely regarded as a bad move."
target = ""
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")


valid_languages = ['de_DE', 'en_XX', 'it_IT']



valid_languages_tuple = (lang for lang in valid_languages)
valid_languages_tuple_trg = (lang for lang in valid_languages)

with st.form("my_form"):
    left_c, right_c = st.columns(2)
    #with left_c:
    src_lang = st.selectbox(
        'Source language',
        valid_languages_tuple,
        )
    #with right_c:
    trg_lang = st.selectbox(
        'Target language',
        valid_languages_tuple_trg,
        )
    source = st.text_area("Source", value=source, height=130, placeholder="Enter the source text...")


    submitted = st.form_submit_button("Translate")
    if submitted:
        if len(source) > 0 and src_lang in valid_languages and trg_lang in valid_languages:
            with st.spinner("Translating..."):
                try:
                    target = get_translation(src_lang, trg_lang, source)[0]
                    st.subheader("Translation done!")
                    target = st.text_area("Target", value=target, height=130)
                except:
                    st.subheader("Translation failed :sad:")
                    
        else:
            st.write("Please enter the source text, source language and target language.")


st.subheader('Input Excel/TSV')
uploaded_file = st.file_uploader("Choose a file")
done = False


if uploaded_file is not None:
    valid_languages_col = (lang for lang in valid_languages)
    valid_languages_col_trg = (lang for lang in valid_languages)
    data = open_input(uploaded_file)
    st.subheader("DataFrame")
    st.write(data)
    st.write(data.describe())
    
    columns = (col for col in data.columns)
    src_col = st.selectbox(
    'Select the column to translate:',
    columns,
    )
    
    if src_col:
        col_src_lang = st.selectbox(
            'Source language:',
            valid_languages_col,
        )
        col_trg_lang = st.selectbox(
            'Target language:',
            valid_languages_col_trg,
        )
        submitted_cols = st.button("Translate column")
        
        if submitted_cols:
            translated_data = []
            new_df = data
            for text in data[src_col]:
                if len(text) > 0 and col_src_lang in valid_languages and col_trg_lang in valid_languages:
                    with st.spinner("Translating..."):
                        try:
                            target_text = get_translation(col_src_lang, col_trg_lang, text)[0]
                            translated_data.append(target_text)
                        except:
                            st.subheader("Translation failed :sad:")
                            break
                else:
                    st.write("Please enter the source text, source language and target language.")
                    
            new_df[src_col] = translated_data
            done = True

if done:
    st.subheader("Translated DataFrame")
    st.write(new_df)
    st.write(new_df.describe())
    to_dl = new_df.to_csv(index=False, sep='\t').encode('utf-8')
    st.download_button('Download TSV', to_dl, 'translated_file.tsv', 'text/tsv', key='download-tsv')

    
else:
    st.info("☝️ Upload a TSV file")