|
import os |
|
import warnings |
|
import huggingface_hub |
|
import requests |
|
import torch |
|
|
|
import ctranslate2 |
|
import transformers |
|
|
|
import re |
|
|
|
from typing import Optional |
|
from src.config import ModelConfig |
|
from src.translation.translationLangs import TranslationLang, get_lang_from_whisper_code |
|
|
|
class TranslationModel: |
|
def __init__( |
|
self, |
|
modelConfig: ModelConfig, |
|
device: str = None, |
|
whisperLang: TranslationLang = None, |
|
translationLang: TranslationLang = None, |
|
batchSize: int = 2, |
|
noRepeatNgramSize: int = 3, |
|
numBeams: int = 2, |
|
downloadRoot: Optional[str] = None, |
|
localFilesOnly: bool = False, |
|
loadModel: bool = False, |
|
): |
|
"""Initializes the M2M100 / Nllb-200 / mt5 model. |
|
|
|
Args: |
|
modelConfig: Config of the model to use (distilled-600M, distilled-1.3B, |
|
1.3B, 3.3B...) or a path to a converted |
|
model directory. When a size is configured, the converted model is downloaded |
|
from the Hugging Face Hub. |
|
device: Device to use for computation (cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, |
|
ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia). |
|
device_index: Device ID to use. |
|
The model can also be loaded on multiple GPUs by passing a list of IDs |
|
(e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can run in parallel |
|
when transcribe() is called from multiple Python threads (see also num_workers). |
|
compute_type: Type to use for computation. |
|
See https://opennmt.net/CTranslate2/quantization.html. |
|
cpu_threads: Number of threads to use when running on CPU (4 by default). |
|
A non zero value overrides the OMP_NUM_THREADS environment variable. |
|
num_workers: When transcribe() is called from multiple Python threads, |
|
having multiple workers enables true parallelism when running the model |
|
(concurrent calls to self.model.generate() will run in parallel). |
|
This can improve the global throughput at the cost of increased memory usage. |
|
downloadRoot: Directory where the models should be saved. If not set, the models |
|
are saved in the standard Hugging Face cache directory. |
|
localFilesOnly: If True, avoid downloading the file and return the path to the |
|
local cached file if it exists. |
|
""" |
|
self.modelConfig = modelConfig |
|
self.whisperLang = whisperLang |
|
self.translationLang = translationLang |
|
|
|
if translationLang is None: |
|
return |
|
|
|
self.batchSize = batchSize |
|
self.noRepeatNgramSize = noRepeatNgramSize |
|
self.numBeams = numBeams |
|
|
|
if os.path.isdir(modelConfig.url): |
|
self.modelPath = modelConfig.url |
|
else: |
|
self.modelPath = download_model( |
|
modelConfig, |
|
localFilesOnly=localFilesOnly, |
|
cacheDir=downloadRoot, |
|
) |
|
|
|
if device is None: |
|
if torch.cuda.is_available(): |
|
device = "cuda" if "ct2" in self.modelPath else "cuda:0" |
|
else: |
|
device = "cpu" |
|
|
|
self.device = device |
|
|
|
if loadModel: |
|
self.load_model() |
|
|
|
def load_model(self): |
|
print('\n\nLoading model: %s\n\n' % self.modelPath) |
|
if "ct2" in self.modelPath: |
|
if "nllb" in self.modelPath: |
|
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.nllb.code) |
|
self.targetPrefix = [self.translationLang.nllb.code] |
|
elif "m2m100" in self.modelPath: |
|
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.m2m100.code) |
|
self.targetPrefix = [self.transTokenizer.lang_code_to_token[self.translationLang.m2m100.code]] |
|
self.transModel = ctranslate2.Translator(self.modelPath, compute_type="auto", device=self.device) |
|
elif "mt5" in self.modelPath: |
|
self.mt5Prefix = self.whisperLang.whisper.code + "2" + self.translationLang.whisper.code + ": " |
|
self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) |
|
self.transModel = transformers.MT5ForConditionalGeneration.from_pretrained(self.modelPath) |
|
self.transTranslator = transformers.pipeline('text2text-generation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer) |
|
elif "ALMA" in self.modelPath: |
|
self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.code + " to " + self.translationLang.whisper.code + ":" + self.whisperLang.whisper.code + ":" |
|
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath, use_fast=True) |
|
self.transModel = transformers.AutoModelForCausalLM.from_pretrained(self.modelPath, device_map="auto", trust_remote_code=False, revision="main") |
|
self.transTranslator = transformers.pipeline("text-generation", model=self.transModel, tokenizer=self.transTokenizer, batch_size=2, do_sample=True, temperature=0.7, top_p=0.95, top_k=40, repetition_penalty=1.1) |
|
else: |
|
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath) |
|
self.transModel = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.modelPath) |
|
if "m2m100" in self.modelPath: |
|
self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.m2m100.code, tgt_lang=self.translationLang.m2m100.code) |
|
else: |
|
self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.nllb.code, tgt_lang=self.translationLang.nllb.code) |
|
|
|
def release_vram(self): |
|
try: |
|
if torch.cuda.is_available(): |
|
if "ct2" not in self.modelPath: |
|
device = torch.device("cpu") |
|
self.transModel.to(device) |
|
del self.transModel |
|
torch.cuda.empty_cache() |
|
print("release vram end.") |
|
except Exception as e: |
|
print("Error release vram: " + str(e)) |
|
|
|
|
|
def translation(self, text: str, max_length: int = 400): |
|
output = None |
|
result = None |
|
try: |
|
if "ct2" in self.modelPath: |
|
source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(text)) |
|
output = self.transModel.translate_batch([source], target_prefix=[self.targetPrefix], max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams) |
|
target = output[0].hypotheses[0][1:] |
|
result = self.transTokenizer.decode(self.transTokenizer.convert_tokens_to_ids(target)) |
|
elif "mt5" in self.modelPath: |
|
output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) |
|
result = output[0]['generated_text'] |
|
elif "ALMA" in self.modelPath: |
|
output = self.transTranslator(self.ALMAPrefix + text + self.translationLang.whisper.code + ":", max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) |
|
result = output[0]['generated_text'] |
|
result = re.sub(rf'^(.*{self.translationLang.whisper.code}: )', '', result) |
|
result = re.sub(rf'^(Translate this from .* to .*:)', '', result) |
|
return result.strip() |
|
else: |
|
output = self.transTranslator(text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) |
|
result = output[0]['translation_text'] |
|
except Exception as e: |
|
print("Error translation text: " + str(e)) |
|
|
|
return result |
|
|
|
|
|
_MODELS = ["distilled-600M", "distilled-1.3B", "1.3B", "3.3B", |
|
"ct2fast-nllb-200-distilled-1.3B-int8_float16", |
|
"ct2fast-nllb-200-3.3B-int8_float16", |
|
"nllb-200-3.3B-ct2-float16", "nllb-200-1.3B-ct2", "nllb-200-1.3B-ct2-int8", "nllb-200-1.3B-ct2-float16", |
|
"nllb-200-distilled-1.3B-ct2", "nllb-200-distilled-1.3B-ct2-int8", "nllb-200-distilled-1.3B-ct2-float16", |
|
"nllb-200-distilled-600M-ct2", "nllb-200-distilled-600M-ct2-int8", "nllb-200-distilled-600M-ct2-float16", |
|
"m2m100_1.2B-ct2", "m2m100_418M-ct2", "m2m100-12B-ct2", |
|
"m2m100_1.2B", "m2m100_418M", |
|
"mt5-zh-ja-en-trimmed", |
|
"mt5-zh-ja-en-trimmed-fine-tuned-v1", |
|
"ALMA-13B-GPTQ"] |
|
|
|
def check_model_name(name): |
|
return any(allowed_name in name for allowed_name in _MODELS) |
|
|
|
def download_model( |
|
modelConfig: ModelConfig, |
|
outputDir: Optional[str] = None, |
|
localFilesOnly: bool = False, |
|
cacheDir: Optional[str] = None, |
|
): |
|
""""download_model" is referenced from the "utils.py" script |
|
of the "faster_whisper" project, authored by guillaumekln. |
|
|
|
Downloads a nllb-200 model from the Hugging Face Hub. |
|
|
|
The model is downloaded from https://huggingface.co/facebook. |
|
|
|
Args: |
|
modelConfig: config of the model to download (facebook/nllb-distilled-600M, |
|
facebook/nllb-distilled-1.3B, facebook/nllb-1.3B, facebook/nllb-3.3B...). |
|
outputDir: Directory where the model should be saved. If not set, the model is saved in |
|
the cache directory. |
|
localFilesOnly: If True, avoid downloading the file and return the path to the local |
|
cached file if it exists. |
|
cacheDir: Path to the folder where cached files are stored. |
|
|
|
Returns: |
|
The path to the downloaded model. |
|
|
|
Raises: |
|
ValueError: if the model size is invalid. |
|
""" |
|
if not check_model_name(modelConfig.name): |
|
raise ValueError( |
|
"Invalid model name '%s', expected one of: %s" % (modelConfig.name, ", ".join(_MODELS)) |
|
) |
|
|
|
repoId = modelConfig.url |
|
|
|
allowPatterns = [ |
|
"config.json", |
|
"generation_config.json", |
|
"model.bin", |
|
"pytorch_model.bin", |
|
"pytorch_model.bin.index.json", |
|
"pytorch_model-*.bin", |
|
"pytorch_model-00001-of-00003.bin", |
|
"pytorch_model-00002-of-00003.bin", |
|
"pytorch_model-00003-of-00003.bin", |
|
"sentencepiece.bpe.model", |
|
"tokenizer.json", |
|
"tokenizer_config.json", |
|
"shared_vocabulary.txt", |
|
"shared_vocabulary.json", |
|
"special_tokens_map.json", |
|
"spiece.model", |
|
"vocab.json", |
|
"model.safetensors", |
|
"quantize_config.json", |
|
"tokenizer.model" |
|
] |
|
|
|
kwargs = { |
|
"local_files_only": localFilesOnly, |
|
"allow_patterns": allowPatterns, |
|
|
|
} |
|
|
|
if outputDir is not None: |
|
kwargs["local_dir"] = outputDir |
|
kwargs["local_dir_use_symlinks"] = False |
|
|
|
if cacheDir is not None: |
|
kwargs["cache_dir"] = cacheDir |
|
|
|
try: |
|
return huggingface_hub.snapshot_download(repoId, **kwargs) |
|
except ( |
|
huggingface_hub.utils.HfHubHTTPError, |
|
requests.exceptions.ConnectionError, |
|
) as exception: |
|
warnings.warn( |
|
"An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s", |
|
repoId, |
|
exception, |
|
) |
|
warnings.warn( |
|
"Trying to load the model directly from the local cache, if it exists." |
|
) |
|
|
|
kwargs["local_files_only"] = True |
|
return huggingface_hub.snapshot_download(repoId, **kwargs) |
|
|