Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,318 Bytes
4e73d3c 7f2690b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
try:
from data import *
except:
from foleycrafter.models.specvqgan.onset_baseline.data import *
import pdb
import csv
import glob
import h5py
import io
import json
import librosa
import numpy as np
import os
import pickle
from PIL import Image
from PIL import ImageFilter
import random
import scipy
import soundfile as sf
import time
from tqdm import tqdm
import glob
import cv2
import torch
import torch.nn as nn
import torchaudio
import torchvision.transforms as transforms
# import kornia as K
import sys
sys.path.append('..')
class GreatestHitDataset(object):
def __init__(self, args, split='train'):
self.split = split
if split == 'train':
list_sample = './data/greatesthit_train_2.00.json'
elif split == 'val':
list_sample = './data/greatesthit_valid_2.00.json'
elif split == 'test':
list_sample = './data/greatesthit_test_2.00.json'
# save args parameter
self.repeat = args.repeat if split == 'train' else 1
self.max_sample = args.max_sample
self.video_transform = transforms.Compose(
self.generate_video_transform(args))
if isinstance(list_sample, str):
with open(list_sample, "r") as f:
self.list_sample = json.load(f)
if self.max_sample > 0:
self.list_sample = self.list_sample[0:self.max_sample]
self.list_sample = self.list_sample * self.repeat
random.seed(1234)
np.random.seed(1234)
num_sample = len(self.list_sample)
if self.split == 'train':
random.shuffle(self.list_sample)
# self.class_dist = self.unbalanced_dist()
print('Greatesthit Dataloader: # sample of {}: {}'.format(self.split, num_sample))
def __getitem__(self, index):
# import pdb; pdb.set_trace()
info = self.list_sample[index].split('_')[0]
video_path = os.path.join('data', 'greatesthit', 'greatesthit_processed', info)
frame_path = os.path.join(video_path, 'frames')
audio_path = os.path.join(video_path, 'audio')
audio_path = glob.glob(f"{audio_path}/*.wav")[0]
# Unused, consider remove
meta_path = os.path.join(video_path, 'hit_record.json')
if os.path.exists(meta_path):
with open(meta_path, "r") as f:
meta_dict = json.load(f)
audio, audio_sample_rate = sf.read(audio_path, start=0, stop=1000, dtype='float64', always_2d=True)
frame_rate = 15
duration = 2.0
frame_list = glob.glob(f'{frame_path}/*.jpg')
frame_list.sort()
hit_time = float(self.list_sample[index].split('_')[-1]) / 22050
if self.split == 'train':
frame_start = hit_time * frame_rate + np.random.randint(10) - 5
frame_start = max(frame_start, 0)
frame_start = min(frame_start, len(frame_list) - duration * frame_rate)
else:
frame_start = hit_time * frame_rate
frame_start = max(frame_start, 0)
frame_start = min(frame_start, len(frame_list) - duration * frame_rate)
frame_start = int(frame_start)
frame_list = frame_list[frame_start: int(
frame_start + np.ceil(duration * frame_rate))]
audio_start = int(frame_start / frame_rate * audio_sample_rate)
audio_end = int(audio_start + duration * audio_sample_rate)
imgs = self.read_image(frame_list)
audio, audio_rate = sf.read(audio_path, start=audio_start, stop=audio_end, dtype='float64', always_2d=True)
audio = audio.mean(-1)
onsets = librosa.onset.onset_detect(y=audio, sr=audio_rate, units='time', delta=0.3)
onsets = np.rint(onsets * frame_rate).astype(int)
onsets[onsets>29] = 29
label = torch.zeros(len(frame_list))
label[onsets] = 1
batch = {
'frames': imgs,
'label': label
}
return batch
def getitem_test(self, index):
self.__getitem__(index)
def __len__(self):
return len(self.list_sample)
def read_image(self, frame_list):
imgs = []
convert_tensor = transforms.ToTensor()
for img_path in frame_list:
image = Image.open(img_path).convert('RGB')
image = convert_tensor(image)
imgs.append(image.unsqueeze(0))
# (T, C, H ,W)
imgs = torch.cat(imgs, dim=0).squeeze()
imgs = self.video_transform(imgs)
imgs = imgs.permute(1, 0, 2, 3)
# (C, T, H ,W)
return imgs
def generate_video_transform(self, args):
resize_funct = transforms.Resize((128, 128))
if self.split == 'train':
crop_funct = transforms.RandomCrop(
(112, 112))
color_funct = transforms.ColorJitter(
brightness=0.1, contrast=0.1, saturation=0, hue=0)
else:
crop_funct = transforms.CenterCrop(
(112, 112))
color_funct = transforms.Lambda(lambda img: img)
vision_transform_list = [
resize_funct,
crop_funct,
color_funct,
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
return vision_transform_list
|