OpenSound commited on
Commit
0b57247
·
1 Parent(s): a67873a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -57
app.py CHANGED
@@ -79,6 +79,10 @@ def get_mask_interval(transcribe_state, word_span):
79
  end = float(data[e][0]) if e < len(data) else float(data[-1][1])
80
 
81
  return (start, end)
 
 
 
 
82
 
83
  @spaces.GPU
84
  class WhisperxAlignModel:
@@ -91,22 +95,6 @@ class WhisperxAlignModel:
91
  audio = load_audio(audio_path)
92
  return align(segments, self.model, self.metadata, audio, device, return_char_alignments=False)["segments"]
93
 
94
- @spaces.GPU
95
- class WhisperModel:
96
- def __init__(self, model_name, language):
97
- from whisper import load_model
98
- self.model = load_model(model_name, device, language=language)
99
-
100
- from whisper.tokenizer import get_tokenizer
101
- tokenizer = get_tokenizer(multilingual=False, language=language)
102
- self.supress_tokens = [-1] + [
103
- i
104
- for i in range(tokenizer.eot)
105
- if all(c in "0123456789" for c in tokenizer.decode([i]).removeprefix(" "))
106
- ]
107
-
108
- def transcribe(self, audio_path):
109
- return self.model.transcribe(audio_path, suppress_tokens=self.supress_tokens, word_timestamps=True)["segments"]
110
 
111
  @spaces.GPU
112
  class WhisperxModel:
@@ -121,41 +109,45 @@ class WhisperxModel:
121
  segment['text'] = replace_numbers_with_words(segment['text'])
122
  return self.align_model.align(segments, audio_path)
123
 
124
- @spaces.GPU
125
- def load_models():
126
- ssrspeech_model_name = "English"
127
- text_tokenizer = TextTokenizer(backend="espeak")
128
- language = "en"
129
- transcribe_model_name = "base.en"
130
-
131
- align_model = WhisperxAlignModel(language)
132
- transcribe_model = WhisperxModel(transcribe_model_name, align_model, language)
133
-
134
- ssrspeech_fn = f"{MODELS_PATH}/{ssrspeech_model_name}.pth"
135
- if not os.path.exists(ssrspeech_fn):
136
- os.system(f"wget https://huggingface.co/westbrook/SSR-Speech-{ssrspeech_model_name}/resolve/main/{ssrspeech_model_name}.pth -O " + ssrspeech_fn)
137
-
138
- ckpt = torch.load(ssrspeech_fn)
139
- model = ssr.SSR_Speech(ckpt["config"])
140
- model.load_state_dict(ckpt["model"])
141
- config = model.args
142
- phn2num = ckpt["phn2num"]
143
- model.to(device)
144
-
145
- encodec_fn = f"{MODELS_PATH}/wmencodec.th"
146
- if not os.path.exists(encodec_fn):
147
- os.system(f"wget https://huggingface.co/westbrook/SSR-Speech-English/resolve/main/wmencodec.th -O " + encodec_fn)
148
-
149
- ssrspeech_model = {
150
- "config": config,
151
- "phn2num": phn2num,
152
- "model": model,
153
- "text_tokenizer": text_tokenizer,
154
- "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
155
- }
156
- return transcribe_model, align_model, ssrspeech_model
 
 
 
 
 
157
 
158
- transcribe_model, align_model, ssrspeech_model = load_models()
159
 
160
  def get_transcribe_state(segments):
161
  transcript = " ".join([segment["text"] for segment in segments])
@@ -167,9 +159,8 @@ def get_transcribe_state(segments):
167
 
168
  @spaces.GPU
169
  def transcribe(audio_path):
170
- if transcribe_model is None:
171
- raise gr.Error("Transcription model not loaded")
172
-
173
  segments = transcribe_model.transcribe(audio_path)
174
  state = get_transcribe_state(segments)
175
  success_message = "<span style='color:green;'>Success: Transcribe completed successfully!</span>"
@@ -182,10 +173,9 @@ def transcribe(audio_path):
182
 
