bpietrzak commited on
Commit
91a9e54
·
1 Parent(s): 710c57b

Training fix

Browse files
Files changed (8) hide show
  1. .gitignore +1 -159
  2. dl/make_dataset.py +0 -42
  3. dl/push_model.py +0 -36
  4. dl/testing.ipynb +0 -394
  5. dl/train.py +0 -113
  6. main.py +0 -29
  7. requirements.txt +9 -7
  8. train.py +134 -0
.gitignore CHANGED
@@ -1,160 +1,2 @@
1
- # Byte-compiled / optimized / DLL files
2
- __pycache__/
3
- *.py[cod]
4
- *$py.class
5
-
6
- # C extensions
7
- *.so
8
-
9
- # Distribution / packaging
10
- .Python
11
- build/
12
- develop-eggs/
13
- dist/
14
- downloads/
15
- eggs/
16
- .eggs/
17
- lib/
18
- lib64/
19
- parts/
20
- sdist/
21
- var/
22
- wheels/
23
- share/python-wheels/
24
- *.egg-info/
25
- .installed.cfg
26
- *.egg
27
- MANIFEST
28
-
29
- # PyInstaller
30
- # Usually these files are written by a python script from a template
31
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
- *.manifest
33
- *.spec
34
-
35
- # Installer logs
36
- pip-log.txt
37
- pip-delete-this-directory.txt
38
-
39
- # Unit test / coverage reports
40
- htmlcov/
41
- .tox/
42
- .nox/
43
- .coverage
44
- .coverage.*
45
- .cache
46
- nosetests.xml
47
- coverage.xml
48
- *.cover
49
- *.py,cover
50
- .hypothesis/
51
- .pytest_cache/
52
- cover/
53
-
54
- # Translations
55
- *.mo
56
- *.pot
57
-
58
- # Django stuff:
59
- *.log
60
- local_settings.py
61
- db.sqlite3
62
- db.sqlite3-journal
63
-
64
- # Flask stuff:
65
- instance/
66
- .webassets-cache
67
-
68
- # Scrapy stuff:
69
- .scrapy
70
-
71
- # Sphinx documentation
72
- docs/_build/
73
-
74
- # PyBuilder
75
- .pybuilder/
76
- target/
77
-
78
- # Jupyter Notebook
79
- .ipynb_checkpoints
80
-
81
- # IPython
82
- profile_default/
83
- ipython_config.py
84
-
85
- # pyenv
86
- # For a library or package, you might want to ignore these files since the code is
87
- # intended to run in multiple environments; otherwise, check them in:
88
- # .python-version
89
-
90
- # pipenv
91
- # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
- # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
- # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
- # install all needed dependencies.
95
- #Pipfile.lock
96
-
97
- # poetry
98
- # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
- # This is especially recommended for binary packages to ensure reproducibility, and is more
100
- # commonly ignored for libraries.
101
- # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
- #poetry.lock
103
-
104
- # pdm
105
- # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
- #pdm.lock
107
- # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
- # in version control.
109
- # https://pdm.fming.dev/#use-with-ide
110
- .pdm.toml
111
-
112
- # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
- __pypackages__/
114
-
115
- # Celery stuff
116
- celerybeat-schedule
117
- celerybeat.pid
118
-
119
- # SageMath parsed files
120
- *.sage.py
121
-
122
- # Environments
123
- .env
124
- .venv
125
- env/
126
  venv/
