Andrey Moskalenko commited on
Commit
3d38624
·
1 Parent(s): c94270d

Upload Train_fakenews_detector.ipynb

Browse files
Files changed (1) hide show
  1. Train_fakenews_detector.ipynb +1465 -0
Train_fakenews_detector.ipynb ADDED
@@ -0,0 +1,1465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Data Preparation"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "Я нашел три датасета на kaggle по классификации фейков. Они все на английском, поэтому для поддержки русскуязычных статей будем использовать специально обученную для перевода новостей модель wmt19-ru-en. \n",
15
+ "\n",
16
+ "Выбранные датасеты:\n",
17
+ "* https://www.kaggle.com/c/fake-news/data\n",
18
+ "* https://www.kaggle.com/c/fakenewskdd2020/data\n",
19
+ "* https://www.kaggle.com/c/classifying-the-fake-news/data"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": 95,
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "import pandas as pd\n",
29
+ "\n",
30
+ "df1_train = pd.read_csv('./data1/train.csv')"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": 96,
36
+ "metadata": {},
37
+ "outputs": [
38
+ {
39
+ "data": {
40
+ "text/html": [
41
+ "<div>\n",
42
+ "<style scoped>\n",
43
+ " .dataframe tbody tr th:only-of-type {\n",
44
+ " vertical-align: middle;\n",
45
+ " }\n",
46
+ "\n",
47
+ " .dataframe tbody tr th {\n",
48
+ " vertical-align: top;\n",
49
+ " }\n",
50
+ "\n",
51
+ " .dataframe thead th {\n",
52
+ " text-align: right;\n",
53
+ " }\n",
54
+ "</style>\n",
55
+ "<table border=\"1\" class=\"dataframe\">\n",
56
+ " <thead>\n",
57
+ " <tr style=\"text-align: right;\">\n",
58
+ " <th></th>\n",
59
+ " <th>id</th>\n",
60
+ " <th>title</th>\n",
61
+ " <th>author</th>\n",
62
+ " <th>text</th>\n",
63
+ " <th>label</th>\n",
64
+ " </tr>\n",
65
+ " </thead>\n",
66
+ " <tbody>\n",
67
+ " <tr>\n",
68
+ " <th>0</th>\n",
69
+ " <td>0</td>\n",
70
+ " <td>House Dem Aide: We Didn’t Even See Comey’s Let...</td>\n",
71
+ " <td>Darrell Lucus</td>\n",
72
+ " <td>House Dem Aide: We Didn’t Even See Comey’s Let...</td>\n",
73
+ " <td>1</td>\n",
74
+ " </tr>\n",
75
+ " <tr>\n",
76
+ " <th>1</th>\n",
77
+ " <td>1</td>\n",
78
+ " <td>FLYNN: Hillary Clinton, Big Woman on Campus - ...</td>\n",
79
+ " <td>Daniel J. Flynn</td>\n",
80
+ " <td>Ever get the feeling your life circles the rou...</td>\n",
81
+ " <td>0</td>\n",
82
+ " </tr>\n",
83
+ " <tr>\n",
84
+ " <th>2</th>\n",
85
+ " <td>2</td>\n",
86
+ " <td>Why the Truth Might Get You Fired</td>\n",
87
+ " <td>Consortiumnews.com</td>\n",
88
+ " <td>Why the Truth Might Get You Fired October 29, ...</td>\n",
89
+ " <td>1</td>\n",
90
+ " </tr>\n",
91
+ " <tr>\n",
92
+ " <th>3</th>\n",
93
+ " <td>3</td>\n",
94
+ " <td>15 Civilians Killed In Single US Airstrike Hav...</td>\n",
95
+ " <td>Jessica Purkiss</td>\n",
96
+ " <td>Videos 15 Civilians Killed In Single US Airstr...</td>\n",
97
+ " <td>1</td>\n",
98
+ " </tr>\n",
99
+ " <tr>\n",
100
+ " <th>4</th>\n",
101
+ " <td>4</td>\n",
102
+ " <td>Iranian woman jailed for fictional unpublished...</td>\n",
103
+ " <td>Howard Portnoy</td>\n",
104
+ " <td>Print \\nAn Iranian woman has been sentenced to...</td>\n",
105
+ " <td>1</td>\n",
106
+ " </tr>\n",
107
+ " <tr>\n",
108
+ " <th>...</th>\n",
109
+ " <td>...</td>\n",
110
+ " <td>...</td>\n",
111
+ " <td>...</td>\n",
112
+ " <td>...</td>\n",
113
+ " <td>...</td>\n",
114
+ " </tr>\n",
115
+ " <tr>\n",
116
+ " <th>20795</th>\n",
117
+ " <td>20795</td>\n",
118
+ " <td>Rapper T.I.: Trump a ’Poster Child For White S...</td>\n",
119
+ " <td>Jerome Hudson</td>\n",
120
+ " <td>Rapper T. I. unloaded on black celebrities who...</td>\n",
121
+ " <td>0</td>\n",
122
+ " </tr>\n",
123
+ " <tr>\n",
124
+ " <th>20796</th>\n",
125
+ " <td>20796</td>\n",
126
+ " <td>N.F.L. Playoffs: Schedule, Matchups and Odds -...</td>\n",
127
+ " <td>Benjamin Hoffman</td>\n",
128
+ " <td>When the Green Bay Packers lost to the Washing...</td>\n",
129
+ " <td>0</td>\n",
130
+ " </tr>\n",
131
+ " <tr>\n",
132
+ " <th>20797</th>\n",
133
+ " <td>20797</td>\n",
134
+ " <td>Macy’s Is Said to Receive Takeover Approach by...</td>\n",
135
+ " <td>Michael J. de la Merced and Rachel Abrams</td>\n",
136
+ " <td>The Macy’s of today grew from the union of sev...</td>\n",
137
+ " <td>0</td>\n",
138
+ " </tr>\n",
139
+ " <tr>\n",
140
+ " <th>20798</th>\n",
141
+ " <td>20798</td>\n",
142
+ " <td>NATO, Russia To Hold Parallel Exercises In Bal...</td>\n",
143
+ " <td>Alex Ansary</td>\n",
144
+ " <td>NATO, Russia To Hold Parallel Exercises In Bal...</td>\n",
145
+ " <td>1</td>\n",
146
+ " </tr>\n",
147
+ " <tr>\n",
148
+ " <th>20799</th>\n",
149
+ " <td>20799</td>\n",
150
+ " <td>What Keeps the F-35 Alive</td>\n",
151
+ " <td>David Swanson</td>\n",
152
+ " <td>David Swanson is an author, activist, journa...</td>\n",
153
+ " <td>1</td>\n",
154
+ " </tr>\n",
155
+ " </tbody>\n",
156
+ "</table>\n",
157
+ "<p>20800 rows × 5 columns</p>\n",
158
+ "</div>"
159
+ ],
160
+ "text/plain": [
161
+ " id title \\\n",
162
+ "0 0 House Dem Aide: We Didn’t Even See Comey’s Let... \n",
163
+ "1 1 FLYNN: Hillary Clinton, Big Woman on Campus - ... \n",
164
+ "2 2 Why the Truth Might Get You Fired \n",
165
+ "3 3 15 Civilians Killed In Single US Airstrike Hav... \n",
166
+ "4 4 Iranian woman jailed for fictional unpublished... \n",
167
+ "... ... ... \n",
168
+ "20795 20795 Rapper T.I.: Trump a ’Poster Child For White S... \n",
169
+ "20796 20796 N.F.L. Playoffs: Schedule, Matchups and Odds -... \n",
170
+ "20797 20797 Macy’s Is Said to Receive Takeover Approach by... \n",
171
+ "20798 20798 NATO, Russia To Hold Parallel Exercises In Bal... \n",
172
+ "20799 20799 What Keeps the F-35 Alive \n",
173
+ "\n",
174
+ " author \\\n",
175
+ "0 Darrell Lucus \n",
176
+ "1 Daniel J. Flynn \n",
177
+ "2 Consortiumnews.com \n",
178
+ "3 Jessica Purkiss \n",
179
+ "4 Howard Portnoy \n",
180
+ "... ... \n",
181
+ "20795 Jerome Hudson \n",
182
+ "20796 Benjamin Hoffman \n",
183
+ "20797 Michael J. de la Merced and Rachel Abrams \n",
184
+ "20798 Alex Ansary \n",
185
+ "20799 David Swanson \n",
186
+ "\n",
187
+ " text label \n",
188
+ "0 House Dem Aide: We Didn’t Even See Comey’s Let... 1 \n",
189
+ "1 Ever get the feeling your life circles the rou... 0 \n",
190
+ "2 Why the Truth Might Get You Fired October 29, ... 1 \n",
191
+ "3 Videos 15 Civilians Killed In Single US Airstr... 1 \n",
192
+ "4 Print \\nAn Iranian woman has been sentenced to... 1 \n",
193
+ "... ... ... \n",
194
+ "20795 Rapper T. I. unloaded on black celebrities who... 0 \n",
195
+ "20796 When the Green Bay Packers lost to the Washing... 0 \n",
196
+ "20797 The Macy’s of today grew from the union of sev... 0 \n",
197
+ "20798 NATO, Russia To Hold Parallel Exercises In Bal... 1 \n",
198
+ "20799 David Swanson is an author, activist, journa... 1 \n",
199
+ "\n",
200
+ "[20800 rows x 5 columns]"
201
+ ]
202
+ },
203
+ "execution_count": 96,
204
+ "metadata": {},
205
+ "output_type": "execute_result"
206
+ }
207
+ ],
208
+ "source": [
209
+ "df1_train"
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "code",
214
+ "execution_count": 97,
215
+ "metadata": {},
216
+ "outputs": [],
217
+ "source": [
218
+ "df1_train['text'] = df1_train.apply(lambda x: str(x.title) + '. ' + str(x.text), axis=1)\n",
219
+ "df1_train = df1_train[['text', 'label']]"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "code",
224
+ "execution_count": 98,
225
+ "metadata": {},
226
+ "outputs": [],
227
+ "source": [
228
+ "df2_train = pd.read_csv('./data2/train.csv', sep='\\t')"
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "code",
233
+ "execution_count": 99,
234
+ "metadata": {},
235
+ "outputs": [],
236
+ "source": [
237
+ "# Битая строка\n",
238
+ "df2_train = df2_train.drop([1615])"
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "execution_count": 100,
244
+ "metadata": {},
245
+ "outputs": [
246
+ {
247
+ "data": {
248
+ "text/html": [
249
+ "<div>\n",
250
+ "<style scoped>\n",
251
+ " .dataframe tbody tr th:only-of-type {\n",
252
+ " vertical-align: middle;\n",
253
+ " }\n",
254
+ "\n",
255
+ " .dataframe tbody tr th {\n",
256
+ " vertical-align: top;\n",
257
+ " }\n",
258
+ "\n",
259
+ " .dataframe thead th {\n",
260
+ " text-align: right;\n",
261
+ " }\n",
262
+ "</style>\n",
263
+ "<table border=\"1\" class=\"dataframe\">\n",
264
+ " <thead>\n",
265
+ " <tr style=\"text-align: right;\">\n",
266
+ " <th></th>\n",
267
+ " <th>text</th>\n",
268
+ " <th>label</th>\n",
269
+ " </tr>\n",
270
+ " </thead>\n",
271
+ " <tbody>\n",
272
+ " <tr>\n",
273
+ " <th>0</th>\n",
274
+ " <td>Get the latest from TODAY Sign up for our news...</td>\n",
275
+ " <td>1</td>\n",
276
+ " </tr>\n",
277
+ " <tr>\n",
278
+ " <th>1</th>\n",
279
+ " <td>2d Conan On The Funeral Trump Will Be Invited...</td>\n",
280
+ " <td>1</td>\n",
281
+ " </tr>\n",
282
+ " <tr>\n",
283
+ " <th>2</th>\n",
284
+ " <td>It’s safe to say that Instagram Stories has fa...</td>\n",
285
+ " <td>0</td>\n",
286
+ " </tr>\n",
287
+ " <tr>\n",
288
+ " <th>3</th>\n",
289
+ " <td>Much like a certain Amazon goddess with a lass...</td>\n",
290
+ " <td>0</td>\n",
291
+ " </tr>\n",
292
+ " <tr>\n",
293
+ " <th>4</th>\n",
294
+ " <td>At a time when the perfect outfit is just one ...</td>\n",
295
+ " <td>0</td>\n",
296
+ " </tr>\n",
297
+ " <tr>\n",
298
+ " <th>...</th>\n",
299
+ " <td>...</td>\n",
300
+ " <td>...</td>\n",
301
+ " </tr>\n",
302
+ " <tr>\n",
303
+ " <th>4982</th>\n",
304
+ " <td>The storybook romance of WWE stars John Cena a...</td>\n",
305
+ " <td>0</td>\n",
306
+ " </tr>\n",
307
+ " <tr>\n",
308
+ " <th>4983</th>\n",
309
+ " <td>The actor told friends he’s responsible for en...</td>\n",
310
+ " <td>0</td>\n",
311
+ " </tr>\n",
312
+ " <tr>\n",
313
+ " <th>4984</th>\n",
314
+ " <td>Sarah Hyland is getting real. The Modern Fami...</td>\n",
315
+ " <td>0</td>\n",
316
+ " </tr>\n",
317
+ " <tr>\n",
318
+ " <th>4985</th>\n",
319
+ " <td>Production has been suspended on the sixth and...</td>\n",
320
+ " <td>0</td>\n",
321
+ " </tr>\n",
322
+ " <tr>\n",
323
+ " <th>4986</th>\n",
324
+ " <td>A jury ruled against Bill Cosby in his sexual ...</td>\n",
325
+ " <td>0</td>\n",
326
+ " </tr>\n",
327
+ " </tbody>\n",
328
+ "</table>\n",
329
+ "<p>4986 rows × 2 columns</p>\n",
330
+ "</div>"
331
+ ],
332
+ "text/plain": [
333
+ " text label\n",
334
+ "0 Get the latest from TODAY Sign up for our news... 1\n",
335
+ "1 2d Conan On The Funeral Trump Will Be Invited... 1\n",
336
+ "2 It’s safe to say that Instagram Stories has fa... 0\n",
337
+ "3 Much like a certain Amazon goddess with a lass... 0\n",
338
+ "4 At a time when the perfect outfit is just one ... 0\n",
339
+ "... ... ...\n",
340
+ "4982 The storybook romance of WWE stars John Cena a... 0\n",
341
+ "4983 The actor told friends he’s responsible for en... 0\n",
342
+ "4984 Sarah Hyland is getting real. The Modern Fami... 0\n",
343
+ "4985 Production has been suspended on the sixth and... 0\n",
344
+ "4986 A jury ruled against Bill Cosby in his sexual ... 0\n",
345
+ "\n",
346
+ "[4986 rows x 2 columns]"
347
+ ]
348
+ },
349
+ "execution_count": 100,
350
+ "metadata": {},
351
+ "output_type": "execute_result"
352
+ }
353
+ ],
354
+ "source": [
355
+ "df2_train"
356
+ ]
357
+ },
358
+ {
359
+ "cell_type": "code",
360
+ "execution_count": 104,
361
+ "metadata": {},
362
+ "outputs": [],
363
+ "source": [
364
+ "df3_train = pd.read_csv('./data3/training.csv')"
365
+ ]
366
+ },
367
+ {
368
+ "cell_type": "code",
369
+ "execution_count": 105,
370
+ "metadata": {},
371
+ "outputs": [],
372
+ "source": [
373
+ "df3_train['text'] = df3_train.apply(lambda x: str(x.title) + '. ' + str(x.text), axis=1)\n",
374
+ "df3_train = df3_train[['text', 'label']]"
375
+ ]
376
+ },
377
+ {
378
+ "cell_type": "code",
379
+ "execution_count": 106,
380
+ "metadata": {},
381
+ "outputs": [],
382
+ "source": [
383
+ "all_data_train = df1_train.append(df2_train).append(df3_train)\n",
384
+ "all_data_train.to_csv('./train.csv', index=False)"
385
+ ]
386
+ },
387
+ {
388
+ "cell_type": "markdown",
389
+ "metadata": {},
390
+ "source": [
391
+ "# Training"
392
+ ]
393
+ },
394
+ {
395
+ "cell_type": "code",
396
+ "execution_count": 1,
397
+ "metadata": {
398
+ "id": "zriTdjauH8iQ"
399
+ },
400
+ "outputs": [],
401
+ "source": [
402
+ "#!pip install transformers\n",
403
+ "import transformers"
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "code",
408
+ "execution_count": 2,
409
+ "metadata": {
410
+ "id": "TFh3upySL3XG"
411
+ },
412
+ "outputs": [],
413
+ "source": [
414
+ "from transformers import Trainer, TrainingArguments, LineByLineTextDataset"
415
+ ]
416
+ },
417
+ {
418
+ "cell_type": "code",
419
+ "execution_count": 3,
420
+ "metadata": {
421
+ "id": "H2Ym6YhyNfON"
422
+ },
423
+ "outputs": [],
424
+ "source": [
425
+ "import pandas as pd"
426
+ ]
427
+ },
428
+ {
429
+ "cell_type": "code",
430
+ "execution_count": 4,
431
+ "metadata": {
432
+ "id": "ueRyDnvgNgpW"
433
+ },
434
+ "outputs": [],
435
+ "source": [
436
+ "from datasets import Dataset"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": 5,
442
+ "metadata": {
443
+ "id": "HVBCtqyjNhLn"
444
+ },
445
+ "outputs": [],
446
+ "source": [
447
+ "df = pd.read_csv('./train.csv')"
448
+ ]
449
+ },
450
+ {
451
+ "cell_type": "code",
452
+ "execution_count": 6,
453
+ "metadata": {
454
+ "colab": {
455
+ "base_uri": "https://localhost:8080/",
456
+ "height": 424
457
+ },
458
+ "id": "f7j8fEl1Nogb",
459
+ "outputId": "3b5b13a0-4c34-412c-9718-5b0decb855cc"
460
+ },
461
+ "outputs": [
462
+ {
463
+ "data": {
464
+ "text/html": [
465
+ "<div>\n",
466
+ "<style scoped>\n",
467
+ " .dataframe tbody tr th:only-of-type {\n",
468
+ " vertical-align: middle;\n",
469
+ " }\n",
470
+ "\n",
471
+ " .dataframe tbody tr th {\n",
472
+ " vertical-align: top;\n",
473
+ " }\n",
474
+ "\n",
475
+ " .dataframe thead th {\n",
476
+ " text-align: right;\n",
477
+ " }\n",
478
+ "</style>\n",
479
+ "<table border=\"1\" class=\"dataframe\">\n",
480
+ " <thead>\n",
481
+ " <tr style=\"text-align: right;\">\n",
482
+ " <th></th>\n",
483
+ " <th>text</th>\n",
484
+ " <th>label</th>\n",
485
+ " </tr>\n",
486
+ " </thead>\n",
487
+ " <tbody>\n",
488
+ " <tr>\n",
489
+ " <th>0</th>\n",
490
+ " <td>House Dem Aide: We Didn’t Even See Comey’s Let...</td>\n",
491
+ " <td>1</td>\n",
492
+ " </tr>\n",
493
+ " <tr>\n",
494
+ " <th>1</th>\n",
495
+ " <td>FLYNN: Hillary Clinton, Big Woman on Campus - ...</td>\n",
496
+ " <td>0</td>\n",
497
+ " </tr>\n",
498
+ " <tr>\n",
499
+ " <th>2</th>\n",
500
+ " <td>Why the Truth Might Get You Fired.Why the Trut...</td>\n",
501
+ " <td>1</td>\n",
502
+ " </tr>\n",
503
+ " <tr>\n",
504
+ " <th>3</th>\n",
505
+ " <td>15 Civilians Killed In Single US Airstrike Hav...</td>\n",
506
+ " <td>1</td>\n",
507
+ " </tr>\n",
508
+ " <tr>\n",
509
+ " <th>4</th>\n",
510
+ " <td>Iranian woman jailed for fictional unpublished...</td>\n",
511
+ " <td>1</td>\n",
512
+ " </tr>\n",
513
+ " <tr>\n",
514
+ " <th>...</th>\n",
515
+ " <td>...</td>\n",
516
+ " <td>...</td>\n",
517
+ " </tr>\n",
518
+ " <tr>\n",
519
+ " <th>57209</th>\n",
520
+ " <td>CHICAGO TRUMP RALLY CANCELLED: Radicals And BL...</td>\n",
521
+ " <td>1</td>\n",
522
+ " </tr>\n",
523
+ " <tr>\n",
524
+ " <th>57210</th>\n",
525
+ " <td>Trump supports completion of Dakota Access Pip...</td>\n",
526
+ " <td>0</td>\n",
527
+ " </tr>\n",
528
+ " <tr>\n",
529
+ " <th>57211</th>\n",
530
+ " <td>Obama Can’t Stop Winning As New Jobs Report S...</td>\n",
531
+ " <td>1</td>\n",
532
+ " </tr>\n",
533
+ " <tr>\n",
534
+ " <th>57212</th>\n",
535
+ " <td>Turkey bank regulator dismisses 'rumors' after...</td>\n",
536
+ " <td>0</td>\n",
537
+ " </tr>\n",
538
+ " <tr>\n",
539
+ " <th>57213</th>\n",
540
+ " <td>California mayors ask for governor's support f...</td>\n",
541
+ " <td>0</td>\n",
542
+ " </tr>\n",
543
+ " </tbody>\n",
544
+ "</table>\n",
545
+ "<p>57214 rows × 2 columns</p>\n",
546
+ "</div>"
547
+ ],
548
+ "text/plain": [
549
+ " text label\n",
550
+ "0 House Dem Aide: We Didn’t Even See Comey’s Let... 1\n",
551
+ "1 FLYNN: Hillary Clinton, Big Woman on Campus - ... 0\n",
552
+ "2 Why the Truth Might Get You Fired.Why the Trut... 1\n",
553
+ "3 15 Civilians Killed In Single US Airstrike Hav... 1\n",
554
+ "4 Iranian woman jailed for fictional unpublished... 1\n",
555
+ "... ... ...\n",
556
+ "57209 CHICAGO TRUMP RALLY CANCELLED: Radicals And BL... 1\n",
557
+ "57210 Trump supports completion of Dakota Access Pip... 0\n",
558
+ "57211 Obama Can’t Stop Winning As New Jobs Report S... 1\n",
559
+ "57212 Turkey bank regulator dismisses 'rumors' after... 0\n",
560
+ "57213 California mayors ask for governor's support f... 0\n",
561
+ "\n",
562
+ "[57214 rows x 2 columns]"
563
+ ]
564
+ },
565
+ "execution_count": 6,
566
+ "metadata": {},
567
+ "output_type": "execute_result"
568
+ }
569
+ ],
570
+ "source": [
571
+ "df"
572
+ ]
573
+ },
574
+ {
575
+ "cell_type": "code",
576
+ "execution_count": 7,
577
+ "metadata": {
578
+ "id": "L0ET6Z83Pcxu"
579
+ },
580
+ "outputs": [],
581
+ "source": [
582
+ "df['labels'] = df['label']"
583
+ ]
584
+ },
585
+ {
586
+ "cell_type": "code",
587
+ "execution_count": 8,
588
+ "metadata": {
589
+ "id": "39Zv6HBJPgEt"
590
+ },
591
+ "outputs": [],
592
+ "source": [
593
+ "df = df[['text', 'labels']]"
594
+ ]
595
+ },
596
+ {
597
+ "cell_type": "code",
598
+ "execution_count": 9,
599
+ "metadata": {
600
+ "id": "bPGVPY17NI7x"
601
+ },
602
+ "outputs": [],
603
+ "source": [
604
+ "dataset = Dataset.from_pandas(df)"
605
+ ]
606
+ },
607
+ {
608
+ "cell_type": "code",
609
+ "execution_count": 10,
610
+ "metadata": {
611
+ "colab": {
612
+ "base_uri": "https://localhost:8080/"
613
+ },
614
+ "id": "3LTGwWrINmZq",
615
+ "outputId": "177d8749-68cf-4f81-a91b-1097bf155478"
616
+ },
617
+ "outputs": [
618
+ {
619
+ "data": {
620
+ "text/plain": [
621
+ "Dataset({\n",
622
+ " features: ['text', 'labels'],\n",
623
+ " num_rows: 57214\n",
624
+ "})"
625
+ ]
626
+ },
627
+ "execution_count": 10,
628
+ "metadata": {},
629
+ "output_type": "execute_result"
630
+ }
631
+ ],
632
+ "source": [
633
+ "dataset"
634
+ ]
635
+ },
636
+ {
637
+ "cell_type": "code",
638
+ "execution_count": 11,
639
+ "metadata": {
640
+ "colab": {
641
+ "base_uri": "https://localhost:8080/"
642
+ },
643
+ "id": "3DrWrMiDd7e-",
644
+ "outputId": "d331ebe6-5ed4-4fef-8a8d-41d25ed4b638"
645
+ },
646
+ "outputs": [],
647
+ "source": [
648
+ "import torch\n",
649
+ "from transformers import AutoTokenizer, AutoModel, pipeline\n",
650
+ "\n",
651
+ "model_name = 'distilbert-base-uncased-finetuned-sst-2-english'\n",
652
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)"
653
+ ]
654
+ },
655
+ {
656
+ "cell_type": "code",
657
+ "execution_count": 12,
658
+ "metadata": {
659
+ "id": "dRJOO2c5PT3V"
660
+ },
661
+ "outputs": [],
662
+ "source": [
663
+ "def preprocess_function(examples):\n",
664
+ " return tokenizer(examples[\"text\"], padding=True, truncation=True)"
665
+ ]
666
+ },
667
+ {
668
+ "cell_type": "code",
669
+ "execution_count": 13,
670
+ "metadata": {
671
+ "colab": {
672
+ "base_uri": "https://localhost:8080/",
673
+ "height": 49,
674
+ "referenced_widgets": [
675
+ "5b49dc833234406da3da7435b9045fd2",
676
+ "300b70ed57dd493997afb0b3f25f4245",
677
+ "c03cc68b079c4e23b339e9de5ba38d29",
678
+ "57c3794731c84c42bb49618482b6b8cc",
679
+ "e306828f6d7444ddafce604e9a170467",
680
+ "9e11898bc51e483d91301387099368a4",
681
+ "a43574fa5fdf47ba9d5598b2b31f2082",
682
+ "482bae742d2a461cad525888e6ee8b91",
683
+ "e9c56275d73545a6961efe5704308ede",
684
+ "d604380b5e444f62ad36c4598230c561",
685
+ "c52ad745acb3423494b4ea5af5a934c7"
686
+ ]
687
+ },
688
+ "id": "hCxs-HasPQ7s",
689
+ "outputId": "be4f8483-316c-4677-f804-12c78f358fac"
690
+ },
691
+ "outputs": [
692
+ {
693
+ "data": {
694
+ "application/vnd.jupyter.widget-view+json": {
695
+ "model_id": "67689f0c8fb842b2969c4fc584fa3a4b",
696
+ "version_major": 2,
697
+ "version_minor": 0
698
+ },
699
+ "text/plain": [
700
+ " 0%| | 0/58 [00:00<?, ?ba/s]"
701
+ ]
702
+ },
703
+ "metadata": {},
704
+ "output_type": "display_data"
705
+ }
706
+ ],
707
+ "source": [
708
+ "dataset = dataset.map(preprocess_function, batched=True)"
709
+ ]
710
+ },
711
+ {
712
+ "cell_type": "code",
713
+ "execution_count": 14,
714
+ "metadata": {},
715
+ "outputs": [],
716
+ "source": [
717
+ "dataset_splitted = dataset.shuffle(1337).train_test_split(0.1)"
718
+ ]
719
+ },
720
+ {
721
+ "cell_type": "code",
722
+ "execution_count": 15,
723
+ "metadata": {},
724
+ "outputs": [
725
+ {
726
+ "data": {
727
+ "text/plain": [
728
+ "DatasetDict({\n",
729
+ " train: Dataset({\n",
730
+ " features: ['text', 'labels', 'input_ids', 'attention_mask'],\n",
731
+ " num_rows: 51492\n",
732
+ " })\n",
733
+ " test: Dataset({\n",
734
+ " features: ['text', 'labels', 'input_ids', 'attention_mask'],\n",
735
+ " num_rows: 5722\n",
736
+ " })\n",
737
+ "})"
738
+ ]
739
+ },
740
+ "execution_count": 15,
741
+ "metadata": {},
742
+ "output_type": "execute_result"
743
+ }
744
+ ],
745
+ "source": [
746
+ "dataset_splitted"
747
+ ]
748
+ },
749
+ {
750
+ "cell_type": "code",
751
+ "execution_count": 16,
752
+ "metadata": {
753
+ "id": "NyHknkwcYi6L"
754
+ },
755
+ "outputs": [],
756
+ "source": [
757
+ "from transformers import AutoModelForSequenceClassification"
758
+ ]
759
+ },
760
+ {
761
+ "cell_type": "code",
762
+ "execution_count": 23,
763
+ "metadata": {
764
+ "colab": {
765
+ "base_uri": "https://localhost:8080/"
766
+ },
767
+ "id": "gv_fYzmEYlUm",
768
+ "outputId": "7a97df03-8f7b-4d54-f8d7-6a6b71d4c8c4"
769
+ },
770
+ "outputs": [
771
+ {
772
+ "name": "stderr",
773
+ "output_type": "stream",
774
+ "text": [
775
+ "loading configuration file https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/config.json from cache at C:\\Users\\andry/.cache\\huggingface\\transformers\\4e60bb8efad3d4b7dc9969bf204947c185166a0a3cf37ddb6f481a876a3777b5.9f8326d0b7697c7fd57366cdde57032f46bc10e37ae81cb7eb564d66d23ec96b\n",
776
+ "Model config DistilBertConfig {\n",
777
+ " \"_name_or_path\": \"distilbert-base-uncased-finetuned-sst-2-english\",\n",
778
+ " \"activation\": \"gelu\",\n",
779
+ " \"architectures\": [\n",
780
+ " \"DistilBertForSequenceClassification\"\n",
781
+ " ],\n",
782
+ " \"attention_dropout\": 0.1,\n",
783
+ " \"dim\": 768,\n",
784
+ " \"dropout\": 0.1,\n",
785
+ " \"finetuning_task\": \"sst-2\",\n",
786
+ " \"hidden_dim\": 3072,\n",
787
+ " \"id2label\": {\n",
788
+ " \"0\": \"NEGATIVE\",\n",
789
+ " \"1\": \"POSITIVE\"\n",
790
+ " },\n",
791
+ " \"initializer_range\": 0.02,\n",
792
+ " \"label2id\": {\n",
793
+ " \"NEGATIVE\": 0,\n",
794
+ " \"POSITIVE\": 1\n",
795
+ " },\n",
796
+ " \"max_position_embeddings\": 512,\n",
797
+ " \"model_type\": \"distilbert\",\n",
798
+ " \"n_heads\": 12,\n",
799
+ " \"n_layers\": 6,\n",
800
+ " \"output_past\": true,\n",
801
+ " \"pad_token_id\": 0,\n",
802
+ " \"qa_dropout\": 0.1,\n",
803
+ " \"seq_classif_dropout\": 0.2,\n",
804
+ " \"sinusoidal_pos_embds\": false,\n",
805
+ " \"tie_weights_\": true,\n",
806
+ " \"transformers_version\": \"4.17.0\",\n",
807
+ " \"vocab_size\": 30522\n",
808
+ "}\n",
809
+ "\n",
810
+ "loading weights file https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/pytorch_model.bin from cache at C:\\Users\\andry/.cache\\huggingface\\transformers\\8d04c767d9d4c14d929ce7ad8e067b80c74dbdb212ef4c3fb743db4ee109fae0.9d268a35da669ead745c44d369dc9948b408da5010c6bac414414a7e33d5748c\n",
811
+ "All model checkpoint weights were used when initializing DistilBertForSequenceClassification.\n",
812
+ "\n",
813
+ "All the weights of DistilBertForSequenceClassification were initialized from the model checkpoint at distilbert-base-uncased-finetuned-sst-2-english.\n",
814
+ "If your task is similar to the task the model of the checkpoint was trained on, you can already use DistilBertForSequenceClassification for predictions without further training.\n"
815
+ ]
816
+ }
817
+ ],
818
+ "source": [
819
+ "model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)"
820
+ ]
821
+ },
822
+ {
823
+ "cell_type": "code",
824
+ "execution_count": 24,
825
+ "metadata": {
826
+ "id": "YqcdtMXZelbm"
827
+ },
828
+ "outputs": [],
829
+ "source": [
830
+ "for name, param in model.named_parameters():\n",
831
+ " if name in ['classifier.weight', 'classifier.bias']:\n",
832
+ " param.requires_grad = True\n",
833
+ " else:\n",
834
+ " param.requires_grad = False"
835
+ ]
836
+ },
837
+ {
838
+ "cell_type": "code",
839
+ "execution_count": 25,
840
+ "metadata": {},
841
+ "outputs": [],
842
+ "source": [
843
+ "from sklearn.metrics import accuracy_score\n",
844
+ "\n",
845
+ "def compute_metrics(pred):\n",
846
+ " labels = pred.label_ids\n",
847
+ " preds = pred.predictions.argmax(-1)\n",
848
+ " acc = accuracy_score(labels, preds)\n",
849
+ " return {'accuracy': acc}"
850
+ ]
851
+ },
852
+ {
853
+ "cell_type": "code",
854
+ "execution_count": 26,
855
+ "metadata": {
856
+ "colab": {
857
+ "base_uri": "https://localhost:8080/",
858
+ "height": 608
859
+ },
860
+ "id": "DkBWiEiyIgnV",
861
+ "outputId": "07f58180-8005-4f7e-fd72-62a5d2c78717",
862
+ "scrolled": false
863
+ },
864
+ "outputs": [
865
+ {
866
+ "name": "stderr",
867
+ "output_type": "stream",
868
+ "text": [
869
+ "PyTorch: setting up devices\n",
870
+ "The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).\n",
871
+ "The following columns in the training set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n",
872
+ "***** Running training *****\n",
873
+ " Num examples = 51492\n",
874
+ " Num Epochs = 10\n",
875
+ " Instantaneous batch size per device = 64\n",
876
+ " Total train batch size (w. parallel, distributed & accumulation) = 64\n",
877
+ " Gradient Accumulation steps = 1\n",
878
+ " Total optimization steps = 8050\n"
879
+ ]
880
+ },
881
+ {
882
+ "data": {
883
+ "text/html": [
884
+ "\n",
885
+ " <div>\n",
886
+ " \n",
887
+ " <progress value='8050' max='8050' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
888
+ " [8050/8050 1:31:55, Epoch 10/10]\n",
889
+ " </div>\n",
890
+ " <table border=\"1\" class=\"dataframe\">\n",
891
+ " <thead>\n",
892
+ " <tr style=\"text-align: left;\">\n",
893
+ " <th>Epoch</th>\n",
894
+ " <th>Training Loss</th>\n",
895
+ " <th>Validation Loss</th>\n",
896
+ " <th>Accuracy</th>\n",
897
+ " </tr>\n",
898
+ " </thead>\n",
899
+ " <tbody>\n",
900
+ " <tr>\n",
901
+ " <td>1</td>\n",
902
+ " <td>1.124500</td>\n",
903
+ " <td>0.655170</td>\n",
904
+ " <td>0.631423</td>\n",
905
+ " </tr>\n",
906
+ " <tr>\n",
907
+ " <td>2</td>\n",
908
+ " <td>0.635900</td>\n",
909
+ " <td>0.616928</td>\n",
910
+ " <td>0.696435</td>\n",
911
+ " </tr>\n",
912
+ " <tr>\n",
913
+ " <td>3</td>\n",
914
+ " <td>0.617400</td>\n",
915
+ " <td>0.592879</td>\n",
916
+ " <td>0.727019</td>\n",
917
+ " </tr>\n",
918
+ " <tr>\n",
919
+ " <td>4</td>\n",
920
+ " <td>0.591200</td>\n",
921
+ " <td>0.577941</td>\n",
922
+ " <td>0.734533</td>\n",
923
+ " </tr>\n",
924
+ " <tr>\n",
925
+ " <td>5</td>\n",
926
+ " <td>0.577100</td>\n",
927
+ " <td>0.564665</td>\n",
928
+ " <td>0.747466</td>\n",
929
+ " </tr>\n",
930
+ " <tr>\n",
931
+ " <td>6</td>\n",
932
+ " <td>0.569300</td>\n",
933
+ " <td>0.556096</td>\n",
934
+ " <td>0.749913</td>\n",
935
+ " </tr>\n",
936
+ " <tr>\n",
937
+ " <td>7</td>\n",
938
+ " <td>0.563200</td>\n",
939
+ " <td>0.551389</td>\n",
940
+ " <td>0.755330</td>\n",
941
+ " </tr>\n",
942
+ " <tr>\n",
943
+ " <td>8</td>\n",
944
+ " <td>0.559900</td>\n",
945
+ " <td>0.546756</td>\n",
946
+ " <td>0.754981</td>\n",
947
+ " </tr>\n",
948
+ " <tr>\n",
949
+ " <td>9</td>\n",
950
+ " <td>0.554800</td>\n",
951
+ " <td>0.544496</td>\n",
952
+ " <td>0.759000</td>\n",
953
+ " </tr>\n",
954
+ " <tr>\n",
955
+ " <td>10</td>\n",
956
+ " <td>0.554000</td>\n",
957
+ " <td>0.543604</td>\n",
958
+ " <td>0.760398</td>\n",
959
+ " </tr>\n",
960
+ " </tbody>\n",
961
+ "</table><p>"
962
+ ],
963
+ "text/plain": [
964
+ "<IPython.core.display.HTML object>"
965
+ ]
966
+ },
967
+ "metadata": {},
968
+ "output_type": "display_data"
969
+ },
970
+ {
971
+ "name": "stderr",
972
+ "output_type": "stream",
973
+ "text": [
974
+ "The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n",
975
+ "***** Running Evaluation *****\n",
976
+ " Num examples = 5722\n",
977
+ " Batch size = 64\n",
978
+ "Saving model checkpoint to ./my_saved_model\\checkpoint-805\n",
979
+ "Configuration saved in ./my_saved_model\\checkpoint-805\\config.json\n",
980
+ "Model weights saved in ./my_saved_model\\checkpoint-805\\pytorch_model.bin\n",
981
+ "The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n",
982
+ "***** Running Evaluation *****\n",
983
+ " Num examples = 5722\n",
984
+ " Batch size = 64\n",
985
+ "Saving model checkpoint to ./my_saved_model\\checkpoint-1610\n",
986
+ "Configuration saved in ./my_saved_model\\checkpoint-1610\\config.json\n",
987
+ "Model weights saved in ./my_saved_model\\checkpoint-1610\\pytorch_model.bin\n",
988
+ "The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n",
989
+ "***** Running Evaluation *****\n",
990
+ " Num examples = 5722\n",
991
+ " Batch size = 64\n",
992
+ "Saving model checkpoint to ./my_saved_model\\checkpoint-2415\n",
993
+ "Configuration saved in ./my_saved_model\\checkpoint-2415\\config.json\n",
994
+ "Model weights saved in ./my_saved_model\\checkpoint-2415\\pytorch_model.bin\n",
995
+ "Deleting older checkpoint [my_saved_model\\checkpoint-805] due to args.save_total_limit\n",
996
+ "The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n",
997
+ "***** Running Evaluation *****\n",
998
+ " Num examples = 5722\n",
999
+ " Batch size = 64\n",
1000
+ "Saving model checkpoint to ./my_saved_model\\checkpoint-3220\n",
1001
+ "Configuration saved in ./my_saved_model\\checkpoint-3220\\config.json\n",
1002
+ "Model weights saved in ./my_saved_model\\checkpoint-3220\\pytorch_model.bin\n",
1003
+ "Deleting older checkpoint [my_saved_model\\checkpoint-1610] due to args.save_total_limit\n",
1004
+ "The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n",
1005
+ "***** Running Evaluation *****\n",
1006
+ " Num examples = 5722\n",
1007
+ " Batch size = 64\n",
1008
+ "Saving model checkpoint to ./my_saved_model\\checkpoint-4025\n",
1009
+ "Configuration saved in ./my_saved_model\\checkpoint-4025\\config.json\n",
1010
+ "Model weights saved in ./my_saved_model\\checkpoint-4025\\pytorch_model.bin\n",
1011
+ "Deleting older checkpoint [my_saved_model\\checkpoint-2415] due to args.save_total_limit\n",
1012
+ "The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n",
1013
+ "***** Running Evaluation *****\n",
1014
+ " Num examples = 5722\n",
1015
+ " Batch size = 64\n",
1016
+ "Saving model checkpoint to ./my_saved_model\\checkpoint-4830\n",
1017
+ "Configuration saved in ./my_saved_model\\checkpoint-4830\\config.json\n",
1018
+ "Model weights saved in ./my_saved_model\\checkpoint-4830\\pytorch_model.bin\n",
1019
+ "Deleting older checkpoint [my_saved_model\\checkpoint-3220] due to args.save_total_limit\n",
1020
+ "The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n",
1021
+ "***** Running Evaluation *****\n",
1022
+ " Num examples = 5722\n",
1023
+ " Batch size = 64\n",
1024
+ "Saving model checkpoint to ./my_saved_model\\checkpoint-5635\n",
1025
+ "Configuration saved in ./my_saved_model\\checkpoint-5635\\config.json\n",
1026
+ "Model weights saved in ./my_saved_model\\checkpoint-5635\\pytorch_model.bin\n",
1027
+ "Deleting older checkpoint [my_saved_model\\checkpoint-4025] due to args.save_total_limit\n",
1028
+ "The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n",
1029
+ "***** Running Evaluation *****\n",
1030
+ " Num examples = 5722\n",
1031
+ " Batch size = 64\n",
1032
+ "Saving model checkpoint to ./my_saved_model\\checkpoint-6440\n",
1033
+ "Configuration saved in ./my_saved_model\\checkpoint-6440\\config.json\n",
1034
+ "Model weights saved in ./my_saved_model\\checkpoint-6440\\pytorch_model.bin\n",
1035
+ "Deleting older checkpoint [my_saved_model\\checkpoint-4830] due to args.save_total_limit\n",
1036
+ "The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n",
1037
+ "***** Running Evaluation *****\n",
1038
+ " Num examples = 5722\n",
1039
+ " Batch size = 64\n",
1040
+ "Saving model checkpoint to ./my_saved_model\\checkpoint-7245\n",
1041
+ "Configuration saved in ./my_saved_model\\checkpoint-7245\\config.json\n",
1042
+ "Model weights saved in ./my_saved_model\\checkpoint-7245\\pytorch_model.bin\n",
1043
+ "Deleting older checkpoint [my_saved_model\\checkpoint-5635] due to args.save_total_limit\n",
1044
+ "The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n",
1045
+ "***** Running Evaluation *****\n",
1046
+ " Num examples = 5722\n",
1047
+ " Batch size = 64\n",
1048
+ "Saving model checkpoint to ./my_saved_model\\checkpoint-8050\n",
1049
+ "Configuration saved in ./my_saved_model\\checkpoint-8050\\config.json\n",
1050
+ "Model weights saved in ./my_saved_model\\checkpoint-8050\\pytorch_model.bin\n",
1051
+ "Deleting older checkpoint [my_saved_model\\checkpoint-6440] due to args.save_total_limit\n",
1052
+ "\n",
1053
+ "\n",
1054
+ "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
1055
+ "\n",
1056
+ "\n",
1057
+ "Loading best model from ./my_saved_model\\checkpoint-8050 (score: 0.543603777885437).\n"
1058
+ ]
1059
+ },
1060
+ {
1061
+ "data": {
1062
+ "text/plain": [
1063
+ "TrainOutput(global_step=8050, training_loss=0.6166538418598057, metrics={'train_runtime': 5516.6092, 'train_samples_per_second': 93.34, 'train_steps_per_second': 1.459, 'total_flos': 6.821011291594752e+16, 'train_loss': 0.6166538418598057, 'epoch': 10.0})"
1064
+ ]
1065
+ },
1066
+ "execution_count": 26,
1067
+ "metadata": {},
1068
+ "output_type": "execute_result"
1069
+ }
1070
+ ],
1071
+ "source": [
1072
+ "from transformers import Trainer, TrainingArguments\n",
1073
+ "\n",
1074
+ "trainer = Trainer(\n",
1075
+ " model=model, train_dataset=dataset_splitted['train'], \n",
1076
+ " eval_dataset=dataset_splitted['test'],\n",
1077
+ " compute_metrics=compute_metrics,\n",
1078
+ " args=TrainingArguments(\n",
1079
+ " load_best_model_at_end=True,\n",
1080
+ " output_dir=\"./my_saved_model\", overwrite_output_dir=True,\n",
1081
+ " num_train_epochs=10, per_device_train_batch_size=64, \n",
1082
+ " per_device_eval_batch_size=64,\n",
1083
+ " evaluation_strategy = \"epoch\",\n",
1084
+ " save_strategy = \"epoch\",\n",
1085
+ " save_steps=10_000, save_total_limit=2),\n",
1086
+ ")\n",
1087
+ "\n",
1088
+ "trainer.train()"
1089
+ ]
1090
+ }
1091
+ ],
1092
+ "metadata": {
1093
+ "accelerator": "GPU",
1094
+ "colab": {
1095
+ "collapsed_sections": [],
1096
+ "name": "Копия блокнота \"ysda_2022.03.07.ipynb\"",
1097
+ "provenance": []
1098
+ },
1099
+ "kernelspec": {
1100
+ "display_name": "Python 3 (ipykernel)",
1101
+ "language": "python",
1102
+ "name": "python3"
1103
+ },
1104
+ "language_info": {
1105
+ "codemirror_mode": {
1106
+ "name": "ipython",
1107
+ "version": 3
1108
+ },
1109
+ "file_extension": ".py",
1110
+ "mimetype": "text/x-python",
1111
+ "name": "python",
1112
+ "nbconvert_exporter": "python",
1113
+ "pygments_lexer": "ipython3",
1114
+ "version": "3.7.12"
1115
+ },
1116
+ "widgets": {
1117
+ "application/vnd.jupyter.widget-state+json": {
1118
+ "300b70ed57dd493997afb0b3f25f4245": {
1119
+ "model_module": "@jupyter-widgets/controls",
1120
+ "model_module_version": "1.5.0",
1121
+ "model_name": "HTMLModel",
1122
+ "state": {
1123
+ "_dom_classes": [],
1124
+ "_model_module": "@jupyter-widgets/controls",
1125
+ "_model_module_version": "1.5.0",
1126
+ "_model_name": "HTMLModel",
1127
+ "_view_count": null,
1128
+ "_view_module": "@jupyter-widgets/controls",
1129
+ "_view_module_version": "1.5.0",
1130
+ "_view_name": "HTMLView",
1131
+ "description": "",
1132
+ "description_tooltip": null,
1133
+ "layout": "IPY_MODEL_9e11898bc51e483d91301387099368a4",
1134
+ "placeholder": "​",
1135
+ "style": "IPY_MODEL_a43574fa5fdf47ba9d5598b2b31f2082",
1136
+ "value": "100%"
1137
+ }
1138
+ },
1139
+ "482bae742d2a461cad525888e6ee8b91": {
1140
+ "model_module": "@jupyter-widgets/base",
1141
+ "model_module_version": "1.2.0",
1142
+ "model_name": "LayoutModel",
1143
+ "state": {
1144
+ "_model_module": "@jupyter-widgets/base",
1145
+ "_model_module_version": "1.2.0",
1146
+ "_model_name": "LayoutModel",
1147
+ "_view_count": null,
1148
+ "_view_module": "@jupyter-widgets/base",
1149
+ "_view_module_version": "1.2.0",
1150
+ "_view_name": "LayoutView",
1151
+ "align_content": null,
1152
+ "align_items": null,
1153
+ "align_self": null,
1154
+ "border": null,
1155
+ "bottom": null,
1156
+ "display": null,
1157
+ "flex": null,
1158
+ "flex_flow": null,
1159
+ "grid_area": null,
1160
+ "grid_auto_columns": null,
1161
+ "grid_auto_flow": null,
1162
+ "grid_auto_rows": null,
1163
+ "grid_column": null,
1164
+ "grid_gap": null,
1165
+ "grid_row": null,
1166
+ "grid_template_areas": null,
1167
+ "grid_template_columns": null,
1168
+ "grid_template_rows": null,
1169
+ "height": null,
1170
+ "justify_content": null,
1171
+ "justify_items": null,
1172
+ "left": null,
1173
+ "margin": null,
1174
+ "max_height": null,
1175
+ "max_width": null,
1176
+ "min_height": null,
1177
+ "min_width": null,
1178
+ "object_fit": null,
1179
+ "object_position": null,
1180
+ "order": null,
1181
+ "overflow": null,
1182
+ "overflow_x": null,
1183
+ "overflow_y": null,
1184
+ "padding": null,
1185
+ "right": null,
1186
+ "top": null,
1187
+ "visibility": null,
1188
+ "width": null
1189
+ }
1190
+ },
1191
+ "57c3794731c84c42bb49618482b6b8cc": {
1192
+ "model_module": "@jupyter-widgets/controls",
1193
+ "model_module_version": "1.5.0",
1194
+ "model_name": "HTMLModel",
1195
+ "state": {
1196
+ "_dom_classes": [],
1197
+ "_model_module": "@jupyter-widgets/controls",
1198
+ "_model_module_version": "1.5.0",
1199
+ "_model_name": "HTMLModel",
1200
+ "_view_count": null,
1201
+ "_view_module": "@jupyter-widgets/controls",
1202
+ "_view_module_version": "1.5.0",
1203
+ "_view_name": "HTMLView",
1204
+ "description": "",
1205
+ "description_tooltip": null,
1206
+ "layout": "IPY_MODEL_d604380b5e444f62ad36c4598230c561",
1207
+ "placeholder": "​",
1208
+ "style": "IPY_MODEL_c52ad745acb3423494b4ea5af5a934c7",
1209
+ "value": " 58/58 [02:02&lt;00:00, 1.83s/ba]"
1210
+ }
1211
+ },
1212
+ "5b49dc833234406da3da7435b9045fd2": {
1213
+ "model_module": "@jupyter-widgets/controls",
1214
+ "model_module_version": "1.5.0",
1215
+ "model_name": "HBoxModel",
1216
+ "state": {
1217
+ "_dom_classes": [],
1218
+ "_model_module": "@jupyter-widgets/controls",
1219
+ "_model_module_version": "1.5.0",
1220
+ "_model_name": "HBoxModel",
1221
+ "_view_count": null,
1222
+ "_view_module": "@jupyter-widgets/controls",
1223
+ "_view_module_version": "1.5.0",
1224
+ "_view_name": "HBoxView",
1225
+ "box_style": "",
1226
+ "children": [
1227
+ "IPY_MODEL_300b70ed57dd493997afb0b3f25f4245",
1228
+ "IPY_MODEL_c03cc68b079c4e23b339e9de5ba38d29",
1229
+ "IPY_MODEL_57c3794731c84c42bb49618482b6b8cc"
1230
+ ],
1231
+ "layout": "IPY_MODEL_e306828f6d7444ddafce604e9a170467"
1232
+ }
1233
+ },
1234
+ "9e11898bc51e483d91301387099368a4": {
1235
+ "model_module": "@jupyter-widgets/base",
1236
+ "model_module_version": "1.2.0",
1237
+ "model_name": "LayoutModel",
1238
+ "state": {
1239
+ "_model_module": "@jupyter-widgets/base",
1240
+ "_model_module_version": "1.2.0",
1241
+ "_model_name": "LayoutModel",
1242
+ "_view_count": null,
1243
+ "_view_module": "@jupyter-widgets/base",
1244
+ "_view_module_version": "1.2.0",
1245
+ "_view_name": "LayoutView",
1246
+ "align_content": null,
1247
+ "align_items": null,
1248
+ "align_self": null,
1249
+ "border": null,
1250
+ "bottom": null,
1251
+ "display": null,
1252
+ "flex": null,
1253
+ "flex_flow": null,
1254
+ "grid_area": null,
1255
+ "grid_auto_columns": null,
1256
+ "grid_auto_flow": null,
1257
+ "grid_auto_rows": null,
1258
+ "grid_column": null,
1259
+ "grid_gap": null,
1260
+ "grid_row": null,
1261
+ "grid_template_areas": null,
1262
+ "grid_template_columns": null,
1263
+ "grid_template_rows": null,
1264
+ "height": null,
1265
+ "justify_content": null,
1266
+ "justify_items": null,
1267
+ "left": null,
1268
+ "margin": null,
1269
+ "max_height": null,
1270
+ "max_width": null,
1271
+ "min_height": null,
1272
+ "min_width": null,
1273
+ "object_fit": null,
1274
+ "object_position": null,
1275
+ "order": null,
1276
+ "overflow": null,
1277
+ "overflow_x": null,
1278
+ "overflow_y": null,
1279
+ "padding": null,
1280
+ "right": null,
1281
+ "top": null,
1282
+ "visibility": null,
1283
+ "width": null
1284
+ }
1285
+ },
1286
+ "a43574fa5fdf47ba9d5598b2b31f2082": {
1287
+ "model_module": "@jupyter-widgets/controls",
1288
+ "model_module_version": "1.5.0",
1289
+ "model_name": "DescriptionStyleModel",
1290
+ "state": {
1291
+ "_model_module": "@jupyter-widgets/controls",
1292
+ "_model_module_version": "1.5.0",
1293
+ "_model_name": "DescriptionStyleModel",
1294
+ "_view_count": null,
1295
+ "_view_module": "@jupyter-widgets/base",
1296
+ "_view_module_version": "1.2.0",
1297
+ "_view_name": "StyleView",
1298
+ "description_width": ""
1299
+ }
1300
+ },
1301
+ "c03cc68b079c4e23b339e9de5ba38d29": {
1302
+ "model_module": "@jupyter-widgets/controls",
1303
+ "model_module_version": "1.5.0",
1304
+ "model_name": "FloatProgressModel",
1305
+ "state": {
1306
+ "_dom_classes": [],
1307
+ "_model_module": "@jupyter-widgets/controls",
1308
+ "_model_module_version": "1.5.0",
1309
+ "_model_name": "FloatProgressModel",
1310
+ "_view_count": null,
1311
+ "_view_module": "@jupyter-widgets/controls",
1312
+ "_view_module_version": "1.5.0",
1313
+ "_view_name": "ProgressView",
1314
+ "bar_style": "success",
1315
+ "description": "",
1316
+ "description_tooltip": null,
1317
+ "layout": "IPY_MODEL_482bae742d2a461cad525888e6ee8b91",
1318
+ "max": 58,
1319
+ "min": 0,
1320
+ "orientation": "horizontal",
1321
+ "style": "IPY_MODEL_e9c56275d73545a6961efe5704308ede",
1322
+ "value": 58
1323
+ }
1324
+ },
1325
+ "c52ad745acb3423494b4ea5af5a934c7": {
1326
+ "model_module": "@jupyter-widgets/controls",
1327
+ "model_module_version": "1.5.0",
1328
+ "model_name": "DescriptionStyleModel",
1329
+ "state": {
1330
+ "_model_module": "@jupyter-widgets/controls",
1331
+ "_model_module_version": "1.5.0",
1332
+ "_model_name": "DescriptionStyleModel",
1333
+ "_view_count": null,
1334
+ "_view_module": "@jupyter-widgets/base",
1335
+ "_view_module_version": "1.2.0",
1336
+ "_view_name": "StyleView",
1337
+ "description_width": ""
1338
+ }
1339
+ },
1340
+ "d604380b5e444f62ad36c4598230c561": {
1341
+ "model_module": "@jupyter-widgets/base",
1342
+ "model_module_version": "1.2.0",
1343
+ "model_name": "LayoutModel",
1344
+ "state": {
1345
+ "_model_module": "@jupyter-widgets/base",
1346
+ "_model_module_version": "1.2.0",
1347
+ "_model_name": "LayoutModel",
1348
+ "_view_count": null,
1349
+ "_view_module": "@jupyter-widgets/base",
1350
+ "_view_module_version": "1.2.0",
1351
+ "_view_name": "LayoutView",
1352
+ "align_content": null,
1353
+ "align_items": null,
1354
+ "align_self": null,
1355
+ "border": null,
1356
+ "bottom": null,
1357
+ "display": null,
1358
+ "flex": null,
1359
+ "flex_flow": null,
1360
+ "grid_area": null,
1361
+ "grid_auto_columns": null,
1362
+ "grid_auto_flow": null,
1363
+ "grid_auto_rows": null,
1364
+ "grid_column": null,
1365
+ "grid_gap": null,
1366
+ "grid_row": null,
1367
+ "grid_template_areas": null,
1368
+ "grid_template_columns": null,
1369
+ "grid_template_rows": null,
1370
+ "height": null,
1371
+ "justify_content": null,
1372
+ "justify_items": null,
1373
+ "left": null,
1374
+ "margin": null,
1375
+ "max_height": null,
1376
+ "max_width": null,
1377
+ "min_height": null,
1378
+ "min_width": null,
1379
+ "object_fit": null,
1380
+ "object_position": null,
1381
+ "order": null,
1382
+ "overflow": null,
1383
+ "overflow_x": null,
1384
+ "overflow_y": null,
1385
+ "padding": null,
1386
+ "right": null,
1387
+ "top": null,
1388
+ "visibility": null,
1389
+ "width": null
1390
+ }
1391
+ },
1392
+ "e306828f6d7444ddafce604e9a170467": {
1393
+ "model_module": "@jupyter-widgets/base",
1394
+ "model_module_version": "1.2.0",
1395
+ "model_name": "LayoutModel",
1396
+ "state": {
1397
+ "_model_module": "@jupyter-widgets/base",
1398
+ "_model_module_version": "1.2.0",
1399
+ "_model_name": "LayoutModel",
1400
+ "_view_count": null,
1401
+ "_view_module": "@jupyter-widgets/base",
1402
+ "_view_module_version": "1.2.0",
1403
+ "_view_name": "LayoutView",
1404
+ "align_content": null,
1405
+ "align_items": null,
1406
+ "align_self": null,
1407
+ "border": null,
1408
+ "bottom": null,
1409
+ "display": null,
1410
+ "flex": null,
1411
+ "flex_flow": null,
1412
+ "grid_area": null,
1413
+ "grid_auto_columns": null,
1414
+ "grid_auto_flow": null,
1415
+ "grid_auto_rows": null,
1416
+ "grid_column": null,
1417
+ "grid_gap": null,
1418
+ "grid_row": null,
1419
+ "grid_template_areas": null,
1420
+ "grid_template_columns": null,
1421
+ "grid_template_rows": null,
1422
+ "height": null,
1423
+ "justify_content": null,
1424
+ "justify_items": null,
1425
+ "left": null,
1426
+ "margin": null,
1427
+ "max_height": null,
1428
+ "max_width": null,
1429
+ "min_height": null,
1430
+ "min_width": null,
1431
+ "object_fit": null,
1432
+ "object_position": null,
1433
+ "order": null,
1434
+ "overflow": null,
1435
+ "overflow_x": null,
1436
+ "overflow_y": null,
1437
+ "padding": null,
1438
+ "right": null,
1439
+ "top": null,
1440
+ "visibility": null,
1441
+ "width": null
1442
+ }
1443
+ },
1444
+ "e9c56275d73545a6961efe5704308ede": {
1445
+ "model_module": "@jupyter-widgets/controls",
1446
+ "model_module_version": "1.5.0",
1447
+ "model_name": "ProgressStyleModel",
1448
+ "state": {
1449
+ "_model_module": "@jupyter-widgets/controls",
1450
+ "_model_module_version": "1.5.0",
1451
+ "_model_name": "ProgressStyleModel",
1452
+ "_view_count": null,
1453
+ "_view_module": "@jupyter-widgets/base",
1454
+ "_view_module_version": "1.2.0",
1455
+ "_view_name": "StyleView",
1456
+ "bar_color": null,
1457
+ "description_width": ""
1458
+ }
1459
+ }
1460
+ }
1461
+ }
1462
+ },
1463
+ "nbformat": 4,
1464
+ "nbformat_minor": 1
1465
+ }