File size: 9,089 Bytes
b6ac700 1e744c4 b6ac700 1e744c4 b6ac700 1e744c4 b6ac700 1e744c4 b6ac700 1e744c4 b6ac700 1e744c4 b6ac700 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
# External programs
import abc
import os
import sys
from typing import List
from urllib.parse import urlparse
import torch
import urllib3
from src.hooks.progressListener import ProgressListener
import whisper
from whisper import Whisper
from src.config import ModelConfig, VadInitialPromptMode
from src.hooks.whisperProgressHook import create_progress_listener_handle
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
from src.utils import download_file
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
class WhisperContainer(AbstractWhisperContainer):
def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
download_root: str = None,
cache: ModelCache = None, models: List[ModelConfig] = []):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
super().__init__(model_name, device, compute_type, download_root, cache, models)
def ensure_downloaded(self):
"""
Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
passing the container to a subprocess.
"""
# Warning: Using private API here
try:
root_dir = self.download_root
model_config = self._get_model_config()
if root_dir is None:
root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
if self.model_name in whisper._MODELS:
whisper._download(whisper._MODELS[self.model_name], root_dir, False)
else:
# If the model is not in the official list, see if it needs to be downloaded
model_config.download_url(root_dir)
return True
except Exception as e:
# Given that the API is private, it could change at any time. We don't want to crash the program
print("Error pre-downloading model: " + str(e))
return False
def _get_model_config(self) -> ModelConfig:
"""
Get the model configuration for the model.
"""
for model in self.models:
if model.name == self.model_name:
return model
return None
def _create_model(self):
print("Loading whisper model " + self.model_name)
model_config = self._get_model_config()
# Note that the model will not be downloaded in the case of an official Whisper model
model_path = self._get_model_path(model_config, self.download_root)
return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
def create_callback(self, languageCode: str = None, task: str = None,
prompt_strategy: AbstractPromptStrategy = None,
**decodeOptions: dict) -> AbstractWhisperCallback:
"""
Create a WhisperCallback object that can be used to transcript audio files.
Parameters
----------
languageCode: str
The target language code of the transcription. If not specified, the language will be inferred from the audio content.
task: str
The task - either translate or transcribe.
prompt_strategy: AbstractPromptStrategy
The prompt strategy to use. If not specified, the prompt from Whisper will be used.
decodeOptions: dict
Additional options to pass to the decoder. Must be pickleable.
Returns
-------
A WhisperCallback object.
"""
return WhisperCallback(self, languageCode=languageCode, task=task, prompt_strategy=prompt_strategy, **decodeOptions)
def _get_model_path(self, model_config: ModelConfig, root_dir: str = None):
from src.conversion.hf_converter import convert_hf_whisper
"""
Download the model.
Parameters
----------
model_config: ModelConfig
The model configuration.
"""
# See if path is already set
if model_config.path is not None:
return model_config.path
if root_dir is None:
root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
model_type = model_config.type.lower() if model_config.type is not None else "whisper"
if model_type in ["huggingface", "hf"]:
model_config.path = model_config.url
destination_target = os.path.join(root_dir, model_config.name + ".pt")
# Convert from HuggingFace format to Whisper format
if os.path.exists(destination_target):
print(f"File {destination_target} already exists, skipping conversion")
else:
print("Saving HuggingFace model in Whisper format to " + destination_target)
convert_hf_whisper(model_config.url, destination_target)
model_config.path = destination_target
elif model_type in ["whisper", "w"]:
model_config.path = model_config.url
# See if URL is just a file
if model_config.url in whisper._MODELS:
# No need to download anything - Whisper will handle it
model_config.path = model_config.url
elif model_config.url.startswith("file://"):
# Get file path
model_config.path = urlparse(model_config.url).path
# See if it is an URL
elif model_config.url.startswith("http://") or model_config.url.startswith("https://"):
# Extension (or file name)
extension = os.path.splitext(model_config.url)[-1]
download_target = os.path.join(root_dir, model_config.name + extension)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if not os.path.isfile(download_target):
download_file(model_config.url, download_target)
else:
print(f"File {download_target} already exists, skipping download")
model_config.path = download_target
# Must be a local file
else:
model_config.path = model_config.url
else:
raise ValueError(f"Unknown model type {model_type}")
return model_config.path
class WhisperCallback(AbstractWhisperCallback):
def __init__(self, model_container: WhisperContainer, languageCode: str = None, task: str = None,
prompt_strategy: AbstractPromptStrategy = None,
**decodeOptions: dict):
self.model_container = model_container
self.languageCode = languageCode
self.task = task
self.prompt_strategy = prompt_strategy
self.decodeOptions = decodeOptions
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
"""
Peform the transcription of the given audio file or data.
Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor]
The audio file to transcribe, or the audio data as a numpy array or torch tensor.
segment_index: int
The target language of the transcription. If not specified, the language will be inferred from the audio content.
task: str
The task - either translate or transcribe.
progress_listener: ProgressListener
A callback to receive progress updates.
"""
model = self.model_container.get_model()
if progress_listener is not None:
with create_progress_listener_handle(progress_listener):
return self._transcribe(model, audio, segment_index, prompt, detected_language)
else:
return self._transcribe(model, audio, segment_index, prompt, detected_language)
def _transcribe(self, model: Whisper, audio, segment_index: int, prompt: str, detected_language: str):
decodeOptions = self.decodeOptions.copy()
# Add fp16
if self.model_container.compute_type in ["fp16", "float16"]:
decodeOptions["fp16"] = True
initial_prompt = self.prompt_strategy.get_segment_prompt(segment_index, prompt, detected_language) \
if self.prompt_strategy else prompt
result = model.transcribe(audio, \
language=self.languageCode if self.languageCode else detected_language, task=self.task, \
initial_prompt=initial_prompt, \
**decodeOptions
)
# If we have a prompt strategy, we need to increment the current prompt
if self.prompt_strategy:
self.prompt_strategy.on_segment_finished(segment_index, prompt, detected_language, result)
return result |