Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -6,10 +6,51 @@ from jiwer import wer, cer
|
|
6 |
import json
|
7 |
from io import BytesIO
|
8 |
import base64
|
|
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
# Add this at the top of your file
|
13 |
class ChunkedUploader:
|
14 |
def __init__(self):
|
15 |
if 'chunks' not in st.session_state:
|
@@ -38,10 +79,63 @@ class ChunkedUploader:
|
|
38 |
del st.session_state.chunks[upload_id]
|
39 |
return complete_data
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
def main():
|
42 |
-
|
43 |
-
chunked_uploader = ChunkedUploader()
|
44 |
|
|
|
|
|
45 |
if model is None:
|
46 |
st.write(json.dumps({
|
47 |
'error': 'Failed to load model',
|
@@ -49,6 +143,9 @@ def main():
|
|
49 |
}))
|
50 |
return
|
51 |
|
|
|
|
|
|
|
52 |
# Check if this is an API request
|
53 |
if "api" in st.query_params:
|
54 |
try:
|
|
|
6 |
import json
|
7 |
from io import BytesIO
|
8 |
import base64
|
9 |
+
import tensorflow as tf
|
10 |
|
11 |
+
class TwiTranscriptionModel:
|
12 |
+
def __init__(self, encoder_model, decoder_model, char_tokenizer, max_length=50):
|
13 |
+
self.encoder_model = encoder_model
|
14 |
+
self.decoder_model = decoder_model
|
15 |
+
self.char_tokenizer = char_tokenizer
|
16 |
+
self.max_length = max_length
|
17 |
+
self.sos_token = '<sos>'
|
18 |
+
self.eos_token = '<eos>'
|
19 |
+
self.sos_index = char_tokenizer.word_index[self.sos_token]
|
20 |
+
self.eos_index = char_tokenizer.word_index[self.eos_token]
|
21 |
+
|
22 |
+
def predict(self, audio_features):
|
23 |
+
batch_size = audio_features.shape[0]
|
24 |
+
transcriptions = []
|
25 |
+
|
26 |
+
for i in range(batch_size):
|
27 |
+
states_value = self.encoder_model.predict(
|
28 |
+
audio_features[i:i+1],
|
29 |
+
verbose=0
|
30 |
+
)
|
31 |
+
target_seq = np.array([[self.sos_index]])
|
32 |
+
decoded_chars = []
|
33 |
+
|
34 |
+
for _ in range(self.max_length):
|
35 |
+
output_tokens, h, c = self.decoder_model.predict(
|
36 |
+
[target_seq] + states_value,
|
37 |
+
verbose=0
|
38 |
+
)
|
39 |
+
|
40 |
+
sampled_token_index = np.argmax(output_tokens[0, -1, :])
|
41 |
+
sampled_char = self.char_tokenizer.index_word.get(sampled_token_index, '')
|
42 |
+
|
43 |
+
if sampled_char == self.eos_token or len(decoded_chars) > self.max_length:
|
44 |
+
break
|
45 |
+
|
46 |
+
decoded_chars.append(sampled_char)
|
47 |
+
target_seq = np.array([[sampled_token_index]])
|
48 |
+
states_value = [h, c]
|
49 |
+
|
50 |
+
transcriptions.append(''.join(decoded_chars))
|
51 |
+
|
52 |
+
return transcriptions
|
53 |
|
|
|
54 |
class ChunkedUploader:
|
55 |
def __init__(self):
|
56 |
if 'chunks' not in st.session_state:
|
|
|
79 |
del st.session_state.chunks[upload_id]
|
80 |
return complete_data
|
81 |
|
82 |
+
@st.cache_resource
|
83 |
+
def get_model():
|
84 |
+
try:
|
85 |
+
model_path = 'twi_transcription_model.pkl'
|
86 |
+
st.write(f"Looking for model at: {model_path}") # Debug info
|
87 |
+
|
88 |
+
with open(model_path, 'rb') as f:
|
89 |
+
model_data = pickle.load(f)
|
90 |
+
|
91 |
+
model = TwiTranscriptionModel(
|
92 |
+
model_data['encoder_model'],
|
93 |
+
model_data['decoder_model'],
|
94 |
+
model_data['char_tokenizer'],
|
95 |
+
model_data['max_length']
|
96 |
+
)
|
97 |
+
st.write("Model loaded successfully") # Debug info
|
98 |
+
return model
|
99 |
+
except Exception as e:
|
100 |
+
st.error(f"Error loading model: {str(e)}")
|
101 |
+
return None
|
102 |
+
|
103 |
+
def extract_mfcc(audio_data, sr=16000, n_mfcc=13):
|
104 |
+
if sr != 16000:
|
105 |
+
audio_data = librosa.resample(y=audio_data, orig_sr=sr, target_sr=16000)
|
106 |
+
|
107 |
+
mfcc = librosa.feature.mfcc(y=audio_data, sr=16000, n_mfcc=n_mfcc)
|
108 |
+
|
109 |
+
max_length = 1000 # Adjust based on your model's requirements
|
110 |
+
if mfcc.shape[1] > max_length:
|
111 |
+
mfcc = mfcc[:, :max_length]
|
112 |
+
else:
|
113 |
+
mfcc = np.pad(mfcc, ((0, 0), (0, max_length - mfcc.shape[1])), mode='constant')
|
114 |
+
|
115 |
+
return mfcc.T
|
116 |
+
|
117 |
+
def calculate_error_rates(reference, hypothesis):
|
118 |
+
try:
|
119 |
+
error_wer = wer(reference, hypothesis)
|
120 |
+
error_cer = cer(reference, hypothesis)
|
121 |
+
return error_wer, error_cer
|
122 |
+
except Exception as e:
|
123 |
+
return None, None
|
124 |
+
|
125 |
+
def process_audio_bytes(audio_bytes):
|
126 |
+
try:
|
127 |
+
audio_data, sr = librosa.load(BytesIO(audio_bytes), sr=None)
|
128 |
+
if len(audio_data.shape) > 1:
|
129 |
+
audio_data = np.mean(audio_data, axis=1)
|
130 |
+
return audio_data, sr
|
131 |
+
except Exception as e:
|
132 |
+
raise Exception(f"Error processing audio: {str(e)}")
|
133 |
+
|
134 |
def main():
|
135 |
+
st.set_page_config(page_title="Twi Speech API")
|
|
|
136 |
|
137 |
+
# Initialize model
|
138 |
+
model = get_model()
|
139 |
if model is None:
|
140 |
st.write(json.dumps({
|
141 |
'error': 'Failed to load model',
|
|
|
143 |
}))
|
144 |
return
|
145 |
|
146 |
+
# Initialize chunked uploader
|
147 |
+
chunked_uploader = ChunkedUploader()
|
148 |
+
|
149 |
# Check if this is an API request
|
150 |
if "api" in st.query_params:
|
151 |
try:
|