SoybeanMilk commited on
Commit
4e2f72e
·
1 Parent(s): 2def7a1

Add madlad400 support.

Browse files
Files changed (4) hide show
  1. app.py +19 -0
  2. config.json5 +14 -0
  3. src/config.py +2 -2
  4. src/translation/translationModel.py +11 -1
app.py CHANGED
@@ -233,6 +233,8 @@ class WhisperTranscriber:
233
  mt5LangName: str = decodeOptions.pop("mt5LangName")
234
  ALMAModelName: str = decodeOptions.pop("ALMAModelName")
235
  ALMALangName: str = decodeOptions.pop("ALMALangName")
 
 
236
 
237
  translationBatchSize: int = decodeOptions.pop("translationBatchSize")
238
  translationNoRepeatNgramSize: int = decodeOptions.pop("translationNoRepeatNgramSize")
@@ -368,6 +370,10 @@ class WhisperTranscriber:
368
  selectedModelName = ALMAModelName if ALMAModelName is not None and len(ALMAModelName) > 0 else "ALMA-13B-GPTQ/TheBloke"
369
  selectedModel = next((modelConfig for modelConfig in self.app_config.models["ALMA"] if modelConfig.name == selectedModelName), None)
370
  translationLang = get_lang_from_m2m100_name(ALMALangName)
 
 
 
 
371
 
372
  if translationLang is not None:
373
  translationModel = TranslationModel(modelConfig=selectedModel, whisperLang=whisperLang, translationLang=translationLang, batchSize=translationBatchSize, noRepeatNgramSize=translationNoRepeatNgramSize, numBeams=translationNumBeams)
@@ -929,6 +935,7 @@ def create_ui(app_config: ApplicationConfig):
929
  m2m100_models = app_config.get_model_names("m2m100")
930
  mt5_models = app_config.get_model_names("mt5")
931
  ALMA_models = app_config.get_model_names("ALMA")
 
932
  if not torch.cuda.is_available(): #Due to the poor support of GPTQ for CPUs, the execution time per iteration exceeds a thousand seconds when operating on a CPU. Therefore, when the system does not support a GPU, the GPTQ model is removed from the list.
933
  ALMA_models = list(filter(lambda alma: "GPTQ" not in alma, ALMA_models))
934
 
@@ -952,6 +959,10 @@ def create_ui(app_config: ApplicationConfig):
952
  gr.Dropdown(label="ALMA - Model (for translate)", choices=ALMA_models, elem_id="ALMAModelName"),
953
  gr.Dropdown(label="ALMA - Language", choices=sort_lang_by_whisper_codes(["en", "de", "cs", "is", "ru", "zh", "ja"]), elem_id="ALMALangName"),
954
  }
 
 
 
 
955
 
956
  common_translation_inputs = lambda : {
957
  gr.Number(label="Translation - Batch Size", precision=0, value=app_config.translation_batch_size, elem_id="translationBatchSize"),
@@ -1036,10 +1047,14 @@ def create_ui(app_config: ApplicationConfig):
1036
  with gr.Tab(label="ALMA") as simpleALMATab:
1037
  with gr.Row():
1038
  simpleInputDict.update(common_ALMA_inputs())
 
 
 
1039
  simpleM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [simpleTranslateInput] )
1040
  simpleNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [simpleTranslateInput] )
1041
  simpleMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [simpleTranslateInput] )
1042
  simpleALMATab.select(fn=lambda: "ALMA", inputs = [], outputs= [simpleTranslateInput] )
 
1043
  with gr.Column():
1044
  with gr.Tab(label="URL") as simpleUrlTab:
1045
  simpleInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
@@ -1103,10 +1118,14 @@ def create_ui(app_config: ApplicationConfig):
1103
  with gr.Tab(label="ALMA") as fullALMATab:
1104
  with gr.Row():
1105
  fullInputDict.update(common_ALMA_inputs())
 
 
 
1106
  fullM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [fullTranslateInput] )
1107
  fullNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [fullTranslateInput] )
1108
  fullMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [fullTranslateInput] )
1109
  fullALMATab.select(fn=lambda: "ALMA", inputs = [], outputs= [fullTranslateInput] )
 
1110
  with gr.Column():
1111
  with gr.Tab(label="URL") as fullUrlTab:
1112
  fullInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
 
233
  mt5LangName: str = decodeOptions.pop("mt5LangName")
234
  ALMAModelName: str = decodeOptions.pop("ALMAModelName")
235
  ALMALangName: str = decodeOptions.pop("ALMALangName")
236
+ madlad400ModelName: str = decodeOptions.pop("madlad400ModelName")
237
+ madlad400LangName: str = decodeOptions.pop("madlad400LangName")
238
 
239
  translationBatchSize: int = decodeOptions.pop("translationBatchSize")
240
  translationNoRepeatNgramSize: int = decodeOptions.pop("translationNoRepeatNgramSize")
 
370
  selectedModelName = ALMAModelName if ALMAModelName is not None and len(ALMAModelName) > 0 else "ALMA-13B-GPTQ/TheBloke"
