Spaces:
Runtime error
Runtime error
waidhoferj
commited on
Commit
·
3b31903
1
Parent(s):
3a0f0a5
updated packages
Browse files- .gitignore +1 -0
- TODO.md +4 -0
- environment.yml +1 -1
- models/config/train_local.yaml +2 -2
- models/decision_tree.py +62 -39
- preprocessing/dataset.py +165 -104
- train.py +79 -54
.gitignore
CHANGED
@@ -8,3 +8,4 @@ scrapers/auth
|
|
8 |
lightning_logs
|
9 |
.lr_find_*
|
10 |
.cache
|
|
|
|
8 |
lightning_logs
|
9 |
.lr_find_*
|
10 |
.cache
|
11 |
+
.vscode
|
TODO.md
CHANGED
@@ -9,3 +9,7 @@
|
|
9 |
- Read the Medium series about audio DL
|
10 |
- double check \_rectify_duration
|
11 |
- ✅ Filter out songs that have only one vote
|
|
|
|
|
|
|
|
|
|
9 |
- Read the Medium series about audio DL
|
10 |
- double check \_rectify_duration
|
11 |
- ✅ Filter out songs that have only one vote
|
12 |
+
|
13 |
+
## Notes
|
14 |
+
|
15 |
+
2xM60 insufficient memory.
|
environment.yml
CHANGED
@@ -22,7 +22,7 @@ dependencies:
|
|
22 |
- rich
|
23 |
- scikit-learn
|
24 |
- tensorboard
|
|
|
25 |
- pip:
|
26 |
-
- git+https://github.com/huggingface/transformers.git
|
27 |
- evaluate
|
28 |
- wakepy
|
|
|
22 |
- rich
|
23 |
- scikit-learn
|
24 |
- tensorboard
|
25 |
+
- transformers
|
26 |
- pip:
|
|
|
27 |
- evaluate
|
28 |
- wakepy
|
models/config/train_local.yaml
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
global:
|
2 |
-
id: decision_tree
|
3 |
device: mps
|
4 |
seed: 42
|
5 |
dance_ids:
|
@@ -22,7 +22,7 @@ data_module:
|
|
22 |
song_data_path: data/songs_cleaned.csv
|
23 |
song_audio_path: data/samples
|
24 |
batch_size: 32
|
25 |
-
num_workers:
|
26 |
min_votes: 1
|
27 |
dataset_kwargs:
|
28 |
audio_window_duration: 6
|
|
|
1 |
global:
|
2 |
+
id: ast_ptl # decision_tree
|
3 |
device: mps
|
4 |
seed: 42
|
5 |
dance_ids:
|
|
|
22 |
song_data_path: data/songs_cleaned.csv
|
23 |
song_audio_path: data/samples
|
24 |
batch_size: 32
|
25 |
+
num_workers: 7
|
26 |
min_votes: 1
|
27 |
dataset_kwargs:
|
28 |
audio_window_duration: 6
|
models/decision_tree.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
from sklearn.base import ClassifierMixin, BaseEstimator
|
3 |
import pandas as pd
|
4 |
from torch import nn
|
@@ -8,10 +7,15 @@ import numpy as np
|
|
8 |
import json
|
9 |
from tqdm import tqdm
|
10 |
import librosa
|
|
|
11 |
DANCE_INFO_FILE = "data/dance_info.csv"
|
12 |
-
dance_info_df = pd.read_csv(
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
class DanceTreeClassifier(BaseEstimator,ClassifierMixin):
|
15 |
"""
|
16 |
Trains a series of binary classifiers to classify each dance when a song falls into its bpm range.
|
17 |
|
@@ -21,20 +25,20 @@ class DanceTreeClassifier(BaseEstimator,ClassifierMixin):
|
|
21 |
"""
|
22 |
|
23 |
def __init__(self, device="cpu", lr=1e-4, epochs=5, verbose=True) -> None:
|
24 |
-
self.device=device
|
25 |
-
self.epochs=epochs
|
26 |
self.verbose = verbose
|
27 |
self.lr = lr
|
28 |
self.classifiers = {}
|
29 |
self.optimizers = {}
|
30 |
self.criterion = nn.BCELoss()
|
31 |
|
32 |
-
def get_valid_dances_from_bpm(self,bpm:float) -> list[str]:
|
33 |
-
mask = dance_info_df["tempoRange"].apply(
|
|
|
|
|
34 |
return list(dance_info_df["id"][mask])
|
35 |
|
36 |
-
|
37 |
-
|
38 |
def fit(self, x, y):
|
39 |
"""
|
40 |
x: (specs, bpms). The first element is the spectrogram, second element is the bpm. spec shape should be (channel, freq_bins, sr * time)
|
@@ -45,57 +49,73 @@ class DanceTreeClassifier(BaseEstimator,ClassifierMixin):
|
|
45 |
# TODO: Introduce batches
|
46 |
epoch_loss = 0
|
47 |
pred_count = 0
|
|
|
48 |
for (spec, bpm), label in zip(x, y):
|
|
|
49 |
# find all models that are in the bpm range
|
50 |
matching_dances = self.get_valid_dances_from_bpm(bpm)
|
|
|
51 |
for dance in matching_dances:
|
52 |
if dance not in self.classifiers or dance not in self.optimizers:
|
53 |
-
classifier = DanceCNN()
|
54 |
self.classifiers[dance] = classifier
|
55 |
-
self.optimizers[dance] = torch.optim.Adam(
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
opt.zero_grad()
|
59 |
-
spec = torch.from_numpy(spec).to(self.device)
|
60 |
output = model(spec)
|
61 |
-
target = torch.tensor(float(dance == label))
|
62 |
loss = self.criterion(output, target)
|
63 |
epoch_loss += loss.item()
|
64 |
-
pred_count +=1
|
65 |
loss.backward()
|
66 |
opt.step()
|
67 |
-
|
|
|
|
|
68 |
|
69 |
def predict(self, x) -> list[str]:
|
70 |
results = []
|
71 |
for spec, bpm in zip(*x):
|
72 |
matching_dances = self.get_valid_dances_from_bpm(bpm)
|
73 |
-
dance_i = torch.tensor(
|
|
|
|
|
74 |
results.append(matching_dances[dance_i])
|
75 |
return results
|
76 |
|
77 |
-
|
78 |
-
|
79 |
|
80 |
class DanceCNN(nn.Module):
|
81 |
def __init__(self, sr=16000, freq_bins=20, duration=6, *args, **kwargs) -> None:
|
82 |
super().__init__(*args, **kwargs)
|
83 |
-
kernel_size=(3,9)
|
84 |
self.cnn = nn.Sequential(
|
85 |
-
nn.Conv2d(1,16, kernel_size=kernel_size),
|
86 |
nn.ReLU(),
|
87 |
-
nn.MaxPool2d((2,10)),
|
88 |
-
nn.Conv2d(16,32, kernel_size=kernel_size),
|
89 |
nn.ReLU(),
|
90 |
-
nn.MaxPool2d((2,10))
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
)
|
92 |
|
93 |
-
embedding_dimension =
|
94 |
self.classifier = nn.Sequential(
|
95 |
nn.Linear(embedding_dimension, 200),
|
96 |
nn.ReLU(),
|
97 |
nn.Linear(200, 1),
|
98 |
-
nn.Sigmoid()
|
99 |
)
|
100 |
|
101 |
def forward(self, x):
|
@@ -103,22 +123,25 @@ class DanceCNN(nn.Module):
|
|
103 |
x = x.flatten() if len(x.shape) == 3 else x.flatten(1)
|
104 |
return self.classifier(x)
|
105 |
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
"""
|
111 |
Loads audio and bpm from an audio path.
|
112 |
"""
|
113 |
-
|
114 |
for path in paths:
|
115 |
waveform, sr = librosa.load(path, mono=True, sr=resample_freq)
|
116 |
-
num_frames =
|
117 |
tempo, _ = librosa.beat.beat_track(y=waveform, sr=sr)
|
|
|
118 |
mfccs = librosa.feature.mfcc(y=waveform, sr=sr, n_mfcc=20)
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
1 |
from sklearn.base import ClassifierMixin, BaseEstimator
|
2 |
import pandas as pd
|
3 |
from torch import nn
|
|
|
7 |
import json
|
8 |
from tqdm import tqdm
|
9 |
import librosa
|
10 |
+
|
11 |
DANCE_INFO_FILE = "data/dance_info.csv"
|
12 |
+
dance_info_df = pd.read_csv(
|
13 |
+
DANCE_INFO_FILE,
|
14 |
+
converters={"tempoRange": lambda s: json.loads(s.replace("'", '"'))},
|
15 |
+
)
|
16 |
+
|
17 |
|
18 |
+
class DanceTreeClassifier(BaseEstimator, ClassifierMixin):
|
19 |
"""
|
20 |
Trains a series of binary classifiers to classify each dance when a song falls into its bpm range.
|
21 |
|
|
|
25 |
"""
|
26 |
|
27 |
def __init__(self, device="cpu", lr=1e-4, epochs=5, verbose=True) -> None:
|
28 |
+
self.device = device
|
29 |
+
self.epochs = epochs
|
30 |
self.verbose = verbose
|
31 |
self.lr = lr
|
32 |
self.classifiers = {}
|
33 |
self.optimizers = {}
|
34 |
self.criterion = nn.BCELoss()
|
35 |
|
36 |
+
def get_valid_dances_from_bpm(self, bpm: float) -> list[str]:
|
37 |
+
mask = dance_info_df["tempoRange"].apply(
|
38 |
+
lambda interval: interval["min"] <= bpm <= interval["max"]
|
39 |
+
)
|
40 |
return list(dance_info_df["id"][mask])
|
41 |
|
|
|
|
|
42 |
def fit(self, x, y):
|
43 |
"""
|
44 |
x: (specs, bpms). The first element is the spectrogram, second element is the bpm. spec shape should be (channel, freq_bins, sr * time)
|
|
|
49 |
# TODO: Introduce batches
|
50 |
epoch_loss = 0
|
51 |
pred_count = 0
|
52 |
+
step = 0
|
53 |
for (spec, bpm), label in zip(x, y):
|
54 |
+
step += 1
|
55 |
# find all models that are in the bpm range
|
56 |
matching_dances = self.get_valid_dances_from_bpm(bpm)
|
57 |
+
spec = torch.from_numpy(spec).to(self.device)
|
58 |
for dance in matching_dances:
|
59 |
if dance not in self.classifiers or dance not in self.optimizers:
|
60 |
+
classifier = DanceCNN().to(self.device)
|
61 |
self.classifiers[dance] = classifier
|
62 |
+
self.optimizers[dance] = torch.optim.Adam(
|
63 |
+
classifier.parameters(), lr=self.lr
|
64 |
+
)
|
65 |
+
models = [
|
66 |
+
(dance, model, self.optimizers[dance])
|
67 |
+
for dance, model in self.classifiers.items()
|
68 |
+
if dance in matching_dances
|
69 |
+
]
|
70 |
+
for model_i, (dance, model, opt) in enumerate(models):
|
71 |
opt.zero_grad()
|
|
|
72 |
output = model(spec)
|
73 |
+
target = torch.tensor([float(dance == label)], device=self.device)
|
74 |
loss = self.criterion(output, target)
|
75 |
epoch_loss += loss.item()
|
76 |
+
pred_count += 1
|
77 |
loss.backward()
|
78 |
opt.step()
|
79 |
+
progress_bar.set_description(
|
80 |
+
f"Loss: {epoch_loss / pred_count}, Step: {step}, Model: {model_i+1}/{len(models)}"
|
81 |
+
)
|
82 |
|
83 |
def predict(self, x) -> list[str]:
|
84 |
results = []
|
85 |
for spec, bpm in zip(*x):
|
86 |
matching_dances = self.get_valid_dances_from_bpm(bpm)
|
87 |
+
dance_i = torch.tensor(
|
88 |
+
[self.classifiers[dance](spec) for dance in matching_dances]
|
89 |
+
).argmax()
|
90 |
results.append(matching_dances[dance_i])
|
91 |
return results
|
92 |
|
|
|
|
|
93 |
|
94 |
class DanceCNN(nn.Module):
|
95 |
def __init__(self, sr=16000, freq_bins=20, duration=6, *args, **kwargs) -> None:
|
96 |
super().__init__(*args, **kwargs)
|
97 |
+
kernel_size = (3, 9)
|
98 |
self.cnn = nn.Sequential(
|
99 |
+
nn.Conv2d(1, 16, kernel_size=kernel_size),
|
100 |
nn.ReLU(),
|
101 |
+
nn.MaxPool2d((2, 10)),
|
102 |
+
nn.Conv2d(16, 32, kernel_size=kernel_size),
|
103 |
nn.ReLU(),
|
104 |
+
nn.MaxPool2d((2, 10)),
|
105 |
+
nn.Conv2d(32, 32, kernel_size=kernel_size),
|
106 |
+
nn.ReLU(),
|
107 |
+
nn.MaxPool2d((2, 10)),
|
108 |
+
nn.Conv2d(32, 16, kernel_size=kernel_size),
|
109 |
+
nn.ReLU(),
|
110 |
+
nn.MaxPool2d((2, 10)),
|
111 |
)
|
112 |
|
113 |
+
embedding_dimension = 16 * 6 * 8
|
114 |
self.classifier = nn.Sequential(
|
115 |
nn.Linear(embedding_dimension, 200),
|
116 |
nn.ReLU(),
|
117 |
nn.Linear(200, 1),
|
118 |
+
nn.Sigmoid(),
|
119 |
)
|
120 |
|
121 |
def forward(self, x):
|
|
|
123 |
x = x.flatten() if len(x.shape) == 3 else x.flatten(1)
|
124 |
return self.classifier(x)
|
125 |
|
126 |
+
|
127 |
+
def features_from_path(
|
128 |
+
paths: list[str], audio_window_duration=6, audio_duration=30, resample_freq=16000
|
129 |
+
) -> Iterator[tuple[np.array, float]]:
|
130 |
"""
|
131 |
Loads audio and bpm from an audio path.
|
132 |
"""
|
133 |
+
|
134 |
for path in paths:
|
135 |
waveform, sr = librosa.load(path, mono=True, sr=resample_freq)
|
136 |
+
num_frames = audio_window_duration * sr
|
137 |
tempo, _ = librosa.beat.beat_track(y=waveform, sr=sr)
|
138 |
+
spec = librosa.feature.melspectrogram(y=waveform, sr=sr)
|
139 |
mfccs = librosa.feature.mfcc(y=waveform, sr=sr, n_mfcc=20)
|
140 |
+
spec_normalized = (spec - spec.mean()) / spec.std()
|
141 |
+
spec_padded = librosa.util.fix_length(
|
142 |
+
spec_normalized, size=sr * audio_duration, axis=1
|
143 |
+
)
|
144 |
+
batched_spec = np.expand_dims(spec_padded, axis=0)
|
145 |
+
for i in range(audio_duration // audio_window_duration):
|
146 |
+
spec_window = batched_spec[:, :, i * num_frames : (i + 1) * num_frames]
|
147 |
+
yield (spec_window, tempo)
|
preprocessing/dataset.py
CHANGED
@@ -12,19 +12,23 @@ from torch import nn
|
|
12 |
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
|
13 |
|
14 |
|
15 |
-
|
16 |
class SongDataset(Dataset):
|
17 |
-
def __init__(
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
25 |
):
|
26 |
-
assert
|
27 |
-
|
|
|
|
|
|
|
|
|
28 |
|
29 |
self.audio_paths = audio_paths
|
30 |
self.dance_labels = dance_labels
|
@@ -34,14 +38,21 @@ class SongDataset(Dataset):
|
|
34 |
self.audio_window_jitter = audio_window_jitter
|
35 |
self.audio_duration = int(audio_duration)
|
36 |
|
37 |
-
self.audio_pipeline = AudioTrainingPipeline(
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
def __len__(self):
|
40 |
-
return len(self.audio_paths) * self.audio_duration // self.audio_window_duration
|
41 |
|
42 |
-
def __getitem__(self, idx:int) -> tuple[torch.Tensor, torch.Tensor]:
|
43 |
waveform = self._waveform_from_index(idx)
|
44 |
-
assert
|
|
|
|
|
45 |
spectrogram = self.audio_pipeline(waveform)
|
46 |
|
47 |
dance_labels = self._label_from_index(idx)
|
@@ -53,206 +64,256 @@ class SongDataset(Dataset):
|
|
53 |
# Try the previous one
|
54 |
# This happens when some of the audio recordings are really quiet
|
55 |
# This WILL NOT leak into other data partitions because songs belong entirely to a partition
|
56 |
-
return self[idx-1]
|
57 |
|
58 |
-
def _convert_idx(self,idx:int) -> int:
|
59 |
return idx * self.audio_window_duration // self.audio_duration
|
60 |
|
61 |
-
def _backtrace_audio_path(self, index:int) -> str:
|
62 |
return self.audio_paths[self._convert_idx(index)]
|
63 |
|
64 |
-
def _validate_output(self,x,y):
|
65 |
-
is_finite =
|
66 |
is_numerical = not torch.any(torch.isnan(x))
|
67 |
has_data = torch.any(x != 0.0)
|
68 |
is_binary = len(torch.unique(y)) < 3
|
69 |
-
return all((is_finite,is_numerical, has_data, is_binary))
|
70 |
|
71 |
-
def _waveform_from_index(self, idx:int) -> torch.Tensor:
|
72 |
audio_filepath = self.audio_paths[self._convert_idx(idx)]
|
73 |
num_windows = self.audio_duration // self.audio_window_duration
|
74 |
frame_index = idx % num_windows
|
75 |
jitter_start = -self.audio_window_jitter if frame_index > 0 else 0.0
|
76 |
jitter_end = self.audio_window_jitter if frame_index != num_windows - 1 else 0.0
|
77 |
-
jitter = int(
|
78 |
-
|
|
|
|
|
|
|
|
|
79 |
num_frames = self.sample_rate * self.audio_window_duration
|
80 |
-
waveform, sample_rate = ta.load(
|
81 |
-
|
|
|
|
|
|
|
|
|
82 |
return waveform
|
83 |
|
84 |
-
|
85 |
-
def _label_from_index(self, idx:int) -> torch.Tensor:
|
86 |
return torch.from_numpy(self.dance_labels[self._convert_idx(idx)])
|
87 |
-
|
88 |
|
89 |
class WaveformSongDataset(SongDataset):
|
90 |
"""
|
91 |
Outputs raw waveforms of the data instead of a spectrogram.
|
92 |
"""
|
93 |
|
94 |
-
def __init__(self, *args,resample_frequency=16000, **kwargs):
|
95 |
super().__init__(*args, **kwargs)
|
96 |
self.resample_frequency = resample_frequency
|
97 |
self.resampler = taT.Resample(self.sample_rate, self.resample_frequency)
|
98 |
self.pipeline = []
|
99 |
|
100 |
-
def __getitem__(self, idx:int) -> dict[str, torch.Tensor]:
|
101 |
waveform = self._waveform_from_index(idx)
|
102 |
-
assert
|
|
|
|
|
103 |
# resample the waveform
|
104 |
waveform = self.resampler(waveform)
|
105 |
-
|
106 |
waveform = waveform.mean(0)
|
107 |
|
108 |
dance_labels = self._label_from_index(idx)
|
109 |
return waveform, dance_labels
|
110 |
-
|
111 |
-
|
112 |
|
113 |
|
114 |
class HuggingFaceWaveformSongDataset(WaveformSongDataset):
|
115 |
-
|
116 |
def __init__(self, *args, **kwargs):
|
117 |
super().__init__(*args, **kwargs)
|
118 |
self.pipeline = []
|
119 |
|
120 |
-
|
121 |
-
|
122 |
-
x,y = super().__getitem__(idx)
|
123 |
if len(self.pipeline) > 0:
|
124 |
for fn in self.pipeline:
|
125 |
x = fn(x)
|
126 |
|
127 |
dance_labels = y.argmax()
|
128 |
-
return {
|
|
|
|
|
|
|
129 |
|
130 |
-
def map(self,fn):
|
131 |
"""
|
132 |
NOTE this mutates the original, doesn't return a copy like normal maps.
|
133 |
"""
|
134 |
self.pipeline.append(fn)
|
135 |
|
|
|
136 |
class DanceDataModule(pl.LightningDataModule):
|
137 |
-
def __init__(
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
|
|
148 |
):
|
149 |
super().__init__()
|
150 |
self.song_data_path = song_data_path
|
151 |
self.song_audio_path = song_audio_path
|
152 |
-
self.val_proportion=val_proportion
|
153 |
-
self.test_proportion=test_proportion
|
154 |
-
self.train_proportion= 1
|
155 |
-
self.target_classes=target_classes
|
156 |
self.batch_size = batch_size
|
157 |
self.num_workers = num_workers
|
158 |
self.dataset_kwargs = dataset_kwargs
|
159 |
self.dataset_cls = dataset_cls if dataset_cls is not None else SongDataset
|
160 |
|
161 |
df = pd.read_csv(song_data_path)
|
162 |
-
self.x,self.y = get_examples(
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
def setup(self, stage: str):
|
165 |
-
train_i, val_i, test_i = random_split(
|
|
|
|
|
|
|
166 |
self.train_ds = self._dataset_from_indices(train_i)
|
167 |
self.val_ds = self._dataset_from_indices(val_i)
|
168 |
self.test_ds = self._dataset_from_indices(test_i)
|
169 |
-
|
170 |
-
def _dataset_from_indices(self, idx:list[int]) -> SongDataset:
|
171 |
return self.dataset_cls(self.x[idx], self.y[idx], **self.dataset_kwargs)
|
172 |
-
|
173 |
def train_dataloader(self):
|
174 |
-
return DataLoader(
|
|
|
|
|
|
|
|
|
|
|
175 |
|
176 |
def val_dataloader(self):
|
177 |
-
return DataLoader(
|
|
|
|
|
178 |
|
179 |
def test_dataloader(self):
|
180 |
-
return DataLoader(
|
|
|
|
|
181 |
|
182 |
def get_label_weights(self):
|
183 |
n_examples, n_classes = self.y.shape
|
184 |
return torch.from_numpy(n_examples / (n_classes * sum(self.y)))
|
185 |
-
|
186 |
|
187 |
-
class WaveformTrainingEnvironment(pl.LightningModule):
|
188 |
|
189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
super().__init__(*args, **kwargs)
|
191 |
self.model = model
|
192 |
self.criterion = criterion
|
193 |
self.learning_rate = learning_rate
|
194 |
-
self.config=config
|
195 |
-
self.feature_extractor=feature_extractor
|
196 |
-
self.save_hyperparameters(
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
|
|
|
|
202 |
|
203 |
def preprocess_inputs(self, x):
|
204 |
device = x.device
|
205 |
-
x = x.squeeze(1).cpu().numpy()
|
206 |
-
x = self.feature_extractor(
|
207 |
return x["input_values"].to(device)
|
208 |
-
|
209 |
-
def training_step(
|
|
|
|
|
210 |
features, labels = batch
|
211 |
features = self.preprocess_inputs(features)
|
212 |
outputs = self.model(features).logits
|
213 |
-
outputs = nn.Sigmoid()(
|
|
|
|
|
214 |
loss = self.criterion(outputs, labels)
|
215 |
metrics = calculate_metrics(outputs, labels, prefix="train/", multi_label=True)
|
216 |
self.log_dict(metrics, prog_bar=True)
|
217 |
return loss
|
218 |
|
219 |
-
|
220 |
-
|
221 |
-
|
|
|
222 |
x = self.preprocess_inputs(x)
|
223 |
preds = self.model(x).logits
|
224 |
-
preds = nn.Sigmoid()(preds)
|
225 |
metrics = calculate_metrics(preds, y, prefix="val/", multi_label=True)
|
226 |
metrics["val/loss"] = self.criterion(preds, y)
|
227 |
-
self.log_dict(metrics,prog_bar=True)
|
228 |
|
229 |
-
def test_step(self, batch:tuple[torch.Tensor, torch.TensorType], batch_index:int):
|
230 |
x, y = batch
|
231 |
x = self.preprocess_inputs(x)
|
232 |
preds = self.model(x).logits
|
233 |
-
preds = nn.Sigmoid()(preds)
|
234 |
-
self.log_dict(
|
235 |
-
|
|
|
|
|
236 |
def configure_optimizers(self):
|
237 |
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
238 |
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') {"scheduler": scheduler, "monitor": "val/loss"}
|
239 |
-
return [optimizer]
|
240 |
-
|
241 |
|
242 |
|
243 |
-
def calculate_metrics(
|
|
|
|
|
244 |
target = target.detach().cpu().numpy()
|
245 |
pred = pred.detach().cpu().numpy()
|
246 |
params = {
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
|
|
|
|
|
|
|
|
|
12 |
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
|
13 |
|
14 |
|
|
|
15 |
class SongDataset(Dataset):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
audio_paths: list[str],
|
19 |
+
dance_labels: list[np.ndarray],
|
20 |
+
audio_duration=30, # seconds
|
21 |
+
audio_window_duration=6, # seconds
|
22 |
+
audio_window_jitter=0.0, # seconds
|
23 |
+
audio_pipeline_kwargs={},
|
24 |
+
resample_frequency=16000,
|
25 |
):
|
26 |
+
assert (
|
27 |
+
audio_duration % audio_window_duration == 0
|
28 |
+
), "Audio window should divide duration evenly."
|
29 |
+
assert (
|
30 |
+
audio_window_duration > audio_window_jitter
|
31 |
+
), "Jitter should be a small fraction of the audio window duration."
|
32 |
|
33 |
self.audio_paths = audio_paths
|
34 |
self.dance_labels = dance_labels
|
|
|
38 |
self.audio_window_jitter = audio_window_jitter
|
39 |
self.audio_duration = int(audio_duration)
|
40 |
|
41 |
+
self.audio_pipeline = AudioTrainingPipeline(
|
42 |
+
self.sample_rate,
|
43 |
+
resample_frequency,
|
44 |
+
audio_window_duration,
|
45 |
+
**audio_pipeline_kwargs,
|
46 |
+
)
|
47 |
|
48 |
def __len__(self):
|
49 |
+
return len(self.audio_paths) * self.audio_duration // self.audio_window_duration
|
50 |
|
51 |
+
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
52 |
waveform = self._waveform_from_index(idx)
|
53 |
+
assert (
|
54 |
+
waveform.shape[1] > 10
|
55 |
+
), f"No data found: {self._backtrace_audio_path(idx)}"
|
56 |
spectrogram = self.audio_pipeline(waveform)
|
57 |
|
58 |
dance_labels = self._label_from_index(idx)
|
|
|
64 |
# Try the previous one
|
65 |
# This happens when some of the audio recordings are really quiet
|
66 |
# This WILL NOT leak into other data partitions because songs belong entirely to a partition
|
67 |
+
return self[idx - 1]
|
68 |
|
69 |
+
def _convert_idx(self, idx: int) -> int:
|
70 |
return idx * self.audio_window_duration // self.audio_duration
|
71 |
|
72 |
+
def _backtrace_audio_path(self, index: int) -> str:
|
73 |
return self.audio_paths[self._convert_idx(index)]
|
74 |
|
75 |
+
def _validate_output(self, x, y):
|
76 |
+
is_finite = not torch.any(torch.isinf(x))
|
77 |
is_numerical = not torch.any(torch.isnan(x))
|
78 |
has_data = torch.any(x != 0.0)
|
79 |
is_binary = len(torch.unique(y)) < 3
|
80 |
+
return all((is_finite, is_numerical, has_data, is_binary))
|
81 |
|
82 |
+
def _waveform_from_index(self, idx: int) -> torch.Tensor:
|
83 |
audio_filepath = self.audio_paths[self._convert_idx(idx)]
|
84 |
num_windows = self.audio_duration // self.audio_window_duration
|
85 |
frame_index = idx % num_windows
|
86 |
jitter_start = -self.audio_window_jitter if frame_index > 0 else 0.0
|
87 |
jitter_end = self.audio_window_jitter if frame_index != num_windows - 1 else 0.0
|
88 |
+
jitter = int(
|
89 |
+
torch.FloatTensor(1).uniform_(jitter_start, jitter_end) * self.sample_rate
|
90 |
+
)
|
91 |
+
frame_offset = (
|
92 |
+
frame_index * self.audio_window_duration * self.sample_rate + jitter
|
93 |
+
)
|
94 |
num_frames = self.sample_rate * self.audio_window_duration
|
95 |
+
waveform, sample_rate = ta.load(
|
96 |
+
audio_filepath, frame_offset=frame_offset, num_frames=num_frames
|
97 |
+
)
|
98 |
+
assert (
|
99 |
+
sample_rate == self.sample_rate
|
100 |
+
), f"Expected sample rate of {self.sample_rate}. Found {sample_rate}"
|
101 |
return waveform
|
102 |
|
103 |
+
def _label_from_index(self, idx: int) -> torch.Tensor:
|
|
|
104 |
return torch.from_numpy(self.dance_labels[self._convert_idx(idx)])
|
105 |
+
|
106 |
|
107 |
class WaveformSongDataset(SongDataset):
|
108 |
"""
|
109 |
Outputs raw waveforms of the data instead of a spectrogram.
|
110 |
"""
|
111 |
|
112 |
+
def __init__(self, *args, resample_frequency=16000, **kwargs):
|
113 |
super().__init__(*args, **kwargs)
|
114 |
self.resample_frequency = resample_frequency
|
115 |
self.resampler = taT.Resample(self.sample_rate, self.resample_frequency)
|
116 |
self.pipeline = []
|
117 |
|
118 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
119 |
waveform = self._waveform_from_index(idx)
|
120 |
+
assert (
|
121 |
+
waveform.shape[1] > 10
|
122 |
+
), f"No data found: {self._backtrace_audio_path(idx)}"
|
123 |
# resample the waveform
|
124 |
waveform = self.resampler(waveform)
|
125 |
+
|
126 |
waveform = waveform.mean(0)
|
127 |
|
128 |
dance_labels = self._label_from_index(idx)
|
129 |
return waveform, dance_labels
|
|
|
|
|
130 |
|
131 |
|
132 |
class HuggingFaceWaveformSongDataset(WaveformSongDataset):
|
|
|
133 |
def __init__(self, *args, **kwargs):
|
134 |
super().__init__(*args, **kwargs)
|
135 |
self.pipeline = []
|
136 |
|
137 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
138 |
+
x, y = super().__getitem__(idx)
|
|
|
139 |
if len(self.pipeline) > 0:
|
140 |
for fn in self.pipeline:
|
141 |
x = fn(x)
|
142 |
|
143 |
dance_labels = y.argmax()
|
144 |
+
return {
|
145 |
+
"input_values": x["input_values"][0] if hasattr(x, "input_values") else x,
|
146 |
+
"label": dance_labels,
|
147 |
+
}
|
148 |
|
149 |
+
def map(self, fn):
|
150 |
"""
|
151 |
NOTE this mutates the original, doesn't return a copy like normal maps.
|
152 |
"""
|
153 |
self.pipeline.append(fn)
|
154 |
|
155 |
+
|
156 |
class DanceDataModule(pl.LightningDataModule):
|
157 |
+
def __init__(
|
158 |
+
self,
|
159 |
+
song_data_path="data/songs_cleaned.csv",
|
160 |
+
song_audio_path="data/samples",
|
161 |
+
test_proportion=0.15,
|
162 |
+
val_proportion=0.1,
|
163 |
+
target_classes: list[str] = None,
|
164 |
+
min_votes=1,
|
165 |
+
batch_size: int = 64,
|
166 |
+
num_workers=10,
|
167 |
+
dataset_cls=None,
|
168 |
+
dataset_kwargs={},
|
169 |
):
|
170 |
super().__init__()
|
171 |
self.song_data_path = song_data_path
|
172 |
self.song_audio_path = song_audio_path
|
173 |
+
self.val_proportion = val_proportion
|
174 |
+
self.test_proportion = test_proportion
|
175 |
+
self.train_proportion = 1.0 - test_proportion - val_proportion
|
176 |
+
self.target_classes = target_classes
|
177 |
self.batch_size = batch_size
|
178 |
self.num_workers = num_workers
|
179 |
self.dataset_kwargs = dataset_kwargs
|
180 |
self.dataset_cls = dataset_cls if dataset_cls is not None else SongDataset
|
181 |
|
182 |
df = pd.read_csv(song_data_path)
|
183 |
+
self.x, self.y = get_examples(
|
184 |
+
df,
|
185 |
+
self.song_audio_path,
|
186 |
+
class_list=self.target_classes,
|
187 |
+
multi_label=True,
|
188 |
+
min_votes=min_votes,
|
189 |
+
)
|
190 |
|
191 |
def setup(self, stage: str):
|
192 |
+
train_i, val_i, test_i = random_split(
|
193 |
+
np.arange(len(self.x)),
|
194 |
+
[self.train_proportion, self.val_proportion, self.test_proportion],
|
195 |
+
)
|
196 |
self.train_ds = self._dataset_from_indices(train_i)
|
197 |
self.val_ds = self._dataset_from_indices(val_i)
|
198 |
self.test_ds = self._dataset_from_indices(test_i)
|
199 |
+
|
200 |
+
def _dataset_from_indices(self, idx: list[int]) -> SongDataset:
|
201 |
return self.dataset_cls(self.x[idx], self.y[idx], **self.dataset_kwargs)
|
202 |
+
|
203 |
def train_dataloader(self):
|
204 |
+
return DataLoader(
|
205 |
+
self.train_ds,
|
206 |
+
batch_size=self.batch_size,
|
207 |
+
num_workers=self.num_workers,
|
208 |
+
shuffle=True,
|
209 |
+
)
|
210 |
|
211 |
def val_dataloader(self):
|
212 |
+
return DataLoader(
|
213 |
+
self.val_ds, batch_size=self.batch_size, num_workers=self.num_workers
|
214 |
+
)
|
215 |
|
216 |
def test_dataloader(self):
|
217 |
+
return DataLoader(
|
218 |
+
self.test_ds, batch_size=self.batch_size, num_workers=self.num_workers
|
219 |
+
)
|
220 |
|
221 |
def get_label_weights(self):
|
222 |
n_examples, n_classes = self.y.shape
|
223 |
return torch.from_numpy(n_examples / (n_classes * sum(self.y)))
|
|
|
224 |
|
|
|
225 |
|
226 |
+
class WaveformTrainingEnvironment(pl.LightningModule):
|
227 |
+
def __init__(
|
228 |
+
self,
|
229 |
+
model: nn.Module,
|
230 |
+
criterion: nn.Module,
|
231 |
+
feature_extractor,
|
232 |
+
config: dict,
|
233 |
+
learning_rate=1e-4,
|
234 |
+
*args,
|
235 |
+
**kwargs,
|
236 |
+
):
|
237 |
super().__init__(*args, **kwargs)
|
238 |
self.model = model
|
239 |
self.criterion = criterion
|
240 |
self.learning_rate = learning_rate
|
241 |
+
self.config = config
|
242 |
+
self.feature_extractor = feature_extractor
|
243 |
+
self.save_hyperparameters(
|
244 |
+
{
|
245 |
+
"model": type(model).__name__,
|
246 |
+
"loss": type(criterion).__name__,
|
247 |
+
"config": config,
|
248 |
+
**kwargs,
|
249 |
+
}
|
250 |
+
)
|
251 |
|
252 |
def preprocess_inputs(self, x):
|
253 |
device = x.device
|
254 |
+
x = list(x.squeeze(1).cpu().numpy())
|
255 |
+
x = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000)
|
256 |
return x["input_values"].to(device)
|
257 |
+
|
258 |
+
def training_step(
|
259 |
+
self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int
|
260 |
+
) -> torch.Tensor:
|
261 |
features, labels = batch
|
262 |
features = self.preprocess_inputs(features)
|
263 |
outputs = self.model(features).logits
|
264 |
+
outputs = nn.Sigmoid()(
|
265 |
+
outputs
|
266 |
+
) # good for multi label classification, should be softmax otherwise
|
267 |
loss = self.criterion(outputs, labels)
|
268 |
metrics = calculate_metrics(outputs, labels, prefix="train/", multi_label=True)
|
269 |
self.log_dict(metrics, prog_bar=True)
|
270 |
return loss
|
271 |
|
272 |
+
def validation_step(
|
273 |
+
self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int
|
274 |
+
):
|
275 |
+
x, y = batch
|
276 |
x = self.preprocess_inputs(x)
|
277 |
preds = self.model(x).logits
|
278 |
+
preds = nn.Sigmoid()(preds)
|
279 |
metrics = calculate_metrics(preds, y, prefix="val/", multi_label=True)
|
280 |
metrics["val/loss"] = self.criterion(preds, y)
|
281 |
+
self.log_dict(metrics, prog_bar=True)
|
282 |
|
283 |
+
def test_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int):
|
284 |
x, y = batch
|
285 |
x = self.preprocess_inputs(x)
|
286 |
preds = self.model(x).logits
|
287 |
+
preds = nn.Sigmoid()(preds)
|
288 |
+
self.log_dict(
|
289 |
+
calculate_metrics(preds, y, prefix="test/", multi_label=True), prog_bar=True
|
290 |
+
)
|
291 |
+
|
292 |
def configure_optimizers(self):
|
293 |
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
294 |
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') {"scheduler": scheduler, "monitor": "val/loss"}
|
295 |
+
return [optimizer]
|
|
|
296 |
|
297 |
|
298 |
+
def calculate_metrics(
|
299 |
+
pred, target, threshold=0.5, prefix="", multi_label=True
|
300 |
+
) -> dict[str, torch.Tensor]:
|
301 |
target = target.detach().cpu().numpy()
|
302 |
pred = pred.detach().cpu().numpy()
|
303 |
params = {
|
304 |
+
"y_true": target if multi_label else target.argmax(1),
|
305 |
+
"y_pred": np.array(pred > threshold, dtype=float)
|
306 |
+
if multi_label
|
307 |
+
else pred.argmax(1),
|
308 |
+
"zero_division": 0,
|
309 |
+
"average": "macro",
|
310 |
+
}
|
311 |
+
metrics = {
|
312 |
+
"precision": precision_score(**params),
|
313 |
+
"recall": recall_score(**params),
|
314 |
+
"f1": f1_score(**params),
|
315 |
+
"accuracy": accuracy_score(y_true=params["y_true"], y_pred=params["y_pred"]),
|
316 |
+
}
|
317 |
+
return {
|
318 |
+
prefix + k: torch.tensor(v, dtype=torch.float32) for k, v in metrics.items()
|
319 |
+
}
|
train.py
CHANGED
@@ -7,25 +7,32 @@ from sklearn.model_selection import KFold
|
|
7 |
import pytorch_lightning as pl
|
8 |
from pytorch_lightning import callbacks as cb
|
9 |
from models.utils import LabelWeightedBCELoss
|
10 |
-
from models.audio_spectrogram_transformer import
|
|
|
|
|
|
|
11 |
from preprocessing.dataset import SongDataset, WaveformTrainingEnvironment
|
12 |
from preprocessing.preprocess import get_examples
|
13 |
from models.residual import ResidualDancer, TrainingEnvironment
|
14 |
from models.decision_tree import DanceTreeClassifier, features_from_path
|
15 |
import yaml
|
16 |
-
from preprocessing.dataset import
|
|
|
|
|
|
|
|
|
17 |
from torch.utils.data import random_split
|
18 |
import numpy as np
|
19 |
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
|
20 |
from argparse import ArgumentParser
|
21 |
|
22 |
|
23 |
-
|
24 |
import torch
|
25 |
from torch import nn
|
26 |
from sklearn.utils.class_weight import compute_class_weight
|
27 |
|
28 |
-
|
|
|
29 |
match id:
|
30 |
case "ast_ptl":
|
31 |
return train_ast_lightning
|
@@ -38,7 +45,8 @@ def get_training_fn(id:str) -> Callable:
|
|
38 |
case _:
|
39 |
raise Exception(f"Couldn't find a training function for '{id}'.")
|
40 |
|
41 |
-
|
|
|
42 |
with open(filepath, "r") as f:
|
43 |
config = yaml.safe_load(f)
|
44 |
return config
|
@@ -48,14 +56,14 @@ def cross_validation(config, k=5):
|
|
48 |
df = pd.read_csv("data/songs.csv")
|
49 |
g_config = config["global"]
|
50 |
batch_size = config["data_module"]["batch_size"]
|
51 |
-
x,y = get_examples(df, "data/samples",class_list=g_config["dance_ids"])
|
52 |
-
dataset = SongDataset(x,y)
|
53 |
-
splits=KFold(n_splits=k,shuffle=True,random_state=g_config["seed"])
|
54 |
trainer = pl.Trainer(accelerator=g_config["device"])
|
55 |
-
for fold, (train_idx,val_idx) in enumerate(splits.split(x,y)):
|
56 |
print(f"Fold {fold+1}")
|
57 |
model = ResidualDancer(n_classes=len(g_config["dance_ids"]))
|
58 |
-
train_env = TrainingEnvironment(model,nn.BCELoss())
|
59 |
train_sampler = SubsetRandomSampler(train_idx)
|
60 |
test_sampler = SubsetRandomSampler(val_idx)
|
61 |
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
|
@@ -64,15 +72,17 @@ def cross_validation(config, k=5):
|
|
64 |
trainer.test(train_env, test_loader)
|
65 |
|
66 |
|
67 |
-
def train_model(config:dict):
|
68 |
TARGET_CLASSES = config["global"]["dance_ids"]
|
69 |
DEVICE = config["global"]["device"]
|
70 |
SEED = config["global"]["seed"]
|
71 |
pl.seed_everything(SEED, workers=True)
|
72 |
-
data = DanceDataModule(target_classes=TARGET_CLASSES, **config[
|
73 |
-
model = ResidualDancer(n_classes=len(TARGET_CLASSES), **config[
|
74 |
label_weights = data.get_label_weights().to(DEVICE)
|
75 |
-
criterion = LabelWeightedBCELoss(
|
|
|
|
|
76 |
train_env = TrainingEnvironment(model, criterion, config)
|
77 |
callbacks = [
|
78 |
# cb.LearningRateFinder(update_attr=True),
|
@@ -81,36 +91,41 @@ def train_model(config:dict):
|
|
81 |
cb.RichProgressBar(),
|
82 |
cb.DeviceStatsMonitor(),
|
83 |
]
|
84 |
-
trainer = pl.Trainer(
|
85 |
-
callbacks=callbacks,
|
86 |
-
**config["trainer"]
|
87 |
-
)
|
88 |
trainer.fit(train_env, datamodule=data)
|
89 |
trainer.test(train_env, datamodule=data)
|
90 |
|
91 |
|
92 |
-
def train_ast(
|
93 |
-
config:dict
|
94 |
-
):
|
95 |
TARGET_CLASSES = config["global"]["dance_ids"]
|
96 |
DEVICE = config["global"]["device"]
|
97 |
SEED = config["global"]["seed"]
|
98 |
dataset_kwargs = config["data_module"]["dataset_kwargs"]
|
99 |
test_proportion = config["data_module"].get("test_proportion", 0.2)
|
100 |
-
train_proportion = 1. - test_proportion
|
101 |
-
song_data_path="data/songs_cleaned.csv"
|
102 |
song_audio_path = "data/samples"
|
103 |
pl.seed_everything(SEED, workers=True)
|
104 |
|
105 |
df = pd.read_csv(song_data_path)
|
106 |
-
x, y = get_examples(
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
"""
|
115 |
work on integration between waveform dataset and environment. Should work for both HF and PTL.
|
116 |
"""
|
@@ -118,45 +133,50 @@ def train_ast_lightning(config:dict):
|
|
118 |
DEVICE = config["global"]["device"]
|
119 |
SEED = config["global"]["seed"]
|
120 |
pl.seed_everything(SEED, workers=True)
|
121 |
-
data = DanceDataModule(
|
|
|
|
|
|
|
|
|
122 |
id2label, label2id = get_id_label_mapping(TARGET_CLASSES)
|
123 |
model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"
|
124 |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
|
125 |
|
126 |
model = AutoModelForAudioClassification.from_pretrained(
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
).to(DEVICE)
|
133 |
label_weights = data.get_label_weights().to(DEVICE)
|
134 |
-
criterion = LabelWeightedBCELoss(
|
135 |
-
|
|
|
|
|
136 |
callbacks = [
|
137 |
# cb.LearningRateFinder(update_attr=True),
|
138 |
cb.EarlyStopping("val/loss", patience=5),
|
139 |
cb.StochasticWeightAveraging(1e-2),
|
140 |
-
cb.RichProgressBar()
|
141 |
]
|
142 |
-
trainer = pl.Trainer(
|
143 |
-
callbacks=callbacks,
|
144 |
-
**config["trainer"]
|
145 |
-
)
|
146 |
trainer.fit(train_env, datamodule=data)
|
147 |
trainer.test(train_env, datamodule=data)
|
148 |
|
149 |
|
150 |
-
def train_decision_tree(config:dict):
|
151 |
TARGET_CLASSES = config["global"]["dance_ids"]
|
152 |
DEVICE = config["global"]["device"]
|
153 |
SEED = config["global"]["seed"]
|
154 |
-
song_data_path=config[
|
155 |
-
song_audio_path = config[
|
156 |
pl.seed_everything(SEED, workers=True)
|
157 |
|
158 |
df = pd.read_csv(song_data_path)
|
159 |
-
x, y = get_examples(
|
|
|
|
|
160 |
# Convert y back to string classes
|
161 |
y = np.array(TARGET_CLASSES)[y.argmax(-1)]
|
162 |
train_i, test_i = random_split(np.arange(len(x)), [0.8, 0.2])
|
@@ -166,13 +186,18 @@ def train_decision_tree(config:dict):
|
|
166 |
model.fit(train_x, train_y)
|
167 |
model.save()
|
168 |
|
|
|
169 |
if __name__ == "__main__":
|
170 |
-
parser = ArgumentParser(
|
171 |
-
|
172 |
-
|
173 |
-
|
|
|
|
|
|
|
|
|
174 |
args = parser.parse_args()
|
175 |
config = get_config(args.config)
|
176 |
training_id = config["global"]["id"]
|
177 |
train = get_training_fn(training_id)
|
178 |
-
train(config)
|
|
|
7 |
import pytorch_lightning as pl
|
8 |
from pytorch_lightning import callbacks as cb
|
9 |
from models.utils import LabelWeightedBCELoss
|
10 |
+
from models.audio_spectrogram_transformer import (
|
11 |
+
train as train_audio_spectrogram_transformer,
|
12 |
+
get_id_label_mapping,
|
13 |
+
)
|
14 |
from preprocessing.dataset import SongDataset, WaveformTrainingEnvironment
|
15 |
from preprocessing.preprocess import get_examples
|
16 |
from models.residual import ResidualDancer, TrainingEnvironment
|
17 |
from models.decision_tree import DanceTreeClassifier, features_from_path
|
18 |
import yaml
|
19 |
+
from preprocessing.dataset import (
|
20 |
+
DanceDataModule,
|
21 |
+
WaveformSongDataset,
|
22 |
+
HuggingFaceWaveformSongDataset,
|
23 |
+
)
|
24 |
from torch.utils.data import random_split
|
25 |
import numpy as np
|
26 |
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
|
27 |
from argparse import ArgumentParser
|
28 |
|
29 |
|
|
|
30 |
import torch
|
31 |
from torch import nn
|
32 |
from sklearn.utils.class_weight import compute_class_weight
|
33 |
|
34 |
+
|
35 |
+
def get_training_fn(id: str) -> Callable:
|
36 |
match id:
|
37 |
case "ast_ptl":
|
38 |
return train_ast_lightning
|
|
|
45 |
case _:
|
46 |
raise Exception(f"Couldn't find a training function for '{id}'.")
|
47 |
|
48 |
+
|
49 |
+
def get_config(filepath: str) -> dict:
|
50 |
with open(filepath, "r") as f:
|
51 |
config = yaml.safe_load(f)
|
52 |
return config
|
|
|
56 |
df = pd.read_csv("data/songs.csv")
|
57 |
g_config = config["global"]
|
58 |
batch_size = config["data_module"]["batch_size"]
|
59 |
+
x, y = get_examples(df, "data/samples", class_list=g_config["dance_ids"])
|
60 |
+
dataset = SongDataset(x, y)
|
61 |
+
splits = KFold(n_splits=k, shuffle=True, random_state=g_config["seed"])
|
62 |
trainer = pl.Trainer(accelerator=g_config["device"])
|
63 |
+
for fold, (train_idx, val_idx) in enumerate(splits.split(x, y)):
|
64 |
print(f"Fold {fold+1}")
|
65 |
model = ResidualDancer(n_classes=len(g_config["dance_ids"]))
|
66 |
+
train_env = TrainingEnvironment(model, nn.BCELoss())
|
67 |
train_sampler = SubsetRandomSampler(train_idx)
|
68 |
test_sampler = SubsetRandomSampler(val_idx)
|
69 |
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
|
|
|
72 |
trainer.test(train_env, test_loader)
|
73 |
|
74 |
|
75 |
+
def train_model(config: dict):
|
76 |
TARGET_CLASSES = config["global"]["dance_ids"]
|
77 |
DEVICE = config["global"]["device"]
|
78 |
SEED = config["global"]["seed"]
|
79 |
pl.seed_everything(SEED, workers=True)
|
80 |
+
data = DanceDataModule(target_classes=TARGET_CLASSES, **config["data_module"])
|
81 |
+
model = ResidualDancer(n_classes=len(TARGET_CLASSES), **config["model"])
|
82 |
label_weights = data.get_label_weights().to(DEVICE)
|
83 |
+
criterion = LabelWeightedBCELoss(
|
84 |
+
label_weights
|
85 |
+
) # nn.CrossEntropyLoss(label_weights)
|
86 |
train_env = TrainingEnvironment(model, criterion, config)
|
87 |
callbacks = [
|
88 |
# cb.LearningRateFinder(update_attr=True),
|
|
|
91 |
cb.RichProgressBar(),
|
92 |
cb.DeviceStatsMonitor(),
|
93 |
]
|
94 |
+
trainer = pl.Trainer(callbacks=callbacks, **config["trainer"])
|
|
|
|
|
|
|
95 |
trainer.fit(train_env, datamodule=data)
|
96 |
trainer.test(train_env, datamodule=data)
|
97 |
|
98 |
|
99 |
+
def train_ast(config: dict):
|
|
|
|
|
100 |
TARGET_CLASSES = config["global"]["dance_ids"]
|
101 |
DEVICE = config["global"]["device"]
|
102 |
SEED = config["global"]["seed"]
|
103 |
dataset_kwargs = config["data_module"]["dataset_kwargs"]
|
104 |
test_proportion = config["data_module"].get("test_proportion", 0.2)
|
105 |
+
train_proportion = 1.0 - test_proportion
|
106 |
+
song_data_path = "data/songs_cleaned.csv"
|
107 |
song_audio_path = "data/samples"
|
108 |
pl.seed_everything(SEED, workers=True)
|
109 |
|
110 |
df = pd.read_csv(song_data_path)
|
111 |
+
x, y = get_examples(
|
112 |
+
df, song_audio_path, class_list=TARGET_CLASSES, multi_label=True
|
113 |
+
)
|
114 |
+
train_i, test_i = random_split(
|
115 |
+
np.arange(len(x)), [train_proportion, test_proportion]
|
116 |
+
)
|
117 |
+
train_ds = HuggingFaceWaveformSongDataset(
|
118 |
+
x[train_i], y[train_i], **dataset_kwargs, resample_frequency=16000
|
119 |
+
)
|
120 |
+
test_ds = HuggingFaceWaveformSongDataset(
|
121 |
+
x[test_i], y[test_i], **dataset_kwargs, resample_frequency=16000
|
122 |
+
)
|
123 |
+
train_audio_spectrogram_transformer(
|
124 |
+
TARGET_CLASSES, train_ds, test_ds, device=DEVICE
|
125 |
+
)
|
126 |
+
|
127 |
+
|
128 |
+
def train_ast_lightning(config: dict):
|
129 |
"""
|
130 |
work on integration between waveform dataset and environment. Should work for both HF and PTL.
|
131 |
"""
|
|
|
133 |
DEVICE = config["global"]["device"]
|
134 |
SEED = config["global"]["seed"]
|
135 |
pl.seed_everything(SEED, workers=True)
|
136 |
+
data = DanceDataModule(
|
137 |
+
target_classes=TARGET_CLASSES,
|
138 |
+
dataset_cls=WaveformSongDataset,
|
139 |
+
**config["data_module"],
|
140 |
+
)
|
141 |
id2label, label2id = get_id_label_mapping(TARGET_CLASSES)
|
142 |
model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"
|
143 |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
|
144 |
|
145 |
model = AutoModelForAudioClassification.from_pretrained(
|
146 |
+
model_checkpoint,
|
147 |
+
num_labels=len(label2id),
|
148 |
+
label2id=label2id,
|
149 |
+
id2label=id2label,
|
150 |
+
ignore_mismatched_sizes=True,
|
151 |
+
).to(DEVICE)
|
152 |
label_weights = data.get_label_weights().to(DEVICE)
|
153 |
+
criterion = LabelWeightedBCELoss(
|
154 |
+
label_weights
|
155 |
+
) # nn.CrossEntropyLoss(label_weights)
|
156 |
+
train_env = WaveformTrainingEnvironment(model, criterion, feature_extractor, config)
|
157 |
callbacks = [
|
158 |
# cb.LearningRateFinder(update_attr=True),
|
159 |
cb.EarlyStopping("val/loss", patience=5),
|
160 |
cb.StochasticWeightAveraging(1e-2),
|
161 |
+
cb.RichProgressBar(),
|
162 |
]
|
163 |
+
trainer = pl.Trainer(callbacks=callbacks, **config["trainer"])
|
|
|
|
|
|
|
164 |
trainer.fit(train_env, datamodule=data)
|
165 |
trainer.test(train_env, datamodule=data)
|
166 |
|
167 |
|
168 |
+
def train_decision_tree(config: dict):
|
169 |
TARGET_CLASSES = config["global"]["dance_ids"]
|
170 |
DEVICE = config["global"]["device"]
|
171 |
SEED = config["global"]["seed"]
|
172 |
+
song_data_path = config["data_module"]["song_data_path"]
|
173 |
+
song_audio_path = config["data_module"]["song_audio_path"]
|
174 |
pl.seed_everything(SEED, workers=True)
|
175 |
|
176 |
df = pd.read_csv(song_data_path)
|
177 |
+
x, y = get_examples(
|
178 |
+
df, song_audio_path, class_list=TARGET_CLASSES, multi_label=True
|
179 |
+
)
|
180 |
# Convert y back to string classes
|
181 |
y = np.array(TARGET_CLASSES)[y.argmax(-1)]
|
182 |
train_i, test_i = random_split(np.arange(len(x)), [0.8, 0.2])
|
|
|
186 |
model.fit(train_x, train_y)
|
187 |
model.save()
|
188 |
|
189 |
+
|
190 |
if __name__ == "__main__":
|
191 |
+
parser = ArgumentParser(
|
192 |
+
description="Trains models on the dance dataset and saves weights."
|
193 |
+
)
|
194 |
+
parser.add_argument(
|
195 |
+
"--config",
|
196 |
+
help="Path to the yaml file that defines the training configuration.",
|
197 |
+
default="models/config/train_local.yaml",
|
198 |
+
)
|
199 |
args = parser.parse_args()
|
200 |
config = get_config(args.config)
|
201 |
training_id = config["global"]["id"]
|
202 |
train = get_training_fn(training_id)
|
203 |
+
train(config)
|