waidhoferj commited on
Commit
e6fd727
Β·
1 Parent(s): 0030bc6

added AST model

Browse files
TODO.md CHANGED
@@ -1,4 +1,5 @@
1
  - βœ… Ensure app.py audio input sounds like training data
 
2
  - Verify that the training spectrogram matches the predict spectrogram
3
  - Count number of example misses in dataset loading
4
  - Verify windowing and jitter params in Song Dataset
@@ -7,4 +8,4 @@
7
  - Verify that labels really match what is on the music4dance site
8
  - Read the Medium series about audio DL
9
  - double check \_rectify_duration
10
- - Filter out songs that have only one vote
 
1
  - βœ… Ensure app.py audio input sounds like training data
2
+ - βœ… Use a huggingface transformer with the dataset
3
  - Verify that the training spectrogram matches the predict spectrogram
4
  - Count number of example misses in dataset loading
5
  - Verify windowing and jitter params in Song Dataset
 
8
  - Verify that labels really match what is on the music4dance site
9
  - Read the Medium series about audio DL
10
  - double check \_rectify_duration
11
+ - βœ… Filter out songs that have only one vote
models/audio_spectrogram_transformer.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ASTFeatureExtractor, AutoFeatureExtractor, ASTConfig, AutoModelForAudioClassification, TrainingArguments, Trainer
2
+ import torch
3
+ from torch import nn
4
+ from sklearn.utils.class_weight import compute_class_weight
5
+ import evaluate
6
+ import numpy as np
7
+
8
+ accuracy = evaluate.load("accuracy")
9
+
10
+ def compute_metrics(eval_pred):
11
+ predictions = np.argmax(eval_pred.predictions, axis=1)
12
+ return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
13
+
14
+ def get_id_label_mapping(labels:list[str]) -> tuple[dict, dict]:
15
+ id2label = {str(i) : label for i, label in enumerate(labels)}
16
+ label2id = {label : str(i) for i, label in enumerate(labels)}
17
+
18
+ return id2label, label2id
19
+
20
+ def train(
21
+ labels,
22
+ train_ds,
23
+ test_ds,
24
+ output_dir="models/weights/ast",
25
+ device="cpu",
26
+ batch_size=128,
27
+ epochs=10):
28
+ id2label, label2id = get_id_label_mapping(labels)
29
+ model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"
30
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
31
+ preprocess_waveform = lambda wf : feature_extractor(wf, sampling_rate=train_ds.resample_frequency, padding="max_length", return_tensors="pt")
32
+ train_ds.map(preprocess_waveform)
33
+ test_ds.map(preprocess_waveform)
34
+
35
+ model = AutoModelForAudioClassification.from_pretrained(
36
+ model_checkpoint,
37
+ num_labels=len(labels),
38
+ label2id=label2id,
39
+ id2label=id2label,
40
+ ignore_mismatched_sizes=True
41
+ ).to(device)
42
+ training_args = TrainingArguments(
43
+ output_dir=output_dir,
44
+ evaluation_strategy="epoch",
45
+ save_strategy="epoch",
46
+ learning_rate=5e-5,
47
+ per_device_train_batch_size=batch_size,
48
+ gradient_accumulation_steps=5,
49
+ per_device_eval_batch_size=batch_size,
50
+ num_train_epochs=epochs,
51
+ warmup_ratio=0.1,
52
+ logging_steps=10,
53
+ load_best_model_at_end=True,
54
+ metric_for_best_model="accuracy",
55
+ push_to_hub=False,
56
+ use_mps_device=device == "mps"
57
+ )
58
+
59
+ trainer = Trainer(
60
+ model=model,
61
+ args=training_args,
62
+ train_dataset=train_ds,
63
+ eval_dataset=test_ds,
64
+ tokenizer=feature_extractor,
65
+ compute_metrics=compute_metrics,
66
+ )
67
+ trainer.train()
68
+ return model
69
+
70
+
71
+
72
+
models/config/train.yaml CHANGED
@@ -1,4 +1,5 @@
1
  global:
 
2
  device: mps
3
  seed: 42
4
  dance_ids:
@@ -18,11 +19,11 @@ global:
18
  - VWZ
19
  - WCS
20
  data_module:
21
- batch_size: 1024
22
- num_workers: 10
23
- min_votes: 2
24
- song_data_path: data/songs_cleaned.csv
25
  song_audio_path: data/samples
 
 
 
