avans06 commited on
Commit
40311b7
·
1 Parent(s): e8762f9

Merge branch 'main' of https://huggingface.co/spaces/SoybeanMilk/whisper-webui-translate

Browse files

1.Thank you SoybeanMilk for assisting in the development and integration of the excellent ALMA translation model.

2.Add the missing packages for the GPTQ version of ALMA in the requirements.txt.

3.Include the 7B version of the ALMA model in addition to the 13B version in the web UI.

4.Write the ReadMe document for the Translate Model, containing a brief introduction and explanation of each translation model used in this project.

app.py CHANGED
@@ -40,7 +40,7 @@ from src.whisper.whisperFactory import create_whisper_container
40
  from src.translation.translationModel import TranslationModel
41
  from src.translation.translationLangs import (TranslationLang,
42
  _TO_LANG_CODE_WHISPER, get_lang_whisper_names, get_lang_from_whisper_name, get_lang_from_whisper_code,
43
- get_lang_nllb_names, get_lang_from_nllb_name, get_lang_m2m100_names, get_lang_from_m2m100_name)
44
  import shutil
45
  import zhconv
46
  import tqdm
@@ -773,8 +773,15 @@ class WhisperTranscriber:
773
  self.diarization = None
774
 
775
  def create_ui(app_config: ApplicationConfig):
 
776
  optionsMd: str = None
777
  readmeMd: str = None
 
 
 
 
 
 
778
  try:
779
  optionsPath = pathlib.Path("docs/options.md")
780
  with open(optionsPath, "r", encoding="utf-8") as optionsFile:
@@ -819,16 +826,6 @@ def create_ui(app_config: ApplicationConfig):
819
  uiDescription += "\n\n" + "Max audio file length: " + str(app_config.input_audio_max_duration) + " s"
820
 
821
  uiArticle = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)."
822
- uiArticle += "\n\nWhisper's Task 'translate' only implements the functionality of translating other languages into English. "
823
- uiArticle += "OpenAI does not guarantee translations between arbitrary languages. In such cases, you can choose to use the Translation Model to implement the translation task. "
824
- uiArticle += "However, it's important to note that the Translation Model runs slowly(in CPU), and the completion time may be twice as long as usual. "
825
- uiArticle += "\n\nThe larger the parameters of the Translation model, the better its performance is expected to be. "
826
- uiArticle += "However, it also requires higher computational resources, making it slower to operate. "
827
- uiArticle += "On the other hand, the version converted from ct2 ([CTranslate2](https://opennmt.net/CTranslate2/guides/transformers.html)) requires lower resources and operates at a faster speed."
828
- uiArticle += "\n\nCurrently, enabling `Highlight Words` timestamps cannot be used in conjunction with Translation Model translation "
829
- uiArticle += "because Highlight Words will split the source text, and after translation, it becomes a non-word-level string. "
830
- uiArticle += "\n\nThe 'mt5-zh-ja-en-trimmed' model is finetuned from Google's 'mt5-base' model. "
831
- uiArticle += "This model has a relatively good translation speed, but it only supports three languages: Chinese, Japanese, and English. "
832
 
833
  whisper_models = app_config.get_model_names("whisper")
834
  nllb_models = app_config.get_model_names("nllb")
@@ -854,7 +851,7 @@ def create_ui(app_config: ApplicationConfig):
854
  }
855
  common_ALMA_inputs = lambda : {
856
  gr.Dropdown(label="ALMA - Model (for translate)", choices=ALMA_models, elem_id="ALMAModelName"),
857
- gr.Dropdown(label="ALMA - Language", choices=sorted(get_lang_m2m100_names(["en", "ja", "zh"])), elem_id="ALMALangName"),
858
  }
859
 
860
  common_translation_inputs = lambda : {
@@ -944,8 +941,10 @@ def create_ui(app_config: ApplicationConfig):
944
  simpleInputDict.update(common_translation_inputs())
945
  with gr.Column():
946
  simpleOutput = common_output()
947
- with gr.Accordion("Article"):
948
- gr.Markdown(uiArticle)
 
 
949
  if optionsMd is not None:
950
  with gr.Accordion("docs/options.md", open=False):
951
  gr.Markdown(optionsMd)
@@ -1056,7 +1055,7 @@ def create_ui(app_config: ApplicationConfig):
1056
  print("Queue mode enabled (concurrency count: " + str(app_config.queue_concurrency_count) + ")")
1057
  else:
1058
  print("Queue mode disabled - progress bars will not be shown.")
1059
-
1060
  demo.launch(inbrowser=app_config.autolaunch, share=app_config.share, server_name=app_config.server_name, server_port=app_config.server_port)
1061
 
1062
  # Clean up
@@ -1138,6 +1137,16 @@ if __name__ == '__main__':
1138
  # updated_config.autolaunch = True
1139
  # updated_config.auto_parallel = False
1140
  # updated_config.save_downloaded_files = True
 
 
 
 
 
 
 
 
 
 
1141
 
1142
  if (threads := args.pop("threads")) > 0:
1143
  torch.set_num_threads(threads)
 
40
  from src.translation.translationModel import TranslationModel
41
  from src.translation.translationLangs import (TranslationLang,
42
  _TO_LANG_CODE_WHISPER, get_lang_whisper_names, get_lang_from_whisper_name, get_lang_from_whisper_code,
43
+ get_lang_nllb_names, get_lang_from_nllb_name, get_lang_m2m100_names, get_lang_from_m2m100_name, sort_lang_by_whisper_codes)
44
  import shutil
45
  import zhconv
46
  import tqdm
 
773
  self.diarization = None
774
 
775
  def create_ui(app_config: ApplicationConfig):
776
+ translateModelMd: str = None
777
  optionsMd: str = None
778
  readmeMd: str = None
779
+ try:
780
+ translateModelPath = pathlib.Path("docs/translateModel.md")
781
+ with open(translateModelPath, "r", encoding="utf-8") as translateModelFile:
782
+ translateModelMd = translateModelFile.read()
783
+ except Exception as e:
784
+ print("Error occurred during read translateModel.md file: ", str(e))
785
  try:
786
  optionsPath = pathlib.Path("docs/options.md")
787
  with open(optionsPath, "r", encoding="utf-8") as optionsFile:
 
826
  uiDescription += "\n\n" + "Max audio file length: " + str(app_config.input_audio_max_duration) + " s"
827
 
828
  uiArticle = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)."
 
 
 
 
 
 
 
 
 
 
