Spaces:
Sleeping
Sleeping
import streamlit as st | |
import numpy as np | |
import librosa | |
import pickle | |
from jiwer import wer, cer | |
import tensorflow as tf | |
from io import BytesIO | |
import soundfile as sf | |
class TwiTranscriptionModel: | |
def __init__(self, encoder_model, decoder_model, char_tokenizer, max_length=50): | |
self.encoder_model = encoder_model | |
self.decoder_model = decoder_model | |
self.char_tokenizer = char_tokenizer | |
self.max_length = max_length | |
self.sos_token = '<sos>' | |
self.eos_token = '<eos>' | |
self.sos_index = char_tokenizer.word_index[self.sos_token] | |
self.eos_index = char_tokenizer.word_index[self.eos_token] | |
def predict(self, audio_features): | |
batch_size = audio_features.shape[0] | |
transcriptions = [] | |
for i in range(batch_size): | |
states_value = self.encoder_model.predict( | |
audio_features[i:i+1], | |
verbose=0 | |
) | |
target_seq = np.array([[self.sos_index]]) | |
decoded_chars = [] | |
for _ in range(self.max_length): | |
output_tokens, h, c = self.decoder_model.predict( | |
[target_seq] + states_value, | |
verbose=0 | |
) | |
sampled_token_index = np.argmax(output_tokens[0, -1, :]) | |
sampled_char = self.char_tokenizer.index_word.get(sampled_token_index, '') | |
if sampled_char == self.eos_token or len(decoded_chars) > self.max_length: | |
break | |
decoded_chars.append(sampled_char) | |
target_seq = np.array([[sampled_token_index]]) | |
states_value = [h, c] | |
transcriptions.append(''.join(decoded_chars)) | |
return transcriptions | |
def get_model(): | |
try: | |
with open('twi_transcription_model.pkl', 'rb') as f: | |
model_data = pickle.load(f) | |
return TwiTranscriptionModel( | |
model_data['encoder_model'], | |
model_data['decoder_model'], | |
model_data['char_tokenizer'], | |
model_data['max_length'] | |
) | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
return None | |
def extract_mfcc(audio_data, sr=16000, n_mfcc=13): | |
if sr != 16000: | |
audio_data = librosa.resample(y=audio_data, orig_sr=sr, target_sr=16000) | |
mfcc = librosa.feature.mfcc(y=audio_data, sr=16000, n_mfcc=n_mfcc) | |
max_length = 1000 # Adjust based on your model's requirements | |
if mfcc.shape[1] > max_length: | |
mfcc = mfcc[:, :max_length] | |
else: | |
mfcc = np.pad(mfcc, ((0, 0), (0, max_length - mfcc.shape[1])), mode='constant') | |
return mfcc.T | |
def calculate_error_rates(reference, hypothesis): | |
try: | |
error_wer = wer(reference, hypothesis) | |
error_cer = cer(reference, hypothesis) | |
return error_wer, error_cer | |
except Exception as e: | |
return None, None | |
def process_audio_file(audio_file, model, reference_text=None): | |
"""Process uploaded audio file and return transcription""" | |
try: | |
# Read audio file | |
audio_data, sr = librosa.load(audio_file, sr=None) | |
if len(audio_data.shape) > 1: | |
audio_data = np.mean(audio_data, axis=1) | |
# Extract features | |
mfcc_features = extract_mfcc(audio_data, sr) | |
mfcc_features = np.expand_dims(mfcc_features, axis=0) | |
# Get transcription | |
transcription = model.predict(mfcc_features)[0] | |
# Prepare response | |
response = { | |
'status': 'success', | |
'transcription': transcription, | |
'audio_details': { | |
'sample_rate': int(sr), | |
'duration': float(len(audio_data) / sr) | |
}, | |
'audio_data': audio_data, | |
'sample_rate': sr | |
} | |
# Add error metrics if reference provided | |
if reference_text: | |
error_wer, error_cer = calculate_error_rates(reference_text, transcription) | |
if error_wer is not None and error_cer is not None: | |
response['error_metrics'] = { | |
'word_error_rate': round(float(error_wer), 4), | |
'character_error_rate': round(float(error_cer), 4) | |
} | |
return response | |
except Exception as e: | |
return { | |
'status': 'error', | |
'error': str(e) | |
} | |
def main(): | |
st.set_page_config( | |
page_title="Twi Speech Recognition", | |
page_icon="π€", | |
layout="wide" | |
) | |
# Initialize model | |
model = get_model() | |
if model is None: | |
st.error("Failed to load model. Please try again later.") | |
return | |
st.title("Twi Speech Recognition") | |
st.write("Upload an audio file for transcription") | |
# File uploader | |
audio_file = st.file_uploader("Choose an audio file", type=['wav', 'mp3', 'ogg']) | |
# Optional reference text | |
reference_text = st.text_area("Reference text (optional)", "", help="Enter the correct transcription to calculate error rates") | |
if audio_file is not None: | |
if st.button("Transcribe"): | |
with st.spinner("Processing audio... This may take a few minutes."): | |
result = process_audio_file( | |
audio_file, | |
model, | |
reference_text if reference_text else None | |
) | |
if result['status'] == 'success': | |
st.success("Transcription completed!") | |
# Convert audio data to bytes for Streamlit audio player | |
audio_bytes = BytesIO() | |
sf.write(audio_bytes, result['audio_data'], result['sample_rate'], format='WAV') | |
audio_bytes.seek(0) | |
# Audio Playback | |
st.audio(audio_bytes, format='audio/wav') | |
# Transcription Display | |
st.write("### Transcription:") | |
st.write(result['transcription']) | |
# Audio Details | |
st.write("### Audio Details:") | |
st.json(result['audio_details']) | |
# Error Metrics | |
if 'error_metrics' in result: | |
st.write("### Error Metrics:") | |
st.json(result['error_metrics']) | |
else: | |
st.error(f"Error: {result.get('error', 'Unknown error')}") | |
if __name__ == "__main__": | |
main() |