import streamlit as st import numpy as np import librosa import pickle from jiwer import wer, cer import tensorflow as tf from io import BytesIO import soundfile as sf class TwiTranscriptionModel: def __init__(self, encoder_model, decoder_model, char_tokenizer, max_length=50): self.encoder_model = encoder_model self.decoder_model = decoder_model self.char_tokenizer = char_tokenizer self.max_length = max_length self.sos_token = '' self.eos_token = '' self.sos_index = char_tokenizer.word_index[self.sos_token] self.eos_index = char_tokenizer.word_index[self.eos_token] def predict(self, audio_features): batch_size = audio_features.shape[0] transcriptions = [] for i in range(batch_size): states_value = self.encoder_model.predict( audio_features[i:i+1], verbose=0 ) target_seq = np.array([[self.sos_index]]) decoded_chars = [] for _ in range(self.max_length): output_tokens, h, c = self.decoder_model.predict( [target_seq] + states_value, verbose=0 ) sampled_token_index = np.argmax(output_tokens[0, -1, :]) sampled_char = self.char_tokenizer.index_word.get(sampled_token_index, '') if sampled_char == self.eos_token or len(decoded_chars) > self.max_length: break decoded_chars.append(sampled_char) target_seq = np.array([[sampled_token_index]]) states_value = [h, c] transcriptions.append(''.join(decoded_chars)) return transcriptions @st.cache_resource def get_model(): try: with open('twi_transcription_model.pkl', 'rb') as f: model_data = pickle.load(f) return TwiTranscriptionModel( model_data['encoder_model'], model_data['decoder_model'], model_data['char_tokenizer'], model_data['max_length'] ) except Exception as e: st.error(f"Error loading model: {str(e)}") return None def extract_mfcc(audio_data, sr=16000, n_mfcc=13): if sr != 16000: audio_data = librosa.resample(y=audio_data, orig_sr=sr, target_sr=16000) mfcc = librosa.feature.mfcc(y=audio_data, sr=16000, n_mfcc=n_mfcc) max_length = 1000 # Adjust based on your model's requirements if mfcc.shape[1] > max_length: mfcc = mfcc[:, :max_length] else: mfcc = np.pad(mfcc, ((0, 0), (0, max_length - mfcc.shape[1])), mode='constant') return mfcc.T def calculate_error_rates(reference, hypothesis): try: error_wer = wer(reference, hypothesis) error_cer = cer(reference, hypothesis) return error_wer, error_cer except Exception as e: return None, None def process_audio_file(audio_file, model, reference_text=None): """Process uploaded audio file and return transcription""" try: # Read audio file audio_data, sr = librosa.load(audio_file, sr=None) if len(audio_data.shape) > 1: audio_data = np.mean(audio_data, axis=1) # Extract features mfcc_features = extract_mfcc(audio_data, sr) mfcc_features = np.expand_dims(mfcc_features, axis=0) # Get transcription transcription = model.predict(mfcc_features)[0] # Prepare response response = { 'status': 'success', 'transcription': transcription, 'audio_details': { 'sample_rate': int(sr), 'duration': float(len(audio_data) / sr) }, 'audio_data': audio_data, 'sample_rate': sr } # Add error metrics if reference provided if reference_text: error_wer, error_cer = calculate_error_rates(reference_text, transcription) if error_wer is not None and error_cer is not None: response['error_metrics'] = { 'word_error_rate': round(float(error_wer), 4), 'character_error_rate': round(float(error_cer), 4) } return response except Exception as e: return { 'status': 'error', 'error': str(e) } def main(): st.set_page_config( page_title="Twi Speech Recognition", page_icon="🎤", layout="wide" ) # Initialize model model = get_model() if model is None: st.error("Failed to load model. Please try again later.") return st.title("Twi Speech Recognition") st.write("Upload an audio file for transcription") # File uploader audio_file = st.file_uploader("Choose an audio file", type=['wav', 'mp3', 'ogg']) # Optional reference text reference_text = st.text_area("Reference text (optional)", "", help="Enter the correct transcription to calculate error rates") if audio_file is not None: if st.button("Transcribe"): with st.spinner("Processing audio... This may take a few minutes."): result = process_audio_file( audio_file, model, reference_text if reference_text else None ) if result['status'] == 'success': st.success("Transcription completed!") # Convert audio data to bytes for Streamlit audio player audio_bytes = BytesIO() sf.write(audio_bytes, result['audio_data'], result['sample_rate'], format='WAV') audio_bytes.seek(0) # Audio Playback st.audio(audio_bytes, format='audio/wav') # Transcription Display st.write("### Transcription:") st.write(result['transcription']) # Audio Details st.write("### Audio Details:") st.json(result['audio_details']) # Error Metrics if 'error_metrics' in result: st.write("### Error Metrics:") st.json(result['error_metrics']) else: st.error(f"Error: {result.get('error', 'Unknown error')}") if __name__ == "__main__": main()