371
  selectedModel = next((modelConfig for modelConfig in self.app_config.models["ALMA"] if modelConfig.name == selectedModelName), None)
372
  translationLang = get_lang_from_m2m100_name(ALMALangName)
373
+ elif translateInput == "madlad400" and madlad400LangName is not None and len(madlad400LangName) > 0:
374
+ selectedModelName = madlad400ModelName if madlad400ModelName is not None and len(madlad400ModelName) > 0 else "madlad400-10b-mt-ct2-int8_float16"
375
+ selectedModel = next((modelConfig for modelConfig in self.app_config.models["madlad400"] if modelConfig.name == selectedModelName), None)
376
+ translationLang = get_lang_from_m2m100_name(madlad400LangName)
377
 
378
  if translationLang is not None:
379
  translationModel = TranslationModel(modelConfig=selectedModel, whisperLang=whisperLang, translationLang=translationLang, batchSize=translationBatchSize, noRepeatNgramSize=translationNoRepeatNgramSize, numBeams=translationNumBeams)
 
935
  m2m100_models = app_config.get_model_names("m2m100")
936
  mt5_models = app_config.get_model_names("mt5")
937
  ALMA_models = app_config.get_model_names("ALMA")
938
+ madlad400_models = app_config.get_model_names("madlad400")
939
  if not torch.cuda.is_available(): #Due to the poor support of GPTQ for CPUs, the execution time per iteration exceeds a thousand seconds when operating on a CPU. Therefore, when the system does not support a GPU, the GPTQ model is removed from the list.
940
  ALMA_models = list(filter(lambda alma: "GPTQ" not in alma, ALMA_models))
941
 
 
959
  gr.Dropdown(label="ALMA - Model (for translate)", choices=ALMA_models, elem_id="ALMAModelName"),
960
  gr.Dropdown(label="ALMA - Language", choices=sort_lang_by_whisper_codes(["en", "de", "cs", "is", "ru", "zh", "ja"]), elem_id="ALMALangName"),
961
  }
962
+ common_madlad400_inputs = lambda : {
963
+ gr.Dropdown(label="madlad400 - Model (for translate)", choices=madlad400_models, elem_id="madlad400ModelName"),
964
+ gr.Dropdown(label="madlad400 - Language", choices=sorted(get_lang_m2m100_names()), elem_id="madlad400LangName"),
965
+ }
966
 
967
  common_translation_inputs = lambda : {
968
  gr.Number(label="Translation - Batch Size", precision=0, value=app_config.translation_batch_size, elem_id="translationBatchSize"),
 
1047
  with gr.Tab(label="ALMA") as simpleALMATab:
1048
  with gr.Row():
1049
  simpleInputDict.update(common_ALMA_inputs())
1050
+ with gr.Tab(label="madlad400") as simplemadlad400Tab:
1051
+ with gr.Row():
1052
+ simpleInputDict.update(common_madlad400_inputs())
1053
  simpleM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [simpleTranslateInput] )
1054
  simpleNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [simpleTranslateInput] )
1055
  simpleMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [simpleTranslateInput] )
1056
  simpleALMATab.select(fn=lambda: "ALMA", inputs = [], outputs= [simpleTranslateInput] )
1057
+ simplemadlad400Tab.select(fn=lambda: "madlad400", inputs = [], outputs= [simpleTranslateInput] )
1058
  with gr.Column():
1059
  with gr.Tab(label="URL") as simpleUrlTab:
1060
  simpleInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
 
1118
  with gr.Tab(label="ALMA") as fullALMATab:
1119
  with gr.Row():
1120
  fullInputDict.update(common_ALMA_inputs())
1121
+ with gr.Tab(label="madlad400") as fullmadlad400Tab:
1122
+ with gr.Row():
1123
+ fullInputDict.update(common_madlad400_inputs())
1124
  fullM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [fullTranslateInput] )
1125
  fullNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [fullTranslateInput] )
1126
  fullMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [fullTranslateInput] )
1127
  fullALMATab.select(fn=lambda: "ALMA", inputs = [], outputs= [fullTranslateInput] )
1128
+ fullmadlad400Tab.select(fn=lambda: "madlad400", inputs = [], outputs= [fullTranslateInput] )
1129
  with gr.Column():
1130
  with gr.Tab(label="URL") as fullUrlTab:
1131
  fullInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
config.json5 CHANGED
@@ -229,6 +229,20 @@
229
  "type": "huggingface",
230
  "tokenizer_url": "haoranxu/ALMA-13B"
231
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  ]
233
  },
234
  // Configuration options that will be used if they are not specified in the command line arguments.
 
229
  "type": "huggingface",
230
  "tokenizer_url": "haoranxu/ALMA-13B"
231
  },
