aadnk commited on
Commit
d906b98
·
1 Parent(s): 48d8572

Make it easier to use the old segmentation strategy

Browse files
Files changed (3) hide show
  1. app.py +24 -20
  2. cli.py +3 -3
  3. src/vad.py +58 -22
app.py CHANGED
@@ -14,7 +14,7 @@ import gradio as gr
14
 
15
  from src.download import ExceededMaximumDuration, download_url
16
  from src.utils import slugify, write_srt, write_vtt
17
- from src.vad import VadPeriodicTranscription, VadSileroTranscription
18
 
19
  # Limitations (set to -1 to disable)
20
  DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
@@ -94,25 +94,17 @@ class WhisperTranscriber:
94
 
95
  # The results
96
  if (vad == 'silero-vad'):
97
- # Use Silero VAD and include gaps
98
- if (self.vad_model is None):
99
- self.vad_model = VadSileroTranscription()
100
-
101
- process_gaps = VadSileroTranscription(transcribe_non_speech = True,
102
- max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize,
103
- segment_padding_left=vadPadding, segment_padding_right=vadPadding,
104
- max_prompt_window=vadPromptWindow, copy=self.vad_model)
105
  result = process_gaps.transcribe(audio_path, whisperCallable)
106
  elif (vad == 'silero-vad-skip-gaps'):
107
- # Use Silero VAD
108
- if (self.vad_model is None):
109
- self.vad_model = VadSileroTranscription()
110
-
111
- skip_gaps = VadSileroTranscription(transcribe_non_speech = False,
112
- max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize,
113
- segment_padding_left=vadPadding, segment_padding_right=vadPadding,
114
- max_prompt_window=vadPromptWindow, copy=self.vad_model)
115
  result = skip_gaps.transcribe(audio_path, whisperCallable)
 
 
 
 
116
  elif (vad == 'periodic-vad'):
117
  # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
118
  # it may create a break in the middle of a sentence, causing some artifacts.
@@ -124,6 +116,18 @@ class WhisperTranscriber:
124
 
125
  return result
126
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  def write_result(self, result: dict, source_name: str, output_dir: str):
128
  if not os.path.exists(output_dir):
129
  os.makedirs(output_dir)
@@ -218,11 +222,11 @@ def create_ui(inputAudioMaxDuration, share=False, server_name: str = None):
218
  gr.Audio(source="upload", type="filepath", label="Upload Audio"),
219
  gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
220
  gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
221
- gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "periodic-vad"], label="VAD"),
222
- gr.Number(label="VAD - Merge Window (s)", precision=0, value=4),
223
  gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=30),
224
  gr.Number(label="VAD - Padding (s)", precision=None, value=1),
225
- gr.Number(label="VAD - Prompt Window (s)", precision=None, value=10)
226
  ], outputs=[
227
  gr.File(label="Download"),
228
  gr.Text(label="Transcription"),
 
14
 
15
  from src.download import ExceededMaximumDuration, download_url
16
  from src.utils import slugify, write_srt, write_vtt
17
+ from src.vad import NonSpeechStrategy, VadPeriodicTranscription, VadSileroTranscription
18
 
19
  # Limitations (set to -1 to disable)
20
  DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
 
94
 
95
  # The results
96
  if (vad == 'silero-vad'):
97
+ # Silero VAD where non-speech gaps are transcribed
98
+ process_gaps = self._create_silero_vad(NonSpeechStrategy.CREATE_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
 
 
 
 
 
 
99
  result = process_gaps.transcribe(audio_path, whisperCallable)
100
  elif (vad == 'silero-vad-skip-gaps'):
101
+ # Silero VAD where non-speech gaps are simply ignored
102
+ skip_gaps = self._create_silero_vad(NonSpeechStrategy.SKIP, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
 
 
 
 
 
 
103
  result = skip_gaps.transcribe(audio_path, whisperCallable)
104
+ elif (vad == 'silero-vad-expand-into-gaps'):
105
+ # Use Silero VAD where speech-segments are expanded into non-speech gaps
106
+ expand_gaps = self._create_silero_vad(NonSpeechStrategy.EXPAND_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
107
+ result = expand_gaps.transcribe(audio_path, whisperCallable)
108
  elif (vad == 'periodic-vad'):
109
  # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
110
  # it may create a break in the middle of a sentence, causing some artifacts.
 
116
 
117
  return result
118
 
119
+ def _create_silero_vad(self, non_speech_strategy: NonSpeechStrategy, vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1):
120
+ # Use Silero VAD
121
+ if (self.vad_model is None):
122
+ self.vad_model = VadSileroTranscription()
123
+
124
+ result = VadSileroTranscription(non_speech_strategy = non_speech_strategy,
125
+ max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize,
126
+ segment_padding_left=vadPadding, segment_padding_right=vadPadding,
127
+ max_prompt_window=vadPromptWindow, copy=self.vad_model)
128
+
129
+ return result
130
+
131
  def write_result(self, result: dict, source_name: str, output_dir: str):
132
  if not os.path.exists(output_dir):
133
  os.makedirs(output_dir)
 
222
  gr.Audio(source="upload", type="filepath", label="Upload Audio"),
223
  gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
224
  gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
225
+ gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], label="VAD"),
226
+ gr.Number(label="VAD - Merge Window (s)", precision=0, value=5),
227
  gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=30),
228
  gr.Number(label="VAD - Padding (s)", precision=None, value=1),
229
+ gr.Number(label="VAD - Prompt Window (s)", precision=None, value=3)
230
  ], outputs=[
231
  gr.File(label="Download"),
232
  gr.Text(label="Transcription"),
cli.py CHANGED
@@ -26,11 +26,11 @@ def cli():
26
  parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
27
  parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES), help="language spoken in the audio, specify None to perform language detection")
