qgyd2021 commited on
Commit
9b27f4a
·
1 Parent(s): 6d06dfd

[update]add main

Browse files
Files changed (1) hide show
  1. main.py +26 -15
main.py CHANGED
@@ -13,27 +13,32 @@ import gradio as gr
13
  from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
14
 
15
 
16
- def get_args():
17
- parser = argparse.ArgumentParser()
18
- parser.add_argument(
19
- "--pretrained_model_name_or_path",
20
- default="facebook/m2m100_418M",
21
- type=str
22
- )
23
- args = parser.parse_args()
24
- return args
25
-
26
-
27
  def main():
28
- args = get_args()
29
-
30
- model = M2M100ForConditionalGeneration.from_pretrained(args.pretrained_model_name_or_path)
31
- tokenizer = M2M100Tokenizer.from_pretrained(args.pretrained_model_name_or_path)
 
 
32
 
33
  def multilingual_translate(src_text: str,
34
  src_lang: str,
35
  tgt_lang: str,
 
36
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  tokenizer.src_lang = src_lang
38
  encoded_src = tokenizer(src_text, return_tensors="pt")
39
  generated_tokens = model.generate(**encoded_src,
@@ -77,13 +82,19 @@ Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu)
77
  "Hello world!",
78
  "en",
79
  "zh",
 
80
  ],
81
  ]
82
 
 
 
 
 
83
  inputs = [
84
  gr.Textbox(lines=4, value="", label="Input Text"),
85
  gr.Textbox(lines=1, value="", label="Source Language"),
86
  gr.Textbox(lines=1, value="", label="Target Language"),
 
87
  ]
88
 
89
  output = gr.Textbox(lines=4, label="Output Text")
 
13
  from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
16
  def main():
17
+ model_dict = {
18
+ "facebook/m2m100_418M": {
19
+ "model": M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M"),
20
+ "tokenizer": M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
21
+ }
22
+ }
23
 
24
  def multilingual_translate(src_text: str,
25
  src_lang: str,
26
  tgt_lang: str,
27
+ model_name: str,
28
  ):
29
+ model_group = model_dict.get(model_name)
30
+ if model_group is None:
31
+ for k, mg in model_dict.items():
32
+ del mg["model"]
33
+ model_dict[model_name] = {
34
+ "model": M2M100ForConditionalGeneration.from_pretrained(model_name),
35
+ "tokenizer": M2M100Tokenizer.from_pretrained(model_name)
36
+ }
37
+ model_group = model_dict[model_name]
38
+
39
+ model = model_group["model"]
40
+ tokenizer = model_group["tokenizer"]
41
+
42
  tokenizer.src_lang = src_lang
43
  encoded_src = tokenizer(src_text, return_tensors="pt")
44
  generated_tokens = model.generate(**encoded_src,
 
82
  "Hello world!",
83
  "en",
84
  "zh",
85
+ "facebook/m2m100_418M",
86
  ],
87
  ]
88
 
89
+ model_choices = [
90
+ "facebook/m2m100_418M",
91
+ "facebook/m2m100_1.2B"
92
+ ]
93
  inputs = [
94
  gr.Textbox(lines=4, value="", label="Input Text"),
95
  gr.Textbox(lines=1, value="", label="Source Language"),
96
  gr.Textbox(lines=1, value="", label="Target Language"),
97
+ gr.Dropdown(choices=model_choices, label="model_name")
98
  ]
99
 
100
  output = gr.Textbox(lines=4, label="Output Text")