T-MoENet / Infer.py
yixin1121's picture
Upload folder using huggingface_hub
85af14c verified
raw
history blame
5.16 kB
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from tqdm import tqdm
import argparse
from collections import OrderedDict
import json
from collections import defaultdict
from model.deberta_moe import DebertaV2ForMaskedLM
from transformers import DebertaV2Tokenizer
import clip
import ffmpeg
from VideoLoader import VideoLoader
def get_mask(lengths, max_length):
""" Computes a batch of padding masks given batched lengths """
mask = 1 * (
torch.arange(max_length).unsqueeze(1) < lengths
).transpose(0, 1)
return mask
class Infer:
def __init__(self, device):
pretrained_ckpt = torch.load("ckpts/model.pth", map_location="cpu")
args = pretrained_ckpt['args']
args.n_ans = 2
args.max_tokens = 256
self.args = args
self.clip_model = clip.load("ViT-L/14", device = device)[0]
self.tokenizer = DebertaV2Tokenizer.from_pretrained(
"ckpts/deberta-v2-xlarge", local_files_only=True
)
self.model = DebertaV2ForMaskedLM.from_pretrained(
features_dim=args.features_dim if args.use_video else 0,
max_feats=args.max_feats,
freeze_lm=args.freeze_lm,
freeze_mlm=args.freeze_mlm,
ft_ln=args.ft_ln,
ds_factor_attn=args.ds_factor_attn,
ds_factor_ff=args.ds_factor_ff,
dropout=args.dropout,
n_ans=args.n_ans,
freeze_last=args.freeze_last,
pretrained_model_name_or_path="ckpts/deberta-v2-xlarge",
local_files_only=False,
add_video_feat=args.add_video_feat,
freeze_ad=args.freeze_ad,
)
new_state_dict = OrderedDict()
for k, v in pretrained_ckpt['model'].items():
new_state_dict[k.replace("module.","")] = v
self.model.load_state_dict(pretrained_ckpt, strict=False)
self.model.eval()
self.model.to(device)
self.device = device
self.video_loader = VideoLoader()
self.set_answer()
def _get_clip_feature(self, video):
feat = self.clip_model.encode_image(video.to(self.device))
#feat = F.normalize(feat, dim=1)
return feat
def set_answer(self):
tok_yes = torch.tensor(
self.tokenizer(
"Yes",
add_special_tokens=False,
max_length=1,
truncation=True,
padding="max_length",
)["input_ids"],
dtype=torch.long,
)
tok_no = torch.tensor(
self.tokenizer(
"No",
add_special_tokens=False,
max_length=1,
truncation=True,
padding="max_length",
)["input_ids"],
dtype=torch.long,
)
a2tok = torch.stack([tok_yes, tok_no])
self.model.set_answer_embeddings(
a2tok.to(self.model.device), freeze_last=self.args.freeze_last
)
def generate(self, text, video_path, candidates = None):
video, video_len = self.video_loader(video_path)
video = self._get_clip_feature(video).unsqueeze(0).float()
video_mask = get_mask(video_len, 10)
video_mask = torch.cat([torch.ones((1,1)),video_mask], dim=1)
logits_list = []
question = text.capitalize().strip()
if question[-1] != "?":
question = str(question) + "?"
for aid in range(len(candidates)):
prompt = (
f" Question: {question} Is it '{candidates[aid]}'? {self.tokenizer.mask_token}. Subtitles: "
)
prompt = prompt.strip()
encoded = self.tokenizer(
prompt,
add_special_tokens=True,
max_length=self.args.max_tokens,
padding="longest",
truncation=True,
return_tensors="pt",
)
# forward
output = self.model(
video=video.to(self.device),
video_mask=video_mask.to(self.device),
input_ids=encoded["input_ids"].to(self.device),
attention_mask=encoded["attention_mask"].to(self.device),
)
# += output['loads'].detach().cpu()
logits = output["logits"]
# get logits for the mask token
delay = 11
logits = logits[:, delay : encoded["input_ids"].size(1) + delay][
encoded["input_ids"] == self.tokenizer.mask_token_id
]
logits_list.append(logits.softmax(-1)[:, 0])
logits = torch.stack(logits_list, 1)
if logits.shape[1] == 1:
preds = logits.round().long().squeeze(1)
else:
preds = logits.max(1).indices
return candidates[preds]