aadnk commited on
Commit
48d8572
·
1 Parent(s): 1217d8b

Add prompt window to VAD

Browse files

This allows each speech section to access the text of the previous
speech section if it ends within a certain time window.

Files changed (4) hide show
  1. app.py +13 -10
  2. cli.py +3 -1
  3. src/utils.py +3 -3
  4. src/vad.py +106 -15
app.py CHANGED
@@ -53,7 +53,7 @@ class WhisperTranscriber:
53
  self.inputAudioMaxDuration = inputAudioMaxDuration
54
  self.deleteUploadedFiles = deleteUploadedFiles
55
 
56
- def transcribe_webui(self, modelName, languageName, urlData, uploadFile, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding):
57
  try:
58
  source, sourceName = self.__get_source(urlData, uploadFile, microphoneData)
59
 
@@ -68,7 +68,7 @@ class WhisperTranscriber:
68
  self.model_cache[selectedModel] = model
69
 
70
  # Execute whisper
71
- result = self.transcribe_file(model, source, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding)
72
 
73
  # Write result
74
  downloadDirectory = tempfile.mkdtemp()
@@ -88,9 +88,9 @@ class WhisperTranscriber:
88
  return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
89
 
90
  def transcribe_file(self, model: whisper.Whisper, audio_path: str, language: str, task: str = None, vad: str = None,
91
- vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, **decodeOptions: dict):
92
  # Callable for processing an audio file
93
- whisperCallable = lambda audio : model.transcribe(audio, language=language, task=task, **decodeOptions)
94
 
95
  # The results
96
  if (vad == 'silero-vad'):
@@ -100,7 +100,8 @@ class WhisperTranscriber:
100
 
101
  process_gaps = VadSileroTranscription(transcribe_non_speech = True,
102
  max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize,
103
- segment_padding_left=vadPadding, segment_padding_right=vadPadding, copy=self.vad_model)
 
104
  result = process_gaps.transcribe(audio_path, whisperCallable)
105
  elif (vad == 'silero-vad-skip-gaps'):
106
  # Use Silero VAD
@@ -109,7 +110,8 @@ class WhisperTranscriber:
109
 
110
  skip_gaps = VadSileroTranscription(transcribe_non_speech = False,
111
  max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize,
112
- segment_padding_left=vadPadding, segment_padding_right=vadPadding, copy=self.vad_model)
 
113
  result = skip_gaps.transcribe(audio_path, whisperCallable)
114
  elif (vad == 'periodic-vad'):
115
  # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
@@ -118,7 +120,7 @@ class WhisperTranscriber:
118
  result = periodic_vad.transcribe(audio_path, whisperCallable)
119
  else:
120
  # Default VAD
121
- result = whisperCallable(audio_path)
122
 
123
  return result
124
 
@@ -217,9 +219,10 @@ def create_ui(inputAudioMaxDuration, share=False, server_name: str = None):
217
  gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
218
  gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
219
  gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "periodic-vad"], label="VAD"),
220
- gr.Number(label="VAD - Merge Window (s)", precision=0, value=5),
221
- gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=150),
222
- gr.Number(label="VAD - Padding (s)", precision=None, value=1)
 
