aadnk commited on
Commit
0cb931d
·
1 Parent(s): d906b98

Refactor pad and merge timestamps into one function

Browse files

This also fixes a bunch of issues regarding when the timestamps
should be merged.

Files changed (3) hide show
  1. src/segments.py +47 -0
  2. src/vad.py +6 -66
  3. tests/segments_test.py +48 -0
src/segments.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ import copy
4
+
5
+ def merge_timestamps(timestamps: List[Dict[str, Any]], merge_window: float = 5, max_merge_size: float = 30, padding_left: float = 1, padding_right: float = 1):
6
+ result = []
7
+
8
+ if len(timestamps) == 0:
9
+ return result
10
+
11
+ processed_time = 0
12
+ current_segment = None
13
+
14
+ for i in range(len(timestamps)):
15
+ next_segment = timestamps[i]
16
+
17
+ delta = next_segment['start'] - processed_time
18
+
19
+ # Note that segments can still be longer than the max merge size, they just won't be merged in that case
20
+ if current_segment is None or delta > merge_window or next_segment['end'] - current_segment['start'] > max_merge_size:
21
+ # Finish the current segment
22
+ if current_segment is not None:
23
+ # Add right padding
24
+ finish_padding = min(padding_right, delta / 2) if delta < padding_left + padding_right else padding_right
25
+ current_segment['end'] += finish_padding
26
+ delta -= finish_padding
27
+
28
+ result.append(current_segment)
29
+
30
+ # Start a new segment
31
+ current_segment = copy.deepcopy(next_segment)
32
+
33
+ # Pad the segment
34
+ current_segment['start'] = current_segment['start'] - min(padding_left, delta)
35
+ processed_time = current_segment['end']
36
+
37
+ else:
38
+ # Merge the segment
39
+ current_segment['end'] = next_segment['end']
40
+ processed_time = current_segment['end']
41
+
42
+ # Add the last segment
43
+ if current_segment is not None:
44
+ current_segment['end'] += padding_right
45
+ result.append(current_segment)
46
+
47
+ return result
src/vad.py CHANGED
@@ -5,6 +5,8 @@ from typing import Any, Deque, Iterator, List, Dict
5
 
6
  from pprint import pprint
7
 
 
 
8
  # Workaround for https://github.com/tensorflow/tensorflow/issues/48797
9
  try:
10
  import tensorflow as tf
@@ -110,8 +112,10 @@ class AbstractTranscription(ABC):
110
  # get speech timestamps from full audio file
111
  seconds_timestamps = self.get_transcribe_timestamps(audio)
112
 
113
- padded = self.pad_timestamps(seconds_timestamps, self.segment_padding_left, self.segment_padding_right)
114
- merged = self.merge_timestamps(padded, self.max_silent_period, self.max_merge_size, self.min_force_merge_gap, self.max_force_merge_size)
 
 
115
 
116
  # A deque of transcribed segments that is passed to the next segment as a prompt
117
  prompt_window = deque()
@@ -346,70 +350,6 @@ class AbstractTranscription(ABC):
346
  result.append(new_segment)
347
  return result
348
 