127
- ENV/
128
- env.bak/
129
- venv.bak/
130
-
131
- # Spyder project settings
132
- .spyderproject
133
- .spyproject
134
-
135
- # Rope project settings
136
- .ropeproject
137
-
138
- # mkdocs documentation
139
- /site
140
-
141
- # mypy
142
- .mypy_cache/
143
- .dmypy.json
144
- dmypy.json
145
-
146
- # Pyre type checker
147
- .pyre/
148
-
149
- # pytype static type analyzer
150
- .pytype/
151
-
152
- # Cython debug symbols
153
- cython_debug/
154
-
155
- # PyCharm
156
- # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
- # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
- # and can be added to the global gitignore or merged into this file. For a more nuclear
159
- # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
- #.idea/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  venv/
2
+ Data/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dl/make_dataset.py DELETED
@@ -1,42 +0,0 @@
1
- import os
2
- import json
3
- import argparse
4
- import librosa
5
- import pandas as pd
6
-
7
- def parse_args():
8
- parser = argparse.ArgumentParser()
9
- parser.add_argument("--dir", type=str, help="Directory containing OGG audio files.")
10
- parser.add_argument("--file", type=str, help="JSON file mapping filenames to classes.")
11
- parser.add_argument('-o', '--output', type=str, default="output_dataset.csv", help="Output CSV file.")
12
- return vars(parser.parse_args())
13
-
14
- def load_audio_files(audio_dir, file_class_mapping):
15
- data = []
16
- for filename, class_label in file_class_mapping.items():
17
- file_path = os.path.join(audio_dir, filename)
18
- if os.path.exists(file_path):
19
- audio, sr = librosa.load(file_path, sr=None)
20
- data.append({
21
- 'filename': filename,
22
- 'audio': audio,
23
- 'sampling_rate': sr,
24
- 'label': class_label
25
- })
26
- return data
27
-
28
- def main(args):
29
- audio_dir = args['dir']
30
- json_file = args['file']
31
-
32
- with open(json_file, 'r') as f:
33
- file_class_mapping = json.load(f)
34
-
35
- dataset = load_audio_files(audio_dir, file_class_mapping)
36
-
37
- df = pd.DataFrame(dataset)
38
-
39
- df.to_csv(args['output'], index=False)
40
-
41
- if __name__ == "__main__":
42
- main(parse_args())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dl/push_model.py DELETED
@@ -1,36 +0,0 @@
1
- import argparse
2
- from transformers import AutoModel, AutoTokenizer
3
- from huggingface_hub import HfApi, HfFolder
4
-
5
- def parse_args():
6
- parser = argparse.ArgumentParser()
7
- parser.add_argument("--username", type=str, required=True, help="Nazwa użytkownika Hugging Face.")
8
- parser.add_argument("--model_dir", type=str, required=True, help="Ścieżka do zapisanego modelu.")
9
- parser.add_argument("--repo_name", type=str, required=True, help="Nazwa repozytorium HuggingFace Hub.")
10
- parser.add_argument("--private", type=bool, default=False, help="Flaga określająca, czy repozytorium powinno być prywatne.")
11
- return parser.parse_args()
12
-
13
- def main():
14
- args = parse_args()
15
- token = HfFolder.get_token()
16
- if token is None:
17
- raise ValueError("Token uwierzytelniający nie został znaleziony. Zaloguj się za pomocą CLI Hugging Face.")
18
-
19
- model = AutoModel.from_pretrained(args.model_dir)
20
- tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
21
-
22
- repo_url = HfApi().create_repo(
23
- token=token,
24
- name=args.repo_name,
25
- organization=args.username,
26
- private=args.private,
27
- exist_ok=True
28
- )
29
-
30
- model.push_to_hub(args.repo_name, use_auth_token=token)
31
- tokenizer.push_to_hub(args.repo_name, use_auth_token=token)
32
-
33
- print(f"Model i tokajzer zostały wysłane do {repo_url}")
34
-
35
- if __name__ == "__main__":
36
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dl/testing.ipynb DELETED
@@ -1,394 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "from datasets import load_dataset, Audio\n",
10
- "from transformers import AutoFeatureExtractor\n",
11
- "import numpy as np"
12
- ]
13
- },
14
- {
15
- "cell_type": "code",
16
- "execution_count": 2,
17
- "metadata": {},
18
- "outputs": [
19
- {
20
- "name": "stderr",
21
- "output_type": "stream",
22
- "text": [
23
- "/home/potato/.virtualenvs/studia/lib/python3.10/site-packages/datasets/load.py:1486: FutureWarning: The repository for marsyas/gtzan contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/marsyas/gtzan\n",
24
- "You can avoid this message in future by passing the argument `trust_remote_code=True`.\n",
25
- "Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.\n",
26
- " warnings.warn(\n"
27
- ]
28
- }
29
- ],
30
- "source": [
31
- "data = load_dataset(\"marsyas/gtzan\", \"all\")\n",
32
- "data = data['train'].train_test_split(seed=42, shuffle=True, test_size=.1)"
33
- ]
34
- },
35
- {
36
- "cell_type": "code",
37
- "execution_count": 3,
38
- "metadata": {},
39
- "outputs": [],
40
- "source": [
41
- "map_class = data['train'].features['genre'].int2str"
42
- ]
43
- },
44
- {
45
- "cell_type": "markdown",
46
- "metadata": {},
47
- "source": [
48
- "Models to train:\n",
49
- "\n",
50
- "- ntu-spml/distilhubert\n",
51
- "- dima806/music_genres_classification"
52
- ]
53
- },
54
- {
55
- "cell_type": "code",
56
- "execution_count": 4,
57
- "metadata": {},
58
- "outputs": [],
59
- "source": [
60
- "distilhubert = AutoFeatureExtractor.from_pretrained(\n",
61
- " 'ntu-spml/distilhubert', do_normalize=True, return_attention_mask=True\n",
62
- ")\n",
63
- "# music_genres_classification = AutoFeatureExtractor.from_pretrained(\n",
64
- "# 'dima806/music_genres_classification', do_normalize=True, return_attention_mask=True\n",
65
- "# )\n",
66
- "\n",
67
- "# models = {'distilhubert': distilhubert,\n",
68
- "# 'music_genres_classification': music_genres_classification}\n",
69
- "\n",
70
- "# def get_sampling_rate(model):\n",
71
- "# return model.sampling_rate\n",
72
- "\n",
73
- "# if np.all([ get_sampling_rate(model) == 16000 for model in models.values()]):\n",
74
- "# sampling_rate = 16000\n",
75
- "# else:\n",
76
- "# raise ValueError('You need to setup different values than 16000 for a sampling rate')\n",
77
- "\n",
78
- "data = data.cast_column(\"audio\", Audio(sampling_rate=16000))"
79
- ]
80
- },
81
- {
82
- "cell_type": "code",
83
- "execution_count": 5,
84
- "metadata": {},
85
- "outputs": [],
86
- "source": [
87
- "class Preprocess:\n",
88
- " def __init__(self, model):\n",
89
- " self.model = model\n",
90
- " \n",
91
- " def __call__(self, examples):\n",
92
- " audio_arrays = [x[\"array\"] for x in examples[\"audio\"]]\n",
93
- " inputs = self.model(\n",
94
- " audio_arrays,\n",
95
- " sampling_rate=self.model.sampling_rate,\n",
96
- " max_length=int(self.model.sampling_rate * 30.0),\n",
97
- " truncation=True,\n",
98
- " return_attention_mask=True)\n",
99
- " return inputs"
100
- ]
101
- },
102
- {
103
- "cell_type": "code",
104
- "execution_count": 6,
105
- "metadata": {},
106
- "outputs": [],
107
- "source": [
108
- "distilhubert_preprocess = Preprocess(distilhubert)\n",
109
- "# music_genres_classification_preprocess = Preprocess(music_genres_classification)"
110
- ]
111
- },
112
- {
113
- "cell_type": "code",
114
- "execution_count": 7,
115
- "metadata": {},
116
- "outputs": [],
117
- "source": [
118
- "def process_data(preprocess):\n",
119
- " data_preprocessed = data.map(\n",
120
- " preprocess,\n",
121
- " remove_columns=[\"audio\", \"file\"],\n",
122
- " batched=True,\n",
123
- " batch_size=100,\n",
124
- " num_proc=1)\n",
125
- " return data_preprocessed\n",
126
- "\n",
127
- "distilhubert_data = process_data(distilhubert_preprocess)\n",
128
- "# music_genres_classification_data = process_data(music_genres_classification_preprocess)"
129
- ]
130
- },
131
- {
132
- "cell_type": "code",
133
- "execution_count": 8,
134
- "metadata": {},
135
- "outputs": [],
136
- "source": [
137
- "distilhubert_data = distilhubert_data.rename_column(\"genre\", \"label\")\n",
138
- "# music_genres_classification_data = music_genres_classification_data.rename_column(\"genre\", \"label\")"
139
- ]
140
- },
141
- {
142
- "cell_type": "code",
143
- "execution_count": 9,
144
- "metadata": {},
145
- "outputs": [],
146
- "source": [
147
- "id2label = {\n",
148
- " str(i): map_class(i)\n",
149
- " for i in range(len(distilhubert_data[\"train\"].features[\"label\"].names))\n",
150
- "}\n",
151
- "label2id = {v: k for k, v in id2label.items()}"
152
- ]
153
- },
154
- {
155
- "cell_type": "code",
156
- "execution_count": 10,
157
- "metadata": {},
158
- "outputs": [],
159
- "source": [
160
- "from transformers import AutoModelForAudioClassification\n",
161
- "from transformers import TrainingArguments\n",
162
- "import numpy as np\n",
163
- "from transformers import Trainer\n",
164
- "\n",
165
- "\n",
166
- "class Eval:\n",
167
- " def __init__(self, metric) -> None:\n",
168
- " self.metric = metric\n",
169
- "\n",
170
- " def __call__(self, eval_pred):\n",
171
- " predictions = np.argmax(eval_pred.predictions, axis=1)\n",
172
- " return self.metric.compute(predictions=predictions, references=eval_pred.label_ids)\n",
173
- "\n",
174
- "def train(model_name, class_nb, label2id, id2label, batch_size, epochs, eval_metric, data, feature_extractor):\n",
175
- " model = AutoModelForAudioClassification.from_pretrained(\n",
176
- " model_name,\n",
177
- " num_labels=class_nb,\n",
178
- " label2id=label2id,\n",
179
- " id2label=id2label)\n",
180
- "\n",
181
- " training_args = TrainingArguments(\n",
182
- " f\"{model_name.split('/')[-1]}-ft-gtzan-{batch_size}-{epochs}\",\n",
183
- " evaluation_strategy=\"epoch\",\n",
184
- " save_strategy=\"epoch\",\n",
185
- " learning_rate=5e-5,\n",
186
- " per_device_train_batch_size=batch_size,\n",
187
- " gradient_accumulation_steps=2,\n",
188
- " per_device_eval_batch_size=batch_size,\n",
189
- " num_train_epochs=epochs,\n",
190
- " warmup_ratio=0.1,\n",
191
- " logging_steps=5,\n",
192
- " load_best_model_at_end=True,\n",
193
- " metric_for_best_model=\"accuracy\",\n",
194
- " fp16=True,\n",
195
- " push_to_hub=True)\n",
196
- " \n",
197
- " trainer = Trainer(\n",
198
- " model,\n",
199
- " training_args,\n",
200
- " train_dataset=data[\"train\"],\n",
201
- " eval_dataset=data[\"test\"],\n",
202
- " tokenizer=feature_extractor,\n",
203
- " compute_metrics=eval_metric)\n",
204
- "\n",
205
- " trainer.train()"
206
- ]
207
- },
208
- {
209
- "cell_type": "code",
210
- "execution_count": 11,
211
- "metadata": {},
212
- "outputs": [
213
- {
214
- "data": {
215
- "application/vnd.jupyter.widget-view+json": {
216
- "model_id": "0282b77db7d1478f8e96988688e4b049",
217
- "version_major": 2,
218
- "version_minor": 0
219
- },
220
- "text/plain": [
221
- "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
222
- ]
223
- },
224
- "metadata": {},
225
- "output_type": "display_data"
226
- }
227
- ],
228
- "source": [
229
- "from huggingface_hub import notebook_login\n",
230
- "\n",
231
- "notebook_login()"
232
- ]
233
- },
234
- {
235
- "cell_type": "code",
236
- "execution_count": 14,
237
- "metadata": {},
238
- "outputs": [
239
- {
240
- "ename": "NameError",
241
- "evalue": "name 'Eval' is not defined",
242
- "output_type": "error",
243
- "traceback": [
244
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
245
- "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
246
- "Cell \u001b[0;32mIn[14], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mevaluate\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m acc \u001b[38;5;241m=\u001b[39m \u001b[43mEval\u001b[49m(evaluate\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124maccuracy\u001b[39m\u001b[38;5;124m'\u001b[39m))\n\u001b[1;32m 6\u001b[0m models \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 7\u001b[0m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmodel_name\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mntu-spml/distilhubert\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mclass_nb\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;28mlen\u001b[39m(id2label), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlabel2id\u001b[39m\u001b[38;5;124m'\u001b[39m: label2id, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mid2label\u001b[39m\u001b[38;5;124m'\u001b[39m: id2label, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbatch_size\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;241m4\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mepochs\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;241m8\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124meval_metric\u001b[39m\u001b[38;5;124m'\u001b[39m: acc, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdata\u001b[39m\u001b[38;5;124m'\u001b[39m: distilhubert_data, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mfeature_extractor\u001b[39m\u001b[38;5;124m'\u001b[39m: distilhubert},\n\u001b[1;32m 8\u001b[0m \u001b[38;5;66;03m# {'model_name': 'dima806/music_genres_classification', 'class_nb': len(id2label), 'label2id': label2id, 'id2label': id2label, 'batch_size': 25, 'epochs': 8, 'eval_metric': acc, 'data': music_genres_classification_data, 'feature_extractor': music_genres_classification}]\u001b[39;00m\n\u001b[1;32m 9\u001b[0m ]\n",
247
- "\u001b[0;31mNameError\u001b[0m: name 'Eval' is not defined"
248
- ]
249
- }
250
- ],
251
- "source": [
252
- "import evaluate\n",
253
- "\n",
254
- "acc = Eval(evaluate.load('accuracy'))\n",
255
- "\n",
256
- "\n",
257
- "models = [\n",
258
- " {'model_name': 'ntu-spml/distilhubert', 'class_nb': len(id2label), 'label2id': label2id, 'id2label': id2label, 'batch_size': 4, 'epochs': 8, 'eval_metric': acc, 'data': distilhubert_data, 'feature_extractor': distilhubert},\n",
259
- " # {'model_name': 'dima806/music_genres_classification', 'class_nb': len(id2label), 'label2id': label2id, 'id2label': id2label, 'batch_size': 25, 'epochs': 8, 'eval_metric': acc, 'data': music_genres_classification_data, 'feature_extractor': music_genres_classification}]\n",
260
- "]"
261
- ]
262
- },
263
- {
264
- "cell_type": "code",
265
- "execution_count": 13,
266
- "metadata": {},
267
- "outputs": [
268
- {
269
- "name": "stderr",
270
- "output_type": "stream",
271
- "text": [
272
- "Some weights of HubertForSequenceClassification were not initialized from the model checkpoint at ntu-spml/distilhubert and are newly initialized: ['classifier.bias', 'classifier.weight', 'encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'projector.bias', 'projector.weight']\n",
273
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
274
- "/home/potato/.virtualenvs/studia/lib/python3.10/site-packages/transformers/training_args.py:1474: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
275
- " warnings.warn(\n"
276
- ]
277
- },
278
- {
279
- "data": {
280
- "application/vnd.jupyter.widget-view+json": {
281
- "model_id": "abba0ecb6fb242fe8538bcdeec44ef9b",
282
- "version_major": 2,
283
- "version_minor": 0
284
- },
285
- "text/plain": [
286
- " 0%| | 0/896 [00:00<?, ?it/s]"
287
- ]
288
- },
289
- "metadata": {},
290
- "output_type": "display_data"
291
- },
292
- {
293
- "name": "stderr",
294
- "output_type": "stream",
295
- "text": [
296
- "/home/potato/.virtualenvs/studia/lib/python3.10/site-packages/torch/autograd/graph.py:744: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n",
297
- " return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n"
298
- ]
299
- },
300
- {
301
- "name": "stdout",
302
- "output_type": "stream",
303
- "text": [
304
- "{'loss': 2.2907, 'grad_norm': 1.133973240852356, 'learning_rate': 2.777777777777778e-06, 'epoch': 0.04}\n",
305
- "{'loss': 2.2975, 'grad_norm': 1.526039719581604, 'learning_rate': 5.555555555555556e-06, 'epoch': 0.09}\n",
306
- "{'loss': 2.2871, 'grad_norm': 1.1069303750991821, 'learning_rate': 8.333333333333334e-06, 'epoch': 0.13}\n",
307
- "{'loss': 2.3107, 'grad_norm': 1.4785107374191284, 'learning_rate': 1.1111111111111112e-05, 'epoch': 0.18}\n",
308
- "{'loss': 2.2712, 'grad_norm': 1.5087419748306274, 'learning_rate': 1.388888888888889e-05, 'epoch': 0.22}\n",
309
- "{'loss': 2.3081, 'grad_norm': 1.904876708984375, 'learning_rate': 1.6666666666666667e-05, 'epoch': 0.27}\n",
310
- "{'loss': 2.2583, 'grad_norm': 1.2942432165145874, 'learning_rate': 1.9444444444444445e-05, 'epoch': 0.31}\n",
311
- "{'loss': 2.2572, 'grad_norm': 1.8840770721435547, 'learning_rate': 2.2222222222222223e-05, 'epoch': 0.36}\n",
312
- "{'loss': 2.2733, 'grad_norm': 1.25327730178833, 'learning_rate': 2.5e-05, 'epoch': 0.4}\n",
313
- "{'loss': 2.273, 'grad_norm': 1.499450922012329, 'learning_rate': 2.777777777777778e-05, 'epoch': 0.44}\n",
314
- "{'loss': 2.2221, 'grad_norm': 1.6644848585128784, 'learning_rate': 3.055555555555556e-05, 'epoch': 0.49}\n",
315
- "{'loss': 2.2279, 'grad_norm': 1.5860854387283325, 'learning_rate': 3.3333333333333335e-05, 'epoch': 0.53}\n",
316
- "{'loss': 2.2253, 'grad_norm': 1.8796266317367554, 'learning_rate': 3.611111111111111e-05, 'epoch': 0.58}\n",
317
- "{'loss': 2.1556, 'grad_norm': 2.5186994075775146, 'learning_rate': 3.888888888888889e-05, 'epoch': 0.62}\n",
318
- "{'loss': 2.1329, 'grad_norm': 2.4733965396881104, 'learning_rate': 4.166666666666667e-05, 'epoch': 0.67}\n",
319
- "{'loss': 2.1214, 'grad_norm': 1.7492904663085938, 'learning_rate': 4.4444444444444447e-05, 'epoch': 0.71}\n",
320
- "{'loss': 2.0805, 'grad_norm': 3.5089523792266846, 'learning_rate': 4.722222222222222e-05, 'epoch': 0.76}\n",
321
- "{'loss': 1.9769, 'grad_norm': 1.759109377861023, 'learning_rate': 5e-05, 'epoch': 0.8}\n",
322
- "{'loss': 2.0086, 'grad_norm': 3.3685412406921387, 'learning_rate': 4.968982630272953e-05, 'epoch': 0.84}\n",
323
- "{'loss': 1.9664, 'grad_norm': 4.404444694519043, 'learning_rate': 4.937965260545906e-05, 'epoch': 0.89}\n",
324
- "{'loss': 1.9937, 'grad_norm': 9.0780611038208, 'learning_rate': 4.9069478908188585e-05, 'epoch': 0.93}\n",
325
- "{'loss': 1.876, 'grad_norm': 4.4436798095703125, 'learning_rate': 4.8759305210918115e-05, 'epoch': 0.98}\n"
326
- ]
327
- },
328
- {
329
- "data": {
330
- "application/vnd.jupyter.widget-view+json": {
331
- "model_id": "7c5670c3679841ab80f531ecb12310a1",
332
- "version_major": 2,
333
- "version_minor": 0
334
- },
335
- "text/plain": [
336
- " 0%| | 0/25 [00:00<?, ?it/s]"
337
- ]
338
- },
339
- "metadata": {},
340
- "output_type": "display_data"
341
- },
342
- {
343
- "ename": "TypeError",
344
- "evalue": "'Accuracy' object is not callable",
345
- "output_type": "error",
346
- "traceback": [
347
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
348
- "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
349
- "Cell \u001b[0;32mIn[13], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m m \u001b[38;5;129;01min\u001b[39;00m models:\n\u001b[0;32m----> 2\u001b[0m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mm\u001b[49m\u001b[43m)\u001b[49m\n",
350
- "Cell \u001b[0;32mIn[10], line 42\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(model_name, class_nb, label2id, id2label, batch_size, epochs, eval_metric, data, feature_extractor)\u001b[0m\n\u001b[1;32m 18\u001b[0m training_args \u001b[38;5;241m=\u001b[39m TrainingArguments(\n\u001b[1;32m 19\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel_name\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m/\u001b[39m\u001b[38;5;124m'\u001b[39m)[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m-ft-gtzan-\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mbatch_size\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m-\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepochs\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 20\u001b[0m evaluation_strategy\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mepoch\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 31\u001b[0m fp16\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 32\u001b[0m push_to_hub\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 34\u001b[0m trainer \u001b[38;5;241m=\u001b[39m Trainer(\n\u001b[1;32m 35\u001b[0m model,\n\u001b[1;32m 36\u001b[0m training_args,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 39\u001b[0m tokenizer\u001b[38;5;241m=\u001b[39mfeature_extractor,\n\u001b[1;32m 40\u001b[0m compute_metrics\u001b[38;5;241m=\u001b[39meval_metric)\n\u001b[0;32m---> 42\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
351
- "File \u001b[0;32m~/.virtualenvs/studia/lib/python3.10/site-packages/transformers/trainer.py:1876\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1873\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1874\u001b[0m \u001b[38;5;66;03m# Disable progress bars when uploading models during checkpoints to avoid polluting stdout\u001b[39;00m\n\u001b[1;32m 1875\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39mdisable_progress_bars()\n\u001b[0;32m-> 1876\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1877\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1878\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1879\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1880\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1881\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1882\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 1883\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n",
352
- "File \u001b[0;32m~/.virtualenvs/studia/lib/python3.10/site-packages/transformers/trainer.py:2311\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 2308\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol\u001b[38;5;241m.\u001b[39mshould_training_stop \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 2310\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_epoch_end(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[0;32m-> 2311\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_maybe_log_save_evaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtr_loss\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_norm\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2313\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m DebugOption\u001b[38;5;241m.\u001b[39mTPU_METRICS_DEBUG \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mdebug:\n\u001b[1;32m 2314\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_torch_xla_available():\n\u001b[1;32m 2315\u001b[0m \u001b[38;5;66;03m# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)\u001b[39;00m\n",
353
- "File \u001b[0;32m~/.virtualenvs/studia/lib/python3.10/site-packages/transformers/trainer.py:2721\u001b[0m, in \u001b[0;36mTrainer._maybe_log_save_evaluate\u001b[0;34m(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 2719\u001b[0m metrics \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 2720\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol\u001b[38;5;241m.\u001b[39mshould_evaluate:\n\u001b[0;32m-> 2721\u001b[0m metrics \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mevaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mignore_keys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2722\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_report_to_hp_search(trial, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step, metrics)\n\u001b[1;32m 2724\u001b[0m \u001b[38;5;66;03m# Run delayed LR scheduler now that metrics are populated\u001b[39;00m\n",
354
- "File \u001b[0;32m~/.virtualenvs/studia/lib/python3.10/site-packages/transformers/trainer.py:3572\u001b[0m, in \u001b[0;36mTrainer.evaluate\u001b[0;34m(self, eval_dataset, ignore_keys, metric_key_prefix)\u001b[0m\n\u001b[1;32m 3569\u001b[0m start_time \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[1;32m 3571\u001b[0m eval_loop \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprediction_loop \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39muse_legacy_prediction_loop \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mevaluation_loop\n\u001b[0;32m-> 3572\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43meval_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3573\u001b[0m \u001b[43m \u001b[49m\u001b[43meval_dataloader\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3574\u001b[0m \u001b[43m \u001b[49m\u001b[43mdescription\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mEvaluation\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3575\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# No point gathering the predictions if there are no metrics, otherwise we defer to\u001b[39;49;00m\n\u001b[1;32m 3576\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# self.args.prediction_loss_only\u001b[39;49;00m\n\u001b[1;32m 3577\u001b[0m \u001b[43m \u001b[49m\u001b[43mprediction_loss_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_metrics\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mis\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 3578\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3579\u001b[0m \u001b[43m \u001b[49m\u001b[43mmetric_key_prefix\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetric_key_prefix\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3580\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3582\u001b[0m total_batch_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39meval_batch_size \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mworld_size\n\u001b[1;32m 3583\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmetric_key_prefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m_jit_compilation_time\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m output\u001b[38;5;241m.\u001b[39mmetrics:\n",
355
- "File \u001b[0;32m~/.virtualenvs/studia/lib/python3.10/site-packages/transformers/trainer.py:3854\u001b[0m, in \u001b[0;36mTrainer.evaluation_loop\u001b[0;34m(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)\u001b[0m\n\u001b[1;32m 3850\u001b[0m metrics \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcompute_metrics(\n\u001b[1;32m 3851\u001b[0m EvalPrediction(predictions\u001b[38;5;241m=\u001b[39mall_preds, label_ids\u001b[38;5;241m=\u001b[39mall_labels, inputs\u001b[38;5;241m=\u001b[39mall_inputs)\n\u001b[1;32m 3852\u001b[0m )\n\u001b[1;32m 3853\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 3854\u001b[0m metrics \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_metrics\u001b[49m\u001b[43m(\u001b[49m\u001b[43mEvalPrediction\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpredictions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mall_preds\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabel_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mall_labels\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3855\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m metrics \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 3856\u001b[0m metrics \u001b[38;5;241m=\u001b[39m {}\n",
356
- "\u001b[0;31mTypeError\u001b[0m: 'Accuracy' object is not callable"
357
- ]
358
- }
359
- ],
360
- "source": [
361
- "for m in models:\n",
362
- " train(**m)"
363
- ]
364
- },
365
- {
366
- "cell_type": "code",
367
- "execution_count": null,
368
- "metadata": {},
369
- "outputs": [],
370
- "source": []
371
- }
372
- ],
373
- "metadata": {
374
- "kernelspec": {
375
- "display_name": "studia",
376
- "language": "python",
377
- "name": "python3"
378
- },
379
- "language_info": {
380
- "codemirror_mode": {
381
- "name": "ipython",
382
- "version": 3
383
- },
384
- "file_extension": ".py",
385
- "mimetype": "text/x-python",
386
- "name": "python",
387
- "nbconvert_exporter": "python",
388
- "pygments_lexer": "ipython3",
389
- "version": "3.10.12"
390
- }
391
- },
392
- "nbformat": 4,
393
- "nbformat_minor": 2
394
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dl/train.py DELETED
@@ -1,113 +0,0 @@
1
- import argparse
2
- import numpy as np
3
- from datasets import load_dataset, Audio
4
- from transformers import (AutoFeatureExtractor,
5
- AutoModelForAudioClassification, TrainingArguments,
6
- Trainer)
7
- import os
8
- import evaluate
9
- import random
10
-
11
-
12
- accuracy_metric = evaluate.load("accuracy")
13
-
14
- def parse_args() -> dict:
15
- parser = argparse.ArgumentParser(description="Skrypt do trenowania modelu klasyfikacji audio.")
16
- parser.add_argument("--learning_rate", type=float, default=5e-5,
17
- help="Współczynnik uczenia podczas treningu modelu.")
18
- parser.add_argument("--train_eval_split", type=float, default=0.9,
19
- help="Stosunek danych trenujących do całego zbioru; reszta to dane walidacyjne.")
20
- parser.add_argument("--model_id", type=str, required=True,
21
- help="Identyfikator modelu z Hugging Face lub ścieżka do lokalnego modelu.")
22
- parser.add_argument("--num_epochs", type=int, default=20,
23
- help="Liczba epok treningowych.")
24
- parser.add_argument("--seed", type=int, default=42,
25
- help="Ziarno liczb losowych.")
26
- parser.add_argument("--save_dir", type=str, default=".",
27
- help="Ścieżka do katalogu wag tranowanego modelu.")
28
- parser.add_argument("--dataset", type=str, default="marsyas/gtzan",
29
- help="Nazwa/lokalizacja zbioru danych.")
30
- return vars(parser.parse_args())
31
-
32
-
33
- def compute_metrics(eval_pred):
34
- predictions = np.argmax(eval_pred.predictions, axis=1)
35
- return accuracy_metric.compute(predictions=predictions,
36
- references=eval_pred.label_ids)
37
-
38
- def main(args: dict) -> None:
39
- random.seed(args["seed"])
40
- max_duration = 30.0
41
-
42
- gtzan = load_dataset(args["dataset"], "all")
43
- gtzan = gtzan["train"].train_test_split(seed=42, shuffle=True,
44
- test_size=1 - args["train_eval_split"])
45
-
46
- feature_extractor = AutoFeatureExtractor.from_pretrained(
47
- args["model_id"], do_normalize=True, return_attention_mask=True)
48
- sampling_rate = feature_extractor.sampling_rate
49
-
50
-
51
- def preprocess_function(examples):
52
- audio_arrays = [x["array"] for x in examples["audio"]]
53
- inputs = feature_extractor(
54
- audio_arrays,
55
- sampling_rate=sampling_rate,
56
- max_length=int(sampling_rate * max_duration),
57
- truncation=True,
58
- return_attention_mask=True,
59
- )
60
- return inputs
61
-
62
- gtzan = gtzan.cast_column("audio", Audio(sampling_rate=sampling_rate))
63
- gtzan_encoded = gtzan.map(
64
- preprocess_function,
65
- remove_columns=["audio", "file"],
66
- batched=True,
67
- batch_size=100,
68
- num_proc=1)
69
-
70
- gtzan_encoded = gtzan_encoded.rename_column("genre", "label")
71
-
72
- id2label = {str(i): gtzan["train"].features["genre"].int2str(i)
73
- for i in range(len(gtzan_encoded["train"].features["label"].names))}
74
- label2id = {v: k for k, v in id2label.items()}
75
- num_labels = len(id2label)
76
-
77
- model = AutoModelForAudioClassification.from_pretrained(
78
- args["model_id"],
79
- num_labels=num_labels,
80
- label2id=label2id,
81
- id2label=id2label)
82
-
83
- dir_name = f"{args["model_id"]}-{args["seed"]}-{args["dataset"]}-{args['learning_rate']}".replace("/", "-")
84
-
85
- training_args = TrainingArguments(
86
- output_dir=os.path.join(args["save_dir"], dir_name),
87
- evaluation_strategy="epoch",
88
- save_strategy="epoch",
89
- learning_rate=args["learning_rate"],
90
- per_device_train_batch_size=5,
91
- gradient_accumulation_steps=2,
92
- per_device_eval_batch_size=5,
93
- num_train_epochs=args["num_epochs"],
94
- warmup_ratio=0.1,
95
- logging_dir="./logs",
96
- logging_steps=5,
97
- load_best_model_at_end=True,
98
- metric_for_best_model="accuracy",
99
- fp16=True)
100
-
101
- trainer = Trainer(
102
- model=model,
103
- args=training_args,
104
- train_dataset=gtzan_encoded["train"],
105
- eval_dataset=gtzan_encoded["test"],
106
- tokenizer=feature_extractor,
107
- compute_metrics=compute_metrics)
108
-
109
- trainer.train()
110
-
111
-
112
- if __name__ == "__main__":
113
- main(parse_args())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py DELETED
@@ -1,29 +0,0 @@
1
- from transformers import pipeline
2
- import librosa
3
- import json
4
- import gradio as gr
5
-
6
-
7
- def audio_pipeline(file_path: str, top_k: int = 7) -> dict[str, float]:
8
- y, _ = librosa.load(file_path, sr=config['sampling_rate'])
9
- out = pipe(y, top_k=top_k)
10
- print(out)
11
- return {clas['label']: clas['score'] for clas in out}
12
-
13
-
14
- with open('config.json', 'r') as f:
15
- config = json.load(f)
16
-
17
- pipe = pipeline("audio-classification", model=config['models_path'])
18
-
19
- demo = gr.Interface(
20
- fn=audio_pipeline,
21
- inputs=[gr.Audio(type="filepath"), gr.Slider(1, 10, 1,
22
- label="Top K Results")],
23
- outputs=gr.Label(num_top_classes=7),
24
- title="Music Mind Audio Classification",
25
- description="Upload an .mp3 or .ogg audio file "
26
- "to classify the content using a pre-trained model.")
27
-
28
- if __name__ == "__main__":
29
- demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,7 +1,9 @@
1
- torch
2
- datasets
3
- transformers[torch]
4
- evaluate
5
- numpy
6
- librosa
7
- soundfile
 
 
 
