File size: 5,318 Bytes
4e73d3c
 
 
 
7f2690b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
try:
    from data import *
except:
    from foleycrafter.models.specvqgan.onset_baseline.data import *
import pdb
import csv
import glob
import h5py
import io
import json
import librosa
import numpy as np
import os
import pickle
from PIL import Image
from PIL import ImageFilter
import random
import scipy
import soundfile as sf
import time
from tqdm import tqdm
import glob
import cv2

import torch
import torch.nn as nn
import torchaudio
import torchvision.transforms as transforms
# import kornia as K
import sys
sys.path.append('..')


class GreatestHitDataset(object):
    def __init__(self, args, split='train'):
        self.split = split
        if split == 'train':
            list_sample = './data/greatesthit_train_2.00.json'
        elif split == 'val':
            list_sample = './data/greatesthit_valid_2.00.json'
        elif split == 'test':
            list_sample = './data/greatesthit_test_2.00.json'

        # save args parameter
        self.repeat = args.repeat if split == 'train' else 1
        self.max_sample = args.max_sample

        self.video_transform = transforms.Compose(
            self.generate_video_transform(args))
        
        if isinstance(list_sample, str):
            with open(list_sample, "r") as f:
                self.list_sample = json.load(f)

        if self.max_sample > 0:
            self.list_sample = self.list_sample[0:self.max_sample]
        self.list_sample = self.list_sample * self.repeat

        random.seed(1234)
        np.random.seed(1234)
        num_sample = len(self.list_sample)
        if self.split == 'train':
            random.shuffle(self.list_sample)

        # self.class_dist = self.unbalanced_dist()
        print('Greatesthit Dataloader: # sample of {}: {}'.format(self.split, num_sample))


    def __getitem__(self, index):
        # import pdb; pdb.set_trace()
        info = self.list_sample[index].split('_')[0]
        video_path = os.path.join('data', 'greatesthit', 'greatesthit_processed', info)
        frame_path = os.path.join(video_path, 'frames')
        audio_path = os.path.join(video_path, 'audio')
        audio_path = glob.glob(f"{audio_path}/*.wav")[0]
        # Unused, consider remove
        meta_path = os.path.join(video_path, 'hit_record.json')
        if os.path.exists(meta_path):
            with open(meta_path, "r") as f:
                meta_dict = json.load(f)

        audio, audio_sample_rate = sf.read(audio_path, start=0, stop=1000, dtype='float64', always_2d=True)
        frame_rate = 15
        duration = 2.0
        frame_list = glob.glob(f'{frame_path}/*.jpg')
        frame_list.sort()

        hit_time = float(self.list_sample[index].split('_')[-1]) / 22050
        if self.split == 'train':
            frame_start = hit_time * frame_rate + np.random.randint(10) - 5
            frame_start = max(frame_start, 0)
            frame_start = min(frame_start, len(frame_list) - duration * frame_rate)
            
        else:
            frame_start = hit_time * frame_rate
            frame_start = max(frame_start, 0)
            frame_start = min(frame_start, len(frame_list) - duration * frame_rate)
        frame_start = int(frame_start)
        
        frame_list = frame_list[frame_start: int(
            frame_start + np.ceil(duration * frame_rate))]
        audio_start = int(frame_start / frame_rate * audio_sample_rate)
        audio_end = int(audio_start + duration * audio_sample_rate)

        imgs = self.read_image(frame_list)
        audio, audio_rate = sf.read(audio_path, start=audio_start, stop=audio_end, dtype='float64', always_2d=True)
        audio = audio.mean(-1)

        onsets = librosa.onset.onset_detect(y=audio, sr=audio_rate, units='time', delta=0.3)
        onsets = np.rint(onsets * frame_rate).astype(int)
        onsets[onsets>29] = 29
        label = torch.zeros(len(frame_list))
        label[onsets] = 1

        batch = {
            'frames': imgs,
            'label': label
        }
        return batch

    def getitem_test(self, index):
        self.__getitem__(index)

    def __len__(self):
        return len(self.list_sample)


    def read_image(self, frame_list):
        imgs = []
        convert_tensor = transforms.ToTensor()
        for img_path in frame_list:
            image = Image.open(img_path).convert('RGB')
            image = convert_tensor(image)
            imgs.append(image.unsqueeze(0))
        # (T, C, H ,W)
        imgs = torch.cat(imgs, dim=0).squeeze()
        imgs = self.video_transform(imgs)
        imgs = imgs.permute(1, 0, 2, 3)
        # (C, T, H ,W)
        return imgs

    def generate_video_transform(self, args):
        resize_funct = transforms.Resize((128, 128))
        if self.split == 'train':
            crop_funct = transforms.RandomCrop(
                (112, 112))
            color_funct = transforms.ColorJitter(
                brightness=0.1, contrast=0.1, saturation=0, hue=0)
        else:
            crop_funct = transforms.CenterCrop(
                (112, 112))
            color_funct = transforms.Lambda(lambda img: img)

        vision_transform_list = [
            resize_funct,
            crop_funct,
            color_funct,
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ]
        return vision_transform_list