""" File: submit.py Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov Description: Event handler for Gradio app to submit. License: MIT License """ import spaces import torch import pandas as pd import cv2 import gradio as gr # Importing necessary components for the Gradio app from app.config import config_data from app.utils import ( Timer, convert_video_to_audio, readetect_speech, slice_audio, find_intersections, calculate_mode, find_nearest_frames, convert_webm_to_mp4, ) from app.plots import ( get_evenly_spaced_frame_indices, plot_audio, display_frame_info, plot_images, plot_predictions, ) from app.data_init import ( read_audio, get_speech_timestamps, vad_model, video_model, asr, audio_model, text_model, ) from app.load_models import VideoFeatureExtractor @spaces.GPU def event_handler_submit( video: str, ) -> tuple[ gr.Textbox, gr.Plot, gr.Plot, gr.Plot, gr.Plot, gr.Row, gr.Textbox, gr.Textbox, ]: with Timer() as timer: if video: if video.split(".")[-1] == "webm": video = convert_webm_to_mp4(video) audio_file_path = convert_video_to_audio( file_path=video, sr=config_data.General_SR ) wav, vad_info = readetect_speech( file_path=audio_file_path, read_audio=read_audio, get_speech_timestamps=get_speech_timestamps, vad_model=vad_model, sr=config_data.General_SR, ) audio_windows = slice_audio( start_time=config_data.General_START_TIME, end_time=int(len(wav)), win_max_length=int( config_data.General_WIN_MAX_LENGTH * config_data.General_SR ), win_shift=int(config_data.General_WIN_SHIFT * config_data.General_SR), win_min_length=int( config_data.General_WIN_MIN_LENGTH * config_data.General_SR ), ) intersections = find_intersections( x=audio_windows, y=vad_info, min_length=config_data.General_WIN_MIN_LENGTH * config_data.General_SR, ) vfe = VideoFeatureExtractor(video_model, file_path=video, with_features=False) vfe.preprocess_video() transcriptions, total_text = asr(wav, audio_windows) window_frames = [] preds_emo = [] preds_sen = [] for w_idx, window in enumerate(audio_windows): a_w = intersections[w_idx] if not a_w["speech"]: a_pred = None else: wave = wav[a_w["start"] : a_w["end"]].clone() a_pred, _ = audio_model(wave) v_pred, _ = vfe(window, config_data.General_WIN_MAX_LENGTH) t_pred, _ = text_model(transcriptions[w_idx][0]) if a_pred: pred_emo = (a_pred["emo"] + v_pred["emo"] + t_pred["emo"]) / 3 pred_sen = (a_pred["sen"] + v_pred["sen"] + t_pred["sen"]) / 3 else: pred_emo = (v_pred["emo"] + t_pred["emo"]) / 2 pred_sen = (v_pred["sen"] + t_pred["sen"]) / 2 frames = list( range( int(window["start"] * vfe.fps / config_data.General_SR) + 1, int(window["end"] * vfe.fps / config_data.General_SR) + 2, ) ) preds_emo.extend([torch.argmax(pred_emo).numpy()] * len(frames)) preds_sen.extend([torch.argmax(pred_sen).numpy()] * len(frames)) window_frames.extend(frames) if max(window_frames) < vfe.frame_number: missed_frames = list(range(max(window_frames) + 1, vfe.frame_number + 1)) window_frames.extend(missed_frames) preds_emo.extend([preds_emo[-1]] * len(missed_frames)) preds_sen.extend([preds_sen[-1]] * len(missed_frames)) df_pred = pd.DataFrame(columns=["frames", "pred_emo", "pred_sent"]) df_pred["frames"] = window_frames df_pred["pred_emo"] = preds_emo df_pred["pred_sent"] = preds_sen df_pred = df_pred.groupby("frames").agg( { "pred_emo": calculate_mode, "pred_sent": calculate_mode, } ) frame_indices = get_evenly_spaced_frame_indices(vfe.frame_number, 9) num_frames = len(wav) time_axis = [i / config_data.General_SR for i in range(num_frames)] plt_audio = plot_audio( time_axis, wav.unsqueeze(0), frame_indices, vfe.fps, (12, 2) ) all_idx_faces = list(vfe.faces[1].keys()) need_idx_faces = find_nearest_frames(frame_indices, all_idx_faces) faces = [] for idx_frame, idx_faces in zip(frame_indices, need_idx_faces): cur_face = cv2.resize( vfe.faces[1][idx_faces], (224, 224), interpolation=cv2.INTER_AREA ) faces.append( display_frame_info( cur_face, "Frame: {}".format(idx_frame + 1), box_scale=0.3 ) ) plt_faces = plot_images(faces) plt_emo = plot_predictions( df_pred, "pred_emo", "Emotion", list(config_data.General_DICT_EMO), (12, 2.5), [i + 1 for i in frame_indices], 3, ) plt_sent = plot_predictions( df_pred, "pred_sent", "Sentiment", list(config_data.General_DICT_SENT), (12, 1.5), [i + 1 for i in frame_indices], 3, ) return ( gr.Textbox( value=" ".join(total_text).strip(), info=config_data.InformationMessages_REC_TEXT, container=True, elem_classes="noti-results", ), gr.Plot(value=plt_audio, visible=True), gr.Plot(value=plt_faces, visible=True), gr.Plot(value=plt_emo, visible=True), gr.Plot(value=plt_sent, visible=True), gr.Row(visible=True), gr.Textbox( value=config_data.OtherMessages_SEC.format(vfe.dur), info=config_data.InformationMessages_VIDEO_DURATION, container=True, visible=True, ), gr.Textbox( value=timer.execution_time, info=config_data.InformationMessages_INFERENCE_TIME, container=True, visible=True, ), )