aadnk commited on
Commit
052fe7e
·
1 Parent(s): 46b5f28

Fix unit test

Browse files
src/whisper/abstractWhisperContainer.py CHANGED
@@ -1,5 +1,5 @@
1
  import abc
2
- from typing import List
3
 
4
  from src.config import ModelConfig, VadInitialPromptMode
5
 
@@ -9,7 +9,7 @@ from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
9
 
10
  class AbstractWhisperCallback:
11
  def __init__(self):
12
- self.__prompt_mode_gpt = None
13
 
14
  @abc.abstractmethod
15
  def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
@@ -29,6 +29,14 @@ class AbstractWhisperCallback:
29
  """
30
  raise NotImplementedError()
31
 
 
 
 
 
 
 
 
 
32
  class AbstractWhisperContainer:
33
  def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
34
  download_root: str = None,
 
1
  import abc
2
+ from typing import Any, Callable, List
3
 
4
  from src.config import ModelConfig, VadInitialPromptMode
5
 
 
9
 
10
  class AbstractWhisperCallback:
11
  def __init__(self):
12
+ pass
13
 
14
  @abc.abstractmethod
15
  def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
 
29
  """
30
  raise NotImplementedError()
31
 
32
+ class LambdaWhisperCallback(AbstractWhisperCallback):
33
+ def __init__(self, callback_lambda: Callable[[Any, int, str, str, ProgressListener], None]):
34
+ super().__init__()
35
+ self.callback_lambda = callback_lambda
36
+
37
+ def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
38
+ return self.callback_lambda(audio, segment_index, prompt, detected_language, progress_listener)
39
+
40
  class AbstractWhisperContainer:
41
  def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
42
  download_root: str = None,
tests/vad_test.py CHANGED
@@ -1,10 +1,11 @@
1
- import pprint
2
  import unittest
3
  import numpy as np
4
  import sys
5
 
6
  sys.path.append('../whisper-webui')
 
7
 
 
8
  from src.vad import AbstractTranscription, TranscriptionConfig, VadSileroTranscription
9
 
10
  class TestVad(unittest.TestCase):
@@ -13,10 +14,11 @@ class TestVad(unittest.TestCase):
13
  self.transcribe_calls = []
14
 
15
  def test_transcript(self):
16
- mock = MockVadTranscription()
 
17
 
18
  self.transcribe_calls.clear()
19
- result = mock.transcribe("mock", lambda segment : self.transcribe_segments(segment))
20
 
21
  self.assertListEqual(self.transcribe_calls, [
22
  [30, 30],
@@ -45,8 +47,9 @@ class TestVad(unittest.TestCase):
45
  }
46
 
47
  class MockVadTranscription(AbstractTranscription):
48
- def __init__(self):
49
  super().__init__()
 
50
 
51
  def get_audio_segment(self, str, start_time: str = None, duration: str = None):
52
  start_time_seconds = float(start_time.removesuffix("s"))
@@ -61,6 +64,9 @@ class MockVadTranscription(AbstractTranscription):
61
  result.append( { 'start': 30, 'end': 60 } )
62
  result.append( { 'start': 100, 'end': 200 } )
63
  return result
 
 
 
64
 
65
  if __name__ == '__main__':
66
  unittest.main()
 
 
1
  import unittest
2
  import numpy as np
3
  import sys
4
 
5
  sys.path.append('../whisper-webui')
6
+ #print("Sys path: " + str(sys.path))
7
 
8
+ from src.whisper.abstractWhisperContainer import LambdaWhisperCallback
9
  from src.vad import AbstractTranscription, TranscriptionConfig, VadSileroTranscription
10
 
11
  class TestVad(unittest.TestCase):
 
14
  self.transcribe_calls = []
15
 
16
  def test_transcript(self):
17
+ mock = MockVadTranscription(mock_audio_length=120)
18
+ config = TranscriptionConfig()
19
 
20
  self.transcribe_calls.clear()
21
+ result = mock.transcribe("mock", LambdaWhisperCallback(lambda segment, _1, _2, _3, _4: self.transcribe_segments(segment)), config)
22
 
23
  self.assertListEqual(self.transcribe_calls, [
24
  [30, 30],
 
47
  }
48
 
49
  class MockVadTranscription(AbstractTranscription):
50
+ def __init__(self, mock_audio_length: float = 1000):
51
  super().__init__()
52
+ self.mock_audio_length = mock_audio_length
53
 
54
  def get_audio_segment(self, str, start_time: str = None, duration: str = None):
55
  start_time_seconds = float(start_time.removesuffix("s"))
 
64
  result.append( { 'start': 30, 'end': 60 } )
65
  result.append( { 'start': 100, 'end': 200 } )
66
  return result
67
+
68
+ def get_audio_duration(self, audio: str, config: TranscriptionConfig):
69
+ return self.mock_audio_length
70
 
71
  if __name__ == '__main__':
72
  unittest.main()