28
 
29
- parser.add_argument("--vad", type=str, default="none", choices=["none", "silero-vad", "silero-vad-skip-gaps", "periodic-vad"], help="The voice activity detection algorithm to use")
30
  parser.add_argument("--vad_merge_window", type=optional_float, default=5, help="The window size (in seconds) to merge voice segments")
31
- parser.add_argument("--vad_max_merge_size", type=optional_float, default=150, help="The maximum size (in seconds) of a voice segment")
32
  parser.add_argument("--vad_padding", type=optional_float, default=1, help="The padding (in seconds) to add to each voice segment")
33
- parser.add_argument("--vad_prompt_window", type=optional_float, default=0, help="The window size of the prompt to pass to Whisper")
34
 
35
  parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
36
  parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
 
26
  parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
27
  parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES), help="language spoken in the audio, specify None to perform language detection")
28
 
29
+ parser.add_argument("--vad", type=str, default="none", choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], help="The voice activity detection algorithm to use")
30
  parser.add_argument("--vad_merge_window", type=optional_float, default=5, help="The window size (in seconds) to merge voice segments")
31
+ parser.add_argument("--vad_max_merge_size", type=optional_float, default=30, help="The maximum size (in seconds) of a voice segment")
32
  parser.add_argument("--vad_padding", type=optional_float, default=1, help="The padding (in seconds) to add to each voice segment")
33
+ parser.add_argument("--vad_prompt_window", type=optional_float, default=3, help="The window size of the prompt to pass to Whisper")
34
 
35
  parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
36
  parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
src/vad.py CHANGED
@@ -1,6 +1,7 @@
1
  from abc import ABC, abstractmethod
2
  from collections import Counter, deque
3
- from typing import Any, Iterator, List, Dict
 
4
 
5
  from pprint import pprint
6
 
@@ -19,6 +20,20 @@ import numpy as np
19
  from src.utils import format_timestamp
20
  from enum import Enum
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  # Defaults for Silero
23
  SPEECH_TRESHOLD = 0.3
24
  MAX_SILENT_PERIOD = 10 # seconds
@@ -28,9 +43,6 @@ MAX_MERGE_SIZE = 150 # Do not create segments larger than 2.5 minutes
28
  SEGMENT_PADDING_LEFT = 1 # Start detected text segment early
29
  SEGMENT_PADDING_RIGHT = 1 # End detected segments late
30
 
31
- # Whether to attempt to transcribe non-speech
32
- TRANSCRIBE_NON_SPEECH = False
33
-
34
  # Minimum size of segments to process
35
  MIN_SEGMENT_DURATION = 1
36
 
@@ -46,13 +58,13 @@ VAD_MAX_PROCESSING_CHUNK = 60 * 60 # 60 minutes of audio
46
 
47
  class AbstractTranscription(ABC):
48
  def __init__(self, segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
49
- max_merge_size: float = None, transcribe_non_speech: bool = False, max_prompt_window: float = None):
50
  self.sampling_rate = 16000
51
  self.segment_padding_left = segment_padding_left
52
  self.segment_padding_right = segment_padding_right
53
  self.max_silent_period = max_silent_period
54
  self.max_merge_size = max_merge_size
55
- self.transcribe_non_speech = transcribe_non_speech
56
  self.max_prompt_window = max_prompt_window
57
 
58
  self.min_force_merge_gap = MIN_FORCE_MERGE_GAP
@@ -107,16 +119,18 @@ class AbstractTranscription(ABC):
107
  print("Timestamps:")
108
  pprint(merged)
109
 
110
- if self.transcribe_non_speech:
111
  max_audio_duration = get_audio_duration(audio)
