Spaces:
Runtime error
Runtime error
File size: 9,256 Bytes
bab971b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import math
from einops import rearrange
import decord
from torch.nn import functional as F
import torch
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
class DecordInit(object):
"""Using Decord(https://github.com/dmlc/decord) to initialize the video_reader."""
def __init__(self, num_threads=1):
self.num_threads = num_threads
self.ctx = decord.cpu(0)
def __call__(self, filename):
"""Perform the Decord initialization.
Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""
reader = decord.VideoReader(filename,
ctx=self.ctx,
num_threads=self.num_threads)
return reader
def __repr__(self):
repr_str = (f'{self.__class__.__name__}('
f'sr={self.sr},'
f'num_threads={self.num_threads})')
return repr_str
def pad_to_multiple(number, ds_stride):
remainder = number % ds_stride
if remainder == 0:
return number
else:
padding = ds_stride - remainder
return number + padding
class Collate:
def __init__(self, args):
self.max_image_size = args.max_image_size
self.ae_stride = args.ae_stride
self.ae_stride_t = args.ae_stride_t
self.ae_stride_thw = (self.ae_stride_t, self.ae_stride, self.ae_stride)
self.ae_stride_1hw = (1, self.ae_stride, self.ae_stride)
self.patch_size = args.patch_size
self.patch_size_t = args.patch_size_t
self.patch_size_thw = (self.patch_size_t, self.patch_size, self.patch_size)
self.patch_size_1hw = (1, self.patch_size, self.patch_size)
self.num_frames = args.num_frames
self.use_image_num = args.use_image_num
self.max_thw = (self.num_frames, self.max_image_size, self.max_image_size)
self.max_1hw = (1, self.max_image_size, self.max_image_size)
def package(self, batch):
# import ipdb;ipdb.set_trace()
batch_tubes_vid = [i['video_data']['video'] for i in batch] # b [c t h w]
input_ids_vid = torch.stack([i['video_data']['input_ids'] for i in batch]) # b 1 l
cond_mask_vid = torch.stack([i['video_data']['cond_mask'] for i in batch]) # b 1 l
batch_tubes_img, input_ids_img, cond_mask_img = None, None, None
if self.use_image_num != 0:
batch_tubes_img = [j for i in batch for j in i['image_data']['image']] # b*num_img [c 1 h w]
input_ids_img = torch.stack([i['image_data']['input_ids'] for i in batch]) # b image_num l
cond_mask_img = torch.stack([i['image_data']['cond_mask'] for i in batch]) # b image_num l
return batch_tubes_vid, input_ids_vid, cond_mask_vid, batch_tubes_img, input_ids_img, cond_mask_img
def __call__(self, batch):
batch_tubes_vid, input_ids_vid, cond_mask_vid, batch_tubes_img, input_ids_img, cond_mask_img = self.package(batch)
# import ipdb;ipdb.set_trace()
ds_stride = self.ae_stride * self.patch_size
t_ds_stride = self.ae_stride_t * self.patch_size_t
if self.use_image_num == 0:
pad_batch_tubes, attention_mask = self.process(batch_tubes_vid, t_ds_stride, ds_stride,
self.max_thw, self.ae_stride_thw, self.patch_size_thw, extra_1=True)
# attention_mask: b t h w
input_ids, cond_mask = input_ids_vid.squeeze(1), cond_mask_vid.squeeze(1) # b 1 l -> b l
else:
pad_batch_tubes_vid, attention_mask_vid = self.process(batch_tubes_vid, t_ds_stride, ds_stride,
self.max_thw, self.ae_stride_thw, self.patch_size_thw, extra_1=True)
# attention_mask_vid: b t h w
pad_batch_tubes_img, attention_mask_img = self.process(batch_tubes_img, 1, ds_stride,
self.max_1hw, self.ae_stride_1hw, self.patch_size_1hw, extra_1=False)
pad_batch_tubes_img = rearrange(pad_batch_tubes_img, '(b i) c 1 h w -> b c i h w', i=self.use_image_num)
attention_mask_img = rearrange(attention_mask_img, '(b i) 1 h w -> b i h w', i=self.use_image_num)
pad_batch_tubes = torch.cat([pad_batch_tubes_vid, pad_batch_tubes_img], dim=2) # concat at temporal, video first
# attention_mask_img: b num_img h w
attention_mask = torch.cat([attention_mask_vid, attention_mask_img], dim=1) # b t+num_img h w
input_ids = torch.cat([input_ids_vid, input_ids_img], dim=1) # b 1+num_img hw
cond_mask = torch.cat([cond_mask_vid, cond_mask_img], dim=1) # b 1+num_img hw
return pad_batch_tubes, attention_mask, input_ids, cond_mask
def process(self, batch_tubes, t_ds_stride, ds_stride, max_thw, ae_stride_thw, patch_size_thw, extra_1):
# pad to max multiple of ds_stride
batch_input_size = [i.shape for i in batch_tubes] # [(c t h w), (c t h w)]
max_t, max_h, max_w = max_thw
pad_max_t, pad_max_h, pad_max_w = pad_to_multiple(max_t-1 if extra_1 else max_t, t_ds_stride), \
pad_to_multiple(max_h, ds_stride), \
pad_to_multiple(max_w, ds_stride)
pad_max_t = pad_max_t + 1 if extra_1 else pad_max_t
each_pad_t_h_w = [[pad_max_t - i.shape[1],
pad_max_h - i.shape[2],
pad_max_w - i.shape[3]] for i in batch_tubes]
pad_batch_tubes = [F.pad(im,
(0, pad_w,
0, pad_h,
0, pad_t), value=0) for (pad_t, pad_h, pad_w), im in zip(each_pad_t_h_w, batch_tubes)]
pad_batch_tubes = torch.stack(pad_batch_tubes, dim=0)
# make attention_mask
# first_channel_first_frame, first_channel_other_frame = pad_batch_tubes[:, :1, :1], pad_batch_tubes[:, :1, 1:] # first channel to make attention_mask
# attention_mask_first_frame = F.max_pool3d(first_channel_first_frame, kernel_size=(1, *ae_stride_thw[1:]), stride=(1, *ae_stride_thw[1:]))
# if first_channel_other_frame.numel() != 0:
# attention_mask_other_frame = F.max_pool3d(first_channel_other_frame, kernel_size=ae_stride_thw, stride=ae_stride_thw)
# attention_mask = torch.cat([attention_mask_first_frame, attention_mask_other_frame], dim=2)
# else:
# attention_mask = attention_mask_first_frame
# attention_mask_ = attention_mask[:, 0].bool().float() # b t h w, do not channel
# import ipdb;ipdb.set_trace()
max_tube_size = [pad_max_t, pad_max_h, pad_max_w]
max_latent_size = [((max_tube_size[0]-1) // ae_stride_thw[0] + 1) if extra_1 else (max_tube_size[0] // ae_stride_thw[0]),
max_tube_size[1] // ae_stride_thw[1],
max_tube_size[2] // ae_stride_thw[2]]
valid_latent_size = [[int(math.ceil((i[1]-1) / ae_stride_thw[0])) + 1 if extra_1 else int(math.ceil(i[1] / ae_stride_thw[0])),
int(math.ceil(i[2] / ae_stride_thw[1])),
int(math.ceil(i[3] / ae_stride_thw[2]))] for i in batch_input_size]
attention_mask = [F.pad(torch.ones(i),
(0, max_latent_size[2] - i[2],
0, max_latent_size[1] - i[1],
0, max_latent_size[0] - i[0]), value=0) for i in valid_latent_size]
attention_mask = torch.stack(attention_mask) # b t h w
# max_tube_size = [pad_max_t, pad_max_h, pad_max_w]
# max_latent_size = [((max_tube_size[0]-1) // ae_stride_thw[0] + 1) if extra_1 else (max_tube_size[0] // ae_stride_thw[0]),
# max_tube_size[1] // ae_stride_thw[1],
# max_tube_size[2] // ae_stride_thw[2]]
# max_patchify_latent_size = [((max_latent_size[0]-1) // patch_size_thw[0] + 1) if extra_1 else (max_latent_size[0] // patch_size_thw[0]),
# max_latent_size[1] // patch_size_thw[1],
# max_latent_size[2] // patch_size_thw[2]]
# valid_patchify_latent_size = [[int(math.ceil((i[1]-1) / t_ds_stride)) + 1 if extra_1 else int(math.ceil(i[1] / t_ds_stride)),
# int(math.ceil(i[2] / ds_stride)),
# int(math.ceil(i[3] / ds_stride))] for i in batch_input_size]
# attention_mask = [F.pad(torch.ones(i),
# (0, max_patchify_latent_size[2] - i[2],
# 0, max_patchify_latent_size[1] - i[1],
# 0, max_patchify_latent_size[0] - i[0]), value=0) for i in valid_patchify_latent_size]
# attention_mask = torch.stack(attention_mask) # b t h w
return pad_batch_tubes, attention_mask
|