nisheeth commited on
Commit
37b69a5
·
verified ·
1 Parent(s): f93e55e

Upload 9 files

Browse files
Files changed (9) hide show
  1. .gitattributes +8 -1
  2. .gitignore.txt +2 -0
  3. README.md +5 -5
  4. app.py +52 -103
  5. langs.py +8 -0
  6. langs_all.py +204 -0
  7. requirements.txt +2 -6
  8. ui.cpython-310.pyc +0 -0
  9. ui.py +12 -0
.gitattributes CHANGED
@@ -2,20 +2,27 @@
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
5
  *.ftz filter=lfs diff=lfs merge=lfs -text
6
  *.gz filter=lfs diff=lfs merge=lfs -text
7
  *.h5 filter=lfs diff=lfs merge=lfs -text
8
  *.joblib filter=lfs diff=lfs merge=lfs -text
9
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
10
  *.model filter=lfs diff=lfs merge=lfs -text
11
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
12
  *.onnx filter=lfs diff=lfs merge=lfs -text
13
  *.ot filter=lfs diff=lfs merge=lfs -text
14
  *.parquet filter=lfs diff=lfs merge=lfs -text
15
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
16
  *.pt filter=lfs diff=lfs merge=lfs -text
17
  *.pth filter=lfs diff=lfs merge=lfs -text
18
  *.rar filter=lfs diff=lfs merge=lfs -text
 
19
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
20
  *.tar.* filter=lfs diff=lfs merge=lfs -text
21
  *.tflite filter=lfs diff=lfs merge=lfs -text
@@ -23,5 +30,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
23
  *.wasm filter=lfs diff=lfs merge=lfs -text
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
- *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
 
30
  *.wasm filter=lfs diff=lfs merge=lfs -text
