SoybeanMilk commited on
Commit
e8762f9
·
1 Parent(s): 50167d4

Add support for the ALMA model.

Browse files
Files changed (5) hide show
  1. app.py +19 -0
  2. config.json5 +7 -0
  3. src/config.py +2 -2
  4. src/translation/translationModel.py +18 -1
  5. src/utils.py +1 -1
app.py CHANGED
@@ -231,6 +231,8 @@ class WhisperTranscriber:
231
  nllbLangName: str = decodeOptions.pop("nllbLangName")
232
  mt5ModelName: str = decodeOptions.pop("mt5ModelName")
233
  mt5LangName: str = decodeOptions.pop("mt5LangName")
 
 
234
 
235
  translationBatchSize: int = decodeOptions.pop("translationBatchSize")
236
  translationNoRepeatNgramSize: int = decodeOptions.pop("translationNoRepeatNgramSize")
@@ -337,6 +339,10 @@ class WhisperTranscriber:
337
  selectedModelName = mt5ModelName if mt5ModelName is not None and len(mt5ModelName) > 0 else "mt5-zh-ja-en-trimmed/K024"
338
  selectedModel = next((modelConfig for modelConfig in self.app_config.models["mt5"] if modelConfig.name == selectedModelName), None)
339
  translationLang = get_lang_from_m2m100_name(mt5LangName)
 
 
 
 
340
 
341
  if translationLang is not None:
342
  translationModel = TranslationModel(modelConfig=selectedModel, whisperLang=whisperLang, translationLang=translationLang, batchSize=translationBatchSize, noRepeatNgramSize=translationNoRepeatNgramSize, numBeams=translationNumBeams)
@@ -828,6 +834,7 @@ def create_ui(app_config: ApplicationConfig):
828
  nllb_models = app_config.get_model_names("nllb")
829
  m2m100_models = app_config.get_model_names("m2m100")
830
  mt5_models = app_config.get_model_names("mt5")
 
831
 
832
  common_whisper_inputs = lambda : {
833
  gr.Dropdown(label="Whisper - Model (for audio)", choices=whisper_models, value=app_config.default_model_name, elem_id="whisperModelName"),
@@ -845,6 +852,10 @@ def create_ui(app_config: ApplicationConfig):
845
  gr.Dropdown(label="MT5 - Model (for translate)", choices=mt5_models, elem_id="mt5ModelName"),
846
  gr.Dropdown(label="MT5 - Language", choices=sorted(get_lang_m2m100_names(["en", "ja", "zh"])), elem_id="mt5LangName"),
847
  }
 
 
 
 
848
 
849
  common_translation_inputs = lambda : {
850
  gr.Number(label="Translation - Batch Size", precision=0, value=app_config.translation_batch_size, elem_id="translationBatchSize"),
@@ -905,9 +916,13 @@ def create_ui(app_config: ApplicationConfig):
905
  with gr.Tab(label="MT5") as simpleMT5Tab:
906
  with gr.Row():
907
  simpleInputDict.update(common_mt5_inputs())
 
 
 
908
  simpleM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [simpleTranslateInput] )
909
  simpleNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [simpleTranslateInput] )
910
  simpleMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [simpleTranslateInput] )
 
911
  with gr.Column():
912
  with gr.Tab(label="URL") as simpleUrlTab:
913
  simpleInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
@@ -964,9 +979,13 @@ def create_ui(app_config: ApplicationConfig):
964
  with gr.Tab(label="MT5") as fullMT5Tab:
965
  with gr.Row():
966
  fullInputDict.update(common_mt5_inputs())
 
 
 
967
  fullM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [fullTranslateInput] )
968
  fullNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [fullTranslateInput] )
969
  fullMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [fullTranslateInput] )
 
970
  with gr.Column():
971
  with gr.Tab(label="URL") as fullUrlTab:
972
  fullInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
 
231
  nllbLangName: str = decodeOptions.pop("nllbLangName")
