Spaces:
Sleeping
Sleeping
bpietrzak
commited on
Commit
·
91a9e54
1
Parent(s):
710c57b
Training fix
Browse files- .gitignore +1 -159
- dl/make_dataset.py +0 -42
- dl/push_model.py +0 -36
- dl/testing.ipynb +0 -394
- dl/train.py +0 -113
- main.py +0 -29
- requirements.txt +9 -7
- train.py +134 -0
.gitignore
CHANGED
@@ -1,160 +1,2 @@
|
|
1 |
-
# Byte-compiled / optimized / DLL files
|
2 |
-
__pycache__/
|
3 |
-
*.py[cod]
|
4 |
-
*$py.class
|
5 |
-
|
6 |
-
# C extensions
|
7 |
-
*.so
|
8 |
-
|
9 |
-
# Distribution / packaging
|
10 |
-
.Python
|
11 |
-
build/
|
12 |
-
develop-eggs/
|
13 |
-
dist/
|
14 |
-
downloads/
|
15 |
-
eggs/
|
16 |
-
.eggs/
|
17 |
-
lib/
|
18 |
-
lib64/
|
19 |
-
parts/
|
20 |
-
sdist/
|
21 |
-
var/
|
22 |
-
wheels/
|
23 |
-
share/python-wheels/
|
24 |
-
*.egg-info/
|
25 |
-
.installed.cfg
|
26 |
-
*.egg
|
27 |
-
MANIFEST
|
28 |
-
|
29 |
-
# PyInstaller
|
30 |
-
# Usually these files are written by a python script from a template
|
31 |
-
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
-
*.manifest
|
33 |
-
*.spec
|
34 |
-
|
35 |
-
# Installer logs
|
36 |
-
pip-log.txt
|
37 |
-
pip-delete-this-directory.txt
|
38 |
-
|
39 |
-
# Unit test / coverage reports
|
40 |
-
htmlcov/
|
41 |
-
.tox/
|
42 |
-
.nox/
|
43 |
-
.coverage
|
44 |
-
.coverage.*
|
45 |
-
.cache
|
46 |
-
nosetests.xml
|
47 |
-
coverage.xml
|
48 |
-
*.cover
|
49 |
-
*.py,cover
|
50 |
-
.hypothesis/
|
51 |
-
.pytest_cache/
|
52 |
-
cover/
|
53 |
-
|
54 |
-
# Translations
|
55 |
-
*.mo
|
56 |
-
*.pot
|
57 |
-
|
58 |
-
# Django stuff:
|
59 |
-
*.log
|
60 |
-
local_settings.py
|
61 |
-
db.sqlite3
|
62 |
-
db.sqlite3-journal
|
63 |
-
|
64 |
-
# Flask stuff:
|
65 |
-
instance/
|
66 |
-
.webassets-cache
|
67 |
-
|
68 |
-
# Scrapy stuff:
|
69 |
-
.scrapy
|
70 |
-
|
71 |
-
# Sphinx documentation
|
72 |
-
docs/_build/
|
73 |
-
|
74 |
-
# PyBuilder
|
75 |
-
.pybuilder/
|
76 |
-
target/
|
77 |
-
|
78 |
-
# Jupyter Notebook
|
79 |
-
.ipynb_checkpoints
|
80 |
-
|
81 |
-
# IPython
|
82 |
-
profile_default/
|
83 |
-
ipython_config.py
|
84 |
-
|
85 |
-
# pyenv
|
86 |
-
# For a library or package, you might want to ignore these files since the code is
|
87 |
-
# intended to run in multiple environments; otherwise, check them in:
|
88 |
-
# .python-version
|
89 |
-
|
90 |
-
# pipenv
|
91 |
-
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
-
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
-
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
-
# install all needed dependencies.
|
95 |
-
#Pipfile.lock
|
96 |
-
|
97 |
-
# poetry
|
98 |
-
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
-
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
-
# commonly ignored for libraries.
|
101 |
-
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
-
#poetry.lock
|
103 |
-
|
104 |
-
# pdm
|
105 |
-
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
-
#pdm.lock
|
107 |
-
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
-
# in version control.
|
109 |
-
# https://pdm.fming.dev/#use-with-ide
|
110 |
-
.pdm.toml
|
111 |
-
|
112 |
-
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
-
__pypackages__/
|
114 |
-
|
115 |
-
# Celery stuff
|
116 |
-
celerybeat-schedule
|
117 |
-
celerybeat.pid
|
118 |
-
|
119 |
-
# SageMath parsed files
|
120 |
-
*.sage.py
|
121 |
-
|
122 |
-
# Environments
|
123 |
-
.env
|
124 |
-
.venv
|
125 |
-
env/
|
126 |
venv/
|
127 |
-
|
128 |
-
env.bak/
|
129 |
-
venv.bak/
|
130 |
-
|
131 |
-
# Spyder project settings
|
132 |
-
.spyderproject
|
133 |
-
.spyproject
|
134 |
-
|
135 |
-
# Rope project settings
|
136 |
-
.ropeproject
|
137 |
-
|
138 |
-
# mkdocs documentation
|
139 |
-
/site
|
140 |
-
|
141 |
-
# mypy
|
142 |
-
.mypy_cache/
|
143 |
-
.dmypy.json
|
144 |
-
dmypy.json
|
145 |
-
|
146 |
-
# Pyre type checker
|
147 |
-
.pyre/
|
148 |
-
|
149 |
-
# pytype static type analyzer
|
150 |
-
.pytype/
|
151 |
-
|
152 |
-
# Cython debug symbols
|
153 |
-
cython_debug/
|
154 |
-
|
155 |
-
# PyCharm
|
156 |
-
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
-
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
-
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
-
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
-
#.idea/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
venv/
|
2 |
+
Data/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dl/make_dataset.py
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import json
|
3 |
-
import argparse
|
4 |
-
import librosa
|
5 |
-
import pandas as pd
|
6 |
-
|
7 |
-
def parse_args():
|
8 |
-
parser = argparse.ArgumentParser()
|
9 |
-
parser.add_argument("--dir", type=str, help="Directory containing OGG audio files.")
|
10 |
-
parser.add_argument("--file", type=str, help="JSON file mapping filenames to classes.")
|
11 |
-
parser.add_argument('-o', '--output', type=str, default="output_dataset.csv", help="Output CSV file.")
|
12 |
-
return vars(parser.parse_args())
|
13 |
-
|
14 |
-
def load_audio_files(audio_dir, file_class_mapping):
|
15 |
-
data = []
|
16 |
-
for filename, class_label in file_class_mapping.items():
|
17 |
-
file_path = os.path.join(audio_dir, filename)
|
18 |
-
if os.path.exists(file_path):
|
19 |
-
audio, sr = librosa.load(file_path, sr=None)
|
20 |
-
data.append({
|
21 |
-
'filename': filename,
|
22 |
-
'audio': audio,
|
23 |
-
'sampling_rate': sr,
|
24 |
-
'label': class_label
|
25 |
-
})
|
26 |
-
return data
|
27 |
-
|
28 |
-
def main(args):
|
29 |
-
audio_dir = args['dir']
|
30 |
-
json_file = args['file']
|
31 |
-
|
32 |
-
with open(json_file, 'r') as f:
|
33 |
-
file_class_mapping = json.load(f)
|
34 |
-
|
35 |
-
dataset = load_audio_files(audio_dir, file_class_mapping)
|
36 |
-
|
37 |
-
df = pd.DataFrame(dataset)
|
38 |
-
|
39 |
-
df.to_csv(args['output'], index=False)
|
40 |
-
|
41 |
-
if __name__ == "__main__":
|
42 |
-
main(parse_args())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dl/push_model.py
DELETED
@@ -1,36 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
from transformers import AutoModel, AutoTokenizer
|
3 |
-
from huggingface_hub import HfApi, HfFolder
|
4 |
-
|
5 |
-
def parse_args():
|
6 |
-
parser = argparse.ArgumentParser()
|
7 |
-
parser.add_argument("--username", type=str, required=True, help="Nazwa użytkownika Hugging Face.")
|
8 |
-
parser.add_argument("--model_dir", type=str, required=True, help="Ścieżka do zapisanego modelu.")
|
9 |
-
parser.add_argument("--repo_name", type=str, required=True, help="Nazwa repozytorium HuggingFace Hub.")
|
10 |
-
parser.add_argument("--private", type=bool, default=False, help="Flaga określająca, czy repozytorium powinno być prywatne.")
|
11 |
-
return parser.parse_args()
|
12 |
-
|
13 |
-
def main():
|
14 |
-
args = parse_args()
|
15 |
-
token = HfFolder.get_token()
|
16 |
-
if token is None:
|
17 |
-
raise ValueError("Token uwierzytelniający nie został znaleziony. Zaloguj się za pomocą CLI Hugging Face.")
|
18 |
-
|
19 |
-
model = AutoModel.from_pretrained(args.model_dir)
|
20 |
-
tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
|
21 |
-
|
22 |
-
repo_url = HfApi().create_repo(
|
23 |
-
token=token,
|
24 |
-
name=args.repo_name,
|
25 |
-
organization=args.username,
|
26 |
-
private=args.private,
|
27 |
-
exist_ok=True
|
28 |
-
)
|
29 |
-
|
30 |
-
model.push_to_hub(args.repo_name, use_auth_token=token)
|
31 |
-
tokenizer.push_to_hub(args.repo_name, use_auth_token=token)
|
32 |
-
|
33 |
-
print(f"Model i tokajzer zostały wysłane do {repo_url}")
|
34 |
-
|
35 |
-
if __name__ == "__main__":
|
36 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dl/testing.ipynb
DELETED
@@ -1,394 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": 1,
|
6 |
-
"metadata": {},
|
7 |
-
"outputs": [],
|
8 |
-
"source": [
|
9 |
-
"from datasets import load_dataset, Audio\n",
|
10 |
-
"from transformers import AutoFeatureExtractor\n",
|
11 |
-
"import numpy as np"
|
12 |
-
]
|
13 |
-
},
|
14 |
-
{
|
15 |
-
"cell_type": "code",
|
16 |
-
"execution_count": 2,
|
17 |
-
"metadata": {},
|
18 |
-
"outputs": [
|
19 |
-
{
|
20 |
-
"name": "stderr",
|
21 |
-
"output_type": "stream",
|
22 |
-
"text": [
|
23 |
-
"/home/potato/.virtualenvs/studia/lib/python3.10/site-packages/datasets/load.py:1486: FutureWarning: The repository for marsyas/gtzan contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/marsyas/gtzan\n",
|
24 |
-
"You can avoid this message in future by passing the argument `trust_remote_code=True`.\n",
|
25 |
-
"Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.\n",
|
26 |
-
" warnings.warn(\n"
|
27 |
-
]
|
28 |
-
}
|
29 |
-
],
|
30 |
-
"source": [
|
31 |
-
"data = load_dataset(\"marsyas/gtzan\", \"all\")\n",
|
32 |
-
"data = data['train'].train_test_split(seed=42, shuffle=True, test_size=.1)"
|
33 |
-
]
|
34 |
-
},
|
35 |
-
{
|
36 |
-
"cell_type": "code",
|
37 |
-
"execution_count": 3,
|
38 |
-
"metadata": {},
|
39 |
-
"outputs": [],
|
40 |
-
"source": [
|
41 |
-
"map_class = data['train'].features['genre'].int2str"
|
42 |
-
]
|
43 |
-
},
|
44 |
-
{
|
45 |
-
"cell_type": "markdown",
|
46 |
-
"metadata": {},
|
47 |
-
"source": [
|
48 |
-
"Models to train:\n",
|
49 |
-
"\n",
|
50 |
-
"- ntu-spml/distilhubert\n",
|
51 |
-
"- dima806/music_genres_classification"
|
52 |
-
]
|
53 |
-
},
|
54 |
-
{
|
55 |
-
"cell_type": "code",
|
56 |
-
"execution_count": 4,
|
57 |
-
"metadata": {},
|
58 |
-
"outputs": [],
|
59 |
-
"source": [
|
60 |
-
"distilhubert = AutoFeatureExtractor.from_pretrained(\n",
|
61 |
-
" 'ntu-spml/distilhubert', do_normalize=True, return_attention_mask=True\n",
|
62 |
-
")\n",
|
63 |
-
"# music_genres_classification = AutoFeatureExtractor.from_pretrained(\n",
|
64 |
-
"# 'dima806/music_genres_classification', do_normalize=True, return_attention_mask=True\n",
|
65 |
-
"# )\n",
|
66 |
-
"\n",
|
67 |
-
"# models = {'distilhubert': distilhubert,\n",
|
68 |
-
"# 'music_genres_classification': music_genres_classification}\n",
|
69 |
-
"\n",
|
70 |
-
"# def get_sampling_rate(model):\n",
|
71 |
-
"# return model.sampling_rate\n",
|
72 |
-
"\n",
|
73 |
-
"# if np.all([ get_sampling_rate(model) == 16000 for model in models.values()]):\n",
|
74 |
-
"# sampling_rate = 16000\n",
|
75 |
-
"# else:\n",
|
76 |
-
"# raise ValueError('You need to setup different values than 16000 for a sampling rate')\n",
|
77 |
-
"\n",
|
78 |
-
"data = data.cast_column(\"audio\", Audio(sampling_rate=16000))"
|
79 |
-
]
|
80 |
-
},
|
81 |
-
{
|
82 |
-
"cell_type": "code",
|
83 |
-
"execution_count": 5,
|
84 |
-
"metadata": {},
|
85 |
-
"outputs": [],
|
86 |
-
"source": [
|
87 |
-
"class Preprocess:\n",
|
88 |
-
" def __init__(self, model):\n",
|
89 |
-
" self.model = model\n",
|
90 |
-
" \n",
|
91 |
-
" def __call__(self, examples):\n",
|
92 |
-
" audio_arrays = [x[\"array\"] for x in examples[\"audio\"]]\n",
|
93 |
-
" inputs = self.model(\n",
|
94 |
-
" audio_arrays,\n",
|
95 |
-
" sampling_rate=self.model.sampling_rate,\n",
|
96 |
-
" max_length=int(self.model.sampling_rate * 30.0),\n",
|
97 |
-
" truncation=True,\n",
|
98 |
-
" return_attention_mask=True)\n",
|
99 |
-
" return inputs"
|
100 |
-
]
|
101 |
-
},
|
102 |
-
{
|
103 |
-
"cell_type": "code",
|
104 |
-
"execution_count": 6,
|
105 |
-
"metadata": {},
|
106 |
-
"outputs": [],
|
107 |
-
"source": [
|
108 |
-
"distilhubert_preprocess = Preprocess(distilhubert)\n",
|
109 |
-
"# music_genres_classification_preprocess = Preprocess(music_genres_classification)"
|
110 |
-
]
|
111 |
-
},
|
112 |
-
{
|
113 |
-
"cell_type": "code",
|
114 |
-
"execution_count": 7,
|
115 |
-
"metadata": {},
|
116 |
-
"outputs": [],
|
117 |
-
"source": [
|
118 |
-
"def process_data(preprocess):\n",
|
119 |
-
" data_preprocessed = data.map(\n",
|
120 |
-
" preprocess,\n",
|
121 |
-
" remove_columns=[\"audio\", \"file\"],\n",
|
122 |
-
" batched=True,\n",
|
123 |
-
" batch_size=100,\n",
|
124 |
-
" num_proc=1)\n",
|
125 |
-
" return data_preprocessed\n",
|
126 |
-
"\n",
|
127 |
-
"distilhubert_data = process_data(distilhubert_preprocess)\n",
|
128 |
-
"# music_genres_classification_data = process_data(music_genres_classification_preprocess)"
|
129 |
-
]
|
130 |
-
},
|
131 |
-
{
|
132 |
-
"cell_type": "code",
|
133 |
-
"execution_count": 8,
|
134 |
-
"metadata": {},
|
135 |
-
"outputs": [],
|
136 |
-
"source": [
|
137 |
-
"distilhubert_data = distilhubert_data.rename_column(\"genre\", \"label\")\n",
|
138 |
-
"# music_genres_classification_data = music_genres_classification_data.rename_column(\"genre\", \"label\")"
|
139 |
-
]
|
140 |
-
},
|
141 |
-
{
|
142 |
-
"cell_type": "code",
|
143 |
-
"execution_count": 9,
|
144 |
-
"metadata": {},
|
145 |
-
"outputs": [],
|
146 |
-
"source": [
|
147 |
-
"id2label = {\n",
|
148 |
-
" str(i): map_class(i)\n",
|
149 |
-
" for i in range(len(distilhubert_data[\"train\"].features[\"label\"].names))\n",
|
150 |
-
"}\n",
|
151 |
-
"label2id = {v: k for k, v in id2label.items()}"
|
152 |
-
]
|
153 |
-
},
|
154 |
-
{
|
155 |
-
"cell_type": "code",
|
156 |
-
"execution_count": 10,
|
157 |
-
"metadata": {},
|
158 |
-
"outputs": [],
|
159 |
-
"source": [
|
160 |
-
"from transformers import AutoModelForAudioClassification\n",
|
161 |
-
"from transformers import TrainingArguments\n",
|
162 |
-
"import numpy as np\n",
|
163 |
-
"from transformers import Trainer\n",
|
164 |
-
"\n",
|
165 |
-
"\n",
|
166 |
-
"class Eval:\n",
|
167 |
-
" def __init__(self, metric) -> None:\n",
|
168 |
-
" self.metric = metric\n",
|
169 |
-
"\n",
|
170 |
-
" def __call__(self, eval_pred):\n",
|
171 |
-
" predictions = np.argmax(eval_pred.predictions, axis=1)\n",
|
172 |
-
" return self.metric.compute(predictions=predictions, references=eval_pred.label_ids)\n",
|
173 |
-
"\n",
|
174 |
-
"def train(model_name, class_nb, label2id, id2label, batch_size, epochs, eval_metric, data, feature_extractor):\n",
|
175 |
-
" model = AutoModelForAudioClassification.from_pretrained(\n",
|
176 |
-
" model_name,\n",
|
177 |
-
" num_labels=class_nb,\n",
|
178 |
-
" label2id=label2id,\n",
|
179 |
-
" id2label=id2label)\n",
|
180 |
-
"\n",
|
181 |
-
" training_args = TrainingArguments(\n",
|
182 |
-
" f\"{model_name.split('/')[-1]}-ft-gtzan-{batch_size}-{epochs}\",\n",
|
183 |
-
" evaluation_strategy=\"epoch\",\n",
|
184 |
-
" save_strategy=\"epoch\",\n",
|
185 |
-
" learning_rate=5e-5,\n",
|
186 |
-
" per_device_train_batch_size=batch_size,\n",
|
187 |
-
" gradient_accumulation_steps=2,\n",
|
188 |
-
" per_device_eval_batch_size=batch_size,\n",
|
189 |
-
" num_train_epochs=epochs,\n",
|
190 |
-
" warmup_ratio=0.1,\n",
|
191 |
-
" logging_steps=5,\n",
|
192 |
-
" load_best_model_at_end=True,\n",
|
193 |
-
" metric_for_best_model=\"accuracy\",\n",
|
194 |
-
" fp16=True,\n",
|
195 |
-
" push_to_hub=True)\n",
|
196 |
-
" \n",
|
197 |
-
" trainer = Trainer(\n",
|
198 |
-
" model,\n",
|
199 |
-
" training_args,\n",
|
200 |
-
" train_dataset=data[\"train\"],\n",
|
201 |
-
" eval_dataset=data[\"test\"],\n",
|
202 |
-
" tokenizer=feature_extractor,\n",
|
203 |
-
" compute_metrics=eval_metric)\n",
|
204 |
-
"\n",
|
205 |
-
" trainer.train()"
|
206 |
-
]
|
207 |
-
},
|
208 |
-
{
|
209 |
-
"cell_type": "code",
|
210 |
-
"execution_count": 11,
|
211 |
-
"metadata": {},
|
212 |
-
"outputs": [
|
213 |
-
{
|
214 |
-
"data": {
|
215 |
-
"application/vnd.jupyter.widget-view+json": {
|
216 |
-
"model_id": "0282b77db7d1478f8e96988688e4b049",
|
217 |
-
"version_major": 2,
|
218 |
-
"version_minor": 0
|
219 |
-
},
|
220 |
-
"text/plain": [
|
221 |
-
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
|
222 |
-
]
|
223 |
-
},
|
224 |
-
"metadata": {},
|
225 |
-
"output_type": "display_data"
|
226 |
-
}
|
227 |
-
],
|
228 |
-
"source": [
|
229 |
-
"from huggingface_hub import notebook_login\n",
|
230 |
-
"\n",
|
231 |
-
"notebook_login()"
|
232 |
-
]
|
233 |
-
},
|
234 |
-
{
|
235 |
-
"cell_type": "code",
|
236 |
-
"execution_count": 14,
|
237 |
-
"metadata": {},
|
238 |
-
"outputs": [
|
239 |
-
{
|
240 |
-
"ename": "NameError",
|
241 |
-
"evalue": "name 'Eval' is not defined",
|
242 |
-
"output_type": "error",
|
243 |
-
"traceback": [
|
244 |
-
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
245 |
-
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
246 |
-
"Cell \u001b[0;32mIn[14], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mevaluate\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m acc \u001b[38;5;241m=\u001b[39m \u001b[43mEval\u001b[49m(evaluate\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124maccuracy\u001b[39m\u001b[38;5;124m'\u001b[39m))\n\u001b[1;32m 6\u001b[0m models \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 7\u001b[0m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmodel_name\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mntu-spml/distilhubert\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mclass_nb\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;28mlen\u001b[39m(id2label), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlabel2id\u001b[39m\u001b[38;5;124m'\u001b[39m: label2id, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mid2label\u001b[39m\u001b[38;5;124m'\u001b[39m: id2label, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbatch_size\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;241m4\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mepochs\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;241m8\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124meval_metric\u001b[39m\u001b[38;5;124m'\u001b[39m: acc, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdata\u001b[39m\u001b[38;5;124m'\u001b[39m: distilhubert_data, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mfeature_extractor\u001b[39m\u001b[38;5;124m'\u001b[39m: distilhubert},\n\u001b[1;32m 8\u001b[0m \u001b[38;5;66;03m# {'model_name': 'dima806/music_genres_classification', 'class_nb': len(id2label), 'label2id': label2id, 'id2label': id2label, 'batch_size': 25, 'epochs': 8, 'eval_metric': acc, 'data': music_genres_classification_data, 'feature_extractor': music_genres_classification}]\u001b[39;00m\n\u001b[1;32m 9\u001b[0m ]\n",
|
247 |
-
"\u001b[0;31mNameError\u001b[0m: name 'Eval' is not defined"
|
248 |
-
]
|
249 |
-
}
|
250 |
-
],
|
251 |
-
"source": [
|
252 |
-
"import evaluate\n",
|
253 |
-
"\n",
|
254 |
-
"acc = Eval(evaluate.load('accuracy'))\n",
|
255 |
-
"\n",
|
256 |
-
"\n",
|
257 |
-
"models = [\n",
|
258 |
-
" {'model_name': 'ntu-spml/distilhubert', 'class_nb': len(id2label), 'label2id': label2id, 'id2label': id2label, 'batch_size': 4, 'epochs': 8, 'eval_metric': acc, 'data': distilhubert_data, 'feature_extractor': distilhubert},\n",
|
259 |
-
" # {'model_name': 'dima806/music_genres_classification', 'class_nb': len(id2label), 'label2id': label2id, 'id2label': id2label, 'batch_size': 25, 'epochs': 8, 'eval_metric': acc, 'data': music_genres_classification_data, 'feature_extractor': music_genres_classification}]\n",
|
260 |
-
"]"
|
261 |
-
]
|
262 |
-
},
|
263 |
-
{
|
264 |
-
"cell_type": "code",
|
265 |
-
"execution_count": 13,
|
266 |
-
"metadata": {},
|
267 |
-
"outputs": [
|
268 |
-
{
|
269 |
-
"name": "stderr",
|
270 |
-
"output_type": "stream",
|
271 |
-
"text": [
|
272 |
-
"Some weights of HubertForSequenceClassification were not initialized from the model checkpoint at ntu-spml/distilhubert and are newly initialized: ['classifier.bias', 'classifier.weight', 'encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'projector.bias', 'projector.weight']\n",
|
273 |
-
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
|
274 |
-
"/home/potato/.virtualenvs/studia/lib/python3.10/site-packages/transformers/training_args.py:1474: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
|
275 |
-
" warnings.warn(\n"
|
276 |
-
]
|
277 |
-
},
|
278 |
-
{
|
279 |
-
"data": {
|
280 |
-
"application/vnd.jupyter.widget-view+json": {
|
281 |
-
"model_id": "abba0ecb6fb242fe8538bcdeec44ef9b",
|
282 |
-
"version_major": 2,
|
283 |
-
"version_minor": 0
|
284 |
-
},
|
285 |
-
"text/plain": [
|
286 |
-
" 0%| | 0/896 [00:00<?, ?it/s]"
|
287 |
-
]
|
288 |
-
},
|
289 |
-
"metadata": {},
|
290 |
-
"output_type": "display_data"
|
291 |
-
},
|
292 |
-
{
|
293 |
-
"name": "stderr",
|
294 |
-
"output_type": "stream",
|
295 |
-
"text": [
|
296 |
-
"/home/potato/.virtualenvs/studia/lib/python3.10/site-packages/torch/autograd/graph.py:744: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n",
|
297 |
-
" return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n"
|
298 |
-
]
|
299 |
-
},
|
300 |
-
{
|
301 |
-
"name": "stdout",
|
302 |
-
"output_type": "stream",
|
303 |
-
"text": [
|
304 |
-
"{'loss': 2.2907, 'grad_norm': 1.133973240852356, 'learning_rate': 2.777777777777778e-06, 'epoch': 0.04}\n",
|
305 |
-
"{'loss': 2.2975, 'grad_norm': 1.526039719581604, 'learning_rate': 5.555555555555556e-06, 'epoch': 0.09}\n",
|
306 |
-
"{'loss': 2.2871, 'grad_norm': 1.1069303750991821, 'learning_rate': 8.333333333333334e-06, 'epoch': 0.13}\n",
|
307 |
-
"{'loss': 2.3107, 'grad_norm': 1.4785107374191284, 'learning_rate': 1.1111111111111112e-05, 'epoch': 0.18}\n",
|
308 |
-
"{'loss': 2.2712, 'grad_norm': 1.5087419748306274, 'learning_rate': 1.388888888888889e-05, 'epoch': 0.22}\n",
|
309 |
-
"{'loss': 2.3081, 'grad_norm': 1.904876708984375, 'learning_rate': 1.6666666666666667e-05, 'epoch': 0.27}\n",
|
310 |
-
"{'loss': 2.2583, 'grad_norm': 1.2942432165145874, 'learning_rate': 1.9444444444444445e-05, 'epoch': 0.31}\n",
|
311 |
-
"{'loss': 2.2572, 'grad_norm': 1.8840770721435547, 'learning_rate': 2.2222222222222223e-05, 'epoch': 0.36}\n",
|
312 |
-
"{'loss': 2.2733, 'grad_norm': 1.25327730178833, 'learning_rate': 2.5e-05, 'epoch': 0.4}\n",
|
313 |
-
"{'loss': 2.273, 'grad_norm': 1.499450922012329, 'learning_rate': 2.777777777777778e-05, 'epoch': 0.44}\n",
|
314 |
-
"{'loss': 2.2221, 'grad_norm': 1.6644848585128784, 'learning_rate': 3.055555555555556e-05, 'epoch': 0.49}\n",
|
315 |
-
"{'loss': 2.2279, 'grad_norm': 1.5860854387283325, 'learning_rate': 3.3333333333333335e-05, 'epoch': 0.53}\n",
|
316 |
-
"{'loss': 2.2253, 'grad_norm': 1.8796266317367554, 'learning_rate': 3.611111111111111e-05, 'epoch': 0.58}\n",
|
317 |
-
"{'loss': 2.1556, 'grad_norm': 2.5186994075775146, 'learning_rate': 3.888888888888889e-05, 'epoch': 0.62}\n",
|
318 |
-
"{'loss': 2.1329, 'grad_norm': 2.4733965396881104, 'learning_rate': 4.166666666666667e-05, 'epoch': 0.67}\n",
|
319 |
-
"{'loss': 2.1214, 'grad_norm': 1.7492904663085938, 'learning_rate': 4.4444444444444447e-05, 'epoch': 0.71}\n",
|
320 |
-
"{'loss': 2.0805, 'grad_norm': 3.5089523792266846, 'learning_rate': 4.722222222222222e-05, 'epoch': 0.76}\n",
|
321 |
-
"{'loss': 1.9769, 'grad_norm': 1.759109377861023, 'learning_rate': 5e-05, 'epoch': 0.8}\n",
|
322 |
-
"{'loss': 2.0086, 'grad_norm': 3.3685412406921387, 'learning_rate': 4.968982630272953e-05, 'epoch': 0.84}\n",
|
323 |
-
"{'loss': 1.9664, 'grad_norm': 4.404444694519043, 'learning_rate': 4.937965260545906e-05, 'epoch': 0.89}\n",
|
324 |
-
"{'loss': 1.9937, 'grad_norm': 9.0780611038208, 'learning_rate': 4.9069478908188585e-05, 'epoch': 0.93}\n",
|
325 |
-
"{'loss': 1.876, 'grad_norm': 4.4436798095703125, 'learning_rate': 4.8759305210918115e-05, 'epoch': 0.98}\n"
|
326 |
-
]
|
327 |
-
},
|
328 |
-
{
|
329 |
-
"data": {
|
330 |
-
"application/vnd.jupyter.widget-view+json": {
|
331 |
-
"model_id": "7c5670c3679841ab80f531ecb12310a1",
|
332 |
-
"version_major": 2,
|
333 |
-
"version_minor": 0
|
334 |
-
},
|
335 |
-
"text/plain": [
|
336 |
-
" 0%| | 0/25 [00:00<?, ?it/s]"
|
337 |
-
]
|
338 |
-
},
|
339 |
-
"metadata": {},
|
340 |
-
"output_type": "display_data"
|
341 |
-
},
|
342 |
-
{
|
343 |
-
"ename": "TypeError",
|
344 |
-
"evalue": "'Accuracy' object is not callable",
|
345 |
-
"output_type": "error",
|
346 |
-
"traceback": [
|
347 |
-
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
348 |
-
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
|
349 |
-
"Cell \u001b[0;32mIn[13], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m m \u001b[38;5;129;01min\u001b[39;00m models:\n\u001b[0;32m----> 2\u001b[0m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mm\u001b[49m\u001b[43m)\u001b[49m\n",
|
350 |
-
"Cell \u001b[0;32mIn[10], line 42\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(model_name, class_nb, label2id, id2label, batch_size, epochs, eval_metric, data, feature_extractor)\u001b[0m\n\u001b[1;32m 18\u001b[0m training_args \u001b[38;5;241m=\u001b[39m TrainingArguments(\n\u001b[1;32m 19\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel_name\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m/\u001b[39m\u001b[38;5;124m'\u001b[39m)[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m-ft-gtzan-\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mbatch_size\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m-\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepochs\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 20\u001b[0m evaluation_strategy\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mepoch\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 31\u001b[0m fp16\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 32\u001b[0m push_to_hub\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 34\u001b[0m trainer \u001b[38;5;241m=\u001b[39m Trainer(\n\u001b[1;32m 35\u001b[0m model,\n\u001b[1;32m 36\u001b[0m training_args,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 39\u001b[0m tokenizer\u001b[38;5;241m=\u001b[39mfeature_extractor,\n\u001b[1;32m 40\u001b[0m compute_metrics\u001b[38;5;241m=\u001b[39meval_metric)\n\u001b[0;32m---> 42\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
351 |
-
"File \u001b[0;32m~/.virtualenvs/studia/lib/python3.10/site-packages/transformers/trainer.py:1876\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1873\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1874\u001b[0m \u001b[38;5;66;03m# Disable progress bars when uploading models during checkpoints to avoid polluting stdout\u001b[39;00m\n\u001b[1;32m 1875\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39mdisable_progress_bars()\n\u001b[0;32m-> 1876\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1877\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1878\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1879\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1880\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1881\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1882\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 1883\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n",
|
352 |
-
"File \u001b[0;32m~/.virtualenvs/studia/lib/python3.10/site-packages/transformers/trainer.py:2311\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 2308\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol\u001b[38;5;241m.\u001b[39mshould_training_stop \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 2310\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_epoch_end(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[0;32m-> 2311\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_maybe_log_save_evaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtr_loss\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_norm\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2313\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m DebugOption\u001b[38;5;241m.\u001b[39mTPU_METRICS_DEBUG \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mdebug:\n\u001b[1;32m 2314\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_torch_xla_available():\n\u001b[1;32m 2315\u001b[0m \u001b[38;5;66;03m# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)\u001b[39;00m\n",
|
353 |
-
"File \u001b[0;32m~/.virtualenvs/studia/lib/python3.10/site-packages/transformers/trainer.py:2721\u001b[0m, in \u001b[0;36mTrainer._maybe_log_save_evaluate\u001b[0;34m(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 2719\u001b[0m metrics \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 2720\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol\u001b[38;5;241m.\u001b[39mshould_evaluate:\n\u001b[0;32m-> 2721\u001b[0m metrics \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mevaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mignore_keys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2722\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_report_to_hp_search(trial, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step, metrics)\n\u001b[1;32m 2724\u001b[0m \u001b[38;5;66;03m# Run delayed LR scheduler now that metrics are populated\u001b[39;00m\n",
|
354 |
-
"File \u001b[0;32m~/.virtualenvs/studia/lib/python3.10/site-packages/transformers/trainer.py:3572\u001b[0m, in \u001b[0;36mTrainer.evaluate\u001b[0;34m(self, eval_dataset, ignore_keys, metric_key_prefix)\u001b[0m\n\u001b[1;32m 3569\u001b[0m start_time \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[1;32m 3571\u001b[0m eval_loop \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprediction_loop \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39muse_legacy_prediction_loop \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mevaluation_loop\n\u001b[0;32m-> 3572\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43meval_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3573\u001b[0m \u001b[43m \u001b[49m\u001b[43meval_dataloader\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3574\u001b[0m \u001b[43m \u001b[49m\u001b[43mdescription\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mEvaluation\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3575\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# No point gathering the predictions if there are no metrics, otherwise we defer to\u001b[39;49;00m\n\u001b[1;32m 3576\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# self.args.prediction_loss_only\u001b[39;49;00m\n\u001b[1;32m 3577\u001b[0m \u001b[43m \u001b[49m\u001b[43mprediction_loss_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_metrics\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mis\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 3578\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3579\u001b[0m \u001b[43m \u001b[49m\u001b[43mmetric_key_prefix\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetric_key_prefix\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3580\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3582\u001b[0m total_batch_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39meval_batch_size \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mworld_size\n\u001b[1;32m 3583\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmetric_key_prefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m_jit_compilation_time\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m output\u001b[38;5;241m.\u001b[39mmetrics:\n",
|
355 |
-
"File \u001b[0;32m~/.virtualenvs/studia/lib/python3.10/site-packages/transformers/trainer.py:3854\u001b[0m, in \u001b[0;36mTrainer.evaluation_loop\u001b[0;34m(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)\u001b[0m\n\u001b[1;32m 3850\u001b[0m metrics \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcompute_metrics(\n\u001b[1;32m 3851\u001b[0m EvalPrediction(predictions\u001b[38;5;241m=\u001b[39mall_preds, label_ids\u001b[38;5;241m=\u001b[39mall_labels, inputs\u001b[38;5;241m=\u001b[39mall_inputs)\n\u001b[1;32m 3852\u001b[0m )\n\u001b[1;32m 3853\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 3854\u001b[0m metrics \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_metrics\u001b[49m\u001b[43m(\u001b[49m\u001b[43mEvalPrediction\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpredictions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mall_preds\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabel_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mall_labels\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3855\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m metrics \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 3856\u001b[0m metrics \u001b[38;5;241m=\u001b[39m {}\n",
|
356 |
-
"\u001b[0;31mTypeError\u001b[0m: 'Accuracy' object is not callable"
|
357 |
-
]
|
358 |
-
}
|
359 |
-
],
|
360 |
-
"source": [
|
361 |
-
"for m in models:\n",
|
362 |
-
" train(**m)"
|
363 |
-
]
|
364 |
-
},
|
365 |
-
{
|
366 |
-
"cell_type": "code",
|
367 |
-
"execution_count": null,
|
368 |
-
"metadata": {},
|
369 |
-
"outputs": [],
|
370 |
-
"source": []
|
371 |
-
}
|
372 |
-
],
|
373 |
-
"metadata": {
|
374 |
-
"kernelspec": {
|
375 |
-
"display_name": "studia",
|
376 |
-
"language": "python",
|
377 |
-
"name": "python3"
|
378 |
-
},
|
379 |
-
"language_info": {
|
380 |
-
"codemirror_mode": {
|
381 |
-
"name": "ipython",
|
382 |
-
"version": 3
|
383 |
-
},
|
384 |
-
"file_extension": ".py",
|
385 |
-
"mimetype": "text/x-python",
|
386 |
-
"name": "python",
|
387 |
-
"nbconvert_exporter": "python",
|
388 |
-
"pygments_lexer": "ipython3",
|
389 |
-
"version": "3.10.12"
|
390 |
-
}
|
391 |
-
},
|
392 |
-
"nbformat": 4,
|
393 |
-
"nbformat_minor": 2
|
394 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dl/train.py
DELETED
@@ -1,113 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import numpy as np
|
3 |
-
from datasets import load_dataset, Audio
|
4 |
-
from transformers import (AutoFeatureExtractor,
|
5 |
-
AutoModelForAudioClassification, TrainingArguments,
|
6 |
-
Trainer)
|
7 |
-
import os
|
8 |
-
import evaluate
|
9 |
-
import random
|
10 |
-
|
11 |
-
|
12 |
-
accuracy_metric = evaluate.load("accuracy")
|
13 |
-
|
14 |
-
def parse_args() -> dict:
|
15 |
-
parser = argparse.ArgumentParser(description="Skrypt do trenowania modelu klasyfikacji audio.")
|
16 |
-
parser.add_argument("--learning_rate", type=float, default=5e-5,
|
17 |
-
help="Współczynnik uczenia podczas treningu modelu.")
|
18 |
-
parser.add_argument("--train_eval_split", type=float, default=0.9,
|
19 |
-
help="Stosunek danych trenujących do całego zbioru; reszta to dane walidacyjne.")
|
20 |
-
parser.add_argument("--model_id", type=str, required=True,
|
21 |
-
help="Identyfikator modelu z Hugging Face lub ścieżka do lokalnego modelu.")
|
22 |
-
parser.add_argument("--num_epochs", type=int, default=20,
|
23 |
-
help="Liczba epok treningowych.")
|
24 |
-
parser.add_argument("--seed", type=int, default=42,
|
25 |
-
help="Ziarno liczb losowych.")
|
26 |
-
parser.add_argument("--save_dir", type=str, default=".",
|
27 |
-
help="Ścieżka do katalogu wag tranowanego modelu.")
|
28 |
-
parser.add_argument("--dataset", type=str, default="marsyas/gtzan",
|
29 |
-
help="Nazwa/lokalizacja zbioru danych.")
|
30 |
-
return vars(parser.parse_args())
|
31 |
-
|
32 |
-
|
33 |
-
def compute_metrics(eval_pred):
|
34 |
-
predictions = np.argmax(eval_pred.predictions, axis=1)
|
35 |
-
return accuracy_metric.compute(predictions=predictions,
|
36 |
-
references=eval_pred.label_ids)
|
37 |
-
|
38 |
-
def main(args: dict) -> None:
|
39 |
-
random.seed(args["seed"])
|
40 |
-
max_duration = 30.0
|
41 |
-
|
42 |
-
gtzan = load_dataset(args["dataset"], "all")
|
43 |
-
gtzan = gtzan["train"].train_test_split(seed=42, shuffle=True,
|
44 |
-
test_size=1 - args["train_eval_split"])
|
45 |
-
|
46 |
-
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
47 |
-
args["model_id"], do_normalize=True, return_attention_mask=True)
|
48 |
-
sampling_rate = feature_extractor.sampling_rate
|
49 |
-
|
50 |
-
|
51 |
-
def preprocess_function(examples):
|
52 |
-
audio_arrays = [x["array"] for x in examples["audio"]]
|
53 |
-
inputs = feature_extractor(
|
54 |
-
audio_arrays,
|
55 |
-
sampling_rate=sampling_rate,
|
56 |
-
max_length=int(sampling_rate * max_duration),
|
57 |
-
truncation=True,
|
58 |
-
return_attention_mask=True,
|
59 |
-
)
|
60 |
-
return inputs
|
61 |
-
|
62 |
-
gtzan = gtzan.cast_column("audio", Audio(sampling_rate=sampling_rate))
|
63 |
-
gtzan_encoded = gtzan.map(
|
64 |
-
preprocess_function,
|
65 |
-
remove_columns=["audio", "file"],
|
66 |
-
batched=True,
|
67 |
-
batch_size=100,
|
68 |
-
num_proc=1)
|
69 |
-
|
70 |
-
gtzan_encoded = gtzan_encoded.rename_column("genre", "label")
|
71 |
-
|
72 |
-
id2label = {str(i): gtzan["train"].features["genre"].int2str(i)
|
73 |
-
for i in range(len(gtzan_encoded["train"].features["label"].names))}
|
74 |
-
label2id = {v: k for k, v in id2label.items()}
|
75 |
-
num_labels = len(id2label)
|
76 |
-
|
77 |
-
model = AutoModelForAudioClassification.from_pretrained(
|
78 |
-
args["model_id"],
|
79 |
-
num_labels=num_labels,
|
80 |
-
label2id=label2id,
|
81 |
-
id2label=id2label)
|
82 |
-
|
83 |
-
dir_name = f"{args["model_id"]}-{args["seed"]}-{args["dataset"]}-{args['learning_rate']}".replace("/", "-")
|
84 |
-
|
85 |
-
training_args = TrainingArguments(
|
86 |
-
output_dir=os.path.join(args["save_dir"], dir_name),
|
87 |
-
evaluation_strategy="epoch",
|
88 |
-
save_strategy="epoch",
|
89 |
-
learning_rate=args["learning_rate"],
|
90 |
-
per_device_train_batch_size=5,
|
91 |
-
gradient_accumulation_steps=2,
|
92 |
-
per_device_eval_batch_size=5,
|
93 |
-
num_train_epochs=args["num_epochs"],
|
94 |
-
warmup_ratio=0.1,
|
95 |
-
logging_dir="./logs",
|
96 |
-
logging_steps=5,
|
97 |
-
load_best_model_at_end=True,
|
98 |
-
metric_for_best_model="accuracy",
|
99 |
-
fp16=True)
|
100 |
-
|
101 |
-
trainer = Trainer(
|
102 |
-
model=model,
|
103 |
-
args=training_args,
|
104 |
-
train_dataset=gtzan_encoded["train"],
|
105 |
-
eval_dataset=gtzan_encoded["test"],
|
106 |
-
tokenizer=feature_extractor,
|
107 |
-
compute_metrics=compute_metrics)
|
108 |
-
|
109 |
-
trainer.train()
|
110 |
-
|
111 |
-
|
112 |
-
if __name__ == "__main__":
|
113 |
-
main(parse_args())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main.py
DELETED
@@ -1,29 +0,0 @@
|
|
1 |
-
from transformers import pipeline
|
2 |
-
import librosa
|
3 |
-
import json
|
4 |
-
import gradio as gr
|
5 |
-
|
6 |
-
|
7 |
-
def audio_pipeline(file_path: str, top_k: int = 7) -> dict[str, float]:
|
8 |
-
y, _ = librosa.load(file_path, sr=config['sampling_rate'])
|
9 |
-
out = pipe(y, top_k=top_k)
|
10 |
-
print(out)
|
11 |
-
return {clas['label']: clas['score'] for clas in out}
|
12 |
-
|
13 |
-
|
14 |
-
with open('config.json', 'r') as f:
|
15 |
-
config = json.load(f)
|
16 |
-
|
17 |
-
pipe = pipeline("audio-classification", model=config['models_path'])
|
18 |
-
|
19 |
-
demo = gr.Interface(
|
20 |
-
fn=audio_pipeline,
|
21 |
-
inputs=[gr.Audio(type="filepath"), gr.Slider(1, 10, 1,
|
22 |
-
label="Top K Results")],
|
23 |
-
outputs=gr.Label(num_top_classes=7),
|
24 |
-
title="Music Mind Audio Classification",
|
25 |
-
description="Upload an .mp3 or .ogg audio file "
|
26 |
-
"to classify the content using a pre-trained model.")
|
27 |
-
|
28 |
-
if __name__ == "__main__":
|
29 |
-
demo.launch(debug=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
-
torch
|
2 |
-
|
3 |
-
transformers
|
4 |
-
|
5 |
-
numpy
|
6 |
-
|
7 |
-
|
|
|
|
|
|
1 |
+
torch --index-url https://download.pytorch.org/whl/cu121
|
2 |
+
torchvision --index-url https://download.pytorch.org/whl/cu121
|
3 |
+
transformers==4.41.2
|
4 |
+
gradio==4.36.1
|
5 |
+
numpy==1.26.4
|
6 |
+
evaluate==0.4.2
|
7 |
+
tqdm==4.66.4
|
8 |
+
mlflow==2.13.2
|
9 |
+
librosa==0.10.2.post1
|
train.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForAudioClassification
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
import evaluate
|
4 |
+
import torch
|
5 |
+
from tqdm import tqdm
|
6 |
+
import argparse
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
import shutil
|
10 |
+
import mlflow
|
11 |
+
import mlflow.pytorch
|
12 |
+
|
13 |
+
from gtzan import GtzanDataset
|
14 |
+
|
15 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
+
|
17 |
+
metric = evaluate.load("accuracy")
|
18 |
+
|
19 |
+
def parse_args():
|
20 |
+
ap = argparse.ArgumentParser()
|
21 |
+
ap.add_argument("--label2id", type=str)
|
22 |
+
ap.add_argument("--model_id", type=str)
|
23 |
+
ap.add_argument("--batch_size", type=int, default=32)
|
24 |
+
ap.add_argument("--train_dir", type=str, default="data/train")
|
25 |
+
ap.add_argument("--val_dir", type=str, default="data/val")
|
26 |
+
ap.add_argument("--num_workers", type=int, default=4)
|
27 |
+
ap.add_argument("--lr", type=float, default=1e-4)
|
28 |
+
ap.add_argument("--epochs", type=int, default=10)
|
29 |
+
ap.add_argument("--output_dir", type=str, default="./weights")
|
30 |
+
ap.add_argument("--seed", type=int, default=42)
|
31 |
+
ap.add_argument("--name", type=str, default="model")
|
32 |
+
return vars(ap.parse_args())
|
33 |
+
|
34 |
+
def train(args):
|
35 |
+
torch.manual_seed(args["seed"])
|
36 |
+
|
37 |
+
label2id = json.load(open(args["label2id"]))
|
38 |
+
id2label = {v: k for k, v in label2id.items()}
|
39 |
+
num_labels = len(label2id)
|
40 |
+
if not os.path.exists(args["output_dir"]):
|
41 |
+
os.makedirs(args["output_dir"])
|
42 |
+
|
43 |
+
train_dataset = GtzanDataset(args["train_dir"], label2id)
|
44 |
+
val_dataset = GtzanDataset(args["val_dir"], label2id)
|
45 |
+
|
46 |
+
train_loader = DataLoader(
|
47 |
+
train_dataset,
|
48 |
+
batch_size=args["batch_size"],
|
49 |
+
shuffle=True,
|
50 |
+
num_workers=args["num_workers"])
|
51 |
+
|
52 |
+
val_loader = DataLoader(
|
53 |
+
val_dataset,
|
54 |
+
batch_size=args["batch_size"],
|
55 |
+
shuffle=False,
|
56 |
+
num_workers=args["num_workers"])
|
57 |
+
|
58 |
+
model = AutoModelForAudioClassification.from_pretrained(
|
59 |
+
args['model_id'],
|
60 |
+
num_labels=num_labels,
|
61 |
+
label2id=label2id,
|
62 |
+
id2label=id2label,
|
63 |
+
).to(device)
|
64 |
+
|
65 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=args["lr"])
|
66 |
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
67 |
+
optimizer, T_max=len(train_loader) * args["epochs"]
|
68 |
+
)
|
69 |
+
|
70 |
+
max_val_accuracy = 0
|
71 |
+
best_path = ""
|
72 |
+
|
73 |
+
with mlflow.start_run():
|
74 |
+
mlflow.log_params({
|
75 |
+
"model_id": args["model_id"],
|
76 |
+
"batch_size": args["batch_size"],
|
77 |
+
"lr": args["lr"],
|
78 |
+
"epochs": args["epochs"],
|
79 |
+
"seed": args["seed"]
|
80 |
+
})
|
81 |
+
|
82 |
+
for epoch in tqdm(range(args["epochs"])):
|
83 |
+
model.train()
|
84 |
+
train_progress_bar = tqdm(train_loader, desc=f"Training Epoch {epoch + 1}")
|
85 |
+
for batch in train_progress_bar:
|
86 |
+
input_values, attention_mask, label = [b.to(device) for b in batch]
|
87 |
+
outputs = model(input_values=input_values,
|
88 |
+
attention_mask=attention_mask,
|
89 |
+
labels=label)
|
90 |
+
loss = outputs.loss
|
91 |
+
loss.backward()
|
92 |
+
optimizer.step()
|
93 |
+
lr_scheduler.step()
|
94 |
+
optimizer.zero_grad()
|
95 |
+
|
96 |
+
train_progress_bar.set_postfix({"loss": loss.item()})
|
97 |
+
train_progress_bar.update(1)
|
98 |
+
mlflow.log_metric("train_loss", loss.item()) # Log training loss
|
99 |
+
|
100 |
+
torch.cuda.empty_cache()
|
101 |
+
model.eval()
|
102 |
+
|
103 |
+
val_progress_bar = tqdm(val_loader, desc="Validation")
|
104 |
+
for batch in val_progress_bar:
|
105 |
+
input_values, attention_mask, label = [b.to(device) for b in batch]
|
106 |
+
with torch.no_grad():
|
107 |
+
outputs = model(input_values=input_values,
|
108 |
+
attention_mask=attention_mask,
|
109 |
+
labels=label)
|
110 |
+
|
111 |
+
logits = outputs.logits
|
112 |
+
predictions = torch.argmax(logits, dim=-1)
|
113 |
+
metric.add_batch(predictions=predictions, references=label)
|
114 |
+
val_progress_bar.update(1)
|
115 |
+
|
116 |
+
val_accuracy = metric.compute()
|
117 |
+
mlflow.log_metric("val_accuracy", val_accuracy["accuracy"], step=epoch) # Log validation accuracy
|
118 |
+
torch.cuda.empty_cache()
|
119 |
+
if val_accuracy["accuracy"] > max_val_accuracy:
|
120 |
+
if best_path:
|
121 |
+
shutil.rmtree(best_path)
|
122 |
+
model_save_dir = os.path.join(
|
123 |
+
args["output_dir"],
|
124 |
+
args['name'],
|
125 |
+
f"{int(round(val_accuracy['accuracy'], 2) * 100)}")
|
126 |
+
if not os.path.exists(model_save_dir):
|
127 |
+
os.makedirs(model_save_dir, exist_ok=True)
|
128 |
+
model.save_pretrained(model_save_dir)
|
129 |
+
max_val_accuracy = val_accuracy["accuracy"]
|
130 |
+
best_path = model_save_dir
|
131 |
+
|
132 |
+
mlflow.pytorch.log_model(model, "model")
|
133 |
+
if __name__ == "__main__":
|
134 |
+
train(parse_args())
|