1
+ torch --index-url https://download.pytorch.org/whl/cu121
2
+ torchvision --index-url https://download.pytorch.org/whl/cu121
3
+ transformers==4.41.2
4
+ gradio==4.36.1
5
+ numpy==1.26.4
6
+ evaluate==0.4.2
7
+ tqdm==4.66.4
8
+ mlflow==2.13.2
9
+ librosa==0.10.2.post1
train.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForAudioClassification
2
+ from torch.utils.data import DataLoader
3
+ import evaluate
4
+ import torch
5
+ from tqdm import tqdm
6
+ import argparse
7
+ import json
8
+ import os
9
+ import shutil
10
+ import mlflow
11
+ import mlflow.pytorch
12
+
13
+ from gtzan import GtzanDataset
14
+
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ metric = evaluate.load("accuracy")
18
+
19
+ def parse_args():
20
+ ap = argparse.ArgumentParser()
21
+ ap.add_argument("--label2id", type=str)
22
+ ap.add_argument("--model_id", type=str)
23
+ ap.add_argument("--batch_size", type=int, default=32)
24
+ ap.add_argument("--train_dir", type=str, default="data/train")
25
+ ap.add_argument("--val_dir", type=str, default="data/val")
26
+ ap.add_argument("--num_workers", type=int, default=4)
27
+ ap.add_argument("--lr", type=float, default=1e-4)
28
+ ap.add_argument("--epochs", type=int, default=10)
29
+ ap.add_argument("--output_dir", type=str, default="./weights")
30
+ ap.add_argument("--seed", type=int, default=42)
31
+ ap.add_argument("--name", type=str, default="model")
32
+ return vars(ap.parse_args())
33
+
34
+ def train(args):
35
+ torch.manual_seed(args["seed"])
36
+
37
+ label2id = json.load(open(args["label2id"]))
38
+ id2label = {v: k for k, v in label2id.items()}
39
+ num_labels = len(label2id)
40
+ if not os.path.exists(args["output_dir"]):
41
+ os.makedirs(args["output_dir"])
42
+
43
+ train_dataset = GtzanDataset(args["train_dir"], label2id)
44
+ val_dataset = GtzanDataset(args["val_dir"], label2id)
45
+
46
+ train_loader = DataLoader(
47
+ train_dataset,
48
+ batch_size=args["batch_size"],
49
+ shuffle=True,
50
+ num_workers=args["num_workers"])
51
+
52
+ val_loader = DataLoader(
53
+ val_dataset,
54
+ batch_size=args["batch_size"],
55
+ shuffle=False,
56
+ num_workers=args["num_workers"])
57
+
58
+ model = AutoModelForAudioClassification.from_pretrained(
59
+ args['model_id'],
60
+ num_labels=num_labels,
61
+ label2id=label2id,
62
+ id2label=id2label,
63
+ ).to(device)
64
+
65
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args["lr"])
66
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
67
+ optimizer, T_max=len(train_loader) * args["epochs"]
68
+ )
69
+
70
+ max_val_accuracy = 0
71
+ best_path = ""
72
+
73
+ with mlflow.start_run():
74
+ mlflow.log_params({
75
+ "model_id": args["model_id"],
76
+ "batch_size": args["batch_size"],
77
+ "lr": args["lr"],
78
+ "epochs": args["epochs"],
79
+ "seed": args["seed"]
80
+ })
81
+
82
+ for epoch in tqdm(range(args["epochs"])):
83
+ model.train()
84
+ train_progress_bar = tqdm(train_loader, desc=f"Training Epoch {epoch + 1}")
85
+ for batch in train_progress_bar:
86
+ input_values, attention_mask, label = [b.to(device) for b in batch]
87
+ outputs = model(input_values=input_values,
88
+ attention_mask=attention_mask,
89
+ labels=label)
90
+ loss = outputs.loss
91
+ loss.backward()
92
+ optimizer.step()
93
+ lr_scheduler.step()
94
+ optimizer.zero_grad()
95
+
96
+ train_progress_bar.set_postfix({"loss": loss.item()})
97
+ train_progress_bar.update(1)
98
+ mlflow.log_metric("train_loss", loss.item()) # Log training loss
99
+
100
+ torch.cuda.empty_cache()
101
+ model.eval()
102
+
103
+ val_progress_bar = tqdm(val_loader, desc="Validation")
104
+ for batch in val_progress_bar:
105
+ input_values, attention_mask, label = [b.to(device) for b in batch]
106
+ with torch.no_grad():
107
+ outputs = model(input_values=input_values,
108
+ attention_mask=attention_mask,
109
+ labels=label)
110
+
111
+ logits = outputs.logits
112
+ predictions = torch.argmax(logits, dim=-1)
113
+ metric.add_batch(predictions=predictions, references=label)
114
+ val_progress_bar.update(1)
115
+
116
+ val_accuracy = metric.compute()
117
+ mlflow.log_metric("val_accuracy", val_accuracy["accuracy"], step=epoch) # Log validation accuracy
118
+ torch.cuda.empty_cache()
119
+ if val_accuracy["accuracy"] > max_val_accuracy:
120
+ if best_path:
121
+ shutil.rmtree(best_path)
122
+ model_save_dir = os.path.join(
123
+ args["output_dir"],
124
+ args['name'],
125
+ f"{int(round(val_accuracy['accuracy'], 2) * 100)}")
126
+ if not os.path.exists(model_save_dir):
127
+ os.makedirs(model_save_dir, exist_ok=True)
128
+ model.save_pretrained(model_save_dir)
129
+ max_val_accuracy = val_accuracy["accuracy"]
130
+ best_path = model_save_dir
131
+
132
+ mlflow.pytorch.log_model(model, "model")
133
+ if __name__ == "__main__":
134
+ train(parse_args())