Spaces:
Runtime error
Runtime error
waidhoferj
commited on
Commit
·
5da8010
1
Parent(s):
4078103
added hubert and updated preprocessing for wav2vec2
Browse files- models/hubert.py +86 -0
- models/wav2vec2.py +9 -5
models/hubert.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Any
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
from torch.utils.data import random_split
|
5 |
+
from transformers import AutoFeatureExtractor
|
6 |
+
from transformers import (
|
7 |
+
AutoModelForAudioClassification,
|
8 |
+
TrainingArguments,
|
9 |
+
Trainer,
|
10 |
+
)
|
11 |
+
|
12 |
+
from preprocessing.dataset import (
|
13 |
+
HuggingFaceDatasetWrapper,
|
14 |
+
get_datasets,
|
15 |
+
)
|
16 |
+
from preprocessing.pipelines import WaveformTrainingPipeline
|
17 |
+
|
18 |
+
from .utils import get_id_label_mapping, compute_hf_metrics
|
19 |
+
|
20 |
+
MODEL_CHECKPOINT = "ntu-spml/distilhubert"
|
21 |
+
|
22 |
+
|
23 |
+
class HubertFeatureExtractor:
|
24 |
+
def __init__(self) -> None:
|
25 |
+
self.waveform_pipeline = WaveformTrainingPipeline()
|
26 |
+
self.feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_CHECKPOINT)
|
27 |
+
|
28 |
+
def __call__(self, waveform) -> Any:
|
29 |
+
waveform = self.waveform_pipeline(waveform)
|
30 |
+
return self.feature_extractor(waveform.squeeze(0), sampling_rate=16000)
|
31 |
+
|
32 |
+
def __getattr__(self, attr):
|
33 |
+
return getattr(self.feature_extractor, attr)
|
34 |
+
|
35 |
+
|
36 |
+
def train_huggingface(config: dict):
|
37 |
+
TARGET_CLASSES = config["dance_ids"]
|
38 |
+
DEVICE = config["device"]
|
39 |
+
SEED = config["seed"]
|
40 |
+
OUTPUT_DIR = "models/weights/wav2vec2"
|
41 |
+
batch_size = config["data_module"]["batch_size"]
|
42 |
+
epochs = config["trainer"]["min_epochs"]
|
43 |
+
test_proportion = config["data_module"].get("test_proportion", 0.2)
|
44 |
+
pl.seed_everything(SEED, workers=True)
|
45 |
+
feature_extractor = HubertFeatureExtractor()
|
46 |
+
dataset = get_datasets(config["datasets"], feature_extractor)
|
47 |
+
dataset = HuggingFaceDatasetWrapper(dataset)
|
48 |
+
id2label, label2id = get_id_label_mapping(TARGET_CLASSES)
|
49 |
+
test_proportion = config["data_module"]["test_proportion"]
|
50 |
+
train_proporition = 1 - test_proportion
|
51 |
+
train_ds, test_ds = random_split(dataset, [train_proporition, test_proportion])
|
52 |
+
|
53 |
+
model = AutoModelForAudioClassification.from_pretrained(
|
54 |
+
MODEL_CHECKPOINT,
|
55 |
+
num_labels=len(TARGET_CLASSES),
|
56 |
+
label2id=label2id,
|
57 |
+
id2label=id2label,
|
58 |
+
# ignore_mismatched_sizes=True,
|
59 |
+
).to(DEVICE)
|
60 |
+
training_args = TrainingArguments(
|
61 |
+
output_dir=OUTPUT_DIR,
|
62 |
+
evaluation_strategy="epoch",
|
63 |
+
save_strategy="epoch",
|
64 |
+
learning_rate=5e-5,
|
65 |
+
per_device_train_batch_size=batch_size,
|
66 |
+
gradient_accumulation_steps=1,
|
67 |
+
gradient_checkpointing=True,
|
68 |
+
per_device_eval_batch_size=batch_size,
|
69 |
+
num_train_epochs=epochs,
|
70 |
+
warmup_ratio=0.1,
|
71 |
+
logging_steps=10,
|
72 |
+
load_best_model_at_end=True,
|
73 |
+
metric_for_best_model="accuracy",
|
74 |
+
push_to_hub=False,
|
75 |
+
use_mps_device=DEVICE == "mps",
|
76 |
+
fp16=True,
|
77 |
+
)
|
78 |
+
trainer = Trainer(
|
79 |
+
model=model,
|
80 |
+
args=training_args,
|
81 |
+
train_dataset=train_ds,
|
82 |
+
eval_dataset=test_ds,
|
83 |
+
compute_metrics=compute_hf_metrics,
|
84 |
+
)
|
85 |
+
trainer.train()
|
86 |
+
return model
|
models/wav2vec2.py
CHANGED
@@ -3,7 +3,12 @@ from typing import Any
|
|
3 |
import pytorch_lightning as pl
|
4 |
from torch.utils.data import random_split
|
5 |
from transformers import AutoFeatureExtractor
|
6 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
from preprocessing.dataset import (
|
9 |
HuggingFaceDatasetWrapper,
|
@@ -14,18 +19,17 @@ from preprocessing.pipelines import WaveformTrainingPipeline
|
|
14 |
from .utils import get_id_label_mapping, compute_hf_metrics
|
15 |
|
16 |
MODEL_CHECKPOINT = "yuval6967/wav2vec2-base-finetuned-gtzan"
|
|
|
17 |
|
18 |
|
19 |
class Wav2VecFeatureExtractor:
|
20 |
def __init__(self) -> None:
|
21 |
self.waveform_pipeline = WaveformTrainingPipeline()
|
22 |
-
self.feature_extractor = AutoProcessor.from_pretrained(
|
23 |
|
24 |
def __call__(self, waveform) -> Any:
|
25 |
waveform = self.waveform_pipeline(waveform)
|
26 |
-
return self.feature_extractor(
|
27 |
-
waveform.squeeze(0), sampling_rate=16000
|
28 |
-
)
|
29 |
|
30 |
def __getattr__(self, attr):
|
31 |
return getattr(self.feature_extractor, attr)
|
|
|
3 |
import pytorch_lightning as pl
|
4 |
from torch.utils.data import random_split
|
5 |
from transformers import AutoFeatureExtractor
|
6 |
+
from transformers import (
|
7 |
+
AutoModelForAudioClassification,
|
8 |
+
TrainingArguments,
|
9 |
+
Trainer,
|
10 |
+
AutoProcessor,
|
11 |
+
)
|
12 |
|
13 |
from preprocessing.dataset import (
|
14 |
HuggingFaceDatasetWrapper,
|
|
|
19 |
from .utils import get_id_label_mapping, compute_hf_metrics
|
20 |
|
21 |
MODEL_CHECKPOINT = "yuval6967/wav2vec2-base-finetuned-gtzan"
|
22 |
+
PROCESSOR_CHECKPOINT = "facebook/wav2vec2-base"
|
23 |
|
24 |
|
25 |
class Wav2VecFeatureExtractor:
|
26 |
def __init__(self) -> None:
|
27 |
self.waveform_pipeline = WaveformTrainingPipeline()
|
28 |
+
self.feature_extractor = AutoProcessor.from_pretrained(PROCESSOR_CHECKPOINT)
|
29 |
|
30 |
def __call__(self, waveform) -> Any:
|
31 |
waveform = self.waveform_pipeline(waveform)
|
32 |
+
return self.feature_extractor(waveform.squeeze(0), sampling_rate=16000)
|
|
|
|
|
33 |
|
34 |
def __getattr__(self, attr):
|
35 |
return getattr(self.feature_extractor, attr)
|