232
+ ],
233
+ "madlad400": [
234
+ {
235
+ "name": "madlad400-3b-mt-ct2-int8_float16/SoybeanMilk",
236
+ "url": "SoybeanMilk/madlad400-3b-mt-ct2-int8_float16",
237
+ "type": "huggingface",
238
+ "tokenizer_url": "jbochi/madlad400-3b-mt"
239
+ },
240
+ {
241
+ "name": "madlad400-10b-mt-ct2-int8_float16/SoybeanMilk",
242
+ "url": "SoybeanMilk/madlad400-10b-mt-ct2-int8_float16",
243
+ "type": "huggingface",
244
+ "tokenizer_url": "jbochi/madlad400-10b-mt"
245
+ },
246
  ]
247
  },
248
  // Configuration options that will be used if they are not specified in the command line arguments.
src/config.py CHANGED
@@ -50,7 +50,7 @@ class VadInitialPromptMode(Enum):
50
  return None
51
 
52
  class ApplicationConfig:
53
- def __init__(self, models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA"], List[ModelConfig]],
54
  input_audio_max_duration: int = 600, share: bool = False, server_name: str = None, server_port: int = 7860,
55
  queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
56
  whisper_implementation: str = "whisper", default_model_name: str = "medium",
@@ -181,7 +181,7 @@ class ApplicationConfig:
181
  # Load using json5
182
  data = json5.load(f)
183
  data_models = data.pop("models", [])
184
- models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA"], List[ModelConfig]] = {
185
  key: [ModelConfig(**item) for item in value]
186
  for key, value in data_models.items()
187
  }
 
50
  return None
51
 
52
  class ApplicationConfig:
53
+ def __init__(self, models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA", "madlad400"], List[ModelConfig]],
54
  input_audio_max_duration: int = 600, share: bool = False, server_name: str = None, server_port: int = 7860,
55
  queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
56
  whisper_implementation: str = "whisper", default_model_name: str = "medium",
 
181
  # Load using json5
182
  data = json5.load(f)
183
  data_models = data.pop("models", [])
184
+ models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA", "madlad400"], List[ModelConfig]] = {
185
  key: [ModelConfig(**item) for item in value]
186
  for key, value in data_models.items()
187
  }
src/translation/translationModel.py CHANGED
@@ -159,6 +159,10 @@ class TranslationModel:
159
  self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath)
160
  self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.names[0] + " to " + self.translationLang.whisper.names[0] + ":\n" + self.whisperLang.whisper.names[0] + ": "
161
  self.transModel = ctranslate2.Generator(self.modelPath, compute_type="auto", device=self.device)
 
 
 
 
162
  elif "mt5" in self.modelPath:
163
  self.mt5Prefix = self.whisperLang.whisper.code + "2" + self.translationLang.whisper.code + ": "
164
  self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) #requires spiece.model
@@ -277,6 +281,11 @@ class TranslationModel:
277
  output = self.transModel.generate_batch([source], max_length=max_length, max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams, sampling_temperature=0.7, sampling_topp=0.9, repetition_penalty=1.1, include_prompt_in_result=False) #, sampling_topk=40
278
  target = output[0]
279
  result = self.transTokenizer.decode(target.sequences_ids[0])
 
 
 
 
 
280
  elif "mt5" in self.modelPath:
281
  output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
282
  result = output[0]['generated_text']
@@ -299,7 +308,8 @@ class TranslationModel:
299
  _MODELS = ["nllb-200",
300
  "m2m100",
301
  "mt5",
302
- "ALMA"]
 
303
 
304
  def check_model_name(name):
305
  return any(allowed_name in name for allowed_name in _MODELS)
 
159
  self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath)
160
  self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.names[0] + " to " + self.translationLang.whisper.names[0] + ":\n" + self.whisperLang.whisper.names[0] + ": "
161
  self.transModel = ctranslate2.Generator(self.modelPath, compute_type="auto", device=self.device)
162
+ elif "madlad400" in self.modelPath:
163
+ self.madlad400Prefix = "<2" + self.translationLang.whisper.code + "> "
164
+ self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.m2m100.code)
165
+ self.transModel = ctranslate2.Translator(self.modelPath, compute_type="auto", device=self.device)
166
  elif "mt5" in self.modelPath:
167
  self.mt5Prefix = self.whisperLang.whisper.code + "2" + self.translationLang.whisper.code + ": "
168
  self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) #requires spiece.model
 
281
  output = self.transModel.generate_batch([source], max_length=max_length, max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams, sampling_temperature=0.7, sampling_topp=0.9, repetition_penalty=1.1, include_prompt_in_result=False) #, sampling_topk=40
282
  target = output[0]
283
  result = self.transTokenizer.decode(target.sequences_ids[0])
284
+ elif "madlad400" in self.modelPath:
285
+ source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(self.madlad400Prefix + text))
286
+ output = self.transModel.translate_batch([source], max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams)
287
+ target = output[0].hypotheses[0]
288
+ result = self.transTokenizer.decode(self.transTokenizer.convert_tokens_to_ids(target))
289
  elif "mt5" in self.modelPath:
290
  output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
291
  result = output[0]['generated_text']
 
308
  _MODELS = ["nllb-200",
309
  "m2m100",
310
  "mt5",
311
+ "ALMA",
312
+ "madlad400"]
313
 
314
  def check_model_name(name):
315
  return any(allowed_name in name for allowed_name in _MODELS)