829
 
830
  whisper_models = app_config.get_model_names("whisper")
831
  nllb_models = app_config.get_model_names("nllb")
 
851
  }
852
  common_ALMA_inputs = lambda : {
853
  gr.Dropdown(label="ALMA - Model (for translate)", choices=ALMA_models, elem_id="ALMAModelName"),
854
+ gr.Dropdown(label="ALMA - Language", choices=sort_lang_by_whisper_codes(["en", "de", "cs", "is", "ru", "zh", "ja"]), elem_id="ALMALangName"),
855
  }
856
 
857
  common_translation_inputs = lambda : {
 
941
  simpleInputDict.update(common_translation_inputs())
942
  with gr.Column():
943
  simpleOutput = common_output()
944
+ gr.Markdown(uiArticle)
945
+ if translateModelMd is not None:
946
+ with gr.Accordion("docs/translateModel.md", open=False):
947
+ gr.Markdown(translateModelMd)
948
  if optionsMd is not None:
949
  with gr.Accordion("docs/options.md", open=False):
950
  gr.Markdown(optionsMd)
 
1055
  print("Queue mode enabled (concurrency count: " + str(app_config.queue_concurrency_count) + ")")
1056
  else:
1057
  print("Queue mode disabled - progress bars will not be shown.")
1058
+
1059
  demo.launch(inbrowser=app_config.autolaunch, share=app_config.share, server_name=app_config.server_name, server_port=app_config.server_port)
1060
 
1061
  # Clean up
 
1137
  # updated_config.autolaunch = True
1138
  # updated_config.auto_parallel = False
1139
  # updated_config.save_downloaded_files = True
1140
+
1141
+ try:
1142
+ if torch.cuda.is_available():
1143
+ deviceId = torch.cuda.current_device()
1144
+ totalVram = torch.cuda.get_device_properties(deviceId).total_memory
1145
+ if totalVram/(1024*1024*1024) <= 4: #VRAM <= 4 GB
1146
+ updated_config.vad_process_timeout = 0
1147
+ except Exception as e:
1148
+ print(traceback.format_exc())
1149
+ print("Error detect vram: " + str(e))
1150
 
1151
  if (threads := args.pop("threads")) > 0:
1152
  torch.set_num_threads(threads)
config.json5 CHANGED
@@ -193,10 +193,15 @@
193
  }
194
  ],
195
  "ALMA": [
 
 
 
 
 
196
  {
197
  "name": "ALMA-13B-GPTQ/TheBloke",
198
  "url": "TheBloke/ALMA-13B-GPTQ",
199
- "type": "huggingface",
200
  },
201
  ]
202
  },
 
193
  }
194
  ],
195
  "ALMA": [
196
+ {
197
+ "name": "ALMA-7B-GPTQ/TheBloke",
198
+ "url": "TheBloke/ALMA-7B-GPTQ",
199
+ "type": "huggingface"
200
+ },
201
  {
202
  "name": "ALMA-13B-GPTQ/TheBloke",
203
  "url": "TheBloke/ALMA-13B-GPTQ",
204
+ "type": "huggingface"
205
  },
206
  ]
207
  },
