Spaces:
Running
on
Zero
Running
on
Zero
from matplotlib import collections | |
import json | |
import os | |
import copy | |
import matplotlib.pyplot as plt | |
import torch | |
from torchvision import transforms | |
import numpy as np | |
from tqdm import tqdm | |
from random import sample | |
import torchaudio | |
import logging | |
import collections | |
from glob import glob | |
import sys | |
import albumentations | |
import soundfile | |
sys.path.insert(0, '.') # nopep8 | |
from train import instantiate_from_config | |
from foleycrafter.models.specvqgan.data.transforms import * | |
torchaudio.set_audio_backend("sox_io") | |
logger = logging.getLogger(f'main.{__name__}') | |
SR = 22050 | |
FPS = 15 | |
MAX_SAMPLE_ITER = 10 | |
def non_negative(x): return int(np.round(max(0, x), 0)) | |
def rms(x): return np.sqrt(np.mean(x**2)) | |
def get_GH_data_identifier(video_name, start_idx, split='_'): | |
if isinstance(start_idx, str): | |
return video_name + split + start_idx | |
elif isinstance(start_idx, int): | |
return video_name + split + str(start_idx) | |
else: | |
raise NotImplementedError | |
class Crop(object): | |
def __init__(self, cropped_shape=None, random_crop=False): | |
self.cropped_shape = cropped_shape | |
if cropped_shape is not None: | |
mel_num, spec_len = cropped_shape | |
if random_crop: | |
self.cropper = albumentations.RandomCrop | |
else: | |
self.cropper = albumentations.CenterCrop | |
self.preprocessor = albumentations.Compose([self.cropper(mel_num, spec_len)]) | |
else: | |
self.preprocessor = lambda **kwargs: kwargs | |
def __call__(self, item): | |
item['image'] = self.preprocessor(image=item['image'])['image'] | |
if 'cond_image' in item.keys(): | |
item['cond_image'] = self.preprocessor(image=item['cond_image'])['image'] | |
return item | |
class CropImage(Crop): | |
def __init__(self, *crop_args): | |
super().__init__(*crop_args) | |
class CropFeats(Crop): | |
def __init__(self, *crop_args): | |
super().__init__(*crop_args) | |
def __call__(self, item): | |
item['feature'] = self.preprocessor(image=item['feature'])['image'] | |
return item | |
class CropCoords(Crop): | |
def __init__(self, *crop_args): | |
super().__init__(*crop_args) | |
def __call__(self, item): | |
item['coord'] = self.preprocessor(image=item['coord'])['image'] | |
return item | |
class ResampleFrames(object): | |
def __init__(self, feat_sample_size, times_to_repeat_after_resample=None): | |
self.feat_sample_size = feat_sample_size | |
self.times_to_repeat_after_resample = times_to_repeat_after_resample | |
def __call__(self, item): | |
feat_len = item['feature'].shape[0] | |
## resample | |
assert feat_len >= self.feat_sample_size | |
# evenly spaced points (abcdefghkl -> aoooofoooo) | |
idx = np.linspace(0, feat_len, self.feat_sample_size, dtype=np.int, endpoint=False) | |
# xoooo xoooo -> ooxoo ooxoo | |
shift = feat_len // (self.feat_sample_size + 1) | |
idx = idx + shift | |
## repeat after resampling (abc -> aaaabbbbcccc) | |
if self.times_to_repeat_after_resample is not None and self.times_to_repeat_after_resample > 1: | |
idx = np.repeat(idx, self.times_to_repeat_after_resample) | |
item['feature'] = item['feature'][idx, :] | |
return item | |
class GreatestHitSpecs(torch.utils.data.Dataset): | |
def __init__(self, split, spec_dir_path, spec_len, random_crop, mel_num, | |
spec_crop_len, L=2.0, rand_shift=False, spec_transforms=None, splits_path='./data', | |
meta_path='./data/info_r2plus1d_dim1024_15fps.json'): | |
super().__init__() | |
self.split = split | |
self.specs_dir = spec_dir_path | |
self.spec_transforms = spec_transforms | |
self.splits_path = splits_path | |
self.meta_path = meta_path | |
self.spec_len = spec_len | |
self.rand_shift = rand_shift | |
self.L = L | |
self.spec_take_first = int(math.ceil(860 * (L / 10.) / 32) * 32) | |
self.spec_take_first = 860 if self.spec_take_first > 860 else self.spec_take_first | |
greatesthit_meta = json.load(open(self.meta_path, 'r')) | |
unique_classes = sorted(list(set(ht for ht in greatesthit_meta['hit_type']))) | |
self.label2target = {label: target for target, label in enumerate(unique_classes)} | |
self.target2label = {target: label for label, target in self.label2target.items()} | |
self.video_idx2label = { | |
get_GH_data_identifier(greatesthit_meta['video_name'][i], greatesthit_meta['start_idx'][i]): | |
greatesthit_meta['hit_type'][i] for i in range(len(greatesthit_meta['video_name'])) | |
} | |
self.available_video_hit = list(self.video_idx2label.keys()) | |
self.video_idx2path = { | |
vh: os.path.join(self.specs_dir, | |
vh.replace('_', '_denoised_') + '_' + self.video_idx2label[vh].replace(' ', '_') +'_mel.npy') | |
for vh in self.available_video_hit | |
} | |
self.video_idx2idx = { | |
get_GH_data_identifier(greatesthit_meta['video_name'][i], greatesthit_meta['start_idx'][i]): | |
i for i in range(len(greatesthit_meta['video_name'])) | |
} | |
split_clip_ids_path = os.path.join(splits_path, f'greatesthit_{split}.json') | |
if not os.path.exists(split_clip_ids_path): | |
raise NotImplementedError() | |
clip_video_hit = json.load(open(split_clip_ids_path, 'r')) | |
self.dataset = clip_video_hit | |
spec_crop_len = self.spec_take_first if self.spec_take_first <= spec_crop_len else spec_crop_len | |
self.spec_transforms = transforms.Compose([ | |
CropImage([mel_num, spec_crop_len], random_crop), | |
# transforms.RandomApply([FrequencyMasking(freq_mask_param=20)], p=0), | |
# transforms.RandomApply([TimeMasking(time_mask_param=int(32 * self.L))], p=0) | |
]) | |
self.video2indexes = {} | |
for video_idx in self.dataset: | |
video, start_idx = video_idx.split('_') | |
if video not in self.video2indexes.keys(): | |
self.video2indexes[video] = [] | |
self.video2indexes[video].append(start_idx) | |
for video in self.video2indexes.keys(): | |
if len(self.video2indexes[video]) == 1: # given video contains only one hit | |
self.dataset.remove( | |
get_GH_data_identifier(video, self.video2indexes[video][0]) | |
) | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, idx): | |
item = {} | |
video_idx = self.dataset[idx] | |
spec_path = self.video_idx2path[video_idx] | |
spec = np.load(spec_path) # (80, 860) | |
if self.rand_shift: | |
shift = random.uniform(0, 0.5) | |
spec_shift = int(shift * spec.shape[1] // 10) | |
# Since only the first second is used | |
spec = np.roll(spec, -spec_shift, 1) | |
# concat spec outside dataload | |
item['image'] = 2 * spec - 1 # (80, 860) | |
item['image'] = item['image'][:, :self.spec_take_first] | |
item['file_path'] = spec_path | |
item['label'] = self.video_idx2label[video_idx] | |
item['target'] = self.label2target[item['label']] | |
if self.spec_transforms is not None: | |
item = self.spec_transforms(item) | |
return item | |
class GreatestHitSpecsTrain(GreatestHitSpecs): | |
def __init__(self, specs_dataset_cfg): | |
super().__init__('train', **specs_dataset_cfg) | |
class GreatestHitSpecsValidation(GreatestHitSpecs): | |
def __init__(self, specs_dataset_cfg): | |
super().__init__('val', **specs_dataset_cfg) | |
class GreatestHitSpecsTest(GreatestHitSpecs): | |
def __init__(self, specs_dataset_cfg): | |
super().__init__('test', **specs_dataset_cfg) | |
class GreatestHitWave(torch.utils.data.Dataset): | |
def __init__(self, split, wav_dir, random_crop, mel_num, spec_crop_len, spec_len, | |
L=2.0, splits_path='./data', rand_shift=True, | |
data_path='data/greatesthit/greatesthit-process-resized'): | |
super().__init__() | |
self.split = split | |
self.wav_dir = wav_dir | |
self.splits_path = splits_path | |
self.data_path = data_path | |
self.L = L | |
self.rand_shift = rand_shift | |
split_clip_ids_path = os.path.join(splits_path, f'greatesthit_{split}.json') | |
if not os.path.exists(split_clip_ids_path): | |
raise NotImplementedError() | |
clip_video_hit = json.load(open(split_clip_ids_path, 'r')) | |
video_name = list(set([vidx.split('_')[0] for vidx in clip_video_hit])) | |
self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames'))) // 2 for v in video_name} | |
self.left_over = int(FPS * L + 1) | |
self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_denoised_resampled.wav') for v in video_name} | |
self.dataset = clip_video_hit | |
self.video2indexes = {} | |
for video_idx in self.dataset: | |
video, start_idx = video_idx.split('_') | |
if video not in self.video2indexes.keys(): | |
self.video2indexes[video] = [] | |
self.video2indexes[video].append(start_idx) | |
for video in self.video2indexes.keys(): | |
if len(self.video2indexes[video]) == 1: # given video contains only one hit | |
self.dataset.remove( | |
get_GH_data_identifier(video, self.video2indexes[video][0]) | |
) | |
self.wav_transforms = transforms.Compose([ | |
MakeMono(), | |
Padding(target_len=int(SR * self.L)), | |
]) | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, idx): | |
item = {} | |
video_idx = self.dataset[idx] | |
video, start_idx = video_idx.split('_') | |
start_idx = int(start_idx) | |
if self.rand_shift: | |
shift = int(random.uniform(-0.5, 0.5) * SR) | |
start_idx = non_negative(start_idx + shift) | |
wave_path = self.video_audio_path[video] | |
wav, sr = soundfile.read(wave_path, frames=int(SR * self.L), start=start_idx) | |
assert sr == SR | |
wav = self.wav_transforms(wav) | |
item['image'] = wav # (44100,) | |
# item['wav'] = wav | |
item['file_path_wav_'] = wave_path | |
item['label'] = 'None' | |
item['target'] = 'None' | |
return item | |
class GreatestHitWaveTrain(GreatestHitWave): | |
def __init__(self, specs_dataset_cfg): | |
super().__init__('train', **specs_dataset_cfg) | |
class GreatestHitWaveValidation(GreatestHitWave): | |
def __init__(self, specs_dataset_cfg): | |
super().__init__('val', **specs_dataset_cfg) | |
class GreatestHitWaveTest(GreatestHitWave): | |
def __init__(self, specs_dataset_cfg): | |
super().__init__('test', **specs_dataset_cfg) | |
class CondGreatestHitSpecsCondOnImage(torch.utils.data.Dataset): | |
def __init__(self, split, specs_dir, spec_len, feat_len, feat_depth, feat_crop_len, random_crop, mel_num, spec_crop_len, | |
vqgan_L=10.0, L=1.0, rand_shift=False, spec_transforms=None, frame_transforms=None, splits_path='./data', | |
meta_path='./data/info_r2plus1d_dim1024_15fps.json', frame_path='data/greatesthit/greatesthit_processed', | |
p_outside_cond=0., p_audio_aug=0.5): | |
super().__init__() | |
self.split = split | |
self.specs_dir = specs_dir | |
self.spec_transforms = spec_transforms | |
self.frame_transforms = frame_transforms | |
self.splits_path = splits_path | |
self.meta_path = meta_path | |
self.frame_path = frame_path | |
self.feat_len = feat_len | |
self.feat_depth = feat_depth | |
self.feat_crop_len = feat_crop_len | |
self.spec_len = spec_len | |
self.rand_shift = rand_shift | |
self.L = L | |
self.spec_take_first = int(math.ceil(860 * (vqgan_L / 10.) / 32) * 32) | |
self.spec_take_first = 860 if self.spec_take_first > 860 else self.spec_take_first | |
self.p_outside_cond = torch.tensor(p_outside_cond) | |
greatesthit_meta = json.load(open(self.meta_path, 'r')) | |
unique_classes = sorted(list(set(ht for ht in greatesthit_meta['hit_type']))) | |
self.label2target = {label: target for target, label in enumerate(unique_classes)} | |
self.target2label = {target: label for label, target in self.label2target.items()} | |
self.video_idx2label = { | |
get_GH_data_identifier(greatesthit_meta['video_name'][i], greatesthit_meta['start_idx'][i]): | |
greatesthit_meta['hit_type'][i] for i in range(len(greatesthit_meta['video_name'])) | |
} | |
self.available_video_hit = list(self.video_idx2label.keys()) | |
self.video_idx2path = { | |
vh: os.path.join(self.specs_dir, | |
vh.replace('_', '_denoised_') + '_' + self.video_idx2label[vh].replace(' ', '_') +'_mel.npy') | |
for vh in self.available_video_hit | |
} | |
for value in self.video_idx2path.values(): | |
assert os.path.exists(value) | |
self.video_idx2idx = { | |
get_GH_data_identifier(greatesthit_meta['video_name'][i], greatesthit_meta['start_idx'][i]): | |
i for i in range(len(greatesthit_meta['video_name'])) | |
} | |
split_clip_ids_path = os.path.join(splits_path, f'greatesthit_{split}.json') | |
if not os.path.exists(split_clip_ids_path): | |
self.make_split_files() | |
clip_video_hit = json.load(open(split_clip_ids_path, 'r')) | |
self.dataset = clip_video_hit | |
spec_crop_len = self.spec_take_first if self.spec_take_first <= spec_crop_len else spec_crop_len | |
self.spec_transforms = transforms.Compose([ | |
CropImage([mel_num, spec_crop_len], random_crop), | |
# transforms.RandomApply([FrequencyMasking(freq_mask_param=20)], p=p_audio_aug), | |
# transforms.RandomApply([TimeMasking(time_mask_param=int(32 * self.L))], p=p_audio_aug) | |
]) | |
if self.frame_transforms == None: | |
self.frame_transforms = transforms.Compose([ | |
Resize3D(128), | |
RandomResizedCrop3D(112, scale=(0.5, 1.0)), | |
RandomHorizontalFlip3D(), | |
ColorJitter3D(brightness=0.1, saturation=0.1), | |
ToTensor3D(), | |
Normalize3D(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]), | |
]) | |
self.video2indexes = {} | |
for video_idx in self.dataset: | |
video, start_idx = video_idx.split('_') | |
if video not in self.video2indexes.keys(): | |
self.video2indexes[video] = [] | |
self.video2indexes[video].append(start_idx) | |
for video in self.video2indexes.keys(): | |
if len(self.video2indexes[video]) == 1: # given video contains only one hit | |
self.dataset.remove( | |
get_GH_data_identifier(video, self.video2indexes[video][0]) | |
) | |
clip_classes = [self.label2target[self.video_idx2label[vh]] for vh in clip_video_hit] | |
class2count = collections.Counter(clip_classes) | |
self.class_counts = torch.tensor([class2count[cls] for cls in range(len(class2count))]) | |
if self.L != 1.0: | |
print(split, L) | |
self.validate_data() | |
self.video2indexes = {} | |
for video_idx in self.dataset: | |
video, start_idx = video_idx.split('_') | |
if video not in self.video2indexes.keys(): | |
self.video2indexes[video] = [] | |
self.video2indexes[video].append(start_idx) | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, idx): | |
item = {} | |
try: | |
video_idx = self.dataset[idx] | |
spec_path = self.video_idx2path[video_idx] | |
spec = np.load(spec_path) # (80, 860) | |
video, start_idx = video_idx.split('_') | |
frame_path = os.path.join(self.frame_path, video, 'frames') | |
start_frame_idx = non_negative(FPS * int(start_idx)/SR) | |
end_frame_idx = non_negative(start_frame_idx + FPS * self.L) | |
if self.rand_shift: | |
shift = random.uniform(0, 0.5) | |
spec_shift = int(shift * spec.shape[1] // 10) | |
# Since only the first second is used | |
spec = np.roll(spec, -spec_shift, 1) | |
start_frame_idx += int(FPS * shift) | |
end_frame_idx += int(FPS * shift) | |
frames = [Image.open(os.path.join( | |
frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in | |
range(start_frame_idx, end_frame_idx)] | |
# Sample condition | |
if torch.all(torch.bernoulli(self.p_outside_cond) == 1.): | |
# Sample condition from outside video | |
all_idx = set(list(range(len(self.dataset)))) | |
all_idx.remove(idx) | |
cond_video_idx = self.dataset[sample(all_idx, k=1)[0]] | |
cond_video, cond_start_idx = cond_video_idx.split('_') | |
else: | |
cond_video = video | |
video_hits_idx = copy.copy(self.video2indexes[video]) | |
video_hits_idx.remove(start_idx) | |
cond_start_idx = sample(video_hits_idx, k=1)[0] | |
cond_video_idx = get_GH_data_identifier(cond_video, cond_start_idx) | |
cond_spec_path = self.video_idx2path[cond_video_idx] | |
cond_spec = np.load(cond_spec_path) # (80, 860) | |
cond_video, cond_start_idx = cond_video_idx.split('_') | |
cond_frame_path = os.path.join(self.frame_path, cond_video, 'frames') | |
cond_start_frame_idx = non_negative(FPS * int(cond_start_idx)/SR) | |
cond_end_frame_idx = non_negative(cond_start_frame_idx + FPS * self.L) | |
if self.rand_shift: | |
cond_shift = random.uniform(0, 0.5) | |
cond_spec_shift = int(cond_shift * cond_spec.shape[1] // 10) | |
# Since only the first second is used | |
cond_spec = np.roll(cond_spec, -cond_spec_shift, 1) | |
cond_start_frame_idx += int(FPS * cond_shift) | |
cond_end_frame_idx += int(FPS * cond_shift) | |
cond_frames = [Image.open(os.path.join( | |
cond_frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in | |
range(cond_start_frame_idx, cond_end_frame_idx)] | |
# concat spec outside dataload | |
item['image'] = 2 * spec - 1 # (80, 860) | |
item['cond_image'] = 2 * cond_spec - 1 # (80, 860) | |
item['image'] = item['image'][:, :self.spec_take_first] | |
item['cond_image'] = item['cond_image'][:, :self.spec_take_first] | |
item['file_path_specs_'] = spec_path | |
item['file_path_cond_specs_'] = cond_spec_path | |
if self.frame_transforms is not None: | |
cond_frames = self.frame_transforms(cond_frames) | |
frames = self.frame_transforms(frames) | |
item['feature'] = np.stack(cond_frames + frames, axis=0) # (30 * L, 112, 112, 3) | |
item['file_path_feats_'] = (frame_path, start_frame_idx) | |
item['file_path_cond_feats_'] = (cond_frame_path, cond_start_frame_idx) | |
item['label'] = self.video_idx2label[video_idx] | |
item['target'] = self.label2target[item['label']] | |
if self.spec_transforms is not None: | |
item = self.spec_transforms(item) | |
except Exception: | |
print(sys.exc_info()[2]) | |
print('!!!!!!!!!!!!!!!!!!!!', video_idx, cond_video_idx) | |
print('!!!!!!!!!!!!!!!!!!!!', end_frame_idx, cond_end_frame_idx) | |
exit(1) | |
return item | |
def validate_data(self): | |
original_len = len(self.dataset) | |
valid_dataset = [] | |
for video_idx in tqdm(self.dataset): | |
video, start_idx = video_idx.split('_') | |
frame_path = os.path.join(self.frame_path, video, 'frames') | |
start_frame_idx = non_negative(FPS * int(start_idx)/SR) | |
end_frame_idx = non_negative(start_frame_idx + FPS * (self.L + 0.6)) | |
if os.path.exists(os.path.join(frame_path, f'frame{end_frame_idx:0>6d}.jpg')): | |
valid_dataset.append(video_idx) | |
else: | |
self.video2indexes[video].remove(start_idx) | |
for video_idx in valid_dataset: | |
video, start_idx = video_idx.split('_') | |
if len(self.video2indexes[video]) == 1: | |
valid_dataset.remove(video_idx) | |
if original_len != len(valid_dataset): | |
print(f'Validated dataset with enough frames: {len(valid_dataset)}') | |
self.dataset = valid_dataset | |
split_clip_ids_path = os.path.join(self.splits_path, f'greatesthit_{self.split}_{self.L:.2f}.json') | |
if not os.path.exists(split_clip_ids_path): | |
with open(split_clip_ids_path, 'w') as f: | |
json.dump(valid_dataset, f) | |
def make_split_files(self, ratio=[0.85, 0.1, 0.05]): | |
random.seed(1337) | |
print(f'The split files do not exist @ {self.splits_path}. Calculating the new ones.') | |
# The downloaded videos (some went missing on YouTube and no longer available) | |
available_mel_paths = set(glob(os.path.join(self.specs_dir, '*_mel.npy'))) | |
self.available_video_hit = [vh for vh in self.available_video_hit if self.video_idx2path[vh] in available_mel_paths] | |
all_video = list(self.video2indexes.keys()) | |
print(f'The number of clips available after download: {len(self.available_video_hit)}') | |
print(f'The number of videos available after download: {len(all_video)}') | |
available_idx = list(range(len(all_video))) | |
random.shuffle(available_idx) | |
assert sum(ratio) == 1. | |
cut_train = int(ratio[0] * len(all_video)) | |
cut_test = cut_train + int(ratio[1] * len(all_video)) | |
train_idx = available_idx[:cut_train] | |
test_idx = available_idx[cut_train:cut_test] | |
valid_idx = available_idx[cut_test:] | |
train_video = [all_video[i] for i in train_idx] | |
test_video = [all_video[i] for i in test_idx] | |
valid_video = [all_video[i] for i in valid_idx] | |
train_video_hit = [] | |
for v in train_video: | |
train_video_hit += [get_GH_data_identifier(v, hit_idx) for hit_idx in self.video2indexes[v]] | |
test_video_hit = [] | |
for v in test_video: | |
test_video_hit += [get_GH_data_identifier(v, hit_idx) for hit_idx in self.video2indexes[v]] | |
valid_video_hit = [] | |
for v in valid_video: | |
valid_video_hit += [get_GH_data_identifier(v, hit_idx) for hit_idx in self.video2indexes[v]] | |
# mix train and valid for better validation loss | |
mixed = train_video_hit + valid_video_hit | |
random.shuffle(mixed) | |
split = int(len(mixed) * ratio[0] / (ratio[0] + ratio[2])) | |
train_video_hit = mixed[:split] | |
valid_video_hit = mixed[split:] | |
with open(os.path.join(self.splits_path, 'greatesthit_train.json'), 'w') as train_file,\ | |
open(os.path.join(self.splits_path, 'greatesthit_test.json'), 'w') as test_file,\ | |
open(os.path.join(self.splits_path, 'greatesthit_valid.json'), 'w') as valid_file: | |
json.dump(train_video_hit, train_file) | |
json.dump(test_video_hit, test_file) | |
json.dump(valid_video_hit, valid_file) | |
print(f'Put {len(train_idx)} clips to the train set and saved it to ./data/greatesthit_train.json') | |
print(f'Put {len(test_idx)} clips to the test set and saved it to ./data/greatesthit_test.json') | |
print(f'Put {len(valid_idx)} clips to the valid set and saved it to ./data/greatesthit_valid.json') | |
class CondGreatestHitSpecsCondOnImageTrain(CondGreatestHitSpecsCondOnImage): | |
def __init__(self, dataset_cfg): | |
train_transforms = transforms.Compose([ | |
Resize3D(256), | |
RandomResizedCrop3D(224, scale=(0.5, 1.0)), | |
RandomHorizontalFlip3D(), | |
ColorJitter3D(brightness=0.1, saturation=0.1), | |
ToTensor3D(), | |
Normalize3D(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]), | |
]) | |
super().__init__('train', frame_transforms=train_transforms, **dataset_cfg) | |
class CondGreatestHitSpecsCondOnImageValidation(CondGreatestHitSpecsCondOnImage): | |
def __init__(self, dataset_cfg): | |
valid_transforms = transforms.Compose([ | |
Resize3D(256), | |
CenterCrop3D(224), | |
ToTensor3D(), | |
Normalize3D(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]), | |
]) | |
super().__init__('val', frame_transforms=valid_transforms, **dataset_cfg) | |
class CondGreatestHitSpecsCondOnImageTest(CondGreatestHitSpecsCondOnImage): | |
def __init__(self, dataset_cfg): | |
test_transforms = transforms.Compose([ | |
Resize3D(256), | |
CenterCrop3D(224), | |
ToTensor3D(), | |
Normalize3D(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]), | |
]) | |
super().__init__('test', frame_transforms=test_transforms, **dataset_cfg) | |
class CondGreatestHitWaveCondOnImage(torch.utils.data.Dataset): | |
def __init__(self, split, wav_dir, spec_len, random_crop, mel_num, spec_crop_len, | |
L=2.0, frame_transforms=None, splits_path='./data', | |
data_path='data/greatesthit/greatesthit-process-resized', | |
p_outside_cond=0., p_audio_aug=0.5, rand_shift=True): | |
super().__init__() | |
self.split = split | |
self.wav_dir = wav_dir | |
self.frame_transforms = frame_transforms | |
self.splits_path = splits_path | |
self.data_path = data_path | |
self.spec_len = spec_len | |
self.L = L | |
self.rand_shift = rand_shift | |
self.p_outside_cond = torch.tensor(p_outside_cond) | |
split_clip_ids_path = os.path.join(splits_path, f'greatesthit_{split}.json') | |
if not os.path.exists(split_clip_ids_path): | |
raise NotImplementedError() | |
clip_video_hit = json.load(open(split_clip_ids_path, 'r')) | |
video_name = list(set([vidx.split('_')[0] for vidx in clip_video_hit])) | |
self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames')))//2 for v in video_name} | |
self.left_over = int(FPS * L + 1) | |
self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_denoised_resampled.wav') for v in video_name} | |
self.dataset = clip_video_hit | |
self.video2indexes = {} | |
for video_idx in self.dataset: | |
video, start_idx = video_idx.split('_') | |
if video not in self.video2indexes.keys(): | |
self.video2indexes[video] = [] | |
self.video2indexes[video].append(start_idx) | |
for video in self.video2indexes.keys(): | |
if len(self.video2indexes[video]) == 1: # given video contains only one hit | |
self.dataset.remove( | |
get_GH_data_identifier(video, self.video2indexes[video][0]) | |
) | |
self.wav_transforms = transforms.Compose([ | |
MakeMono(), | |
Padding(target_len=int(SR * self.L)), | |
]) | |
if self.frame_transforms == None: | |
self.frame_transforms = transforms.Compose([ | |
Resize3D(256), | |
RandomResizedCrop3D(224, scale=(0.5, 1.0)), | |
RandomHorizontalFlip3D(), | |
ColorJitter3D(brightness=0.1, saturation=0.1), | |
ToTensor3D(), | |
Normalize3D(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]), | |
]) | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, idx): | |
item = {} | |
video_idx = self.dataset[idx] | |
video, start_idx = video_idx.split('_') | |
start_idx = int(start_idx) | |
frame_path = os.path.join(self.data_path, video, 'frames') | |
start_frame_idx = non_negative(FPS * int(start_idx)/SR) | |
if self.rand_shift: | |
shift = random.uniform(-0.5, 0.5) | |
start_frame_idx = non_negative(start_frame_idx + int(FPS * shift)) | |
start_idx = non_negative(start_idx + int(SR * shift)) | |
if start_frame_idx > self.video_frame_cnt[video] - self.left_over: | |
start_frame_idx = self.video_frame_cnt[video] - self.left_over | |
start_idx = non_negative(SR * (start_frame_idx / FPS)) | |
end_frame_idx = non_negative(start_frame_idx + FPS * self.L) | |
# target | |
wave_path = self.video_audio_path[video] | |
frames = [Image.open(os.path.join( | |
frame_path, f'frame{i+1:0>6d}')).convert('RGB') for i in | |
range(start_frame_idx, end_frame_idx)] | |
wav, sr = soundfile.read(wave_path, frames=int(SR * self.L), start=start_idx) | |
assert sr == SR | |
wav = self.wav_transforms(wav) | |
# cond | |
if torch.all(torch.bernoulli(self.p_outside_cond) == 1.): | |
all_idx = set(list(range(len(self.dataset)))) | |
all_idx.remove(idx) | |
cond_video_idx = self.dataset[sample(all_idx, k=1)[0]] | |
cond_video, cond_start_idx = cond_video_idx.split('_') | |
else: | |
cond_video = video | |
video_hits_idx = copy.copy(self.video2indexes[video]) | |
if str(start_idx) in video_hits_idx: | |
video_hits_idx.remove(str(start_idx)) | |
cond_start_idx = sample(video_hits_idx, k=1)[0] | |
cond_video_idx = get_GH_data_identifier(cond_video, cond_start_idx) | |
cond_video, cond_start_idx = cond_video_idx.split('_') | |
cond_start_idx = int(cond_start_idx) | |
cond_frame_path = os.path.join(self.data_path, cond_video, 'frames') | |
cond_start_frame_idx = non_negative(FPS * int(cond_start_idx)/SR) | |
cond_wave_path = self.video_audio_path[cond_video] | |
if self.rand_shift: | |
cond_shift = random.uniform(-0.5, 0.5) | |
cond_start_frame_idx = non_negative(cond_start_frame_idx + int(FPS * cond_shift)) | |
cond_start_idx = non_negative(cond_start_idx + int(shift * SR)) | |
if cond_start_frame_idx > self.video_frame_cnt[cond_video] - self.left_over: | |
cond_start_frame_idx = self.video_frame_cnt[cond_video] - self.left_over | |
cond_start_idx = non_negative(SR * (cond_start_frame_idx / FPS)) | |
cond_end_frame_idx = non_negative(cond_start_frame_idx + FPS * self.L) | |
cond_frames = [Image.open(os.path.join( | |
cond_frame_path, f'frame{i+1:0>6d}')).convert('RGB') for i in | |
range(cond_start_frame_idx, cond_end_frame_idx)] | |
cond_wav, _ = soundfile.read(cond_wave_path, frames=int(SR * self.L), start=cond_start_idx) | |
cond_wav = self.wav_transforms(cond_wav) | |
item['image'] = wav # (44100,) | |
item['cond_image'] = cond_wav # (44100,) | |
item['file_path_wav_'] = wave_path | |
item['file_path_cond_wav_'] = cond_wave_path | |
if self.frame_transforms is not None: | |
cond_frames = self.frame_transforms(cond_frames) | |
frames = self.frame_transforms(frames) | |
item['feature'] = np.stack(cond_frames + frames, axis=0) # (30 * L, 112, 112, 3) | |
item['file_path_feats_'] = (frame_path, start_idx) | |
item['file_path_cond_feats_'] = (cond_frame_path, cond_start_idx) | |
item['label'] = 'None' | |
item['target'] = 'None' | |
return item | |
def validate_data(self): | |
raise NotImplementedError() | |
def make_split_files(self, ratio=[0.85, 0.1, 0.05]): | |
random.seed(1337) | |
print(f'The split files do not exist @ {self.splits_path}. Calculating the new ones.') | |
all_video = sorted(os.listdir(self.data_path)) | |
print(f'The number of videos available after download: {len(all_video)}') | |
available_idx = list(range(len(all_video))) | |
random.shuffle(available_idx) | |
assert sum(ratio) == 1. | |
cut_train = int(ratio[0] * len(all_video)) | |
cut_test = cut_train + int(ratio[1] * len(all_video)) | |
train_idx = available_idx[:cut_train] | |
test_idx = available_idx[cut_train:cut_test] | |
valid_idx = available_idx[cut_test:] | |
train_video = [all_video[i] for i in train_idx] | |
test_video = [all_video[i] for i in test_idx] | |
valid_video = [all_video[i] for i in valid_idx] | |
with open(os.path.join(self.splits_path, 'greatesthit_video_train.json'), 'w') as train_file,\ | |
open(os.path.join(self.splits_path, 'greatesthit_video_test.json'), 'w') as test_file,\ | |
open(os.path.join(self.splits_path, 'greatesthit_video_valid.json'), 'w') as valid_file: | |
json.dump(train_video, train_file) | |
json.dump(test_video, test_file) | |
json.dump(valid_video, valid_file) | |
print(f'Put {len(train_idx)} videos to the train set and saved it to ./data/greatesthit_video_train.json') | |
print(f'Put {len(test_idx)} videos to the test set and saved it to ./data/greatesthit_video_test.json') | |
print(f'Put {len(valid_idx)} videos to the valid set and saved it to ./data/greatesthit_video_valid.json') | |
class CondGreatestHitWaveCondOnImageTrain(CondGreatestHitWaveCondOnImage): | |
def __init__(self, dataset_cfg): | |
train_transforms = transforms.Compose([ | |
Resize3D(128), | |
RandomResizedCrop3D(112, scale=(0.5, 1.0)), | |
RandomHorizontalFlip3D(), | |
ColorJitter3D(brightness=0.4, saturation=0.4, contrast=0.2, hue=0.1), | |
ToTensor3D(), | |
Normalize3D(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]), | |
]) | |
super().__init__('train', frame_transforms=train_transforms, **dataset_cfg) | |
class CondGreatestHitWaveCondOnImageValidation(CondGreatestHitWaveCondOnImage): | |
def __init__(self, dataset_cfg): | |
valid_transforms = transforms.Compose([ | |
Resize3D(128), | |
CenterCrop3D(112), | |
ToTensor3D(), | |
Normalize3D(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]), | |
]) | |
super().__init__('val', frame_transforms=valid_transforms, **dataset_cfg) | |
class CondGreatestHitWaveCondOnImageTest(CondGreatestHitWaveCondOnImage): | |
def __init__(self, dataset_cfg): | |
test_transforms = transforms.Compose([ | |
Resize3D(128), | |
CenterCrop3D(112), | |
ToTensor3D(), | |
Normalize3D(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]), | |
]) | |
super().__init__('test', frame_transforms=test_transforms, **dataset_cfg) | |
class GreatestHitWaveCondOnImage(torch.utils.data.Dataset): | |
def __init__(self, split, wav_dir, spec_len, random_crop, mel_num, spec_crop_len, | |
L=2.0, frame_transforms=None, splits_path='./data', | |
data_path='data/greatesthit/greatesthit-process-resized', | |
p_outside_cond=0., p_audio_aug=0.5, rand_shift=True): | |
super().__init__() | |
self.split = split | |
self.wav_dir = wav_dir | |
self.frame_transforms = frame_transforms | |
self.splits_path = splits_path | |
self.data_path = data_path | |
self.spec_len = spec_len | |
self.L = L | |
self.rand_shift = rand_shift | |
self.p_outside_cond = torch.tensor(p_outside_cond) | |
split_clip_ids_path = os.path.join(splits_path, f'greatesthit_{split}.json') | |
if not os.path.exists(split_clip_ids_path): | |
raise NotImplementedError() | |
clip_video_hit = json.load(open(split_clip_ids_path, 'r')) | |
video_name = list(set([vidx.split('_')[0] for vidx in clip_video_hit])) | |
self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames')))//2 for v in video_name} | |
self.left_over = int(FPS * L + 1) | |
self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_denoised_resampled.wav') for v in video_name} | |
self.dataset = clip_video_hit | |
self.video2indexes = {} | |
for video_idx in self.dataset: | |
video, start_idx = video_idx.split('_') | |
if video not in self.video2indexes.keys(): | |
self.video2indexes[video] = [] | |
self.video2indexes[video].append(start_idx) | |
for video in self.video2indexes.keys(): | |
if len(self.video2indexes[video]) == 1: # given video contains only one hit | |
self.dataset.remove( | |
get_GH_data_identifier(video, self.video2indexes[video][0]) | |
) | |
self.wav_transforms = transforms.Compose([ | |
MakeMono(), | |
Padding(target_len=int(SR * self.L)), | |
]) | |
if self.frame_transforms == None: | |
self.frame_transforms = transforms.Compose([ | |
Resize3D(256), | |
RandomResizedCrop3D(224, scale=(0.5, 1.0)), | |
RandomHorizontalFlip3D(), | |
ColorJitter3D(brightness=0.1, saturation=0.1), | |
ToTensor3D(), | |
Normalize3D(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]), | |
]) | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, idx): | |
item = {} | |
video_idx = self.dataset[idx] | |
video, start_idx = video_idx.split('_') | |
start_idx = int(start_idx) | |
frame_path = os.path.join(self.data_path, video, 'frames') | |
start_frame_idx = non_negative(FPS * int(start_idx)/SR) | |
if self.rand_shift: | |
shift = random.uniform(-0.5, 0.5) | |
start_frame_idx = non_negative(start_frame_idx + int(FPS * shift)) | |
start_idx = non_negative(start_idx + int(SR * shift)) | |
if start_frame_idx > self.video_frame_cnt[video] - self.left_over: | |
start_frame_idx = self.video_frame_cnt[video] - self.left_over | |
start_idx = non_negative(SR * (start_frame_idx / FPS)) | |
end_frame_idx = non_negative(start_frame_idx + FPS * self.L) | |
# target | |
wave_path = self.video_audio_path[video] | |
frames = [Image.open(os.path.join( | |
frame_path, f'frame{i+1:0>6d}')).convert('RGB') for i in | |
range(start_frame_idx, end_frame_idx)] | |
wav, sr = soundfile.read(wave_path, frames=int(SR * self.L), start=start_idx) | |
assert sr == SR | |
wav = self.wav_transforms(wav) | |
item['image'] = wav # (44100,) | |
item['file_path_wav_'] = wave_path | |
if self.frame_transforms is not None: | |
frames = self.frame_transforms(frames) | |
item['feature'] = torch.stack(frames, dim=0) # (15 * L, 112, 112, 3) | |
item['file_path_feats_'] = (frame_path, start_idx) | |
item['label'] = 'None' | |
item['target'] = 'None' | |
return item | |
def validate_data(self): | |
raise NotImplementedError() | |
def make_split_files(self, ratio=[0.85, 0.1, 0.05]): | |
random.seed(1337) | |
print(f'The split files do not exist @ {self.splits_path}. Calculating the new ones.') | |
all_video = sorted(os.listdir(self.data_path)) | |
print(f'The number of videos available after download: {len(all_video)}') | |
available_idx = list(range(len(all_video))) | |
random.shuffle(available_idx) | |
assert sum(ratio) == 1. | |
cut_train = int(ratio[0] * len(all_video)) | |
cut_test = cut_train + int(ratio[1] * len(all_video)) | |
train_idx = available_idx[:cut_train] | |
test_idx = available_idx[cut_train:cut_test] | |
valid_idx = available_idx[cut_test:] | |
train_video = [all_video[i] for i in train_idx] | |
test_video = [all_video[i] for i in test_idx] | |
valid_video = [all_video[i] for i in valid_idx] | |
with open(os.path.join(self.splits_path, 'greatesthit_video_train.json'), 'w') as train_file,\ | |
open(os.path.join(self.splits_path, 'greatesthit_video_test.json'), 'w') as test_file,\ | |
open(os.path.join(self.splits_path, 'greatesthit_video_valid.json'), 'w') as valid_file: | |
json.dump(train_video, train_file) | |
json.dump(test_video, test_file) | |
json.dump(valid_video, valid_file) | |
print(f'Put {len(train_idx)} videos to the train set and saved it to ./data/greatesthit_video_train.json') | |
print(f'Put {len(test_idx)} videos to the test set and saved it to ./data/greatesthit_video_test.json') | |
print(f'Put {len(valid_idx)} videos to the valid set and saved it to ./data/greatesthit_video_valid.json') | |
class GreatestHitWaveCondOnImageTrain(GreatestHitWaveCondOnImage): | |
def __init__(self, dataset_cfg): | |
train_transforms = transforms.Compose([ | |
Resize3D(128), | |
RandomResizedCrop3D(112, scale=(0.5, 1.0)), | |
RandomHorizontalFlip3D(), | |
ColorJitter3D(brightness=0.4, saturation=0.4, contrast=0.2, hue=0.1), | |
ToTensor3D(), | |
Normalize3D(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]), | |
]) | |
super().__init__('train', frame_transforms=train_transforms, **dataset_cfg) | |
class GreatestHitWaveCondOnImageValidation(GreatestHitWaveCondOnImage): | |
def __init__(self, dataset_cfg): | |
valid_transforms = transforms.Compose([ | |
Resize3D(128), | |
CenterCrop3D(112), | |
ToTensor3D(), | |
Normalize3D(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]), | |
]) | |
super().__init__('val', frame_transforms=valid_transforms, **dataset_cfg) | |
class GreatestHitWaveCondOnImageTest(GreatestHitWaveCondOnImage): | |
def __init__(self, dataset_cfg): | |
test_transforms = transforms.Compose([ | |
Resize3D(128), | |
CenterCrop3D(112), | |
ToTensor3D(), | |
Normalize3D(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]), | |
]) | |
super().__init__('test', frame_transforms=test_transforms, **dataset_cfg) | |
def draw_spec(spec, dest, cmap='magma'): | |
plt.imshow(spec, cmap=cmap, origin='lower') | |
plt.axis('off') | |
plt.savefig(dest, bbox_inches='tight', pad_inches=0., dpi=300) | |
plt.close() | |
if __name__ == '__main__': | |
import sys | |
from omegaconf import OmegaConf | |
# cfg = OmegaConf.load('configs/greatesthit_transformer_with_vNet_randshift_2s_GH_vqgan_no_earlystop.yaml') | |
cfg = OmegaConf.load('configs/greatesthit_codebook.yaml') | |
data = instantiate_from_config(cfg.data) | |
data.prepare_data() | |
data.setup() | |
print(len(data.datasets['train'])) | |
print(data.datasets['train'][24]) | |