sedemkofi commited on
Commit
346a962
·
verified ·
1 Parent(s): e2e2e8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -91
app.py CHANGED
@@ -7,100 +7,40 @@ import json
7
  from io import BytesIO
8
  import base64
9
 
10
- class TwiTranscriptionModel:
11
- def __init__(self, encoder_model, decoder_model, char_tokenizer, max_length=50):
12
- self.encoder_model = encoder_model
13
- self.decoder_model = decoder_model
14
- self.char_tokenizer = char_tokenizer
15
- self.max_length = max_length
16
- self.sos_token = '<sos>'
17
- self.eos_token = '<eos>'
18
- self.sos_index = char_tokenizer.word_index[self.sos_token]
19
- self.eos_index = char_tokenizer.word_index[self.eos_token]
20
-
21
- def predict(self, audio_features):
22
- batch_size = audio_features.shape[0]
23
- transcriptions = []
24
-
25
- for i in range(batch_size):
26
- states_value = self.encoder_model.predict(
27
- audio_features[i:i+1],
28
- verbose=0
29
- )
30
- target_seq = np.array([[self.sos_index]])
31
- decoded_chars = []
32
 
33
- for _ in range(self.max_length):
34
- output_tokens, h, c = self.decoder_model.predict(
35
- [target_seq] + states_value,
36
- verbose=0
37
- )
38
-
39
- sampled_token_index = np.argmax(output_tokens[0, -1, :])
40
- sampled_char = self.char_tokenizer.index_word.get(sampled_token_index, '')
41
-
42
- if sampled_char == self.eos_token or len(decoded_chars) > self.max_length:
43
- break
44
-
45
- decoded_chars.append(sampled_char)
46
- target_seq = np.array([[sampled_token_index]])
47
- states_value = [h, c]
48
-
49
- transcriptions.append(''.join(decoded_chars))
50
-
51
- return transcriptions
52
-
53
- @st.cache_resource
54
- def get_model():
55
- try:
56
- with open('twi_transcription_model.pkl', 'rb') as f:
57
- model_data = pickle.load(f)
58
- return TwiTranscriptionModel(
59
- model_data['encoder_model'],
60
- model_data['decoder_model'],
61
- model_data['char_tokenizer'],
62
- model_data['max_length']
63
- )
64
- except Exception as e:
65
- st.error(f"Error loading model: {str(e)}")
66
- return None
67
-
68
- def extract_mfcc(audio_data, sr=16000, n_mfcc=13):
69
- if sr != 16000:
70
- audio_data = librosa.resample(y=audio_data, orig_sr=sr, target_sr=16000)
71
-
72
- mfcc = librosa.feature.mfcc(y=audio_data, sr=16000, n_mfcc=n_mfcc)
73
-
74
- max_length = 1000 # Adjust based on your model's requirements
75
- if mfcc.shape[1] > max_length:
76
- mfcc = mfcc[:, :max_length]
77
- else:
78
- mfcc = np.pad(mfcc, ((0, 0), (0, max_length - mfcc.shape[1])), mode='constant')
79
-
80
- return mfcc.T
81
-
82
- def calculate_error_rates(reference, hypothesis):
83
- try:
84
- error_wer = wer(reference, hypothesis)
85
- error_cer = cer(reference, hypothesis)
86
- return error_wer, error_cer
87
- except Exception as e:
88
- return None, None
89
-
90
- def process_audio_bytes(audio_bytes):
91
- try:
92
- audio_data, sr = librosa.load(BytesIO(audio_bytes), sr=None)
93
- if len(audio_data.shape) > 1:
94
- audio_data = np.mean(audio_data, axis=1)
95
- return audio_data, sr
96
- except Exception as e:
97
- raise Exception(f"Error processing audio: {str(e)}")
98
-
99
- # Set page config
100
- st.set_page_config(page_title="Twi Speech API")
101
 
102
  def main():
103
  model = get_model()
 
104
 
105
  if model is None:
106
  st.write(json.dumps({
@@ -123,7 +63,34 @@ def main():
123
  }))
124
  return
125
 
126
- audio_base64 = data.get('audio')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  reference_text = data.get('reference_text')
128
 
129
  if not audio_base64:
 
7
  from io import BytesIO
8
  import base64
9
 
10
+ # ... (keep your existing imports and TwiTranscriptionModel class) ...
11
+
12
+ # Add this at the top of your file
13
+ class ChunkedUploader:
14
+ def __init__(self):
15
+ if 'chunks' not in st.session_state:
16
+ st.session_state.chunks = {}
17
+ if 'current_upload_id' not in st.session_state:
18
+ st.session_state.current_upload_id = None
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ def add_chunk(self, upload_id, chunk_num, total_chunks, chunk_data):
21
+ if upload_id not in st.session_state.chunks:
22
+ st.session_state.chunks[upload_id] = {'data': {}, 'total': total_chunks}
23
+ st.session_state.chunks[upload_id]['data'][chunk_num] = chunk_data
24
+
25
+ def is_upload_complete(self, upload_id):
26
+ if upload_id not in st.session_state.chunks:
27
+ return False
28
+ upload = st.session_state.chunks[upload_id]
29
+ return len(upload['data']) == upload['total']
30
+
31
+ def get_complete_data(self, upload_id):
32
+ if not self.is_upload_complete(upload_id):
33
+ return None
34
+ chunks = st.session_state.chunks[upload_id]['data']
35
+ sorted_chunks = [chunks[i] for i in range(len(chunks))]
36
+ complete_data = ''.join(sorted_chunks)
37
+ # Clean up after getting data
38
+ del st.session_state.chunks[upload_id]
39
+ return complete_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  def main():
42
  model = get_model()
43
+ chunked_uploader = ChunkedUploader()
44
 
45
  if model is None:
46
  st.write(json.dumps({
 
63
  }))
64
  return
65
 
66
+ # Handle chunked upload
67
+ if 'chunk_data' in data:
68
+ upload_id = data.get('upload_id')
69
+ chunk_num = data.get('chunk_num')
70
+ total_chunks = data.get('total_chunks')
71
+ chunk_data = data.get('chunk_data')
72
+
73
+ if not all([upload_id, chunk_num is not None, total_chunks, chunk_data]):
74
+ st.write(json.dumps({
75
+ 'error': 'Missing chunked upload parameters',
76
+ 'status': 'error'
77
+ }))
78
+ return
79
+
80
+ chunked_uploader.add_chunk(upload_id, chunk_num, total_chunks, chunk_data)
81
+
82
+ if not chunked_uploader.is_upload_complete(upload_id):
83
+ st.write(json.dumps({
84
+ 'status': 'pending',
85
+ 'message': f'Received chunk {chunk_num + 1} of {total_chunks}'
86
+ }))
87
+ return
88
+
89
+ # Get complete data if upload is finished
90
+ audio_base64 = chunked_uploader.get_complete_data(upload_id)
91
+ else:
92
+ audio_base64 = data.get('audio')
93
+
94
  reference_text = data.get('reference_text')
95
 
96
  if not audio_base64: