sedemkofi's picture
Update app.py
cab68e0 verified
import streamlit as st
import numpy as np
import librosa
import pickle
from jiwer import wer, cer
import tensorflow as tf
from io import BytesIO
import soundfile as sf
class TwiTranscriptionModel:
def __init__(self, encoder_model, decoder_model, char_tokenizer, max_length=50):
self.encoder_model = encoder_model
self.decoder_model = decoder_model
self.char_tokenizer = char_tokenizer
self.max_length = max_length
self.sos_token = '<sos>'
self.eos_token = '<eos>'
self.sos_index = char_tokenizer.word_index[self.sos_token]
self.eos_index = char_tokenizer.word_index[self.eos_token]
def predict(self, audio_features):
batch_size = audio_features.shape[0]
transcriptions = []
for i in range(batch_size):
states_value = self.encoder_model.predict(
audio_features[i:i+1],
verbose=0
)
target_seq = np.array([[self.sos_index]])
decoded_chars = []
for _ in range(self.max_length):
output_tokens, h, c = self.decoder_model.predict(
[target_seq] + states_value,
verbose=0
)
sampled_token_index = np.argmax(output_tokens[0, -1, :])
sampled_char = self.char_tokenizer.index_word.get(sampled_token_index, '')
if sampled_char == self.eos_token or len(decoded_chars) > self.max_length:
break
decoded_chars.append(sampled_char)
target_seq = np.array([[sampled_token_index]])
states_value = [h, c]
transcriptions.append(''.join(decoded_chars))
return transcriptions
@st.cache_resource
def get_model():
try:
with open('twi_transcription_model.pkl', 'rb') as f:
model_data = pickle.load(f)
return TwiTranscriptionModel(
model_data['encoder_model'],
model_data['decoder_model'],
model_data['char_tokenizer'],
model_data['max_length']
)
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return None
def extract_mfcc(audio_data, sr=16000, n_mfcc=13):
if sr != 16000:
audio_data = librosa.resample(y=audio_data, orig_sr=sr, target_sr=16000)
mfcc = librosa.feature.mfcc(y=audio_data, sr=16000, n_mfcc=n_mfcc)
max_length = 1000 # Adjust based on your model's requirements
if mfcc.shape[1] > max_length:
mfcc = mfcc[:, :max_length]
else:
mfcc = np.pad(mfcc, ((0, 0), (0, max_length - mfcc.shape[1])), mode='constant')
return mfcc.T
def calculate_error_rates(reference, hypothesis):
try:
error_wer = wer(reference, hypothesis)
error_cer = cer(reference, hypothesis)
return error_wer, error_cer
except Exception as e:
return None, None
def process_audio_file(audio_file, model, reference_text=None):
"""Process uploaded audio file and return transcription"""
try:
# Read audio file
audio_data, sr = librosa.load(audio_file, sr=None)
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
# Extract features
mfcc_features = extract_mfcc(audio_data, sr)
mfcc_features = np.expand_dims(mfcc_features, axis=0)
# Get transcription
transcription = model.predict(mfcc_features)[0]
# Prepare response
response = {
'status': 'success',
'transcription': transcription,
'audio_details': {
'sample_rate': int(sr),
'duration': float(len(audio_data) / sr)
},
'audio_data': audio_data,
'sample_rate': sr
}
# Add error metrics if reference provided
if reference_text:
error_wer, error_cer = calculate_error_rates(reference_text, transcription)
if error_wer is not None and error_cer is not None:
response['error_metrics'] = {
'word_error_rate': round(float(error_wer), 4),
'character_error_rate': round(float(error_cer), 4)
}
return response
except Exception as e:
return {
'status': 'error',
'error': str(e)
}
def main():
st.set_page_config(
page_title="Twi Speech Recognition",
page_icon="🎀",
layout="wide"
)
# Initialize model
model = get_model()
if model is None:
st.error("Failed to load model. Please try again later.")
return
st.title("Twi Speech Recognition")
st.write("Upload an audio file for transcription")
# File uploader
audio_file = st.file_uploader("Choose an audio file", type=['wav', 'mp3', 'ogg'])
# Optional reference text
reference_text = st.text_area("Reference text (optional)", "", help="Enter the correct transcription to calculate error rates")
if audio_file is not None:
if st.button("Transcribe"):
with st.spinner("Processing audio... This may take a few minutes."):
result = process_audio_file(
audio_file,
model,
reference_text if reference_text else None
)
if result['status'] == 'success':
st.success("Transcription completed!")
# Convert audio data to bytes for Streamlit audio player
audio_bytes = BytesIO()
sf.write(audio_bytes, result['audio_data'], result['sample_rate'], format='WAV')
audio_bytes.seek(0)
# Audio Playback
st.audio(audio_bytes, format='audio/wav')
# Transcription Display
st.write("### Transcription:")
st.write(result['transcription'])
# Audio Details
st.write("### Audio Details:")
st.json(result['audio_details'])
# Error Metrics
if 'error_metrics' in result:
st.write("### Error Metrics:")
st.json(result['error_metrics'])
else:
st.error(f"Error: {result.get('error', 'Unknown error')}")
if __name__ == "__main__":
main()