sedemkofi commited on
Commit
e2e2e8b
·
verified ·
1 Parent(s): 2e1e1ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -55
app.py CHANGED
@@ -7,12 +7,48 @@ import json
7
  from io import BytesIO
8
  import base64
9
 
10
- # Set page config
11
- st.set_page_config(page_title="Twi Speech API")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # Get request method and body
14
- request_method = st.experimental_get_query_params().get("_stcore_method", ["GET"])[0]
15
- request_json = st.experimental_get_query_params().get("_stcore_body", ["{}"])[0]
 
 
 
 
 
 
 
 
 
 
16
 
17
  @st.cache_resource
18
  def get_model():
@@ -26,51 +62,81 @@ def get_model():
26
  model_data['max_length']
27
  )
28
  except Exception as e:
 
29
  return None
30
 
31
- def process_request():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  model = get_model()
33
 
34
  if model is None:
35
- return {
36
  'error': 'Failed to load model',
37
  'status': 'error'
38
- }
 
39
 
40
- try:
41
- # Parse request body
42
- if request_method == "POST":
 
 
43
  try:
44
- body = json.loads(request_json)
45
  except json.JSONDecodeError:
46
- return {
47
  'error': 'Invalid JSON data',
48
  'status': 'error'
49
- }
50
- else:
51
- return {
52
- 'error': 'Method not allowed. Use POST request.',
53
- 'status': 'error'
54
- }
55
-
56
- # Get audio data
57
- audio_base64 = body.get('audio')
58
- reference_text = body.get('reference_text')
59
-
60
- if not audio_base64:
61
- return {
62
- 'error': 'No audio data provided',
63
- 'status': 'error'
64
- }
65
 
66
- # Process audio
67
- try:
68
- audio_bytes = base64.b64decode(audio_base64)
69
- audio_data, sr = librosa.load(BytesIO(audio_bytes), sr=None)
70
 
71
- if len(audio_data.shape) > 1:
72
- audio_data = np.mean(audio_data, axis=1)
 
 
 
 
73
 
 
 
 
 
74
  # Extract features
75
  mfcc_features = extract_mfcc(audio_data, sr)
76
  mfcc_features = np.expand_dims(mfcc_features, axis=0)
@@ -96,27 +162,19 @@ def process_request():
96
  'word_error_rate': round(float(error_wer), 4),
97
  'character_error_rate': round(float(error_cer), 4)
98
  }
99
-
100
- return response
101
 
102
  except Exception as e:
103
- return {
104
- 'error': f'Error processing audio: {str(e)}',
105
  'status': 'error'
106
- }
107
-
108
- except Exception as e:
109
- return {
110
- 'error': str(e),
111
  'status': 'error'
112
- }
113
-
114
- # Main execution
115
- if "api" in st.experimental_get_query_params():
116
- response = process_request()
117
- st.write(json.dumps(response))
118
- else:
119
- st.write(json.dumps({
120
- 'error': 'Please use the API endpoint with ?api=true',
121
- 'status': 'error'
122
- }))
 
7
  from io import BytesIO
8
  import base64
9
 
10
+ class TwiTranscriptionModel:
11
+ def __init__(self, encoder_model, decoder_model, char_tokenizer, max_length=50):
12
+ self.encoder_model = encoder_model
13
+ self.decoder_model = decoder_model
14
+ self.char_tokenizer = char_tokenizer
15
+ self.max_length = max_length
16
+ self.sos_token = '<sos>'
17
+ self.eos_token = '<eos>'
18
+ self.sos_index = char_tokenizer.word_index[self.sos_token]
19
+ self.eos_index = char_tokenizer.word_index[self.eos_token]
20
+
21
+ def predict(self, audio_features):
22
+ batch_size = audio_features.shape[0]
23
+ transcriptions = []
24
+
25
+ for i in range(batch_size):
26
+ states_value = self.encoder_model.predict(
27
+ audio_features[i:i+1],
28
+ verbose=0
29
+ )
30
+ target_seq = np.array([[self.sos_index]])
31
+ decoded_chars = []
32
+
33
+ for _ in range(self.max_length):
34
+ output_tokens, h, c = self.decoder_model.predict(
35
+ [target_seq] + states_value,
36
+ verbose=0
37
+ )
38
 
39
+ sampled_token_index = np.argmax(output_tokens[0, -1, :])
40
+ sampled_char = self.char_tokenizer.index_word.get(sampled_token_index, '')
41
+
42
+ if sampled_char == self.eos_token or len(decoded_chars) > self.max_length:
43
+ break
44
+
45
+ decoded_chars.append(sampled_char)
46
+ target_seq = np.array([[sampled_token_index]])
47
+ states_value = [h, c]
48
+
49
+ transcriptions.append(''.join(decoded_chars))
50
+
51
+ return transcriptions
52
 
53
  @st.cache_resource
54
  def get_model():
 
62
  model_data['max_length']
63
  )
64
  except Exception as e:
65
+ st.error(f"Error loading model: {str(e)}")
66
  return None
67
 
68
+ def extract_mfcc(audio_data, sr=16000, n_mfcc=13):
69
+ if sr != 16000:
70
+ audio_data = librosa.resample(y=audio_data, orig_sr=sr, target_sr=16000)
71
+
72
+ mfcc = librosa.feature.mfcc(y=audio_data, sr=16000, n_mfcc=n_mfcc)
73
+
74
+ max_length = 1000 # Adjust based on your model's requirements
75
+ if mfcc.shape[1] > max_length:
76
+ mfcc = mfcc[:, :max_length]
77
+ else:
78
+ mfcc = np.pad(mfcc, ((0, 0), (0, max_length - mfcc.shape[1])), mode='constant')
79
+
80
+ return mfcc.T
81
+
82
+ def calculate_error_rates(reference, hypothesis):
83
+ try:
84
+ error_wer = wer(reference, hypothesis)
85
+ error_cer = cer(reference, hypothesis)
86
+ return error_wer, error_cer
87
+ except Exception as e:
88
+ return None, None
89
+
90
+ def process_audio_bytes(audio_bytes):
91
+ try:
92
+ audio_data, sr = librosa.load(BytesIO(audio_bytes), sr=None)
93
+ if len(audio_data.shape) > 1:
94
+ audio_data = np.mean(audio_data, axis=1)
95
+ return audio_data, sr
96
+ except Exception as e:
97
+ raise Exception(f"Error processing audio: {str(e)}")
98
+
99
+ # Set page config
100
+ st.set_page_config(page_title="Twi Speech API")
101
+
102
+ def main():
103
  model = get_model()
104
 
105
  if model is None:
106
+ st.write(json.dumps({
107
  'error': 'Failed to load model',
108
  'status': 'error'
109
+ }))
110
+ return
111
 
112
+ # Check if this is an API request
113
+ if "api" in st.query_params:
114
+ try:
115
+ # Get the request body
116
+ body_data = st.query_params.get("data", "{}")
117
  try:
118
+ data = json.loads(body_data)
119
  except json.JSONDecodeError:
120
+ st.write(json.dumps({
121
  'error': 'Invalid JSON data',
122
  'status': 'error'
123
+ }))
124
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
+ audio_base64 = data.get('audio')
127
+ reference_text = data.get('reference_text')
 
 
128
 
129
+ if not audio_base64:
130
+ st.write(json.dumps({
131
+ 'error': 'No audio data provided',
132
+ 'status': 'error'
133
+ }))
134
+ return
135
 
136
+ # Process audio
137
+ audio_bytes = base64.b64decode(audio_base64)
138
+ audio_data, sr = process_audio_bytes(audio_bytes)
139
+
140
  # Extract features
141
  mfcc_features = extract_mfcc(audio_data, sr)
142
  mfcc_features = np.expand_dims(mfcc_features, axis=0)
 
162
  'word_error_rate': round(float(error_wer), 4),
163
  'character_error_rate': round(float(error_cer), 4)
164
  }
165
+
166
+ st.write(json.dumps(response))
167
 
168
  except Exception as e:
169
+ st.write(json.dumps({
170
+ 'error': str(e),
171
  'status': 'error'
172
+ }))
173
+ else:
174
+ st.write(json.dumps({
175
+ 'error': 'Please use the API endpoint with ?api=true',
 
176
  'status': 'error'
177
+ }))
178
+
179
+ if __name__ == "__main__":
180
+ main()