qgyd2021 commited on
Commit
e6a62de
·
1 Parent(s): 6bebdfc

[update]add sent_tokenize model

Browse files
Files changed (2) hide show
  1. cache/huggingface/hub/version.txt +1 -0
  2. main.py +44 -5
cache/huggingface/hub/version.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 1
main.py CHANGED
@@ -1,6 +1,7 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import argparse
 
4
  import os
5
 
6
  from project_settings import project_path
@@ -14,6 +15,36 @@ from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
14
  from transformers.generation.streamers import TextIteratorStreamer
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def main():
18
  model_dict = {
19
  "facebook/m2m100_418M": {
@@ -27,6 +58,7 @@ def main():
27
  tgt_lang: str,
28
  model_name: str,
29
  ):
 
30
  model_group = model_dict.get(model_name)
31
  if model_group is None:
32
  for k in list(model_dict.keys()):
@@ -41,15 +73,20 @@ def main():
41
  model = model_group["model"]
42
  tokenizer = model_group["tokenizer"]
43
 
44
- tokenizer.src_lang = src_lang
 
45
 
46
- src_t_list = nltk.sent_tokenize(src_text)
 
 
 
47
 
 
48
  result = ""
49
  for src_t in src_t_list:
50
  encoded_src = tokenizer(src_t, return_tensors="pt")
51
  generated_tokens = model.generate(**encoded_src,
52
- forced_bos_token_id=tokenizer.get_lang_id(tgt_lang),
53
  )
54
  text_decoded = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
55
  result += text_decoded[0]
@@ -83,8 +120,10 @@ It was introduced in this [paper](https://arxiv.org/abs/2010.11125) and first re
83
  ]
84
  inputs = [
85
  gr.Textbox(lines=4, placeholder="text", label="Input Text"),
86
- gr.Textbox(lines=1, value="en", label="Source Language"),
87
- gr.Textbox(lines=1, value="zh", label="Target Language"),
 
 
88
  gr.Dropdown(choices=model_choices, value="facebook/m2m100_418M", label="model_name")
89
  ]
90
 
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import argparse
4
+ import json
5
  import os
6
 
7
  from project_settings import project_path
 
15
  from transformers.generation.streamers import TextIteratorStreamer
16
 
17
 
18
+ language_map = {
19
+ "Chinese": "zh",
20
+ "Czech": "cs",
21
+ "Danish": "da",
22
+ "Dutch": "nl",
23
+ "Flemish": "nl",
24
+ "English": "en",
25
+ "Estonian": "et",
26
+ "Finnish": "fi",
27
+ "French": "fr",
28
+ "German": "de",
29
+ "Italian": "it",
30
+ "Norwegian": "no",
31
+ "Polish": "pl",
32
+ "Portuguese": "pt",
33
+ "Russian": "ru",
34
+ "Spanish": "es",
35
+ "Swedish": "sv",
36
+ "Turkish": "tr",
37
+
38
+ }
39
+
40
+
41
+ nltk_sent_tokenize_languages = [
42
+ "czech", "danish", "dutch", "flemish", "english", "estonian",
43
+ "finnish", "french", "german", "italian", "norwegian",
44
+ "polish", "portuguese", "russian", "spanish", "swedish", "turkish"
45
+ ]
46
+
47
+
48
  def main():
49
  model_dict = {
50
  "facebook/m2m100_418M": {
 
58
  tgt_lang: str,
59
  model_name: str,
60
  ):
61
+ # model
62
  model_group = model_dict.get(model_name)
63
  if model_group is None:
64
  for k in list(model_dict.keys()):
 
73
  model = model_group["model"]
74
  tokenizer = model_group["tokenizer"]
75
 
76
+ # tokenize
77
+ tokenizer.src_lang = language_map[src_lang]
78
 
79
+ if src_lang.lower() in nltk_sent_tokenize_languages:
80
+ src_t_list = nltk.sent_tokenize(src_text, language="")
81
+ else:
82
+ src_t_list = [src_text]
83
 
84
+ # infer
85
  result = ""
86
  for src_t in src_t_list:
87
  encoded_src = tokenizer(src_t, return_tensors="pt")
88
  generated_tokens = model.generate(**encoded_src,
89
+ forced_bos_token_id=tokenizer.get_lang_id(language_map[tgt_lang]),
90
  )
91
  text_decoded = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
92
  result += text_decoded[0]
 
120
  ]
121
  inputs = [
122
  gr.Textbox(lines=4, placeholder="text", label="Input Text"),
123
+ gr.Dropdown(choices=list(language_map.keys()), value="English", label="Source Language"),
124
+ gr.Dropdown(choices=list(language_map.keys()), value="Chinese", label="Target Language"),
125
+ # gr.Textbox(lines=1, value="en", label="Source Language"),
126
+ # gr.Textbox(lines=1, value="zh", label="Target Language"),
127
  gr.Dropdown(choices=model_choices, value="facebook/m2m100_418M", label="model_name")
128
  ]
129