avans06 commited on
Commit
6fc9a01
·
2 Parent(s): 4f4c582 3cd7d59

Merge branch 'main' of https://huggingface.co/spaces/aadnk/whisper-webui

Browse files

Changes:

* Using pyannote/speaker-diarization-3.0.
* Display diarization min speakers and diarization max speakers options in the simple tab.
* When diarization speakers are set to 0, initialize diarization using only the min and max speakers options.

README.md CHANGED
@@ -72,7 +72,7 @@ pip install -r requirements-fasterWhisper.txt
72
  ```
73
  And then run the App or the CLI with the `--whisper_implementation faster-whisper` flag:
74
  ```
75
- python app.py --whisper_implementation faster-whisper --input_audio_max_duration -1 --server_name 127.0.0.1 --auto_parallel True
76
  ```
77
  You can also select the whisper implementation in `config.json5`:
78
  ```json5
 
72
  ```
73
  And then run the App or the CLI with the `--whisper_implementation faster-whisper` flag:
74
  ```
75
+ python app.py --whisper_implementation faster-whisper --input_audio_max_duration -1 --server_name 127.0.0.1 --server_port 7860 --auto_parallel True
76
  ```
77
  You can also select the whisper implementation in `config.json5`:
78
  ```json5
app.py CHANGED
@@ -7,6 +7,7 @@ import argparse
7
  from io import StringIO
8
  import time
9
  import os
 
10
  import tempfile
11
  import zipfile
12
  import numpy as np
@@ -14,6 +15,8 @@ import numpy as np
14
  import torch
15
 
16
  from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
 
 
17
  from src.hooks.progressListener import ProgressListener
18
  from src.hooks.subTaskProgressListener import SubTaskProgressListener
19
  from src.hooks.whisperProgressHook import create_progress_listener_handle
@@ -33,7 +36,7 @@ import ffmpeg
33
  import gradio as gr
34
 
35
  from src.download import ExceededMaximumDuration, download_url
36
- from src.utils import optional_int, slugify, write_srt, write_vtt
37
  from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
38
  from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
39
  from src.whisper.whisperFactory import create_whisper_container
@@ -83,6 +86,10 @@ class WhisperTranscriber:
83
  self.deleteUploadedFiles = delete_uploaded_files
84
  self.output_dir = output_dir
85
 
 
 
 
 
86
  self.app_config = app_config
87
 
88
  def set_parallel_devices(self, vad_parallel_devices: str):
@@ -96,22 +103,49 @@ class WhisperTranscriber:
96
  self.vad_cpu_cores = min(os.cpu_count(), MAX_AUTO_CPU_CORES)
97
  print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  # Entry function for the simple tab
100
  def transcribe_webui_simple(self, modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
101
  vad, vadMergeWindow, vadMaxMergeSize,
102
- word_timestamps: bool = False, highlight_words: bool = False):
 
 
103
  return self.transcribe_webui_simple_progress(modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
104
- vad, vadMergeWindow, vadMaxMergeSize,
105
- word_timestamps, highlight_words)
 
 
106
 
107
  # Entry function for the simple tab progress
108
  def transcribe_webui_simple_progress(self, modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
109
  vad, vadMergeWindow, vadMaxMergeSize,
110
  word_timestamps: bool = False, highlight_words: bool = False,
 
 
111
  progress=gr.Progress()):
112
-
113
  vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, self.app_config.vad_padding, self.app_config.vad_prompt_window, self.app_config.vad_initial_prompt_mode)
114
 
 
 
 
 
 
 
 
 
115
  return self.transcribe_webui(modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task, vadOptions,
116
  word_timestamps=word_timestamps, highlight_words=highlight_words, progress=progress)
117
 
@@ -122,14 +156,18 @@ class WhisperTranscriber:
122
  word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
123
  initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
124
  condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
125
- compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float):
 
 
126
 
127
  return self.transcribe_webui_full_progress(modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
128
  vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
129
  word_timestamps, highlight_words, prepend_punctuations, append_punctuations,
130
  initial_prompt, temperature, best_of, beam_size, patience, length_penalty, suppress_tokens,
131
  condition_on_previous_text, fp16, temperature_increment_on_fallback,
132
- compression_ratio_threshold, logprob_threshold, no_speech_threshold)
 
 
133
 
134
  # Entry function for the full tab with progress
135
  def transcribe_webui_full_progress(self, modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
@@ -139,6 +177,8 @@ class WhisperTranscriber:
139
  initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
140
  condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
141
  compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
 
 
142
  progress=gr.Progress()):
143
 
144
  # Handle temperature_increment_on_fallback
@@ -149,6 +189,15 @@ class WhisperTranscriber:
149
 
150
  vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode)
151
 
 
 
 
 
 
 
 
 
 
152
  return self.transcribe_webui(modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task, vadOptions,
153
  initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
154
  condition_on_previous_text=condition_on_previous_text, fp16=fp16,
@@ -373,6 +422,19 @@ class WhisperTranscriber:
373
  else:
374
  # Default VAD
375
  result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
 
 
 
 
 
 
 
 
 
 
 
 
 
376
 
377
  return result
378
 
@@ -545,11 +607,15 @@ class WhisperTranscriber:
545
  if (self.cpu_parallel_context is not None):
546
  self.cpu_parallel_context.close()
547
 
 
 
 
 
548
 
549
  def create_ui(app_config: ApplicationConfig):
550
  ui = WhisperTranscriber(app_config.input_audio_max_duration, app_config.vad_process_timeout, app_config.vad_cpu_cores,
551
  app_config.delete_uploaded_files, app_config.output_dir, app_config)
552
-
553
  # Specify a list of devices to use for parallel processing
554
  ui.set_parallel_devices(app_config.vad_parallel_devices)
555
  ui.set_auto_parallel(app_config.auto_parallel)
@@ -619,6 +685,19 @@ def create_ui(app_config: ApplicationConfig):
619
  gr.Checkbox(label="Word Timestamps - Highlight Words", value=app_config.highlight_words),
620
  ]
621
 
 
 
 
 
 
 
 
 
 
 
 
 
 
622
  common_output = lambda : [
623
  gr.File(label="Download"),
624
  gr.Text(label="Transcription"),
@@ -640,7 +719,7 @@ def create_ui(app_config: ApplicationConfig):
640
  with gr.Row():
641
  simple_input += common_nllb_inputs()
642
  with gr.Column():
643
- simple_input += common_audio_inputs() + common_vad_inputs() + common_word_timestamps_inputs()
644
  with gr.Column():
645
  simple_output = common_output()
646
  simple_flag = gr.Button("Flag")
@@ -689,7 +768,7 @@ def create_ui(app_config: ApplicationConfig):
689
  gr.Number(label="Temperature increment on fallback", value=app_config.temperature_increment_on_fallback),
690
  gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold),
691
  gr.Number(label="Logprob threshold", value=app_config.logprob_threshold),
692
- gr.Number(label="No speech threshold", value=app_config.no_speech_threshold)]
693
 
694
  with gr.Column():
695
  full_output = common_output()
@@ -771,7 +850,14 @@ if __name__ == '__main__':
771
  help="Maximum length of a file name.")
772
  parser.add_argument("--autolaunch", action='store_true', \
773
  help="open the webui URL in the system's default browser upon launch")
774
-
 
 
 
 
 
 
 
775
 
776
  args = parser.parse_args().__dict__
777
 
@@ -788,4 +874,5 @@ if __name__ == '__main__':
788
  if (threads := args.pop("threads")) > 0:
789
  torch.set_num_threads(threads)
790
 
 
791
  create_ui(app_config=updated_config)
 
7
  from io import StringIO
8
  import time
9
  import os
10
+ import pathlib
11
  import tempfile
12
  import zipfile
13
  import numpy as np
 
15
  import torch
16
 
17
  from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
18
+ from src.diarization.diarization import Diarization
19
+ from src.diarization.diarizationContainer import DiarizationContainer
20
  from src.hooks.progressListener import ProgressListener
21
  from src.hooks.subTaskProgressListener import SubTaskProgressListener
22
  from src.hooks.whisperProgressHook import create_progress_listener_handle
 
36
  import gradio as gr
37
 
38
  from src.download import ExceededMaximumDuration, download_url
39
+ from src.utils import optional_int, slugify, str2bool, write_srt, write_vtt
40
  from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
41
  from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
42
  from src.whisper.whisperFactory import create_whisper_container
 
86
  self.deleteUploadedFiles = delete_uploaded_files
87
  self.output_dir = output_dir
88
 
89
+ # Support for diarization
90
+ self.diarization: DiarizationContainer = None
91
+ # Dictionary with parameters to pass to diarization.run - if None, diarization is not enabled
92
+ self.diarization_kwargs = None
93
  self.app_config = app_config
94
 
95
  def set_parallel_devices(self, vad_parallel_devices: str):
 
103
  self.vad_cpu_cores = min(os.cpu_count(), MAX_AUTO_CPU_CORES)
104
  print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
105
 
106
+ def set_diarization(self, auth_token: str, enable_daemon_process: bool = True, **kwargs):
107
+ if self.diarization is None:
108
+ self.diarization = DiarizationContainer(auth_token=auth_token, enable_daemon_process=enable_daemon_process,
109
+ auto_cleanup_timeout_seconds=self.app_config.diarization_process_timeout,
110
+ cache=self.model_cache)
111
+ # Set parameters
112
+ self.diarization_kwargs = kwargs
113
+
114
+ def unset_diarization(self):
115
+ if self.diarization is not None:
116
+ self.diarization.cleanup()
117
+ self.diarization_kwargs = None
118
+
119
  # Entry function for the simple tab
120
  def transcribe_webui_simple(self, modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
121
  vad, vadMergeWindow, vadMaxMergeSize,
122
+ word_timestamps: bool = False, highlight_words: bool = False,
123
+ diarization: bool = False, diarization_speakers: int = 2,
124
+ diarization_min_speakers = 1, diarization_max_speakers = 8):
125
  return self.transcribe_webui_simple_progress(modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
126
+ vad, vadMergeWindow, vadMaxMergeSize,
127
+ word_timestamps, highlight_words,
128
+ diarization, diarization_speakers,
129
+ diarization_min_speakers, diarization_max_speakers)
130
 
131
  # Entry function for the simple tab progress
132
  def transcribe_webui_simple_progress(self, modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
133
  vad, vadMergeWindow, vadMaxMergeSize,
134
  word_timestamps: bool = False, highlight_words: bool = False,
135
+ diarization: bool = False, diarization_speakers: int = 2,
136
+ diarization_min_speakers = 1, diarization_max_speakers = 8,
137
  progress=gr.Progress()):
138
+
139
  vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, self.app_config.vad_padding, self.app_config.vad_prompt_window, self.app_config.vad_initial_prompt_mode)
140
 
141
+ if diarization:
142
+ if diarization_speakers < 1:
143
+ self.set_diarization(auth_token=self.app_config.auth_token, min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers)
144
+ else:
145
+ self.set_diarization(auth_token=self.app_config.auth_token, num_speakers=diarization_speakers, min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers)
146
+ else:
147
+ self.unset_diarization()
148
+
149
  return self.transcribe_webui(modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task, vadOptions,
150
  word_timestamps=word_timestamps, highlight_words=highlight_words, progress=progress)
151
 
 
156
  word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
157
  initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
158
  condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
159
+ compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
160
+ diarization: bool = False, diarization_speakers: int = 2,
161
+ diarization_min_speakers = 1, diarization_max_speakers = 8):
162
 
163
  return self.transcribe_webui_full_progress(modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
164
  vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
165
  word_timestamps, highlight_words, prepend_punctuations, append_punctuations,
166
  initial_prompt, temperature, best_of, beam_size, patience, length_penalty, suppress_tokens,
167
  condition_on_previous_text, fp16, temperature_increment_on_fallback,
168
+ compression_ratio_threshold, logprob_threshold, no_speech_threshold,
169
+ diarization, diarization_speakers,
170
+ diarization_min_speakers, diarization_max_speakers)
171
 
172
  # Entry function for the full tab with progress
173
  def transcribe_webui_full_progress(self, modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
 
177
  initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
178
  condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
179
  compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
180
+ diarization: bool = False, diarization_speakers: int = 2,
181
+ diarization_min_speakers = 1, diarization_max_speakers = 8,
182
  progress=gr.Progress()):
183
 
184
  # Handle temperature_increment_on_fallback
 
189
 
190
  vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode)
191
 
192
+ # Set diarization
193
+ if diarization:
194
+ if diarization_speakers < 1:
195
+ self.set_diarization(auth_token=self.app_config.auth_token, min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers)
196
+ else:
197
+ self.set_diarization(auth_token=self.app_config.auth_token, num_speakers=diarization_speakers, min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers)
198
+ else:
199
+ self.unset_diarization()
200
+
201
  return self.transcribe_webui(modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task, vadOptions,
202
  initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
203
  condition_on_previous_text=condition_on_previous_text, fp16=fp16,
 
422
  else:
423
  # Default VAD
424
  result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
425
+
426
+ # Diarization
427
+ if self.diarization and self.diarization_kwargs:
428
+ print("Diarizing ", audio_path)
429
+ diarization_result = list(self.diarization.run(audio_path, **self.diarization_kwargs))
430
+
431
+ # Print result
432
+ print("Diarization result: ")
433
+ for entry in diarization_result:
434
+ print(f" start={entry.start:.1f}s stop={entry.end:.1f}s speaker_{entry.speaker}")
435
+
436
+ # Add speakers to result
437
+ result = self.diarization.mark_speakers(diarization_result, result)
438
 
439
  return result
440
 
 
607
  if (self.cpu_parallel_context is not None):
608
  self.cpu_parallel_context.close()
609
 
610
+ # Cleanup diarization
611
+ if (self.diarization is not None):
612
+ self.diarization.cleanup()
613
+ self.diarization = None
614
 
615
  def create_ui(app_config: ApplicationConfig):
616
  ui = WhisperTranscriber(app_config.input_audio_max_duration, app_config.vad_process_timeout, app_config.vad_cpu_cores,
617
  app_config.delete_uploaded_files, app_config.output_dir, app_config)
618
+
619
  # Specify a list of devices to use for parallel processing
620
  ui.set_parallel_devices(app_config.vad_parallel_devices)
621
  ui.set_auto_parallel(app_config.auto_parallel)
 
685
  gr.Checkbox(label="Word Timestamps - Highlight Words", value=app_config.highlight_words),
686
  ]
687
 
688
+ has_diarization_libs = Diarization.has_libraries()
689
+
690
+ if not has_diarization_libs:
691
+ print("Diarization libraries not found - disabling diarization")
692
+ app_config.diarization = False
693
+
694
+ common_diarization_inputs = lambda : [
695
+ gr.Checkbox(label="Diarization", value=app_config.diarization, interactive=has_diarization_libs),
696
+ gr.Number(label="Diarization - Speakers", precision=0, value=app_config.diarization_speakers, interactive=has_diarization_libs),
697
+ gr.Number(label="Diarization - Min Speakers", precision=0, value=app_config.diarization_min_speakers, interactive=has_diarization_libs),
698
+ gr.Number(label="Diarization - Max Speakers", precision=0, value=app_config.diarization_max_speakers, interactive=has_diarization_libs)
699
+ ]
700
+
701
  common_output = lambda : [
702
  gr.File(label="Download"),
703
  gr.Text(label="Transcription"),
 
719
  with gr.Row():
720
  simple_input += common_nllb_inputs()
721
  with gr.Column():
722
+ simple_input += common_audio_inputs() + common_vad_inputs() + common_word_timestamps_inputs() + common_diarization_inputs()
723
  with gr.Column():
724
  simple_output = common_output()
725
  simple_flag = gr.Button("Flag")
 
768
  gr.Number(label="Temperature increment on fallback", value=app_config.temperature_increment_on_fallback),
769
  gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold),
770
  gr.Number(label="Logprob threshold", value=app_config.logprob_threshold),
771
+ gr.Number(label="No speech threshold", value=app_config.no_speech_threshold)] + common_diarization_inputs()
772
 
773
  with gr.Column():
774
  full_output = common_output()
 
850
  help="Maximum length of a file name.")
851
  parser.add_argument("--autolaunch", action='store_true', \
852
  help="open the webui URL in the system's default browser upon launch")
853
+ parser.add_argument('--auth_token', type=str, default=default_app_config.auth_token, help='HuggingFace API Token (optional)')
854
+ parser.add_argument("--diarization", type=str2bool, default=default_app_config.diarization, \
855
+ help="whether to perform speaker diarization")
856
+ parser.add_argument("--diarization_num_speakers", type=int, default=default_app_config.diarization_speakers, help="Number of speakers")
857
+ parser.add_argument("--diarization_min_speakers", type=int, default=default_app_config.diarization_min_speakers, help="Minimum number of speakers")
858
+ parser.add_argument("--diarization_max_speakers", type=int, default=default_app_config.diarization_max_speakers, help="Maximum number of speakers")
859
+ parser.add_argument("--diarization_process_timeout", type=int, default=default_app_config.diarization_process_timeout, \
860
+ help="Number of seconds before inactivate diarization processes are terminated. Use 0 to close processes immediately, or None for no timeout.")
861
 
862
  args = parser.parse_args().__dict__
863
 
 
874
  if (threads := args.pop("threads")) > 0:
875
  torch.set_num_threads(threads)
876
 
877
+ print("Using whisper implementation: " + updated_config.whisper_implementation)
878
  create_ui(app_config=updated_config)
cli.py CHANGED
@@ -8,6 +8,7 @@ import numpy as np
8
  import torch
9
  from app import VadOptions, WhisperTranscriber
10
  from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
 
11
  from src.download import download_url
12
  from src.languages import get_language_names
13
 
@@ -106,6 +107,14 @@ def cli():
106
  parser.add_argument("--threads", type=optional_int, default=0,
107
  help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
108
 
 
 
 
 
 
 
 
 
109
  args = parser.parse_args().__dict__
110
  model_name: str = args.pop("model")
111
  model_dir: str = args.pop("model_dir")
@@ -142,10 +151,19 @@ def cli():
142
  compute_type = args.pop("compute_type")
143
  highlight_words = args.pop("highlight_words")
144
 
 
 
 
 
 
 
145
  transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
146
  transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
147
  transcriber.set_auto_parallel(auto_parallel)
148
 
 
 
 
149
  model = create_whisper_container(whisper_implementation=whisper_implementation, model_name=model_name,
150
  device=device, compute_type=compute_type, download_root=model_dir, models=app_config.models)
151
 
 
8
  import torch
9
  from app import VadOptions, WhisperTranscriber
10
  from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
11
+ from src.diarization.diarization import Diarization
12
  from src.download import download_url
13
  from src.languages import get_language_names
14
 
 
107
  parser.add_argument("--threads", type=optional_int, default=0,
108
  help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
109
 
110
+ # Diarization
111
+ parser.add_argument('--auth_token', type=str, default=None, help='HuggingFace API Token (optional)')
112
+ parser.add_argument("--diarization", type=str2bool, default=app_config.diarization, \
113
+ help="whether to perform speaker diarization")
114
+ parser.add_argument("--diarization_num_speakers", type=int, default=None, help="Number of speakers")
115
+ parser.add_argument("--diarization_min_speakers", type=int, default=None, help="Minimum number of speakers")
116
+ parser.add_argument("--diarization_max_speakers", type=int, default=None, help="Maximum number of speakers")
117
+
118
  args = parser.parse_args().__dict__
119
  model_name: str = args.pop("model")
120
  model_dir: str = args.pop("model_dir")
 
151
  compute_type = args.pop("compute_type")
152
  highlight_words = args.pop("highlight_words")
153
 
154
+ auth_token = args.pop("auth_token")
155
+ diarization = args.pop("diarization")
156
+ num_speakers = args.pop("diarization_num_speakers")
157
+ min_speakers = args.pop("diarization_min_speakers")
158
+ max_speakers = args.pop("diarization_max_speakers")
159
+
160
  transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
161
  transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
162
  transcriber.set_auto_parallel(auto_parallel)
163
 
164
+ if diarization:
165
+ transcriber.set_diarization(auth_token=auth_token, enable_daemon_process=False, num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers)
166
+
167
  model = create_whisper_container(whisper_implementation=whisper_implementation, model_name=model_name,
168
  device=device, compute_type=compute_type, download_root=model_dir, models=app_config.models)
169
 
config.json5 CHANGED
@@ -234,4 +234,17 @@
234
  "append_punctuations": "\"\'.。,,!!??::”)]}、",
235
  // (requires --word_timestamps True) underline each word as it is spoken in srt and vtt
236
  "highlight_words": false,
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  }
 
234
  "append_punctuations": "\"\'.。,,!!??::”)]}、",
235
  // (requires --word_timestamps True) underline each word as it is spoken in srt and vtt
236
  "highlight_words": false,
237
+
238
+ // Diarization settings
239
+ "auth_token": null,
240
+ // Whether to perform speaker diarization
241
+ "diarization": false,
242
+ // The number of speakers to detect
243
+ "diarization_speakers": 2,
244
+ // The minimum number of speakers to detect
245
+ "diarization_min_speakers": 1,
246
+ // The maximum number of speakers to detect
247
+ "diarization_max_speakers": 8,
248
+ // The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.
249
+ "diarization_process_timeout": 60,
250
  }
docs/options.md CHANGED
@@ -80,6 +80,17 @@ number of seconds after the line has finished. For instance, if a line ends at 1
80
  Note that detected lines in gaps between speech sections will not be included in the prompt
81
  (if silero-vad or silero-vad-expand-into-gaps) is used.
82
 
 
 
 
 
 
 
 
 
 
 
 
83
  # Command Line Options
84
 
85
  Both `app.py` and `cli.py` also accept command line options, such as the ability to enable parallel execution on multiple
@@ -132,3 +143,11 @@ If the average log probability is lower than this value, treat the decoding as f
132
 
133
  ## No speech threshold
134
  If the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence. Default is 0.6.
 
 
 
 
 
 
 
 
 
80
  Note that detected lines in gaps between speech sections will not be included in the prompt
81
  (if silero-vad or silero-vad-expand-into-gaps) is used.
82
 
83
+ ## Diarization
84
+
85
+ If checked, Pyannote will be used to detect speakers in the audio, and label them as (SPEAKER 00), (SPEAKER 01), etc.
86
+
87
+ This requires a HuggingFace API key to function, which can be supplied with the `--auth_token` command line option for the CLI,
88
+ set in the `config.json5` file for the GUI, or provided via the `HF_ACCESS_TOKEN` environment variable.
89
+
90
+ ## Diarization - Speakers
91
+
92
+ The number of speakers to detect. If set to 0, Pyannote will attempt to detect the number of speakers automatically.
93
+
94
  # Command Line Options
95
 
96
  Both `app.py` and `cli.py` also accept command line options, such as the ability to enable parallel execution on multiple
 
143
 
144
  ## No speech threshold
145
  If the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence. Default is 0.6.
146
+
147
+ ## Diarization - Min Speakers
148
+
149
+ The minimum number of speakers for Pyannote to detect.
150
+
151
+ ## Diarization - Max Speakers
152
+
153
+ The maximum number of speakers for Pyannote to detect.
requirements-fasterWhisper.txt CHANGED
@@ -9,4 +9,10 @@ torch
9
  torchaudio
10
  more_itertools
11
  zhconv
12
- sentencepiece
 
 
 
 
 
 
 
9
  torchaudio
10
  more_itertools
11
  zhconv
12
+ sentencepiece
13
+
14
+ # Needed by diarization
15
+ intervaltree
16
+ srt
17
+ torch
18
+ https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
requirements-whisper.txt CHANGED
@@ -8,4 +8,10 @@ torchaudio
8
  altair
9
  json5
10
  zhconv
11
- sentencepiece
 
 
 
 
 
 
 
8
  altair
9
  json5
10
  zhconv
11
+ sentencepiece
12
+
13
+ # Needed by diarization
14
+ intervaltree
15
+ srt
16
+ torch
17
+ https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
requirements.txt CHANGED
@@ -9,4 +9,10 @@ torch
9
  torchaudio
10
  more_itertools
11
  zhconv
12
- sentencepiece
 
 
 
 
 
 
 
9
  torchaudio
10
  more_itertools
11
  zhconv
12
+ sentencepiece
13
+
14
+ # Needed by diarization
15
+ intervaltree
16
+ srt
17
+ torch
18
+ https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
src/config.py CHANGED
@@ -69,7 +69,11 @@ class ApplicationConfig:
69
  # Word timestamp settings
70
  word_timestamps: bool = False, prepend_punctuations: str = "\"\'“¿([{-",
71
  append_punctuations: str = "\"\'.。,,!!??::”)]}、",
72
- highlight_words: bool = False):
 
 
 
 
73
 
74
  self.models = models
75
  self.nllb_models = nllb_models
@@ -123,6 +127,14 @@ class ApplicationConfig:
123
  self.append_punctuations = append_punctuations
124
  self.highlight_words = highlight_words
125
 
 
 
 
 
 
 
 
 
126
  def get_model_names(self):
127
  return [ x.name for x in self.models ]
128
 
 
69
  # Word timestamp settings
70
  word_timestamps: bool = False, prepend_punctuations: str = "\"\'“¿([{-",
71
  append_punctuations: str = "\"\'.。,,!!??::”)]}、",
72
+ highlight_words: bool = False,
73
+ # Diarization
74
+ auth_token: str = None, diarization: bool = False, diarization_speakers: int = 2,
75
+ diarization_min_speakers: int = 1, diarization_max_speakers: int = 5,
76
+ diarization_process_timeout: int = 60):
77
 
78
  self.models = models
79
  self.nllb_models = nllb_models
 
127
  self.append_punctuations = append_punctuations
128
  self.highlight_words = highlight_words
129
 
130
+ # Diarization settings
131
+ self.auth_token = auth_token
132
+ self.diarization = diarization
133
+ self.diarization_speakers = diarization_speakers
134
+ self.diarization_min_speakers = diarization_min_speakers
135
+ self.diarization_max_speakers = diarization_max_speakers
136
+ self.diarization_process_timeout = diarization_process_timeout
137
+
138
  def get_model_names(self):
139
  return [ x.name for x in self.models ]
140
 