112
 
113
  # Expand segments to include the gaps between them
114
- if (self.max_prompt_window is not None and self.max_prompt_window > 0):
115
  # When we have a prompt window, we create speech segments betwen each segment if we exceed the merge size
116
  merged = self.fill_gaps(merged, total_duration=max_audio_duration, max_expand_size=self.max_merge_size)
117
- else:
118
- # With no prompt window, it is better to expand the segments
119
  merged = self.expand_gaps(merged, total_duration=max_audio_duration)
 
 
120
 
121
  print("Transcribing non-speech:")
122
  pprint(merged)
@@ -150,6 +164,17 @@ class AbstractTranscription(ABC):
150
 
151
  adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
152
 
 
 
 
 
 
 
 
 
 
 
 
153
  # Append to output
154
  result['text'] += segment_result['text']
155
  result['segments'].extend(adjusted_segments)
@@ -158,20 +183,30 @@ class AbstractTranscription(ABC):
158
  languageCounter[segment_result['language']] += 1
159
 
160
  # Update prompt window
161
- if (self.max_prompt_window is not None and self.max_prompt_window > 0):
162
- # Add segments to the current prompt window
163
- for segment in adjusted_segments:
164
- if segment.get('no_speech_prob', 0) <= PROMPT_NO_SPEECH_PROB:
165
- prompt_window.append(segment)
166
-
167
- while (len(prompt_window) > 0 and prompt_window[0]['end'] < segment_end - self.max_prompt_window):
168
- prompt_window.popleft()
169
-
170
  if len(languageCounter) > 0:
171
  result['language'] = languageCounter.most_common(1)[0][0]
172
 
173
  return result
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  def include_gaps(self, segments: Iterator[dict], min_gap_length: float, total_duration: float):
176
  result = []
177
  last_end_time = 0
@@ -360,7 +395,8 @@ class AbstractTranscription(ABC):
360
  if distance <= max_merge_gap and (max_merge_size is None or current_entry_size <= max_merge_size):
361
  # Regular merge
362
  current_entry['end'] = entry['end']
363
- elif min_force_merge_gap is not None and distance <= min_force_merge_gap and (max_force_merge_size is None or current_entry_size <= max_force_merge_size):
 
364
  # Force merge if the distance is small (up to a certain maximum size)
365
  current_entry['end'] = entry['end']
366
  else:
@@ -389,10 +425,10 @@ class AbstractTranscription(ABC):
389
 
390
  class VadSileroTranscription(AbstractTranscription):
391
  def __init__(self, segment_padding_left=SEGMENT_PADDING_LEFT, segment_padding_right=SEGMENT_PADDING_RIGHT,
392
- max_silent_period=MAX_SILENT_PERIOD, max_merge_size=MAX_MERGE_SIZE, transcribe_non_speech: bool = False,
393
  max_prompt_window=MAX_PROMPT_WINDOW, copy = None):
394
  super().__init__(segment_padding_left=segment_padding_left, segment_padding_right=segment_padding_right,
395
- max_silent_period=max_silent_period, max_merge_size=max_merge_size, transcribe_non_speech=transcribe_non_speech, max_prompt_window=max_prompt_window)
396
 
397
  if copy:
398
  self.model = copy.model
 
1
  from abc import ABC, abstractmethod
2
  from collections import Counter, deque
3
+
4
+ from typing import Any, Deque, Iterator, List, Dict
5
 
6
  from pprint import pprint
7
 
 
20
  from src.utils import format_timestamp
21
  from enum import Enum
22
 
23
+ class NonSpeechStrategy(Enum):
24
+ """
25
+ Ignore non-speech frames segments.
26
+ """
27
+ SKIP = 1
28
+ """
29
+ Just treat non-speech segments as speech.
30
+ """
31
+ CREATE_SEGMENT = 2
32
+ """
33
+ Expand speech segments into subsequent non-speech segments.
34
+ """
35
+ EXPAND_SEGMENT = 3
36
+
37
  # Defaults for Silero
38
  SPEECH_TRESHOLD = 0.3
39
  MAX_SILENT_PERIOD = 10 # seconds
 
43
  SEGMENT_PADDING_LEFT = 1 # Start detected text segment early
44
  SEGMENT_PADDING_RIGHT = 1 # End detected segments late
45
 
 
 
 
46
  # Minimum size of segments to process
47
  MIN_SEGMENT_DURATION = 1
48
 
 
58
 
59
  class AbstractTranscription(ABC):
60
  def __init__(self, segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
61
+ max_merge_size: float = None, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP, max_prompt_window: float = None):
62
  self.sampling_rate = 16000
63
  self.segment_padding_left = segment_padding_left
64
  self.segment_padding_right = segment_padding_right