223
  ], outputs=[
224
  gr.File(label="Download"),
225
  gr.Text(label="Transcription"),
 
53
  self.inputAudioMaxDuration = inputAudioMaxDuration
54
  self.deleteUploadedFiles = deleteUploadedFiles
55
 
56
+ def transcribe_webui(self, modelName, languageName, urlData, uploadFile, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow):
57
  try:
58
  source, sourceName = self.__get_source(urlData, uploadFile, microphoneData)
59
 
 
68
  self.model_cache[selectedModel] = model
69
 
70
  # Execute whisper
71
+ result = self.transcribe_file(model, source, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
72
 
73
  # Write result
74
  downloadDirectory = tempfile.mkdtemp()
 
88
  return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
89
 
90
  def transcribe_file(self, model: whisper.Whisper, audio_path: str, language: str, task: str = None, vad: str = None,
91
+ vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1, **decodeOptions: dict):
92
  # Callable for processing an audio file
93
+ whisperCallable = lambda audio, prompt : model.transcribe(audio, language=language, task=task, initial_prompt=prompt, **decodeOptions)
94
 
95
  # The results
96
  if (vad == 'silero-vad'):
 
100
 
101
  process_gaps = VadSileroTranscription(transcribe_non_speech = True,
102
  max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize,
103
+ segment_padding_left=vadPadding, segment_padding_right=vadPadding,
104
+ max_prompt_window=vadPromptWindow, copy=self.vad_model)
105
  result = process_gaps.transcribe(audio_path, whisperCallable)
106
  elif (vad == 'silero-vad-skip-gaps'):
107
  # Use Silero VAD
 
110
 
111
  skip_gaps = VadSileroTranscription(transcribe_non_speech = False,
112
  max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize,
113
+ segment_padding_left=vadPadding, segment_padding_right=vadPadding,
114
+ max_prompt_window=vadPromptWindow, copy=self.vad_model)
115
  result = skip_gaps.transcribe(audio_path, whisperCallable)
116
  elif (vad == 'periodic-vad'):
117
  # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
 
120
  result = periodic_vad.transcribe(audio_path, whisperCallable)
121
  else:
122
  # Default VAD
123
+ result = whisperCallable(audio_path, None)
124
 
125
  return result
126
 
 
219
  gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
220
  gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
221
  gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "periodic-vad"], label="VAD"),
