Spaces:
Runtime error
Runtime error
waidhoferj
commited on
Commit
Β·
e6fd727
1
Parent(s):
0030bc6
added AST model
Browse files- TODO.md +2 -1
- models/audio_spectrogram_transformer.py +72 -0
- models/config/train.yaml +7 -5
- models/residual.py +1 -0
- preprocessing/dataset.py +129 -2
- preprocessing/preprocess.py +7 -8
- train.py +91 -5
TODO.md
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
- β
Ensure app.py audio input sounds like training data
|
|
|
2 |
- Verify that the training spectrogram matches the predict spectrogram
|
3 |
- Count number of example misses in dataset loading
|
4 |
- Verify windowing and jitter params in Song Dataset
|
@@ -7,4 +8,4 @@
|
|
7 |
- Verify that labels really match what is on the music4dance site
|
8 |
- Read the Medium series about audio DL
|
9 |
- double check \_rectify_duration
|
10 |
-
- Filter out songs that have only one vote
|
|
|
1 |
- β
Ensure app.py audio input sounds like training data
|
2 |
+
- β
Use a huggingface transformer with the dataset
|
3 |
- Verify that the training spectrogram matches the predict spectrogram
|
4 |
- Count number of example misses in dataset loading
|
5 |
- Verify windowing and jitter params in Song Dataset
|
|
|
8 |
- Verify that labels really match what is on the music4dance site
|
9 |
- Read the Medium series about audio DL
|
10 |
- double check \_rectify_duration
|
11 |
+
- β
Filter out songs that have only one vote
|
models/audio_spectrogram_transformer.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import ASTFeatureExtractor, AutoFeatureExtractor, ASTConfig, AutoModelForAudioClassification, TrainingArguments, Trainer
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from sklearn.utils.class_weight import compute_class_weight
|
5 |
+
import evaluate
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
accuracy = evaluate.load("accuracy")
|
9 |
+
|
10 |
+
def compute_metrics(eval_pred):
|
11 |
+
predictions = np.argmax(eval_pred.predictions, axis=1)
|
12 |
+
return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
|
13 |
+
|
14 |
+
def get_id_label_mapping(labels:list[str]) -> tuple[dict, dict]:
|
15 |
+
id2label = {str(i) : label for i, label in enumerate(labels)}
|
16 |
+
label2id = {label : str(i) for i, label in enumerate(labels)}
|
17 |
+
|
18 |
+
return id2label, label2id
|
19 |
+
|
20 |
+
def train(
|
21 |
+
labels,
|
22 |
+
train_ds,
|
23 |
+
test_ds,
|
24 |
+
output_dir="models/weights/ast",
|
25 |
+
device="cpu",
|
26 |
+
batch_size=128,
|
27 |
+
epochs=10):
|
28 |
+
id2label, label2id = get_id_label_mapping(labels)
|
29 |
+
model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"
|
30 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
|
31 |
+
preprocess_waveform = lambda wf : feature_extractor(wf, sampling_rate=train_ds.resample_frequency, padding="max_length", return_tensors="pt")
|
32 |
+
train_ds.map(preprocess_waveform)
|
33 |
+
test_ds.map(preprocess_waveform)
|
34 |
+
|
35 |
+
model = AutoModelForAudioClassification.from_pretrained(
|
36 |
+
model_checkpoint,
|
37 |
+
num_labels=len(labels),
|
38 |
+
label2id=label2id,
|
39 |
+
id2label=id2label,
|
40 |
+
ignore_mismatched_sizes=True
|
41 |
+
).to(device)
|
42 |
+
training_args = TrainingArguments(
|
43 |
+
output_dir=output_dir,
|
44 |
+
evaluation_strategy="epoch",
|
45 |
+
save_strategy="epoch",
|
46 |
+
learning_rate=5e-5,
|
47 |
+
per_device_train_batch_size=batch_size,
|
48 |
+
gradient_accumulation_steps=5,
|
49 |
+
per_device_eval_batch_size=batch_size,
|
50 |
+
num_train_epochs=epochs,
|
51 |
+
warmup_ratio=0.1,
|
52 |
+
logging_steps=10,
|
53 |
+
load_best_model_at_end=True,
|
54 |
+
metric_for_best_model="accuracy",
|
55 |
+
push_to_hub=False,
|
56 |
+
use_mps_device=device == "mps"
|
57 |
+
)
|
58 |
+
|
59 |
+
trainer = Trainer(
|
60 |
+
model=model,
|
61 |
+
args=training_args,
|
62 |
+
train_dataset=train_ds,
|
63 |
+
eval_dataset=test_ds,
|
64 |
+
tokenizer=feature_extractor,
|
65 |
+
compute_metrics=compute_metrics,
|
66 |
+
)
|
67 |
+
trainer.train()
|
68 |
+
return model
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
|
models/config/train.yaml
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
global:
|
|
|
2 |
device: mps
|
3 |
seed: 42
|
4 |
dance_ids:
|
@@ -18,11 +19,11 @@ global:
|
|
18 |
- VWZ
|
19 |
- WCS
|
20 |
data_module:
|
21 |
-
|
22 |
-
num_workers: 10
|
23 |
-
min_votes: 2
|
24 |
-
song_data_path: data/songs_cleaned.csv
|
25 |
song_audio_path: data/samples
|
|
|
|
|
|
|
26 |
dataset_kwargs:
|
27 |
audio_window_duration: 6
|
28 |
audio_window_jitter: 1.5
|
@@ -40,7 +41,8 @@ trainer:
|
|
40 |
fast_dev_run: False
|
41 |
track_grad_norm: 2
|
42 |
# gradient_clip_val: 0.5
|
|
|
43 |
training_environment:
|
44 |
-
learning_rate: 0.
|
45 |
model:
|
46 |
n_channels: 128
|
|
|
1 |
global:
|
2 |
+
id: ast_ptl
|
3 |
device: mps
|
4 |
seed: 42
|
5 |
dance_ids:
|
|
|
19 |
- VWZ
|
20 |
- WCS
|
21 |
data_module:
|
22 |
+
song_data_path: data/samples/songs_cleaned.csv
|
|
|
|
|
|
|
23 |
song_audio_path: data/samples
|
24 |
+
batch_size: 256
|
25 |
+
num_workers: 10
|
26 |
+
min_votes: 1
|
27 |
dataset_kwargs:
|
28 |
audio_window_duration: 6
|
29 |
audio_window_jitter: 1.5
|
|
|
41 |
fast_dev_run: False
|
42 |
track_grad_norm: 2
|
43 |
# gradient_clip_val: 0.5
|
44 |
+
overfit_batches: 1
|
45 |
training_environment:
|
46 |
+
learning_rate: 0.00053
|
47 |
model:
|
48 |
n_channels: 128
|
models/residual.py
CHANGED
@@ -136,6 +136,7 @@ class TrainingEnvironment(pl.LightningModule):
|
|
136 |
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
137 |
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') {"scheduler": scheduler, "monitor": "val/loss"}
|
138 |
return [optimizer]
|
|
|
139 |
|
140 |
|
141 |
class DancePredictor:
|
|
|
136 |
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
137 |
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') {"scheduler": scheduler, "monitor": "val/loss"}
|
138 |
return [optimizer]
|
139 |
+
|
140 |
|
141 |
|
142 |
class DancePredictor:
|
preprocessing/dataset.py
CHANGED
@@ -7,6 +7,9 @@ from .pipelines import AudioTrainingPipeline
|
|
7 |
import pytorch_lightning as pl
|
8 |
from .preprocess import get_examples
|
9 |
from sklearn.model_selection import train_test_split
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
|
@@ -81,6 +84,54 @@ class SongDataset(Dataset):
|
|
81 |
|
82 |
def _label_from_index(self, idx:int) -> torch.Tensor:
|
83 |
return torch.from_numpy(self.dance_labels[self._convert_idx(idx)])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
class DanceDataModule(pl.LightningDataModule):
|
86 |
def __init__(self,
|
@@ -92,6 +143,7 @@ class DanceDataModule(pl.LightningDataModule):
|
|
92 |
min_votes=1,
|
93 |
batch_size:int=64,
|
94 |
num_workers=10,
|
|
|
95 |
dataset_kwargs={}
|
96 |
):
|
97 |
super().__init__()
|
@@ -104,6 +156,7 @@ class DanceDataModule(pl.LightningDataModule):
|
|
104 |
self.batch_size = batch_size
|
105 |
self.num_workers = num_workers
|
106 |
self.dataset_kwargs = dataset_kwargs
|
|
|
107 |
|
108 |
df = pd.read_csv(song_data_path)
|
109 |
self.x,self.y = get_examples(df, self.song_audio_path,class_list=self.target_classes, multi_label=True, min_votes=min_votes)
|
@@ -115,7 +168,7 @@ class DanceDataModule(pl.LightningDataModule):
|
|
115 |
self.test_ds = self._dataset_from_indices(test_i)
|
116 |
|
117 |
def _dataset_from_indices(self, idx:list[int]) -> SongDataset:
|
118 |
-
return
|
119 |
|
120 |
def train_dataloader(self):
|
121 |
return DataLoader(self.train_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
|
@@ -128,4 +181,78 @@ class DanceDataModule(pl.LightningDataModule):
|
|
128 |
|
129 |
def get_label_weights(self):
|
130 |
n_examples, n_classes = self.y.shape
|
131 |
-
return torch.from_numpy(n_examples / (n_classes * sum(self.y)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
import pytorch_lightning as pl
|
8 |
from .preprocess import get_examples
|
9 |
from sklearn.model_selection import train_test_split
|
10 |
+
from torchaudio import transforms as taT
|
11 |
+
from torch import nn
|
12 |
+
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
|
13 |
|
14 |
|
15 |
|
|
|
84 |
|
85 |
def _label_from_index(self, idx:int) -> torch.Tensor:
|
86 |
return torch.from_numpy(self.dance_labels[self._convert_idx(idx)])
|
87 |
+
|
88 |
+
|
89 |
+
class WaveformSongDataset(SongDataset):
|
90 |
+
"""
|
91 |
+
Outputs raw waveforms of the data instead of a spectrogram.
|
92 |
+
"""
|
93 |
+
|
94 |
+
def __init__(self, *args,resample_frequency=16000, **kwargs):
|
95 |
+
super().__init__(*args, **kwargs)
|
96 |
+
self.resample_frequency = resample_frequency
|
97 |
+
self.resampler = taT.Resample(self.sample_rate, self.resample_frequency)
|
98 |
+
self.pipeline = []
|
99 |
+
|
100 |
+
def __getitem__(self, idx:int) -> dict[str, torch.Tensor]:
|
101 |
+
waveform = self._waveform_from_index(idx)
|
102 |
+
assert waveform.shape[1] > 10, f"No data found: {self._backtrace_audio_path(idx)}"
|
103 |
+
# resample the waveform
|
104 |
+
waveform = self.resampler(waveform)
|
105 |
+
|
106 |
+
waveform = waveform.mean(0)
|
107 |
+
|
108 |
+
dance_labels = self._label_from_index(idx)
|
109 |
+
return waveform, dance_labels
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
class HuggingFaceWaveformSongDataset(WaveformSongDataset):
|
115 |
+
|
116 |
+
def __init__(self, *args, **kwargs):
|
117 |
+
super().__init__(*args, **kwargs)
|
118 |
+
self.pipeline = []
|
119 |
+
|
120 |
+
|
121 |
+
def __getitem__(self, idx:int) -> dict[str, torch.Tensor]:
|
122 |
+
x,y = super().__getitem__(idx)
|
123 |
+
if len(self.pipeline) > 0:
|
124 |
+
for fn in self.pipeline:
|
125 |
+
x = fn(x)
|
126 |
+
|
127 |
+
dance_labels = y.argmax()
|
128 |
+
return {"input_values": x["input_values"][0] if hasattr(x, "input_values") else x, "label": dance_labels}
|
129 |
+
|
130 |
+
def map(self,fn):
|
131 |
+
"""
|
132 |
+
NOTE this mutates the original, doesn't return a copy like normal maps.
|
133 |
+
"""
|
134 |
+
self.pipeline.append(fn)
|
135 |
|
136 |
class DanceDataModule(pl.LightningDataModule):
|
137 |
def __init__(self,
|
|
|
143 |
min_votes=1,
|
144 |
batch_size:int=64,
|
145 |
num_workers=10,
|
146 |
+
dataset_cls = None,
|
147 |
dataset_kwargs={}
|
148 |
):
|
149 |
super().__init__()
|
|
|
156 |
self.batch_size = batch_size
|
157 |
self.num_workers = num_workers
|
158 |
self.dataset_kwargs = dataset_kwargs
|
159 |
+
self.dataset_cls = dataset_cls if dataset_cls is not None else SongDataset
|
160 |
|
161 |
df = pd.read_csv(song_data_path)
|
162 |
self.x,self.y = get_examples(df, self.song_audio_path,class_list=self.target_classes, multi_label=True, min_votes=min_votes)
|
|
|
168 |
self.test_ds = self._dataset_from_indices(test_i)
|
169 |
|
170 |
def _dataset_from_indices(self, idx:list[int]) -> SongDataset:
|
171 |
+
return self.dataset_cls(self.x[idx], self.y[idx], **self.dataset_kwargs)
|
172 |
|
173 |
def train_dataloader(self):
|
174 |
return DataLoader(self.train_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
|
|
|
181 |
|
182 |
def get_label_weights(self):
|
183 |
n_examples, n_classes = self.y.shape
|
184 |
+
return torch.from_numpy(n_examples / (n_classes * sum(self.y)))
|
185 |
+
|
186 |
+
|
187 |
+
class WaveformTrainingEnvironment(pl.LightningModule):
|
188 |
+
|
189 |
+
def __init__(self, model: nn.Module, criterion: nn.Module, feature_extractor, config:dict, learning_rate=1e-4, *args, **kwargs):
|
190 |
+
super().__init__(*args, **kwargs)
|
191 |
+
self.model = model
|
192 |
+
self.criterion = criterion
|
193 |
+
self.learning_rate = learning_rate
|
194 |
+
self.config=config
|
195 |
+
self.feature_extractor=feature_extractor
|
196 |
+
self.save_hyperparameters({
|
197 |
+
"model": type(model).__name__,
|
198 |
+
"loss": type(criterion).__name__,
|
199 |
+
"config": config,
|
200 |
+
**kwargs
|
201 |
+
})
|
202 |
+
|
203 |
+
def preprocess_inputs(self, x):
|
204 |
+
device = x.device
|
205 |
+
x = x.squeeze(1).cpu().numpy()
|
206 |
+
x = self.feature_extractor(list(x),return_tensors='pt', sampling_rate=16000)
|
207 |
+
return x["input_values"].to(device)
|
208 |
+
|
209 |
+
def training_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int) -> torch.Tensor:
|
210 |
+
features, labels = batch
|
211 |
+
features = self.preprocess_inputs(features)
|
212 |
+
outputs = self.model(features).logits
|
213 |
+
outputs = nn.Sigmoid()(outputs) # good for multi label classification, should be softmax otherwise
|
214 |
+
loss = self.criterion(outputs, labels)
|
215 |
+
metrics = calculate_metrics(outputs, labels, prefix="train/", multi_label=True)
|
216 |
+
self.log_dict(metrics, prog_bar=True)
|
217 |
+
return loss
|
218 |
+
|
219 |
+
|
220 |
+
def validation_step(self, batch:tuple[torch.Tensor, torch.TensorType], batch_index:int):
|
221 |
+
x,y = batch
|
222 |
+
x = self.preprocess_inputs(x)
|
223 |
+
preds = self.model(x).logits
|
224 |
+
preds = nn.Sigmoid()(preds)
|
225 |
+
metrics = calculate_metrics(preds, y, prefix="val/", multi_label=True)
|
226 |
+
metrics["val/loss"] = self.criterion(preds, y)
|
227 |
+
self.log_dict(metrics,prog_bar=True)
|
228 |
+
|
229 |
+
def test_step(self, batch:tuple[torch.Tensor, torch.TensorType], batch_index:int):
|
230 |
+
x, y = batch
|
231 |
+
x = self.preprocess_inputs(x)
|
232 |
+
preds = self.model(x).logits
|
233 |
+
preds = nn.Sigmoid()(preds)
|
234 |
+
self.log_dict(calculate_metrics(preds, y, prefix="test/", multi_label=True), prog_bar=True)
|
235 |
+
|
236 |
+
def configure_optimizers(self):
|
237 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
238 |
+
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') {"scheduler": scheduler, "monitor": "val/loss"}
|
239 |
+
return [optimizer]
|
240 |
+
|
241 |
+
|
242 |
+
|
243 |
+
def calculate_metrics(pred, target, threshold=0.5, prefix="", multi_label=True) -> dict[str, torch.Tensor]:
|
244 |
+
target = target.detach().cpu().numpy()
|
245 |
+
pred = pred.detach().cpu().numpy()
|
246 |
+
params = {
|
247 |
+
"y_true": target if multi_label else target.argmax(1) ,
|
248 |
+
"y_pred": np.array(pred > threshold, dtype=float) if multi_label else pred.argmax(1),
|
249 |
+
"zero_division": 0,
|
250 |
+
"average":"macro"
|
251 |
+
}
|
252 |
+
metrics= {
|
253 |
+
'precision': precision_score(**params),
|
254 |
+
'recall': recall_score(**params),
|
255 |
+
'f1': f1_score(**params),
|
256 |
+
'accuracy': accuracy_score(y_true=params["y_true"], y_pred=params["y_pred"]),
|
257 |
+
}
|
258 |
+
return {prefix + k: torch.tensor(v,dtype=torch.float32) for k,v in metrics.items()}
|
preprocessing/preprocess.py
CHANGED
@@ -11,12 +11,11 @@ from tqdm import tqdm
|
|
11 |
def url_to_filename(url:str) -> str:
|
12 |
return f"{url.split('/')[-1]}.wav"
|
13 |
|
14 |
-
def
|
15 |
-
audio_urls =
|
16 |
audio_files = set(os.path.basename(f) for f in Path(audio_dir).iterdir())
|
17 |
-
|
18 |
-
|
19 |
-
return df
|
20 |
|
21 |
def validate_audio(audio_urls:pd.Series, audio_dir:str) -> pd.Series:
|
22 |
"""
|
@@ -95,11 +94,11 @@ def vectorize_multi_label(labels: dict[str,int], unique_labels:np.ndarray) -> np
|
|
95 |
return probs
|
96 |
|
97 |
def get_examples(df:pd.DataFrame, audio_dir:str, class_list=None, multi_label=True, min_votes=1) -> tuple[np.ndarray, np.ndarray]:
|
98 |
-
sampled_songs =
|
99 |
-
sampled_songs
|
100 |
if class_list is not None:
|
101 |
class_list = set(class_list)
|
102 |
-
sampled_songs
|
103 |
lambda labels : {k: v for k,v in labels.items() if k in class_list}
|
104 |
if not pd.isna(labels) and any(label in class_list and amt > 0 for label, amt in labels.items())
|
105 |
else np.nan)
|
|
|
11 |
def url_to_filename(url:str) -> str:
|
12 |
return f"{url.split('/')[-1]}.wav"
|
13 |
|
14 |
+
def has_valid_audio(audio_urls:pd.Series, audio_dir:str) -> pd.Series:
|
15 |
+
audio_urls = audio_urls.replace(".", np.nan)
|
16 |
audio_files = set(os.path.basename(f) for f in Path(audio_dir).iterdir())
|
17 |
+
valid_audio_mask = audio_urls.apply(lambda url : url is not np.nan and url_to_filename(url) in audio_files)
|
18 |
+
return valid_audio_mask
|
|
|
19 |
|
20 |
def validate_audio(audio_urls:pd.Series, audio_dir:str) -> pd.Series:
|
21 |
"""
|
|
|
94 |
return probs
|
95 |
|
96 |
def get_examples(df:pd.DataFrame, audio_dir:str, class_list=None, multi_label=True, min_votes=1) -> tuple[np.ndarray, np.ndarray]:
|
97 |
+
sampled_songs = df[has_valid_audio(df["Sample"], audio_dir)]
|
98 |
+
sampled_songs["DanceRating"] = fix_dance_rating_counts(sampled_songs["DanceRating"])
|
99 |
if class_list is not None:
|
100 |
class_list = set(class_list)
|
101 |
+
sampled_songs["DanceRating"] = sampled_songs["DanceRating"].apply(
|
102 |
lambda labels : {k: v for k,v in labels.items() if k in class_list}
|
103 |
if not pd.isna(labels) and any(label in class_list and amt > 0 for label, amt in labels.items())
|
104 |
else np.nan)
|
train.py
CHANGED
@@ -1,23 +1,47 @@
|
|
1 |
from torch.utils.data import DataLoader
|
2 |
import pandas as pd
|
|
|
3 |
from torch import nn
|
4 |
from torch.utils.data import SubsetRandomSampler
|
5 |
from sklearn.model_selection import KFold
|
6 |
import pytorch_lightning as pl
|
7 |
from pytorch_lightning import callbacks as cb
|
8 |
from models.utils import LabelWeightedBCELoss
|
9 |
-
from
|
|
|
10 |
from preprocessing.preprocess import get_examples
|
11 |
from models.residual import ResidualDancer, TrainingEnvironment
|
12 |
import yaml
|
13 |
-
from preprocessing.dataset import DanceDataModule
|
|
|
14 |
from wakepy import keepawake
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
def get_config(filepath:str) -> dict:
|
17 |
with open(filepath, "r") as f:
|
18 |
config = yaml.safe_load(f)
|
19 |
return config
|
20 |
|
|
|
21 |
def cross_validation(config, k=5):
|
22 |
df = pd.read_csv("data/songs.csv")
|
23 |
g_config = config["global"]
|
@@ -52,7 +76,8 @@ def train_model(config:dict):
|
|
52 |
# cb.LearningRateFinder(update_attr=True),
|
53 |
cb.EarlyStopping("val/loss", patience=5),
|
54 |
cb.StochasticWeightAveraging(1e-2),
|
55 |
-
cb.RichProgressBar()
|
|
|
56 |
]
|
57 |
trainer = pl.Trainer(
|
58 |
callbacks=callbacks,
|
@@ -62,8 +87,69 @@ def train_model(config:dict):
|
|
62 |
trainer.test(train_env, datamodule=data)
|
63 |
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
if __name__ == "__main__":
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
68 |
with keepawake():
|
69 |
-
|
|
|
1 |
from torch.utils.data import DataLoader
|
2 |
import pandas as pd
|
3 |
+
from typing import Callable
|
4 |
from torch import nn
|
5 |
from torch.utils.data import SubsetRandomSampler
|
6 |
from sklearn.model_selection import KFold
|
7 |
import pytorch_lightning as pl
|
8 |
from pytorch_lightning import callbacks as cb
|
9 |
from models.utils import LabelWeightedBCELoss
|
10 |
+
from models.audio_spectrogram_transformer import train as train_audio_spectrogram_transformer, get_id_label_mapping
|
11 |
+
from preprocessing.dataset import SongDataset, WaveformTrainingEnvironment
|
12 |
from preprocessing.preprocess import get_examples
|
13 |
from models.residual import ResidualDancer, TrainingEnvironment
|
14 |
import yaml
|
15 |
+
from preprocessing.dataset import DanceDataModule, WaveformSongDataset, HuggingFaceWaveformSongDataset
|
16 |
+
from torch.utils.data import random_split
|
17 |
from wakepy import keepawake
|
18 |
+
import numpy as np
|
19 |
+
from transformers import ASTFeatureExtractor, AutoFeatureExtractor, ASTConfig, AutoModelForAudioClassification
|
20 |
+
from argparse import ArgumentParser
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
import torch
|
25 |
+
from torch import nn
|
26 |
+
from sklearn.utils.class_weight import compute_class_weight
|
27 |
+
|
28 |
+
def get_training_fn(id:str) -> Callable:
|
29 |
+
match id:
|
30 |
+
case "ast_ptl":
|
31 |
+
return train_ast_lightning
|
32 |
+
case "ast_hf":
|
33 |
+
return train_ast
|
34 |
+
case "residual_dancer":
|
35 |
+
return train_model
|
36 |
+
case _:
|
37 |
+
raise Exception(f"Couldn't find a training function for '{id}'.")
|
38 |
|
39 |
def get_config(filepath:str) -> dict:
|
40 |
with open(filepath, "r") as f:
|
41 |
config = yaml.safe_load(f)
|
42 |
return config
|
43 |
|
44 |
+
|
45 |
def cross_validation(config, k=5):
|
46 |
df = pd.read_csv("data/songs.csv")
|
47 |
g_config = config["global"]
|
|
|
76 |
# cb.LearningRateFinder(update_attr=True),
|
77 |
cb.EarlyStopping("val/loss", patience=5),
|
78 |
cb.StochasticWeightAveraging(1e-2),
|
79 |
+
cb.RichProgressBar(),
|
80 |
+
cb.DeviceStatsMonitor(),
|
81 |
]
|
82 |
trainer = pl.Trainer(
|
83 |
callbacks=callbacks,
|
|
|
87 |
trainer.test(train_env, datamodule=data)
|
88 |
|
89 |
|
90 |
+
def train_ast(
|
91 |
+
config:dict
|
92 |
+
):
|
93 |
+
TARGET_CLASSES = config["global"]["dance_ids"]
|
94 |
+
DEVICE = config["global"]["device"]
|
95 |
+
SEED = config["global"]["seed"]
|
96 |
+
dataset_kwargs = config["data_module"]["dataset_kwargs"]
|
97 |
+
test_proportion = config["data_module"].get("test_proportion", 0.2)
|
98 |
+
train_proportion = 1. - test_proportion
|
99 |
+
song_data_path="data/songs_cleaned.csv"
|
100 |
+
song_audio_path = "data/samples"
|
101 |
+
pl.seed_everything(SEED, workers=True)
|
102 |
+
|
103 |
+
df = pd.read_csv(song_data_path)
|
104 |
+
x, y = get_examples(df, song_audio_path,class_list=TARGET_CLASSES, multi_label=True)
|
105 |
+
train_i, test_i = random_split(np.arange(len(x)), [train_proportion, test_proportion])
|
106 |
+
train_ds = HuggingFaceWaveformSongDataset(x[train_i], y[train_i], **dataset_kwargs, resample_frequency=16000)
|
107 |
+
test_ds = HuggingFaceWaveformSongDataset(x[test_i], y[test_i], **dataset_kwargs, resample_frequency=16000)
|
108 |
+
train_audio_spectrogram_transformer(TARGET_CLASSES, train_ds, test_ds, device=DEVICE)
|
109 |
+
|
110 |
+
|
111 |
+
def train_ast_lightning(config:dict):
|
112 |
+
"""
|
113 |
+
work on integration between waveform dataset and environment. Should work for both HF and PTL.
|
114 |
+
"""
|
115 |
+
TARGET_CLASSES = config["global"]["dance_ids"]
|
116 |
+
DEVICE = config["global"]["device"]
|
117 |
+
SEED = config["global"]["seed"]
|
118 |
+
pl.seed_everything(SEED, workers=True)
|
119 |
+
data = DanceDataModule(target_classes=TARGET_CLASSES, dataset_cls=WaveformSongDataset, **config['data_module'])
|
120 |
+
id2label, label2id = get_id_label_mapping(TARGET_CLASSES)
|
121 |
+
model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"
|
122 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
|
123 |
+
|
124 |
+
model = AutoModelForAudioClassification.from_pretrained(
|
125 |
+
model_checkpoint,
|
126 |
+
num_labels=len(label2id),
|
127 |
+
label2id=label2id,
|
128 |
+
id2label=id2label,
|
129 |
+
ignore_mismatched_sizes=True
|
130 |
+
).to(DEVICE)
|
131 |
+
label_weights = data.get_label_weights().to(DEVICE)
|
132 |
+
criterion = LabelWeightedBCELoss(label_weights) #nn.CrossEntropyLoss(label_weights)
|
133 |
+
train_env = WaveformTrainingEnvironment(model, criterion,feature_extractor, config)
|
134 |
+
callbacks = [
|
135 |
+
# cb.LearningRateFinder(update_attr=True),
|
136 |
+
cb.EarlyStopping("val/loss", patience=5),
|
137 |
+
cb.StochasticWeightAveraging(1e-2),
|
138 |
+
cb.RichProgressBar()
|
139 |
+
]
|
140 |
+
trainer = pl.Trainer(
|
141 |
+
callbacks=callbacks,
|
142 |
+
**config["trainer"]
|
143 |
+
)
|
144 |
+
trainer.fit(train_env, datamodule=data)
|
145 |
+
trainer.test(train_env, datamodule=data)
|
146 |
|
147 |
if __name__ == "__main__":
|
148 |
+
parser = ArgumentParser(description="Trains models on the dance dataset and saves weights.")
|
149 |
+
parser.add_argument("--config", help="Path to the yaml file that defines the training configuration.", default="models/config/train.yaml")
|
150 |
+
args = parser.parse_args()
|
151 |
+
config = get_config(args.config)
|
152 |
+
training_id = config["global"]["id"]
|
153 |
+
train = get_training_fn(training_id)
|
154 |
with keepawake():
|
155 |
+
train(config)
|