src/diarization/diarization.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gc
3
+ import json
4
+ import os
5
+ from pathlib import Path
6
+ import tempfile
7
+ from typing import TYPE_CHECKING, List
8
+ import torch
9
+
10
+ import ffmpeg
11
+
12
+ class DiarizationEntry:
13
+ def __init__(self, start, end, speaker):
14
+ self.start = start
15
+ self.end = end
16
+ self.speaker = speaker
17
+
18
+ def __repr__(self):
19
+ return f"<DiarizationEntry start={self.start} end={self.end} speaker={self.speaker}>"
20
+
21
+ def toJson(self):
22
+ return {
23
+ "start": self.start,
24
+ "end": self.end,
25
+ "speaker": self.speaker
26
+ }
27
+
28
+ class Diarization:
29
+ def __init__(self, auth_token=None):
30
+ if auth_token is None:
31
+ auth_token = os.environ.get("HF_ACCESS_TOKEN")
32
+ if auth_token is None:
33
+ raise ValueError("No HuggingFace API Token provided - please use the --auth_token argument or set the HF_ACCESS_TOKEN environment variable")
34
+
35
+ self.auth_token = auth_token
36
+ self.initialized = False
37
+ self.pipeline = None
38
+
39
+ @staticmethod
40
+ def has_libraries():
41
+ try:
42
+ import pyannote.audio
43
+ import intervaltree
44
+ return True
45
+ except ImportError:
46
+ return False
47
+
48
+ def initialize(self):
49
+ """
50
+ 1.Install pyannote.audio 3.0 with pip install pyannote.audio
51
+ 2.Accept pyannote/segmentation-3.0 user conditions
52
+ 3.Accept pyannote/speaker-diarization-3.0 user conditions
53
+ 4.Create access token at hf.co/settings/tokens.
54
+ https://huggingface.co/pyannote/speaker-diarization-3.0
55
+ """
56
+ if self.initialized:
57
+ return
58
+ from pyannote.audio import Pipeline
59
+
60
+ self.pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.0", use_auth_token=self.auth_token)
61
+ self.initialized = True
62
+
63
+ # Load GPU mode if available
64
+ device = "cuda" if torch.cuda.is_available() else "cpu"
65
+ if device == "cuda":
66
+ print("Diarization - using GPU")
67
+ self.pipeline = self.pipeline.to(torch.device(0))
68
+ else:
69
+ print("Diarization - using CPU")
70
+
71
+ def run(self, audio_file, **kwargs):
72
+ self.initialize()
73
+ audio_file_obj = Path(audio_file)
74
+
75
+ # Supported file types in soundfile is WAV, FLAC, OGG and MAT
76
+ if audio_file_obj.suffix in [".wav", ".flac", ".ogg", ".mat"]:
77
+ target_file = audio_file
78
+ else:
79
+ # Create temp WAV file
80
+ target_file = tempfile.mktemp(prefix="diarization_", suffix=".wav")
81
+ try:
82
+ ffmpeg.input(audio_file).output(target_file, ac=1).run()
83
+ except ffmpeg.Error as e:
84
+ print(f"Error occurred during audio conversion: {e.stderr}")
85
+
86
+ diarization = self.pipeline(target_file, **kwargs)
87
+
88
+ if target_file != audio_file:
89
+ # Delete temp file
90
+ os.remove(target_file)
91
+
92
+ # Yield result
93
+ for turn, _, speaker in diarization.itertracks(yield_label=True):
94
+ yield DiarizationEntry(turn.start, turn.end, speaker)
95
+
96
+ def mark_speakers(self, diarization_result: List[DiarizationEntry], whisper_result: dict):
97
+ from intervaltree import IntervalTree
98
+ result = whisper_result.copy()
99
+
100
+ # Create an interval tree from the diarization results
101
+ tree = IntervalTree()
102
+ for entry in diarization_result:
103
+ tree[entry.start:entry.end] = entry
104
+
105
+ # Iterate through each segment in the Whisper JSON
106
+ for segment in result["segments"]:
107
+ segment_start = segment["start"]
108
+ segment_end = segment["end"]
109
+
110
+ # Find overlapping speakers using the interval tree
111
+ overlapping_speakers = tree[segment_start:segment_end]
112
+
113
+ # If no speakers overlap with this segment, skip it
114
+ if not overlapping_speakers:
115
+ continue
116
+
117
+ # If multiple speakers overlap with this segment, choose the one with the longest duration
118
+ longest_speaker = None
119
+ longest_duration = 0
120
+
121
+ for speaker_interval in overlapping_speakers:
122
+ overlap_start = max(speaker_interval.begin, segment_start)
123
+ overlap_end = min(speaker_interval.end, segment_end)
124
+ overlap_duration = overlap_end - overlap_start
125
+
126
+ if overlap_duration > longest_duration:
127
+ longest_speaker = speaker_interval.data.speaker
128
+ longest_duration = overlap_duration
129
+
130
+ # Add speakers
131
+ segment["longest_speaker"] = longest_speaker
132
+ segment["speakers"] = list([speaker_interval.data.toJson() for speaker_interval in overlapping_speakers])
133
+
134
+ # The write_srt will use the longest_speaker if it exist, and add it to the text field
135
+
136
+ return result
137
+
138
+ def _write_file(input_file: str, output_path: str, output_extension: str, file_writer: lambda f: None):
139
+ if input_file is None:
140
+ raise ValueError("input_file is required")
141
+ if file_writer is None:
142
+ raise ValueError("file_writer is required")
143
+
144
+ # Write file
145
+ if output_path is None:
146
+ effective_path = os.path.splitext(input_file)[0] + "_output" + output_extension
147
+ else:
148
+ effective_path = output_path
149
+
150
+ with open(effective_path, 'w+', encoding="utf-8") as f:
151
+ file_writer(f)
152
+
153
+ print(f"Output saved to {effective_path}")
154
+
155
+ def main():
156
+ from src.utils import write_srt
157
+ from src.diarization.transcriptLoader import load_transcript
158
+
159
+ parser = argparse.ArgumentParser(description='Add speakers to a SRT file or Whisper JSON file using pyannote/speaker-diarization.')
160
+ parser.add_argument('audio_file', type=str, help='Input audio file')
161
+ parser.add_argument('whisper_file', type=str, help='Input Whisper JSON/SRT file')
162
+ parser.add_argument('--output_json_file', type=str, default=None, help='Output JSON file (optional)')
163
+ parser.add_argument('--output_srt_file', type=str, default=None, help='Output SRT file (optional)')
164
+ parser.add_argument('--auth_token', type=str, default=None, help='HuggingFace API Token (optional)')
165
+ parser.add_argument("--max_line_width", type=int, default=40, help="Maximum line width for SRT file (default: 40)")
166
+ parser.add_argument("--num_speakers", type=int, default=None, help="Number of speakers")
167
+ parser.add_argument("--min_speakers", type=int, default=None, help="Minimum number of speakers")
168
+ parser.add_argument("--max_speakers", type=int, default=None, help="Maximum number of speakers")
169
+
170
+ args = parser.parse_args()
171
+
172
+ print("\nReading whisper JSON from " + args.whisper_file)
173
+
174
+ # Read whisper JSON or SRT file
175
+ whisper_result = load_transcript(args.whisper_file)
176
+
177
+ diarization = Diarization(auth_token=args.auth_token)
178
+ diarization_result = list(diarization.run(args.audio_file, num_speakers=args.num_speakers, min_speakers=args.min_speakers, max_speakers=args.max_speakers))
179
+
180
+ # Print result
181
+ print("Diarization result:")
182
+ for entry in diarization_result:
183
+ print(f" start={entry.start:.1f}s stop={entry.end:.1f}s speaker_{entry.speaker}")
184
+
185
+ marked_whisper_result = diarization.mark_speakers(diarization_result, whisper_result)
186
+
187
+ # Write output JSON to file
188
+ _write_file(args.whisper_file, args.output_json_file, ".json",
189
+ lambda f: json.dump(marked_whisper_result, f, indent=4, ensure_ascii=False))
190
+
191
+ # Write SRT
192
+ _write_file(args.whisper_file, args.output_srt_file, ".srt",
193
+ lambda f: write_srt(marked_whisper_result["segments"], f, maxLineWidth=args.max_line_width))
194
+
195
+ if __name__ == "__main__":
196
+ main()
197
+
198
+ #test = Diarization()
199
+ #print("Initializing")
200
+ #test.initialize()
201
+
202
+ #input("Press Enter to continue...")
src/diarization/diarizationContainer.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from src.diarization.diarization import Diarization, DiarizationEntry
3
+ from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
4
+ from src.vadParallel import ParallelContext
5
+
6
+ class DiarizationContainer:
7
+ def __init__(self, auth_token: str = None, enable_daemon_process: bool = True, auto_cleanup_timeout_seconds=60, cache: ModelCache = None):
8
+ self.auth_token = auth_token
9
+ self.enable_daemon_process = enable_daemon_process
10
+ self.auto_cleanup_timeout_seconds = auto_cleanup_timeout_seconds
11
+ self.diarization_context: ParallelContext = None
12
+ self.cache = cache
13
+ self.model = None
14
+
15
+ def run(self, audio_file, **kwargs):
16
+ # Create parallel context if needed
17
+ if self.diarization_context is None and self.enable_daemon_process:
18
+ # Number of processes is set to 1 as we mainly use this in order to clean up GPU memory
19
+ self.diarization_context = ParallelContext(num_processes=1, auto_cleanup_timeout_seconds=self.auto_cleanup_timeout_seconds)
20
+ print("Created diarization context with auto cleanup timeout of %d seconds" % self.auto_cleanup_timeout_seconds)
21
+
22
+ # Run directly
23
+ if self.diarization_context is None:
24
+ return self.execute(audio_file, **kwargs)
25
+
26
+ # Otherwise run in a separate process
27
+ pool = self.diarization_context.get_pool()
28
+
29
+ try:
30
+ result = pool.apply(self.execute, (audio_file,), kwargs)
31
+ return result
32
+ finally:
33
+ self.diarization_context.return_pool(pool)
34
+
35
+ def mark_speakers(self, diarization_result: List[DiarizationEntry], whisper_result: dict):
36
+ if self.model is not None:
37
+ return self.model.mark_speakers(diarization_result, whisper_result)
38
+
39
+ # Create a new diarization model (calling mark_speakers will not initialize pyannote.audio)
40
+ model = Diarization(self.auth_token)
41
+ return model.mark_speakers(diarization_result, whisper_result)
42
+
43
+ def get_model(self):
44
+ # Lazy load the model
45
+ if (self.model is None):
46
+ if self.cache:
47
+ print("Loading diarization model from cache")
48
+ self.model = self.cache.get("diarization", lambda : Diarization(self.auth_token))
49
+ else:
50
+ print("Loading diarization model")
51
+ self.model = Diarization(self.auth_token)
52
+ return self.model
53
+
54
+ def execute(self, audio_file, **kwargs):
55
+ model = self.get_model()
56
+
57
+ # We must use list() here to force the iterator to run, as generators are not picklable
58
+ result = list(model.run(audio_file, **kwargs))
59
+ return result
60
+
61
+ def cleanup(self):
62
+ if self.diarization_context is not None:
63
+ self.diarization_context.close()
64
+
65
+ def __getstate__(self):
66
+ return {
67
+ "auth_token": self.auth_token,
68
+ "enable_daemon_process": self.enable_daemon_process,
69
+ "auto_cleanup_timeout_seconds": self.auto_cleanup_timeout_seconds
70
+ }
71
+
72
+ def __setstate__(self, state):
73
+ self.auth_token = state["auth_token"]
74
+ self.enable_daemon_process = state["enable_daemon_process"]
75
+ self.auto_cleanup_timeout_seconds = state["auto_cleanup_timeout_seconds"]
76
+ self.diarization_context = None
77
+ self.cache = GLOBAL_MODEL_CACHE
78
+ self.model = None
src/diarization/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ intervaltree
2
+ srt
3
+ torch
4
+ ffmpeg-python==0.2.0
5
+ https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
src/diarization/transcriptLoader.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+ from pathlib import Path
4
+
5
+ def load_transcript_json(transcript_file: str):
6
+ """
7
+ Parse a Whisper JSON file into a Whisper JSON object
8
+
9
+ # Parameters:
10
+ transcript_file (str): Path to the Whisper JSON file
11
+ """
12
+ with open(transcript_file, "r", encoding="utf-8") as f:
13
+ whisper_result = json.load(f)
14
+
15
+ # Format of Whisper JSON file:
16
+ # {
17
+ # "text": " And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.",
18
+ # "segments": [
19
+ # {
20
+ # "text": " And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.",
21
+ # "start": 0.0,
22
+ # "end": 10.36,
23
+ # "words": [
24
+ # {
25
+ # "start": 0.0,
26
+ # "end": 0.56,
27
+ # "word": " And",
28
+ # "probability": 0.61767578125
29
+ # },
30
+ # {
31
+ # "start": 0.56,
32
+ # "end": 0.88,
33
+ # "word": " so",
34
+ # "probability": 0.9033203125
35
+ # },
36
+ # etc.
37
+
38
+ return whisper_result
39
+
40
+
41
+ def load_transcript_srt(subtitle_file: str):
42
+ import srt
43
+
44
+ """
45
+ Parse a SRT file into a Whisper JSON object
46
+
47
+ # Parameters:
48
+ subtitle_file (str): Path to the SRT file
49
+ """
50
+ with open(subtitle_file, "r", encoding="utf-8") as f:
51
+ subs = srt.parse(f)
52
+
53
+ whisper_result = {
54
+ "text": "",
55
+ "segments": []
56
+ }
57
+
58
+ for sub in subs:
59
+ # Subtitle(index=1, start=datetime.timedelta(seconds=33, microseconds=843000), end=datetime.timedelta(seconds=38, microseconds=97000), content='地球上只有3%的水是淡水', proprietary='')
60
+ segment = {
61
+ "text": sub.content,
62
+ "start": sub.start.total_seconds(),
63
+ "end": sub.end.total_seconds(),
64
+ "words": []
65
+ }
66
+ whisper_result["segments"].append(segment)
67
+ whisper_result["text"] += sub.content
68
+
69
+ return whisper_result
70
+
71
+ def load_transcript(file: str):
72
+ # Determine file type
73
+ file_extension = Path(file).suffix.lower()
74
+
75
+ if file_extension == ".json":
76
+ return load_transcript_json(file)
77
+ elif file_extension == ".srt":
78
+ return load_transcript_srt(file)
79
+ else:
80
+ raise ValueError(f"Unsupported file type: {file_extension}")
src/utils.py CHANGED
@@ -102,17 +102,28 @@ def write_srt(transcript: Iterator[dict], file: TextIO,
102
 
103
  def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: int = None, highlight_words: bool = False):