31
  *.xz filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .devcontainer/*
2
+ ui.cpython-310.pyc
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: Nllb Translation Demo
3
- emoji: 👀
4
  colorFrom: indigo
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 3.0.26
8
  app_file: app.py
9
  pinned: false
10
- duplicated_from: Geonmo/nllb-translation-demo
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: NLLB200 Translate Distill 600
3
+ emoji: 🐢
4
  colorFrom: indigo
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 3.18.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,106 +1,55 @@
1
- import os
2
- import torch
3
  import gradio as gr
4
- import time
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
- from flores200_codes import flores_codes
7
-
8
-
9
- def load_models():
10
- # build model and tokenizer
11
- model_name_dict = {#'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
12
- 'nllb-1.3B': 'facebook/nllb-200-1.3B',
13
- #'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
14
- #'nllb-3.3B': 'facebook/nllb-200-3.3B',
15
- }
16
-
17
- model_dict = {}
18
-
19
- for call_name, real_name in model_name_dict.items():
20
- print('\tLoading model: %s' % call_name)
21
- model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
22
- tokenizer = AutoTokenizer.from_pretrained(real_name)
23
- model_dict[call_name+'_model'] = model
24
- model_dict[call_name+'_tokenizer'] = tokenizer
25
-
26
- return model_dict
27
-
28
-
29
- def load_models():
30
- # build model and tokenizer
31
- model_name_dict = {'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
32
- #'nllb-1.3B': 'facebook/nllb-200-1.3B',
33
- #'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
34
- #'nllb-3.3B': 'facebook/nllb-200-3.3B',
35
- }
36
-
37
- model_dict = {}
38
-
39
- for call_name, real_name in model_name_dict.items():
40
- print('\tLoading model: %s' % call_name)
41
- model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
42
- tokenizer = AutoTokenizer.from_pretrained(real_name)
43
- model_dict[call_name+'_model'] = model
44
- model_dict[call_name+'_tokenizer'] = tokenizer
45
-
46
- return model_dict
47
-
48
-
49
- def translation(source, target, text):
50
- if len(model_dict) == 2:
51
- model_name = 'nllb-distilled-600M'
52
-
53
- start_time = time.time()
54
- source = flores_codes[source]
55
- target = flores_codes[target]
56
-
57
- model = model_dict[model_name + '_model']
58
- tokenizer = model_dict[model_name + '_tokenizer']
59
-
60
- translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target)
61
- output = translator(text, max_length=400)
62
-
63
- end_time = time.time()
64
-
65
- output = output[0]['translation_text']
66
- result = {'inference_time': end_time - start_time,
67
- 'source': source,
68
- 'target': target,
69
- 'result': output}
70
- return result
71
-
72
-
73
-
74
- if __name__ == '__main__':
75
- print('\tinit models')
76
-
77
- global model_dict
78
-
79
- model_dict = load_models()
80
 
81
- # define gradio demo
82
- lang_codes = list(flores_codes.keys())
83
- #inputs = [gr.inputs.Radio(['nllb-distilled-600M', 'nllb-1.3B', 'nllb-distilled-1.3B'], label='NLLB Model'),
84
- inputs = [gr.inputs.Dropdown(lang_codes, default='English', label='Source'),
85
- gr.inputs.Dropdown(lang_codes, default='Korean', label='Target'),
86
- gr.inputs.Textbox(lines=5, label="Input text"),
87
- ]
88
-
89
- outputs = gr.outputs.JSON()
90
-
91
- title = "NLLB distilled 1.3B demo"
92
-
93
- demo_status = "Demo is running on CPU"
94
- description = f"Details: https://github.com/facebookresearch/fairseq/tree/nllb. {demo_status}"
95
- examples = [
96
- ['English', 'Korean', 'Hi. nice to meet you']
97
- ]
98
-
99
- gr.Interface(translation,
100
- inputs,
101
- outputs,
102
- title=title,
103
- description=description,
104
- ).launch()
105
-
106
-
 
 
 
 
1
  import gradio as gr
 
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
+ import torch
4
+ from ui import title, description, examples
5
+ from langs import LANGS
6
+ #from langs_all import LANGS ##for 200+ languages
7
+
8
+ TASK = "translation"
9
+ CKPT = "facebook/nllb-200-distilled-1.3B"
10
+ #CKPT = "facebook/nllb-200-distilled-600M"
11
+
12
+ model = AutoModelForSeq2SeqLM.from_pretrained(CKPT)
13
+ tokenizer = AutoTokenizer.from_pretrained(CKPT)
14
+
15
+ # device = 0 if torch.cuda.is_available() else -1
16
+
17
+
18
+ def translate(text, src_lang, tgt_lang, max_length=512):
19
+ """
20
+ Translate the text from source lang to target lang
21
+ """
22
+ translation_pipeline = pipeline(TASK,
23
+ model=model,
24
+ tokenizer=tokenizer,
25
+ src_lang=src_lang,
26
+ tgt_lang=tgt_lang,
27
+ max_length=max_length)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ # translation_pipeline = pipeline(TASK,
30
+ # model=model,
31
+ # tokenizer=tokenizer,
32
+ # src_lang=src_lang,
33
+ # tgt_lang=tgt_lang,
34
+ # max_length=max_length,
35
+ # device=device)
36
+
37
+ result = translation_pipeline(text)
38
+ return result[0]['translation_text']
39
+
40
+
41
+ gr.Interface(
42
+ translate,
43
+ [
44
+ gr.components.Textbox(label="Text"),
45
+ gr.components.Dropdown(label="Source Language", choices=LANGS),
46
+ gr.components.Dropdown(label="Target Language", choices=LANGS),
47
+ gr.components.Slider(8, 512, value=512, step=8, label="Max Length")
48
+ ],
49
+ ["text"],
50
+ examples=examples,
51
+ # article=article,
52
+ cache_examples=False,
53
+ title=title,
54
+ description=description
55
+ ).launch()
langs.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ LANGS = [
2
+ "bod_Tibt",
3
+ "khk_Cyrl",
4
+ "uig_Arab",
5
+ "yue_Hant",
6
+ "zho_Hans",
7
+ "zho_Hant"
8
+ ]
langs_all.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LANGS = [
2
+ "ace_Arab",
3
+ "ace_Latn",
4
+ "acm_Arab",
5
+ "acq_Arab",
6
+ "aeb_Arab",
7
+ "afr_Latn",
8
+ "ajp_Arab",
9
+ "aka_Latn",
10
+ "amh_Ethi",
11
+ "apc_Arab",
12
+ "arb_Arab",
13
+ "ars_Arab",
14
+ "ary_Arab",
15
+ "arz_Arab",
16
+ "asm_Beng",
17
+ "ast_Latn",
18
+ "awa_Deva",
19
+ "ayr_Latn",
20
+ "azb_Arab",
21
+ "azj_Latn",
22
+ "bak_Cyrl",
23
+ "bam_Latn",
24
+ "ban_Latn",
25
+ "bel_Cyrl",
26
+ "bem_Latn",
27
+ "ben_Beng",
28
+ "bho_Deva",
29
+ "bjn_Arab",
30
+ "bjn_Latn",
31
+ "bod_Tibt",
32
+ "bos_Latn",
33
+ "bug_Latn",
34
+ "bul_Cyrl",
35
+ "cat_Latn",
36
+ "ceb_Latn",
37
+ "ces_Latn",
38
+ "cjk_Latn",
39
+ "ckb_Arab",
40
+ "crh_Latn",
41
+ "cym_Latn",
42
+ "dan_Latn",
43
+ "deu_Latn",
44
+ "dik_Latn",
45
+ "dyu_Latn",
46
+ "dzo_Tibt",
47
+ "ell_Grek",
48
+ "eng_Latn",
49
+ "epo_Latn",
50
+ "est_Latn",
51
+ "eus_Latn",
52
+ "ewe_Latn",
53
+ "fao_Latn",
54
+ "pes_Arab",
55
+ "fij_Latn",
56
+ "fin_Latn",
57
+ "fon_Latn",
58
+ "fra_Latn",
59
+ "fur_Latn",
60
+ "fuv_Latn",
61
+ "gla_Latn",
62
+ "gle_Latn",
63
+ "glg_Latn",
64
+ "grn_Latn",
65
+ "guj_Gujr",
66
+ "hat_Latn",
67
+ "hau_Latn",
68
+ "heb_Hebr",
69
+ "hin_Deva",
70
+ "hne_Deva",
71
+ "hrv_Latn",
72
+ "hun_Latn",
73
+ "hye_Armn",
74
+ "ibo_Latn",
75
+ "ilo_Latn",
76
+ "ind_Latn",
77
+ "isl_Latn",
78
+ "ita_Latn",
79
+ "jav_Latn",
80
+ "jpn_Jpan",
81
+ "kab_Latn",
82
+ "kac_Latn",
83
+ "kam_Latn",
84
+ "kan_Knda",
85
+ "kas_Arab",
86
+ "kas_Deva",
87
+ "kat_Geor",
88
+ "knc_Arab",
89
+ "knc_Latn",
90
+ "kaz_Cyrl",
91
+ "kbp_Latn",
92
+ "kea_Latn",
93
+ "khm_Khmr",
94
+ "kik_Latn",
95
+ "kin_Latn",
96
+ "kir_Cyrl",
97
+ "kmb_Latn",
98
+ "kon_Latn",
99
+ "kor_Hang",
100
+ "kmr_Latn",
101
+ "lao_Laoo",
102
+ "lvs_Latn",
103
+ "lij_Latn",
104
+ "lim_Latn",
105
+ "lin_Latn",
106
+ "lit_Latn",
107
+ "lmo_Latn",
108
+ "ltg_Latn",
109
+ "ltz_Latn",
110
+ "lua_Latn",
111
+ "lug_Latn",
112
+ "luo_Latn",
113
+ "lus_Latn",
114
+ "mag_Deva",
115
+ "mai_Deva",
116
+ "mal_Mlym",
117
+ "mar_Deva",
118
+ "min_Latn",
119
+ "mkd_Cyrl",
120
+ "plt_Latn",
121
+ "mlt_Latn",
122
+ "mni_Beng",
123
+ "khk_Cyrl",
124
+ "mos_Latn",
125
+ "mri_Latn",
126
+ "zsm_Latn",
127
+ "mya_Mymr",
128
+ "nld_Latn",
129
+ "nno_Latn",
130
+ "nob_Latn",
131
+ "npi_Deva",
132
+ "nso_Latn",
133
+ "nus_Latn",
134
+ "nya_Latn",
135
+ "oci_Latn",
136
+ "gaz_Latn",
137
+ "ory_Orya",
138
+ "pag_Latn",
139
+ "pan_Guru",
140
+ "pap_Latn",
141
+ "pol_Latn",
142
+ "por_Latn",
143
+ "prs_Arab",
144
+ "pbt_Arab",
145
+ "quy_Latn",
146
+ "ron_Latn",
147
+ "run_Latn",
148
+ "rus_Cyrl",
149
+ "sag_Latn",
150
+ "san_Deva",
151
+ "sat_Beng",
152
+ "scn_Latn",
153
+ "shn_Mymr",
154
+ "sin_Sinh",
155
+ "slk_Latn",
156
+ "slv_Latn",
157
+ "smo_Latn",
158
+ "sna_Latn",
159
+ "snd_Arab",
160
+ "som_Latn",
161
+ "sot_Latn",
162
+ "spa_Latn",
163
+ "als_Latn",
164
+ "srd_Latn",
165
+ "srp_Cyrl",
166
+ "ssw_Latn",
167
+ "sun_Latn",
168
+ "swe_Latn",
169
+ "swh_Latn",
170
+ "szl_Latn",
171
+ "tam_Taml",
172
+ "tat_Cyrl",
173
+ "tel_Telu",
174
+ "tgk_Cyrl",
175
+ "tgl_Latn",
176
+ "tha_Thai",
177
+ "tir_Ethi",
178
+ "taq_Latn",
179
+ "taq_Tfng",
180
+ "tpi_Latn",
181
+ "tsn_Latn",
182
+ "tso_Latn",
183
+ "tuk_Latn",
184
+ "tum_Latn",
185
+ "tur_Latn",
186
+ "twi_Latn",
187
+ "tzm_Tfng",
188
+ "uig_Arab",
189
+ "ukr_Cyrl",
190
+ "umb_Latn",
191
+ "urd_Arab",
192
+ "uzn_Latn",
193
+ "vec_Latn",
194
+ "vie_Latn",
195
+ "war_Latn",
196
+ "wol_Latn",
197
+ "xho_Latn",
198
+ "ydd_Hebr",
199
+ "yor_Latn",
200
+ "yue_Hant",
201
+ "zho_Hans",
202
+ "zho_Hant",
203
+ "zul_Latn"
204
+ ]
requirements.txt CHANGED
@@ -1,7 +1,3 @@
1
  git+https://github.com/huggingface/transformers
2
-
3
- gradio==3.8
4
-
5
- torch
6
-
7
- httpx==0.24.1
 
1
  git+https://github.com/huggingface/transformers
2
+ gradio
3
+ torch
 
 
 
 
ui.cpython-310.pyc ADDED
Binary file (672 Bytes). View file
 
ui.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ title = "NLLB-200 Traslation Demo"
2
+ description = """
3
+ <p>
4
+ <center>
5
+ Translator using <a href='https://ai.facebook.com/research/no-language-left-behind/' target='_blank'>Facebook's NLLB</a> models.
6
+ Codes Using <a href='https://github.com/facebookresearch/fairseq/tree/nllb' target=blank'>Facebook's fairseq NLLB</a>.
7
+ Demo is running on CPU.
8
+ </center>
9
+ </p>
10
+ """
11
+
12
+ examples = [["我非常喜欢这个地方", "zho_Hans", "yue_Hant", 512]]