MotionInversion / dataset /video_json_dataset.py
ziyangmai's picture
page demo
113884e
from utils.dataset_utils import *
# https://github.com/ExponentialML/Video-BLIP2-Preprocessor
class VideoJsonDataset(Dataset):
def __init__(
self,
tokenizer = None,
width: int = 256,
height: int = 256,
n_sample_frames: int = 4,
sample_start_idx: int = 1,
frame_step: int = 1,
json_path: str ="",
json_data = None,
vid_data_key: str = "video_path",
preprocessed: bool = False,
use_bucketing: bool = False,
**kwargs
):
self.vid_types = (".mp4", ".avi", ".mov", ".webm", ".flv", ".mjpeg")
self.use_bucketing = use_bucketing
self.tokenizer = tokenizer
self.preprocessed = preprocessed
self.vid_data_key = vid_data_key
self.train_data = self.load_from_json(json_path, json_data)
self.width = width
self.height = height
self.n_sample_frames = n_sample_frames
self.sample_start_idx = sample_start_idx
self.frame_step = frame_step
def build_json(self, json_data):
extended_data = []
for data in json_data['data']:
for nested_data in data['data']:
self.build_json_dict(
data,
nested_data,
extended_data
)
json_data = extended_data
return json_data
def build_json_dict(self, data, nested_data, extended_data):
clip_path = nested_data['clip_path'] if 'clip_path' in nested_data else None
extended_data.append({
self.vid_data_key: data[self.vid_data_key],
'frame_index': nested_data['frame_index'],
'prompt': nested_data['prompt'],
'clip_path': clip_path
})
def load_from_json(self, path, json_data):
try:
with open(path) as jpath:
print(f"Loading JSON from {path}")
json_data = json.load(jpath)
return self.build_json(json_data)
except:
self.train_data = []
print("Non-existant JSON path. Skipping.")
def validate_json(self, base_path, path):
return os.path.exists(f"{base_path}/{path}")
def get_frame_range(self, vr):
return get_video_frames(
vr,
self.sample_start_idx,
self.frame_step,
self.n_sample_frames
)
def get_vid_idx(self, vr, vid_data=None):
frames = self.n_sample_frames
if vid_data is not None:
idx = vid_data['frame_index']
else:
idx = self.sample_start_idx
return idx
def get_frame_buckets(self, vr):
_, h, w = vr[0].shape
width, height = sensible_buckets(self.width, self.height, h, w)
# width, height = self.width, self.height
resize = T.transforms.Resize((height, width), antialias=True)
return resize
def get_frame_batch(self, vr, resize=None):
frame_range = self.get_frame_range(vr)
frames = vr.get_batch(frame_range)
video = rearrange(frames, "f h w c -> f c h w")
if resize is not None: video = resize(video)
return video
def process_video_wrapper(self, vid_path):
video, vr = process_video(
vid_path,
self.use_bucketing,
self.width,
self.height,
self.get_frame_buckets,
self.get_frame_batch
)
return video, vr
def train_data_batch(self, index):
# If we are training on individual clips.
if 'clip_path' in self.train_data[index] and \
self.train_data[index]['clip_path'] is not None:
vid_data = self.train_data[index]
clip_path = vid_data['clip_path']
# Get video prompt
prompt = vid_data['prompt']
video, _ = self.process_video_wrapper(clip_path)
prompt_ids = get_prompt_ids(prompt, self.tokenizer)
return video, prompt, prompt_ids
# Assign train data
train_data = self.train_data[index]
# Get the frame of the current index.
self.sample_start_idx = train_data['frame_index']
# Initialize resize
resize = None
video, vr = self.process_video_wrapper(train_data[self.vid_data_key])
# Get video prompt
prompt = train_data['prompt']
vr.seek(0)
prompt_ids = get_prompt_ids(prompt, self.tokenizer)
return video, prompt, prompt_ids
@staticmethod
def __getname__(): return 'json'
def __len__(self):
if self.train_data is not None:
return len(self.train_data)
else:
return 0
def __getitem__(self, index):
# Initialize variables
video = None
prompt = None
prompt_ids = None
# Use default JSON training
if self.train_data is not None:
video, prompt, prompt_ids = self.train_data_batch(index)
example = {
"pixel_values": (video / 127.5 - 1.0),
"prompt_ids": prompt_ids[0],
"text_prompt": prompt,
'dataset': self.__getname__()
}
return example