104
  for segment in transcript:
105
- words = segment.get('words', [])
 
 
 
 
 
106
 
107
  if len(words) == 0:
108
  # Yield the segment as-is or processed
109
- if maxLineWidth is None or maxLineWidth < 0:
110
  yield segment
111
  else:
 
 
 
 
 
 
112
  yield {
113
  'start': segment['start'],
114
  'end': segment['end'],
115
- 'text': process_text(segment['text'].strip(), maxLineWidth)
116
  }
117
  # We are done
118
  continue
@@ -120,9 +131,17 @@ def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: i
120
  subtitle_start = segment['start']
121
  subtitle_end = segment['end']
122
 
 
 
 
 
 
 
 
 
123
  text_words = [ this_word["word"] for this_word in words ]
124
  subtitle_text = __join_words(text_words, maxLineWidth)
125
-
126
  # Iterate over the words in the segment
127
  if highlight_words:
128
  last = subtitle_start
 
102
 
103
  def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: int = None, highlight_words: bool = False):
104
  for segment in transcript:
105
+ words: list = segment.get('words', [])
106
+
107
+ # Append longest speaker ID if available
108
+ segment_longest_speaker = segment.get('longest_speaker', None)
109
+ if segment_longest_speaker is not None:
110
+ segment_longest_speaker = segment_longest_speaker.replace("SPEAKER", "S")
111
 
