Spaces:
Sleeping
Sleeping
File size: 6,596 Bytes
cab68e0 2df44cf 3f5a402 148dafd 3f5a402 52f9ffb 62a2e2b e2e2e8b 2df44cf cab68e0 62a2e2b 148dafd 62a2e2b 2b84614 148dafd 2b84614 62a2e2b 2df44cf 62a2e2b cab68e0 2df44cf cab68e0 2df44cf cab68e0 2df44cf cab68e0 2df44cf cab68e0 2df44cf cab68e0 2df44cf cab68e0 2df44cf cab68e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import streamlit as st
import numpy as np
import librosa
import pickle
from jiwer import wer, cer
import tensorflow as tf
from io import BytesIO
import soundfile as sf
class TwiTranscriptionModel:
def __init__(self, encoder_model, decoder_model, char_tokenizer, max_length=50):
self.encoder_model = encoder_model
self.decoder_model = decoder_model
self.char_tokenizer = char_tokenizer
self.max_length = max_length
self.sos_token = '<sos>'
self.eos_token = '<eos>'
self.sos_index = char_tokenizer.word_index[self.sos_token]
self.eos_index = char_tokenizer.word_index[self.eos_token]
def predict(self, audio_features):
batch_size = audio_features.shape[0]
transcriptions = []
for i in range(batch_size):
states_value = self.encoder_model.predict(
audio_features[i:i+1],
verbose=0
)
target_seq = np.array([[self.sos_index]])
decoded_chars = []
for _ in range(self.max_length):
output_tokens, h, c = self.decoder_model.predict(
[target_seq] + states_value,
verbose=0
)
sampled_token_index = np.argmax(output_tokens[0, -1, :])
sampled_char = self.char_tokenizer.index_word.get(sampled_token_index, '')
if sampled_char == self.eos_token or len(decoded_chars) > self.max_length:
break
decoded_chars.append(sampled_char)
target_seq = np.array([[sampled_token_index]])
states_value = [h, c]
transcriptions.append(''.join(decoded_chars))
return transcriptions
@st.cache_resource
def get_model():
try:
with open('twi_transcription_model.pkl', 'rb') as f:
model_data = pickle.load(f)
return TwiTranscriptionModel(
model_data['encoder_model'],
model_data['decoder_model'],
model_data['char_tokenizer'],
model_data['max_length']
)
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return None
def extract_mfcc(audio_data, sr=16000, n_mfcc=13):
if sr != 16000:
audio_data = librosa.resample(y=audio_data, orig_sr=sr, target_sr=16000)
mfcc = librosa.feature.mfcc(y=audio_data, sr=16000, n_mfcc=n_mfcc)
max_length = 1000 # Adjust based on your model's requirements
if mfcc.shape[1] > max_length:
mfcc = mfcc[:, :max_length]
else:
mfcc = np.pad(mfcc, ((0, 0), (0, max_length - mfcc.shape[1])), mode='constant')
return mfcc.T
def calculate_error_rates(reference, hypothesis):
try:
error_wer = wer(reference, hypothesis)
error_cer = cer(reference, hypothesis)
return error_wer, error_cer
except Exception as e:
return None, None
def process_audio_file(audio_file, model, reference_text=None):
"""Process uploaded audio file and return transcription"""
try:
# Read audio file
audio_data, sr = librosa.load(audio_file, sr=None)
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
# Extract features
mfcc_features = extract_mfcc(audio_data, sr)
mfcc_features = np.expand_dims(mfcc_features, axis=0)
# Get transcription
transcription = model.predict(mfcc_features)[0]
# Prepare response
response = {
'status': 'success',
'transcription': transcription,
'audio_details': {
'sample_rate': int(sr),
'duration': float(len(audio_data) / sr)
},
'audio_data': audio_data,
'sample_rate': sr
}
# Add error metrics if reference provided
if reference_text:
error_wer, error_cer = calculate_error_rates(reference_text, transcription)
if error_wer is not None and error_cer is not None:
response['error_metrics'] = {
'word_error_rate': round(float(error_wer), 4),
'character_error_rate': round(float(error_cer), 4)
}
return response
except Exception as e:
return {
'status': 'error',
'error': str(e)
}
def main():
st.set_page_config(
page_title="Twi Speech Recognition",
page_icon="🎤",
layout="wide"
)
# Initialize model
model = get_model()
if model is None:
st.error("Failed to load model. Please try again later.")
return
st.title("Twi Speech Recognition")
st.write("Upload an audio file for transcription")
# File uploader
audio_file = st.file_uploader("Choose an audio file", type=['wav', 'mp3', 'ogg'])
# Optional reference text
reference_text = st.text_area("Reference text (optional)", "", help="Enter the correct transcription to calculate error rates")
if audio_file is not None:
if st.button("Transcribe"):
with st.spinner("Processing audio... This may take a few minutes."):
result = process_audio_file(
audio_file,
model,
reference_text if reference_text else None
)
if result['status'] == 'success':
st.success("Transcription completed!")
# Convert audio data to bytes for Streamlit audio player
audio_bytes = BytesIO()
sf.write(audio_bytes, result['audio_data'], result['sample_rate'], format='WAV')
audio_bytes.seek(0)
# Audio Playback
st.audio(audio_bytes, format='audio/wav')
# Transcription Display
st.write("### Transcription:")
st.write(result['transcription'])
# Audio Details
st.write("### Audio Details:")
st.json(result['audio_details'])
# Error Metrics
if 'error_metrics' in result:
st.write("### Error Metrics:")
st.json(result['error_metrics'])
else:
st.error(f"Error: {result.get('error', 'Unknown error')}")
if __name__ == "__main__":
main() |