MASAI / app /load_models.py
DmitryRyumin's picture
Summary
15b7f31
"""
File: load_models.py
Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov
Description: Load pretrained models.
License: MIT License
"""
import math
import numpy as np
import cv2
import torch.nn.functional as F
import torch.nn as nn
import torch
from typing import Optional
from PIL import Image
from ultralytics import YOLO
from transformers.models.wav2vec2.modeling_wav2vec2 import (
Wav2Vec2Model,
Wav2Vec2PreTrainedModel,
)
from transformers import (
AutoConfig,
Wav2Vec2Processor,
AutoTokenizer,
AutoModel,
logging,
)
logging.set_verbosity_error()
from app.utils import pth_processing, get_idx_frames_in_windows
# Importing necessary components for the Gradio app
from app.utils import load_model
class ScaledDotProductAttention_MultiHead(nn.Module):
def __init__(self):
super(ScaledDotProductAttention_MultiHead, self).__init__()
self.softmax = nn.Softmax(dim=-1)
def forward(self, query, key, value, mask=None):
if mask is not None:
raise ValueError("Mask is not supported yet")
# key, query, value shapes: [batch_size, num_heads, seq_len, dim]
emb_dim = key.shape[-1]
# Calculate attention weights
attention_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(
emb_dim
)
# masking
if mask is not None:
raise ValueError("Mask is not supported yet")
# Softmax
attention_weights = self.softmax(attention_weights)
# modify value
value = torch.matmul(attention_weights, value)
return value, attention_weights
class PositionWiseFeedForward(nn.Module):
def __init__(self, input_dim, hidden_dim, dropout: float = 0.1):
super().__init__()
self.layer_1 = nn.Linear(input_dim, hidden_dim)
self.layer_2 = nn.Linear(hidden_dim, input_dim)
self.layer_norm = nn.LayerNorm(input_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# feed-forward network
x = self.layer_1(x)
x = self.dropout(x)
x = F.relu(x)
x = self.layer_2(x)
return x
class Add_and_Norm(nn.Module):
def __init__(self, input_dim, dropout: Optional[float] = 0.1):
super().__init__()
self.layer_norm = nn.LayerNorm(input_dim)
if dropout is not None:
self.dropout = nn.Dropout(dropout)
def forward(self, x1, residual):
x = x1
# apply dropout of needed
if hasattr(self, "dropout"):
x = self.dropout(x)
# add and then norm
x = x + residual
x = self.layer_norm(x)
return x
class MultiHeadAttention(nn.Module):
def __init__(self, input_dim, num_heads, dropout: Optional[float] = 0.1):
super().__init__()
self.input_dim = input_dim
self.num_heads = num_heads
if input_dim % num_heads != 0:
raise ValueError("input_dim must be divisible by num_heads")
self.head_dim = input_dim // num_heads
self.dropout = dropout
# initialize weights
self.query_w = nn.Linear(input_dim, self.num_heads * self.head_dim, bias=False)
self.keys_w = nn.Linear(input_dim, self.num_heads * self.head_dim, bias=False)
self.values_w = nn.Linear(input_dim, self.num_heads * self.head_dim, bias=False)
self.ff_layer_after_concat = nn.Linear(
self.num_heads * self.head_dim, input_dim, bias=False
)
self.attention = ScaledDotProductAttention_MultiHead()
if self.dropout is not None:
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, mask=None):
# query, keys, values shapes: [batch_size, seq_len, input_dim]
batch_size, len_query, len_keys, len_values = (
queries.size(0),
queries.size(1),
keys.size(1),
values.size(1),
)
# linear transformation before attention
queries = (
self.query_w(queries)
.view(batch_size, len_query, self.num_heads, self.head_dim)
.transpose(1, 2)
) # [batch_size, num_heads, seq_len, dim]
keys = (
self.keys_w(keys)
.view(batch_size, len_keys, self.num_heads, self.head_dim)
.transpose(1, 2)
) # [batch_size, num_heads, seq_len, dim]
values = (
self.values_w(values)
.view(batch_size, len_values, self.num_heads, self.head_dim)
.transpose(1, 2)
) # [batch_size, num_heads, seq_len, dim]
# attention itself
values, attention_weights = self.attention(
queries, keys, values, mask=mask
) # values shape:[batch_size, num_heads, seq_len, dim]
# concatenation
out = (
values.transpose(1, 2)
.contiguous()
.view(batch_size, len_values, self.num_heads * self.head_dim)
) # [batch_size, seq_len, num_heads * dim = input_dim]
# go through last linear layer
out = self.ff_layer_after_concat(out)
return out
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
)
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
pe = pe.permute(
1, 0, 2
) # [seq_len, batch_size, embedding_dim] -> [batch_size, seq_len, embedding_dim]
self.register_buffer("pe", pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Tensor, shape [batch_size, seq_len, embedding_dim]
"""
x = x + self.pe[:, : x.size(1)]
return self.dropout(x)
class TransformerLayer(nn.Module):
def __init__(
self,
input_dim,
num_heads,
dropout: Optional[float] = 0.1,
positional_encoding: bool = True,
):
super(TransformerLayer, self).__init__()
self.positional_encoding = positional_encoding
self.input_dim = input_dim
self.num_heads = num_heads
self.head_dim = input_dim // num_heads
self.dropout = dropout
# initialize layers
self.self_attention = MultiHeadAttention(input_dim, num_heads, dropout=dropout)
self.feed_forward = PositionWiseFeedForward(
input_dim, input_dim, dropout=dropout
)
self.add_norm_after_attention = Add_and_Norm(input_dim, dropout=dropout)
self.add_norm_after_ff = Add_and_Norm(input_dim, dropout=dropout)
# calculate positional encoding
if self.positional_encoding:
self.positional_encoding = PositionalEncoding(input_dim)
def forward(self, key, value, query, mask=None):
# key, value, and query shapes: [batch_size, seq_len, input_dim]
# positional encoding
if self.positional_encoding:
key = self.positional_encoding(key)
value = self.positional_encoding(value)
query = self.positional_encoding(query)
# multi-head attention
residual = query
x = self.self_attention(queries=query, keys=key, values=value, mask=mask)
x = self.add_norm_after_attention(x, residual)
# feed forward
residual = x
x = self.feed_forward(x)
x = self.add_norm_after_ff(x, residual)
return x
class SelfTransformer(nn.Module):
def __init__(self, input_size: int = int(1024), num_heads=1, dropout=0.1):
super(SelfTransformer, self).__init__()
self.att = torch.nn.MultiheadAttention(
input_size, num_heads, dropout, bias=True, batch_first=True
)
self.norm1 = nn.LayerNorm(input_size)
self.fcl = nn.Linear(input_size, input_size)
self.norm2 = nn.LayerNorm(input_size)
def forward(self, video):
represent, _ = self.att(video, video, video)
represent_norm = self.norm1(video + represent)
represent_fcl = self.fcl(represent_norm)
represent = self.norm1(represent_norm + represent_fcl)
return represent
class SmallClassificationHead(nn.Module):
"""ClassificationHead"""
def __init__(self, input_size=256, out_emo=6, out_sen=3):
super(SmallClassificationHead, self).__init__()
self.fc_emo = nn.Linear(input_size, out_emo)
self.fc_sen = nn.Linear(input_size, out_sen)
def forward(self, x):
x_emo = self.fc_emo(x)
x_sen = self.fc_sen(x)
return {"emo": x_emo, "sen": x_sen}
class AudioModelWT(Wav2Vec2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.wav2vec2 = Wav2Vec2Model(config)
self.f_size = 1024
self.tl1 = TransformerLayer(
input_dim=self.f_size, num_heads=4, dropout=0.1, positional_encoding=True
)
self.tl2 = TransformerLayer(
input_dim=self.f_size, num_heads=4, dropout=0.1, positional_encoding=True
)
self.fc1 = nn.Linear(1024, 1)
self.dp = nn.Dropout(p=0.5)
self.selu = nn.SELU()
self.relu = nn.ReLU()
self.cl_head = SmallClassificationHead(
input_size=199, out_emo=config.out_emo, out_sen=config.out_sen
)
self.init_weights()
# freeze conv
self.freeze_feature_encoder()
def freeze_feature_encoder(self):
for param in self.wav2vec2.feature_extractor.conv_layers.parameters():
param.requires_grad = False
def forward(self, x, with_features=False):
outputs = self.wav2vec2(x)
x = self.tl1(outputs[0], outputs[0], outputs[0])
x = self.selu(x)
features = self.tl2(x, x, x)
x = self.selu(features)
x = self.fc1(x)
x = self.relu(x)
x = self.dp(x)
x = x.view(x.size(0), -1)
if with_features:
return self.cl_head(x), features
else:
return self.cl_head(x)
class AudioFeatureExtractor:
def __init__(
self,
checkpoint_url: str,
folder_path: str,
device: torch.device,
sr: int = 16000,
win_max_length: int = 4,
with_features: bool = True,
) -> None:
"""
Args:
sr (int, optional): Sample rate of audio. Defaults to 16000.
win_max_length (int, optional): Max length of window. Defaults to 4.
with_features (bool, optional): Extract features or not
"""
self.device = device
self.sr = sr
self.win_max_length = win_max_length
self.with_features = with_features
checkpoint_path = load_model(checkpoint_url, folder_path)
model_name = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
model_config = AutoConfig.from_pretrained(model_name)
model_config.out_emo = 7
model_config.out_sen = 3
model_config.context_length = 199
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
self.model = AudioModelWT.from_pretrained(
pretrained_model_name_or_path=model_name, config=model_config
)
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.model.to(self.device)
def preprocess_wave(self, x: torch.Tensor) -> torch.Tensor:
"""Extracts features for wav2vec
Apply padding to max length of audio
Args:
x (torch.Tensor): Input data
Returns:
np.ndarray: Preprocessed data
"""
a_data = self.processor(
x,
sampling_rate=self.sr,
return_tensors="pt",
padding="max_length",
max_length=self.sr * self.win_max_length,
)
return a_data["input_values"][0]
def __call__(
self, waveform: torch.Tensor
) -> tuple[dict[torch.Tensor], torch.Tensor]:
"""Extracts acoustic features
Apply padding to max length of audio
Args:
wave (torch.Tensor): wave
Returns:
torch.Tensor: Extracted features
"""
waveform = self.preprocess_wave(waveform).unsqueeze(0).to(self.device)
with torch.no_grad():
if self.with_features:
preds, features = self.model(waveform, with_features=self.with_features)
else:
preds = self.model(waveform, with_features=self.with_features)
predicts = {
"emo": F.softmax(preds["emo"], dim=-1).detach().cpu().squeeze(),
"sen": F.softmax(preds["sen"], dim=-1).detach().cpu().squeeze(),
}
return (
(predicts, features.detach().cpu().squeeze())
if self.with_features
else (predicts, None)
)
class Tmodel(nn.Module):
def __init__(
self,
input_size: int = int(1024),
activation=nn.SELU(),
feature_size1=256,
feature_size2=64,
num_heads=1,
num_layers=2,
n_emo=7,
n_sent=3,
):
super(Tmodel, self).__init__()
self.feature_text_dynamic = nn.ModuleList(
[
SelfTransformer(input_size=input_size, num_heads=num_heads)
for i in range(num_layers)
]
)
self.fcl = nn.Linear(input_size, feature_size1)
self.activation = activation
self.feature_emo = nn.Linear(feature_size1, feature_size2)
self.feature_sent = nn.Linear(feature_size1, feature_size2)
self.fc_emo = nn.Linear(feature_size2, n_emo)
self.fc_sent = nn.Linear(feature_size2, n_sent)
def get_features(self, t):
for i, l in enumerate(self.feature_text_dynamic):
self.features = l(t)
def forward(self, t):
self.get_features(t)
represent = self.activation(torch.mean(t, axis=1))
represent = self.activation(self.fcl(represent))
represent_emo = self.activation(self.feature_emo(represent))
represent_sent = self.activation(self.feature_sent(represent))
prob_emo = self.fc_emo(represent_emo)
prob_sent = self.fc_sent(represent_sent)
return prob_emo, prob_sent
class TextFeatureExtractor:
def __init__(
self,
checkpoint_url: str,
folder_path: str,
device: torch.device,
with_features: bool = True,
) -> None:
self.device = device
self.with_features = with_features
model_name_bert = "julian-schelb/roberta-ner-multilingual"
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_bert, add_prefix_space=True
)
self.model_bert = AutoModel.from_pretrained(model_name_bert)
checkpoint_path = load_model(checkpoint_url, folder_path)
self.model = Tmodel()
self.model.load_state_dict(
torch.load(checkpoint_path, map_location=self.device)
)
self.model.to(self.device)
def preprocess_text(self, text: torch.Tensor) -> torch.Tensor:
if text != "" and str(text) != "nan":
inputs = self.tokenizer(
text.lower(),
padding="max_length",
truncation="longest_first",
return_tensors="pt",
max_length=6,
).to(self.device)
with torch.no_grad():
self.model_bert = self.model_bert.to(self.device)
outputs = (
self.model_bert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
)
.last_hidden_state.cpu()
.detach()
)
else:
outputs = torch.zeros((1, 6, 1024))
return outputs
def __call__(self, text: torch.Tensor) -> tuple[dict[torch.Tensor], torch.Tensor]:
text_features = self.preprocess_text(text)
with torch.no_grad():
if self.with_features:
pred_emo, pred_sent = self.model(text_features.float().to(self.device))
temporal_features = self.model.features
else:
pred_emo, pred_sent = self.model(text_features.float().to(self.device))
predicts = {
"emo": F.softmax(pred_emo, dim=-1).detach().cpu().squeeze(),
"sen": F.softmax(pred_sent, dim=-1).detach().cpu().squeeze(),
}
return (
(predicts, temporal_features.detach().cpu().squeeze())
if self.with_features
else (predicts, None)
)
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=stride,
padding=0,
bias=False,
)
self.batch_norm1 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.99)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, padding="same", bias=False
)
self.batch_norm2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.99)
self.conv3 = nn.Conv2d(
out_channels,
out_channels * self.expansion,
kernel_size=1,
stride=1,
padding=0,
bias=False,
)
self.batch_norm3 = nn.BatchNorm2d(
out_channels * self.expansion, eps=0.001, momentum=0.99
)
self.i_downsample = i_downsample
self.stride = stride
self.relu = nn.ReLU()
def forward(self, x):
identity = x.clone()
x = self.relu(self.batch_norm1(self.conv1(x)))
x = self.relu(self.batch_norm2(self.conv2(x)))
x = self.conv3(x)
x = self.batch_norm3(x)
# downsample if needed
if self.i_downsample is not None:
identity = self.i_downsample(identity)
# add identity
x += identity
x = self.relu(x)
return x
class Conv2dSame(torch.nn.Conv2d):
def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int:
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
ih, iw = x.size()[-2:]
pad_h = self.calc_same_pad(
i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0]
)
pad_w = self.calc_same_pad(
i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1]
)
if pad_h > 0 or pad_w > 0:
x = F.pad(
x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
)
return F.conv2d(
x,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)
class ResNet(nn.Module):
def __init__(self, ResBlock, layer_list, num_classes, num_channels=3):
super(ResNet, self).__init__()
self.in_channels = 64
self.conv_layer_s2_same = Conv2dSame(
num_channels, 64, 7, stride=2, groups=1, bias=False
)
self.batch_norm1 = nn.BatchNorm2d(64, eps=0.001, momentum=0.99)
self.relu = nn.ReLU()
self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2)
self.layer1 = self._make_layer(ResBlock, layer_list[0], planes=64, stride=1)
self.layer2 = self._make_layer(ResBlock, layer_list[1], planes=128, stride=2)
self.layer3 = self._make_layer(ResBlock, layer_list[2], planes=256, stride=2)
self.layer4 = self._make_layer(ResBlock, layer_list[3], planes=512, stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc1 = nn.Linear(512 * ResBlock.expansion, 512)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(512, num_classes)
def extract_features_four(self, x):
x = self.relu(self.batch_norm1(self.conv_layer_s2_same(x)))
x = self.max_pool(x)
# print(x.shape)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
def extract_features(self, x):
x = self.extract_features_four(x)
x = self.avgpool(x)
x = x.reshape(x.shape[0], -1)
x = self.fc1(x)
return x
def forward(self, x):
x = self.extract_features(x)
x = self.relu1(x)
x = self.fc2(x)
return x
def _make_layer(self, ResBlock, blocks, planes, stride=1):
ii_downsample = None
layers = []
if stride != 1 or self.in_channels != planes * ResBlock.expansion:
ii_downsample = nn.Sequential(
nn.Conv2d(
self.in_channels,
planes * ResBlock.expansion,
kernel_size=1,
stride=stride,
bias=False,
padding=0,
),
nn.BatchNorm2d(planes * ResBlock.expansion, eps=0.001, momentum=0.99),
)
layers.append(
ResBlock(
self.in_channels, planes, i_downsample=ii_downsample, stride=stride
)
)
self.in_channels = planes * ResBlock.expansion
for i in range(blocks - 1):
layers.append(ResBlock(self.in_channels, planes))
return nn.Sequential(*layers)
def ResNet50(num_classes, channels=3):
return ResNet(Bottleneck, [3, 4, 6, 3], num_classes, channels)
class Vmodel(nn.Module):
def __init__(
self,
input_size=512,
activation=nn.SELU(),
feature_size=64,
num_heads=1,
num_layers=1,
positional_encoding=False,
n_emo=7,
n_sent=3,
):
super(Vmodel, self).__init__()
self.feature_video_dynamic = nn.ModuleList(
[
TransformerLayer(
input_dim=input_size,
num_heads=num_heads,
positional_encoding=positional_encoding,
)
for i in range(num_layers)
]
)
self.fcl = nn.Linear(input_size, feature_size)
self.activation = activation
self.feature_emo = nn.Linear(feature_size, feature_size)
self.feature_sent = nn.Linear(feature_size, feature_size)
self.fc_emo = nn.Linear(feature_size, n_emo)
self.fc_sent = nn.Linear(feature_size, n_sent)
def forward(self, x, with_features=False):
for i, l in enumerate(self.feature_video_dynamic):
x = l(x, x, x)
represent = self.activation(torch.mean(x, axis=1))
represent = self.activation(self.fcl(represent))
represent_emo = self.activation(self.feature_emo(represent))
represent_sent = self.activation(self.feature_sent(represent))
prob_emo = self.fc_emo(represent_emo)
prob_sent = self.fc_sent(represent_sent)
if with_features:
return {"emo": prob_emo, "sen": prob_sent}, x
else:
return {"emo": prob_emo, "sen": prob_sent}
class VideoModelLoader:
def __init__(
self,
face_checkpoint_url: str,
emotion_checkpoint_url: str,
emo_sent_checkpoint_url: str,
folder_path: str,
device: torch.device,
) -> None:
self.device = device
# YOLO face recognition model initialization
face_model_path = load_model(face_checkpoint_url, folder_path)
emotion_video_model_path = load_model(emotion_checkpoint_url, folder_path)
emo_sent_video_model_path = load_model(emo_sent_checkpoint_url, folder_path)
self.face_model = YOLO(face_model_path)
# EmoAffectet model initialization (static model)
self.emo_affectnet_model = ResNet50(num_classes=7, channels=3)
self.emo_affectnet_model.load_state_dict(
torch.load(emotion_video_model_path, map_location=self.device)
)
self.emo_affectnet_model.to(self.device).eval()
# Visual emotion and sentiment recognition model (dynamic model)
self.emo_sent_video_model = Vmodel()
self.emo_sent_video_model.load_state_dict(
torch.load(emo_sent_video_model_path, map_location=self.device)
)
self.emo_sent_video_model.to(self.device).eval()
def extract_zeros_features(self):
zeros = torch.unsqueeze(torch.zeros((3, 224, 224)), 0).to(self.device)
zeros_features = self.emo_affectnet_model.extract_features(zeros)
return zeros_features.cpu().detach().numpy()[0]
class VideoFeatureExtractor:
def __init__(
self,
model_loader: VideoModelLoader,
file_path: str,
target_fps: int = 5,
with_features: bool = True,
) -> None:
self.model_loader = model_loader
self.with_features = with_features
# Video options
self.cap = cv2.VideoCapture(file_path)
self.w, self.h, self.fps, self.frame_number = (
int(self.cap.get(x))
for x in (
cv2.CAP_PROP_FRAME_WIDTH,
cv2.CAP_PROP_FRAME_HEIGHT,
cv2.CAP_PROP_FPS,
cv2.CAP_PROP_FRAME_COUNT,
)
)
self.dur = self.frame_number / self.fps
self.target_fps = target_fps
self.frame_interval = int(self.fps / target_fps)
# Extract zero features if no face found in frame
self.zeros_features = self.model_loader.extract_zeros_features()
# Dictionaries with facial features and faces
self.facial_features = {}
self.faces = {}
def preprocess_frame(self, frame: np.ndarray, counter: int) -> None:
curr_fr = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = self.model_loader.face_model.track(
curr_fr,
persist=True,
imgsz=640,
conf=0.01,
iou=0.5,
augment=False,
device=self.model_loader.device,
verbose=False,
)
need_features = np.zeros(512)
count_face = 0
if results[0].boxes.xyxy.cpu().tolist() != []:
for i in results[0].boxes:
idx_box = i.id.int().cpu().tolist()[0] if i.id else -1
box = i.xyxy.int().cpu().tolist()[0]
startX, startY = max(0, box[0]), max(0, box[1])
endX, endY = min(self.w - 1, box[2]), min(self.h - 1, box[3])
face_region = curr_fr[startY:endY, startX:endX]
norm_face_region = pth_processing(Image.fromarray(face_region))
with torch.no_grad():
curr_features = (
self.model_loader.emo_affectnet_model.extract_features(
norm_face_region.to(self.model_loader.device)
)
)
need_features += curr_features.cpu().detach().numpy()[0]
count_face += 1
if idx_box in self.faces:
self.faces[idx_box].update({counter: face_region})
else:
self.faces[idx_box] = {counter: face_region}
need_features /= count_face
self.facial_features[counter] = need_features
else:
if counter - 1 in self.facial_features:
self.facial_features[counter] = self.facial_features[counter - 1]
else:
self.facial_features[counter] = self.zeros_features
def preprocess_video(self) -> None:
counter = 0
while True:
ret, frame = self.cap.read()
if not ret:
break
if counter % self.frame_interval == 0:
self.preprocess_frame(frame, counter)
counter += 1
def __call__(
self, window: dict, win_max_length: int, sr: int = 16000
) -> tuple[dict[torch.Tensor], torch.Tensor]:
curr_idx_frames = get_idx_frames_in_windows(
list(self.facial_features.keys()), window, self.fps, sr
)
video_features = np.array(list(self.facial_features.values()))
curr_features = video_features[curr_idx_frames, :]
if len(curr_features) < self.target_fps * win_max_length:
diff = self.target_fps * win_max_length - len(curr_features)
curr_features = np.concatenate(
[curr_features, [curr_features[-1]] * diff], axis=0
)
curr_features = (
torch.FloatTensor(curr_features).unsqueeze(0).to(self.model_loader.device)
)
with torch.no_grad():
if self.with_features:
preds, features = self.model_loader.emo_sent_video_model(
curr_features, with_features=self.with_features
)
else:
preds = self.model_loader.emo_sent_video_model(
curr_features, with_features=self.with_features
)
predicts = {
"emo": F.softmax(preds["emo"], dim=-1).detach().cpu().squeeze(),
"sen": F.softmax(preds["sen"], dim=-1).detach().cpu().squeeze(),
}
return (
(predicts, features.detach().cpu().squeeze())
if self.with_features
else (predicts, None)
)