aadnk commited on
Commit
6250a98
·
1 Parent(s): 052fe7e

Adding a dummy implementation of Whisper for testing

Browse files
app.py CHANGED
@@ -624,4 +624,5 @@ if __name__ == '__main__':
624
  if (threads := args.pop("threads")) > 0:
625
  torch.set_num_threads(threads)
626
 
 
627
  create_ui(app_config=updated_config)
 
624
  if (threads := args.pop("threads")) > 0:
625
  torch.set_num_threads(threads)
626
 
627
+ print("Using whisper implementation: " + updated_config.whisper_implementation)
628
  create_ui(app_config=updated_config)
src/whisper/dummyWhisperContainer.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import ffmpeg
4
+ from src.config import ModelConfig
5
+ from src.hooks.progressListener import ProgressListener
6
+ from src.modelCache import ModelCache
7
+ from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
8
+ from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
9
+
10
+ class DummyWhisperContainer(AbstractWhisperContainer):
11
+ def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
12
+ download_root: str = None,
13
+ cache: ModelCache = None, models: List[ModelConfig] = []):
14
+ super().__init__(model_name, device, compute_type, download_root, cache, models)
15
+
16
+ def ensure_downloaded(self):
17
+ """
18
+ Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
19
+ passing the container to a subprocess.
20
+ """
21
+ print("[Dummy] Ensuring that the model is downloaded")
22
+
23
+ def _create_model(self):
24
+ print("[Dummy] Creating dummy whisper model " + self.model_name + " for device " + str(self.device))
25
+ return None
26
+
27
+ def create_callback(self, language: str = None, task: str = None,
28
+ prompt_strategy: AbstractPromptStrategy = None,
29
+ **decodeOptions: dict) -> AbstractWhisperCallback:
30
+ """
31
+ Create a WhisperCallback object that can be used to transcript audio files.
32
+
33
+ Parameters
34
+ ----------
35
+ language: str
36
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
37
+ task: str
38
+ The task - either translate or transcribe.
39
+ prompt_strategy: AbstractPromptStrategy
40
+ The prompt strategy to use. If not specified, the prompt from Whisper will be used.
41
+ decodeOptions: dict
42
+ Additional options to pass to the decoder. Must be pickleable.
43
+
44
+ Returns
45
+ -------
46
+ A WhisperCallback object.
47
+ """
48
+ return DummyWhisperCallback(self, language=language, task=task, prompt_strategy=prompt_strategy, **decodeOptions)
49
+
50
+ class DummyWhisperCallback(AbstractWhisperCallback):
51
+ def __init__(self, model_container: DummyWhisperContainer, **decodeOptions: dict):
52
+ self.model_container = model_container
53
+ self.decodeOptions = decodeOptions
54
+
55
+ def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
56
+ """
57
+ Peform the transcription of the given audio file or data.
58
+
59
+ Parameters
60
+ ----------
61
+ audio: Union[str, np.ndarray, torch.Tensor]
62
+ The audio file to transcribe, or the audio data as a numpy array or torch tensor.
63
+ segment_index: int
64
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
65
+ task: str
66
+ The task - either translate or transcribe.
67
+ progress_listener: ProgressListener
68
+ A callback to receive progress updates.
69
+ """
70
+ print("[Dummy] Invoking dummy whisper callback for segment " + str(segment_index))
71
+
72
+ # Estimate length
73
+ if isinstance(audio, str):
74
+ audio_length = ffmpeg.probe(audio)["format"]["duration"]
75
+ # Format is pcm_s16le at a sample rate of 16000, loaded as a float32 array.
76
+ else:
77
+ audio_length = len(audio) / 16000
78
+
79
+ # Convert the segments to a format that is easier to serialize
80
+ whisper_segments = [{
81
+ "text": "Dummy text for segment " + str(segment_index),
82
+ "start": 0,
83
+ "end": audio_length,
84
+
85
+ # Extra fields added by faster-whisper
86
+ "words": []
87
+ }]
88
+
89
+ result = {
90
+ "segments": whisper_segments,
91
+ "text": "Dummy text for segment " + str(segment_index),
92
+ "language": "en" if detected_language is None else detected_language,
93
+
94
+ # Extra fields added by faster-whisper
95
+ "language_probability": 1.0,
96
+ "duration": audio_length,
97
+ }
98
+
99
+ if progress_listener is not None:
100
+ progress_listener.on_finished()
101
+ return result
src/whisper/whisperFactory.py CHANGED
@@ -15,5 +15,9 @@ def create_whisper_container(whisper_implementation: str,
15
  elif (whisper_implementation == "faster-whisper" or whisper_implementation == "faster_whisper"):
16
  from src.whisper.fasterWhisperContainer import FasterWhisperContainer
17
  return FasterWhisperContainer(model_name=model_name, device=device, compute_type=compute_type, download_root=download_root, cache=cache, models=models)
 
 
 
 
18
  else:
19
  raise ValueError("Unknown Whisper implementation: " + whisper_implementation)
 
15
  elif (whisper_implementation == "faster-whisper" or whisper_implementation == "faster_whisper"):
16
  from src.whisper.fasterWhisperContainer import FasterWhisperContainer
17
  return FasterWhisperContainer(model_name=model_name, device=device, compute_type=compute_type, download_root=download_root, cache=cache, models=models)
18
+ elif (whisper_implementation == "dummy-whisper" or whisper_implementation == "dummy_whisper" or whisper_implementation == "dummy"):
19
+ # This is useful for testing
20
+ from src.whisper.dummyWhisperContainer import DummyWhisperContainer
21
+ return DummyWhisperContainer(model_name=model_name, device=device, compute_type=compute_type, download_root=download_root, cache=cache, models=models)
22
  else:
23
  raise ValueError("Unknown Whisper implementation: " + whisper_implementation)