222
+ gr.Number(label="VAD - Merge Window (s)", precision=0, value=4),
223
+ gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=30),
224
+ gr.Number(label="VAD - Padding (s)", precision=None, value=1),
225
+ gr.Number(label="VAD - Prompt Window (s)", precision=None, value=10)
226
  ], outputs=[
227
  gr.File(label="Download"),
228
  gr.Text(label="Transcription"),
cli.py CHANGED
@@ -30,6 +30,7 @@ def cli():
30
  parser.add_argument("--vad_merge_window", type=optional_float, default=5, help="The window size (in seconds) to merge voice segments")
31
  parser.add_argument("--vad_max_merge_size", type=optional_float, default=150, help="The maximum size (in seconds) of a voice segment")
32
  parser.add_argument("--vad_padding", type=optional_float, default=1, help="The padding (in seconds) to add to each voice segment")
 
33
 
34
  parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
35
  parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
@@ -69,6 +70,7 @@ def cli():
69
  vad_merge_window = args.pop("vad_merge_window")
70
  vad_max_merge_size = args.pop("vad_max_merge_size")
71
  vad_padding = args.pop("vad_padding")
 
72
 
73
  model = whisper.load_model(model_name, device=device, download_root=model_dir)
74
  transcriber = WhisperTranscriber(deleteUploadedFiles=False)
@@ -91,7 +93,7 @@ def cli():
91
 
92
  result = transcriber.transcribe_file(model, source_path, temperature=temperature,
93
  vad=vad, vadMergeWindow=vad_merge_window, vadMaxMergeSize=vad_max_merge_size,
94
- vadPadding=vad_padding, **args)
95
 
96
  transcriber.write_result(result, source_name, output_dir)
97
 
 
30
  parser.add_argument("--vad_merge_window", type=optional_float, default=5, help="The window size (in seconds) to merge voice segments")
31
  parser.add_argument("--vad_max_merge_size", type=optional_float, default=150, help="The maximum size (in seconds) of a voice segment")
32
  parser.add_argument("--vad_padding", type=optional_float, default=1, help="The padding (in seconds) to add to each voice segment")
33
+ parser.add_argument("--vad_prompt_window", type=optional_float, default=0, help="The window size of the prompt to pass to Whisper")
34
 
35
  parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
36
  parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
 
70
  vad_merge_window = args.pop("vad_merge_window")
71
  vad_max_merge_size = args.pop("vad_max_merge_size")
72
  vad_padding = args.pop("vad_padding")
73
+ vad_prompt_window = args.pop("vad_prompt_window")
74
 
75
  model = whisper.load_model(model_name, device=device, download_root=model_dir)
76
  transcriber = WhisperTranscriber(deleteUploadedFiles=False)
 
93
 
94
  result = transcriber.transcribe_file(model, source_path, temperature=temperature,
95
  vad=vad, vadMergeWindow=vad_merge_window, vadMaxMergeSize=vad_max_merge_size,
96
+ vadPadding=vad_padding, vadPromptWindow=vad_prompt_window, **args)
97
 
98
  transcriber.write_result(result, source_name, output_dir)
99
 
src/utils.py CHANGED
@@ -56,7 +56,7 @@ def write_txt(transcript: Iterator[dict], file: TextIO):
56
  def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
57
  print("WEBVTT\n", file=file)
58
  for segment in transcript:
59
- text = processText(segment['text'], maxLineWidth).replace('-->', '->')
60
 
61
  print(
62
  f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
@@ -79,7 +79,7 @@ def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
79
  write_srt(result["segments"], file=srt)
80
  """
81
  for i, segment in enumerate(transcript, start=1):
82
- text = processText(segment['text'].strip(), maxLineWidth).replace('-->', '->')
83
 
84
  # write srt lines
85
  print(
@@ -91,7 +91,7 @@ def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
91
  flush=True,
92
  )
93
 
94
- def processText(text: str, maxLineWidth=None):
95
  if (maxLineWidth is None or maxLineWidth < 0):
96
  return text
97
 
 
56
  def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
57
  print("WEBVTT\n", file=file)
58
  for segment in transcript:
59
+ text = process_text(segment['text'], maxLineWidth).replace('-->', '->')
60
 
61
  print(
62
  f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
 
79
  write_srt(result["segments"], file=srt)
80
  """
81
  for i, segment in enumerate(transcript, start=1):
82
+ text = process_text(segment['text'].strip(), maxLineWidth).replace('-->', '->')
83
 
84
  # write srt lines
85
  print(
 
91
  flush=True,
92
  )
93
 
94
+ def process_text(text: str, maxLineWidth=None):
95
  if (maxLineWidth is None or maxLineWidth < 0):
96
  return text
97
 
src/vad.py CHANGED
@@ -1,5 +1,5 @@
1
  from abc import ABC, abstractmethod
2
- from collections import Counter
3
  from typing import Any, Iterator, List, Dict
4
 
5
  from pprint import pprint
@@ -17,10 +17,9 @@ import ffmpeg
17
  import numpy as np
18
 
19
  from src.utils import format_timestamp
 
20
 
21
  # Defaults for Silero
22
- # TODO: Make these configurable?
23
-
24
  SPEECH_TRESHOLD = 0.3
25
  MAX_SILENT_PERIOD = 10 # seconds
26
  MAX_MERGE_SIZE = 150 # Do not create segments larger than 2.5 minutes
@@ -35,16 +34,29 @@ TRANSCRIBE_NON_SPEECH = False
35
  # Minimum size of segments to process
36
  MIN_SEGMENT_DURATION = 1
37
 
 
 
 
 
 
 
 
 
38
  VAD_MAX_PROCESSING_CHUNK = 60 * 60 # 60 minutes of audio
39
 
40
  class AbstractTranscription(ABC):
41
- def __init__(self, segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None, max_merge_size: float = None, transcribe_non_speech: bool = False):
 
42
  self.sampling_rate = 16000
43
  self.segment_padding_left = segment_padding_left
44
  self.segment_padding_right = segment_padding_right
45
  self.max_silent_period = max_silent_period
46
  self.max_merge_size = max_merge_size
47
  self.transcribe_non_speech = transcribe_non_speech
 
 
 
 
48
 
49
  def get_audio_segment(self, str, start_time: str = None, duration: str = None):
50
  return load_audio(str, self.sampling_rate, start_time, duration)
@@ -74,8 +86,9 @@ class AbstractTranscription(ABC):
74
  audio: str
75
  The audio file.
76
 
77
- whisperCallable: Callable[[Union[str, np.ndarray, torch.Tensor]], dict[str, Union[dict, Any]]]
78
- The callback that is used to invoke Whisper on an audio file/buffer.
 
79
 
80
  Returns
81
  -------
@@ -86,7 +99,10 @@ class AbstractTranscription(ABC):
86
  seconds_timestamps = self.get_transcribe_timestamps(audio)
87
 
88
  padded = self.pad_timestamps(seconds_timestamps, self.segment_padding_left, self.segment_padding_right)
89
- merged = self.merge_timestamps(padded, self.max_silent_period, self.max_merge_size)
 
 
 
90
 
91
  print("Timestamps:")
92
  pprint(merged)
@@ -95,7 +111,12 @@ class AbstractTranscription(ABC):
95
  max_audio_duration = get_audio_duration(audio)
96
 
97
  # Expand segments to include the gaps between them
98
- merged = self.expand_gaps(merged, total_duration=max_audio_duration)
 
 
 
 
 
99
 
100
  print("Transcribing non-speech:")
101
  pprint(merged)
@@ -118,10 +139,14 @@ class AbstractTranscription(ABC):
118
  if segment_duration < MIN_SEGMENT_DURATION:
119
  continue;
120
 
 
121
  segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration))
