aadnk commited on
Commit
8781620
·
1 Parent(s): c52f09b

Fix merging of VAD segments

Browse files

Also expand VAD segments rather than running whisper
on non-speech segments. That way, whisper will use the
previous section as a prompt, in case there is speech in
the non-speech section after all.

Files changed (1) hide show
  1. src/vad.py +61 -13
src/vad.py CHANGED
@@ -1,6 +1,7 @@
1
  from abc import ABC, abstractmethod
2
  from collections import Counter
3
  from dis import dis
 
4
  from typing import Any, Iterator, List, Dict
5
 
6
  from pprint import pprint
@@ -27,7 +28,7 @@ MAX_SILENT_PERIOD = 10 # seconds
27
  MAX_MERGE_SIZE = 150 # Do not create segments larger than 2.5 minutes
28
 
29
  SEGMENT_PADDING_LEFT = 1 # Start detected text segment early
30
- SEGMENT_PADDING_RIGHT = 3 # End detected segments late
31
 
32
  # Whether to attempt to transcribe non-speech
33
  TRANSCRIBE_NON_SPEECH = False
@@ -91,7 +92,9 @@ class AbstractTranscription(ABC):
91
 
92
  if self.transcribe_non_speech:
93
  max_audio_duration = float(ffmpeg.probe(audio)["format"]["duration"])
94
- merged = self.include_gaps(merged, min_gap_length=5, total_duration=max_audio_duration)
 
 
95
 
96
  print("Transcribing non-speech:")
97
  pprint(merged)
@@ -107,7 +110,7 @@ class AbstractTranscription(ABC):
107
  for segment in merged:
108
  segment_start = segment['start']
109
  segment_end = segment['end']
110
- segment_gap = segment.get('gap', False)
111
 
112
  segment_duration = segment_end - segment_start
113
 
@@ -116,12 +119,9 @@ class AbstractTranscription(ABC):
116
 
117
  segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration))
118
 
119
- print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ", segment_duration, "gap: ", segment_gap)
120
- if segment_gap:
121
- # TODO: Use different parameters for these segments, as they are less likely to contain speech
122
- segment_result = whisperCallable(segment_audio)
123
- else:
124
- segment_result = whisperCallable(segment_audio)
125
  adjusted_segments = self.adjust_whisper_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
126
 
127
  # Append to output
@@ -162,6 +162,44 @@ class AbstractTranscription(ABC):
162
 
163
  return result
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  def adjust_whisper_timestamp(self, segments: Iterator[dict], adjust_seconds: float, max_source_time: float = None):
166
  result = []
167
 
@@ -184,17 +222,27 @@ class AbstractTranscription(ABC):
184
  return result
185
 
186
  def pad_timestamps(self, timestamps: List[Dict[str, Any]], padding_left: float, padding_right: float):
 
 
187
  result = []
188
 
189
- for entry in timestamps:
190
- segment_start = entry['start']
191
- segment_end = entry['end']
 
 
 
 
192
 
193
  if padding_left is not None:
194
- segment_start = max(0, segment_start - padding_left)
195
  if padding_right is not None:
196
  segment_end = segment_end + padding_right
197
 
 
 
 
 
198
  result.append({ 'start': segment_start, 'end': segment_end })
199
 
200
  return result
 
1
  from abc import ABC, abstractmethod
2
  from collections import Counter
3
  from dis import dis
4
+ import re
5
  from typing import Any, Iterator, List, Dict
6
 
7
  from pprint import pprint
 
28
  MAX_MERGE_SIZE = 150 # Do not create segments larger than 2.5 minutes
29
 
30
  SEGMENT_PADDING_LEFT = 1 # Start detected text segment early
31
+ SEGMENT_PADDING_RIGHT = 1 # End detected segments late
32
 
33
  # Whether to attempt to transcribe non-speech
34
  TRANSCRIBE_NON_SPEECH = False
 
92
 
93
  if self.transcribe_non_speech:
94
  max_audio_duration = float(ffmpeg.probe(audio)["format"]["duration"])
95
+
96
+ # Expand segments to include the gaps between them
97
+ merged = self.expand_gaps(merged, total_duration=max_audio_duration)
98
 
99
  print("Transcribing non-speech:")
100
  pprint(merged)
 
110
  for segment in merged:
111
  segment_start = segment['start']
112
  segment_end = segment['end']
113
+ segment_expand_amount = segment.get('expand_amount', 0)
114
 
115
  segment_duration = segment_end - segment_start
116
 
 
119
 
120
  segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration))
121
 
122
+ print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ", segment_duration, "expanded: ", segment_expand_amount)
123
+ segment_result = whisperCallable(segment_audio)
124
+
 
 
 
125
  adjusted_segments = self.adjust_whisper_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
126
 
127
  # Append to output
 
162
 
163
  return result
164
 
165
+ # Expand the end time of each segment to the start of the next segment
166
+ def expand_gaps(self, segments: List[Dict[str, Any]], total_duration: float):
167
+ result = []
168
+
169
+ if len(segments) == 0:
170
+ return result
171
+
172
+ # Add gap at the beginning if needed
173
+ if (segments[0]['start'] > 0):
174
+ result.append({ 'start': 0, 'end': segments[0]['start'], 'gap': True } )
175
+
176
+ for i in range(len(segments) - 1):
177
+ current_segment = segments[i]
178
+ next_segment = segments[i + 1]
179
+
180
+ delta = next_segment['start'] - current_segment['end']
181
+
182
+ # Expand if the gap actually exists
183
+ if (delta >= 0):
184
+ current_segment = current_segment.copy()
185
+ current_segment['expand_amount'] = delta
186
+ current_segment['end'] = next_segment['start']
187
+
188
+ result.append(current_segment)
189
+
190
+ last_segment = result[-1]
191
+
192
+ # Also include total duration if specified
193
+ if (total_duration is not None):
194
+ last_segment = result[-1]
195
+
196
+ if (last_segment['end'] < total_duration):
197
+ last_segment = last_segment.copy()
198
+ last_segment['end'] = total_duration
199
+ result[-1] = last_segment
200
+
201
+ return result
202
+
203
  def adjust_whisper_timestamp(self, segments: Iterator[dict], adjust_seconds: float, max_source_time: float = None):
204
  result = []
205
 
 
222
  return result
223
 
224
  def pad_timestamps(self, timestamps: List[Dict[str, Any]], padding_left: float, padding_right: float):
225
+ if (padding_left == 0 and padding_right == 0):
226
+ return timestamps
227
  result = []
228
 
229
+ for i in range(len(timestamps)):
230
+ prev_entry = timestamps[i - 1] if i > 0 else None
231
+ curr_entry = timestamps[i]
232
+ next_entry = timestamps[i + 1] if i < len(timestamps) - 1 else None
233
+
234
+ segment_start = curr_entry['start']
235
+ segment_end = curr_entry['end']
236
 
237
  if padding_left is not None:
238
+ segment_start = max(prev_entry['end'] if prev_entry else 0, segment_start - padding_left)
239
  if padding_right is not None:
240
  segment_end = segment_end + padding_right
241
 
242
+ # Do not pad past the next segment
243
+ if (next_entry is not None):
244
+ segment_end = min(next_entry['start'], segment_end)
245
+
246
  result.append({ 'start': segment_start, 'end': segment_end })
247
 
248
  return result