65
  self.max_silent_period = max_silent_period
66
  self.max_merge_size = max_merge_size
67
+ self.non_speech_strategy = non_speech_strategy
68
  self.max_prompt_window = max_prompt_window
69
 
70
  self.min_force_merge_gap = MIN_FORCE_MERGE_GAP
 
119
  print("Timestamps:")
120
  pprint(merged)
121
 
122
+ if self.non_speech_strategy != NonSpeechStrategy.SKIP:
123
  max_audio_duration = get_audio_duration(audio)
124
 
125
  # Expand segments to include the gaps between them
126
+ if (self.non_speech_strategy == NonSpeechStrategy.CREATE_SEGMENT):
127
  # When we have a prompt window, we create speech segments betwen each segment if we exceed the merge size
128
  merged = self.fill_gaps(merged, total_duration=max_audio_duration, max_expand_size=self.max_merge_size)
129
+ elif self.non_speech_strategy == NonSpeechStrategy.EXPAND_SEGMENT:
130
+ # With no prompt window, it is better to just expand the segments (this effectively passes the prompt to the next segment)
131
  merged = self.expand_gaps(merged, total_duration=max_audio_duration)
132
+ else:
133
+ raise Exception("Unknown non-speech strategy: " + str(self.non_speech_strategy))
134
 
135
  print("Transcribing non-speech:")
136
  pprint(merged)
 
164
 
165
  adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
166
 
167
+ # Propagate expand amount to the segments
168
+ if (segment_expand_amount > 0):
169
+ segment_without_expansion = segment_duration - segment_expand_amount
170
+
171
+ for adjusted_segment in adjusted_segments:
172
+ adjusted_segment_end = adjusted_segment['end']
173
+
174
+ # Add expand amount if the segment got expanded
175
+ if (adjusted_segment_end > segment_without_expansion):
176
+ adjusted_segment["expand_amount"] = adjusted_segment_end - segment_without_expansion
177
+
178
  # Append to output
179
  result['text'] += segment_result['text']
180
  result['segments'].extend(adjusted_segments)
 
183
  languageCounter[segment_result['language']] += 1
184
 
185
  # Update prompt window
186
+ self.__update_prompt_window(prompt_window, adjusted_segments, segment_end)
187
+
 
 
 
 
 
 
 
188
  if len(languageCounter) > 0:
189
  result['language'] = languageCounter.most_common(1)[0][0]
190
 
191
  return result
192
 
193
+ def __update_prompt_window(self, prompt_window: Deque, adjusted_segments: List, segment_end: float):
194
+ if (self.max_prompt_window is not None and self.max_prompt_window > 0):
195
+ # Add segments to the current prompt window
196
+ for segment in adjusted_segments:
197
+ if segment.get('no_speech_prob', 0) <= PROMPT_NO_SPEECH_PROB:
198
+ prompt_window.append(segment)
199
+
200
+ while (len(prompt_window) > 0):
201
+ first_end_time = prompt_window[0].get('end', 0)
202
+ # Time expanded in the segments should be discounted from the prompt window
203
+ first_expand_time = prompt_window[0].get('expand_amount', 0)
204
+
205
+ if (first_end_time - first_expand_time < segment_end - self.max_prompt_window):
206
+ prompt_window.popleft()
207
+ else:
208
+ break
209
+
210
  def include_gaps(self, segments: Iterator[dict], min_gap_length: float, total_duration: float):
211
  result = []
212
  last_end_time = 0
 
395
  if distance <= max_merge_gap and (max_merge_size is None or current_entry_size <= max_merge_size):
396
  # Regular merge
397
  current_entry['end'] = entry['end']
398
+ elif min_force_merge_gap is not None and distance <= min_force_merge_gap and \
399
+ (max_force_merge_size is None or current_entry_size <= max_force_merge_size):
400
  # Force merge if the distance is small (up to a certain maximum size)
401
  current_entry['end'] = entry['end']
402
  else:
 
425
 
426
  class VadSileroTranscription(AbstractTranscription):
427
  def __init__(self, segment_padding_left=SEGMENT_PADDING_LEFT, segment_padding_right=SEGMENT_PADDING_RIGHT,
428
+ max_silent_period=MAX_SILENT_PERIOD, max_merge_size=MAX_MERGE_SIZE, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
429
  max_prompt_window=MAX_PROMPT_WINDOW, copy = None):
430
  super().__init__(segment_padding_left=segment_padding_left, segment_padding_right=segment_padding_right,
431
+ max_silent_period=max_silent_period, max_merge_size=max_merge_size, non_speech_strategy=non_speech_strategy, max_prompt_window=max_prompt_window)
432
 
433
  if copy:
434
  self.model = copy.model