waidhoferj commited on
Commit
c914273
·
0 Parent(s):

first commit

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.wav filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__
2
+ .DS_Store
3
+ data
4
+ logs
5
+ gradio_cached_examples
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import gradio as gr
3
+ import numpy as np
4
+ import torch
5
+ from preprocessing.preprocess import AudioPipeline
6
+ from preprocessing.preprocess import AudioPipeline
7
+ from dancer_net.dancer_net import ShortChunkCNN
8
+ import os
9
+ import json
10
+ from functools import cache
11
+ import pandas as pd
12
+
13
+ @cache
14
+ def get_model(device) -> tuple[ShortChunkCNN, np.ndarray]:
15
+ model_path = "logs/20221226-230930"
16
+ weights = os.path.join(model_path, "dancer_net.pt")
17
+ config_path = os.path.join(model_path, "config.json")
18
+
19
+ with open(config_path) as f:
20
+ config = json.load(f)
21
+ labels = np.array(sorted(config["classes"]))
22
+
23
+ model = ShortChunkCNN(n_class=len(labels))
24
+ model.load_state_dict(torch.load(weights))
25
+ model = model.to(device).eval()
26
+ return model, labels
27
+
28
+ @cache
29
+ def get_pipeline(sample_rate:int) -> AudioPipeline:
30
+ return AudioPipeline(input_freq=sample_rate)
31
+
32
+ @cache
33
+ def get_dance_map() -> dict:
34
+ df = pd.read_csv("data/dance_mapping.csv")
35
+ return df.set_index("id").to_dict()["name"]
36
+
37
+
38
+ def predict(audio: tuple[int, np.ndarray]) -> list[str]:
39
+ sample_rate, waveform = audio
40
+
41
+ expected_duration = 6
42
+ threshold = 0.5
43
+ sample_len = sample_rate * expected_duration
44
+ device = "mps"
45
+
46
+ audio_pipeline = get_pipeline(sample_rate)
47
+ model, labels = get_model(device)
48
+
49
+ if sample_len > len(waveform):
50
+ raise gr.Error("You must record for at least 6 seconds")
51
+ if len(waveform.shape) > 1 and waveform.shape[1] > 1:
52
+ waveform = waveform.transpose(1,0)
53
+ waveform = waveform.mean(axis=0, keepdims=True)
54
+ else:
55
+ waveform = np.expand_dims(waveform, 0)
56
+ waveform = waveform[: ,:sample_len]
57
+ waveform = (waveform - waveform.min()) / (waveform.max() - waveform.min()) * 2 - 1
58
+ waveform = waveform.astype("float32")
59
+ waveform = torch.from_numpy(waveform)
60
+ spectrogram = audio_pipeline(waveform)
61
+ spectrogram = spectrogram.unsqueeze(0).to(device)
62
+
63
+ with torch.no_grad():
64
+ results = model(spectrogram)
65
+ dance_mapping = get_dance_map()
66
+ results = results.squeeze(0).detach().cpu().numpy()
67
+ result_mask = results > threshold
68
+ probs = results[result_mask]
69
+ dances = labels[result_mask]
70
+
71
+ return {dance_mapping[dance_id]:float(prob) for dance_id, prob in zip(dances, probs)} if len(dances) else "Couldn't find a dance."
72
+
73
+
74
+ def demo():
75
+ title = "Dance Classifier"
76
+ description = "Record 6 seconds of a song and find out what dance fits the music."
77
+ with gr.Blocks() as app:
78
+ gr.Markdown(f"# {title}")
79
+ gr.Markdown(description)
80
+ with gr.Tab("Record Song"):
81
+ mic_audio = gr.Audio(source="microphone", label="Song Recording")
82
+ mic_submit = gr.Button("Predict")
83
+
84
+ with gr.Tab("Upload Song") as t:
85
+ audio_file = gr.Audio(label="Song Audio File")
86
+ audio_file_submit = gr.Button("Predict")
87
+ song_samples = Path(os.path.dirname(__file__), "assets", "song-samples")
88
+ example_audio = [str(song) for song in song_samples.iterdir() if song.name[0] != '.']
89
+
90
+ labels = gr.Label(label="Dances")
91
+
92
+ gr.Markdown("## Examples")
93
+ gr.Examples(
94
+ examples=example_audio,
95
+ inputs=audio_file,
96
+ outputs=labels,
97
+ fn=predict,
98
+ )
99
+
100
+ audio_file_submit.click(fn=predict, inputs=audio_file, outputs=labels)
101
+ mic_submit.click(fn=predict, inputs=mic_audio, outputs=labels)
102
+
103
+ return app
104
+
105
+
106
+ if __name__ == "__main__":
107
+ demo().launch()
assets/song-samples/alejandro.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85f9a65fc4adb1fc0cbdbfafb7f7268a0934d97a120110d3f3a43375e59cba54
3
+ size 5292078
assets/song-samples/exs_and_ohs.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e53fe157ff687b5464e98c7d0c03d0712527c3a7ed24b6b063a328fcf7bf608
3
+ size 5292082
assets/song-samples/take_it_to_the_limit.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c69e0eeb4321c44daaaaf95dd596b1d813b9f7e9b5ef4ac5ae9fe11878d4b13b
3
+ size 5292082
dancer_net/dancer_net.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchaudio import transforms as taT, functional as taF
5
+
6
+
7
+
8
+ DEVICE = "mps"
9
+ class ShortChunkCNN(nn.Module):
10
+ def __init__(self,
11
+ n_channels=128,
12
+ sample_rate=16000,
13
+ n_class=50):
14
+ super().__init__()
15
+
16
+ # Spectrogram
17
+ self.spec_bn = nn.BatchNorm2d(1)
18
+
19
+ # CNN
20
+ self.res_layers = nn.Sequential(
21
+ Res_2d(1, n_channels, stride=2),
22
+ Res_2d(n_channels, n_channels, stride=2),
23
+ Res_2d(n_channels, n_channels*2, stride=2),
24
+ Res_2d(n_channels*2, n_channels*2, stride=2),
25
+ Res_2d(n_channels*2, n_channels*2, stride=2),
26
+ Res_2d(n_channels*2, n_channels*2, stride=2),
27
+ Res_2d(n_channels*2, n_channels*4, stride=2)
28
+ )
29
+
30
+ # Dense
31
+ self.dense1 = nn.Linear(n_channels*4, n_channels*4)
32
+ self.bn = nn.BatchNorm1d(n_channels*4)
33
+ self.dense2 = nn.Linear(n_channels*4, n_class)
34
+ self.dropout = nn.Dropout(0.3)
35
+
36
+ def forward(self, x):
37
+ x = self.spec_bn(x)
38
+
39
+ # CNN
40
+ x = self.res_layers(x)
41
+ x = x.squeeze(2)
42
+
43
+ # Global Max Pooling
44
+ if x.size(-1) != 1:
45
+ x = nn.MaxPool1d(x.size(-1))(x)
46
+ x = x.squeeze(2)
47
+
48
+ # Dense
49
+ x = self.dense1(x)
50
+ x = self.bn(x)
51
+ x = F.relu(x)
52
+ x = self.dropout(x)
53
+ x = self.dense2(x)
54
+ x = nn.Sigmoid()(x)
55
+
56
+ return x
57
+
58
+
59
+ class Res_2d(nn.Module):
60
+ def __init__(self, input_channels, output_channels, shape=3, stride=2):
61
+ super().__init__()
62
+ # convolution
63
+ self.conv_1 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
64
+ self.bn_1 = nn.BatchNorm2d(output_channels)
65
+ self.conv_2 = nn.Conv2d(output_channels, output_channels, shape, padding=shape//2)
66
+ self.bn_2 = nn.BatchNorm2d(output_channels)
67
+
68
+ # residual
69
+ self.diff = False
70
+ if (stride != 1) or (input_channels != output_channels):
71
+ self.conv_3 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
72
+ self.bn_3 = nn.BatchNorm2d(output_channels)
73
+ self.diff = True
74
+ self.relu = nn.ReLU()
75
+
76
+ def forward(self, x):
77
+ # convolution
78
+ out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x)))))
79
+
80
+ # residual
81
+ if self.diff:
82
+ x = self.bn_3(self.conv_3(x))
83
+ out = x + out
84
+ out = self.relu(out)
85
+ return out
environment.yml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: dancer-net
2
+ channels:
3
+ - anaconda
4
+ - conda-forge
5
+ dependencies:
6
+ - torchvision
7
+ - pytorch
8
+ - numpy
9
+ - pandas
10
+ - seaborn
11
+ - python=3.10
12
+ - matplotlib
13
+ - torchaudio
14
+ - bs4
15
+ - requests
16
+ - bidict
17
+ - tqdm
18
+ - pip
19
+ - gradio
20
+ prefix: /opt/homebrew/Caskroom/miniforge/base/envs/dancer-net
main.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchaudio
2
+ from preprocessing.preprocess import AudioPipeline
3
+ from dancer_net.dancer_net import ShortChunkCNN
4
+ import torch
5
+ import numpy as np
6
+ import os
7
+ import json
8
+
9
+ if __name__ == "__main__":
10
+
11
+ audio_file = "data/samples/mzm.iqskzxzx.aac.p.m4a.wav"
12
+ seconds = 6
13
+ model_path = "logs/20221226-230930"
14
+ weights = os.path.join(model_path, "dancer_net.pt")
15
+ config_path = os.path.join(model_path, "config.json")
16
+ device = "mps"
17
+ threshold = 0.5
18
+
19
+ with open(config_path) as f:
20
+ config = json.load(f)
21
+ labels = np.array(sorted(config["classes"]))
22
+
23
+ audio_pipeline = AudioPipeline()
24
+ waveform, sample_rate = torchaudio.load(audio_file)
25
+ waveform = waveform[:, :seconds * sample_rate]
26
+ spectrogram = audio_pipeline(waveform)
27
+ spectrogram = spectrogram.unsqueeze(0).to(device)
28
+
29
+ model = ShortChunkCNN(n_class=len(labels))
30
+ model.load_state_dict(torch.load(weights))
31
+ model = model.to(device).eval()
32
+
33
+ with torch.no_grad():
34
+ results = model(spectrogram)
35
+ results = results.squeeze(0).detach().cpu().numpy()
36
+ results = results > threshold
37
+ results = labels[results]
38
+ print(results)
39
+
40
+
41
+
42
+
43
+
44
+
45
+
46
+
preprocessing/dataset.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import numpy as np
4
+ import torchaudio as ta
5
+ from .preprocess import AudioPipeline
6
+
7
+
8
+ class SongDataset(Dataset):
9
+ def __init__(self,
10
+ audio_paths: list[str],
11
+ dance_labels: list[np.ndarray],
12
+ audio_duration=30, # seconds
13
+ audio_window_duration=6, # seconds
14
+ ):
15
+ assert audio_duration % audio_window_duration == 0, "Audio window should divide duration evenly."
16
+
17
+ self.audio_paths = audio_paths
18
+ self.dance_labels = dance_labels
19
+ audio_info = ta.info(audio_paths[0])
20
+ self.sample_rate = audio_info.sample_rate
21
+ self.audio_window_duration = int(audio_window_duration)
22
+ self.audio_duration = int(audio_duration)
23
+
24
+ self.audio_pipeline = AudioPipeline(input_freq=self.sample_rate)
25
+
26
+ def __len__(self):
27
+ return len(self.audio_paths) * self.audio_duration // self.audio_window_duration
28
+
29
+ def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor]:
30
+ waveform = self._waveform_from_index(idx)
31
+ spectrogram = self.audio_pipeline(waveform)
32
+
33
+ dance_labels = self._label_from_index(idx)
34
+
35
+ return spectrogram, dance_labels
36
+
37
+
38
+ def _waveform_from_index(self, idx:int) -> torch.Tensor:
39
+ audio_file_idx = idx * self.audio_window_duration // self.audio_duration
40
+ frame_offset = idx % self.audio_duration // self.audio_window_duration
41
+ num_frames = self.sample_rate * self.audio_window_duration
42
+ waveform, sample_rate = ta.load(self.audio_paths[audio_file_idx], frame_offset=frame_offset, num_frames=num_frames)
43
+ assert sample_rate == self.sample_rate, f"Expected sample rate of {self.sample_rate}. Found {sample_rate}"
44
+ return waveform
45
+
46
+
47
+ def _label_from_index(self, idx:int) -> torch.Tensor:
48
+ label_idx = idx * self.audio_window_duration // self.audio_duration
49
+ return torch.from_numpy(self.dance_labels[label_idx])
preprocessing/preprocess.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import re
4
+ import json
5
+ from pathlib import Path
6
+ import os
7
+ import torch
8
+ import torchaudio.transforms as taT
9
+
10
+ def url_to_filename(url:str) -> str:
11
+ return f"{url.split('/')[-1]}.wav"
12
+
13
+ def get_songs_with_audio(df:pd.DataFrame, audio_dir:str) -> pd.DataFrame:
14
+ audio_urls = df["Sample"].replace(".", np.nan)
15
+ audio_files = set(os.path.basename(f) for f in Path(audio_dir).iterdir())
16
+ valid_audio = audio_urls.apply(lambda url : url is not np.nan and url_to_filename(url) in audio_files)
17
+ df = df[valid_audio]
18
+ return df
19
+
20
+ def fix_dance_rating_counts(dance_ratings:pd.Series) -> pd.Series:
21
+ tag_pattern = re.compile("([A-Za-z]+)(\+|-)(\d+)")
22
+ dance_ratings = dance_ratings.apply(lambda v : json.loads(v.replace("'", "\"")))
23
+ def fix_labels(labels:dict) -> dict | float:
24
+ new_labels = {}
25
+ for k, v in labels.items():
26
+ match = tag_pattern.search(k)
27
+ if match is None:
28
+ new_labels[k] = new_labels.get(k, 0) + v
29
+ else:
30
+ k = match[1]
31
+ sign = 1 if match[2] == '+' else -1
32
+ scale = int(match[3])
33
+ new_labels[k] = new_labels.get(k, 0) + v * scale * sign
34
+ valid = any(v > 0 for v in new_labels.values())
35
+ return new_labels if valid else np.nan
36
+ return dance_ratings.apply(fix_labels)
37
+
38
+
39
+ def get_unique_labels(dance_labels:pd.Series) -> list:
40
+ labels = set()
41
+ for dances in dance_labels:
42
+ labels |= set(dances)
43
+ return sorted(labels)
44
+
45
+ def vectorize_label_probs(labels: dict[str,int], unique_labels:np.ndarray) -> np.ndarray:
46
+ """
47
+ Turns label dict into probability distribution vector based on each label count.
48
+ """
49
+ label_vec = np.zeros((len(unique_labels),), dtype="float32")
50
+ for k, v in labels.items():
51
+ item_vec = (unique_labels == k) * v
52
+ label_vec += item_vec
53
+ lv_cache = label_vec.copy()
54
+ label_vec[label_vec<0] = 0
55
+ label_vec /= label_vec.sum()
56
+ assert not any(np.isnan(label_vec)), f"Provided labels are invalid: {labels}"
57
+ return label_vec
58
+
59
+ def vectorize_multi_label(labels: dict[str,int], unique_labels:np.ndarray) -> np.ndarray:
60
+ """
61
+ Turns label dict into binary label vectors for multi-label classification.
62
+ """
63
+ probs = vectorize_label_probs(labels,unique_labels)
64
+ probs[probs > 0.0] = 1.0
65
+ return probs
66
+
67
+ def get_examples(df:pd.DataFrame, audio_dir:str, class_list=None) -> tuple[list[str], list[np.ndarray]]:
68
+ sampled_songs = get_songs_with_audio(df, audio_dir)
69
+ sampled_songs.loc[:,"DanceRating"] = fix_dance_rating_counts(sampled_songs["DanceRating"])
70
+ if class_list is not None:
71
+ class_list = set(class_list)
72
+ sampled_songs.loc[:,"DanceRating"] = sampled_songs["DanceRating"].apply(
73
+ lambda labels : {k: v for k,v in labels.items() if k in class_list}
74
+ if not pd.isna(labels) and any(label in class_list and amt > 0 for label, amt in labels.items())
75
+ else np.nan)
76
+ sampled_songs = sampled_songs.dropna(subset=["DanceRating"])
77
+ labels = sampled_songs["DanceRating"]
78
+ unique_labels = np.array(get_unique_labels(labels))
79
+ labels = labels.apply(lambda i : vectorize_multi_label(i, unique_labels))
80
+
81
+ audio_paths = [os.path.join(audio_dir, url_to_filename(url)) for url in sampled_songs["Sample"]]
82
+
83
+ return audio_paths, list(labels)
84
+
85
+ class AudioPipeline(torch.nn.Module):
86
+ def __init__(
87
+ self,
88
+ input_freq=16000,
89
+ resample_freq=16000,
90
+ ):
91
+ super().__init__()
92
+ self.resample = taT.Resample(orig_freq=input_freq, new_freq=resample_freq)
93
+ self.spec = taT.MelSpectrogram(sample_rate=resample_freq, n_mels=64, n_fft=1024)
94
+ self.to_db = taT.AmplitudeToDB()
95
+
96
+ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
97
+ if waveform.shape[0] > 1:
98
+ waveform = waveform.mean(0, keepdim=True)
99
+ waveform = self.resample(waveform)
100
+ spectrogram = self.spec(waveform)
101
+ spectrogram = self.to_db(spectrogram)
102
+
103
+ return spectrogram
104
+
scrapers/music4dance.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from bs4 import BeautifulSoup as bs
3
+ import json
4
+ import argparse
5
+ from pathlib import Path
6
+ import os
7
+ import pandas as pd
8
+ import re
9
+ from tqdm import tqdm
10
+
11
+
12
+
13
+
14
+ def scrape_song_library(page_count=2054) -> pd.DataFrame:
15
+ columns = [
16
+ "Title",
17
+ "Artist",
18
+ "Length",
19
+ "Tempo",
20
+ "Beat",
21
+ "Energy",
22
+ "Danceability",
23
+ "Valence",
24
+ "Sample",
25
+ "Tags",
26
+ "DanceRating",
27
+ ]
28
+ song_df = pd.DataFrame(columns=columns)
29
+ for i in tqdm(range(1, page_count + 1), desc="Pages processed"):
30
+ link = "https://www.music4dance.net/song/Index?filter=v2-Index&page=" + str(i)
31
+ page = requests.get(link)
32
+ soup = bs(page.content, "html.parser")
33
+ songs = pd.DataFrame(get_songs(soup))
34
+ song_df = pd.concat([song_df, songs], axis=0, ignore_index=True)
35
+ return song_df
36
+
37
+
38
+ def get_songs(soup: bs) -> dict:
39
+ js_obj = re.compile(r"{(.|\n)*}")
40
+ reset_keys = [
41
+ "Title",
42
+ "Artist",
43
+ "Length",
44
+ "Tempo",
45
+ "Beat",
46
+ "Energy",
47
+ "Danceability",
48
+ "Valence",
49
+ "Sample",
50
+ ]
51
+ song_text = [str(v) for v in soup.find_all("script") if "histories" in str(v)][0]
52
+ songs_data = json.loads(js_obj.search(song_text).group(0))
53
+ songs = []
54
+ for song_data in songs_data["histories"]:
55
+ song = {"Tags": set(), "DanceRating": {}}
56
+ for feature in song_data["properties"]:
57
+ if "name" not in feature or "value" not in feature:
58
+ continue
59
+ key = feature["name"]
60
+ value = feature["value"]
61
+ if key in reset_keys:
62
+ song[key] = value
63
+ elif key == "Tag+":
64
+ song["Tags"].add(value)
65
+ elif key == "DeleteTag":
66
+ try:
67
+ song["Tags"].remove(value)
68
+ except:
69
+ continue
70
+ elif key == "DanceRating":
71
+ dance = value.replace("+1", "")
72
+ prev = song["DanceRating"].get(dance, 0)
73
+ song["DanceRating"][dance] = prev + 1
74
+ songs.append(song)
75
+ return songs
76
+
77
+
78
+ def download_song(url: str, out_dir: str):
79
+ response = requests.get(url)
80
+ filename = url.split("/")[-1]
81
+ out_file = Path(out_dir, f"{filename}.mp3")
82
+ with open(out_file, "wb") as f:
83
+ f.write(response.content)
84
+
85
+ def scrape_dance_info() -> pd.DataFrame:
86
+ js_obj = re.compile(r"{(.|\n)*}")
87
+ link = "https://www.music4dance.net/song/Index?filter=v2-Index"
88
+ page = requests.get(link)
89
+ soup = bs(page.content, "html.parser")
90
+
91
+ dance_info_text = [str(v) for v in soup.find_all("script") if "environment" in str(v)][0]
92
+ dance_info = json.loads(js_obj.search(dance_info_text).group(0))
93
+ dance_info = dance_info["dances"]
94
+ wanted_keys = ["name", "id", "synonyms", "tempoRange", "songCount"]
95
+ dance_df = pd.DataFrame([{k:v for k, v in dance.items() if k in wanted_keys}
96
+ for dance
97
+ in dance_info])
98
+ return dance_df
99
+
100
+
101
+
102
+ if __name__ == "__main__":
103
+ parser = argparse.ArgumentParser()
104
+ parser.add_argument("--page-count", default=2, type=int)
105
+ parser.add_argument("--out", default="data/song.csv")
106
+
107
+ args = parser.parse_args()
108
+ out_path = Path(args.out)
109
+ out_dir = os.path.dirname(out_path)
110
+ if not os.path.exists(out_dir):
111
+ print(f"Output location does not exist: {out_dir}")
112
+ df = scrape_song_library(args.page_count)
113
+ df.to_csv(out_path)
train.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+ import torch.nn as nn
6
+ from tqdm import tqdm
7
+ import pandas as pd
8
+ import numpy as np
9
+ from torch.utils.data import random_split, SubsetRandomSampler
10
+ import json
11
+ from sklearn.model_selection import KFold
12
+
13
+ from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
14
+ from preprocessing.dataset import SongDataset
15
+ from preprocessing.preprocess import get_examples
16
+ from dancer_net.dancer_net import ShortChunkCNN
17
+
18
+ DEVICE = "mps"
19
+ SEED = 42
20
+
21
+ def get_timestamp() -> str:
22
+ return datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
23
+
24
+ class EarlyStopping:
25
+ def __init__(self, patience=0):
26
+ self.patience = patience
27
+ self.last_measure = np.inf
28
+ self.consecutive_increase = 0
29
+
30
+ def step(self, val) -> bool:
31
+ if self.last_measure <= val:
32
+ self.consecutive_increase +=1
33
+ else:
34
+ self.consecutive_increase = 0
35
+ self.last_measure = val
36
+
37
+ return self.patience < self.consecutive_increase
38
+
39
+
40
+
41
+ def calculate_metrics(pred, target, threshold=0.5, prefix=""):
42
+ target = target.detach().cpu().numpy()
43
+ pred = pred.detach().cpu().numpy()
44
+ pred = np.array(pred > threshold, dtype=float)
45
+ metrics= {
46
+ 'precision': precision_score(y_true=target, y_pred=pred, average='macro', zero_division=0),
47
+ 'recall': recall_score(y_true=target, y_pred=pred, average='macro', zero_division=0),
48
+ 'f1': f1_score(y_true=target, y_pred=pred, average='macro', zero_division=0),
49
+ 'accuracy': accuracy_score(y_true=target, y_pred=pred),
50
+ }
51
+ if prefix != "":
52
+ metrics = {prefix + k : v for k, v in metrics.items()}
53
+
54
+ return metrics
55
+
56
+
57
+ def evaluate(model:nn.Module, data_loader:DataLoader, criterion, device="mps") -> pd.Series:
58
+ val_metrics = []
59
+ for features, labels in (prog_bar := tqdm(data_loader)):
60
+ features = features.to(device)
61
+ labels = labels.to(device)
62
+ with torch.no_grad():
63
+ outputs = model(features)
64
+ loss = criterion(outputs, labels)
65
+ batch_metrics = calculate_metrics(outputs, labels, prefix="val_")
66
+ batch_metrics["val_loss"] = loss.item()
67
+ prog_bar.set_description(f'Validation - Loss: {batch_metrics["val_loss"]:.2f}, Accuracy: {batch_metrics["val_accuracy"]:.2f}')
68
+ val_metrics.append(batch_metrics)
69
+ return pd.DataFrame(val_metrics).mean()
70
+
71
+
72
+
73
+ def train(
74
+ model: nn.Module,
75
+ data_loader: DataLoader,
76
+ val_loader=None,
77
+ epochs=3,
78
+ lr=1e-3,
79
+ device="mps"):
80
+ criterion = nn.BCELoss()
81
+ optimizer = torch.optim.Adam(model.parameters(),lr=lr)
82
+ early_stop = EarlyStopping(1)
83
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr,
84
+ steps_per_epoch=int(len(data_loader)),
85
+ epochs=epochs,
86
+ anneal_strategy='linear')
87
+ metrics = []
88
+ for epoch in range(1,epochs+1):
89
+ train_metrics = []
90
+ prog_bar = tqdm(data_loader)
91
+ for features, labels in prog_bar:
92
+ features = features.to(device)
93
+ labels = labels.to(device)
94
+ optimizer.zero_grad()
95
+ outputs = model(features)
96
+ loss = criterion(outputs, labels)
97
+ loss.backward()
98
+ optimizer.step()
99
+ scheduler.step()
100
+ batch_metrics = calculate_metrics(outputs, labels)
101
+ batch_metrics["loss"] = loss.item()
102
+ train_metrics.append(batch_metrics)
103
+ prog_bar.set_description(f'Training - Epoch: {epoch}/{epochs}, Loss: {batch_metrics["loss"]:.2f}, Accuracy: {batch_metrics["accuracy"]:.2f}')
104
+ train_metrics = pd.DataFrame(train_metrics).mean()
105
+ if val_loader is not None:
106
+ val_metrics = evaluate(model, val_loader, criterion)
107
+ if early_stop.step(val_metrics["val_f1"]):
108
+ break
109
+ epoch_metrics = pd.concat([train_metrics, val_metrics], axis=0)
110
+ else:
111
+ epoch_metrics = train_metrics
112
+ metrics.append(dict(epoch_metrics))
113
+
114
+ return model, metrics
115
+
116
+
117
+ def cross_validation(seed=42, batch_size=64, k=5, device="mps"):
118
+ target_classes = ['ATN',
119
+ 'BBA',
120
+ 'BCH',
121
+ 'BLU',
122
+ 'CHA',
123
+ 'CMB',
124
+ 'CSG',
125
+ 'ECS',
126
+ 'HST',
127
+ 'JIV',
128
+ 'LHP',
129
+ 'QST',
130
+ 'RMB',
131
+ 'SFT',
132
+ 'SLS',
133
+ 'SMB',
134
+ 'SWZ',
135
+ 'TGO',
136
+ 'VWZ',
137
+ 'WCS']
138
+ df = pd.read_csv("data/songs.csv")
139
+ x,y = get_examples(df, "data/samples",class_list=target_classes)
140
+
141
+ dataset = SongDataset(x,y)
142
+ splits=KFold(n_splits=k,shuffle=True,random_state=seed)
143
+ metrics = []
144
+ for fold, (train_idx,val_idx) in enumerate(splits.split(x,y)):
145
+ print(f"Fold {fold+1}")
146
+
147
+ train_sampler = SubsetRandomSampler(train_idx)
148
+ test_sampler = SubsetRandomSampler(val_idx)
149
+ train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
150
+ test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)
151
+ n_classes = len(y[0])
152
+ model = ShortChunkCNN(n_class=n_classes).to(device)
153
+ model, _ = train(model,train_loader, epochs=2, device=device)
154
+ val_metrics = evaluate(model, test_loader, nn.BCELoss())
155
+ metrics.append(val_metrics)
156
+ metrics = pd.DataFrame(metrics)
157
+ log_dir = os.path.join(
158
+ "logs", get_timestamp()
159
+ )
160
+ os.makedirs(log_dir, exist_ok=True)
161
+
162
+ metrics.to_csv(model.state_dict(), os.path.join(log_dir, "cross_val.csv"))
163
+
164
+
165
+
166
+ def train_model():
167
+ target_classes = ['ATN',
168
+ 'BBA',
169
+ 'BCH',
170
+ 'BLU',
171
+ 'CHA',
172
+ 'CMB',
173
+ 'CSG',
174
+ 'ECS',
175
+ 'HST',
176
+ 'JIV',
177
+ 'LHP',
178
+ 'QST',
179
+ 'RMB',
180
+ 'SFT',
181
+ 'SLS',
182
+ 'SMB',
183
+ 'SWZ',
184
+ 'TGO',
185
+ 'VWZ',
186
+ 'WCS']
187
+ df = pd.read_csv("data/songs.csv")
188
+ x,y = get_examples(df, "data/samples",class_list=target_classes)
189
+ dataset = SongDataset(x,y)
190
+ train_count = int(len(dataset) * 0.9)
191
+ datasets = random_split(dataset, [train_count, len(dataset) - train_count], torch.Generator().manual_seed(SEED))
192
+ data_loaders = [DataLoader(data, batch_size=64, shuffle=True) for data in datasets]
193
+ train_data, val_data = data_loaders
194
+ example_spec, example_label = dataset[0]
195
+ n_classes = len(example_label)
196
+ model = ShortChunkCNN(n_class=n_classes).to(DEVICE)
197
+ model, metrics = train(model,train_data, val_data, epochs=3, device=DEVICE)
198
+
199
+ log_dir = os.path.join(
200
+ "logs", get_timestamp()
201
+ )
202
+ os.makedirs(log_dir, exist_ok=True)
203
+
204
+ torch.save(model.state_dict(), os.path.join(log_dir, "dancer_net.pt"))
205
+ metrics = pd.DataFrame(metrics)
206
+ metrics.to_csv(os.path.join(log_dir, "metrics.csv"))
207
+ config = {
208
+ "classes": target_classes
209
+ }
210
+ with open(os.path.join(log_dir, "config.json")) as f:
211
+ json.dump(config, f)
212
+ print("Training information saved!")
213
+
214
+ if __name__ == "__main__":
215
+ cross_validation()