122
-
123
- print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ", segment_duration, "expanded: ", segment_expand_amount)
124
- segment_result = whisperCallable(segment_audio)
 
 
 
125
 
126
  adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
127
 
@@ -132,6 +157,16 @@ class AbstractTranscription(ABC):
132
  # Increment detected language
133
  languageCounter[segment_result['language']] += 1
134
 
 
 
 
 
 
 
 
 
 
 
135
  if len(languageCounter) > 0:
136
  result['language'] = languageCounter.most_common(1)[0][0]
137
 
@@ -203,6 +238,58 @@ class AbstractTranscription(ABC):
203
 
204
  return result
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  def adjust_timestamp(self, segments: Iterator[dict], adjust_seconds: float, max_source_time: float = None):
207
  result = []
208
 
@@ -253,7 +340,8 @@ class AbstractTranscription(ABC):
253
 
254
  return result
255
 
256
- def merge_timestamps(self, timestamps: List[Dict[str, Any]], max_merge_gap: float, max_merge_size: float):
 
257
  if max_merge_gap is None:
258
  return timestamps
259
 
@@ -270,7 +358,10 @@ class AbstractTranscription(ABC):
270
  current_entry_size = current_entry['end'] - current_entry['start']
271
 
272
  if distance <= max_merge_gap and (max_merge_size is None or current_entry_size <= max_merge_size):
273
- # Merge
 
 
 
274
  current_entry['end'] = entry['end']
275
  else:
276
  # Output current entry
@@ -299,9 +390,9 @@ class AbstractTranscription(ABC):
299
  class VadSileroTranscription(AbstractTranscription):
300
  def __init__(self, segment_padding_left=SEGMENT_PADDING_LEFT, segment_padding_right=SEGMENT_PADDING_RIGHT,
301
  max_silent_period=MAX_SILENT_PERIOD, max_merge_size=MAX_MERGE_SIZE, transcribe_non_speech: bool = False,
302
- copy = None):
303
  super().__init__(segment_padding_left=segment_padding_left, segment_padding_right=segment_padding_right,
304
- max_silent_period=max_silent_period, max_merge_size=max_merge_size, transcribe_non_speech=transcribe_non_speech)
305
 
306
  if copy:
307
  self.model = copy.model
 
1
  from abc import ABC, abstractmethod
2
+ from collections import Counter, deque
3
  from typing import Any, Iterator, List, Dict
4
 
5
  from pprint import pprint
 
17
  import numpy as np
18
 
19
  from src.utils import format_timestamp
20
+ from enum import Enum
21
 
22
  # Defaults for Silero
 
 
23
  SPEECH_TRESHOLD = 0.3
24
  MAX_SILENT_PERIOD = 10 # seconds
25
  MAX_MERGE_SIZE = 150 # Do not create segments larger than 2.5 minutes
 
34
  # Minimum size of segments to process
35
  MIN_SEGMENT_DURATION = 1
36
 
37
+ # Always merge segments that are less than this duration apart
38
+ MIN_FORCE_MERGE_GAP = 0.5
39
+ FORCE_MERGE_SEGMENT_MULTIPLIER = 1.5
40
+
41
+ # The maximum time for texts from old segments to be used in the next segment
42
+ MAX_PROMPT_WINDOW = 0 # seconds (0 = disabled)
43
+ PROMPT_NO_SPEECH_PROB = 0.1 # Do not pass the text from segments with a no speech probability higher than this
44
+
45
  VAD_MAX_PROCESSING_CHUNK = 60 * 60 # 60 minutes of audio
46
 
47
  class AbstractTranscription(ABC):
48
+ def __init__(self, segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
49
+ max_merge_size: float = None, transcribe_non_speech: bool = False, max_prompt_window: float = None):
50
  self.sampling_rate = 16000
51
  self.segment_padding_left = segment_padding_left
52
  self.segment_padding_right = segment_padding_right
53
  self.max_silent_period = max_silent_period
54
  self.max_merge_size = max_merge_size
55
  self.transcribe_non_speech = transcribe_non_speech
56
+ self.max_prompt_window = max_prompt_window
57
+
58
+ self.min_force_merge_gap = MIN_FORCE_MERGE_GAP
59
+ self.max_force_merge_size = max_merge_size * FORCE_MERGE_SEGMENT_MULTIPLIER if max_merge_size is not None else None
60
 
61
  def get_audio_segment(self, str, start_time: str = None, duration: str = None):
62
  return load_audio(str, self.sampling_rate, start_time, duration)
 
86
  audio: str
87
  The audio file.
88
 
89
+ whisperCallable: Callable[[Union[str, np.ndarray, torch.Tensor], str], dict[str, Union[dict, Any]]]
90
+ The callback that is used to invoke Whisper on an audio file/buffer. The first parameter is the audio file/buffer,
91
+ and the second parameter is an optional text prompt. The return value is the result of the Whisper call.
92
 
93
  Returns
94
  -------
 
99
  seconds_timestamps = self.get_transcribe_timestamps(audio)
100
 
101
  padded = self.pad_timestamps(seconds_timestamps, self.segment_padding_left, self.segment_padding_right)
102
+ merged = self.merge_timestamps(padded, self.max_silent_period, self.max_merge_size, self.min_force_merge_gap, self.max_force_merge_size)
103
+
104
+ # A deque of transcribed segments that is passed to the next segment as a prompt
105
+ prompt_window = deque()
106
 
107
  print("Timestamps:")
108
  pprint(merged)
 
111
  max_audio_duration = get_audio_duration(audio)
112
 
113
  # Expand segments to include the gaps between them
114
+ if (self.max_prompt_window is not None and self.max_prompt_window > 0):
115
+ # When we have a prompt window, we create speech segments betwen each segment if we exceed the merge size
116
+ merged = self.fill_gaps(merged, total_duration=max_audio_duration, max_expand_size=self.max_merge_size)
117
+ else:
118
+ # With no prompt window, it is better to expand the segments
119
+ merged = self.expand_gaps(merged, total_duration=max_audio_duration)
120
 
121
  print("Transcribing non-speech:")
122
  pprint(merged)
 
139
  if segment_duration < MIN_SEGMENT_DURATION:
140
  continue;
141
 
142
+ # Audio to run on Whisper
143
  segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration))
144
+ # Previous segments to use as a prompt
145
+ segment_prompt = ' '.join([segment['text'] for segment in prompt_window]) if len(prompt_window) > 0 else None
146
+
147
+ print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
148
+ segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt)
149
+ segment_result = whisperCallable(segment_audio, segment_prompt)
150
 
151
  adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
152
 
 
157
  # Increment detected language
158
  languageCounter[segment_result['language']] += 1
159
 
160
+ # Update prompt window
161
+ if (self.max_prompt_window is not None and self.max_prompt_window > 0):
162
+ # Add segments to the current prompt window
163
+ for segment in adjusted_segments:
164
+ if segment.get('no_speech_prob', 0) <= PROMPT_NO_SPEECH_PROB:
165
+ prompt_window.append(segment)
166
+
167
+ while (len(prompt_window) > 0 and prompt_window[0]['end'] < segment_end - self.max_prompt_window):
168
+ prompt_window.popleft()
169
+
170
  if len(languageCounter) > 0:
