Spaces:
Running
on
Zero
Running
on
Zero
from utils.dataset_utils import * | |
# https://github.com/ExponentialML/Video-BLIP2-Preprocessor | |
class VideoJsonDataset(Dataset): | |
def __init__( | |
self, | |
tokenizer = None, | |
width: int = 256, | |
height: int = 256, | |
n_sample_frames: int = 4, | |
sample_start_idx: int = 1, | |
frame_step: int = 1, | |
json_path: str ="", | |
json_data = None, | |
vid_data_key: str = "video_path", | |
preprocessed: bool = False, | |
use_bucketing: bool = False, | |
**kwargs | |
): | |
self.vid_types = (".mp4", ".avi", ".mov", ".webm", ".flv", ".mjpeg") | |
self.use_bucketing = use_bucketing | |
self.tokenizer = tokenizer | |
self.preprocessed = preprocessed | |
self.vid_data_key = vid_data_key | |
self.train_data = self.load_from_json(json_path, json_data) | |
self.width = width | |
self.height = height | |
self.n_sample_frames = n_sample_frames | |
self.sample_start_idx = sample_start_idx | |
self.frame_step = frame_step | |
def build_json(self, json_data): | |
extended_data = [] | |
for data in json_data['data']: | |
for nested_data in data['data']: | |
self.build_json_dict( | |
data, | |
nested_data, | |
extended_data | |
) | |
json_data = extended_data | |
return json_data | |
def build_json_dict(self, data, nested_data, extended_data): | |
clip_path = nested_data['clip_path'] if 'clip_path' in nested_data else None | |
extended_data.append({ | |
self.vid_data_key: data[self.vid_data_key], | |
'frame_index': nested_data['frame_index'], | |
'prompt': nested_data['prompt'], | |
'clip_path': clip_path | |
}) | |
def load_from_json(self, path, json_data): | |
try: | |
with open(path) as jpath: | |
print(f"Loading JSON from {path}") | |
json_data = json.load(jpath) | |
return self.build_json(json_data) | |
except: | |
self.train_data = [] | |
print("Non-existant JSON path. Skipping.") | |
def validate_json(self, base_path, path): | |
return os.path.exists(f"{base_path}/{path}") | |
def get_frame_range(self, vr): | |
return get_video_frames( | |
vr, | |
self.sample_start_idx, | |
self.frame_step, | |
self.n_sample_frames | |
) | |
def get_vid_idx(self, vr, vid_data=None): | |
frames = self.n_sample_frames | |
if vid_data is not None: | |
idx = vid_data['frame_index'] | |
else: | |
idx = self.sample_start_idx | |
return idx | |
def get_frame_buckets(self, vr): | |
_, h, w = vr[0].shape | |
width, height = sensible_buckets(self.width, self.height, h, w) | |
# width, height = self.width, self.height | |
resize = T.transforms.Resize((height, width), antialias=True) | |
return resize | |
def get_frame_batch(self, vr, resize=None): | |
frame_range = self.get_frame_range(vr) | |
frames = vr.get_batch(frame_range) | |
video = rearrange(frames, "f h w c -> f c h w") | |
if resize is not None: video = resize(video) | |
return video | |
def process_video_wrapper(self, vid_path): | |
video, vr = process_video( | |
vid_path, | |
self.use_bucketing, | |
self.width, | |
self.height, | |
self.get_frame_buckets, | |
self.get_frame_batch | |
) | |
return video, vr | |
def train_data_batch(self, index): | |
# If we are training on individual clips. | |
if 'clip_path' in self.train_data[index] and \ | |
self.train_data[index]['clip_path'] is not None: | |
vid_data = self.train_data[index] | |
clip_path = vid_data['clip_path'] | |
# Get video prompt | |
prompt = vid_data['prompt'] | |
video, _ = self.process_video_wrapper(clip_path) | |
prompt_ids = get_prompt_ids(prompt, self.tokenizer) | |
return video, prompt, prompt_ids | |
# Assign train data | |
train_data = self.train_data[index] | |
# Get the frame of the current index. | |
self.sample_start_idx = train_data['frame_index'] | |
# Initialize resize | |
resize = None | |
video, vr = self.process_video_wrapper(train_data[self.vid_data_key]) | |
# Get video prompt | |
prompt = train_data['prompt'] | |
vr.seek(0) | |
prompt_ids = get_prompt_ids(prompt, self.tokenizer) | |
return video, prompt, prompt_ids | |
def __getname__(): return 'json' | |
def __len__(self): | |
if self.train_data is not None: | |
return len(self.train_data) | |
else: | |
return 0 | |
def __getitem__(self, index): | |
# Initialize variables | |
video = None | |
prompt = None | |
prompt_ids = None | |
# Use default JSON training | |
if self.train_data is not None: | |
video, prompt, prompt_ids = self.train_data_batch(index) | |
example = { | |
"pixel_values": (video / 127.5 - 1.0), | |
"prompt_ids": prompt_ids[0], | |
"text_prompt": prompt, | |
'dataset': self.__getname__() | |
} | |
return example |