aadnk commited on
Commit
bc0cb58
·
1 Parent(s): 55fc070

Transcribe non-speech areas too in "silero vad"

Browse files

The old version is now in "silero-vad-skip-gaps". This
may introduce more noise in the transcript, but some of it
will be correct as well.

Files changed (2) hide show
  1. app.py +10 -3
  2. vad.py +60 -11
app.py CHANGED
@@ -72,10 +72,17 @@ class UI:
72
 
73
  # The results
74
  if (vad == 'silero-vad'):
75
- # Use Silero VAD
76
  if (self.vad_model is None):
77
- self.vad_model = VadSileroTranscription()
78
  result = self.vad_model.transcribe(source, whisperCallable)
 
 
 
 
 
 
 
79
  elif (vad == 'periodic-vad'):
80
  # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
81
  # it may create a break in the middle of a sentence, causing some artifacts.
@@ -184,7 +191,7 @@ def createUi(inputAudioMaxDuration, share=False, server_name: str = None):
184
  gr.Audio(source="upload", type="filepath", label="Upload Audio"),
185
  gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
186
  gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
187
- gr.Dropdown(choices=["none", "silero-vad", "periodic-vad"], label="VAD"),
188
  ], outputs=[
189
  gr.File(label="Download"),
190
  gr.Text(label="Transcription"),
 
72
 
73
  # The results
74
  if (vad == 'silero-vad'):
75
+ # Use Silero VAD and include gaps
76
  if (self.vad_model is None):
77
+ self.vad_model = VadSileroTranscription(transcribe_non_speech= True)
78
  result = self.vad_model.transcribe(source, whisperCallable)
79
+ elif (vad == 'silero-vad-skip-gaps'):
80
+ # Use Silero VAD
81
+ if (self.vad_model is None):
82
+ self.vad_model = VadSileroTranscription(transcribe_non_speech= True)
83
+
84
+ skip_gaps = VadSileroTranscription(transcribe_non_speech = False, copy=self.vad_model)
85
+ result = skip_gaps.transcribe(source, whisperCallable)
86
  elif (vad == 'periodic-vad'):
87
  # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
88
  # it may create a break in the middle of a sentence, causing some artifacts.
 
191
  gr.Audio(source="upload", type="filepath", label="Upload Audio"),
192
  gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
193
  gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
194
+ gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "periodic-vad"], label="VAD"),
195
  ], outputs=[
196
  gr.File(label="Download"),
197
  gr.Text(label="Transcription"),
vad.py CHANGED
@@ -9,19 +9,24 @@ import torch
9
  import ffmpeg
10
  import numpy as np
11
 
12
- # Defaults
 
 
13
  SPEECH_TRESHOLD = 0.3
14
  MAX_SILENT_PERIOD = 10 # seconds
15
  SEGMENT_PADDING_LEFT = 1 # Start detected text segment early
16
- SEGMENT_PADDING_RIGHT = 4 # End detected segments late
17
 
 
 
18
 
19
  class AbstractTranscription(ABC):
20
- def __init__(self, segment_padding_left: int = None, segment_padding_right = None, max_silent_period: int = None):
21
  self.sampling_rate = 16000
22
  self.segment_padding_left = segment_padding_left
23
  self.segment_padding_right = segment_padding_right
24
  self.max_silent_period = max_silent_period
 
25
 
26
  def get_audio_segment(self, str, start_time: str = None, duration: str = None):
27
  return load_audio(str, self.sampling_rate, start_time, duration)
@@ -68,6 +73,13 @@ class AbstractTranscription(ABC):
68
  print("Timestamps:")
69
  pprint(merged)
70
 
 
 
 
 
 
 
 
71
  result = {
72
  'text': "",
73
  'segments': [],
@@ -78,12 +90,19 @@ class AbstractTranscription(ABC):
78
  # For each time segment, run whisper
79
  for segment in merged:
80
  segment_start = segment['start']
81
- segment_duration = segment['end'] - segment_start
 
 
 
82
 
83
  segment_audio = self.get_audio_segment(audio, start_time = str(segment_start) + "s", duration = str(segment_duration) + "s")
84
 
85
- print("Running whisper on " + str(segment_start) + ", duration: " + str(segment_duration))
86
- segment_result = whisperCallable(segment_audio)
 
 
 
 
87
  adjusted_segments = self.adjust_whisper_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
88
 
89
  # Append to output
@@ -98,6 +117,32 @@ class AbstractTranscription(ABC):
98
 
99
  return result
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  def adjust_whisper_timestamp(self, segments: Iterator[dict], adjust_seconds: float, max_source_time: float = None):
102
  result = []
103
 
@@ -178,11 +223,15 @@ class AbstractTranscription(ABC):
178
  return result
179
 
180
  class VadSileroTranscription(AbstractTranscription):
181
- def __init__(self):
182
- super().__init__(SEGMENT_PADDING_LEFT, SEGMENT_PADDING_RIGHT, MAX_SILENT_PERIOD)
183
- self.model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
184
-
185
- (self.get_speech_timestamps, _, _, _, _) = utils
 
 
 
 
186
 
187
  def get_transcribe_timestamps(self, audio: str):
188
  wav = self.get_audio_segment(audio)
 
9
  import ffmpeg
10
  import numpy as np
11
 
12
+ from utils import format_timestamp
13
+
14
+ # Defaults for Silero
15
  SPEECH_TRESHOLD = 0.3
16
  MAX_SILENT_PERIOD = 10 # seconds
17
  SEGMENT_PADDING_LEFT = 1 # Start detected text segment early
18
+ SEGMENT_PADDING_RIGHT = 3 # End detected segments late
19
 
20
+ # Whether to attempt to transcribe non-speech
21
+ TRANSCRIBE_NON_SPEECH = False
22
 
23
  class AbstractTranscription(ABC):
24
+ def __init__(self, segment_padding_left: int = None, segment_padding_right = None, max_silent_period: int = None, transcribe_non_speech: bool = False):
25
  self.sampling_rate = 16000
26
  self.segment_padding_left = segment_padding_left
27
  self.segment_padding_right = segment_padding_right
28
  self.max_silent_period = max_silent_period
29
+ self.transcribe_non_speech = transcribe_non_speech
30
 
31
  def get_audio_segment(self, str, start_time: str = None, duration: str = None):
32
  return load_audio(str, self.sampling_rate, start_time, duration)
 
73
  print("Timestamps:")
74
  pprint(merged)
75
 
76
+ if self.transcribe_non_speech:
77
+ max_audio_duration = float(ffmpeg.probe(audio)["format"]["duration"])
78
+ merged = self.include_gaps(merged, min_gap_length=5, total_duration=max_audio_duration)
79
+
80
+ print("Transcribing non-speech:")
81
+ pprint(merged)
82
+
83
  result = {
84
  'text': "",
85
  'segments': [],
 
90
  # For each time segment, run whisper
91
  for segment in merged:
92
  segment_start = segment['start']
93
+ segment_end = segment['end']
94
+ segment_gap = segment.get('gap', False)
95
+
96
+ segment_duration = segment_end - segment_start
97
 
98
  segment_audio = self.get_audio_segment(audio, start_time = str(segment_start) + "s", duration = str(segment_duration) + "s")
99
 
100
+ print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ", segment_duration, "gap: ", segment_gap)
101
+ if segment_gap:
102
+ # TODO: Use different parameters for these segments, as they are less likely to contain speech
103
+ segment_result = whisperCallable(segment_audio)
104
+ else:
105
+ segment_result = whisperCallable(segment_audio)
106
  adjusted_segments = self.adjust_whisper_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
107
 
108
  # Append to output
 
117
 
118
  return result
119
 
120
+ def include_gaps(self, segments: Iterator[dict], min_gap_length: float, total_duration: float):
121
+ result = []
122
+ last_end_time = 0
123
+
124
+ for segment in segments:
125
+ segment_start = float(segment['start'])
126
+ segment_end = float(segment['end'])
127
+
128
+ if (last_end_time != segment_start):
129
+ delta = segment_start - last_end_time
130
+
131
+ if (min_gap_length is None or delta >= min_gap_length):
132
+ result.append( { 'start': last_end_time, 'end': segment_start, 'gap': True } )
133
+
134
+ last_end_time = segment_end
135
+ result.append(segment)
136
+
137
+ # Also include total duration if specified
138
+ if (total_duration is not None and last_end_time < total_duration):
139
+ delta = total_duration - segment_start
140
+
141
+ if (min_gap_length is None or delta >= min_gap_length):
142
+ result.append( { 'start': last_end_time, 'end': total_duration, 'gap': True } )
143
+
144
+ return result
145
+
146
  def adjust_whisper_timestamp(self, segments: Iterator[dict], adjust_seconds: float, max_source_time: float = None):
147
  result = []
148
 
 
223
  return result
224
 
225
  class VadSileroTranscription(AbstractTranscription):
226
+ def __init__(self, transcribe_non_speech: bool = False, copy = None):
227
+ super().__init__(SEGMENT_PADDING_LEFT, SEGMENT_PADDING_RIGHT, MAX_SILENT_PERIOD, transcribe_non_speech)
228
+
229
+ if copy:
230
+ self.model = copy.model
231
+ self.get_speech_timestamps = copy.get_speech_timestamps
232
+ else:
233
+ self.model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
234
+ (self.get_speech_timestamps, _, _, _, _) = utils
235
 
236
  def get_transcribe_timestamps(self, audio: str):
237
  wav = self.get_audio_segment(audio)