26
  dataset_kwargs:
27
  audio_window_duration: 6
28
  audio_window_jitter: 1.5
@@ -40,7 +41,8 @@ trainer:
40
  fast_dev_run: False
41
  track_grad_norm: 2
42
  # gradient_clip_val: 0.5
 
43
  training_environment:
44
- learning_rate: 0.0033
45
  model:
46
  n_channels: 128
 
1
  global:
2
+ id: ast_ptl
3
  device: mps
4
  seed: 42
5
  dance_ids:
 
19
  - VWZ
20
  - WCS
21
  data_module:
22
+ song_data_path: data/samples/songs_cleaned.csv
 
 
 
23
  song_audio_path: data/samples
24
+ batch_size: 256
25
+ num_workers: 10
26
+ min_votes: 1
27
  dataset_kwargs:
28
  audio_window_duration: 6
29
  audio_window_jitter: 1.5
 
41
  fast_dev_run: False
42
  track_grad_norm: 2
43
  # gradient_clip_val: 0.5
44
+ overfit_batches: 1
45
  training_environment:
46
+ learning_rate: 0.00053
47
  model:
48
  n_channels: 128
models/residual.py CHANGED
@@ -136,6 +136,7 @@ class TrainingEnvironment(pl.LightningModule):
136
  optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
137
  # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') {"scheduler": scheduler, "monitor": "val/loss"}
138
  return [optimizer]
 
139
 
140
 
141
  class DancePredictor:
 
136
  optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
137
  # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') {"scheduler": scheduler, "monitor": "val/loss"}
138
  return [optimizer]
139
+
140
 
141
 
142
  class DancePredictor:
preprocessing/dataset.py CHANGED
@@ -7,6 +7,9 @@ from .pipelines import AudioTrainingPipeline
7
  import pytorch_lightning as pl
8
  from .preprocess import get_examples
9
  from sklearn.model_selection import train_test_split
 
 
 
10
 
11
 
12
 
@@ -81,6 +84,54 @@ class SongDataset(Dataset):
81
 
82
  def _label_from_index(self, idx:int) -> torch.Tensor:
83
  return torch.from_numpy(self.dance_labels[self._convert_idx(idx)])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  class DanceDataModule(pl.LightningDataModule):
86
  def __init__(self,
@@ -92,6 +143,7 @@ class DanceDataModule(pl.LightningDataModule):
92
  min_votes=1,
93
  batch_size:int=64,
94
  num_workers=10,
 
95
  dataset_kwargs={}
96
  ):
97
  super().__init__()
@@ -104,6 +156,7 @@ class DanceDataModule(pl.LightningDataModule):
104
  self.batch_size = batch_size
105
  self.num_workers = num_workers
106
  self.dataset_kwargs = dataset_kwargs
 
107
 
108
  df = pd.read_csv(song_data_path)
109
  self.x,self.y = get_examples(df, self.song_audio_path,class_list=self.target_classes, multi_label=True, min_votes=min_votes)
@@ -115,7 +168,7 @@ class DanceDataModule(pl.LightningDataModule):
115
  self.test_ds = self._dataset_from_indices(test_i)
116
 
117
  def _dataset_from_indices(self, idx:list[int]) -> SongDataset:
118
- return SongDataset(self.x[idx], self.y[idx], **self.dataset_kwargs)
119
 
120
  def train_dataloader(self):
