[update]add sent_tokenize model
Browse files- cache/huggingface/hub/version.txt +1 -0
- main.py +44 -5
cache/huggingface/hub/version.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
1
|
main.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
import argparse
|
|
|
4 |
import os
|
5 |
|
6 |
from project_settings import project_path
|
@@ -14,6 +15,36 @@ from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
|
|
14 |
from transformers.generation.streamers import TextIteratorStreamer
|
15 |
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
def main():
|
18 |
model_dict = {
|
19 |
"facebook/m2m100_418M": {
|
@@ -27,6 +58,7 @@ def main():
|
|
27 |
tgt_lang: str,
|
28 |
model_name: str,
|
29 |
):
|
|
|
30 |
model_group = model_dict.get(model_name)
|
31 |
if model_group is None:
|
32 |
for k in list(model_dict.keys()):
|
@@ -41,15 +73,20 @@ def main():
|
|
41 |
model = model_group["model"]
|
42 |
tokenizer = model_group["tokenizer"]
|
43 |
|
44 |
-
|
|
|
45 |
|
46 |
-
|
|
|
|
|
|
|
47 |
|
|
|
48 |
result = ""
|
49 |
for src_t in src_t_list:
|
50 |
encoded_src = tokenizer(src_t, return_tensors="pt")
|
51 |
generated_tokens = model.generate(**encoded_src,
|
52 |
-
forced_bos_token_id=tokenizer.get_lang_id(tgt_lang),
|
53 |
)
|
54 |
text_decoded = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
55 |
result += text_decoded[0]
|
@@ -83,8 +120,10 @@ It was introduced in this [paper](https://arxiv.org/abs/2010.11125) and first re
|
|
83 |
]
|
84 |
inputs = [
|
85 |
gr.Textbox(lines=4, placeholder="text", label="Input Text"),
|
86 |
-
gr.
|
87 |
-
gr.
|
|
|
|
|
88 |
gr.Dropdown(choices=model_choices, value="facebook/m2m100_418M", label="model_name")
|
89 |
]
|
90 |
|
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
import argparse
|
4 |
+
import json
|
5 |
import os
|
6 |
|
7 |
from project_settings import project_path
|
|
|
15 |
from transformers.generation.streamers import TextIteratorStreamer
|
16 |
|
17 |
|
18 |
+
language_map = {
|
19 |
+
"Chinese": "zh",
|
20 |
+
"Czech": "cs",
|
21 |
+
"Danish": "da",
|
22 |
+
"Dutch": "nl",
|
23 |
+
"Flemish": "nl",
|
24 |
+
"English": "en",
|
25 |
+
"Estonian": "et",
|
26 |
+
"Finnish": "fi",
|
27 |
+
"French": "fr",
|
28 |
+
"German": "de",
|
29 |
+
"Italian": "it",
|
30 |
+
"Norwegian": "no",
|
31 |
+
"Polish": "pl",
|
32 |
+
"Portuguese": "pt",
|
33 |
+
"Russian": "ru",
|
34 |
+
"Spanish": "es",
|
35 |
+
"Swedish": "sv",
|
36 |
+
"Turkish": "tr",
|
37 |
+
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
nltk_sent_tokenize_languages = [
|
42 |
+
"czech", "danish", "dutch", "flemish", "english", "estonian",
|
43 |
+
"finnish", "french", "german", "italian", "norwegian",
|
44 |
+
"polish", "portuguese", "russian", "spanish", "swedish", "turkish"
|
45 |
+
]
|
46 |
+
|
47 |
+
|
48 |
def main():
|
49 |
model_dict = {
|
50 |
"facebook/m2m100_418M": {
|
|
|
58 |
tgt_lang: str,
|
59 |
model_name: str,
|
60 |
):
|
61 |
+
# model
|
62 |
model_group = model_dict.get(model_name)
|
63 |
if model_group is None:
|
64 |
for k in list(model_dict.keys()):
|
|
|
73 |
model = model_group["model"]
|
74 |
tokenizer = model_group["tokenizer"]
|
75 |
|
76 |
+
# tokenize
|
77 |
+
tokenizer.src_lang = language_map[src_lang]
|
78 |
|
79 |
+
if src_lang.lower() in nltk_sent_tokenize_languages:
|
80 |
+
src_t_list = nltk.sent_tokenize(src_text, language="")
|
81 |
+
else:
|
82 |
+
src_t_list = [src_text]
|
83 |
|
84 |
+
# infer
|
85 |
result = ""
|
86 |
for src_t in src_t_list:
|
87 |
encoded_src = tokenizer(src_t, return_tensors="pt")
|
88 |
generated_tokens = model.generate(**encoded_src,
|
89 |
+
forced_bos_token_id=tokenizer.get_lang_id(language_map[tgt_lang]),
|
90 |
)
|
91 |
text_decoded = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
92 |
result += text_decoded[0]
|
|
|
120 |
]
|
121 |
inputs = [
|
122 |
gr.Textbox(lines=4, placeholder="text", label="Input Text"),
|
123 |
+
gr.Dropdown(choices=list(language_map.keys()), value="English", label="Source Language"),
|
124 |
+
gr.Dropdown(choices=list(language_map.keys()), value="Chinese", label="Target Language"),
|
125 |
+
# gr.Textbox(lines=1, value="en", label="Source Language"),
|
126 |
+
# gr.Textbox(lines=1, value="zh", label="Target Language"),
|
127 |
gr.Dropdown(choices=model_choices, value="facebook/m2m100_418M", label="model_name")
|
128 |
]
|
129 |
|