SleepyJesse commited on
Commit
d6a066c
·
verified ·
1 Parent(s): a479cdd

Upload ai_music_detection_new_large_60.ipynb

Browse files
Files changed (1) hide show
  1. ai_music_detection_new_large_60.ipynb +988 -0
ai_music_detection_new_large_60.ipynb ADDED
@@ -0,0 +1,988 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "Requirement already satisfied: librosa in /opt/conda/lib/python3.10/site-packages (0.10.2.post1)\n",
13
+ "Requirement already satisfied: soundfile in /opt/conda/lib/python3.10/site-packages (0.12.1)\n",
14
+ "Requirement already satisfied: torchaudio in /opt/conda/lib/python3.10/site-packages (2.2.0)\n",
15
+ "Requirement already satisfied: audiomentations in /opt/conda/lib/python3.10/site-packages (0.38.0)\n",
16
+ "Requirement already satisfied: evaluate in /opt/conda/lib/python3.10/site-packages (0.4.3)\n",
17
+ "Requirement already satisfied: ipywidgets in /opt/conda/lib/python3.10/site-packages (8.1.5)\n",
18
+ "Requirement already satisfied: matplotlib in /opt/conda/lib/python3.10/site-packages (3.9.3)\n",
19
+ "Requirement already satisfied: tensorboard in /opt/conda/lib/python3.10/site-packages (2.18.0)\n",
20
+ "Requirement already satisfied: datasets[audio] in /opt/conda/lib/python3.10/site-packages (3.2.0)\n",
21
+ "Requirement already satisfied: transformers[torch] in /opt/conda/lib/python3.10/site-packages (4.47.0)\n",
22
+ "Requirement already satisfied: audioread>=2.1.9 in /opt/conda/lib/python3.10/site-packages (from librosa) (3.0.1)\n",
23
+ "Requirement already satisfied: numpy!=1.22.0,!=1.22.1,!=1.22.2,>=1.20.3 in /opt/conda/lib/python3.10/site-packages (from librosa) (1.26.3)\n",
24
+ "Requirement already satisfied: scipy>=1.2.0 in /opt/conda/lib/python3.10/site-packages (from librosa) (1.12.0)\n",
25
+ "Requirement already satisfied: scikit-learn>=0.20.0 in /opt/conda/lib/python3.10/site-packages (from librosa) (1.5.2)\n",
26
+ "Requirement already satisfied: joblib>=0.14 in /opt/conda/lib/python3.10/site-packages (from librosa) (1.4.2)\n",
27
+ "Requirement already satisfied: decorator>=4.3.0 in /opt/conda/lib/python3.10/site-packages (from librosa) (5.1.1)\n",
28
+ "Requirement already satisfied: numba>=0.51.0 in /opt/conda/lib/python3.10/site-packages (from librosa) (0.60.0)\n",
29
+ "Requirement already satisfied: pooch>=1.1 in /opt/conda/lib/python3.10/site-packages (from librosa) (1.8.2)\n",
30
+ "Requirement already satisfied: soxr>=0.3.2 in /opt/conda/lib/python3.10/site-packages (from librosa) (0.5.0.post1)\n",
31
+ "Requirement already satisfied: typing-extensions>=4.1.1 in /opt/conda/lib/python3.10/site-packages (from librosa) (4.9.0)\n",
32
+ "Requirement already satisfied: lazy-loader>=0.1 in /opt/conda/lib/python3.10/site-packages (from librosa) (0.4)\n",
33
+ "Requirement already satisfied: msgpack>=1.0 in /opt/conda/lib/python3.10/site-packages (from librosa) (1.1.0)\n",
34
+ "Requirement already satisfied: cffi>=1.0 in /opt/conda/lib/python3.10/site-packages (from soundfile) (1.16.0)\n",
35
+ "Requirement already satisfied: torch in /opt/conda/lib/python3.10/site-packages (from torchaudio) (2.2.0)\n",
36
+ "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from datasets[audio]) (3.13.1)\n",
37
+ "Requirement already satisfied: pyarrow>=15.0.0 in /opt/conda/lib/python3.10/site-packages (from datasets[audio]) (18.1.0)\n",
38
+ "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /opt/conda/lib/python3.10/site-packages (from datasets[audio]) (0.3.8)\n",
39
+ "Requirement already satisfied: pandas in /opt/conda/lib/python3.10/site-packages (from datasets[audio]) (2.2.3)\n",
40
+ "Requirement already satisfied: requests>=2.32.2 in /opt/conda/lib/python3.10/site-packages (from datasets[audio]) (2.32.3)\n",
41
+ "Requirement already satisfied: tqdm>=4.66.3 in /opt/conda/lib/python3.10/site-packages (from datasets[audio]) (4.67.1)\n",
42
+ "Requirement already satisfied: xxhash in /opt/conda/lib/python3.10/site-packages (from datasets[audio]) (3.5.0)\n",
43
+ "Requirement already satisfied: multiprocess<0.70.17 in /opt/conda/lib/python3.10/site-packages (from datasets[audio]) (0.70.16)\n",
44
+ "Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /opt/conda/lib/python3.10/site-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets[audio]) (2023.12.2)\n",
45
+ "Requirement already satisfied: aiohttp in /opt/conda/lib/python3.10/site-packages (from datasets[audio]) (3.11.10)\n",
46
+ "Requirement already satisfied: huggingface-hub>=0.23.0 in /opt/conda/lib/python3.10/site-packages (from datasets[audio]) (0.26.5)\n",
47
+ "Requirement already satisfied: packaging in /opt/conda/lib/python3.10/site-packages (from datasets[audio]) (23.1)\n",
48
+ "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.10/site-packages (from datasets[audio]) (6.0.1)\n",
49
+ "Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.10/site-packages (from transformers[torch]) (2024.11.6)\n",
50
+ "Requirement already satisfied: tokenizers<0.22,>=0.21 in /opt/conda/lib/python3.10/site-packages (from transformers[torch]) (0.21.0)\n",
51
+ "Requirement already satisfied: safetensors>=0.4.1 in /opt/conda/lib/python3.10/site-packages (from transformers[torch]) (0.4.5)\n",
52
+ "Requirement already satisfied: accelerate>=0.26.0 in /opt/conda/lib/python3.10/site-packages (from transformers[torch]) (1.2.0)\n",
53
+ "Requirement already satisfied: numpy-minmax<1,>=0.3.0 in /opt/conda/lib/python3.10/site-packages (from audiomentations) (0.3.1)\n",
54
+ "Requirement already satisfied: numpy-rms<1,>=0.4.2 in /opt/conda/lib/python3.10/site-packages (from audiomentations) (0.4.2)\n",
55
+ "Requirement already satisfied: comm>=0.1.3 in /opt/conda/lib/python3.10/site-packages (from ipywidgets) (0.2.2)\n",
56
+ "Requirement already satisfied: ipython>=6.1.0 in /opt/conda/lib/python3.10/site-packages (from ipywidgets) (8.20.0)\n",
57
+ "Requirement already satisfied: traitlets>=4.3.1 in /opt/conda/lib/python3.10/site-packages (from ipywidgets) (5.7.1)\n",
58
+ "Requirement already satisfied: widgetsnbextension~=4.0.12 in /opt/conda/lib/python3.10/site-packages (from ipywidgets) (4.0.13)\n",
59
+ "Requirement already satisfied: jupyterlab-widgets~=3.0.12 in /opt/conda/lib/python3.10/site-packages (from ipywidgets) (3.0.13)\n",
60
+ "Requirement already satisfied: contourpy>=1.0.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (1.3.1)\n",
61
+ "Requirement already satisfied: cycler>=0.10 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (0.12.1)\n",
62
+ "Requirement already satisfied: fonttools>=4.22.0 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (4.55.3)\n",
63
+ "Requirement already satisfied: kiwisolver>=1.3.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (1.4.7)\n",
64
+ "Requirement already satisfied: pillow>=8 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (10.0.1)\n",
65
+ "Requirement already satisfied: pyparsing>=2.3.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (3.2.0)\n",
66
+ "Requirement already satisfied: python-dateutil>=2.7 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (2.9.0.post0)\n",
67
+ "Requirement already satisfied: absl-py>=0.4 in /opt/conda/lib/python3.10/site-packages (from tensorboard) (2.1.0)\n",
68
+ "Requirement already satisfied: grpcio>=1.48.2 in /opt/conda/lib/python3.10/site-packages (from tensorboard) (1.68.1)\n",
69
+ "Requirement already satisfied: markdown>=2.6.8 in /opt/conda/lib/python3.10/site-packages (from tensorboard) (3.7)\n",
70
+ "Requirement already satisfied: protobuf!=4.24.0,>=3.19.6 in /opt/conda/lib/python3.10/site-packages (from tensorboard) (5.29.1)\n",
71
+ "Requirement already satisfied: setuptools>=41.0.0 in /opt/conda/lib/python3.10/site-packages (from tensorboard) (68.2.2)\n",
72
+ "Requirement already satisfied: six>1.9 in /opt/conda/lib/python3.10/site-packages (from tensorboard) (1.16.0)\n",
73
+ "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /opt/conda/lib/python3.10/site-packages (from tensorboard) (0.7.2)\n",
74
+ "Requirement already satisfied: werkzeug>=1.0.1 in /opt/conda/lib/python3.10/site-packages (from tensorboard) (3.1.3)\n",
75
+ "Requirement already satisfied: psutil in /opt/conda/lib/python3.10/site-packages (from accelerate>=0.26.0->transformers[torch]) (5.9.0)\n",
76
+ "Requirement already satisfied: pycparser in /opt/conda/lib/python3.10/site-packages (from cffi>=1.0->soundfile) (2.21)\n",
77
+ "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets[audio]) (2.4.4)\n",
78
+ "Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets[audio]) (1.3.1)\n",
79
+ "Requirement already satisfied: async-timeout<6.0,>=4.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets[audio]) (5.0.1)\n",
80
+ "Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets[audio]) (23.1.0)\n",
81
+ "Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets[audio]) (1.5.0)\n",
82
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets[audio]) (6.1.0)\n",
83
+ "Requirement already satisfied: propcache>=0.2.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets[audio]) (0.2.1)\n",
84
+ "Requirement already satisfied: yarl<2.0,>=1.17.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets[audio]) (1.18.3)\n",
85
+ "Requirement already satisfied: jedi>=0.16 in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (0.18.1)\n",
86
+ "Requirement already satisfied: matplotlib-inline in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (0.1.6)\n",
87
+ "Requirement already satisfied: prompt-toolkit<3.1.0,>=3.0.41 in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (3.0.43)\n",
88
+ "Requirement already satisfied: pygments>=2.4.0 in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (2.15.1)\n",
89
+ "Requirement already satisfied: stack-data in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (0.2.0)\n",
90
+ "Requirement already satisfied: exceptiongroup in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (1.2.0)\n",
91
+ "Requirement already satisfied: pexpect>4.3 in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (4.8.0)\n",
92
+ "Requirement already satisfied: llvmlite<0.44,>=0.43.0dev0 in /opt/conda/lib/python3.10/site-packages (from numba>=0.51.0->librosa) (0.43.0)\n",
93
+ "Requirement already satisfied: platformdirs>=2.5.0 in /opt/conda/lib/python3.10/site-packages (from pooch>=1.1->librosa) (3.10.0)\n",
94
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests>=2.32.2->datasets[audio]) (2.0.4)\n",
95
+ "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests>=2.32.2->datasets[audio]) (3.4)\n",
96
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests>=2.32.2->datasets[audio]) (1.26.18)\n",
97
+ "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests>=2.32.2->datasets[audio]) (2023.11.17)\n",
98
+ "Requirement already satisfied: threadpoolctl>=3.1.0 in /opt/conda/lib/python3.10/site-packages (from scikit-learn>=0.20.0->librosa) (3.5.0)\n",
99
+ "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch->torchaudio) (1.12)\n",
100
+ "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch->torchaudio) (3.1)\n",
101
+ "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch->torchaudio) (3.1.2)\n",
102
+ "Requirement already satisfied: MarkupSafe>=2.1.1 in /opt/conda/lib/python3.10/site-packages (from werkzeug>=1.0.1->tensorboard) (2.1.3)\n",
103
+ "Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets[audio]) (2023.3.post1)\n",
104
+ "Requirement already satisfied: tzdata>=2022.7 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets[audio]) (2024.2)\n",
105
+ "Requirement already satisfied: parso<0.9.0,>=0.8.0 in /opt/conda/lib/python3.10/site-packages (from jedi>=0.16->ipython>=6.1.0->ipywidgets) (0.8.3)\n",
106
+ "Requirement already satisfied: ptyprocess>=0.5 in /opt/conda/lib/python3.10/site-packages (from pexpect>4.3->ipython>=6.1.0->ipywidgets) (0.7.0)\n",
107
+ "Requirement already satisfied: wcwidth in /opt/conda/lib/python3.10/site-packages (from prompt-toolkit<3.1.0,>=3.0.41->ipython>=6.1.0->ipywidgets) (0.2.5)\n",
108
+ "Requirement already satisfied: executing in /opt/conda/lib/python3.10/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (0.8.3)\n",
109
+ "Requirement already satisfied: asttokens in /opt/conda/lib/python3.10/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (2.0.5)\n",
110
+ "Requirement already satisfied: pure-eval in /opt/conda/lib/python3.10/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (0.2.2)\n",
111
+ "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch->torchaudio) (1.3.0)\n",
112
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
113
+ "Note: you may need to restart the kernel to use updated packages.\n"
114
+ ]
115
+ }
116
+ ],
117
+ "source": [
118
+ "%pip install librosa soundfile torchaudio datasets[audio] transformers[torch] audiomentations evaluate ipywidgets matplotlib tensorboard"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": 2,
124
+ "metadata": {},
125
+ "outputs": [],
126
+ "source": [
127
+ "import torch\n",
128
+ "import torchaudio\n",
129
+ "import librosa\n",
130
+ "import soundfile as sf\n",
131
+ "import numpy as np\n",
132
+ "import os\n",
133
+ "import matplotlib.pyplot as plt\n",
134
+ "import IPython.display as ipd\n",
135
+ "import datasets\n",
136
+ "import evaluate\n",
137
+ "from concurrent.futures import ProcessPoolExecutor\n",
138
+ "from transformers import ASTForAudioClassification, ASTFeatureExtractor, ASTConfig, TrainingArguments, Trainer\n",
139
+ "from audiomentations import Compose, AddGaussianSNR, GainTransition, Gain, ClippingDistortion, TimeStretch, PitchShift\n",
140
+ "from tqdm import tqdm"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": 3,
146
+ "metadata": {},
147
+ "outputs": [],
148
+ "source": [
149
+ "MODEL_DIR = \".\""
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": 4,
155
+ "metadata": {},
156
+ "outputs": [
157
+ {
158
+ "data": {
159
+ "application/vnd.jupyter.widget-view+json": {
160
+ "model_id": "64d53eaca2f340179220cdb30d7ad7b0",
161
+ "version_major": 2,
162
+ "version_minor": 0
163
+ },
164
+ "text/plain": [
165
+ "Resolving data files: 0%| | 0/147 [00:00<?, ?it/s]"
166
+ ]
167
+ },
168
+ "metadata": {},
169
+ "output_type": "display_data"
170
+ },
171
+ {
172
+ "data": {
173
+ "application/vnd.jupyter.widget-view+json": {
174
+ "model_id": "f80dd84995604737aa93a60a5eddeda5",
175
+ "version_major": 2,
176
+ "version_minor": 0
177
+ },
178
+ "text/plain": [
179
+ "Resolving data files: 0%| | 0/147 [00:00<?, ?it/s]"
180
+ ]
181
+ },
182
+ "metadata": {},
183
+ "output_type": "display_data"
184
+ },
185
+ {
186
+ "data": {
187
+ "application/vnd.jupyter.widget-view+json": {
188
+ "model_id": "bf86d3b8bac44dfbbd956b1871af75f0",
189
+ "version_major": 2,
190
+ "version_minor": 0
191
+ },
192
+ "text/plain": [
193
+ "Loading dataset shards: 0%| | 0/115 [00:00<?, ?it/s]"
194
+ ]
195
+ },
196
+ "metadata": {},
197
+ "output_type": "display_data"
198
+ },
199
+ {
200
+ "name": "stdout",
201
+ "output_type": "stream",
202
+ "text": [
203
+ "DatasetDict({\n",
204
+ " train: Dataset({\n",
205
+ " features: ['audio', 'source', 'ai_generated'],\n",
206
+ " num_rows: 20000\n",
207
+ " })\n",
208
+ "})\n"
209
+ ]
210
+ }
211
+ ],
212
+ "source": [
213
+ "# Load the dataset\n",
214
+ "ds = datasets.load_dataset(\"SleepyJesse/ai_music_large\")\n",
215
+ "# Resample the audio files to 16kHz\n",
216
+ "ds = ds.cast_column(\"audio\", datasets.Audio(sampling_rate=16000, mono=True))\n",
217
+ "print(ds)"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": 5,
223
+ "metadata": {},
224
+ "outputs": [],
225
+ "source": [
226
+ "# Cast the \"ai_generated\" column (boolean) to class labels (\"ai_generated\" or \"human\")\n",
227
+ "class_labels = datasets.ClassLabel(names=[\"human\", \"ai_generated\"])\n",
228
+ "labels = [1 if x else 0 for x in ds['train']['ai_generated']]\n",
229
+ "ds['train'] = ds['train'].add_column(\"labels\", labels, feature=class_labels)"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": 6,
235
+ "metadata": {},
236
+ "outputs": [],
237
+ "source": [
238
+ "# Remove the \"ai_generated\" and \"source\" columns\n",
239
+ "ds[\"train\"] = ds[\"train\"].remove_columns(\"ai_generated\")\n",
240
+ "ds[\"train\"] = ds[\"train\"].remove_columns(\"source\")\n",
241
+ "\n",
242
+ "# Rename the \"audio\" column to \"input_values\" to match the expected input key for the processor\n",
243
+ "ds = ds.rename_column(\"audio\", \"input_values\")"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "code",
248
+ "execution_count": 7,
249
+ "metadata": {},
250
+ "outputs": [
251
+ {
252
+ "name": "stdout",
253
+ "output_type": "stream",
254
+ "text": [
255
+ "DatasetDict({\n",
256
+ " train: Dataset({\n",
257
+ " features: ['input_values', 'labels'],\n",
258
+ " num_rows: 20000\n",
259
+ " })\n",
260
+ "})\n",
261
+ "{'input_values': {'path': '030312.mp3', 'array': array([ 0. , 0. , 0. , ..., -0.00048378,\n",
262
+ " -0.00049008, 0. ]), 'sampling_rate': 16000}, 'labels': 0}\n"
263
+ ]
264
+ }
265
+ ],
266
+ "source": [
267
+ "print(ds)\n",
268
+ "print(ds[\"train\"][0])"
269
+ ]
270
+ },
271
+ {
272
+ "cell_type": "code",
273
+ "execution_count": 8,
274
+ "metadata": {},
275
+ "outputs": [],
276
+ "source": [
277
+ "model_name = \"MIT/ast-finetuned-audioset-10-10-0.4593\" # Pre-trained AST model\n",
278
+ "feature_extractor = ASTFeatureExtractor.from_pretrained(model_name)\n",
279
+ "feature_extractor.max_length = 6000\n",
280
+ "model_input_name = feature_extractor.model_input_names[0]\n",
281
+ "sampling_rate = feature_extractor.sampling_rate"
282
+ ]
283
+ },
284
+ {
285
+ "cell_type": "code",
286
+ "execution_count": 9,
287
+ "metadata": {},
288
+ "outputs": [],
289
+ "source": [
290
+ "# Define a function to preprocess the audio data\n",
291
+ "def preprocess_audio(batch):\n",
292
+ " wavs = [audio[\"array\"] for audio in batch[\"input_values\"]]\n",
293
+ " # inputs are spectrograms as torch.tensors now\n",
294
+ " inputs = feature_extractor(wavs, sampling_rate=sampling_rate, return_tensors=\"pt\")\n",
295
+ " del wavs\n",
296
+ "\n",
297
+ " output_batch = {model_input_name: inputs.get(model_input_name), \"labels\": list(batch[\"labels\"])}\n",
298
+ " return output_batch\n",
299
+ "\n",
300
+ "# Apply the preprocessing function to the dataset\n",
301
+ "ds[\"train\"].set_transform(preprocess_audio, output_all_columns=False)"
302
+ ]
303
+ },
304
+ {
305
+ "cell_type": "code",
306
+ "execution_count": 10,
307
+ "metadata": {},
308
+ "outputs": [],
309
+ "source": [
310
+ "# Create audio augmentations\n",
311
+ "audio_augmentations = Compose([\n",
312
+ " AddGaussianSNR(min_snr_db=10, max_snr_db=20),\n",
313
+ " Gain(min_gain_db=-6, max_gain_db=6),\n",
314
+ " GainTransition(min_gain_db=-6, max_gain_db=6, min_duration=0.01, max_duration=0.3, duration_unit=\"fraction\"),\n",
315
+ " ClippingDistortion(min_percentile_threshold=0, max_percentile_threshold=30, p=0.5),\n",
316
+ " TimeStretch(min_rate=0.8, max_rate=1.2),\n",
317
+ " PitchShift(min_semitones=-4, max_semitones=4),\n",
318
+ "], p=0.8, shuffle=True)"
319
+ ]
320
+ },
321
+ {
322
+ "cell_type": "code",
323
+ "execution_count": 11,
324
+ "metadata": {},
325
+ "outputs": [],
326
+ "source": [
327
+ "# Define the preprocessing function for the audio augmentations\n",
328
+ "def preprocess_audio_with_transforms(batch):\n",
329
+ " # we apply augmentations on each waveform\n",
330
+ " wavs = [audio_augmentations(audio[\"array\"], sample_rate=sampling_rate) for audio in batch[\"input_values\"]]\n",
331
+ " inputs = feature_extractor(wavs, sampling_rate=sampling_rate, return_tensors=\"pt\")\n",
332
+ " del wavs\n",
333
+ "\n",
334
+ " output_batch = {model_input_name: inputs.get(model_input_name), \"labels\": list(batch[\"labels\"])}\n",
335
+ " return output_batch"
336
+ ]
337
+ },
338
+ {
339
+ "cell_type": "code",
340
+ "execution_count": 12,
341
+ "metadata": {},
342
+ "outputs": [],
343
+ "source": [
344
+ "# Calculate values for normalization (mean and std) for the dataset (Only need to run this once per dataset)\n",
345
+ "# feature_extractor.do_normalize = False # Disable normalization\n",
346
+ "\n",
347
+ "# means = []\n",
348
+ "# stds = []\n",
349
+ "\n",
350
+ "# def calculate_mean_std(index):\n",
351
+ "# try:\n",
352
+ "# audio_input = ds[\"train\"][index][\"input_values\"]\n",
353
+ "# except Exception as e:\n",
354
+ "# print(f\"Error processing index {index}: {e}\")\n",
355
+ "# return None, None\n",
356
+ "# cur_mean = torch.mean(audio_input)\n",
357
+ "# cur_std = torch.std(audio_input)\n",
358
+ "# return cur_mean, cur_std\n",
359
+ "\n",
360
+ "# with ProcessPoolExecutor() as executor:\n",
361
+ "# results = list(tqdm(executor.map(calculate_mean_std, range(len(ds[\"train\"]))), total=len(ds[\"train\"])))\n",
362
+ "\n",
363
+ "# means, stds = zip(*results)\n",
364
+ "# means = [x.item() for x in means if x is not None]\n",
365
+ "# stds = [x.item() for x in stds if x is not None]\n",
366
+ "# feature_extractor.mean = torch.tensor(means).mean().item()\n",
367
+ "# feature_extractor.std = torch.tensor(stds).mean().item()\n",
368
+ "# feature_extractor.do_normalize = True # Enable normalization\n",
369
+ "\n",
370
+ "# print(f\"Mean: {feature_extractor.mean}\")\n",
371
+ "# print(f\"Std: {feature_extractor.std}\")\n",
372
+ "# print(\"Save these values for normalization if you're using the same dataset in the future.\")"
373
+ ]
374
+ },
375
+ {
376
+ "cell_type": "code",
377
+ "execution_count": 13,
378
+ "metadata": {},
379
+ "outputs": [],
380
+ "source": [
381
+ "# Remove corrupted audio files (4481, 8603 in ai_music_large)\n",
382
+ "corrupted_audio_indices = [4481, 8603]\n",
383
+ "keep_indices = [i for i in range(len(ds[\"train\"])) if i not in corrupted_audio_indices]\n",
384
+ "ds[\"train\"] = ds[\"train\"].select(keep_indices, writer_batch_size=50)"
385
+ ]
386
+ },
387
+ {
388
+ "cell_type": "code",
389
+ "execution_count": 14,
390
+ "metadata": {},
391
+ "outputs": [],
392
+ "source": [
393
+ "# Set the normalization values in the feature extractor (the following values are for the ai_music_large dataset)\n",
394
+ "feature_extractor.mean = -4.855465888977051\n",
395
+ "feature_extractor.std = 3.2848217487335205"
396
+ ]
397
+ },
398
+ {
399
+ "cell_type": "code",
400
+ "execution_count": 15,
401
+ "metadata": {},
402
+ "outputs": [
403
+ {
404
+ "name": "stdout",
405
+ "output_type": "stream",
406
+ "text": [
407
+ "Mean: -4.855465888977051\n",
408
+ "Std: 3.2848217487335205\n"
409
+ ]
410
+ }
411
+ ],
412
+ "source": [
413
+ "print(f\"Mean: {feature_extractor.mean}\")\n",
414
+ "print(f\"Std: {feature_extractor.std}\")"
415
+ ]
416
+ },
417
+ {
418
+ "cell_type": "code",
419
+ "execution_count": 16,
420
+ "metadata": {},
421
+ "outputs": [
422
+ {
423
+ "name": "stdout",
424
+ "output_type": "stream",
425
+ "text": [
426
+ "DatasetDict({\n",
427
+ " train: Dataset({\n",
428
+ " features: ['input_values', 'labels'],\n",
429
+ " num_rows: 15998\n",
430
+ " })\n",
431
+ " test: Dataset({\n",
432
+ " features: ['input_values', 'labels'],\n",
433
+ " num_rows: 4000\n",
434
+ " })\n",
435
+ "})\n"
436
+ ]
437
+ }
438
+ ],
439
+ "source": [
440
+ "# Split the dataset\n",
441
+ "if \"test\" not in ds:\n",
442
+ " ds = ds[\"train\"].train_test_split(test_size=0.2, shuffle=True, seed=42, stratify_by_column=\"labels\")\n",
443
+ "\n",
444
+ "print(ds)"
445
+ ]
446
+ },
447
+ {
448
+ "cell_type": "code",
449
+ "execution_count": 17,
450
+ "metadata": {},
451
+ "outputs": [],
452
+ "source": [
453
+ "# Set transforms for the train and test sets\n",
454
+ "ds[\"train\"].set_transform(preprocess_audio_with_transforms, output_all_columns=False)\n",
455
+ "ds[\"test\"].set_transform(preprocess_audio, output_all_columns=False)"
456
+ ]
457
+ },
458
+ {
459
+ "cell_type": "code",
460
+ "execution_count": 18,
461
+ "metadata": {},
462
+ "outputs": [
463
+ {
464
+ "name": "stderr",
465
+ "output_type": "stream",
466
+ "text": [
467
+ "Some weights of ASTForAudioClassification were not initialized from the model checkpoint at MIT/ast-finetuned-audioset-10-10-0.4593 and are newly initialized because the shapes did not match:\n",
468
+ "- audio_spectrogram_transformer.embeddings.position_embeddings: found shape torch.Size([1, 1214, 768]) in the checkpoint and torch.Size([1, 7190, 768]) in the model instantiated\n",
469
+ "- classifier.dense.bias: found shape torch.Size([527]) in the checkpoint and torch.Size([2]) in the model instantiated\n",
470
+ "- classifier.dense.weight: found shape torch.Size([527, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated\n",
471
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
472
+ ]
473
+ }
474
+ ],
475
+ "source": [
476
+ "# Load config from the pre-trained model\n",
477
+ "config = ASTConfig.from_pretrained(model_name)\n",
478
+ "config.max_length = 6000\n",
479
+ "\n",
480
+ "# Update the config with the labels we have in the dataset\n",
481
+ "config.num_labels = len(ds[\"train\"].features[\"labels\"].names)\n",
482
+ "config.label2id = {name: id for id, name in enumerate(ds[\"train\"].features[\"labels\"].names)}\n",
483
+ "config.id2label = {id: name for name, id in config.label2id.items()}\n",
484
+ "\n",
485
+ "# Initialize the model\n",
486
+ "model = ASTForAudioClassification.from_pretrained(model_name, config=config, ignore_mismatched_sizes=True)\n",
487
+ "model.init_weights()"
488
+ ]
489
+ },
490
+ {
491
+ "cell_type": "code",
492
+ "execution_count": 19,
493
+ "metadata": {},
494
+ "outputs": [
495
+ {
496
+ "name": "stdout",
497
+ "output_type": "stream",
498
+ "text": [
499
+ "ASTConfig {\n",
500
+ " \"architectures\": [\n",
501
+ " \"ASTForAudioClassification\"\n",
502
+ " ],\n",
503
+ " \"attention_probs_dropout_prob\": 0.0,\n",
504
+ " \"frequency_stride\": 10,\n",
505
+ " \"hidden_act\": \"gelu\",\n",
506
+ " \"hidden_dropout_prob\": 0.0,\n",
507
+ " \"hidden_size\": 768,\n",
508
+ " \"id2label\": {\n",
509
+ " \"0\": \"human\",\n",
510
+ " \"1\": \"ai_generated\"\n",
511
+ " },\n",
512
+ " \"initializer_range\": 0.02,\n",
513
+ " \"intermediate_size\": 3072,\n",
514
+ " \"label2id\": {\n",
515
+ " \"ai_generated\": 1,\n",
516
+ " \"human\": 0\n",
517
+ " },\n",
518
+ " \"layer_norm_eps\": 1e-12,\n",
519
+ " \"max_length\": 6000,\n",
520
+ " \"model_type\": \"audio-spectrogram-transformer\",\n",
521
+ " \"num_attention_heads\": 12,\n",
522
+ " \"num_hidden_layers\": 12,\n",
523
+ " \"num_mel_bins\": 128,\n",
524
+ " \"patch_size\": 16,\n",
525
+ " \"qkv_bias\": true,\n",
526
+ " \"time_stride\": 10,\n",
527
+ " \"torch_dtype\": \"float32\",\n",
528
+ " \"transformers_version\": \"4.47.0\"\n",
529
+ "}\n",
530
+ "\n"
531
+ ]
532
+ }
533
+ ],
534
+ "source": [
535
+ "print(config)"
536
+ ]
537
+ },
538
+ {
539
+ "cell_type": "code",
540
+ "execution_count": 20,
541
+ "metadata": {},
542
+ "outputs": [],
543
+ "source": [
544
+ "# Configure the training arguments\n",
545
+ "training_args = TrainingArguments(\n",
546
+ " output_dir=MODEL_DIR + \"/out/ast_classifier_large_60\",\n",
547
+ " logging_dir=MODEL_DIR + \"/logs/ast_classifier_large_60\",\n",
548
+ " report_to=\"tensorboard\",\n",
549
+ " learning_rate=5e-5,\n",
550
+ " push_to_hub=False,\n",
551
+ " num_train_epochs=10,\n",
552
+ " per_device_train_batch_size=4,\n",
553
+ " eval_strategy=\"epoch\",\n",
554
+ " save_strategy=\"steps\",\n",
555
+ " eval_steps=1,\n",
556
+ " save_steps=1300,\n",
557
+ " logging_steps=10,\n",
558
+ " metric_for_best_model=\"accuracy\",\n",
559
+ " dataloader_num_workers=32,\n",
560
+ " dataloader_prefetch_factor=4,\n",
561
+ " dataloader_persistent_workers=True,\n",
562
+ ")"
563
+ ]
564
+ },
565
+ {
566
+ "cell_type": "code",
567
+ "execution_count": 21,
568
+ "metadata": {},
569
+ "outputs": [],
570
+ "source": [
571
+ "# Define evaluation metrics\n",
572
+ "accuracy = evaluate.load(\"accuracy\")\n",
573
+ "recall = evaluate.load(\"recall\")\n",
574
+ "precision = evaluate.load(\"precision\")\n",
575
+ "f1 = evaluate.load(\"f1\")\n",
576
+ "\n",
577
+ "average = \"macro\" if config.num_labels > 2 else \"binary\"\n",
578
+ "\n",
579
+ "def compute_metrics(eval_pred):\n",
580
+ " logits = eval_pred.predictions\n",
581
+ " predictions = np.argmax(logits, axis=-1)\n",
582
+ " metrics = accuracy.compute(predictions=predictions, references=eval_pred.label_ids)\n",
583
+ " metrics.update(precision.compute(predictions=predictions, references=eval_pred.label_ids, average=average))\n",
584
+ " metrics.update(recall.compute(predictions=predictions, references=eval_pred.label_ids, average=average))\n",
585
+ " metrics.update(f1.compute(predictions=predictions, references=eval_pred.label_ids, average=average))\n",
586
+ " return metrics"
587
+ ]
588
+ },
589
+ {
590
+ "cell_type": "code",
591
+ "execution_count": 22,
592
+ "metadata": {},
593
+ "outputs": [
594
+ {
595
+ "name": "stderr",
596
+ "output_type": "stream",
597
+ "text": [
598
+ "Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
599
+ ]
600
+ }
601
+ ],
602
+ "source": [
603
+ "# Initialize the Trainer\n",
604
+ "trainer = Trainer(\n",
605
+ " model=model,\n",
606
+ " args=training_args,\n",
607
+ " train_dataset=ds[\"train\"],\n",
608
+ " eval_dataset=ds[\"test\"],\n",
609
+ " compute_metrics=compute_metrics,\n",
610
+ ")"
611
+ ]
612
+ },
613
+ {
614
+ "cell_type": "code",
615
+ "execution_count": 23,
616
+ "metadata": {},
617
+ "outputs": [],
618
+ "source": [
619
+ "# trainer.evaluate()"
620
+ ]
621
+ },
622
+ {
623
+ "cell_type": "code",
624
+ "execution_count": 24,
625
+ "metadata": {},
626
+ "outputs": [
627
+ {
628
+ "name": "stderr",
629
+ "output_type": "stream",
630
+ "text": [
631
+ "Warning: The following arguments do not match the ones in the `trainer_state.json` within the checkpoint directory: \n",
632
+ "\tsave_steps: 1300 (from args) != 1000 (from trainer_state.json)\n",
633
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
634
+ " warnings.warn(\n",
635
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
636
+ " warnings.warn(\n",
637
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
638
+ " warnings.warn(\n",
639
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
640
+ " warnings.warn(\n",
641
+ "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
642
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n"
643
+ ]
644
+ },
645
+ {
646
+ "data": {
647
+ "text/html": [
648
+ "\n",
649
+ " <div>\n",
650
+ " \n",
651
+ " <progress value='13340' max='13340' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
652
+ " [13340/13340 3:28:00, Epoch 10/10]\n",
653
+ " </div>\n",
654
+ " <table border=\"1\" class=\"dataframe\">\n",
655
+ " <thead>\n",
656
+ " <tr style=\"text-align: left;\">\n",
657
+ " <th>Epoch</th>\n",
658
+ " <th>Training Loss</th>\n",
659
+ " <th>Validation Loss</th>\n",
660
+ " <th>Accuracy</th>\n",
661
+ " <th>Precision</th>\n",
662
+ " <th>Recall</th>\n",
663
+ " <th>F1</th>\n",
664
+ " </tr>\n",
665
+ " </thead>\n",
666
+ " <tbody>\n",
667
+ " <tr>\n",
668
+ " <td>6</td>\n",
669
+ " <td>0.081500</td>\n",
670
+ " <td>0.024286</td>\n",
671
+ " <td>0.992500</td>\n",
672
+ " <td>0.995473</td>\n",
673
+ " <td>0.989500</td>\n",
674
+ " <td>0.992477</td>\n",
675
+ " </tr>\n",
676
+ " <tr>\n",
677
+ " <td>7</td>\n",
678
+ " <td>0.050800</td>\n",
679
+ " <td>0.044788</td>\n",
680
+ " <td>0.987500</td>\n",
681
+ " <td>0.998976</td>\n",
682
+ " <td>0.976000</td>\n",
683
+ " <td>0.987355</td>\n",
684
+ " </tr>\n",
685
+ " <tr>\n",
686
+ " <td>8</td>\n",
687
+ " <td>0.100100</td>\n",
688
+ " <td>0.013277</td>\n",
689
+ " <td>0.996250</td>\n",
690
+ " <td>0.994519</td>\n",
691
+ " <td>0.998000</td>\n",
692
+ " <td>0.996257</td>\n",
693
+ " </tr>\n",
694
+ " <tr>\n",
695
+ " <td>9</td>\n",
696
+ " <td>0.094400</td>\n",
697
+ " <td>0.012132</td>\n",
698
+ " <td>0.996500</td>\n",
699
+ " <td>0.998995</td>\n",
700
+ " <td>0.994000</td>\n",
701
+ " <td>0.996491</td>\n",
702
+ " </tr>\n",
703
+ " <tr>\n",
704
+ " <td>10</td>\n",
705
+ " <td>0.070700</td>\n",
706
+ " <td>0.008876</td>\n",
707
+ " <td>0.997500</td>\n",
708
+ " <td>0.997500</td>\n",
709
+ " <td>0.997500</td>\n",
710
+ " <td>0.997500</td>\n",
711
+ " </tr>\n",
712
+ " </tbody>\n",
713
+ "</table><p>"
714
+ ],
715
+ "text/plain": [
716
+ "<IPython.core.display.HTML object>"
717
+ ]
718
+ },
719
+ "metadata": {},
720
+ "output_type": "display_data"
721
+ },
722
+ {
723
+ "name": "stderr",
724
+ "output_type": "stream",
725
+ "text": [
726
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
727
+ " warnings.warn(\n",
728
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
729
+ " warnings.warn(\n",
730
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
731
+ " warnings.warn(\n",
732
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
733
+ " warnings.warn(\n",
734
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
735
+ " warnings.warn(\n",
736
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
737
+ " warnings.warn(\n",
738
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
739
+ " warnings.warn(\n",
740
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
741
+ " warnings.warn(\n",
742
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
743
+ " warnings.warn(\n",
744
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
745
+ " warnings.warn(\n",
746
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
747
+ " warnings.warn(\n",
748
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
749
+ " warnings.warn(\n",
750
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
751
+ " warnings.warn(\n",
752
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
753
+ " warnings.warn(\n",
754
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
755
+ " warnings.warn(\n",
756
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
757
+ " warnings.warn(\n",
758
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
759
+ " warnings.warn(\n",
760
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
761
+ " warnings.warn(\n",
762
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
763
+ " warnings.warn(\n",
764
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
765
+ " warnings.warn(\n",
766
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
767
+ " warnings.warn(\n",
768
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
769
+ " warnings.warn(\n",
770
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
771
+ " warnings.warn(\n",
772
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
773
+ " warnings.warn(\n",
774
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
775
+ " warnings.warn(\n",
776
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
777
+ " warnings.warn(\n",
778
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
779
+ " warnings.warn(\n",
780
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
781
+ " warnings.warn(\n",
782
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
783
+ " warnings.warn(\n",
784
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
785
+ " warnings.warn(\n",
786
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
787
+ " warnings.warn(\n",
788
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:108: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
789
+ " warnings.warn(\n",
790
+ "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
791
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
792
+ "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
793
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
794
+ "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
795
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
796
+ "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
797
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
798
+ "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
799
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
800
+ "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
801
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
802
+ "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
803
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
804
+ "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
805
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
806
+ "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
807
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
808
+ "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
809
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n"
810
+ ]
811
+ },
812
+ {
813
+ "data": {
814
+ "text/plain": [
815
+ "TrainOutput(global_step=13340, training_loss=0.03427337026900739, metrics={'train_runtime': 12497.1749, 'train_samples_per_second': 12.801, 'train_steps_per_second': 1.067, 'total_flos': 6.692184281954304e+19, 'train_loss': 0.03427337026900739, 'epoch': 10.0})"
816
+ ]
817
+ },
818
+ "execution_count": 24,
819
+ "metadata": {},
820
+ "output_type": "execute_result"
821
+ }
822
+ ],
823
+ "source": [
824
+ "# Train the model\n",
825
+ "trainer.train(resume_from_checkpoint=True)"
826
+ ]
827
+ },
828
+ {
829
+ "cell_type": "code",
830
+ "execution_count": 25,
831
+ "metadata": {},
832
+ "outputs": [],
833
+ "source": [
834
+ "trainer.save_model(\"./model-large-60-10epochs\")"
835
+ ]
836
+ },
837
+ {
838
+ "cell_type": "code",
839
+ "execution_count": 35,
840
+ "metadata": {},
841
+ "outputs": [
842
+ {
843
+ "name": "stdout",
844
+ "output_type": "stream",
845
+ "text": [
846
+ "Dataset({\n",
847
+ " features: ['input_values', 'file_name', 'labels'],\n",
848
+ " num_rows: 29\n",
849
+ "})\n"
850
+ ]
851
+ },
852
+ {
853
+ "name": "stderr",
854
+ "output_type": "stream",
855
+ "text": [
856
+ "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
857
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n"
858
+ ]
859
+ },
860
+ {
861
+ "data": {
862
+ "text/html": [],
863
+ "text/plain": [
864
+ "<IPython.core.display.HTML object>"
865
+ ]
866
+ },
867
+ "metadata": {},
868
+ "output_type": "display_data"
869
+ },
870
+ {
871
+ "name": "stdout",
872
+ "output_type": "stream",
873
+ "text": [
874
+ "PredictionOutput(predictions=array([[-4.8372397 , 5.6649 ],\n",
875
+ " [-5.072206 , 5.9363923 ],\n",
876
+ " [-1.4489449 , 1.6803514 ],\n",
877
+ " [ 1.0182031 , -1.5616423 ],\n",
878
+ " [-5.576309 , 6.689727 ],\n",
879
+ " [-5.258823 , 6.159854 ],\n",
880
+ " [-3.868336 , 4.4103193 ],\n",
881
+ " [ 0.77387625, -0.931359 ],\n",
882
+ " [-5.196196 , 6.2262516 ],\n",
883
+ " [-5.6262727 , 6.7158213 ],\n",
884
+ " [-4.468895 , 5.226064 ],\n",
885
+ " [-3.8541725 , 4.4350986 ],\n",
886
+ " [-4.7433844 , 5.537499 ],\n",
887
+ " [-4.663436 , 5.4584312 ],\n",
888
+ " [-4.8346834 , 5.6357465 ],\n",
889
+ " [-5.145316 , 6.0559726 ],\n",
890
+ " [-4.8361654 , 5.622805 ],\n",
891
+ " [ 0.42984092, -0.8111682 ],\n",
892
+ " [-2.6803234 , 3.189293 ],\n",
893
+ " [-4.888308 , 5.798734 ],\n",
894
+ " [-4.5558224 , 5.330449 ],\n",
895
+ " [-1.834788 , 1.8180133 ],\n",
896
+ " [ 4.3429956 , -5.615567 ],\n",
897
+ " [ 5.5567174 , -6.775386 ],\n",
898
+ " [ 4.1883554 , -5.287336 ],\n",
899
+ " [ 3.2760735 , -4.2713284 ],\n",
900
+ " [ 5.487632 , -6.6286483 ],\n",
901
+ " [ 1.6245338 , -2.1124713 ],\n",
902
+ " [ 3.0620215 , -3.8812313 ]], dtype=float32), label_ids=array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
903
+ " 1, 1, 1, 1, 1, 1, 1]), metrics={'test_loss': 2.323230028152466, 'test_accuracy': 0.6551724137931034, 'test_precision': 1.0, 'test_recall': 0.6551724137931034, 'test_f1': 0.7916666666666666, 'test_runtime': 40.2239, 'test_samples_per_second': 0.721, 'test_steps_per_second': 0.05})\n",
904
+ "File: /workspace/ai/MakeBestMusic1.wav, Predicted class: ai_generated\n",
905
+ "File: /workspace/ai/MakeBestMusic2.wav, Predicted class: ai_generated\n",
906
+ "File: /workspace/ai/MakeBestMusic3.mp3, Predicted class: ai_generated\n",
907
+ "File: /workspace/ai/MakeBestMusic4.mp3, Predicted class: human\n",
908
+ "File: /workspace/ai/MakeBestMusic5.mp3, Predicted class: ai_generated\n",
909
+ "File: /workspace/ai/MakeBestMusic6.mp3, Predicted class: ai_generated\n",
910
+ "File: /workspace/ai/TopMediaAI1.mp3, Predicted class: ai_generated\n",
911
+ "File: /workspace/ai/TopMediaAI10.mp3, Predicted class: human\n",
912
+ "File: /workspace/ai/TopMediaAI11.mp3, Predicted class: ai_generated\n",
913
+ "File: /workspace/ai/TopMediaAI12.mp3, Predicted class: ai_generated\n",
914
+ "File: /workspace/ai/TopMediaAI13.mp3, Predicted class: ai_generated\n",
915
+ "File: /workspace/ai/TopMediaAI14.mp3, Predicted class: ai_generated\n",
916
+ "File: /workspace/ai/TopMediaAI15.mp3, Predicted class: ai_generated\n",
917
+ "File: /workspace/ai/TopMediaAI16.mp3, Predicted class: ai_generated\n",
918
+ "File: /workspace/ai/TopMediaAI2.mp3, Predicted class: ai_generated\n",
919
+ "File: /workspace/ai/TopMediaAI3.mp3, Predicted class: ai_generated\n",
920
+ "File: /workspace/ai/TopMediaAI4.mp3, Predicted class: ai_generated\n",
921
+ "File: /workspace/ai/TopMediaAI5.mp3, Predicted class: human\n",
922
+ "File: /workspace/ai/TopMediaAI6.mp3, Predicted class: ai_generated\n",
923
+ "File: /workspace/ai/TopMediaAI7.mp3, Predicted class: ai_generated\n",
924
+ "File: /workspace/ai/TopMediaAI8.mp3, Predicted class: ai_generated\n",
925
+ "File: /workspace/ai/TopMediaAI9.mp3, Predicted class: ai_generated\n",
926
+ "File: /workspace/ai/ai续写《大石碎胸口》,出乎意料地好听?! [BV1xi421Y7oh_p1].mp3, Predicted class: human\n",
927
+ "File: /workspace/ai/当suno AI续写Felis [BV1Wieie9E9V_p1].mp3, Predicted class: human\n",
928
+ "File: /workspace/ai/ai续写《大石碎胸口》,出乎意料地好听?! [BV1xi421Y7oh_p1]-from60s.mp3, Predicted class: human\n",
929
+ "File: /workspace/ai/当suno AI续写Felis [BV1Wieie9E9V_p1]-from60s.mp3, Predicted class: human\n",
930
+ "File: /workspace/ai/ai续写《大石碎胸口》,出乎意料地好听?! [BV1xi421Y7oh_p1]-from120s.mp3, Predicted class: human\n",
931
+ "File: /workspace/ai/当suno AI续写Felis [BV1Wieie9E9V_p1]-from120s.mp3, Predicted class: human\n",
932
+ "File: /workspace/ai/当suno AI续写Felis [BV1Wieie9E9V_p1]-from180s.mp3, Predicted class: human\n"
933
+ ]
934
+ }
935
+ ],
936
+ "source": [
937
+ "import glob\n",
938
+ "unseen_files = glob.glob(\"/workspace/ai/*\")\n",
939
+ "unseen_set = datasets.Dataset.from_dict({\"input_values\": unseen_files, \"file_name\": unseen_files}).cast_column(\"input_values\", datasets.Audio(sampling_rate=16000, mono=True))\n",
940
+ "unseen_set = unseen_set.add_column(name=\"labels\", column=[1 for _ in range(len(unseen_set))])\n",
941
+ "unseen_set.set_transform(preprocess_audio, output_all_columns=False)\n",
942
+ "print(unseen_set)\n",
943
+ "unseen_set_predictions = trainer.predict(unseen_set)\n",
944
+ "\n",
945
+ "print(unseen_set_predictions)\n",
946
+ "\n",
947
+ "# Map the predicted labels to the corresponding class names and file names\n",
948
+ "unseen_set.reset_format()\n",
949
+ "class_names = ds[\"train\"].features[\"labels\"].names\n",
950
+ "file_names = unseen_set[\"file_name\"]\n",
951
+ "predicted_class_ids = np.argmax(unseen_set_predictions.predictions, axis=-1)\n",
952
+ "predicted_class_names = [class_names[class_id] for class_id in predicted_class_ids]\n",
953
+ "\n",
954
+ "# Print the predicted class names and file names\n",
955
+ "for file_name, class_name in zip(file_names, predicted_class_names):\n",
956
+ " print(f\"File: {file_name}, Predicted class: {class_name}\")"
957
+ ]
958
+ },
959
+ {
960
+ "cell_type": "code",
961
+ "execution_count": null,
962
+ "metadata": {},
963
+ "outputs": [],
964
+ "source": []
965
+ }
966
+ ],
967
+ "metadata": {
968
+ "kernelspec": {
969
+ "display_name": "Python 3 (ipykernel)",
970
+ "language": "python",
971
+ "name": "python3"
972
+ },
973
+ "language_info": {
974
+ "codemirror_mode": {
975
+ "name": "ipython",
976
+ "version": 3
977
+ },
978
+ "file_extension": ".py",
979
+ "mimetype": "text/x-python",
980
+ "name": "python",
981
+ "nbconvert_exporter": "python",
982
+ "pygments_lexer": "ipython3",
983
+ "version": "3.10.13"
984
+ }
985
+ },
986
+ "nbformat": 4,
987
+ "nbformat_minor": 4
988
+ }