import os import random from typing import List import av import av.logging import PIL import numpy as np from torch.utils.data import Dataset from einops import rearrange # from cache_decorator import Cache import torch.nn.functional as F import torchvision.transforms.functional as TF import cv2 class VideoFolder(Dataset): IMG_EXTENSIONS = [ ".png", ".PNG", ".jpg", ".JPG" ] VIDEO_EXTENSIONS = [".mp4", ".MP4", ".avi", ".AVI"] def __init__( self, path: str, size: List[int], nframes: int = 128, ): if isinstance(size, (list, tuple)): if len(size) not in [1, 2]: raise ValueError( f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list" ) if isinstance(size, int): size = [size, size] def _find_all_path(_path): _all_fnames = { os.path.relpath(os.path.join(root, fname), start=_path) for root, _dirs, files in os.walk(_path) for fname in files } _video_fnames = sorted( fname for fname in _all_fnames if self._file_ext(fname) in self.VIDEO_EXTENSIONS ) + sorted( list( set( ( os.path.dirname(fname) for fname in _all_fnames if self._file_ext(fname) in self.IMG_EXTENSIONS ) ) ) ) _video_fnames = sorted(_video_fnames) return _video_fnames _video_fnames = _find_all_path(path) self.path = path self.size = size self.nframes = nframes self._video_fnames = _video_fnames self._total_size = len(self._video_fnames) @staticmethod def _file_ext(fname): return os.path.splitext(fname)[1].lower() def _read_video_opencv(self, video_path, nframes, size): video = [] if os.path.isdir(video_path): _all_fnames = { os.path.relpath(os.path.join(root, fname), start=video_path) for root, _dirs, files in os.walk(video_path) for fname in files } _video_fnames = sorted( fname for fname in _all_fnames if self._file_ext(fname) in self.IMG_EXTENSIONS ) for fname in _video_fnames: with open(os.path.join(video_path, fname), "rb") as f: video.append( np.array( PIL.Image.open(f) .convert("RGB") .resize( size, resample=3 ) # PIL.Image.Resampling.LANCZOS = 1 PIL.Image.Resampling.BICUBIC = 3 ) ) else: video = [] cap = cv2.VideoCapture(video_path) while cap.isOpened(): success, image = cap.read() if success: video.append( np.asarray( cv2.resize(image, size, interpolation=cv2.INTER_CUBIC)[ :, :, ::-1 ] ) ) else: break cap.release() if len(video) != nframes: frame_scale = len(video) / nframes frame_scaled_idxs = [int(i * frame_scale) for i in range(nframes)] video = [video[i] for i in range(len(video)) if i in frame_scaled_idxs] # for cache video = np.stack(video).astype(np.uint8) return video # @Cache( # cache_path="/data/kmei1/caches/{_hash}.pkl", # ) def _read_video(self, video_path, nframes, size): video = [] if os.path.isdir(video_path): _all_fnames = { os.path.relpath(os.path.join(root, fname), start=video_path) for root, _dirs, files in os.walk(video_path) for fname in files } _video_fnames = sorted( [fname for fname in _all_fnames if self._file_ext(fname) in self.IMG_EXTENSIONS], key = lambda x: int(x[:-4]) ) for fname in _video_fnames: with open(os.path.join(video_path, fname), "rb") as f: video.append( np.array( PIL.Image.open(f) .convert("RGB") .resize( self.size, resample=1 ) # PIL.Image.Resampling.LANCZOS = 1 PIL.Image.Resampling.BICUBIC = 3 ) ) else: with av.open(video_path) as container: container.streams.video[0].thread_type = "AUTO" container.streams.video[0].thread_count = 2 total_frames = container.streams.video[0].frames frame_scale = total_frames / nframes frame_scaled_idxs = [int(i * frame_scale) for i in range(total_frames)] for idx, frame in enumerate(container.decode(video=0)): if idx in frame_scaled_idxs: video.append( np.asarray( frame.to_image().resize( size, resample=1 ) # PIL.Image.Resampling.LANCZOS = 1 PIL.Image.Resampling.BICUBIC = 3 ).clip(0, 255) ) container.close() frame_scale = len(video) / nframes frame_scaled_idxs = [int(i * frame_scale) for i in range(nframes)] video = [video[i] for i in range(len(video)) if i in frame_scaled_idxs] video = np.stack(video).astype(np.uint8) # for cache return video def _read_video_metric(self, video_path, nframes, size): video = [] if os.path.isdir(video_path): _all_fnames = { os.path.relpath(os.path.join(root, fname), start=video_path) for root, _dirs, files in os.walk(video_path) for fname in files } _video_fnames = sorted( fname for fname in _all_fnames if self._file_ext(fname) in self.IMG_EXTENSIONS ) for fname in _video_fnames: with open(os.path.join(video_path, fname), "rb") as f: video.append( np.array( PIL.Image.open(f) .convert("RGB") .resize( self.size, resample=1 ) # PIL.Image.Resampling.LANCZOS = 1 PIL.Image.Resampling.BICUBIC = 3 ) ) else: with av.open(video_path) as container: container.streams.video[0].thread_type = "AUTO" container.streams.video[0].thread_count = 2 total_frames = container.streams.video[0].frames frame_scale = total_frames / nframes frame_scaled_idxs = [int(i * frame_scale) for i in range(total_frames)] for idx, frame in enumerate(container.decode(video=0)): if idx in frame_scaled_idxs: frame = F.interpolate(TF.pil_to_tensor(frame.to_image()).unsqueeze(0), size=size[0], mode='bilinear', align_corners=False)[0].numpy().clip(0, 255) video.append(frame) container.close() frame_scale = len(video) / nframes frame_scaled_idxs = [int(i * frame_scale) for i in range(nframes)] video = [video[i] for i in range(len(video)) if i in frame_scaled_idxs] # for cache video = np.stack(video).astype(np.uint8) return video def __getitem__(self, index): video_path = os.path.join(self.path, self._video_fnames[index]) try: video = self._read_video_metric( video_path=video_path, nframes=self.nframes, size=self.size ) except Exception as e: print("=> error with loading video", video_path, e) video = self.__getitem__(index + 1) if video.shape[0] != self.nframes: print("=> unconsisitent video frames", video_path, video.shape[0], "v.s.", self.nframes) video = video.astype(np.float32) video = (video - 127.5) / 127.5 return video def __len__(self): return self._total_size class Dataset(VideoFolder): def __init__( self, data_root: str, resolution: List[int], video_length: int = 128, latent_scale = 8, actions = 7, ): super().__init__(data_root, size=resolution, nframes=video_length) self.data_root = data_root self.actions = actions videos = os.listdir(data_root) videos.sort() videos = videos[:100_000] self._video_fnames = videos self._total_size = len(self._video_fnames) self.latent_scale = latent_scale def __getitem__(self, index): index = index % self._total_size video_path = os.path.join(self.data_root, self._video_fnames[index]) video = self._read_video( video_path=video_path, nframes=self.nframes, size=self.size ) video = video.astype(np.float32) video = (video - 127.5) / 127.5 if video.shape[0] != self.nframes: raise ValueError( f"{video_path} has less than {self.nframes} frames only have {video.shape[0]}" ) actions = [] with open( os.path.join(self.data_root, self._video_fnames[index], "actions.txt"), "r" ) as f: for line in f: actions.append(int(line.strip())) # SIMPLE_MOVEMENT = [ # ['NOOP'], # ['right'], # ['right', 'A'], # ['right', 'B'], # ['right', 'A', 'B'], # ['A'], # ['left'], # ] video = rearrange(video, 'T H W C -> C T H W') grid_size = [self.nframes, video.shape[2] // self.latent_scale, video.shape[3] // self.latent_scale, self.actions] grid_t = np.arange(grid_size[0], dtype=np.float32) grid_h = np.arange(grid_size[1], dtype=np.float32) grid_w = np.arange(grid_size[2], dtype=np.float32) grid_action = np.arange(grid_size[3], dtype=np.float32) grid = np.meshgrid(grid_t, grid_h, grid_w, grid_action, indexing='ij') # here w goes first grid = np.stack(grid, axis=0) grid = rearrange(grid[:, np.arange(grid_size[0]), :, :, actions], "T N H W -> N T H W") return (video, actions, grid) def __len__(self): return self._total_size