File size: 6,617 Bytes
8b092c8
 
 
37d9263
 
f6ee9aa
312ff3c
 
8b092c8
0ab6a79
8b092c8
2ed25e9
f690e46
68ea804
 
cec6612
8b092c8
 
37d9263
 
 
 
 
 
 
68ea804
37d9263
8b092c8
37d9263
8b092c8
 
 
 
0ab6a79
 
 
312ff3c
 
0ab6a79
 
312ff3c
0ab6a79
312ff3c
 
 
 
 
 
 
 
 
 
f17c002
312ff3c
9831833
 
f17c002
312ff3c
 
f17c002
312ff3c
 
 
 
 
 
 
 
 
 
0ab6a79
 
50f7072
312ff3c
 
 
 
 
 
 
 
 
 
 
 
 
 
50f7072
312ff3c
 
 
50f7072
312ff3c
 
 
 
 
0ab6a79
37d9263
0ab6a79
 
 
37d9263
 
0ab6a79
f690e46
37d9263
0ab6a79
37d9263
2a4bf26
447af8d
cec6612
 
 
8b092c8
 
 
cec6612
 
8b092c8
cec6612
8b092c8
cec6612
 
8b092c8
cec6612
8b092c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312ff3c
cec6612
 
d44d8ad
cec6612
312ff3c
 
 
 
 
 
f17c002
312ff3c
 
 
f17c002
312ff3c
49f070a
0ab6a79
312ff3c
50f7072
312ff3c
 
 
 
 
0ab6a79
312ff3c
49f070a
cec6612
 
 
312ff3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cec6612
 
 
f17c002
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import streamlit as st
import pandas as pd
from pathlib import Path
#from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
from transformers import M2M100ForConditionalGeneration
from tokenization_small100 import SMALL100Tokenizer
import io 


st.set_page_config(page_title="Translation Demo", page_icon=":milky_way:", layout="wide")

@st.cache
def load_model():
    model = M2M100ForConditionalGeneration.from_pretrained("alirezamsh/small100")
    return model

def get_translation(src_code, trg_code, src):

    #tokenizer.src_lang = src_code
    #encoded = tokenizer(src, return_tensors="pt")
    #generated_tokens = model.generate(
        #**encoded,
        #forced_bos_token_id=tokenizer.lang_code_to_id[trg_code]
    #)
    #trg = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
    model = load_model()
    tokenizer.tgt_lang = trg_code
    encoded = tokenizer(src, return_tensors="pt")
    generated_tokens = model.generate(**encoded)
    trg = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
    
    return trg


def open_input(the_file):
    
    sheets = []
    
    if the_file.name.endswith('.tsv'):
        parsed = pd.read_csv(the_file, sep="\t")
        
    elif the_file.name.endswith('.xlsx'):
        xlsx = pd.ExcelFile(the_file)
        if len(xlsx.sheet_names) > 1:
            sheets = [sheet for sheet in xlsx.sheet_names]
            parsed = [pd.read_excel(xlsx, sheet) for sheet in sheets]
        else:
            parsed = pd.read_excel(the_file)

    return parsed, sheets


def translate_data(df, s_lang, t_lang, col_for_translation, languages):
    translated_data = []
    new_df = df
    for text in df[col_for_translation]:
        if len(text) > 0 and s_lang in languages and t_lang in languages:
            with st.spinner("Translating..."):
                try:
                    target_text = get_translation(s_lang, t_lang, text)[0]
                    translated_data.append(target_text)
                except:
                    st.subheader("Translation failed :sad:")
                    break
        else:
            st.write("Please enter the source text, source language and target language.")
                    
    new_df["SMALL-100 translation"] = translated_data
            
    return new_df


def select_column(data, valid_source, valid_target, is_excel=False):

    if is_excel:
        columns = (col for col in data[0].columns)
    else:
        columns = (col for col in data.columns)
    
    src_col = st.selectbox(
        'Select the column to translate (WARNING: You can only select a single column - please make sure all columns are named accordingly):',
        columns,
    )
    
    if src_col:
        col_src_lang = st.selectbox(
            'Source language:',
            valid_source,
        )
        col_trg_lang = st.selectbox(
            'Target language:',
            valid_target,
        )
        submitted_cols = st.button("Translate column")
    
    return submitted_cols, src_col, col_src_lang, col_trg_lang
    

st.subheader("SMALL-100 Translator")

source = "In the beginning the Universe was created. This has made a lot of people very angry and been widely regarded as a bad move."
target = ""
#model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
#tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")


tokenizer = SMALL100Tokenizer.from_pretrained("alirezamsh/small100")

#valid_languages = ['de_DE', 'en_XX', 'it_IT']
valid_languages = ['de', 'it', 'en', 'fr', 'sw', 'wo']


valid_languages_tuple = (lang for lang in valid_languages)
valid_languages_tuple_trg = (lang for lang in valid_languages)

with st.form("my_form"):
    left_c, right_c = st.columns(2)
    #with left_c:
    src_lang = st.selectbox(
        'Source language',
        valid_languages_tuple,
        )
    #with right_c:
    trg_lang = st.selectbox(
        'Target language',
        valid_languages_tuple_trg,
        )
    source = st.text_area("Source", value=source, height=130, placeholder="Enter the source text...")


    submitted = st.form_submit_button("Translate")
    if submitted:
        if len(source) > 0 and src_lang in valid_languages and trg_lang in valid_languages:
            with st.spinner("Translating..."):
                try:
                    target = get_translation(src_lang, trg_lang, source)[0]
                    st.subheader("Translation done!")
                    target = st.text_area("Target", value=target, height=130)
                except:
                    st.subheader("Translation failed :sad:")
                    
        else:
            st.write("Please enter the source text, source language and target language.")


st.subheader('Input XLSX/TSV')
uploaded_file = st.file_uploader("Choose a file")
done = False

if uploaded_file is not None:
    valid_col = (lang for lang in valid_languages)
    valid_col_trg = (lang for lang in valid_languages)
    data, sheets = open_input(uploaded_file)

    if len(sheets) > 0:
        translated_sheets = []
        submitted_cols, src_col, src_code, trg_code = select_column(data, valid_col, valid_col_trg, is_excel=True)

        if submitted_cols:
            for sheet in data:
                translated_sheets.append(translate_data(sheet, src_code, trg_code, src_col, valid_languages))
                
            done = True
        
    else:
        submitted_cols, src_col, valid_col, valid_col_trg = select_column(data, valid_col, valid_col_trg)
        
        st.subheader("DataFrame")
        st.write(data)
        st.write(data.describe())

        if submitted_cols:
            new_df = translate_data(data, valid_col, valid_col_trg, src_col, valid_languages)
            done = True    

if done:
    st.subheader("Translated DataFrame")
    
    if len(sheets) > 0:
        pass
        buffer = io.BytesIO()
        with pd.ExcelWriter(buffer) as writer:
            for idx, sheet in enumerate(translated_sheets):
                sheet.to_excel(writer, sheet_name=sheets[idx])
                
        st.download_button('Download XLSX', buffer, 'translated_file.xlsx', 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', key='download-xlsx')

    
    else:    
        st.write(new_df)
        st.write(new_df.describe())
        to_dl = new_df.to_csv(index=False, sep='\t').encode('utf-8')
        st.download_button('Download TSV', to_dl, 'translated_file.tsv', 'text/tsv', key='download-tsv')

    
else:
    st.info("☝️ Upload a TSV file")