112
  if len(words) == 0:
113
  # Yield the segment as-is or processed
114
+ if (maxLineWidth is None or maxLineWidth < 0) and segment_longest_speaker is None:
115
  yield segment
116
  else:
117
+ text = segment['text'].strip()
118
+
119
+ # Prepend the longest speaker ID if available
120
+ if segment_longest_speaker is not None:
121
+ text = f"({segment_longest_speaker}) {text}"
122
+
123
  yield {
124
  'start': segment['start'],
125
  'end': segment['end'],
126
+ 'text': process_text(text, maxLineWidth)
127
  }
128
  # We are done
129
  continue
 
131
  subtitle_start = segment['start']
132
  subtitle_end = segment['end']
133
 
134
+ if segment_longest_speaker is not None:
135
+ # Add the beginning
136
+ words.insert(0, {
137
+ 'start': subtitle_start,
138
+ 'end': subtitle_start,
139
+ 'word': f"({segment_longest_speaker})"
140
+ })
141
+
142
  text_words = [ this_word["word"] for this_word in words ]
143
  subtitle_text = __join_words(text_words, maxLineWidth)
144
+
145
  # Iterate over the words in the segment
146
  if highlight_words:
147
  last = subtitle_start
src/whisper/abstractWhisperContainer.py CHANGED
@@ -1,5 +1,5 @@
1
  import abc
2
- from typing import List
3
 
4
  from src.config import ModelConfig, VadInitialPromptMode
5
 
@@ -9,7 +9,7 @@ from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
9
 
10
  class AbstractWhisperCallback:
11
  def __init__(self):
12
- self.__prompt_mode_gpt = None
13
 
14
  @abc.abstractmethod
15
  def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
@@ -29,6 +29,14 @@ class AbstractWhisperCallback:
29
  """
30
  raise NotImplementedError()
31
 
 
 
 
 
 
 
 
 
32
  class AbstractWhisperContainer:
33
  def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
34
  download_root: str = None,
 
1
  import abc
2
+ from typing import Any, Callable, List
3
 
4
  from src.config import ModelConfig, VadInitialPromptMode
5
 
 
9
 
10
  class AbstractWhisperCallback:
11
  def __init__(self):
12
+ pass
13
 
14
  @abc.abstractmethod
15
  def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
 
29
  """
30
  raise NotImplementedError()
31
 
32
+ class LambdaWhisperCallback(AbstractWhisperCallback):
33
+ def __init__(self, callback_lambda: Callable[[Any, int, str, str, ProgressListener], None]):
34
+ super().__init__()
35
+ self.callback_lambda = callback_lambda
36
+
37
+ def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
38
+ return self.callback_lambda(audio, segment_index, prompt, detected_language, progress_listener)
39
+
40
  class AbstractWhisperContainer:
41
  def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
42
  download_root: str = None,
