|
import random |
|
import pickle |
|
|
|
import logging |
|
import torch |
|
import cv2 |
|
import os |
|
|
|
from torch.utils.data.dataset import Dataset |
|
import numpy as np |
|
import cvbase |
|
from .util.STTN_mask import create_random_shape_with_random_motion |
|
import imageio |
|
from .util.flow_utils import region_fill as rf |
|
|
|
logger = logging.getLogger('base') |
|
|
|
|
|
class VideoBasedDataset(Dataset): |
|
def __init__(self, opt, dataInfo): |
|
self.opt = opt |
|
self.sampleMethod = opt['sample'] |
|
self.dataInfo = dataInfo |
|
self.height, self.width = self.opt['input_resolution'] |
|
self.frame_path = dataInfo['frame_path'] |
|
self.flow_path = dataInfo['flow_path'] |
|
self.train_list = os.listdir(self.frame_path) |
|
self.name2length = self.dataInfo['name2len'] |
|
with open(self.name2length, 'rb') as f: |
|
self.name2length = pickle.load(f) |
|
self.sequenceLen = self.opt['num_frames'] |
|
self.flow2rgb = opt['flow2rgb'] |
|
self.flow_direction = opt[ |
|
'flow_direction'] |
|
|
|
def __len__(self): |
|
return len(self.train_list) |
|
|
|
def __getitem__(self, idx): |
|
try: |
|
item = self.load_item(idx) |
|
except: |
|
print('Loading error: ' + self.train_list[idx]) |
|
item = self.load_item(0) |
|
return item |
|
|
|
def frameSample(self, frameLen, sequenceLen): |
|
if self.sampleMethod == 'random': |
|
indices = [i for i in range(frameLen)] |
|
sampleIndices = random.sample(indices, sequenceLen) |
|
elif self.sampleMethod == 'seq': |
|
pivot = random.randint(0, sequenceLen - 1 - frameLen) |
|
sampleIndices = [i for i in range(pivot, pivot + frameLen)] |
|
else: |
|
raise ValueError('Cannot determine the sample method {}'.format(self.sampleMethod)) |
|
return sampleIndices |
|
|
|
def load_item(self, idx): |
|
video = self.train_list[idx] |
|
frame_dir = os.path.join(self.frame_path, video) |
|
forward_flow_dir = os.path.join(self.flow_path, video, 'forward_flo') |
|
backward_flow_dir = os.path.join(self.flow_path, video, 'backward_flo') |
|
frameLen = self.name2length[video] |
|
flowLen = frameLen - 1 |
|
assert frameLen > self.sequenceLen, 'Frame length {} is less than sequence length'.format(frameLen) |
|
sampledIndices = self.frameSample(frameLen, self.sequenceLen) |
|
|
|
|
|
candidateMasks = create_random_shape_with_random_motion(frameLen, 0.9, 1.1, 1, 10) |
|
|
|
|
|
frames, masks, forward_flows, backward_flows = [], [], [], [] |
|
for i in range(len(sampledIndices)): |
|
frame = self.read_frame(os.path.join(frame_dir, '{:05d}.jpg'.format(sampledIndices[i])), self.height, |
|
self.width) |
|
mask = self.read_mask(candidateMasks[sampledIndices[i]], self.height, self.width) |
|
frames.append(frame) |
|
masks.append(mask) |
|
if self.flow_direction == 'for': |
|
forward_flow = self.read_forward_flow(forward_flow_dir, sampledIndices[i], flowLen) |
|
forward_flow = self.diffusion_flow(forward_flow, mask) |
|
forward_flows.append(forward_flow) |
|
elif self.flow_direction == 'back': |
|
backward_flow = self.read_backward_flow(backward_flow_dir, sampledIndices[i]) |
|
backward_flow = self.diffusion_flow(backward_flow, mask) |
|
backward_flows.append(backward_flow) |
|
elif self.flow_direction == 'bi': |
|
forward_flow = self.read_forward_flow(forward_flow_dir, sampledIndices[i], flowLen) |
|
forward_flow = self.diffusion_flow(forward_flow, mask) |
|
forward_flows.append(forward_flow) |
|
backward_flow = self.read_backward_flow(backward_flow_dir, sampledIndices[i]) |
|
backward_flow = self.diffusion_flow(backward_flow, mask) |
|
backward_flows.append(backward_flow) |
|
else: |
|
raise ValueError('Unknown flow direction mode: {}'.format(self.flow_direction)) |
|
inputs = {'frames': frames, 'masks': masks, 'forward_flo': forward_flows, 'backward_flo': backward_flows} |
|
inputs = self.to_tensor(inputs) |
|
inputs['frames'] = (inputs['frames'] / 255.) * 2 - 1 |
|
return inputs |
|
|
|
def diffusion_flow(self, flow, mask): |
|
flow_filled = np.zeros(flow.shape) |
|
flow_filled[:, :, 0] = rf.regionfill(flow[:, :, 0] * (1 - mask), mask) |
|
flow_filled[:, :, 1] = rf.regionfill(flow[:, :, 1] * (1 - mask), mask) |
|
return flow_filled |
|
|
|
def read_frame(self, path, height, width): |
|
frame = imageio.imread(path) |
|
frame = cv2.resize(frame, (width, height), cv2.INTER_LINEAR) |
|
return frame |
|
|
|
def read_mask(self, mask, height, width): |
|
mask = np.array(mask) |
|
mask = mask / 255. |
|
raw_mask = (mask > 0.5).astype(np.uint8) |
|
raw_mask = cv2.resize(raw_mask, dsize=(width, height), interpolation=cv2.INTER_NEAREST) |
|
return raw_mask |
|
|
|
def read_forward_flow(self, forward_flow_dir, sampledIndex, flowLen): |
|
if sampledIndex >= flowLen: |
|
sampledIndex = flowLen - 1 |
|
flow = cvbase.read_flow(os.path.join(forward_flow_dir, '{:05d}.flo'.format(sampledIndex))) |
|
height, width = flow.shape[:2] |
|
flow = cv2.resize(flow, (self.width, self.height), cv2.INTER_LINEAR) |
|
flow[:, :, 0] = flow[:, :, 0] / width * self.width |
|
flow[:, :, 1] = flow[:, :, 1] / height * self.height |
|
return flow |
|
|
|
def read_backward_flow(self, backward_flow_dir, sampledIndex): |
|
if sampledIndex == 0: |
|
sampledIndex = 0 |
|
else: |
|
sampledIndex -= 1 |
|
flow = cvbase.read_flow(os.path.join(backward_flow_dir, '{:05d}.flo'.format(sampledIndex))) |
|
height, width = flow.shape[:2] |
|
flow = cv2.resize(flow, (self.width, self.height), cv2.INTER_LINEAR) |
|
flow[:, :, 0] = flow[:, :, 0] / width * self.width |
|
flow[:, :, 1] = flow[:, :, 1] / height * self.height |
|
return flow |
|
|
|
def to_tensor(self, data_list): |
|
""" |
|
|
|
Args: |
|
data_list: A list contains multiple numpy arrays |
|
|
|
Returns: The stacked tensor list |
|
|
|
""" |
|
keys = list(data_list.keys()) |
|
for key in keys: |
|
if data_list[key] is None or data_list[key] == []: |
|
data_list.pop(key) |
|
else: |
|
item = data_list[key] |
|
if not isinstance(item, list): |
|
item = torch.from_numpy(np.transpose(item, (2, 0, 1))).float() |
|
else: |
|
item = np.stack(item, axis=0) |
|
if len(item.shape) == 3: |
|
item = item[:, :, :, np.newaxis] |
|
item = torch.from_numpy(np.transpose(item, (0, 3, 1, 2))).float() |
|
data_list[key] = item |
|
return data_list |
|
|
|
|