Suprath commited on
Commit
7b6cd70
·
verified ·
1 Parent(s): 8c1767f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -2
app.py CHANGED
@@ -88,6 +88,46 @@ def detect_landmark(image, detector, predictor):
88
  coords[i] = (shape.part(i).x, shape.part(i).y)
89
  return coords
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  def preprocess_video(input_video_path):
92
  if torch.cuda.is_available():
93
  detector = dlib.cnn_face_detection_model_v1(face_detector_path)
@@ -189,8 +229,8 @@ with demo:
189
  detect_landmark_btn.click(preprocess_video, [video_in], [
190
  video_out])
191
  predict_btn = gr.Button("Predict")
192
- predict_btn.click(predict, [video_out], [
193
- text_output])
194
  with gr.Row():
195
  # video_lip = gr.Video(label="Audio Visual Video", mirror_webcam=False)
196
  text_output.render()
 
88
  coords[i] = (shape.part(i).x, shape.part(i).y)
89
  return coords
90
 
91
+ def predict_and_save(process_video):
92
+ num_frames = int(cv2.VideoCapture(process_video).get(cv2.CAP_PROP_FRAME_COUNT))
93
+
94
+ tsv_cont = ["/\n", f"test-0\t{process_video}\t{None}\t{num_frames}\t{int(16_000*num_frames/25)}\n"]
95
+ label_cont = ["DUMMY\n"]
96
+ with open(f"{data_dir}/test.tsv", "w") as fo:
97
+ fo.write("".join(tsv_cont))
98
+ with open(f"{data_dir}/test.wrd", "w") as fo:
99
+ fo.write("".join(label_cont))
100
+ task.load_dataset(gen_subset, task_cfg=saved_cfg.task)
101
+
102
+ def decode_fn(x):
103
+ dictionary = task.target_dictionary
104
+ symbols_ignore = generator.symbols_to_strip_from_output
105
+ symbols_ignore.add(dictionary.pad())
106
+ return task.datasets[gen_subset].label_processors[0].decode(x, symbols_ignore)
107
+
108
+ itr = task.get_batch_iterator(dataset=task.dataset(gen_subset)).next_epoch_itr(shuffle=False)
109
+ sample = next(itr)
110
+ if torch.cuda.is_available():
111
+ sample = utils.move_to_cuda(sample)
112
+ hypos = task.inference_step(generator, models, sample)
113
+ ref = decode_fn(sample['target'][0].int().cpu())
114
+ hypo = hypos[0][0]['tokens'].int().cpu()
115
+ hypo = decode_fn(hypo)
116
+
117
+ # Collect timestamps and texts
118
+ transcript = []
119
+ for i, (start, end) in enumerate(sample['net_input']['video_lengths'], 1):
120
+ start_time = float(start) / 16_000
121
+ end_time = float(end) / 16_000
122
+ text = hypo[i].strip()
123
+ transcript.append({"timestamp": [start_time, end_time], "text": text})
124
+
125
+ # Save transcript to a JSON file
126
+ with open('speech_transcript.json', 'w') as outfile:
127
+ json.dump(transcript, outfile, indent=4)
128
+
129
+ return hypo
130
+
131
  def preprocess_video(input_video_path):
132
  if torch.cuda.is_available():
133
  detector = dlib.cnn_face_detection_model_v1(face_detector_path)
 
229
  detect_landmark_btn.click(preprocess_video, [video_in], [
230
  video_out])
231
  predict_btn = gr.Button("Predict")
232
+ #predict_btn.click(predict, [video_out], [text_output])
233
+ predict_btn.click(predict_and_save, [video_out], [text_output])
234
  with gr.Row():
235
  # video_lip = gr.Video(label="Audio Visual Video", mirror_webcam=False)
236
  text_output.render()