aadnk commited on
Commit
01fddc0
·
1 Parent(s): 9934006

Fix CLI for parallel devices

Browse files
Files changed (4) hide show
  1. app.py +4 -1
  2. cli.py +8 -4
  3. src/vadParallel.py +12 -5
  4. src/whisperContainer.py +3 -2
app.py CHANGED
@@ -60,6 +60,9 @@ class WhisperTranscriber:
60
  self.inputAudioMaxDuration = input_audio_max_duration
61
  self.deleteUploadedFiles = delete_uploaded_files
62
 
 
 
 
63
  def transcribe_webui(self, modelName, languageName, urlData, uploadFile, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow):
64
  try:
65
  source, sourceName = self.__get_source(urlData, uploadFile, microphoneData)
@@ -255,7 +258,7 @@ def create_ui(input_audio_max_duration, share=False, server_name: str = None, se
255
  ui = WhisperTranscriber(input_audio_max_duration, vad_process_timeout)
256
 
257
  # Specify a list of devices to use for parallel processing
258
- ui.parallel_device_list = [ device.strip() for device in vad_parallel_devices.split(",") ] if vad_parallel_devices else None
259
 
260
  ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
261
  ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
 
60
  self.inputAudioMaxDuration = input_audio_max_duration
61
  self.deleteUploadedFiles = delete_uploaded_files
62
 
63
+ def set_parallel_devices(self, vad_parallel_devices: str):
64
+ self.parallel_device_list = [ device.strip() for device in vad_parallel_devices.split(",") ] if vad_parallel_devices else None
65
+
66
  def transcribe_webui(self, modelName, languageName, urlData, uploadFile, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow):
67
  try:
68
  source, sourceName = self.__get_source(urlData, uploadFile, microphoneData)
 
258
  ui = WhisperTranscriber(input_audio_max_duration, vad_process_timeout)
259
 
260
  # Specify a list of devices to use for parallel processing
261
+ ui.set_parallel_devices(vad_parallel_devices)
262
 
263
  ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
264
  ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
cli.py CHANGED
@@ -12,6 +12,7 @@ from app import LANGUAGES, WhisperTranscriber
12
  from src.download import download_url
13
 
14
  from src.utils import optional_float, optional_int, str2bool
 
15
 
16
 
17
  def cli():
@@ -31,7 +32,7 @@ def cli():
31
  parser.add_argument("--vad_max_merge_size", type=optional_float, default=30, help="The maximum size (in seconds) of a voice segment")
32
  parser.add_argument("--vad_padding", type=optional_float, default=1, help="The padding (in seconds) to add to each voice segment")
33
  parser.add_argument("--vad_prompt_window", type=optional_float, default=3, help="The window size of the prompt to pass to Whisper")
34
- parser.add_argument("--vad_parallel_devices", type=str, default="0", help="A commma delimited list of CUDA devices to use for paralell processing. If None, disable parallel processing.")
35
 
36
  parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
37
  parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
@@ -73,9 +74,12 @@ def cli():
73
  vad_padding = args.pop("vad_padding")
74
  vad_prompt_window = args.pop("vad_prompt_window")
75
 
76
- model = whisper.load_model(model_name, device=device, download_root=model_dir)
77
  transcriber = WhisperTranscriber(delete_uploaded_files=False)
78
- transcriber.parallel_device_list = args.pop("vad_parallel_devices")
 
 
 
79
 
80
  for audio_path in args.pop("audio"):
81
  sources = []
@@ -99,7 +103,7 @@ def cli():
99
 
100
  transcriber.write_result(result, source_name, output_dir)
101
 
102
- transcriber.clear_cache()
103
 
104
  def uri_validator(x):
105
  try:
 
12
  from src.download import download_url
13
 
14
  from src.utils import optional_float, optional_int, str2bool
15
+ from src.whisperContainer import WhisperContainer
16
 
17
 
18
  def cli():
 
32
  parser.add_argument("--vad_max_merge_size", type=optional_float, default=30, help="The maximum size (in seconds) of a voice segment")
33
  parser.add_argument("--vad_padding", type=optional_float, default=1, help="The padding (in seconds) to add to each voice segment")
34
  parser.add_argument("--vad_prompt_window", type=optional_float, default=3, help="The window size of the prompt to pass to Whisper")
35
+ parser.add_argument("--vad_parallel_devices", type=str, default="", help="A commma delimited list of CUDA devices to use for paralell processing. If None, disable parallel processing.")
36
 
37
  parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
38
  parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
 
74
  vad_padding = args.pop("vad_padding")
75
  vad_prompt_window = args.pop("vad_prompt_window")
76
 
77
+ model = WhisperContainer(model_name, device=device, download_root=model_dir)
78
  transcriber = WhisperTranscriber(delete_uploaded_files=False)
79
+ transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
80
+
81
+ if (transcriber._has_parallel_devices()):
82
+ print("Using parallel devices:", transcriber.parallel_device_list)
83
 
84
  for audio_path in args.pop("audio"):
85
  sources = []
 
103
 
104
  transcriber.write_result(result, source_name, output_dir)
105
 
106
+ transcriber.close()
107
 
108
  def uri_validator(x):
109
  try:
src/vadParallel.py CHANGED
@@ -88,14 +88,20 @@ class ParallelTranscription(AbstractTranscription):
88
 
89
  # Split into a list for each device
90
  # TODO: Split by time instead of by number of chunks
91
- merged_split = self._chunks(merged, max(len(merged) // len(devices), 1))
92
 
93
  # Parameters that will be passed to the transcribe function
94
  parameters = []
95
  segment_index = config.initial_segment_index
96
 
97
  for i in range(len(merged_split)):
98
- device_segment_list = merged_split[i]
 
 
 
 
 
 
99
 
100
  # Create a new config with the given device ID
101
  device_config = ParallelTranscriptionConfig(devices[i], device_segment_list, segment_index, config)
@@ -159,7 +165,8 @@ class ParallelTranscription(AbstractTranscription):
159
  os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
160
  return super().transcribe(audio, whisperCallable, config)
161
 
162
- def _chunks(self, lst, n):
163
- """Yield successive n-sized chunks from lst."""
164
- return [lst[i:i + n] for i in range(0, len(lst), n)]
 
165
 
 
88
 
89
  # Split into a list for each device
90
  # TODO: Split by time instead of by number of chunks
91
+ merged_split = list(self._split(merged, len(devices)))
92
 
93
  # Parameters that will be passed to the transcribe function
94
  parameters = []
95
  segment_index = config.initial_segment_index
96
 
97
  for i in range(len(merged_split)):
98
+ device_segment_list = list(merged_split[i])
99
+ device_id = devices[i]
100
+
101
+ if (len(device_segment_list) <= 0):
102
+ continue
103
+
104
+ print("Device " + device_id + " (index " + str(i) + ") has " + str(len(device_segment_list)) + " segments")
105
 
106
  # Create a new config with the given device ID
107
  device_config = ParallelTranscriptionConfig(devices[i], device_segment_list, segment_index, config)
 
165
  os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
166
  return super().transcribe(audio, whisperCallable, config)
167
 
168
+ def _split(self, a, n):
169
+ """Split a list into n approximately equal parts."""
170
+ k, m = divmod(len(a), n)
171
+ return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
172
 
src/whisperContainer.py CHANGED
@@ -23,9 +23,10 @@ class WhisperModelCache:
23
  GLOBAL_WHISPER_MODEL_CACHE = WhisperModelCache()
24
 
25
  class WhisperContainer:
26
- def __init__(self, model_name: str, device: str = None, cache: WhisperModelCache = None):
27
  self.model_name = model_name
28
  self.device = device
 
29
  self.cache = cache
30
 
31
  # Will be created on demand
@@ -36,7 +37,7 @@ class WhisperContainer:
36
 
37
  if (self.cache is None):
38
  print("Loading whisper model " + self.model_name)
39
- self.model = whisper.load_model(self.model_name, device=self.device)
40
  else:
41
  self.model = self.cache.get(self.model_name, device=self.device)
42
  return self.model
 
23
  GLOBAL_WHISPER_MODEL_CACHE = WhisperModelCache()
24
 
25
  class WhisperContainer:
26
+ def __init__(self, model_name: str, device: str = None, download_root: str = None, cache: WhisperModelCache = None):
27
  self.model_name = model_name
28
  self.device = device
29
+ self.download_root = download_root
30
  self.cache = cache
31
 
32
  # Will be created on demand
 
37
 
38
  if (self.cache is None):
39
  print("Loading whisper model " + self.model_name)
40
+ self.model = whisper.load_model(self.model_name, device=self.device, download_root=self.download_root)
41
  else:
42
  self.model = self.cache.get(self.model_name, device=self.device)
43
  return self.model