183
  @spaces.GPU
184
  def align(segments, audio_path):
185
- if align_model is None:
186
- raise gr.Error("Align model not loaded")
187
-
188
- segments = align_model.align(segments, audio_path)
189
  state = get_transcribe_state(segments)
190
 
191
  return state
 
79
  end = float(data[e][0]) if e < len(data) else float(data[-1][1])
80
 
81
  return (start, end)
82
+
83
+
84
+ from whisperx import load_align_model
85
+
86
 
87
  @spaces.GPU
88
  class WhisperxAlignModel:
 
95
  audio = load_audio(audio_path)
96
  return align(segments, self.model, self.metadata, audio, device, return_char_alignments=False)["segments"]
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  @spaces.GPU
100
  class WhisperxModel:
 
109
  segment['text'] = replace_numbers_with_words(segment['text'])
110
  return self.align_model.align(segments, audio_path)
111
 
112
+ from whisperx import load_align_model, load_model, load_audio
113
+ from whisperx import align as align_func
114
+
115
+
116
+ ssrspeech_model_name = "English"
117
+ text_tokenizer = TextTokenizer(backend="espeak")
118
+ language = "en"
119
+ transcribe_model_name = "base.en"
120
+
121
+ # align_model = WhisperxAlignModel(language)
122
+ # transcribe_model = WhisperxModel(transcribe_model_name, align_model, language)
123
+
124
+ # align_model, align_model_metadata = load_align_model(language_code=language, device=device)
125
+ # transcribe_model = load_model(transcribe_model_name, device, asr_options={"suppress_numerals": True, "max_new_tokens": None, "clip_timestamps": None, "hallucination_silence_threshold": None}, language=language)
126
+
127
+ ssrspeech_fn = f"{MODELS_PATH}/{ssrspeech_model_name}.pth"
128
+ if not os.path.exists(ssrspeech_fn):
129
+ os.system(f"wget https://huggingface.co/westbrook/SSR-Speech-{ssrspeech_model_name}/resolve/main/{ssrspeech_model_name}.pth -O " + ssrspeech_fn)
130
+
131
+ ckpt = torch.load(ssrspeech_fn)
132
+ model = ssr.SSR_Speech(ckpt["config"])
133
+ model.load_state_dict(ckpt["model"])
134
+ config = model.args
135
+ phn2num = ckpt["phn2num"]
136
+ model.to(device)
137
+
138
+ encodec_fn = f"{MODELS_PATH}/wmencodec.th"
139
+ if not os.path.exists(encodec_fn):
140
+ os.system(f"wget https://huggingface.co/westbrook/SSR-Speech-English/resolve/main/wmencodec.th -O " + encodec_fn)
141
+
142
+ ssrspeech_model = {
143
+ "config": config,
144
+ "phn2num": phn2num,
145
+ "model": model,
146
+ "text_tokenizer": text_tokenizer,
147
+ "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
148
+ }
149
+
150
 
 
151
 
152
  def get_transcribe_state(segments):
153
  transcript = " ".join([segment["text"] for segment in segments])
 
159
 
160
  @spaces.GPU
161
  def transcribe(audio_path):
162
+ align_model, _ = load_align_model(language_code=language, device=device)
163
+ transcribe_model = load_model(transcribe_model_name, device, asr_options={"suppress_numerals": True, "max_new_tokens": None, "clip_timestamps": None, "hallucination_silence_threshold": None}, language=language)
 
164
  segments = transcribe_model.transcribe(audio_path)
165
  state = get_transcribe_state(segments)
166
  success_message = "<span style='color:green;'>Success: Transcribe completed successfully!</span>"
 
173
 
174
  @spaces.GPU
175
  def align(segments, audio_path):
176
+ align_model, metadata = load_align_model(language_code=language, device=device)
177
+ audio = load_audio(audio_path)
178
+ segments = align_func(segments, align_model, metadata, audio, device, return_char_alignments=False)["segments"]
 
179
  state = get_transcribe_state(segments)
180
 
181
  return state