349
- def pad_timestamps(self, timestamps: List[Dict[str, Any]], padding_left: float, padding_right: float):
350
- if (padding_left == 0 and padding_right == 0):
351
- return timestamps
352
-
353
- result = []
354
- prev_entry = None
355
-
356
- for i in range(len(timestamps)):
357
- curr_entry = timestamps[i]
358
- next_entry = timestamps[i + 1] if i < len(timestamps) - 1 else None
359
-
360
- segment_start = curr_entry['start']
361
- segment_end = curr_entry['end']
362
-
363
- if padding_left is not None:
364
- segment_start = max(prev_entry['end'] if prev_entry else 0, segment_start - padding_left)
365
- if padding_right is not None:
366
- segment_end = segment_end + padding_right
367
-
368
- # Do not pad past the next segment
369
- if (next_entry is not None):
370
- segment_end = min(next_entry['start'], segment_end)
371
-
372
- new_entry = { 'start': segment_start, 'end': segment_end }
373
- prev_entry = new_entry
374
- result.append(new_entry)
375
-
376
- return result
377
-
378
- def merge_timestamps(self, timestamps: List[Dict[str, Any]], max_merge_gap: float, max_merge_size: float,
379
- min_force_merge_gap: float, max_force_merge_size: float):
380
- if max_merge_gap is None:
381
- return timestamps
382
-
383
- result = []
384
- current_entry = None
385
-
386
- for entry in timestamps:
387
- if current_entry is None:
388
- current_entry = entry
389
- continue
390
-
391
- # Get distance to the previous entry
392
- distance = entry['start'] - current_entry['end']
393
- current_entry_size = current_entry['end'] - current_entry['start']
394
-
395
- if distance <= max_merge_gap and (max_merge_size is None or current_entry_size <= max_merge_size):
396
- # Regular merge
397
- current_entry['end'] = entry['end']
398
- elif min_force_merge_gap is not None and distance <= min_force_merge_gap and \
399
- (max_force_merge_size is None or current_entry_size <= max_force_merge_size):
400
- # Force merge if the distance is small (up to a certain maximum size)
401
- current_entry['end'] = entry['end']
402
- else:
403
- # Output current entry
404
- result.append(current_entry)
405
- current_entry = entry
406
-
407
- # Add final entry
408
- if current_entry is not None:
409
- result.append(current_entry)
410
-
411
- return result
412
-
413
  def multiply_timestamps(self, timestamps: List[Dict[str, Any]], factor: float):
414
  result = []
415
 
 
5
 
6
  from pprint import pprint
7
 
8
+ from src.segments import merge_timestamps
9
+
10
  # Workaround for https://github.com/tensorflow/tensorflow/issues/48797
11
  try:
12
  import tensorflow as tf
 
112
  # get speech timestamps from full audio file
113
  seconds_timestamps = self.get_transcribe_timestamps(audio)
114
 
115
+ #for seconds_timestamp in seconds_timestamps:
116
+ # print("VAD timestamp ", format_timestamp(seconds_timestamp['start']), " to ", format_timestamp(seconds_timestamp['end']))
117
+
118
+ merged = merge_timestamps(seconds_timestamps, self.max_silent_period, self.max_merge_size, self.segment_padding_left, self.segment_padding_right)
119
 
120
  # A deque of transcribed segments that is passed to the next segment as a prompt
121
  prompt_window = deque()
 
350
  result.append(new_segment)
351
  return result
352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  def multiply_timestamps(self, timestamps: List[Dict[str, Any]], factor: float):
354
  result = []
355
 
tests/segments_test.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import unittest
3
+
4
+ sys.path.append('../whisper-webui')
5
+
6
+ from src.segments import merge_timestamps
7
+
8
+ class TestSegments(unittest.TestCase):
9
+ def __init__(self, *args, **kwargs):
10
+ super(TestSegments, self).__init__(*args, **kwargs)
11
+
12
+ def test_merge_segments(self):
13
+ segments = [
14
+ {'start': 10.0, 'end': 20.0},
15
+ {'start': 22.0, 'end': 27.0},
16
+ {'start': 31.0, 'end': 35.0},
17
+ {'start': 45.0, 'end': 60.0},
18
+ {'start': 61.0, 'end': 65.0},
19
+ {'start': 68.0, 'end': 98.0},
20
+ {'start': 100.0, 'end': 102.0},
21
+ {'start': 110.0, 'end': 112.0}
22
+ ]
23
+
24
+ result = merge_timestamps(segments, merge_window=5, max_merge_size=30, padding_left=1, padding_right=1)
25
+
26
+ self.assertListEqual(result, [
27
+ {'start': 9.0, 'end': 36.0},
28
+ {'start': 44.0, 'end': 66.0},
29
+ {'start': 67.0, 'end': 99.0},
30
+ {'start': 99.0, 'end': 103.0},
31
+ {'start': 109.0, 'end': 113.0}
32
+ ])
33
+
34
+ def test_overlap_next(self):
35
+ segments = [
36
+ {'start': 5.0, 'end': 39.182},
37
+ {'start': 39.986, 'end': 40.814}
38
+ ]
39
+
40
+ result = merge_timestamps(segments, merge_window=5, max_merge_size=30, padding_left=1, padding_right=1)
41
+
42
+ self.assertListEqual(result, [
43
+ {'start': 4.0, 'end': 39.584},
44
+ {'start': 39.584, 'end': 41.814}
45
+ ])
46
+
47
+ if __name__ == '__main__':
48
+ unittest.main()