232
  mt5ModelName: str = decodeOptions.pop("mt5ModelName")
233
  mt5LangName: str = decodeOptions.pop("mt5LangName")
234
+ ALMAModelName: str = decodeOptions.pop("ALMAModelName")
235
+ ALMALangName: str = decodeOptions.pop("ALMALangName")
236
 
237
  translationBatchSize: int = decodeOptions.pop("translationBatchSize")
238
  translationNoRepeatNgramSize: int = decodeOptions.pop("translationNoRepeatNgramSize")
 
339
  selectedModelName = mt5ModelName if mt5ModelName is not None and len(mt5ModelName) > 0 else "mt5-zh-ja-en-trimmed/K024"
340
  selectedModel = next((modelConfig for modelConfig in self.app_config.models["mt5"] if modelConfig.name == selectedModelName), None)
341
  translationLang = get_lang_from_m2m100_name(mt5LangName)
342
+ elif translateInput == "ALMA" and ALMALangName is not None and len(ALMALangName) > 0:
343
+ selectedModelName = ALMAModelName if ALMAModelName is not None and len(ALMAModelName) > 0 else "ALMA-13B-GPTQ/TheBloke"
344
+ selectedModel = next((modelConfig for modelConfig in self.app_config.models["ALMA"] if modelConfig.name == selectedModelName), None)
345
+ translationLang = get_lang_from_m2m100_name(ALMALangName)
346
 
347
  if translationLang is not None:
348
  translationModel = TranslationModel(modelConfig=selectedModel, whisperLang=whisperLang, translationLang=translationLang, batchSize=translationBatchSize, noRepeatNgramSize=translationNoRepeatNgramSize, numBeams=translationNumBeams)
 
834
  nllb_models = app_config.get_model_names("nllb")
835
  m2m100_models = app_config.get_model_names("m2m100")
836
  mt5_models = app_config.get_model_names("mt5")
837
+ ALMA_models = app_config.get_model_names("ALMA")
838
 
839
  common_whisper_inputs = lambda : {
840
  gr.Dropdown(label="Whisper - Model (for audio)", choices=whisper_models, value=app_config.default_model_name, elem_id="whisperModelName"),
 
852
  gr.Dropdown(label="MT5 - Model (for translate)", choices=mt5_models, elem_id="mt5ModelName"),
853
  gr.Dropdown(label="MT5 - Language", choices=sorted(get_lang_m2m100_names(["en", "ja", "zh"])), elem_id="mt5LangName"),
854
  }
855
+ common_ALMA_inputs = lambda : {
856
+ gr.Dropdown(label="ALMA - Model (for translate)", choices=ALMA_models, elem_id="ALMAModelName"),
857
+ gr.Dropdown(label="ALMA - Language", choices=sorted(get_lang_m2m100_names(["en", "ja", "zh"])), elem_id="ALMALangName"),
858
+ }
859
 
860
  common_translation_inputs = lambda : {
861
  gr.Number(label="Translation - Batch Size", precision=0, value=app_config.translation_batch_size, elem_id="translationBatchSize"),
 
916
  with gr.Tab(label="MT5") as simpleMT5Tab:
917
  with gr.Row():
918
  simpleInputDict.update(common_mt5_inputs())
919
+ with gr.Tab(label="ALMA") as simpleALMATab:
920
+ with gr.Row():
921
+ simpleInputDict.update(common_ALMA_inputs())
922
  simpleM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [simpleTranslateInput] )
923
  simpleNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [simpleTranslateInput] )
924
  simpleMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [simpleTranslateInput] )
925
+ simpleALMATab.select(fn=lambda: "ALMA", inputs = [], outputs= [simpleTranslateInput] )
926
  with gr.Column():
927
  with gr.Tab(label="URL") as simpleUrlTab:
928
  simpleInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
 
979
  with gr.Tab(label="MT5") as fullMT5Tab:
980
  with gr.Row():
