T-MoENet / VideoLoader.py
yixin1121's picture
Upload folder using huggingface_hub
513e1fb verified
raw
history blame
4.3 kB
import torch as th
import os
import numpy as np
import ffmpeg
class Normalize(object):
def __init__(self, mean, std):
self.mean = th.FloatTensor(mean).view(1, 3, 1, 1)
self.std = th.FloatTensor(std).view(1, 3, 1, 1)
def __call__(self, tensor):
tensor = (tensor - self.mean) / (self.std + 1e-8)
return tensor
class Preprocessing(object):
def __init__(self):
self.norm = Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
)
def __call__(self, tensor):
tensor = tensor / 255.0
tensor = self.norm(tensor)
return tensor
class VideoLoader:
"""Pytorch video loader."""
def __init__(
self,
framerate=1,
size=224,
centercrop=True,
):
self.centercrop = centercrop
self.size = size
self.framerate = framerate
self.preprocess = Preprocessing()
self.max_feats = 10
self.features_dim = 768
def _get_video_dim(self, video_path):
probe = ffmpeg.probe(video_path)
video_stream = next(
(stream for stream in probe["streams"] if stream["codec_type"] == "video"),
None,
)
width = int(video_stream["width"])
height = int(video_stream["height"])
num, denum = video_stream["avg_frame_rate"].split("/")
frame_rate = int(num) / int(denum)
return height, width, frame_rate
def _get_output_dim(self, h, w):
if isinstance(self.size, tuple) and len(self.size) == 2:
return self.size
elif h >= w:
return int(h * self.size / w), self.size
else:
return self.size, int(w * self.size / h)
def _getvideo(self, video_path):
if os.path.isfile(video_path):
print("Decoding video: {}".format(video_path))
try:
h, w, fr = self._get_video_dim(video_path)
except:
print("ffprobe failed at: {}".format(video_path))
return {
"video": th.zeros(1),
"input": video_path
}
if fr < 1:
print("Corrupted Frame Rate: {}".format(video_path))
return {
"video": th.zeros(1),
"input": video_path
}
height, width = self._get_output_dim(h, w)
try:
cmd = (
ffmpeg.input(video_path)
.filter("fps", fps=self.framerate)
.filter("scale", width, height)
)
if self.centercrop:
x = int((width - self.size) / 2.0)
y = int((height - self.size) / 2.0)
cmd = cmd.crop(x, y, self.size, self.size)
out, _ = cmd.output("pipe:", format="rawvideo", pix_fmt="rgb24").run(
capture_stdout=True, quiet=True
)
except:
print("ffmpeg error at: {}".format(video_path))
return {
"video": th.zeros(1),
"input": video_path,
}
if self.centercrop and isinstance(self.size, int):
height, width = self.size, self.size
video = np.frombuffer(out, np.uint8).reshape([-1, height, width, 3])
video = th.from_numpy(video.astype("float32"))
video = video.permute(0, 3, 1, 2) # t,c,h,w
else:
video = th.zeros(1)
return {"video": video, "input": video_path}
def __call__(self, video_path):
video = self._getvideo(video_path)['video']
if len(video) > self.max_feats:
sampled = []
for j in range(self.max_feats):
sampled.append(video[(j * len(video)) // self.max_feats])
video = th.stack(sampled)
video_len = self.max_feats
elif len(video) < self.max_feats:
video_len = len(video)
video = th.cat(
[video, th.zeros(self.max_feats - video_len, self.features_dim)], 0
)
video = self.preprocess(video)
return video, video_len