MKFMIKU's picture
A new start
596242b
raw
history blame
11.4 kB
import os
import random
from typing import List
import av
import av.logging
import PIL
import numpy as np
from torch.utils.data import Dataset
from einops import rearrange
# from cache_decorator import Cache
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import cv2
class VideoFolder(Dataset):
IMG_EXTENSIONS = [
".png",
".PNG",
".jpg",
".JPG"
]
VIDEO_EXTENSIONS = [".mp4", ".MP4", ".avi", ".AVI"]
def __init__(
self,
path: str,
size: List[int],
nframes: int = 128,
):
if isinstance(size, (list, tuple)):
if len(size) not in [1, 2]:
raise ValueError(
f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list"
)
if isinstance(size, int):
size = [size, size]
def _find_all_path(_path):
_all_fnames = {
os.path.relpath(os.path.join(root, fname), start=_path)
for root, _dirs, files in os.walk(_path)
for fname in files
}
_video_fnames = sorted(
fname
for fname in _all_fnames
if self._file_ext(fname) in self.VIDEO_EXTENSIONS
) + sorted(
list(
set(
(
os.path.dirname(fname)
for fname in _all_fnames
if self._file_ext(fname) in self.IMG_EXTENSIONS
)
)
)
)
_video_fnames = sorted(_video_fnames)
return _video_fnames
_video_fnames = _find_all_path(path)
self.path = path
self.size = size
self.nframes = nframes
self._video_fnames = _video_fnames
self._total_size = len(self._video_fnames)
@staticmethod
def _file_ext(fname):
return os.path.splitext(fname)[1].lower()
def _read_video_opencv(self, video_path, nframes, size):
video = []
if os.path.isdir(video_path):
_all_fnames = {
os.path.relpath(os.path.join(root, fname), start=video_path)
for root, _dirs, files in os.walk(video_path)
for fname in files
}
_video_fnames = sorted(
fname
for fname in _all_fnames
if self._file_ext(fname) in self.IMG_EXTENSIONS
)
for fname in _video_fnames:
with open(os.path.join(video_path, fname), "rb") as f:
video.append(
np.array(
PIL.Image.open(f)
.convert("RGB")
.resize(
size, resample=3
) # PIL.Image.Resampling.LANCZOS = 1 PIL.Image.Resampling.BICUBIC = 3
)
)
else:
video = []
cap = cv2.VideoCapture(video_path)
while cap.isOpened():
success, image = cap.read()
if success:
video.append(
np.asarray(
cv2.resize(image, size, interpolation=cv2.INTER_CUBIC)[
:, :, ::-1
]
)
)
else:
break
cap.release()
if len(video) != nframes:
frame_scale = len(video) / nframes
frame_scaled_idxs = [int(i * frame_scale) for i in range(nframes)]
video = [video[i] for i in range(len(video)) if i in frame_scaled_idxs]
# for cache
video = np.stack(video).astype(np.uint8)
return video
# @Cache(
# cache_path="/data/kmei1/caches/{_hash}.pkl",
# )
def _read_video(self, video_path, nframes, size):
video = []
if os.path.isdir(video_path):
_all_fnames = {
os.path.relpath(os.path.join(root, fname), start=video_path)
for root, _dirs, files in os.walk(video_path)
for fname in files
}
_video_fnames = sorted(
[fname
for fname in _all_fnames
if self._file_ext(fname) in self.IMG_EXTENSIONS],
key = lambda x: int(x[:-4])
)
for fname in _video_fnames:
with open(os.path.join(video_path, fname), "rb") as f:
video.append(
np.array(
PIL.Image.open(f)
.convert("RGB")
.resize(
self.size, resample=1
) # PIL.Image.Resampling.LANCZOS = 1 PIL.Image.Resampling.BICUBIC = 3
)
)
else:
with av.open(video_path) as container:
container.streams.video[0].thread_type = "AUTO"
container.streams.video[0].thread_count = 2
total_frames = container.streams.video[0].frames
frame_scale = total_frames / nframes
frame_scaled_idxs = [int(i * frame_scale) for i in range(total_frames)]
for idx, frame in enumerate(container.decode(video=0)):
if idx in frame_scaled_idxs:
video.append(
np.asarray(
frame.to_image().resize(
size, resample=1
) # PIL.Image.Resampling.LANCZOS = 1 PIL.Image.Resampling.BICUBIC = 3
).clip(0, 255)
)
container.close()
frame_scale = len(video) / nframes
frame_scaled_idxs = [int(i * frame_scale) for i in range(nframes)]
video = [video[i] for i in range(len(video)) if i in frame_scaled_idxs]
video = np.stack(video).astype(np.uint8) # for cache
return video
def _read_video_metric(self, video_path, nframes, size):
video = []
if os.path.isdir(video_path):
_all_fnames = {
os.path.relpath(os.path.join(root, fname), start=video_path)
for root, _dirs, files in os.walk(video_path)
for fname in files
}
_video_fnames = sorted(
fname
for fname in _all_fnames
if self._file_ext(fname) in self.IMG_EXTENSIONS
)
for fname in _video_fnames:
with open(os.path.join(video_path, fname), "rb") as f:
video.append(
np.array(
PIL.Image.open(f)
.convert("RGB")
.resize(
self.size, resample=1
) # PIL.Image.Resampling.LANCZOS = 1 PIL.Image.Resampling.BICUBIC = 3
)
)
else:
with av.open(video_path) as container:
container.streams.video[0].thread_type = "AUTO"
container.streams.video[0].thread_count = 2
total_frames = container.streams.video[0].frames
frame_scale = total_frames / nframes
frame_scaled_idxs = [int(i * frame_scale) for i in range(total_frames)]
for idx, frame in enumerate(container.decode(video=0)):
if idx in frame_scaled_idxs:
frame = F.interpolate(TF.pil_to_tensor(frame.to_image()).unsqueeze(0), size=size[0], mode='bilinear', align_corners=False)[0].numpy().clip(0, 255)
video.append(frame)
container.close()
frame_scale = len(video) / nframes
frame_scaled_idxs = [int(i * frame_scale) for i in range(nframes)]
video = [video[i] for i in range(len(video)) if i in frame_scaled_idxs]
# for cache
video = np.stack(video).astype(np.uint8)
return video
def __getitem__(self, index):
video_path = os.path.join(self.path, self._video_fnames[index])
try:
video = self._read_video_metric(
video_path=video_path, nframes=self.nframes, size=self.size
)
except Exception as e:
print("=> error with loading video", video_path, e)
video = self.__getitem__(index + 1)
if video.shape[0] != self.nframes: print("=> unconsisitent video frames", video_path, video.shape[0], "v.s.", self.nframes)
video = video.astype(np.float32)
video = (video - 127.5) / 127.5
return video
def __len__(self):
return self._total_size
class Dataset(VideoFolder):
def __init__(
self,
data_root: str,
resolution: List[int],
video_length: int = 128,
latent_scale = 8,
actions = 7,
):
super().__init__(data_root, size=resolution, nframes=video_length)
self.data_root = data_root
self.actions = actions
videos = os.listdir(data_root)
videos.sort()
videos = videos[:100_000]
self._video_fnames = videos
self._total_size = len(self._video_fnames)
self.latent_scale = latent_scale
def __getitem__(self, index):
index = index % self._total_size
video_path = os.path.join(self.data_root, self._video_fnames[index])
video = self._read_video(
video_path=video_path, nframes=self.nframes, size=self.size
)
video = video.astype(np.float32)
video = (video - 127.5) / 127.5
if video.shape[0] != self.nframes:
raise ValueError(
f"{video_path} has less than {self.nframes} frames only have {video.shape[0]}"
)
actions = []
with open(
os.path.join(self.data_root, self._video_fnames[index], "actions.txt"), "r"
) as f:
for line in f: actions.append(int(line.strip()))
# SIMPLE_MOVEMENT = [
# ['NOOP'],
# ['right'],
# ['right', 'A'],
# ['right', 'B'],
# ['right', 'A', 'B'],
# ['A'],
# ['left'],
# ]
video = rearrange(video, 'T H W C -> C T H W')
grid_size = [self.nframes, video.shape[2] // self.latent_scale, video.shape[3] // self.latent_scale, self.actions]
grid_t = np.arange(grid_size[0], dtype=np.float32)
grid_h = np.arange(grid_size[1], dtype=np.float32)
grid_w = np.arange(grid_size[2], dtype=np.float32)
grid_action = np.arange(grid_size[3], dtype=np.float32)
grid = np.meshgrid(grid_t, grid_h, grid_w, grid_action, indexing='ij') # here w goes first
grid = np.stack(grid, axis=0)
grid = rearrange(grid[:, np.arange(grid_size[0]), :, :, actions], "T N H W -> N T H W")
return (video, actions, grid)
def __len__(self):
return self._total_size