981
  fullInputDict.update(common_mt5_inputs())
982
+ with gr.Tab(label="ALMA") as fullALMATab:
983
+ with gr.Row():
984
+ fullInputDict.update(common_ALMA_inputs())
985
  fullM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [fullTranslateInput] )
986
  fullNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [fullTranslateInput] )
987
  fullMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [fullTranslateInput] )
988
+ fullALMATab.select(fn=lambda: "ALMA", inputs = [], outputs= [fullTranslateInput] )
989
  with gr.Column():
990
  with gr.Tab(label="URL") as fullUrlTab:
991
  fullInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
config.json5 CHANGED
@@ -191,6 +191,13 @@
191
  "url": "engmatic-earth/mt5-zh-ja-en-trimmed-fine-tuned-v1",
192
  "type": "huggingface"
193
  }
 
 
 
 
 
 
 
194
  ]
195
  },
196
  // Configuration options that will be used if they are not specified in the command line arguments.
 
191
  "url": "engmatic-earth/mt5-zh-ja-en-trimmed-fine-tuned-v1",
192
  "type": "huggingface"
193
  }
194
+ ],
195
+ "ALMA": [
196
+ {
197
+ "name": "ALMA-13B-GPTQ/TheBloke",
198
+ "url": "TheBloke/ALMA-13B-GPTQ",
199
+ "type": "huggingface",
200
+ },
201
  ]
202
  },
203
  // Configuration options that will be used if they are not specified in the command line arguments.
src/config.py CHANGED
@@ -43,7 +43,7 @@ class VadInitialPromptMode(Enum):
43
  return None
44
 
45
  class ApplicationConfig:
46
- def __init__(self, models: Dict[Literal["whisper", "m2m100", "nllb", "mt5"], List[ModelConfig]],
47
  input_audio_max_duration: int = 600, share: bool = False, server_name: str = None, server_port: int = 7860,
48
  queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
49
  whisper_implementation: str = "whisper", default_model_name: str = "medium",
@@ -169,7 +169,7 @@ class ApplicationConfig:
169
  # Load using json5
170
  data = json5.load(f)
171
  data_models = data.pop("models", [])
172
- models: Dict[Literal["whisper", "m2m100", "nllb", "mt5"], List[ModelConfig]] = {
173
  key: [ModelConfig(**item) for item in value]
174
  for key, value in data_models.items()
175
  }
 
43
  return None
44
 
45
  class ApplicationConfig:
46
+ def __init__(self, models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA"], List[ModelConfig]],
47
  input_audio_max_duration: int = 600, share: bool = False, server_name: str = None, server_port: int = 7860,
48
  queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
49
  whisper_implementation: str = "whisper", default_model_name: str = "medium",
 
169
  # Load using json5
170
  data = json5.load(f)
171
  data_models = data.pop("models", [])
172
+ models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA"], List[ModelConfig]] = {
173
  key: [ModelConfig(**item) for item in value]
174
  for key, value in data_models.items()
175
  }
src/translation/translationModel.py CHANGED
@@ -7,6 +7,8 @@ import torch
7
  import ctranslate2
8
  import transformers
9
 
 
 
10
  from typing import Optional
11
  from src.config import ModelConfig
12
  from src.translation.translationLangs import TranslationLang, get_lang_from_whisper_code
@@ -97,6 +99,11 @@ class TranslationModel:
97
  self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) #requires spiece.model
98
  self.transModel = transformers.MT5ForConditionalGeneration.from_pretrained(self.modelPath)
99
  self.transTranslator = transformers.pipeline('text2text-generation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer)
 
 
 
 
 
100
  else:
101
  self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath)
102
  self.transModel = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.modelPath)
@@ -130,6 +137,12 @@ class TranslationModel:
130
  elif "mt5" in self.modelPath:
131
  output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
132
  result = output[0]['generated_text']
 
 
 
 
 
 
133
  else: #M2M100 & NLLB