121
  return DataLoader(self.train_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
@@ -128,4 +181,78 @@ class DanceDataModule(pl.LightningDataModule):
128
 
129
  def get_label_weights(self):
130
  n_examples, n_classes = self.y.shape
131
- return torch.from_numpy(n_examples / (n_classes * sum(self.y)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import pytorch_lightning as pl
8
  from .preprocess import get_examples
9
  from sklearn.model_selection import train_test_split
10
+ from torchaudio import transforms as taT
11
+ from torch import nn
12
+ from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
13
 
14
 
15
 
 
84
 
85
  def _label_from_index(self, idx:int) -> torch.Tensor:
86
  return torch.from_numpy(self.dance_labels[self._convert_idx(idx)])
87
+
88
+
89
+ class WaveformSongDataset(SongDataset):
90
+ """
91
+ Outputs raw waveforms of the data instead of a spectrogram.
92
+ """
93
+
94
+ def __init__(self, *args,resample_frequency=16000, **kwargs):
95
+ super().__init__(*args, **kwargs)
96
+ self.resample_frequency = resample_frequency
97
+ self.resampler = taT.Resample(self.sample_rate, self.resample_frequency)
98
+ self.pipeline = []
99
+
100
+ def __getitem__(self, idx:int) -> dict[str, torch.Tensor]:
101
+ waveform = self._waveform_from_index(idx)
102
+ assert waveform.shape[1] > 10, f"No data found: {self._backtrace_audio_path(idx)}"
103
+ # resample the waveform
104
+ waveform = self.resampler(waveform)
105
+
106
+ waveform = waveform.mean(0)
107
+
108
+ dance_labels = self._label_from_index(idx)
109
+ return waveform, dance_labels
110
+
111
+
112
+
113
+
114
+ class HuggingFaceWaveformSongDataset(WaveformSongDataset):
115
+
116
+ def __init__(self, *args, **kwargs):
117
+ super().__init__(*args, **kwargs)
118
+ self.pipeline = []
119
+
120
+
121
+ def __getitem__(self, idx:int) -> dict[str, torch.Tensor]:
122
+ x,y = super().__getitem__(idx)
123
+ if len(self.pipeline) > 0:
124
+ for fn in self.pipeline:
125
+ x = fn(x)
126
+
127
+ dance_labels = y.argmax()
128
+ return {"input_values": x["input_values"][0] if hasattr(x, "input_values") else x, "label": dance_labels}
129
+
130
+ def map(self,fn):
131
+ """
132
+ NOTE this mutates the original, doesn't return a copy like normal maps.
133
+ """
134
+ self.pipeline.append(fn)
135
 
136
  class DanceDataModule(pl.LightningDataModule):
137
  def __init__(self,
 
143
  min_votes=1,
144
  batch_size:int=64,
145
  num_workers=10,
146
+ dataset_cls = None,
147
  dataset_kwargs={}
148
  ):
149
  super().__init__()
 
156
  self.batch_size = batch_size
157
  self.num_workers = num_workers
158
  self.dataset_kwargs = dataset_kwargs
159
+ self.dataset_cls = dataset_cls if dataset_cls is not None else SongDataset
160
 
161
  df = pd.read_csv(song_data_path)
162
  self.x,self.y = get_examples(df, self.song_audio_path,class_list=self.target_classes, multi_label=True, min_votes=min_votes)
 
168
  self.test_ds = self._dataset_from_indices(test_i)
169
 
170
  def _dataset_from_indices(self, idx:list[int]) -> SongDataset:
171
+ return self.dataset_cls(self.x[idx], self.y[idx], **self.dataset_kwargs)
172
 
173
  def train_dataloader(self):
174
  return DataLoader(self.train_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
 
181
 
182
  def get_label_weights(self):
183
  n_examples, n_classes = self.y.shape
184
+ return torch.from_numpy(n_examples / (n_classes * sum(self.y)))
185
+
186
+
187
+ class WaveformTrainingEnvironment(pl.LightningModule):
188
+
189
+ def __init__(self, model: nn.Module, criterion: nn.Module, feature_extractor, config:dict, learning_rate=1e-4, *args, **kwargs):
190
+ super().__init__(*args, **kwargs)
191
+ self.model = model
192
+ self.criterion = criterion
193
+ self.learning_rate = learning_rate
194
+ self.config=config
195
+ self.feature_extractor=feature_extractor
196
+ self.save_hyperparameters({
197
+ "model": type(model).__name__,
198
+ "loss": type(criterion).__name__,
199
+ "config": config,
200
+ **kwargs
201
+ })
202
+
203
+ def preprocess_inputs(self, x):
204
+ device = x.device
205
+ x = x.squeeze(1).cpu().numpy()
206
+ x = self.feature_extractor(list(x),return_tensors='pt', sampling_rate=16000)
207
+ return x["input_values"].to(device)
208
+
209
+ def training_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int) -> torch.Tensor:
210
+ features, labels = batch
211
+ features = self.preprocess_inputs(features)
212
+ outputs = self.model(features).logits
213
+ outputs = nn.Sigmoid()(outputs) # good for multi label classification, should be softmax otherwise
214
+ loss = self.criterion(outputs, labels)
215
+ metrics = calculate_metrics(outputs, labels, prefix="train/", multi_label=True)
216
+ self.log_dict(metrics, prog_bar=True)
217
+ return loss
218
+
219
+
220
+ def validation_step(self, batch:tuple[torch.Tensor, torch.TensorType], batch_index:int):
221
+ x,y = batch
222
+ x = self.preprocess_inputs(x)
223
+ preds = self.model(x).logits
224
+ preds = nn.Sigmoid()(preds)
225
+ metrics = calculate_metrics(preds, y, prefix="val/", multi_label=True)
226
+ metrics["val/loss"] = self.criterion(preds, y)
227
+ self.log_dict(metrics,prog_bar=True)
228
+
229
+ def test_step(self, batch:tuple[torch.Tensor, torch.TensorType], batch_index:int):
230
+ x, y = batch
231
+ x = self.preprocess_inputs(x)
232
+ preds = self.model(x).logits
233
+ preds = nn.Sigmoid()(preds)
234
+ self.log_dict(calculate_metrics(preds, y, prefix="test/", multi_label=True), prog_bar=True)
235
+
236
+ def configure_optimizers(self):
237
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
238
+ # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') {"scheduler": scheduler, "monitor": "val/loss"}
239
+ return [optimizer]
240
+
241
+
242
+
243
+ def calculate_metrics(pred, target, threshold=0.5, prefix="", multi_label=True) -> dict[str, torch.Tensor]:
244
+ target = target.detach().cpu().numpy()
245
+ pred = pred.detach().cpu().numpy()
246
+ params = {
247
+ "y_true": target if multi_label else target.argmax(1) ,
248
+ "y_pred": np.array(pred > threshold, dtype=float) if multi_label else pred.argmax(1),
249
+ "zero_division": 0,
250
+ "average":"macro"
251
+ }
252
+ metrics= {
253
+ 'precision': precision_score(**params),
254
+ 'recall': recall_score(**params),
255
+ 'f1': f1_score(**params),
256
+ 'accuracy': accuracy_score(y_true=params["y_true"], y_pred=params["y_pred"]),
257
+ }
258
+ return {prefix + k: torch.tensor(v,dtype=torch.float32) for k,v in metrics.items()}
preprocessing/preprocess.py CHANGED
@@ -11,12 +11,11 @@ from tqdm import tqdm
11
  def url_to_filename(url:str) -> str:
12
  return f"{url.split('/')[-1]}.wav"
13
 
14
- def get_songs_with_audio(df:pd.DataFrame, audio_dir:str) -> pd.DataFrame:
15
- audio_urls = df["Sample"].replace(".", np.nan)
16
  audio_files = set(os.path.basename(f) for f in Path(audio_dir).iterdir())
17
- valid_audio = audio_urls.apply(lambda url : url is not np.nan and url_to_filename(url) in audio_files)
18
- df = df[valid_audio]
19
- return df
20
 
21
  def validate_audio(audio_urls:pd.Series, audio_dir:str) -> pd.Series:
22
  """
@@ -95,11 +94,11 @@ def vectorize_multi_label(labels: dict[str,int], unique_labels:np.ndarray) -> np
95
  return probs
96
 
97
  def get_examples(df:pd.DataFrame, audio_dir:str, class_list=None, multi_label=True, min_votes=1) -> tuple[np.ndarray, np.ndarray]:
98
- sampled_songs = get_songs_with_audio(df, audio_dir)
99
- sampled_songs.loc[:,"DanceRating"] = fix_dance_rating_counts(sampled_songs["DanceRating"])
100
  if class_list is not None:
101
  class_list = set(class_list)
102
- sampled_songs.loc[:,"DanceRating"] = sampled_songs["DanceRating"].apply(
103
  lambda labels : {k: v for k,v in labels.items() if k in class_list}
104
  if not pd.isna(labels) and any(label in class_list and amt > 0 for label, amt in labels.items())
105
  else np.nan)
 
11
  def url_to_filename(url:str) -> str:
12
  return f"{url.split('/')[-1]}.wav"
13
 
14
+ def has_valid_audio(audio_urls:pd.Series, audio_dir:str) -> pd.Series:
15
+ audio_urls = audio_urls.replace(".", np.nan)
16
  audio_files = set(os.path.basename(f) for f in Path(audio_dir).iterdir())
17
+ valid_audio_mask = audio_urls.apply(lambda url : url is not np.nan and url_to_filename(url) in audio_files)
18
+ return valid_audio_mask
 
19
 
20
  def validate_audio(audio_urls:pd.Series, audio_dir:str) -> pd.Series:
21
  """
 
94
  return probs
95
 
96
  def get_examples(df:pd.DataFrame, audio_dir:str, class_list=None, multi_label=True, min_votes=1) -> tuple[np.ndarray, np.ndarray]:
97
+ sampled_songs = df[has_valid_audio(df["Sample"], audio_dir)]
98
+ sampled_songs["DanceRating"] = fix_dance_rating_counts(sampled_songs["DanceRating"])
99
  if class_list is not None:
100
  class_list = set(class_list)
101
+ sampled_songs["DanceRating"] = sampled_songs["DanceRating"].apply(
102
  lambda labels : {k: v for k,v in labels.items() if k in class_list}
103
  if not pd.isna(labels) and any(label in class_list and amt > 0 for label, amt in labels.items())
104
  else np.nan)
train.py CHANGED
@@ -1,23 +1,47 @@
1
  from torch.utils.data import DataLoader
2
  import pandas as pd
 
3
  from torch import nn
4
  from torch.utils.data import SubsetRandomSampler
5
  from sklearn.model_selection import KFold
6
  import pytorch_lightning as pl
7
  from pytorch_lightning import callbacks as cb
8
  from models.utils import LabelWeightedBCELoss
9
- from preprocessing.dataset import SongDataset
 
10
  from preprocessing.preprocess import get_examples
11
  from models.residual import ResidualDancer, TrainingEnvironment
12
  import yaml
13
- from preprocessing.dataset import DanceDataModule
 
14
  from wakepy import keepawake
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def get_config(filepath:str) -> dict:
17
  with open(filepath, "r") as f:
18
  config = yaml.safe_load(f)
19
  return config
20
 
 
21
  def cross_validation(config, k=5):
22
  df = pd.read_csv("data/songs.csv")
23
  g_config = config["global"]
@@ -52,7 +76,8 @@ def train_model(config:dict):
52
  # cb.LearningRateFinder(update_attr=True),
53
  cb.EarlyStopping("val/loss", patience=5),
54
  cb.StochasticWeightAveraging(1e-2),
55
- cb.RichProgressBar()
 
56
  ]
57
  trainer = pl.Trainer(
58
  callbacks=callbacks,
@@ -62,8 +87,69 @@ def train_model(config:dict):
62
  trainer.test(train_env, datamodule=data)
63
 
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  if __name__ == "__main__":
67
- config = get_config("models/config/train.yaml")
 
 
 
 
 
68
  with keepawake():
69
- train_model(config)
 
1
  from torch.utils.data import DataLoader
2
  import pandas as pd
3
+ from typing import Callable
4
  from torch import nn
5
  from torch.utils.data import SubsetRandomSampler
6
  from sklearn.model_selection import KFold
7
  import pytorch_lightning as pl
8
  from pytorch_lightning import callbacks as cb
9
  from models.utils import LabelWeightedBCELoss
10
+ from models.audio_spectrogram_transformer import train as train_audio_spectrogram_transformer, get_id_label_mapping
11
+ from preprocessing.dataset import SongDataset, WaveformTrainingEnvironment
12
  from preprocessing.preprocess import get_examples
13
  from models.residual import ResidualDancer, TrainingEnvironment
14
  import yaml
15
+ from preprocessing.dataset import DanceDataModule, WaveformSongDataset, HuggingFaceWaveformSongDataset
16
+ from torch.utils.data import random_split
17
  from wakepy import keepawake
18
+ import numpy as np
19
+ from transformers import ASTFeatureExtractor, AutoFeatureExtractor, ASTConfig, AutoModelForAudioClassification
20
+ from argparse import ArgumentParser
21
+
22
+
23
+
24
+ import torch
25
+ from torch import nn
26
+ from sklearn.utils.class_weight import compute_class_weight
27
+
28
+ def get_training_fn(id:str) -> Callable:
29
+ match id:
30
+ case "ast_ptl":
31
+ return train_ast_lightning
32
+ case "ast_hf":
33
+ return train_ast
34
+ case "residual_dancer":
35
+ return train_model
36
+ case _:
37
+ raise Exception(f"Couldn't find a training function for '{id}'.")
38
 
39
  def get_config(filepath:str) -> dict:
40
  with open(filepath, "r") as f:
41
  config = yaml.safe_load(f)
42
  return config
43
 
44
+
45
  def cross_validation(config, k=5):
46
  df = pd.read_csv("data/songs.csv")
47
  g_config = config["global"]
 
76
  # cb.LearningRateFinder(update_attr=True),
77
  cb.EarlyStopping("val/loss", patience=5),
78
  cb.StochasticWeightAveraging(1e-2),
79
+ cb.RichProgressBar(),
80
+ cb.DeviceStatsMonitor(),
81
  ]
82
  trainer = pl.Trainer(
83
  callbacks=callbacks,
 
87
  trainer.test(train_env, datamodule=data)
88
 
89
 
90
+ def train_ast(
91
+ config:dict
92
+ ):
93
+ TARGET_CLASSES = config["global"]["dance_ids"]
94
+ DEVICE = config["global"]["device"]
95
+ SEED = config["global"]["seed"]
96
+ dataset_kwargs = config["data_module"]["dataset_kwargs"]
97
+ test_proportion = config["data_module"].get("test_proportion", 0.2)
98
+ train_proportion = 1. - test_proportion
99
+ song_data_path="data/songs_cleaned.csv"
100
+ song_audio_path = "data/samples"
101
+ pl.seed_everything(SEED, workers=True)
102
+
103
+ df = pd.read_csv(song_data_path)
104
+ x, y = get_examples(df, song_audio_path,class_list=TARGET_CLASSES, multi_label=True)
105
+ train_i, test_i = random_split(np.arange(len(x)), [train_proportion, test_proportion])
106
+ train_ds = HuggingFaceWaveformSongDataset(x[train_i], y[train_i], **dataset_kwargs, resample_frequency=16000)
107
+ test_ds = HuggingFaceWaveformSongDataset(x[test_i], y[test_i], **dataset_kwargs, resample_frequency=16000)
108
+ train_audio_spectrogram_transformer(TARGET_CLASSES, train_ds, test_ds, device=DEVICE)
109
+
110
+
111
+ def train_ast_lightning(config:dict):
112
+ """
113
+ work on integration between waveform dataset and environment. Should work for both HF and PTL.
114
+ """
115
+ TARGET_CLASSES = config["global"]["dance_ids"]
116
+ DEVICE = config["global"]["device"]
117
+ SEED = config["global"]["seed"]
118
+ pl.seed_everything(SEED, workers=True)
119
+ data = DanceDataModule(target_classes=TARGET_CLASSES, dataset_cls=WaveformSongDataset, **config['data_module'])
120
+ id2label, label2id = get_id_label_mapping(TARGET_CLASSES)
121
+ model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"
122
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
123
+
124
+ model = AutoModelForAudioClassification.from_pretrained(
125
+ model_checkpoint,
126
+ num_labels=len(label2id),
127
+ label2id=label2id,
128
+ id2label=id2label,
129
+ ignore_mismatched_sizes=True
130
+ ).to(DEVICE)
131
+ label_weights = data.get_label_weights().to(DEVICE)
132
+ criterion = LabelWeightedBCELoss(label_weights) #nn.CrossEntropyLoss(label_weights)
133
+ train_env = WaveformTrainingEnvironment(model, criterion,feature_extractor, config)
134
+ callbacks = [
135
+ # cb.LearningRateFinder(update_attr=True),
136
+ cb.EarlyStopping("val/loss", patience=5),
137
+ cb.StochasticWeightAveraging(1e-2),
138
+ cb.RichProgressBar()
139
+ ]
140
+ trainer = pl.Trainer(
141
+ callbacks=callbacks,
142
+ **config["trainer"]
143
+ )
144
+ trainer.fit(train_env, datamodule=data)
145
+ trainer.test(train_env, datamodule=data)
146
 
147
  if __name__ == "__main__":
148
+ parser = ArgumentParser(description="Trains models on the dance dataset and saves weights.")
149
+ parser.add_argument("--config", help="Path to the yaml file that defines the training configuration.", default="models/config/train.yaml")
150
+ args = parser.parse_args()
151
+ config = get_config(args.config)
152
+ training_id = config["global"]["id"]
153
+ train = get_training_fn(training_id)
154
  with keepawake():
155
+ train(config)