src/whisper/dummyWhisperContainer.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import ffmpeg
4
+ from src.config import ModelConfig
5
+ from src.hooks.progressListener import ProgressListener
6
+ from src.modelCache import ModelCache
7
+ from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
8
+ from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
9
+
10
+ class DummyWhisperContainer(AbstractWhisperContainer):
11
+ def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
12
+ download_root: str = None,
13
+ cache: ModelCache = None, models: List[ModelConfig] = []):
14
+ super().__init__(model_name, device, compute_type, download_root, cache, models)
15
+
16
+ def ensure_downloaded(self):
17
+ """
18
+ Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
19
+ passing the container to a subprocess.
20
+ """
21
+ print("[Dummy] Ensuring that the model is downloaded")
22
+
23
+ def _create_model(self):
24
+ print("[Dummy] Creating dummy whisper model " + self.model_name + " for device " + str(self.device))
25
+ return None
26
+
27
+ def create_callback(self, language: str = None, task: str = None,
28
+ prompt_strategy: AbstractPromptStrategy = None,
29
+ **decodeOptions: dict) -> AbstractWhisperCallback:
30
+ """
31
+ Create a WhisperCallback object that can be used to transcript audio files.
32
+
33
+ Parameters
34
+ ----------
35
+ language: str
36
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
37
+ task: str
38
+ The task - either translate or transcribe.
39
+ prompt_strategy: AbstractPromptStrategy
40
+ The prompt strategy to use. If not specified, the prompt from Whisper will be used.
41
+ decodeOptions: dict
42
+ Additional options to pass to the decoder. Must be pickleable.
43
+
44
+ Returns
45
+ -------
46
+ A WhisperCallback object.
47
+ """
48
+ return DummyWhisperCallback(self, language=language, task=task, prompt_strategy=prompt_strategy, **decodeOptions)
49
+
50
+ class DummyWhisperCallback(AbstractWhisperCallback):
51
+ def __init__(self, model_container: DummyWhisperContainer, **decodeOptions: dict):
52
+ self.model_container = model_container
53
+ self.decodeOptions = decodeOptions
54
+
55
+ def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
56
+ """
57
+ Peform the transcription of the given audio file or data.
58
+
59
+ Parameters
60
+ ----------
61
+ audio: Union[str, np.ndarray, torch.Tensor]
62
+ The audio file to transcribe, or the audio data as a numpy array or torch tensor.
63
+ segment_index: int
64
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
65
+ task: str
66
+ The task - either translate or transcribe.
67
+ progress_listener: ProgressListener
68
+ A callback to receive progress updates.
69
+ """
70
+ print("[Dummy] Invoking dummy whisper callback for segment " + str(segment_index))
71
+
72
+ # Estimate length
73
+ if isinstance(audio, str):
74
+ audio_length = ffmpeg.probe(audio)["format"]["duration"]
75
+ # Format is pcm_s16le at a sample rate of 16000, loaded as a float32 array.
76
+ else:
77
+ audio_length = len(audio) / 16000
78
+
79
+ # Convert the segments to a format that is easier to serialize
80
+ whisper_segments = [{
81
+ "text": "Dummy text for segment " + str(segment_index),
82
+ "start": 0,
83
+ "end": audio_length,
84
+
85
+ # Extra fields added by faster-whisper
86
+ "words": []
87
+ }]
88
+
89
+ result = {
90
+ "segments": whisper_segments,
91
+ "text": "Dummy text for segment " + str(segment_index),
92
+ "language": "en" if detected_language is None else detected_language,
93
+
94
+ # Extra fields added by faster-whisper
95
+ "language_probability": 1.0,
96
+ "duration": audio_length,
97
+ }
98
+
99
+ if progress_listener is not None:
100
+ progress_listener.on_finished()
101
+ return result
src/whisper/whisperFactory.py CHANGED
@@ -15,5 +15,9 @@ def create_whisper_container(whisper_implementation: str,
15
  elif (whisper_implementation == "faster-whisper" or whisper_implementation == "faster_whisper"):
16
  from src.whisper.fasterWhisperContainer import FasterWhisperContainer
17
  return FasterWhisperContainer(model_name=model_name, device=device, compute_type=compute_type, download_root=download_root, cache=cache, models=models)
 
 
 
 
18
  else:
19
  raise ValueError("Unknown Whisper implementation: " + whisper_implementation)
 
15
  elif (whisper_implementation == "faster-whisper" or whisper_implementation == "faster_whisper"):
16
  from src.whisper.fasterWhisperContainer import FasterWhisperContainer
17
  return FasterWhisperContainer(model_name=model_name, device=device, compute_type=compute_type, download_root=download_root, cache=cache, models=models)
18
+ elif (whisper_implementation == "dummy-whisper" or whisper_implementation == "dummy_whisper" or whisper_implementation == "dummy"):
19
+ # This is useful for testing
20
+ from src.whisper.dummyWhisperContainer import DummyWhisperContainer
21
+ return DummyWhisperContainer(model_name=model_name, device=device, compute_type=compute_type, download_root=download_root, cache=cache, models=models)
22
  else:
23
  raise ValueError("Unknown Whisper implementation: " + whisper_implementation)
tests/vad_test.py CHANGED
@@ -1,10 +1,11 @@
1
- import pprint
2
  import unittest
3
  import numpy as np
4
  import sys
5
 
6
  sys.path.append('../whisper-webui')
 
7
 
 
8
  from src.vad import AbstractTranscription, TranscriptionConfig, VadSileroTranscription
9
 
10
  class TestVad(unittest.TestCase):
@@ -13,10 +14,11 @@ class TestVad(unittest.TestCase):
13
  self.transcribe_calls = []
14
 
15
  def test_transcript(self):
16
- mock = MockVadTranscription()
 
17
 
18
  self.transcribe_calls.clear()
19
- result = mock.transcribe("mock", lambda segment : self.transcribe_segments(segment))
20
 
21
  self.assertListEqual(self.transcribe_calls, [
22
  [30, 30],
@@ -45,8 +47,9 @@ class TestVad(unittest.TestCase):
45
  }
46
 
47
  class MockVadTranscription(AbstractTranscription):
48
- def __init__(self):
49
  super().__init__()
 
50
 
51
  def get_audio_segment(self, str, start_time: str = None, duration: str = None):
52
  start_time_seconds = float(start_time.removesuffix("s"))
@@ -61,6 +64,9 @@ class MockVadTranscription(AbstractTranscription):
61
  result.append( { 'start': 30, 'end': 60 } )
62
  result.append( { 'start': 100, 'end': 200 } )
63
  return result
 
 
 
64
 
65
  if __name__ == '__main__':
66
  unittest.main()
 
 
1
  import unittest
2
  import numpy as np
3
  import sys
4
 
5
  sys.path.append('../whisper-webui')
6
+ #print("Sys path: " + str(sys.path))
7
 
8
+ from src.whisper.abstractWhisperContainer import LambdaWhisperCallback
9
  from src.vad import AbstractTranscription, TranscriptionConfig, VadSileroTranscription
10
 
11
  class TestVad(unittest.TestCase):
 
14
  self.transcribe_calls = []
15
 
16
  def test_transcript(self):
17
+ mock = MockVadTranscription(mock_audio_length=120)
18
+ config = TranscriptionConfig()
19
 
20
  self.transcribe_calls.clear()
21
+ result = mock.transcribe("mock", LambdaWhisperCallback(lambda segment, _1, _2, _3, _4: self.transcribe_segments(segment)), config)
22
 
23
  self.assertListEqual(self.transcribe_calls, [
24
  [30, 30],
 
47
  }
48
 
49
  class MockVadTranscription(AbstractTranscription):
50
+ def __init__(self, mock_audio_length: float = 1000):
51
  super().__init__()
52
+ self.mock_audio_length = mock_audio_length
53
 
54
  def get_audio_segment(self, str, start_time: str = None, duration: str = None):
55
  start_time_seconds = float(start_time.removesuffix("s"))
 
64
  result.append( { 'start': 30, 'end': 60 } )
65
  result.append( { 'start': 100, 'end': 200 } )
66
  return result
67
+
68
+ def get_audio_duration(self, audio: str, config: TranscriptionConfig):
69
+ return self.mock_audio_length
70
 
71
  if __name__ == '__main__':
72
  unittest.main()