134
  output = self.transTranslator(text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
135
  result = output[0]['translation_text']
@@ -148,7 +161,8 @@ _MODELS = ["distilled-600M", "distilled-1.3B", "1.3B", "3.3B",
148
  "m2m100_1.2B-ct2", "m2m100_418M-ct2", "m2m100-12B-ct2",
149
  "m2m100_1.2B", "m2m100_418M",
150
  "mt5-zh-ja-en-trimmed",
151
- "mt5-zh-ja-en-trimmed-fine-tuned-v1"]
 
152
 
153
  def check_model_name(name):
154
  return any(allowed_name in name for allowed_name in _MODELS)
@@ -206,6 +220,9 @@ def download_model(
206
  "special_tokens_map.json",
207
  "spiece.model",
208
  "vocab.json", #m2m100
 
 
 
209
  ]
210
 
211
  kwargs = {
 
7
  import ctranslate2
8
  import transformers
9
 
10
+ import re
11
+
12
  from typing import Optional
13
  from src.config import ModelConfig
14
  from src.translation.translationLangs import TranslationLang, get_lang_from_whisper_code
 
99
  self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) #requires spiece.model
100
  self.transModel = transformers.MT5ForConditionalGeneration.from_pretrained(self.modelPath)
101
  self.transTranslator = transformers.pipeline('text2text-generation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer)
102
+ elif "ALMA" in self.modelPath:
103
+ self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.code + " to " + self.translationLang.whisper.code + ":" + self.whisperLang.whisper.code + ":"
104
+ self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath, use_fast=True)
105
+ self.transModel = transformers.AutoModelForCausalLM.from_pretrained(self.modelPath, device_map="auto", trust_remote_code=False, revision="main")
106
+ self.transTranslator = transformers.pipeline("text-generation", model=self.transModel, tokenizer=self.transTokenizer, batch_size=2, do_sample=True, temperature=0.7, top_p=0.95, top_k=40, repetition_penalty=1.1)
107
  else:
108
  self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath)
109
  self.transModel = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.modelPath)
 
137
  elif "mt5" in self.modelPath:
138
  output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
139
  result = output[0]['generated_text']
140
+ elif "ALMA" in self.modelPath:
141
+ output = self.transTranslator(self.ALMAPrefix + text + self.translationLang.whisper.code + ":", max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
142
+ result = output[0]['generated_text']
143
+ result = re.sub(rf'^(.*{self.translationLang.whisper.code}: )', '', result) # Remove the prompt from the result
144
+ result = re.sub(rf'^(Translate this from .* to .*:)', '', result) # Remove the translation instruction
145
+ return result.strip()
146
  else: #M2M100 & NLLB
147
  output = self.transTranslator(text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
148
  result = output[0]['translation_text']
 
161
  "m2m100_1.2B-ct2", "m2m100_418M-ct2", "m2m100-12B-ct2",
162
  "m2m100_1.2B", "m2m100_418M",
163
  "mt5-zh-ja-en-trimmed",
164
+ "mt5-zh-ja-en-trimmed-fine-tuned-v1",
165
+ "ALMA-13B-GPTQ"]
166
 
167
  def check_model_name(name):
168
  return any(allowed_name in name for allowed_name in _MODELS)
 
220
  "special_tokens_map.json",
221
  "spiece.model",
222
  "vocab.json", #m2m100
223
+ "model.safetensors",
224
+ "quantize_config.json",
225
+ "tokenizer.model"
226
  ]
227
 
228
  kwargs = {
src/utils.py CHANGED
@@ -130,7 +130,7 @@ def write_srt_original(transcript: Iterator[dict], file: TextIO,
130
  flush=True,
131
  )
132
 
133
- if original is not None: print(f"{original}",
134
  file=file,
135
  flush=True)
136
 
 
130
  flush=True,
131
  )
132
 
133
+ if original is not None: print(f"{original}\n",
134
  file=file,
135
  flush=True)
136