docs/options.md CHANGED
@@ -1,4 +1,4 @@
1
- # Standard Options
2
  To transcribe or translate an audio file, you can either copy an URL from a website (all [websites](https://github.com/yt-dlp/yt-dlp/blob/master/supportedsites.md)
3
  supported by YT-DLP will work, including YouTube). Otherwise, upload an audio file (choose "All Files (*.*)"
4
  in the file selector to select any file type, including video files) or use the microphone.
@@ -154,29 +154,29 @@ The minimum number of speakers for Pyannote to detect.
154
  The maximum number of speakers for Pyannote to detect.
155
 
156
  ## Repetition Penalty
157
- - ctranslate2: repetition_penalty
158
  This parameter only takes effect in [faster-whisper (ctranslate2)](https://github.com/SYSTRAN/faster-whisper/issues/478).
159
  Penalty applied to the score of previously generated tokens (set > 1 to penalize).
160
 
161
  ## No Repeat Ngram Size
162
- - ctranslate2: no_repeat_ngram_size
163
  This parameter only takes effect in [faster-whisper (ctranslate2)](https://github.com/SYSTRAN/faster-whisper/issues/478).
164
  Prevent repetitions of ngrams with this size (set 0 to disable).
165
 
166
  ## Translation - Batch Size
167
- - transformers: batch_size
168
  When the pipeline will use DataLoader (when passing a dataset, on GPU for a Pytorch model), the size of the batch to use, for inference this is not always beneficial.
169
- - ctranslate2: max_batch_size
170
  The maximum batch size.
171
 
172
  ## Translation - No Repeat Ngram Size
173
- - transformers: no_repeat_ngram_size
174
  Value that will be used by default in the generate method of the model for no_repeat_ngram_size. If set to int > 0, all ngrams of that size can only occur once.
175
- - ctranslate2: no_repeat_ngram_size
176
  Prevent repetitions of ngrams with this size (set 0 to disable).
177
 
178
  ## Translation - Num Beams
179
- - transformers: num_beams
180
  Number of beams for beam search that will be used by default in the generate method of the model. 1 means no beam search.
181
- - ctranslate2: beam_size
182
  Beam size (1 for greedy search).
 
1
+ # Standard Options
2
  To transcribe or translate an audio file, you can either copy an URL from a website (all [websites](https://github.com/yt-dlp/yt-dlp/blob/master/supportedsites.md)
3
  supported by YT-DLP will work, including YouTube). Otherwise, upload an audio file (choose "All Files (*.*)"
4
  in the file selector to select any file type, including video files) or use the microphone.
 
154
  The maximum number of speakers for Pyannote to detect.
155
 
156
  ## Repetition Penalty
157
+ - ctranslate2: repetition_penalty
158
  This parameter only takes effect in [faster-whisper (ctranslate2)](https://github.com/SYSTRAN/faster-whisper/issues/478).
159
  Penalty applied to the score of previously generated tokens (set > 1 to penalize).
160
 
161
  ## No Repeat Ngram Size
162
+ - ctranslate2: no_repeat_ngram_size
163
  This parameter only takes effect in [faster-whisper (ctranslate2)](https://github.com/SYSTRAN/faster-whisper/issues/478).
164
  Prevent repetitions of ngrams with this size (set 0 to disable).
165
 
166
  ## Translation - Batch Size
167
+ - transformers: batch_size
168
  When the pipeline will use DataLoader (when passing a dataset, on GPU for a Pytorch model), the size of the batch to use, for inference this is not always beneficial.
169
+ - ctranslate2: max_batch_size
170
  The maximum batch size.
171
 
172
  ## Translation - No Repeat Ngram Size
173
+ - transformers: no_repeat_ngram_size
174
  Value that will be used by default in the generate method of the model for no_repeat_ngram_size. If set to int > 0, all ngrams of that size can only occur once.
175
+ - ctranslate2: no_repeat_ngram_size
176
  Prevent repetitions of ngrams with this size (set 0 to disable).
177
 
178
  ## Translation - Num Beams
179
+ - transformers: num_beams
180
  Number of beams for beam search that will be used by default in the generate method of the model. 1 means no beam search.
181
+ - ctranslate2: beam_size
182
  Beam size (1 for greedy search).
docs/translateModel.md ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 
2
+ # Describe
3
+
4
+ The `translate` task in `Whisper` only supports translating other languages `into English`. `OpenAI` does not guarantee translations between arbitrary languages. In such cases, you can opt to use the Translation Model for translation tasks. However, it's important to note that the `Translation Model runs very slowly on CPU`, and the completion time may be twice as long as usual. It is recommended to run the Translation Model on devices with `GPUs` for better performance.
5
+
6
+ The larger the parameters of the Translation model, the better its translation capability is expected. However, this also requires higher computational resources and slower running speed.
7
+
8
+ Currently, when the `Highlight Words timestamps` option is enabled in the Whisper `Word Timestamps options`, it cannot be used simultaneously with the Translation Model. This is because Highlight Words splits the source text, and after translation, it becomes a non-word-level string.
9
+
10
+
11
+ # Translation Model
12
+
13
+ The required VRAM is provided for reference and may not apply to everyone. If the model's VRAM requirement exceeds the available capacity of the system, the model will operate on the CPU, resulting in significantly longer execution times.
14
+
15
+ [CTranslate2](https://opennmt.net/CTranslate2/guides/transformers.html) is a C++ and Python library for efficient inference with Transformer models. Models converted from CTranslate2 can run with lower resources and faster speed. Encoder-decoder models currently supported: Transformer base/big, M2M-100, NLLB, BART, mBART, Pegasus, T5, Whisper.
16
+
17
+ ## M2M100
18
+
19
+ 2M100 is a multilingual translation model introduced by Facebook AI in October 2020. It supports arbitrary translation among 101 languages. The paper is titled "`Beyond English-Centric Multilingual Machine Translation`" ([arXiv:2010.11125](https://arxiv.org/abs/2010.11125)).
20
+
21
+ | Name | Parameters | Size | type/quantize | Required VRAM |
22
+ |------|------------|------|---------------|---------------|
23
+ | [facebook/m2m100_418M](https://huggingface.co/facebook/m2m100_418M) | 480M | 1.94 GB | float32 | ≈2 GB |
24
+ | [facebook/m2m100_1.2B](https://huggingface.co/facebook/m2m100_1.2B) | 1.2B | 4.96 GB | float32 | ≈5 GB |
25
+ | [facebook/m2m100-12B-last-ckpt](https://huggingface.co/facebook/m2m100-12B-last-ckpt) | 12B | 47.2 GB | float32 | N/A |
26
+
27
+ ## M2M100-CTranslate2
28
+
29
+ | Name | Parameters | Size | type/quantize | Required VRAM |
30
+ |------|------------|------|---------------|---------------|
31
+ | [michaelfeil/ct2fast-m2m100_418M](https://huggingface.co/michaelfeil/ct2fast-m2m100_418M) | 480M | 970 MB | float16 | ≈0.6 GB |
32
+ | [michaelfeil/ct2fast-m2m100_1.2B](https://huggingface.co/michaelfeil/ct2fast-m2m100_1.2B) | 1.2B | 2.48 GB | float16 | ≈1.3 GB |
33
+ | [michaelfeil/ct2fast-m2m100-12B-last-ckpt](https://huggingface.co/michaelfeil/ct2fast-m2m100-12B-last-ckpt) | 12B | 23.6 GB | float16 | N/A |
34
+
35
+ ## NLLB-200
36
+
37
+ NLLB-200 is a multilingual translation model introduced by Meta AI in July 2022. It supports arbitrary translation among 202 languages. The paper is titled "`No Language Left Behind: Scaling Human-Centered Machine Translation`" ([arXiv:2207.04672](https://arxiv.org/abs/2207.04672)).
38
+
39
+ | Name | Parameters | Size | type/quantize | Required VRAM |
40
+ |------|------------|------|---------------|---------------|
41
+ | [facebook/nllb-200-distilled-600M](https://huggingface.co/facebook/nllb-200-distilled-600M) | 600M | 2.46 GB | float32 | ≈2.5 GB |
42
+ | [facebook/nllb-200-distilled-1.3B](https://huggingface.co/facebook/nllb-200-distilled-1.3B) | 1.3B | 5.48 GB | float32 | ≈5.9 GB |
43
+ | [facebook/nllb-200-1.3B](https://huggingface.co/facebook/nllb-200-1.3B) | 1.3B | 5.48 GB | float32 | 5.8 GB |
44
+ | [facebook/nllb-200-3.3B](https://huggingface.co/facebook/nllb-200-3.3B) | 3.3B | 17.58 GB | float32 | 13.4 GB |
45
+
46
+ ## NLLB-200-CTranslate2
47
+
48
+ | Name | Parameters | Size | type/quantize | Required VRAM |
49
+ |------|------------|------|---------------|---------------|
50
+ | [michaelfeil/ct2fast-nllb-200-distilled-1.3B](https://huggingface.co/michaelfeil/ct2fast-nllb-200-distilled-1.3B) | 1.3B | 1.38 GB | int8_float16 | ≈1.3 GB |
51
+ | [michaelfeil/ct2fast-nllb-200-3.3B](https://huggingface.co/michaelfeil/ct2fast-nllb-200-3.3B) | 3.3B | 3.36 GB | int8_float16 | ≈3.2 GB |
52
+ | [JustFrederik/nllb-200-1.3B-ct2-int8](https://huggingface.co/JustFrederik/nllb-200-1.3B-ct2-int8) | 1.3B | 1.38 GB | int8 | ≈1.3 GB |
53
+ | [JustFrederik/nllb-200-1.3B-ct2-float16](https://huggingface.co/JustFrederik/nllb-200-1.3B-ct2-float16) | 1.3B | 2.74 GB | float16 | ≈1.3 GB |
54
+ | [JustFrederik/nllb-200-distilled-600M-ct2](https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2) | 600M | 2.46 GB | float32 | ≈0.6 GB |
55
+ | [JustFrederik/nllb-200-distilled-600M-ct2-float16](https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2-float16) | 600M | 1.23 GB | float16 | ≈0.6 GB |
56
+ | [JustFrederik/nllb-200-distilled-600M-ct2-int8](https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2-int8) | 600M | 623 MB | int8 | ≈0.6 GB |
57
+ | [JustFrederik/nllb-200-distilled-1.3B-ct2-float16](https://huggingface.co/JustFrederik/nllb-200-distilled-1.3B-ct2-float16) | 1.3B | 2.74 GB | float16 | ≈1.3 GB |
58
+ | [JustFrederik/nllb-200-distilled-1.3B-ct2-int8](https://huggingface.co/JustFrederik/nllb-200-distilled-1.3B-ct2-int8) | 1.3B | 1.38 GB | int8 | ≈1.3 GB |
59
+ | [JustFrederik/nllb-200-distilled-1.3B-ct2](https://huggingface.co/JustFrederik/nllb-200-distilled-1.3B-ct2) | 1.3B | 5.49 GB | float32 | ≈1.3 GB |
60
+ | [JustFrederik/nllb-200-1.3B-ct2](https://huggingface.co/JustFrederik/nllb-200-1.3B-ct2) | 1.3B | 5.49 GB | float32 | ≈1.3 GB |
61
+ | [JustFrederik/nllb-200-3.3B-ct2-float16](https://huggingface.co/JustFrederik/nllb-200-3.3B-ct2-float16) | 3.3B | 6.69 GB | float16 | ≈3.2 GB |
62
+
63
+ ## MT5
64
+
65
+ mT5 is a multilingual pre-trained Text-to-Text Transformer introduced by Google Research in October 2020. It is a multilingual variant of the T5 model, pre-trained on datasets in 101 languages. Further fine-tuning is required to transform it into a translation model. The paper is titled "`mT5: A Massively Multilingual Pre-trained Text-to-Text Transformer`" ([arXiv:2010.11934](https://arxiv.org/abs/2010.11934)).
66
+ The 'mt5-zh-ja-en-trimmed' model is finetuned from Google's 'mt5-base' model. This model has a relatively good translation speed, but it only supports three languages: Chinese, Japanese, and English.
67
+
68
+ | Name | Parameters | Size | type/quantize | Required VRAM |
69
+ |------|------------|------|---------------|---------------|
70
+ | [mt5-base](https://huggingface.co/google/mt5-base) | N/A | 2.33 GB | float32 | N/A |
71
+ | [K024/mt5-zh-ja-en-trimmed](https://huggingface.co/K024/mt5-zh-ja-en-trimmed) | N/A | 1.32 GB | float32 | ≈1.4 GB |
72
+ | [engmatic-earth/mt5-zh-ja-en-trimmed-fine-tuned-v1](https://huggingface.co/engmatic-earth/mt5-zh-ja-en-trimmed-fine-tuned-v1) | N/A | 1.32 GB | float32 | ≈1.4 GB |
73
+
74
+ ## ALMA
75
+
76
+ ALMA is a many-to-many LLM-based translation model introduced by Haoran Xu and colleagues in September 2023. It is based on the fine-tuning of a large language model (LLaMA-2). The approach used for this model is referred to as Advanced Language Model-based trAnslator (ALMA). The paper is titled "`A Paradigm Shift in Machine Translation: Boosting Translation Performance of Large Language Models`" ([arXiv:2309.11674](https://arxiv.org/abs/2309.11674)).
77
+ The official support for ALMA currently includes 10 language directions: English↔German, English↔Czech, English↔Icelandic, English↔Chinese, and English↔Russian. However, the author hints that there might be surprises in other directions, so there are currently no restrictions on the languages that ALMA can be chosen for in the web UI.
78
+
79
+ | Name | Parameters | Size | type/quantize | Required VRAM |
80
+ |------|------------|------|---------------|---------------|
81
+ | [haoranxu/ALMA-7B](https://huggingface.co/haoranxu/ALMA-7B) | 7B | 26.95 GB | float32 | N/A |
82
+ | [haoranxu/ALMA-13B](https://huggingface.co/haoranxu/ALMA-13B) | 13B | 52.07 GB | float32 | N/A |
83
+
84
+ ## ALMA-GPTQ
85
+
86
+ GPTQ is a technique used to quantize the parameters of large language models into integer formats such as int8 or int4. Although the quantization process may lead to a loss in model performance, it significantly reduces both file size and the required VRAM.
87
+
88
+ | Name | Parameters | Size | type/quantize | Required VRAM |
89
+ |------|------------|------|---------------|---------------|
90
+ | [TheBloke/ALMA-7B-GPTQ](https://huggingface.co/TheBloke/ALMA-7B-GPTQ) | 7B | 3.9 GB | 4 Bits | ≈4.3 GB |
91
+ | [TheBloke/ALMA-13B-GPTQ](https://huggingface.co/TheBloke/ALMA-13B-GPTQ) | 13B | 7.26 GB | 4 Bits | ≈8.1 |
92
+
93
+
94
+ # Options
95
+
96
+ ## Translation - Batch Size
97
+ - transformers: batch_size
98
+ When the pipeline will use DataLoader (when passing a dataset, on GPU for a Pytorch model), the size of the batch to use, for inference this is not always beneficial.
99
+ - ctranslate2: max_batch_size
100
+ The maximum batch size.
101
+
102
+ ## Translation - No Repeat Ngram Size
103
+ - transformers: no_repeat_ngram_size
104
+ Value that will be used by default in the generate method of the model for no_repeat_ngram_size. If set to int > 0, all ngrams of that size can only occur once.
105
+ - ctranslate2: no_repeat_ngram_size
106
+ Prevent repetitions of ngrams with this size (set 0 to disable).
107
+
108
+ ## Translation - Num Beams
109
+ - transformers: num_beams
110
+ Number of beams for beam search that will be used by default in the generate method of the model. 1 means no beam search.
111
+ - ctranslate2: beam_size
112
+ Beam size (1 for greedy search).
requirements-fasterWhisper.txt CHANGED
@@ -15,4 +15,9 @@ sentencepiece
15
  intervaltree
16
  srt
17
  torch
18
- https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
 
 
 
 
 
 
15
  intervaltree
16
  srt
17
  torch
18
+ https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
19
+
20
+ # Needed by ALMA(GPTQ)
21
+ accelerate
22
+ auto-gptq
23
+ optimum
requirements-whisper.txt CHANGED
@@ -1,4 +1,5 @@
1
  git+https://github.com/huggingface/transformers
 
2
  git+https://github.com/openai/whisper.git
3
  ffmpeg-python==0.2.0
4
  gradio==3.50.2
@@ -13,4 +14,9 @@ sentencepiece
13
  intervaltree
14
  srt
15
  torch
16
- https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
 
 
 
 
 
 
1
  git+https://github.com/huggingface/transformers
2
+ ctranslate2>=3.21.0
3
  git+https://github.com/openai/whisper.git
4
  ffmpeg-python==0.2.0
5
  gradio==3.50.2
 
14
  intervaltree
15
  srt
16
  torch
17
+ https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
18
+
19
+ # Needed by ALMA(GPTQ)
20
+ accelerate
21
+ auto-gptq
22
+ optimum
requirements.txt CHANGED
@@ -15,4 +15,9 @@ sentencepiece
15
  intervaltree
16
  srt
17
  torch
18
- https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
 
 
 
 
 
 
15
  intervaltree
16
  srt
17
  torch
18
+ https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
19
+
20
+ # Needed by ALMA(GPTQ)
21
+ accelerate
22
+ auto-gptq
23
+ optimum
src/config.py CHANGED
@@ -5,7 +5,7 @@ from typing import List, Dict, Literal
5
 
6
 
7
  class ModelConfig:
8
- def __init__(self, name: str, url: str, path: str = None, type: str = "whisper", tokenizer_url: str = None):
9
  """
10
  Initialize a model configuration.
11
 
@@ -13,12 +13,17 @@ class ModelConfig:
13
  url: URL to download the model from
14
  path: Path to the model file. If not set, the model will be downloaded from the URL.
15
  type: Type of model. Can be whisper or huggingface.
 
 
 
 
16
  """
17
  self.name = name
18
  self.url = url
19
  self.path = path
20
  self.type = type
21
  self.tokenizer_url = tokenizer_url
 
22
 
23
  VAD_INITIAL_PROMPT_MODE_VALUES=["prepend_all_segments", "prepend_first_segment", "json_prompt_mode"]
24
 
 
5
 
6
 
7
  class ModelConfig:
8
+ def __init__(self, name: str, url: str, path: str = None, type: str = "whisper", tokenizer_url: str = None, revision: str = None):
9
  """
10
  Initialize a model configuration.
11
 
 
13
  url: URL to download the model from
14
  path: Path to the model file. If not set, the model will be downloaded from the URL.
15
  type: Type of model. Can be whisper or huggingface.
16
+ revision: [by transformers] The specific model version to use.
17
+ It can be a branch name, a tag name, or a commit id,
18
+ since we use a git-based system for storing models and other artifacts on huggingface.co,
19
+ so revision can be any identifier allowed by git.
20
  """
21
  self.name = name
22
  self.url = url
23
  self.path = path
24
  self.type = type
25
  self.tokenizer_url = tokenizer_url
26
+ self.revision = revision
27
 
28
  VAD_INITIAL_PROMPT_MODE_VALUES=["prepend_all_segments", "prepend_first_segment", "json_prompt_mode"]
29
 
src/translation/translationLangs.py CHANGED
@@ -1,4 +1,6 @@
1
- class Lang():
 
 
2
  def __init__(self, code: str, *names: str):
3
  self.code = code
4
  self.names = names
@@ -292,12 +294,30 @@ def get_lang_whisper_names():
292
  """Return a list of whisper language names."""
293
  return list(_TO_LANG_NAME_WHISPER.keys())
294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  if __name__ == "__main__":
296
  # Test lookup
297
  print("name:Chinese (Traditional)", get_lang_from_nllb_name("Chinese (Traditional)"))
298
  print("name:moldavian", get_lang_from_m2m100_name("moldavian"))
299
  print("code:ja", get_lang_from_whisper_code("ja"))
300
  print("name:English", get_lang_from_nllb_name('English'))
 
301
 
302
  print(get_lang_m2m100_names(["en", "ja", "zh"]))
303
- print(get_lang_nllb_names())
 
 
1
+ from functools import cmp_to_key
2
+
3
+ class Lang():
4
  def __init__(self, code: str, *names: str):
5
  self.code = code
6
  self.names = names
 
294
  """Return a list of whisper language names."""
295
  return list(_TO_LANG_NAME_WHISPER.keys())
296
 
297
+ def sort_lang_by_whisper_codes(specified_order: list = []):
298
+ def sort_by_whisper_code(lang: TranslationLang, specified_order: list):
299
+ return (specified_order.index(lang.whisper.code), lang.whisper.names[0]) if lang.whisper.code in specified_order else (len(specified_order), lang.whisper.names[0])
300
+
301
+ def cmp_by_whisper_code(lang1: TranslationLang, lang2: TranslationLang):
302
+ val1 = sort_by_whisper_code(lang1, specified_order)
303
+ val2 = sort_by_whisper_code(lang2, specified_order)
304
+ if val1 > val2:
305
+ return 1
306
+ elif val1 == val2:
307
+ return 0
308
+ else: return -1
309
+
310
+ sorted_translations = sorted(_TO_LANG_NAME_WHISPER.values(), key=cmp_to_key(cmp_by_whisper_code))
311
+ return list({name.lower(): None for language in sorted_translations for name in language.whisper.names}.keys())
312
+
313
  if __name__ == "__main__":
314
  # Test lookup
315
  print("name:Chinese (Traditional)", get_lang_from_nllb_name("Chinese (Traditional)"))
316
  print("name:moldavian", get_lang_from_m2m100_name("moldavian"))
317
  print("code:ja", get_lang_from_whisper_code("ja"))
318
  print("name:English", get_lang_from_nllb_name('English'))
319
+ print("\n\n")
320
 
321
  print(get_lang_m2m100_names(["en", "ja", "zh"]))
322
+ print("\n\n")
323
+ print(sort_lang_by_whisper_codes(["en", "de", "cs", "is", "ru", "zh", "ja"]))
src/translation/translationModel.py CHANGED
@@ -3,11 +3,9 @@ import warnings
3
  import huggingface_hub
4
  import requests
5
  import torch
6
-
7
  import ctranslate2
8
  import transformers
9
-
10
- import re
11
 
12
  from typing import Optional
13
  from src.config import ModelConfig
@@ -85,84 +83,175 @@ class TranslationModel:
85
  self.load_model()
86
 
87
  def load_model(self):
88
- print('\n\nLoading model: %s\n\n' % self.modelPath)
89
- if "ct2" in self.modelPath:
90
- if "nllb" in self.modelPath:
91
- self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.nllb.code)
92
- self.targetPrefix = [self.translationLang.nllb.code]
93
- elif "m2m100" in self.modelPath:
94
- self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.m2m100.code)
95
- self.targetPrefix = [self.transTokenizer.lang_code_to_token[self.translationLang.m2m100.code]]
96
- self.transModel = ctranslate2.Translator(self.modelPath, compute_type="auto", device=self.device)
97
- elif "mt5" in self.modelPath:
98
- self.mt5Prefix = self.whisperLang.whisper.code + "2" + self.translationLang.whisper.code + ": "
99
- self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) #requires spiece.model
100
- self.transModel = transformers.MT5ForConditionalGeneration.from_pretrained(self.modelPath)
101
- self.transTranslator = transformers.pipeline('text2text-generation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer)
102
- elif "ALMA" in self.modelPath:
103
- self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.code + " to " + self.translationLang.whisper.code + ":" + self.whisperLang.whisper.code + ":"
104
- self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath, use_fast=True)
105
- self.transModel = transformers.AutoModelForCausalLM.from_pretrained(self.modelPath, device_map="auto", trust_remote_code=False, revision="main")
106
- self.transTranslator = transformers.pipeline("text-generation", model=self.transModel, tokenizer=self.transTokenizer, batch_size=2, do_sample=True, temperature=0.7, top_p=0.95, top_k=40, repetition_penalty=1.1)
107
- else:
108
- self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath)
109
- self.transModel = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.modelPath)
110
- if "m2m100" in self.modelPath:
111
- self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.m2m100.code, tgt_lang=self.translationLang.m2m100.code)
112
- else: #NLLB
113
- self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.nllb.code, tgt_lang=self.translationLang.nllb.code)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  def release_vram(self):
116
  try:
117
  if torch.cuda.is_available():
118
  if "ct2" not in self.modelPath:
119
- device = torch.device("cpu")
120
- self.transModel.to(device)
 
 
 
 
 
 
121
  del self.transModel
122
- torch.cuda.empty_cache()
 
 
 
 
 
 
123
  print("release vram end.")
124
  except Exception as e:
125
  print("Error release vram: " + str(e))
126
 
127
 
128
  def translation(self, text: str, max_length: int = 400):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  output = None
130
  result = None
131
  try:
132
  if "ct2" in self.modelPath:
133
- source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(text))
134
- output = self.transModel.translate_batch([source], target_prefix=[self.targetPrefix], max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams)
135
- target = output[0].hypotheses[0][1:]
136
- result = self.transTokenizer.decode(self.transTokenizer.convert_tokens_to_ids(target))
 
 
 
 
 
 
137
  elif "mt5" in self.modelPath:
138
  output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
139
  result = output[0]['generated_text']
140
  elif "ALMA" in self.modelPath:
141
- output = self.transTranslator(self.ALMAPrefix + text + self.translationLang.whisper.code + ":", max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
142
  result = output[0]['generated_text']
143
- result = re.sub(rf'^(.*{self.translationLang.whisper.code}: )', '', result) # Remove the prompt from the result
144
- result = re.sub(rf'^(Translate this from .* to .*:)', '', result) # Remove the translation instruction
145
- return result.strip()
146
  else: #M2M100 & NLLB
147
  output = self.transTranslator(text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
148
  result = output[0]['translation_text']
149
  except Exception as e:
 
150
  print("Error translation text: " + str(e))
151
 
152
  return result
153
 
154
 
155
- _MODELS = ["distilled-600M", "distilled-1.3B", "1.3B", "3.3B",
156
- "ct2fast-nllb-200-distilled-1.3B-int8_float16",
157
- "ct2fast-nllb-200-3.3B-int8_float16",
158
- "nllb-200-3.3B-ct2-float16", "nllb-200-1.3B-ct2", "nllb-200-1.3B-ct2-int8", "nllb-200-1.3B-ct2-float16",
159
- "nllb-200-distilled-1.3B-ct2", "nllb-200-distilled-1.3B-ct2-int8", "nllb-200-distilled-1.3B-ct2-float16",
160
- "nllb-200-distilled-600M-ct2", "nllb-200-distilled-600M-ct2-int8", "nllb-200-distilled-600M-ct2-float16",
161
- "m2m100_1.2B-ct2", "m2m100_418M-ct2", "m2m100-12B-ct2",
162
- "m2m100_1.2B", "m2m100_418M",
163
- "mt5-zh-ja-en-trimmed",
164
- "mt5-zh-ja-en-trimmed-fine-tuned-v1",
165
- "ALMA-13B-GPTQ"]
166
 
167
  def check_model_name(name):
168
  return any(allowed_name in name for allowed_name in _MODELS)
@@ -230,6 +319,9 @@ def download_model(
230
  "allow_patterns": allowPatterns,
231
  #"tqdm_class": disabled_tqdm,
232
  }
 
 
 
233
 
234
  if outputDir is not None:
235
  kwargs["local_dir"] = outputDir
 
3
  import huggingface_hub
4
  import requests
5
  import torch
 
6
  import ctranslate2
7
  import transformers
8
+ import traceback
 
9
 
10
  from typing import Optional
11
  from src.config import ModelConfig
 
83
  self.load_model()
84
 
85
  def load_model(self):
86
+ """
87
+ [from_pretrained]
88
+ low_cpu_mem_usage(bool, optional)
89
+ Tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. This is an experimental feature and a subject to change at any moment.
90
+
91
+ [transformers.AutoTokenizer.from_pretrained]
92
+ use_fast (bool, optional, defaults to True):
93
+ Use a fast Rust-based tokenizer if it is supported for a given model.
94
+ If a fast tokenizer is not available for a given model, a normal Python-based tokenizer is returned instead.
95
+
96
+ [transformers.AutoModelForCausalLM.from_pretrained]
97
+ device_map (str or Dict[str, Union[int, str, torch.device], optional):
98
+ Sent directly as model_kwargs (just a simpler shortcut). When accelerate library is present,
99
+ set device_map="auto" to compute the most optimized device_map automatically.
100
+ revision (str, optional, defaults to "main"):
101
+ The specific model version to use. It can be a branch name, a tag name, or a commit id,
102
+ since we use a git-based system for storing models and other artifacts on huggingface.co,
103
+ so revision can be any identifier allowed by git.
104
+ code_revision (str, optional, defaults to "main")
105
+ The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model.
106
+ It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co,
107
+ so revision can be any identifier allowed by git.
108
+ trust_remote_code (bool, optional, defaults to False):
109
+ Whether or not to allow for custom models defined on the Hub in their own modeling files.
110
+ This option should only be set to True for repositories you trust and in which you have read the code,
111
+ as it will execute code present on the Hub on your local machine.
112
+
113
+ [transformers.pipeline "text-generation"]
114
+ do_sample:
115
+ if set to True, this parameter enables decoding strategies such as multinomial sampling,
116
+ beam-search multinomial sampling, Top-K sampling and Top-p sampling.
117
+ All these strategies select the next token from the probability distribution
118
+ over the entire vocabulary with various strategy-specific adjustments.
119
+ temperature (float, optional, defaults to 1.0):
120
+ The value used to modulate the next token probabilities.
121
+ top_k (int, optional, defaults to 50):
122
+ The number of highest probability vocabulary tokens to keep for top-k-filtering.
123
+ top_p (float, optional, defaults to 1.0):
124
+ If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
125
+ repetition_penalty (float, optional, defaults to 1.0)
126
+ The parameter for repetition penalty. 1.0 means no penalty. See this paper for more details.
127
+ """
128
+ try:
129
+ print('\n\nLoading model: %s\n\n' % self.modelPath)
130
+ if "ct2" in self.modelPath:
131
+ if any(name in self.modelPath for name in ["nllb", "m2m100"]):
132
+ if "nllb" in self.modelPath:
133
+ self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.nllb.code)
134
+ self.targetPrefix = [self.translationLang.nllb.code]
135
+ elif "m2m100" in self.modelPath:
136
+ self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.m2m100.code)
137
+ self.targetPrefix = [self.transTokenizer.lang_code_to_token[self.translationLang.m2m100.code]]
138
+ self.transModel = ctranslate2.Translator(self.modelPath, compute_type="auto", device=self.device)
139
+ elif "ALMA" in self.modelPath:
140
+ self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath)
141
+ self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.names[0] + " to " + self.translationLang.whisper.names[0] + ":\n" + self.whisperLang.whisper.names[0] + ": "
142
+ self.transModel = ctranslate2.Generator(self.modelPath, device=self.device)
143
+ elif "mt5" in self.modelPath:
144
+ self.mt5Prefix = self.whisperLang.whisper.code + "2" + self.translationLang.whisper.code + ": "
145
+ self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) #requires spiece.model
146
+ self.transModel = transformers.MT5ForConditionalGeneration.from_pretrained(self.modelPath, low_cpu_mem_usage=True)
147
+ self.transTranslator = transformers.pipeline('text2text-generation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer)
148
+ elif "ALMA" in self.modelPath:
149
+ self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.names[0] + " to " + self.translationLang.whisper.names[0] + ":\n" + self.whisperLang.whisper.names[0] + ": "
150
+ self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath, use_fast=True)
151
+ self.transModel = transformers.AutoModelForCausalLM.from_pretrained(self.modelPath, device_map="auto", low_cpu_mem_usage=True, trust_remote_code=False, revision=self.modelConfig.revision)
152
+ self.transTranslator = transformers.pipeline("text-generation", model=self.transModel, tokenizer=self.transTokenizer, do_sample=True, temperature=0.7, top_k=40, top_p=0.95, repetition_penalty=1.1)
153
+ else:
154
+ self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath)
155
+ self.transModel = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.modelPath)
156
+ if "m2m100" in self.modelPath:
157
+ self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.m2m100.code, tgt_lang=self.translationLang.m2m100.code)
158
+ else: #NLLB
159
+ self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.nllb.code, tgt_lang=self.translationLang.nllb.code)
160
+
161
+ except Exception as e:
162
+ print(traceback.format_exc())
163
+ self.release_vram()
164
 
165
  def release_vram(self):
166
  try:
167
  if torch.cuda.is_available():
168
  if "ct2" not in self.modelPath:
169
+ try:
170
+ device = torch.device("cpu")
171
+ self.transModel.to(device)
172
+ except Exception as e:
173
+ print(traceback.format_exc())
174
+ print("\tself.transModel.to cpu, error: " + str(e))
175
+ del self.transTranslator
176
+ del self.transTokenizer
177
  del self.transModel
178
+ try:
179
+ torch.cuda.empty_cache()
180
+ except Exception as e:
181
+ print(traceback.format_exc())
182
+ print("\tcuda empty cache, error: " + str(e))
183
+ import gc
184
+ gc.collect()
185
  print("release vram end.")
186
  except Exception as e:
187
  print("Error release vram: " + str(e))
188
 
189
 
190
  def translation(self, text: str, max_length: int = 400):
191
+ """
192
+ [ctranslate2]
193
+ max_batch_size:
194
+ The maximum batch size. If the number of inputs is greater than max_batch_size,
195
+ the inputs are sorted by length and split by chunks of max_batch_size examples
196
+ so that the number of padding positions is minimized.
197
+ no_repeat_ngram_size:
198
+ Prevent repetitions of ngrams with this size (set 0 to disable).
199
+ beam_size:
200
+ Beam size (1 for greedy search).
201
+
202
+ [ctranslate2.Generator.generate_batch]
203
+ sampling_temperature:
204
+ Sampling temperature to generate more random samples.
205
+ sampling_topk:
206
+ Randomly sample predictions from the top K candidates.
207
+ sampling_topp:
208
+ Keep the most probable tokens whose cumulative probability exceeds this value.
209
+ repetition_penalty:
210
+ Penalty applied to the score of previously generated tokens (set > 1 to penalize).
211
+ include_prompt_in_result:
212
+ Include the start_tokens in the result.
213
+ If include_prompt_in_result is True (the default), the decoding loop is constrained to generate the start tokens that are then included in the result.
214
+ If include_prompt_in_result is False, the start tokens are forwarded in the decoder at once to initialize its state (i.e. the KV cache for Transformer models).
215
+ For variable-length inputs, only the tokens up to the minimum length in the batch are forwarded at once. The remaining tokens are generated in the decoding loop with constrained decoding.
216
+
217
+ [transformers.TextGenerationPipeline.__call__]
218
+ return_full_text (bool, optional, defaults to True):
219
+ If set to False only added text is returned, otherwise the full text is returned. Only meaningful if return_text is set to True.
220
+ """
221
  output = None
222
  result = None
223
  try:
224
  if "ct2" in self.modelPath:
225
+ if any(name in self.modelPath for name in ["nllb", "m2m100"]):
226
+ source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(text))
227
+ output = self.transModel.translate_batch([source], target_prefix=[self.targetPrefix], max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams)
228
+ target = output[0].hypotheses[0][1:]
229
+ result = self.transTokenizer.decode(self.transTokenizer.convert_tokens_to_ids(target))
230
+ elif "ALMA" in self.modelPath:
231
+ source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(self.ALMAPrefix + text + "\n" + self.translationLang.whisper.names[0] + ": "))
232
+ output = self.transModel.generate_batch([source], max_length=max_length, max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams, sampling_temperature=0.7, sampling_topp=0.9, repetition_penalty=1.1, include_prompt_in_result=False) #, sampling_topk=40
233
+ target = output[0]
234
+ result = self.transTokenizer.decode(target.sequences_ids[0])
235
  elif "mt5" in self.modelPath:
236
  output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
237
  result = output[0]['generated_text']
238
  elif "ALMA" in self.modelPath:
239
+ output = self.transTranslator(self.ALMAPrefix + text + "\n" + self.translationLang.whisper.names[0] + ": ", max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams, return_full_text=False)
240
  result = output[0]['generated_text']
 
 
 
241
  else: #M2M100 & NLLB
242
  output = self.transTranslator(text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
243
  result = output[0]['translation_text']
244
  except Exception as e:
245
+ print(traceback.format_exc())
246
  print("Error translation text: " + str(e))
247
 
248
  return result
249
 
250
 
251
+ _MODELS = ["nllb-200",
252
+ "m2m100",
253
+ "mt5",
254
+ "ALMA"]
 
 
 
 
 
 
 
255
 
256
  def check_model_name(name):
257
  return any(allowed_name in name for allowed_name in _MODELS)
 
319
  "allow_patterns": allowPatterns,
320
  #"tqdm_class": disabled_tqdm,
321
  }
322
+
323
+ if modelConfig.revision is not None:
324
+ kwargs["revision"] = modelConfig.revision
325
 
326
  if outputDir is not None:
327
  kwargs["local_dir"] = outputDir
src/utils.py CHANGED
@@ -130,7 +130,7 @@ def write_srt_original(transcript: Iterator[dict], file: TextIO,
130
  flush=True,
131
  )
132
 
133
- if original is not None: print(f"{original}\n",
134
  file=file,
135
  flush=True)
136
 
 
130
  flush=True,
131
  )
132
 
133
+ if original is not None: print(f"{original}",
134
  file=file,
135
  flush=True)
136