171
  result['language'] = languageCounter.most_common(1)[0][0]
172
 
 
238
 
239
  return result
240
 
241
+ def fill_gaps(self, segments: List[Dict[str, Any]], total_duration: float, max_expand_size: float = None):
242
+ result = []
243
+
244
+ if len(segments) == 0:
245
+ return result
246
+
247
+ # Add gap at the beginning if needed
248
+ if (segments[0]['start'] > 0):
249
+ result.append({ 'start': 0, 'end': segments[0]['start'], 'gap': True } )
250
+
251
+ for i in range(len(segments) - 1):
252
+ expanded = False
253
+ current_segment = segments[i]
254
+ next_segment = segments[i + 1]
255
+
256
+ delta = next_segment['start'] - current_segment['end']
257
+
258
+ if (max_expand_size is not None and delta <= max_expand_size):
259
+ # Just expand the current segment
260
+ current_segment = current_segment.copy()
261
+ current_segment['expand_amount'] = delta
262
+ current_segment['end'] = next_segment['start']
263
+ expanded = True
264
+
265
+ result.append(current_segment)
266
+
267
+ # Add a gap to the next segment if needed
268
+ if (delta >= 0 and not expanded):
269
+ result.append({ 'start': current_segment['end'], 'end': next_segment['start'], 'gap': True } )
270
+
271
+ # Add last segment
272
+ last_segment = segments[-1]
273
+ result.append(last_segment)
274
+
275
+ # Also include total duration if specified
276
+ if (total_duration is not None):
277
+ last_segment = result[-1]
278
+
279
+ delta = total_duration - last_segment['end']
280
+
281
+ if (delta > 0):
282
+ if (max_expand_size is not None and delta <= max_expand_size):
283
+ # Expand the last segment
284
+ last_segment = last_segment.copy()
285
+ last_segment['expand_amount'] = delta
286
+ last_segment['end'] = total_duration
287
+ result[-1] = last_segment
288
+ else:
289
+ result.append({ 'start': last_segment['end'], 'end': total_duration, 'gap': True } )
290
+
291
+ return result
292
+
293
  def adjust_timestamp(self, segments: Iterator[dict], adjust_seconds: float, max_source_time: float = None):
294
  result = []
295
 
 
340
 
341
  return result
342
 
343
+ def merge_timestamps(self, timestamps: List[Dict[str, Any]], max_merge_gap: float, max_merge_size: float,
344
+ min_force_merge_gap: float, max_force_merge_size: float):
345
  if max_merge_gap is None:
346
  return timestamps
347
 
 
358
  current_entry_size = current_entry['end'] - current_entry['start']
359
 
360
  if distance <= max_merge_gap and (max_merge_size is None or current_entry_size <= max_merge_size):
361
+ # Regular merge
362
+ current_entry['end'] = entry['end']
363
+ elif min_force_merge_gap is not None and distance <= min_force_merge_gap and (max_force_merge_size is None or current_entry_size <= max_force_merge_size):
364
+ # Force merge if the distance is small (up to a certain maximum size)
365
  current_entry['end'] = entry['end']
366
  else:
367
  # Output current entry
 
390
  class VadSileroTranscription(AbstractTranscription):
391
  def __init__(self, segment_padding_left=SEGMENT_PADDING_LEFT, segment_padding_right=SEGMENT_PADDING_RIGHT,
392
  max_silent_period=MAX_SILENT_PERIOD, max_merge_size=MAX_MERGE_SIZE, transcribe_non_speech: bool = False,
393
+ max_prompt_window=MAX_PROMPT_WINDOW, copy = None):
394
  super().__init__(segment_padding_left=segment_padding_left, segment_padding_right=segment_padding_right,
395
+ max_silent_period=max_silent_period, max_merge_size=max_merge_size, transcribe_non_speech=transcribe_non_speech, max_prompt_window=max_prompt_window)
396
 
397
  if copy:
398
  self.model = copy.model