[update]add main
Browse files
main.py
CHANGED
@@ -13,27 +13,32 @@ import gradio as gr
|
|
13 |
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
|
14 |
|
15 |
|
16 |
-
def get_args():
|
17 |
-
parser = argparse.ArgumentParser()
|
18 |
-
parser.add_argument(
|
19 |
-
"--pretrained_model_name_or_path",
|
20 |
-
default="facebook/m2m100_418M",
|
21 |
-
type=str
|
22 |
-
)
|
23 |
-
args = parser.parse_args()
|
24 |
-
return args
|
25 |
-
|
26 |
-
|
27 |
def main():
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
32 |
|
33 |
def multilingual_translate(src_text: str,
|
34 |
src_lang: str,
|
35 |
tgt_lang: str,
|
|
|
36 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
tokenizer.src_lang = src_lang
|
38 |
encoded_src = tokenizer(src_text, return_tensors="pt")
|
39 |
generated_tokens = model.generate(**encoded_src,
|
@@ -77,13 +82,19 @@ Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu)
|
|
77 |
"Hello world!",
|
78 |
"en",
|
79 |
"zh",
|
|
|
80 |
],
|
81 |
]
|
82 |
|
|
|
|
|
|
|
|
|
83 |
inputs = [
|
84 |
gr.Textbox(lines=4, value="", label="Input Text"),
|
85 |
gr.Textbox(lines=1, value="", label="Source Language"),
|
86 |
gr.Textbox(lines=1, value="", label="Target Language"),
|
|
|
87 |
]
|
88 |
|
89 |
output = gr.Textbox(lines=4, label="Output Text")
|
|
|
13 |
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
|
14 |
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
def main():
|
17 |
+
model_dict = {
|
18 |
+
"facebook/m2m100_418M": {
|
19 |
+
"model": M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M"),
|
20 |
+
"tokenizer": M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
|
21 |
+
}
|
22 |
+
}
|
23 |
|
24 |
def multilingual_translate(src_text: str,
|
25 |
src_lang: str,
|
26 |
tgt_lang: str,
|
27 |
+
model_name: str,
|
28 |
):
|
29 |
+
model_group = model_dict.get(model_name)
|
30 |
+
if model_group is None:
|
31 |
+
for k, mg in model_dict.items():
|
32 |
+
del mg["model"]
|
33 |
+
model_dict[model_name] = {
|
34 |
+
"model": M2M100ForConditionalGeneration.from_pretrained(model_name),
|
35 |
+
"tokenizer": M2M100Tokenizer.from_pretrained(model_name)
|
36 |
+
}
|
37 |
+
model_group = model_dict[model_name]
|
38 |
+
|
39 |
+
model = model_group["model"]
|
40 |
+
tokenizer = model_group["tokenizer"]
|
41 |
+
|
42 |
tokenizer.src_lang = src_lang
|
43 |
encoded_src = tokenizer(src_text, return_tensors="pt")
|
44 |
generated_tokens = model.generate(**encoded_src,
|
|
|
82 |
"Hello world!",
|
83 |
"en",
|
84 |
"zh",
|
85 |
+
"facebook/m2m100_418M",
|
86 |
],
|
87 |
]
|
88 |
|
89 |
+
model_choices = [
|
90 |
+
"facebook/m2m100_418M",
|
91 |
+
"facebook/m2m100_1.2B"
|
92 |
+
]
|
93 |
inputs = [
|
94 |
gr.Textbox(lines=4, value="", label="Input Text"),
|
95 |
gr.Textbox(lines=1, value="", label="Source Language"),
|
96 |
gr.Textbox(lines=1, value="", label="Target Language"),
|
97 |
+
gr.Dropdown(choices=model_choices, label="model_name")
|
98 |
]
|
99 |
|
100 |
output = gr.Textbox(lines=4, label="Output Text")
|