sedemkofi commited on
Commit
cab68e0
·
verified ·
1 Parent(s): 2df44cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -62
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import streamlit as st
2
  import numpy as np
3
  import librosa
@@ -51,9 +52,8 @@ class TwiTranscriptionModel:
51
  return transcriptions
52
 
53
  @st.cache_resource
54
- def load_model():
55
  try:
56
- # Modify this path if your model is stored differently in Hugging Face
57
  with open('twi_transcription_model.pkl', 'rb') as f:
58
  model_data = pickle.load(f)
59
  return TwiTranscriptionModel(
@@ -88,6 +88,49 @@ def calculate_error_rates(reference, hypothesis):
88
  except Exception as e:
89
  return None, None
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  def main():
92
  st.set_page_config(
93
  page_title="Twi Speech Recognition",
@@ -95,77 +138,55 @@ def main():
95
  layout="wide"
96
  )
97
 
98
- # Load the model
99
- model = load_model()
100
  if model is None:
101
- st.error("Failed to load model. Please check model file.")
102
  return
103
 
104
- st.title("Twi Speech Transcription")
105
- st.write("Upload an audio file to transcribe Twi speech")
106
 
107
  # File uploader
108
  audio_file = st.file_uploader("Choose an audio file", type=['wav', 'mp3', 'ogg'])
109
 
110
  # Optional reference text
111
- reference_text = st.text_area("Reference text (optional)",
112
- help="Enter the correct transcription to calculate error rates")
113
 
114
  if audio_file is not None:
115
  if st.button("Transcribe"):
116
- with st.spinner("Processing audio... This may take a moment."):
117
- try:
118
- # Read audio file
119
- audio_data, sr = librosa.load(audio_file, sr=None)
120
- if len(audio_data.shape) > 1:
121
- audio_data = np.mean(audio_data, axis=1)
122
-
123
- # Extract features
124
- mfcc_features = extract_mfcc(audio_data, sr)
125
- mfcc_features = np.expand_dims(mfcc_features, axis=0)
126
-
127
- # Get transcription
128
- transcription = model.predict(mfcc_features)[0]
129
-
130
- # Display results
131
- st.success("Transcription completed!")
132
-
133
- # Audio Playback
134
- st.audio(audio_file, format='audio/wav')
135
-
136
- # Transcription Display
137
- st.write("### Transcription:")
138
- st.write(transcription)
139
-
140
- # Audio Details
141
- st.write("### Audio Details:")
142
- st.json({
143
- 'sample_rate': int(sr),
144
- 'duration': float(len(audio_data) / sr)
145
- })
146
-
147
- # Error Metrics (if reference text provided)
148
- if reference_text:
149
- error_wer, error_cer = calculate_error_rates(reference_text, transcription)
150
- if error_wer is not None and error_cer is not None:
151
- st.write("### Error Metrics:")
152
- st.json({
153
- 'word_error_rate': round(float(error_wer), 4),
154
- 'character_error_rate': round(float(error_cer), 4)
155
- })
156
 
157
- except Exception as e:
158
- st.error(f"Error processing audio: {str(e)}")
 
 
 
 
159
 
160
  if __name__ == "__main__":
161
- main()
162
-
163
- # Requirements for Hugging Face (create a requirements.txt)
164
- """
165
- streamlit
166
- numpy
167
- librosa
168
- tensorflow
169
- jiwer
170
- soundfile
171
- """
 
1
+
2
  import streamlit as st
3
  import numpy as np
4
  import librosa
 
52
  return transcriptions
53
 
54
  @st.cache_resource
55
+ def get_model():
56
  try:
 
57
  with open('twi_transcription_model.pkl', 'rb') as f:
58
  model_data = pickle.load(f)
59
  return TwiTranscriptionModel(
 
88
  except Exception as e:
89
  return None, None
90
 
91
+ def process_audio_file(audio_file, model, reference_text=None):
92
+ """Process uploaded audio file and return transcription"""
93
+ try:
94
+ # Read audio file
95
+ audio_data, sr = librosa.load(audio_file, sr=None)
96
+ if len(audio_data.shape) > 1:
97
+ audio_data = np.mean(audio_data, axis=1)
98
+
99
+ # Extract features
100
+ mfcc_features = extract_mfcc(audio_data, sr)
101
+ mfcc_features = np.expand_dims(mfcc_features, axis=0)
102
+
103
+ # Get transcription
104
+ transcription = model.predict(mfcc_features)[0]
105
+
106
+ # Prepare response
107
+ response = {
108
+ 'status': 'success',
109
+ 'transcription': transcription,
110
+ 'audio_details': {
111
+ 'sample_rate': int(sr),
112
+ 'duration': float(len(audio_data) / sr)
113
+ },
114
+ 'audio_data': audio_data,
115
+ 'sample_rate': sr
116
+ }
117
+
118
+ # Add error metrics if reference provided
119
+ if reference_text:
120
+ error_wer, error_cer = calculate_error_rates(reference_text, transcription)
121
+ if error_wer is not None and error_cer is not None:
122
+ response['error_metrics'] = {
123
+ 'word_error_rate': round(float(error_wer), 4),
124
+ 'character_error_rate': round(float(error_cer), 4)
125
+ }
126
+
127
+ return response
128
+ except Exception as e:
129
+ return {
130
+ 'status': 'error',
131
+ 'error': str(e)
132
+ }
133
+
134
  def main():
135
  st.set_page_config(
136
  page_title="Twi Speech Recognition",
 
138
  layout="wide"
139
  )
140
 
141
+ # Initialize model
142
+ model = get_model()
143
  if model is None:
144
+ st.error("Failed to load model. Please try again later.")
145
  return
146
 
147
+ st.title("Twi Speech Recognition")
148
+ st.write("Upload an audio file for transcription")
149
 
150
  # File uploader
151
  audio_file = st.file_uploader("Choose an audio file", type=['wav', 'mp3', 'ogg'])
152
 
153
  # Optional reference text
154
+ reference_text = st.text_area("Reference text (optional)", "", help="Enter the correct transcription to calculate error rates")
 
155
 
156
  if audio_file is not None:
157
  if st.button("Transcribe"):
158
+ with st.spinner("Processing audio... This may take a few minutes."):
159
+ result = process_audio_file(
160
+ audio_file,
161
+ model,
162
+ reference_text if reference_text else None
163
+ )
164
+
165
+ if result['status'] == 'success':
166
+ st.success("Transcription completed!")
167
+
168
+ # Convert audio data to bytes for Streamlit audio player
169
+ audio_bytes = BytesIO()
170
+ sf.write(audio_bytes, result['audio_data'], result['sample_rate'], format='WAV')
171
+ audio_bytes.seek(0)
172
+
173
+ # Audio Playback
174
+ st.audio(audio_bytes, format='audio/wav')
175
+
176
+ # Transcription Display
177
+ st.write("### Transcription:")
178
+ st.write(result['transcription'])
179
+
180
+ # Audio Details
181
+ st.write("### Audio Details:")
182
+ st.json(result['audio_details'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
+ # Error Metrics
185
+ if 'error_metrics' in result:
186
+ st.write("### Error Metrics:")
187
+ st.json(result['error_metrics'])
188
+ else:
189
+ st.error(f"Error: {result.get('error', 'Unknown error')}")
190
 
191
  if __name__ == "__main__":
192
+ main()