File size: 6,596 Bytes
cab68e0
2df44cf
3f5a402
 
 
 
 
148dafd
3f5a402
52f9ffb
62a2e2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2e2e8b
2df44cf
cab68e0
62a2e2b
148dafd
62a2e2b
2b84614
 
148dafd
 
2b84614
 
62a2e2b
2df44cf
62a2e2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cab68e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2df44cf
 
 
 
 
 
 
cab68e0
 
2df44cf
cab68e0
2df44cf
 
cab68e0
 
2df44cf
 
 
 
 
cab68e0
2df44cf
 
 
cab68e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2df44cf
cab68e0
 
 
 
 
 
2df44cf
 
cab68e0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192

import streamlit as st
import numpy as np
import librosa
import pickle
from jiwer import wer, cer
import tensorflow as tf
from io import BytesIO
import soundfile as sf

class TwiTranscriptionModel:
    def __init__(self, encoder_model, decoder_model, char_tokenizer, max_length=50):
        self.encoder_model = encoder_model
        self.decoder_model = decoder_model
        self.char_tokenizer = char_tokenizer
        self.max_length = max_length
        self.sos_token = '<sos>'
        self.eos_token = '<eos>'
        self.sos_index = char_tokenizer.word_index[self.sos_token]
        self.eos_index = char_tokenizer.word_index[self.eos_token]

    def predict(self, audio_features):
        batch_size = audio_features.shape[0]
        transcriptions = []

        for i in range(batch_size):
            states_value = self.encoder_model.predict(
                audio_features[i:i+1], 
                verbose=0
            )
            target_seq = np.array([[self.sos_index]])
            decoded_chars = []
            
            for _ in range(self.max_length):
                output_tokens, h, c = self.decoder_model.predict(
                    [target_seq] + states_value,
                    verbose=0
                )

                sampled_token_index = np.argmax(output_tokens[0, -1, :])
                sampled_char = self.char_tokenizer.index_word.get(sampled_token_index, '')

                if sampled_char == self.eos_token or len(decoded_chars) > self.max_length:
                    break

                decoded_chars.append(sampled_char)
                target_seq = np.array([[sampled_token_index]])
                states_value = [h, c]

            transcriptions.append(''.join(decoded_chars))

        return transcriptions

@st.cache_resource
def get_model():
    try:
        with open('twi_transcription_model.pkl', 'rb') as f:
            model_data = pickle.load(f)
            return TwiTranscriptionModel(
                model_data['encoder_model'],
                model_data['decoder_model'],
                model_data['char_tokenizer'],
                model_data['max_length']
            )
    except Exception as e:
        st.error(f"Error loading model: {str(e)}")
        return None

def extract_mfcc(audio_data, sr=16000, n_mfcc=13):
    if sr != 16000:
        audio_data = librosa.resample(y=audio_data, orig_sr=sr, target_sr=16000)
    
    mfcc = librosa.feature.mfcc(y=audio_data, sr=16000, n_mfcc=n_mfcc)
    
    max_length = 1000  # Adjust based on your model's requirements
    if mfcc.shape[1] > max_length:
        mfcc = mfcc[:, :max_length]
    else:
        mfcc = np.pad(mfcc, ((0, 0), (0, max_length - mfcc.shape[1])), mode='constant')
    
    return mfcc.T

def calculate_error_rates(reference, hypothesis):
    try:
        error_wer = wer(reference, hypothesis)
        error_cer = cer(reference, hypothesis)
        return error_wer, error_cer
    except Exception as e:
        return None, None

def process_audio_file(audio_file, model, reference_text=None):
    """Process uploaded audio file and return transcription"""
    try:
        # Read audio file
        audio_data, sr = librosa.load(audio_file, sr=None)
        if len(audio_data.shape) > 1:
            audio_data = np.mean(audio_data, axis=1)
        
        # Extract features
        mfcc_features = extract_mfcc(audio_data, sr)
        mfcc_features = np.expand_dims(mfcc_features, axis=0)
        
        # Get transcription
        transcription = model.predict(mfcc_features)[0]
        
        # Prepare response
        response = {
            'status': 'success',
            'transcription': transcription,
            'audio_details': {
                'sample_rate': int(sr),
                'duration': float(len(audio_data) / sr)
            },
            'audio_data': audio_data,
            'sample_rate': sr
        }
        
        # Add error metrics if reference provided
        if reference_text:
            error_wer, error_cer = calculate_error_rates(reference_text, transcription)
            if error_wer is not None and error_cer is not None:
                response['error_metrics'] = {
                    'word_error_rate': round(float(error_wer), 4),
                    'character_error_rate': round(float(error_cer), 4)
                }
        
        return response
    except Exception as e:
        return {
            'status': 'error',
            'error': str(e)
        }

def main():
    st.set_page_config(
        page_title="Twi Speech Recognition",
        page_icon="🎤",
        layout="wide"
    )
    
    # Initialize model
    model = get_model()
    if model is None:
        st.error("Failed to load model. Please try again later.")
        return

    st.title("Twi Speech Recognition")
    st.write("Upload an audio file for transcription")
    
    # File uploader
    audio_file = st.file_uploader("Choose an audio file", type=['wav', 'mp3', 'ogg'])
    
    # Optional reference text
    reference_text = st.text_area("Reference text (optional)", "", help="Enter the correct transcription to calculate error rates")
    
    if audio_file is not None:
        if st.button("Transcribe"):
            with st.spinner("Processing audio... This may take a few minutes."):
                result = process_audio_file(
                    audio_file,
                    model,
                    reference_text if reference_text else None
                )
            
            if result['status'] == 'success':
                st.success("Transcription completed!")
                
                # Convert audio data to bytes for Streamlit audio player
                audio_bytes = BytesIO()
                sf.write(audio_bytes, result['audio_data'], result['sample_rate'], format='WAV')
                audio_bytes.seek(0)
                
                # Audio Playback
                st.audio(audio_bytes, format='audio/wav')
                
                # Transcription Display
                st.write("### Transcription:")
                st.write(result['transcription'])
                
                # Audio Details
                st.write("### Audio Details:")
                st.json(result['audio_details'])
                
                # Error Metrics
                if 'error_metrics' in result:
                    st.write("### Error Metrics:")
                    st.json(result['error_metrics'])
            else:
                st.error(f"Error: {result.get('error', 'Unknown error')}")

if __name__ == "__main__":
    main()