sedemkofi commited on
Commit
62a2e2b
·
verified ·
1 Parent(s): 346a962

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -4
app.py CHANGED
@@ -6,10 +6,51 @@ from jiwer import wer, cer
6
  import json
7
  from io import BytesIO
8
  import base64
 
9
 
10
- # ... (keep your existing imports and TwiTranscriptionModel class) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # Add this at the top of your file
13
  class ChunkedUploader:
14
  def __init__(self):
15
  if 'chunks' not in st.session_state:
@@ -38,10 +79,63 @@ class ChunkedUploader:
38
  del st.session_state.chunks[upload_id]
39
  return complete_data
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def main():
42
- model = get_model()
43
- chunked_uploader = ChunkedUploader()
44
 
 
 
45
  if model is None:
46
  st.write(json.dumps({
47
  'error': 'Failed to load model',
@@ -49,6 +143,9 @@ def main():
49
  }))
50
  return
51
 
 
 
 
52
  # Check if this is an API request
53
  if "api" in st.query_params:
54
  try:
 
6
  import json
7
  from io import BytesIO
8
  import base64
9
+ import tensorflow as tf
10
 
11
+ class TwiTranscriptionModel:
12
+ def __init__(self, encoder_model, decoder_model, char_tokenizer, max_length=50):
13
+ self.encoder_model = encoder_model
14
+ self.decoder_model = decoder_model
15
+ self.char_tokenizer = char_tokenizer
16
+ self.max_length = max_length
17
+ self.sos_token = '<sos>'
18
+ self.eos_token = '<eos>'
19
+ self.sos_index = char_tokenizer.word_index[self.sos_token]
20
+ self.eos_index = char_tokenizer.word_index[self.eos_token]
21
+
22
+ def predict(self, audio_features):
23
+ batch_size = audio_features.shape[0]
24
+ transcriptions = []
25
+
26
+ for i in range(batch_size):
27
+ states_value = self.encoder_model.predict(
28
+ audio_features[i:i+1],
29
+ verbose=0
30
+ )
31
+ target_seq = np.array([[self.sos_index]])
32
+ decoded_chars = []
33
+
34
+ for _ in range(self.max_length):
35
+ output_tokens, h, c = self.decoder_model.predict(
36
+ [target_seq] + states_value,
37
+ verbose=0
38
+ )
39
+
40
+ sampled_token_index = np.argmax(output_tokens[0, -1, :])
41
+ sampled_char = self.char_tokenizer.index_word.get(sampled_token_index, '')
42
+
43
+ if sampled_char == self.eos_token or len(decoded_chars) > self.max_length:
44
+ break
45
+
46
+ decoded_chars.append(sampled_char)
47
+ target_seq = np.array([[sampled_token_index]])
48
+ states_value = [h, c]
49
+
50
+ transcriptions.append(''.join(decoded_chars))
51
+
52
+ return transcriptions
53
 
 
54
  class ChunkedUploader:
55
  def __init__(self):
56
  if 'chunks' not in st.session_state:
 
79
  del st.session_state.chunks[upload_id]
80
  return complete_data
81
 
82
+ @st.cache_resource
83
+ def get_model():
84
+ try:
85
+ model_path = 'twi_transcription_model.pkl'
86
+ st.write(f"Looking for model at: {model_path}") # Debug info
87
+
88
+ with open(model_path, 'rb') as f:
89
+ model_data = pickle.load(f)
90
+
91
+ model = TwiTranscriptionModel(
92
+ model_data['encoder_model'],
93
+ model_data['decoder_model'],
94
+ model_data['char_tokenizer'],
95
+ model_data['max_length']
96
+ )
97
+ st.write("Model loaded successfully") # Debug info
98
+ return model
99
+ except Exception as e:
100
+ st.error(f"Error loading model: {str(e)}")
101
+ return None
102
+
103
+ def extract_mfcc(audio_data, sr=16000, n_mfcc=13):
104
+ if sr != 16000:
105
+ audio_data = librosa.resample(y=audio_data, orig_sr=sr, target_sr=16000)
106
+
107
+ mfcc = librosa.feature.mfcc(y=audio_data, sr=16000, n_mfcc=n_mfcc)
108
+
109
+ max_length = 1000 # Adjust based on your model's requirements
110
+ if mfcc.shape[1] > max_length:
111
+ mfcc = mfcc[:, :max_length]
112
+ else:
113
+ mfcc = np.pad(mfcc, ((0, 0), (0, max_length - mfcc.shape[1])), mode='constant')
114
+
115
+ return mfcc.T
116
+
117
+ def calculate_error_rates(reference, hypothesis):
118
+ try:
119
+ error_wer = wer(reference, hypothesis)
120
+ error_cer = cer(reference, hypothesis)
121
+ return error_wer, error_cer
122
+ except Exception as e:
123
+ return None, None
124
+
125
+ def process_audio_bytes(audio_bytes):
126
+ try:
127
+ audio_data, sr = librosa.load(BytesIO(audio_bytes), sr=None)
128
+ if len(audio_data.shape) > 1:
129
+ audio_data = np.mean(audio_data, axis=1)
130
+ return audio_data, sr
131
+ except Exception as e:
132
+ raise Exception(f"Error processing audio: {str(e)}")
133
+
134
  def main():
135
+ st.set_page_config(page_title="Twi Speech API")
 
136
 
137
+ # Initialize model
138
+ model = get_model()
139
  if model is None:
140
  st.write(json.dumps({
141
  'error': 'Failed to load model',
 
143
  }))
144
  return
145
 
146
+ # Initialize chunked uploader
147
+ chunked_uploader = ChunkedUploader()
148
+
149
  # Check if this is an API request
150
  if "api" in st.query_params:
151
  try: