MKFMIKU's picture
A new start
596242b
import os
import random
from typing import List
import av
import av.logging
import PIL
import numpy as np
from torch.utils.data import Dataset
from einops import rearrange
# from cache_decorator import Cache
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import cv2
class VideoFolder(Dataset):
IMG_EXTENSIONS = [
".png",
".PNG",
".jpg",
".JPG"
]
VIDEO_EXTENSIONS = [".mp4", ".MP4", ".avi", ".AVI"]
def __init__(
self,
path: str,
size: List[int],
nframes: int = 128,
):
if isinstance(size, (list, tuple)):
if len(size) not in [1, 2]:
raise ValueError(
f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list"
)
if isinstance(size, int):
size = [size, size]
def _find_all_path(_path):
_all_fnames = {
os.path.relpath(os.path.join(root, fname), start=_path)
for root, _dirs, files in os.walk(_path)
for fname in files
}
_video_fnames = sorted(
fname
for fname in _all_fnames
if self._file_ext(fname) in self.VIDEO_EXTENSIONS
) + sorted(
list(
set(
(
os.path.dirname(fname)
for fname in _all_fnames
if self._file_ext(fname) in self.IMG_EXTENSIONS
)
)
)
)
_video_fnames = sorted(_video_fnames)
return _video_fnames
_video_fnames = _find_all_path(path)
self.path = path
self.size = size
self.nframes = nframes
self._video_fnames = _video_fnames
self._total_size = len(self._video_fnames)
@staticmethod
def _file_ext(fname):
return os.path.splitext(fname)[1].lower()
def _read_video_opencv(self, video_path, nframes, size):
video = []
if os.path.isdir(video_path):
_all_fnames = {
os.path.relpath(os.path.join(root, fname), start=video_path)
for root, _dirs, files in os.walk(video_path)
for fname in files
}
_video_fnames = sorted(
fname
for fname in _all_fnames
if self._file_ext(fname) in self.IMG_EXTENSIONS
)
for fname in _video_fnames:
with open(os.path.join(video_path, fname), "rb") as f:
video.append(
np.array(
PIL.Image.open(f)
.convert("RGB")
.resize(
size, resample=3
) # PIL.Image.Resampling.LANCZOS = 1 PIL.Image.Resampling.BICUBIC = 3
)
)
else:
video = []
cap = cv2.VideoCapture(video_path)
while cap.isOpened():
success, image = cap.read()
if success:
video.append(
np.asarray(
cv2.resize(image, size, interpolation=cv2.INTER_CUBIC)[
:, :, ::-1
]
)
)
else:
break
cap.release()
if len(video) != nframes:
frame_scale = len(video) / nframes
frame_scaled_idxs = [int(i * frame_scale) for i in range(nframes)]
video = [video[i] for i in range(len(video)) if i in frame_scaled_idxs]
# for cache
video = np.stack(video).astype(np.uint8)
return video
# @Cache(
# cache_path="/data/kmei1/caches/{_hash}.pkl",
# )
def _read_video(self, video_path, nframes, size):
video = []
if os.path.isdir(video_path):
_all_fnames = {
os.path.relpath(os.path.join(root, fname), start=video_path)
for root, _dirs, files in os.walk(video_path)
for fname in files
}
_video_fnames = sorted(
[fname
for fname in _all_fnames
if self._file_ext(fname) in self.IMG_EXTENSIONS],
key = lambda x: int(x[:-4])
)
for fname in _video_fnames:
with open(os.path.join(video_path, fname), "rb") as f:
video.append(
np.array(
PIL.Image.open(f)
.convert("RGB")
.resize(
self.size, resample=1
) # PIL.Image.Resampling.LANCZOS = 1 PIL.Image.Resampling.BICUBIC = 3
)
)
else:
with av.open(video_path) as container:
container.streams.video[0].thread_type = "AUTO"
container.streams.video[0].thread_count = 2
total_frames = container.streams.video[0].frames
frame_scale = total_frames / nframes
frame_scaled_idxs = [int(i * frame_scale) for i in range(total_frames)]
for idx, frame in enumerate(container.decode(video=0)):
if idx in frame_scaled_idxs:
video.append(
np.asarray(
frame.to_image().resize(
size, resample=1
) # PIL.Image.Resampling.LANCZOS = 1 PIL.Image.Resampling.BICUBIC = 3
).clip(0, 255)
)
container.close()
frame_scale = len(video) / nframes
frame_scaled_idxs = [int(i * frame_scale) for i in range(nframes)]
video = [video[i] for i in range(len(video)) if i in frame_scaled_idxs]
video = np.stack(video).astype(np.uint8) # for cache
return video
def _read_video_metric(self, video_path, nframes, size):
video = []
if os.path.isdir(video_path):
_all_fnames = {
os.path.relpath(os.path.join(root, fname), start=video_path)
for root, _dirs, files in os.walk(video_path)
for fname in files
}
_video_fnames = sorted(
fname
for fname in _all_fnames
if self._file_ext(fname) in self.IMG_EXTENSIONS
)
for fname in _video_fnames:
with open(os.path.join(video_path, fname), "rb") as f:
video.append(
np.array(
PIL.Image.open(f)
.convert("RGB")
.resize(
self.size, resample=1
) # PIL.Image.Resampling.LANCZOS = 1 PIL.Image.Resampling.BICUBIC = 3
)
)
else:
with av.open(video_path) as container:
container.streams.video[0].thread_type = "AUTO"
container.streams.video[0].thread_count = 2
total_frames = container.streams.video[0].frames
frame_scale = total_frames / nframes
frame_scaled_idxs = [int(i * frame_scale) for i in range(total_frames)]
for idx, frame in enumerate(container.decode(video=0)):
if idx in frame_scaled_idxs:
frame = F.interpolate(TF.pil_to_tensor(frame.to_image()).unsqueeze(0), size=size[0], mode='bilinear', align_corners=False)[0].numpy().clip(0, 255)
video.append(frame)
container.close()
frame_scale = len(video) / nframes
frame_scaled_idxs = [int(i * frame_scale) for i in range(nframes)]
video = [video[i] for i in range(len(video)) if i in frame_scaled_idxs]
# for cache
video = np.stack(video).astype(np.uint8)
return video
def __getitem__(self, index):
video_path = os.path.join(self.path, self._video_fnames[index])
try:
video = self._read_video_metric(
video_path=video_path, nframes=self.nframes, size=self.size
)
except Exception as e:
print("=> error with loading video", video_path, e)
video = self.__getitem__(index + 1)
if video.shape[0] != self.nframes: print("=> unconsisitent video frames", video_path, video.shape[0], "v.s.", self.nframes)
video = video.astype(np.float32)
video = (video - 127.5) / 127.5
return video
def __len__(self):
return self._total_size
class Dataset(VideoFolder):
def __init__(
self,
data_root: str,
resolution: List[int],
video_length: int = 128,
latent_scale = 8,
actions = 7,
):
super().__init__(data_root, size=resolution, nframes=video_length)
self.data_root = data_root
self.actions = actions
videos = os.listdir(data_root)
videos.sort()
videos = videos[:100_000]
self._video_fnames = videos
self._total_size = len(self._video_fnames)
self.latent_scale = latent_scale
def __getitem__(self, index):
index = index % self._total_size
video_path = os.path.join(self.data_root, self._video_fnames[index])
video = self._read_video(
video_path=video_path, nframes=self.nframes, size=self.size
)
video = video.astype(np.float32)
video = (video - 127.5) / 127.5
if video.shape[0] != self.nframes:
raise ValueError(
f"{video_path} has less than {self.nframes} frames only have {video.shape[0]}"
)
actions = []
with open(
os.path.join(self.data_root, self._video_fnames[index], "actions.txt"), "r"
) as f:
for line in f: actions.append(int(line.strip()))
# SIMPLE_MOVEMENT = [
# ['NOOP'],
# ['right'],
# ['right', 'A'],
# ['right', 'B'],
# ['right', 'A', 'B'],
# ['A'],
# ['left'],
# ]
video = rearrange(video, 'T H W C -> C T H W')
grid_size = [self.nframes, video.shape[2] // self.latent_scale, video.shape[3] // self.latent_scale, self.actions]
grid_t = np.arange(grid_size[0], dtype=np.float32)
grid_h = np.arange(grid_size[1], dtype=np.float32)
grid_w = np.arange(grid_size[2], dtype=np.float32)
grid_action = np.arange(grid_size[3], dtype=np.float32)
grid = np.meshgrid(grid_t, grid_h, grid_w, grid_action, indexing='ij') # here w goes first
grid = np.stack(grid, axis=0)
grid = rearrange(grid[:, np.arange(grid_size[0]), :, :, actions], "T N H W -> N T H W")
return (video, actions, grid)
def __len__(self):
return self._total_size