stevhliu HF staff commited on
Commit
7236db0
·
1 Parent(s): d6e5d16

Upload prefix-tuning-clm.ipynb

Browse files
Files changed (1) hide show
  1. prefix-tuning-clm.ipynb +1389 -0
prefix-tuning-clm.ipynb ADDED
@@ -0,0 +1,1389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "id": "71fbfca2",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from transformers import AutoModelForCausalLM\n",
11
+ "from peft import get_peft_config, get_peft_model, PrefixTuningConfig, TaskType, PeftType\n",
12
+ "import torch\n",
13
+ "from datasets import load_dataset\n",
14
+ "import os\n",
15
+ "from transformers import AutoTokenizer\n",
16
+ "from torch.utils.data import DataLoader\n",
17
+ "from transformers import default_data_collator, get_linear_schedule_with_warmup\n",
18
+ "from tqdm import tqdm\n",
19
+ "from datasets import load_dataset\n",
20
+ "\n",
21
+ "device = \"cuda\"\n",
22
+ "model_name_or_path = \"bigscience/bloomz-560m\"\n",
23
+ "tokenizer_name_or_path = \"bigscience/bloomz-560m\"\n",
24
+ "peft_config = PrefixTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=30)\n",
25
+ "\n",
26
+ "dataset_name = \"twitter_complaints\"\n",
27
+ "checkpoint_name = f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}_v1.pt\".replace(\n",
28
+ " \"/\", \"_\"\n",
29
+ ")\n",
30
+ "text_column = \"Tweet text\"\n",
31
+ "label_column = \"text_label\"\n",
32
+ "max_length = 64\n",
33
+ "lr = 3e-2\n",
34
+ "num_epochs = 50\n",
35
+ "batch_size = 8"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": 3,
41
+ "id": "e1a3648b",
42
+ "metadata": {},
43
+ "outputs": [
44
+ {
45
+ "name": "stderr",
46
+ "output_type": "stream",
47
+ "text": [
48
+ "Found cached dataset raft (/home/sourab/.cache/huggingface/datasets/ought___raft/twitter_complaints/1.1.0/79c4de1312c1e3730043f7db07179c914f48403101f7124e2fe336f6f54d9f84)\n"
49
+ ]
50
+ },
51
+ {
52
+ "data": {
53
+ "application/vnd.jupyter.widget-view+json": {
54
+ "model_id": "56d9908a2c8944b484348cc46b16a261",
55
+ "version_major": 2,
56
+ "version_minor": 0
57
+ },
58
+ "text/plain": [
59
+ " 0%| | 0/2 [00:00<?, ?it/s]"
60
+ ]
61
+ },
62
+ "metadata": {},
63
+ "output_type": "display_data"
64
+ },
65
+ {
66
+ "name": "stderr",
67
+ "output_type": "stream",
68
+ "text": [
69
+ "Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/ought___raft/twitter_complaints/1.1.0/79c4de1312c1e3730043f7db07179c914f48403101f7124e2fe336f6f54d9f84/cache-20a7622c86d80cdf.arrow\n",
70
+ "Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/ought___raft/twitter_complaints/1.1.0/79c4de1312c1e3730043f7db07179c914f48403101f7124e2fe336f6f54d9f84/cache-5f1431311da05803.arrow\n"
71
+ ]
72
+ },
73
+ {
74
+ "name": "stdout",
75
+ "output_type": "stream",
76
+ "text": [
77
+ "['Unlabeled', 'complaint', 'no complaint']\n",
78
+ "DatasetDict({\n",
79
+ " train: Dataset({\n",
80
+ " features: ['Tweet text', 'ID', 'Label', 'text_label'],\n",
81
+ " num_rows: 50\n",
82
+ " })\n",
83
+ " test: Dataset({\n",
84
+ " features: ['Tweet text', 'ID', 'Label', 'text_label'],\n",
85
+ " num_rows: 3399\n",
86
+ " })\n",
87
+ "})\n"
88
+ ]
89
+ },
90
+ {
91
+ "data": {
92
+ "text/plain": [
93
+ "{'Tweet text': '@HMRCcustomers No this is my first job',\n",
94
+ " 'ID': 0,\n",
95
+ " 'Label': 2,\n",
96
+ " 'text_label': 'no complaint'}"
97
+ ]
98
+ },
99
+ "execution_count": 3,
100
+ "metadata": {},
101
+ "output_type": "execute_result"
102
+ }
103
+ ],
104
+ "source": [
105
+ "from datasets import load_dataset\n",
106
+ "\n",
107
+ "dataset = load_dataset(\"ought/raft\", dataset_name)\n",
108
+ "\n",
109
+ "classes = [k.replace(\"_\", \" \") for k in dataset[\"train\"].features[\"Label\"].names]\n",
110
+ "print(classes)\n",
111
+ "dataset = dataset.map(\n",
112
+ " lambda x: {\"text_label\": [classes[label] for label in x[\"Label\"]]},\n",
113
+ " batched=True,\n",
114
+ " num_proc=1,\n",
115
+ ")\n",
116
+ "print(dataset)\n",
117
+ "dataset[\"train\"][0]"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": 4,
123
+ "id": "fe12d4d3",
124
+ "metadata": {},
125
+ "outputs": [
126
+ {
127
+ "name": "stdout",
128
+ "output_type": "stream",
129
+ "text": [
130
+ "3\n"
131
+ ]
132
+ },
133
+ {
134
+ "data": {
135
+ "application/vnd.jupyter.widget-view+json": {
136
+ "model_id": "5a0e3242324842fb941950df38b459fe",
137
+ "version_major": 2,
138
+ "version_minor": 0
139
+ },
140
+ "text/plain": [
141
+ "Running tokenizer on dataset: 0%| | 0/1 [00:00<?, ?ba/s]"
142
+ ]
143
+ },
144
+ "metadata": {},
145
+ "output_type": "display_data"
146
+ },
147
+ {
148
+ "data": {
149
+ "application/vnd.jupyter.widget-view+json": {
150
+ "model_id": "133df817b7b9468cabd5353d4d2b675b",
151
+ "version_major": 2,
152
+ "version_minor": 0
153
+ },
154
+ "text/plain": [
155
+ "Running tokenizer on dataset: 0%| | 0/4 [00:00<?, ?ba/s]"
156
+ ]
157
+ },
158
+ "metadata": {},
159
+ "output_type": "display_data"
160
+ }
161
+ ],
162
+ "source": [
163
+ "# data preprocessing\n",
164
+ "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n",
165
+ "if tokenizer.pad_token_id is None:\n",
166
+ " tokenizer.pad_token_id = tokenizer.eos_token_id\n",
167
+ "target_max_length = max([len(tokenizer(class_label)[\"input_ids\"]) for class_label in classes])\n",
168
+ "print(target_max_length)\n",
169
+ "\n",
170
+ "\n",
171
+ "def preprocess_function(examples):\n",
172
+ " batch_size = len(examples[text_column])\n",
173
+ " inputs = [f\"{text_column} : {x} Label : \" for x in examples[text_column]]\n",
174
+ " targets = [str(x) for x in examples[label_column]]\n",
175
+ " model_inputs = tokenizer(inputs)\n",
176
+ " labels = tokenizer(targets, add_special_tokens=False) # don't add bos token because we concatenate with inputs\n",
177
+ " for i in range(batch_size):\n",
178
+ " sample_input_ids = model_inputs[\"input_ids\"][i]\n",
179
+ " label_input_ids = labels[\"input_ids\"][i] + [tokenizer.eos_token_id]\n",
180
+ " # print(i, sample_input_ids, label_input_ids)\n",
181
+ " model_inputs[\"input_ids\"][i] = sample_input_ids + label_input_ids\n",
182
+ " labels[\"input_ids\"][i] = [-100] * len(sample_input_ids) + label_input_ids\n",
183
+ " model_inputs[\"attention_mask\"][i] = [1] * len(model_inputs[\"input_ids\"][i])\n",
184
+ " # print(model_inputs)\n",
185
+ " for i in range(batch_size):\n",
186
+ " sample_input_ids = model_inputs[\"input_ids\"][i]\n",
187
+ " label_input_ids = labels[\"input_ids\"][i]\n",
188
+ " model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (\n",
189
+ " max_length - len(sample_input_ids)\n",
190
+ " ) + sample_input_ids\n",
191
+ " model_inputs[\"attention_mask\"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[\n",
192
+ " \"attention_mask\"\n",
193
+ " ][i]\n",
194
+ " labels[\"input_ids\"][i] = [-100] * (max_length - len(sample_input_ids)) + label_input_ids\n",
195
+ " model_inputs[\"input_ids\"][i] = torch.tensor(model_inputs[\"input_ids\"][i][:max_length])\n",
196
+ " model_inputs[\"attention_mask\"][i] = torch.tensor(model_inputs[\"attention_mask\"][i][:max_length])\n",
197
+ " labels[\"input_ids\"][i] = torch.tensor(labels[\"input_ids\"][i][:max_length])\n",
198
+ " model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
199
+ " return model_inputs\n",
200
+ "\n",
201
+ "\n",
202
+ "processed_datasets = dataset.map(\n",
203
+ " preprocess_function,\n",
204
+ " batched=True,\n",
205
+ " num_proc=1,\n",
206
+ " remove_columns=dataset[\"train\"].column_names,\n",
207
+ " load_from_cache_file=False,\n",
208
+ " desc=\"Running tokenizer on dataset\",\n",
209
+ ")\n",
210
+ "\n",
211
+ "train_dataset = processed_datasets[\"train\"]\n",
212
+ "eval_dataset = processed_datasets[\"train\"]\n",
213
+ "\n",
214
+ "\n",
215
+ "train_dataloader = DataLoader(\n",
216
+ " train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True\n",
217
+ ")\n",
218
+ "eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": null,
224
+ "id": "641b21fe",
225
+ "metadata": {},
226
+ "outputs": [],
227
+ "source": [
228
+ "def test_preprocess_function(examples):\n",
229
+ " batch_size = len(examples[text_column])\n",
230
+ " inputs = [f\"{text_column} : {x} Label : \" for x in examples[text_column]]\n",
231
+ " model_inputs = tokenizer(inputs)\n",
232
+ " # print(model_inputs)\n",
233
+ " for i in range(batch_size):\n",
234
+ " sample_input_ids = model_inputs[\"input_ids\"][i]\n",
235
+ " model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (\n",
236
+ " max_length - len(sample_input_ids)\n",
237
+ " ) + sample_input_ids\n",
238
+ " model_inputs[\"attention_mask\"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[\n",
239
+ " \"attention_mask\"\n",
240
+ " ][i]\n",
241
+ " model_inputs[\"input_ids\"][i] = torch.tensor(model_inputs[\"input_ids\"][i][:max_length])\n",
242
+ " model_inputs[\"attention_mask\"][i] = torch.tensor(model_inputs[\"attention_mask\"][i][:max_length])\n",
243
+ " return model_inputs\n",
244
+ "\n",
245
+ "\n",
246
+ "test_dataset = dataset[\"test\"].map(\n",
247
+ " test_preprocess_function,\n",
248
+ " batched=True,\n",
249
+ " num_proc=1,\n",
250
+ " remove_columns=dataset[\"train\"].column_names,\n",
251
+ " load_from_cache_file=False,\n",
252
+ " desc=\"Running tokenizer on dataset\",\n",
253
+ ")\n",
254
+ "\n",
255
+ "test_dataloader = DataLoader(test_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)\n",
256
+ "next(iter(test_dataloader))"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": null,
262
+ "id": "accc5012",
263
+ "metadata": {},
264
+ "outputs": [],
265
+ "source": [
266
+ "next(iter(train_dataloader))"
267
+ ]
268
+ },
269
+ {
270
+ "cell_type": "code",
271
+ "execution_count": 7,
272
+ "id": "218df807",
273
+ "metadata": {},
274
+ "outputs": [
275
+ {
276
+ "data": {
277
+ "text/plain": [
278
+ "425"
279
+ ]
280
+ },
281
+ "execution_count": 7,
282
+ "metadata": {},
283
+ "output_type": "execute_result"
284
+ }
285
+ ],
286
+ "source": [
287
+ "len(test_dataloader)"
288
+ ]
289
+ },
290
+ {
291
+ "cell_type": "code",
292
+ "execution_count": null,
293
+ "id": "47d1fedf",
294
+ "metadata": {},
295
+ "outputs": [],
296
+ "source": [
297
+ "next(iter(test_dataloader))"
298
+ ]
299
+ },
300
+ {
301
+ "cell_type": "code",
302
+ "execution_count": 9,
303
+ "id": "a773e092",
304
+ "metadata": {},
305
+ "outputs": [
306
+ {
307
+ "name": "stdout",
308
+ "output_type": "stream",
309
+ "text": [
310
+ "trainable params: 1474560 || all params: 560689152 || trainable%: 0.26299064191632515\n"
311
+ ]
312
+ }
313
+ ],
314
+ "source": [
315
+ "# creating model\n",
316
+ "model = AutoModelForCausalLM.from_pretrained(model_name_or_path)\n",
317
+ "model = get_peft_model(model, peft_config)\n",
318
+ "model.print_trainable_parameters()"
319
+ ]
320
+ },
321
+ {
322
+ "cell_type": "code",
323
+ "execution_count": 10,
324
+ "id": "bd419634",
325
+ "metadata": {},
326
+ "outputs": [
327
+ {
328
+ "name": "stdout",
329
+ "output_type": "stream",
330
+ "text": [
331
+ "trainable params: 1474560 || all params: 560689152 || trainable%: 0.26299064191632515\n"
332
+ ]
333
+ }
334
+ ],
335
+ "source": [
336
+ "model.print_trainable_parameters()"
337
+ ]
338
+ },
339
+ {
340
+ "cell_type": "code",
341
+ "execution_count": null,
342
+ "id": "22822901",
343
+ "metadata": {},
344
+ "outputs": [],
345
+ "source": [
346
+ "model"
347
+ ]
348
+ },
349
+ {
350
+ "cell_type": "code",
351
+ "execution_count": 12,
352
+ "id": "023cb942",
353
+ "metadata": {},
354
+ "outputs": [
355
+ {
356
+ "data": {
357
+ "text/plain": [
358
+ "PrefixTuningConfig(peft_type=<PeftType.PREFIX_TUNING: 'PREFIX_TUNING'>, base_model_name_or_path='bigscience/bloomz-560m', task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, inference_mode=False, num_virtual_tokens=30, token_dim=1024, num_transformer_submodules=1, num_attention_heads=16, num_layers=24, encoder_hidden_size=1024, prefix_projection=False)"
359
+ ]
360
+ },
361
+ "execution_count": 12,
362
+ "metadata": {},
363
+ "output_type": "execute_result"
364
+ }
365
+ ],
366
+ "source": [
367
+ "model.peft_config"
368
+ ]
369
+ },
370
+ {
371
+ "cell_type": "code",
372
+ "execution_count": 13,
373
+ "id": "b2f91568",
374
+ "metadata": {},
375
+ "outputs": [],
376
+ "source": [
377
+ "# model\n",
378
+ "# optimizer and lr scheduler\n",
379
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n",
380
+ "lr_scheduler = get_linear_schedule_with_warmup(\n",
381
+ " optimizer=optimizer,\n",
382
+ " num_warmup_steps=0,\n",
383
+ " num_training_steps=(len(train_dataloader) * num_epochs),\n",
384
+ ")"
385
+ ]
386
+ },
387
+ {
388
+ "cell_type": "code",
389
+ "execution_count": 14,
390
+ "id": "e4fb69fc",
391
+ "metadata": {},
392
+ "outputs": [
393
+ {
394
+ "name": "stderr",
395
+ "output_type": "stream",
396
+ "text": [
397
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00, 5.79it/s]\n",
398
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.51it/s]\n"
399
+ ]
400
+ },
401
+ {
402
+ "name": "stdout",
403
+ "output_type": "stream",
404
+ "text": [
405
+ "epoch=0: train_ppl=tensor(1.8325e+09, device='cuda:0') train_epoch_loss=tensor(21.3289, device='cuda:0') eval_ppl=tensor(2713.4180, device='cuda:0') eval_epoch_loss=tensor(7.9060, device='cuda:0')\n"
406
+ ]
407
+ },
408
+ {
409
+ "name": "stderr",
410
+ "output_type": "stream",
411
+ "text": [
412
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.44it/s]\n",
413
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.53it/s]\n"
414
+ ]
415
+ },
416
+ {
417
+ "name": "stdout",
418
+ "output_type": "stream",
419
+ "text": [
420
+ "epoch=1: train_ppl=tensor(341.0600, device='cuda:0') train_epoch_loss=tensor(5.8321, device='cuda:0') eval_ppl=tensor(80.8206, device='cuda:0') eval_epoch_loss=tensor(4.3922, device='cuda:0')\n"
421
+ ]
422
+ },
423
+ {
424
+ "name": "stderr",
425
+ "output_type": "stream",
426
+ "text": [
427
+ "100%|██████████████████████████████████████████████████████████████████████████████████��█████████| 7/7 [00:00<00:00, 11.44it/s]\n",
428
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.55it/s]\n"
429
+ ]
430
+ },
431
+ {
432
+ "name": "stdout",
433
+ "output_type": "stream",
434
+ "text": [
435
+ "epoch=2: train_ppl=tensor(59.8778, device='cuda:0') train_epoch_loss=tensor(4.0923, device='cuda:0') eval_ppl=tensor(34.4593, device='cuda:0') eval_epoch_loss=tensor(3.5398, device='cuda:0')\n"
436
+ ]
437
+ },
438
+ {
439
+ "name": "stderr",
440
+ "output_type": "stream",
441
+ "text": [
442
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.45it/s]\n",
443
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.55it/s]\n"
444
+ ]
445
+ },
446
+ {
447
+ "name": "stdout",
448
+ "output_type": "stream",
449
+ "text": [
450
+ "epoch=3: train_ppl=tensor(22.3307, device='cuda:0') train_epoch_loss=tensor(3.1060, device='cuda:0') eval_ppl=tensor(12.5947, device='cuda:0') eval_epoch_loss=tensor(2.5333, device='cuda:0')\n"
451
+ ]
452
+ },
453
+ {
454
+ "name": "stderr",
455
+ "output_type": "stream",
456
+ "text": [
457
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.45it/s]\n",
458
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.56it/s]\n"
459
+ ]
460
+ },
461
+ {
462
+ "name": "stdout",
463
+ "output_type": "stream",
464
+ "text": [
465
+ "epoch=4: train_ppl=tensor(9.1697, device='cuda:0') train_epoch_loss=tensor(2.2159, device='cuda:0') eval_ppl=tensor(4.5289, device='cuda:0') eval_epoch_loss=tensor(1.5105, device='cuda:0')\n"
466
+ ]
467
+ },
468
+ {
469
+ "name": "stderr",
470
+ "output_type": "stream",
471
+ "text": [
472
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.45it/s]\n",
473
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.52it/s]\n"
474
+ ]
475
+ },
476
+ {
477
+ "name": "stdout",
478
+ "output_type": "stream",
479
+ "text": [
480
+ "epoch=5: train_ppl=tensor(3.0172, device='cuda:0') train_epoch_loss=tensor(1.1043, device='cuda:0') eval_ppl=tensor(1.8092, device='cuda:0') eval_epoch_loss=tensor(0.5929, device='cuda:0')\n"
481
+ ]
482
+ },
483
+ {
484
+ "name": "stderr",
485
+ "output_type": "stream",
486
+ "text": [
487
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.44it/s]\n",
488
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.45it/s]\n"
489
+ ]
490
+ },
491
+ {
492
+ "name": "stdout",
493
+ "output_type": "stream",
494
+ "text": [
495
+ "epoch=6: train_ppl=tensor(1.4885, device='cuda:0') train_epoch_loss=tensor(0.3978, device='cuda:0') eval_ppl=tensor(1.4449, device='cuda:0') eval_epoch_loss=tensor(0.3680, device='cuda:0')\n"
496
+ ]
497
+ },
498
+ {
499
+ "name": "stderr",
500
+ "output_type": "stream",
501
+ "text": [
502
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s]\n",
503
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.48it/s]\n"
504
+ ]
505
+ },
506
+ {
507
+ "name": "stdout",
508
+ "output_type": "stream",
509
+ "text": [
510
+ "epoch=7: train_ppl=tensor(1.2967, device='cuda:0') train_epoch_loss=tensor(0.2598, device='cuda:0') eval_ppl=tensor(1.1587, device='cuda:0') eval_epoch_loss=tensor(0.1473, device='cuda:0')\n"
511
+ ]
512
+ },
513
+ {
514
+ "name": "stderr",
515
+ "output_type": "stream",
516
+ "text": [
517
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s]\n",
518
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.47it/s]\n"
519
+ ]
520
+ },
521
+ {
522
+ "name": "stdout",
523
+ "output_type": "stream",
524
+ "text": [
525
+ "epoch=8: train_ppl=tensor(1.1305, device='cuda:0') train_epoch_loss=tensor(0.1227, device='cuda:0') eval_ppl=tensor(1.0874, device='cuda:0') eval_epoch_loss=tensor(0.0838, device='cuda:0')\n"
526
+ ]
527
+ },
528
+ {
529
+ "name": "stderr",
530
+ "output_type": "stream",
531
+ "text": [
532
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.45it/s]\n",
533
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.46it/s]\n"
534
+ ]
535
+ },
536
+ {
537
+ "name": "stdout",
538
+ "output_type": "stream",
539
+ "text": [
540
+ "epoch=9: train_ppl=tensor(1.1608, device='cuda:0') train_epoch_loss=tensor(0.1491, device='cuda:0') eval_ppl=tensor(1.1461, device='cuda:0') eval_epoch_loss=tensor(0.1364, device='cuda:0')\n"
541
+ ]
542
+ },
543
+ {
544
+ "name": "stderr",
545
+ "output_type": "stream",
546
+ "text": [
547
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.45it/s]\n",
548
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.45it/s]\n"
549
+ ]
550
+ },
551
+ {
552
+ "name": "stdout",
553
+ "output_type": "stream",
554
+ "text": [
555
+ "epoch=10: train_ppl=tensor(1.3172, device='cuda:0') train_epoch_loss=tensor(0.2755, device='cuda:0') eval_ppl=tensor(1.1320, device='cuda:0') eval_epoch_loss=tensor(0.1240, device='cuda:0')\n"
556
+ ]
557
+ },
558
+ {
559
+ "name": "stderr",
560
+ "output_type": "stream",
561
+ "text": [
562
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.44it/s]\n",
563
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.46it/s]\n"
564
+ ]
565
+ },
566
+ {
567
+ "name": "stdout",
568
+ "output_type": "stream",
569
+ "text": [
570
+ "epoch=11: train_ppl=tensor(1.1437, device='cuda:0') train_epoch_loss=tensor(0.1343, device='cuda:0') eval_ppl=tensor(1.0676, device='cuda:0') eval_epoch_loss=tensor(0.0654, device='cuda:0')\n"
571
+ ]
572
+ },
573
+ {
574
+ "name": "stderr",
575
+ "output_type": "stream",
576
+ "text": [
577
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s]\n",
578
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.43it/s]\n"
579
+ ]
580
+ },
581
+ {
582
+ "name": "stdout",
583
+ "output_type": "stream",
584
+ "text": [
585
+ "epoch=12: train_ppl=tensor(1.0651, device='cuda:0') train_epoch_loss=tensor(0.0630, device='cuda:0') eval_ppl=tensor(1.0735, device='cuda:0') eval_epoch_loss=tensor(0.0710, device='cuda:0')\n"
586
+ ]
587
+ },
588
+ {
589
+ "name": "stderr",
590
+ "output_type": "stream",
591
+ "text": [
592
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.46it/s]\n",
593
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.47it/s]\n"
594
+ ]
595
+ },
596
+ {
597
+ "name": "stdout",
598
+ "output_type": "stream",
599
+ "text": [
600
+ "epoch=13: train_ppl=tensor(1.0607, device='cuda:0') train_epoch_loss=tensor(0.0589, device='cuda:0') eval_ppl=tensor(1.0399, device='cuda:0') eval_epoch_loss=tensor(0.0391, device='cuda:0')\n"
601
+ ]
602
+ },
603
+ {
604
+ "name": "stderr",
605
+ "output_type": "stream",
606
+ "text": [
607
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s]\n",
608
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.44it/s]\n"
609
+ ]
610
+ },
611
+ {
612
+ "name": "stdout",
613
+ "output_type": "stream",
614
+ "text": [
615
+ "epoch=14: train_ppl=tensor(1.0351, device='cuda:0') train_epoch_loss=tensor(0.0345, device='cuda:0') eval_ppl=tensor(1.0260, device='cuda:0') eval_epoch_loss=tensor(0.0257, device='cuda:0')\n"
616
+ ]
617
+ },
618
+ {
619
+ "name": "stderr",
620
+ "output_type": "stream",
621
+ "text": [
622
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.44it/s]\n",
623
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.43it/s]\n"
624
+ ]
625
+ },
626
+ {
627
+ "name": "stdout",
628
+ "output_type": "stream",
629
+ "text": [
630
+ "epoch=15: train_ppl=tensor(1.0217, device='cuda:0') train_epoch_loss=tensor(0.0215, device='cuda:0') eval_ppl=tensor(1.0168, device='cuda:0') eval_epoch_loss=tensor(0.0167, device='cuda:0')\n"
631
+ ]
632
+ },
633
+ {
634
+ "name": "stderr",
635
+ "output_type": "stream",
636
+ "text": [
637
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s]\n",
638
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.28it/s]\n"
639
+ ]
640
+ },
641
+ {
642
+ "name": "stdout",
643
+ "output_type": "stream",
644
+ "text": [
645
+ "epoch=16: train_ppl=tensor(1.0152, device='cuda:0') train_epoch_loss=tensor(0.0151, device='cuda:0') eval_ppl=tensor(1.0117, device='cuda:0') eval_epoch_loss=tensor(0.0116, device='cuda:0')\n"
646
+ ]
647
+ },
648
+ {
649
+ "name": "stderr",
650
+ "output_type": "stream",
651
+ "text": [
652
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s]\n",
653
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.41it/s]\n"
654
+ ]
655
+ },
656
+ {
657
+ "name": "stdout",
658
+ "output_type": "stream",
659
+ "text": [
660
+ "epoch=17: train_ppl=tensor(1.0102, device='cuda:0') train_epoch_loss=tensor(0.0101, device='cuda:0') eval_ppl=tensor(1.0088, device='cuda:0') eval_epoch_loss=tensor(0.0088, device='cuda:0')\n"
661
+ ]
662
+ },
663
+ {
664
+ "name": "stderr",
665
+ "output_type": "stream",
666
+ "text": [
667
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.29it/s]\n",
668
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.25it/s]\n"
669
+ ]
670
+ },
671
+ {
672
+ "name": "stdout",
673
+ "output_type": "stream",
674
+ "text": [
675
+ "epoch=18: train_ppl=tensor(1.0083, device='cuda:0') train_epoch_loss=tensor(0.0083, device='cuda:0') eval_ppl=tensor(1.0073, device='cuda:0') eval_epoch_loss=tensor(0.0073, device='cuda:0')\n"
676
+ ]
677
+ },
678
+ {
679
+ "name": "stderr",
680
+ "output_type": "stream",
681
+ "text": [
682
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s]\n",
683
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.46it/s]\n"
684
+ ]
685
+ },
686
+ {
687
+ "name": "stdout",
688
+ "output_type": "stream",
689
+ "text": [
690
+ "epoch=19: train_ppl=tensor(1.0070, device='cuda:0') train_epoch_loss=tensor(0.0070, device='cuda:0') eval_ppl=tensor(1.0064, device='cuda:0') eval_epoch_loss=tensor(0.0063, device='cuda:0')\n"
691
+ ]
692
+ },
693
+ {
694
+ "name": "stderr",
695
+ "output_type": "stream",
696
+ "text": [
697
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s]\n",
698
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.51it/s]\n"
699
+ ]
700
+ },
701
+ {
702
+ "name": "stdout",
703
+ "output_type": "stream",
704
+ "text": [
705
+ "epoch=20: train_ppl=tensor(1.0059, device='cuda:0') train_epoch_loss=tensor(0.0059, device='cuda:0') eval_ppl=tensor(1.0057, device='cuda:0') eval_epoch_loss=tensor(0.0057, device='cuda:0')\n"
706
+ ]
707
+ },
708
+ {
709
+ "name": "stderr",
710
+ "output_type": "stream",
711
+ "text": [
712
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s]\n",
713
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.47it/s]\n"
714
+ ]
715
+ },
716
+ {
717
+ "name": "stdout",
718
+ "output_type": "stream",
719
+ "text": [
720
+ "epoch=21: train_ppl=tensor(1.0056, device='cuda:0') train_epoch_loss=tensor(0.0056, device='cuda:0') eval_ppl=tensor(1.0052, device='cuda:0') eval_epoch_loss=tensor(0.0052, device='cuda:0')\n"
721
+ ]
722
+ },
723
+ {
724
+ "name": "stderr",
725
+ "output_type": "stream",
726
+ "text": [
727
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.41it/s]\n",
728
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.33it/s]\n"
729
+ ]
730
+ },
731
+ {
732
+ "name": "stdout",
733
+ "output_type": "stream",
734
+ "text": [
735
+ "epoch=22: train_ppl=tensor(1.0050, device='cuda:0') train_epoch_loss=tensor(0.0050, device='cuda:0') eval_ppl=tensor(1.0049, device='cuda:0') eval_epoch_loss=tensor(0.0049, device='cuda:0')\n"
736
+ ]
737
+ },
738
+ {
739
+ "name": "stderr",
740
+ "output_type": "stream",
741
+ "text": [
742
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.39it/s]\n",
743
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.44it/s]\n"
744
+ ]
745
+ },
746
+ {
747
+ "name": "stdout",
748
+ "output_type": "stream",
749
+ "text": [
750
+ "epoch=23: train_ppl=tensor(1.0049, device='cuda:0') train_epoch_loss=tensor(0.0049, device='cuda:0') eval_ppl=tensor(1.0045, device='cuda:0') eval_epoch_loss=tensor(0.0045, device='cuda:0')\n"
751
+ ]
752
+ },
753
+ {
754
+ "name": "stderr",
755
+ "output_type": "stream",
756
+ "text": [
757
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.42it/s]\n",
758
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.49it/s]\n"
759
+ ]
760
+ },
761
+ {
762
+ "name": "stdout",
763
+ "output_type": "stream",
764
+ "text": [
765
+ "epoch=24: train_ppl=tensor(1.0043, device='cuda:0') train_epoch_loss=tensor(0.0043, device='cuda:0') eval_ppl=tensor(1.0043, device='cuda:0') eval_epoch_loss=tensor(0.0043, device='cuda:0')\n"
766
+ ]
767
+ },
768
+ {
769
+ "name": "stderr",
770
+ "output_type": "stream",
771
+ "text": [
772
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.46it/s]\n",
773
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.47it/s]\n"
774
+ ]
775
+ },
776
+ {
777
+ "name": "stdout",
778
+ "output_type": "stream",
779
+ "text": [
780
+ "epoch=25: train_ppl=tensor(1.0042, device='cuda:0') train_epoch_loss=tensor(0.0042, device='cuda:0') eval_ppl=tensor(1.0040, device='cuda:0') eval_epoch_loss=tensor(0.0040, device='cuda:0')\n"
781
+ ]
782
+ },
783
+ {
784
+ "name": "stderr",
785
+ "output_type": "stream",
786
+ "text": [
787
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.44it/s]\n",
788
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.52it/s]\n"
789
+ ]
790
+ },
791
+ {
792
+ "name": "stdout",
793
+ "output_type": "stream",
794
+ "text": [
795
+ "epoch=26: train_ppl=tensor(1.0039, device='cuda:0') train_epoch_loss=tensor(0.0039, device='cuda:0') eval_ppl=tensor(1.0039, device='cuda:0') eval_epoch_loss=tensor(0.0039, device='cuda:0')\n"
796
+ ]
797
+ },
798
+ {
799
+ "name": "stderr",
800
+ "output_type": "stream",
801
+ "text": [
802
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.44it/s]\n",
803
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.48it/s]\n"
804
+ ]
805
+ },
806
+ {
807
+ "name": "stdout",
808
+ "output_type": "stream",
809
+ "text": [
810
+ "epoch=27: train_ppl=tensor(1.0038, device='cuda:0') train_epoch_loss=tensor(0.0038, device='cuda:0') eval_ppl=tensor(1.0037, device='cuda:0') eval_epoch_loss=tensor(0.0037, device='cuda:0')\n"
811
+ ]
812
+ },
813
+ {
814
+ "name": "stderr",
815
+ "output_type": "stream",
816
+ "text": [
817
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.46it/s]\n",
818
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.54it/s]\n"
819
+ ]
820
+ },
821
+ {
822
+ "name": "stdout",
823
+ "output_type": "stream",
824
+ "text": [
825
+ "epoch=28: train_ppl=tensor(1.0036, device='cuda:0') train_epoch_loss=tensor(0.0036, device='cuda:0') eval_ppl=tensor(1.0035, device='cuda:0') eval_epoch_loss=tensor(0.0035, device='cuda:0')\n"
826
+ ]
827
+ },
828
+ {
829
+ "name": "stderr",
830
+ "output_type": "stream",
831
+ "text": [
832
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.45it/s]\n",
833
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.53it/s]\n"
834
+ ]
835
+ },
836
+ {
837
+ "name": "stdout",
838
+ "output_type": "stream",
839
+ "text": [
840
+ "epoch=29: train_ppl=tensor(1.0034, device='cuda:0') train_epoch_loss=tensor(0.0034, device='cuda:0') eval_ppl=tensor(1.0034, device='cuda:0') eval_epoch_loss=tensor(0.0034, device='cuda:0')\n"
841
+ ]
842
+ },
843
+ {
844
+ "name": "stderr",
845
+ "output_type": "stream",
846
+ "text": [
847
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s]\n",
848
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.47it/s]\n"
849
+ ]
850
+ },
851
+ {
852
+ "name": "stdout",
853
+ "output_type": "stream",
854
+ "text": [
855
+ "epoch=30: train_ppl=tensor(1.0034, device='cuda:0') train_epoch_loss=tensor(0.0034, device='cuda:0') eval_ppl=tensor(1.0033, device='cuda:0') eval_epoch_loss=tensor(0.0033, device='cuda:0')\n"
856
+ ]
857
+ },
858
+ {
859
+ "name": "stderr",
860
+ "output_type": "stream",
861
+ "text": [
862
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s]\n",
863
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.47it/s]\n"
864
+ ]
865
+ },
866
+ {
867
+ "name": "stdout",
868
+ "output_type": "stream",
869
+ "text": [
870
+ "epoch=31: train_ppl=tensor(1.0033, device='cuda:0') train_epoch_loss=tensor(0.0033, device='cuda:0') eval_ppl=tensor(1.0032, device='cuda:0') eval_epoch_loss=tensor(0.0032, device='cuda:0')\n"
871
+ ]
872
+ },
873
+ {
874
+ "name": "stderr",
875
+ "output_type": "stream",
876
+ "text": [
877
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.46it/s]\n",
878
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.51it/s]\n"
879
+ ]
880
+ },
881
+ {
882
+ "name": "stdout",
883
+ "output_type": "stream",
884
+ "text": [
885
+ "epoch=32: train_ppl=tensor(1.0031, device='cuda:0') train_epoch_loss=tensor(0.0031, device='cuda:0') eval_ppl=tensor(1.0031, device='cuda:0') eval_epoch_loss=tensor(0.0031, device='cuda:0')\n"
886
+ ]
887
+ },
888
+ {
889
+ "name": "stderr",
890
+ "output_type": "stream",
891
+ "text": [
892
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.44it/s]\n",
893
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.43it/s]\n"
894
+ ]
895
+ },
896
+ {
897
+ "name": "stdout",
898
+ "output_type": "stream",
899
+ "text": [
900
+ "epoch=33: train_ppl=tensor(1.0030, device='cuda:0') train_epoch_loss=tensor(0.0030, device='cuda:0') eval_ppl=tensor(1.0030, device='cuda:0') eval_epoch_loss=tensor(0.0030, device='cuda:0')\n"
901
+ ]
902
+ },
903
+ {
904
+ "name": "stderr",
905
+ "output_type": "stream",
906
+ "text": [
907
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.45it/s]\n",
908
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.46it/s]\n"
909
+ ]
910
+ },
911
+ {
912
+ "name": "stdout",
913
+ "output_type": "stream",
914
+ "text": [
915
+ "epoch=34: train_ppl=tensor(1.0029, device='cuda:0') train_epoch_loss=tensor(0.0029, device='cuda:0') eval_ppl=tensor(1.0029, device='cuda:0') eval_epoch_loss=tensor(0.0029, device='cuda:0')\n"
916
+ ]
917
+ },
918
+ {
919
+ "name": "stderr",
920
+ "output_type": "stream",
921
+ "text": [
922
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.45it/s]\n",
923
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.47it/s]\n"
924
+ ]
925
+ },
926
+ {
927
+ "name": "stdout",
928
+ "output_type": "stream",
929
+ "text": [
930
+ "epoch=35: train_ppl=tensor(1.0028, device='cuda:0') train_epoch_loss=tensor(0.0028, device='cuda:0') eval_ppl=tensor(1.0029, device='cuda:0') eval_epoch_loss=tensor(0.0029, device='cuda:0')\n"
931
+ ]
932
+ },
933
+ {
934
+ "name": "stderr",
935
+ "output_type": "stream",
936
+ "text": [
937
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.45it/s]\n",
938
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.45it/s]\n"
939
+ ]
940
+ },
941
+ {
942
+ "name": "stdout",
943
+ "output_type": "stream",
944
+ "text": [
945
+ "epoch=36: train_ppl=tensor(1.0027, device='cuda:0') train_epoch_loss=tensor(0.0027, device='cuda:0') eval_ppl=tensor(1.0028, device='cuda:0') eval_epoch_loss=tensor(0.0028, device='cuda:0')\n"
946
+ ]
947
+ },
948
+ {
949
+ "name": "stderr",
950
+ "output_type": "stream",
951
+ "text": [
952
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.45it/s]\n",
953
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.47it/s]\n"
954
+ ]
955
+ },
956
+ {
957
+ "name": "stdout",
958
+ "output_type": "stream",
959
+ "text": [
960
+ "epoch=37: train_ppl=tensor(1.0027, device='cuda:0') train_epoch_loss=tensor(0.0027, device='cuda:0') eval_ppl=tensor(1.0027, device='cuda:0') eval_epoch_loss=tensor(0.0027, device='cuda:0')\n"
961
+ ]
962
+ },
963
+ {
964
+ "name": "stderr",
965
+ "output_type": "stream",
966
+ "text": [
967
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.45it/s]\n",
968
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.46it/s]\n"
969
+ ]
970
+ },
971
+ {
972
+ "name": "stdout",
973
+ "output_type": "stream",
974
+ "text": [
975
+ "epoch=38: train_ppl=tensor(1.0027, device='cuda:0') train_epoch_loss=tensor(0.0027, device='cuda:0') eval_ppl=tensor(1.0027, device='cuda:0') eval_epoch_loss=tensor(0.0027, device='cuda:0')\n"
976
+ ]
977
+ },
978
+ {
979
+ "name": "stderr",
980
+ "output_type": "stream",
981
+ "text": [
982
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.44it/s]\n",
983
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.43it/s]\n"
984
+ ]
985
+ },
986
+ {
987
+ "name": "stdout",
988
+ "output_type": "stream",
989
+ "text": [
990
+ "epoch=39: train_ppl=tensor(1.0025, device='cuda:0') train_epoch_loss=tensor(0.0025, device='cuda:0') eval_ppl=tensor(1.0026, device='cuda:0') eval_epoch_loss=tensor(0.0026, device='cuda:0')\n"
991
+ ]
992
+ },
993
+ {
994
+ "name": "stderr",
995
+ "output_type": "stream",
996
+ "text": [
997
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.44it/s]\n",
998
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.47it/s]\n"
999
+ ]
1000
+ },
1001
+ {
1002
+ "name": "stdout",
1003
+ "output_type": "stream",
1004
+ "text": [
1005
+ "epoch=40: train_ppl=tensor(1.0026, device='cuda:0') train_epoch_loss=tensor(0.0026, device='cuda:0') eval_ppl=tensor(1.0026, device='cuda:0') eval_epoch_loss=tensor(0.0026, device='cuda:0')\n"
1006
+ ]
1007
+ },
1008
+ {
1009
+ "name": "stderr",
1010
+ "output_type": "stream",
1011
+ "text": [
1012
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.44it/s]\n",
1013
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.33it/s]\n"
1014
+ ]
1015
+ },
1016
+ {
1017
+ "name": "stdout",
1018
+ "output_type": "stream",
1019
+ "text": [
1020
+ "epoch=41: train_ppl=tensor(1.0025, device='cuda:0') train_epoch_loss=tensor(0.0025, device='cuda:0') eval_ppl=tensor(1.0025, device='cuda:0') eval_epoch_loss=tensor(0.0025, device='cuda:0')\n"
1021
+ ]
1022
+ },
1023
+ {
1024
+ "name": "stderr",
1025
+ "output_type": "stream",
1026
+ "text": [
1027
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.42it/s]\n",
1028
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.49it/s]\n"
1029
+ ]
1030
+ },
1031
+ {
1032
+ "name": "stdout",
1033
+ "output_type": "stream",
1034
+ "text": [
1035
+ "epoch=42: train_ppl=tensor(1.0024, device='cuda:0') train_epoch_loss=tensor(0.0024, device='cuda:0') eval_ppl=tensor(1.0025, device='cuda:0') eval_epoch_loss=tensor(0.0025, device='cuda:0')\n"
1036
+ ]
1037
+ },
1038
+ {
1039
+ "name": "stderr",
1040
+ "output_type": "stream",
1041
+ "text": [
1042
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.44it/s]\n",
1043
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.47it/s]\n"
1044
+ ]
1045
+ },
1046
+ {
1047
+ "name": "stdout",
1048
+ "output_type": "stream",
1049
+ "text": [
1050
+ "epoch=43: train_ppl=tensor(1.0024, device='cuda:0') train_epoch_loss=tensor(0.0024, device='cuda:0') eval_ppl=tensor(1.0025, device='cuda:0') eval_epoch_loss=tensor(0.0025, device='cuda:0')\n"
1051
+ ]
1052
+ },
1053
+ {
1054
+ "name": "stderr",
1055
+ "output_type": "stream",
1056
+ "text": [
1057
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.44it/s]\n",
1058
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.43it/s]\n"
1059
+ ]
1060
+ },
1061
+ {
1062
+ "name": "stdout",
1063
+ "output_type": "stream",
1064
+ "text": [
1065
+ "epoch=44: train_ppl=tensor(1.0025, device='cuda:0') train_epoch_loss=tensor(0.0024, device='cuda:0') eval_ppl=tensor(1.0024, device='cuda:0') eval_epoch_loss=tensor(0.0024, device='cuda:0')\n"
1066
+ ]
1067
+ },
1068
+ {
1069
+ "name": "stderr",
1070
+ "output_type": "stream",
1071
+ "text": [
1072
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s]\n",
1073
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.50it/s]\n"
1074
+ ]
1075
+ },
1076
+ {
1077
+ "name": "stdout",
1078
+ "output_type": "stream",
1079
+ "text": [
1080
+ "epoch=45: train_ppl=tensor(1.0024, device='cuda:0') train_epoch_loss=tensor(0.0024, device='cuda:0') eval_ppl=tensor(1.0024, device='cuda:0') eval_epoch_loss=tensor(0.0024, device='cuda:0')\n"
1081
+ ]
1082
+ },
1083
+ {
1084
+ "name": "stderr",
1085
+ "output_type": "stream",
1086
+ "text": [
1087
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s]\n",
1088
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.49it/s]\n"
1089
+ ]
1090
+ },
1091
+ {
1092
+ "name": "stdout",
1093
+ "output_type": "stream",
1094
+ "text": [
1095
+ "epoch=46: train_ppl=tensor(1.0024, device='cuda:0') train_epoch_loss=tensor(0.0024, device='cuda:0') eval_ppl=tensor(1.0024, device='cuda:0') eval_epoch_loss=tensor(0.0024, device='cuda:0')\n"
1096
+ ]
1097
+ },
1098
+ {
1099
+ "name": "stderr",
1100
+ "output_type": "stream",
1101
+ "text": [
1102
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.42it/s]\n",
1103
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.39it/s]\n"
1104
+ ]
1105
+ },
1106
+ {
1107
+ "name": "stdout",
1108
+ "output_type": "stream",
1109
+ "text": [
1110
+ "epoch=47: train_ppl=tensor(1.0023, device='cuda:0') train_epoch_loss=tensor(0.0023, device='cuda:0') eval_ppl=tensor(1.0024, device='cuda:0') eval_epoch_loss=tensor(0.0024, device='cuda:0')\n"
1111
+ ]
1112
+ },
1113
+ {
1114
+ "name": "stderr",
1115
+ "output_type": "stream",
1116
+ "text": [
1117
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.40it/s]\n",
1118
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.40it/s]\n"
1119
+ ]
1120
+ },
1121
+ {
1122
+ "name": "stdout",
1123
+ "output_type": "stream",
1124
+ "text": [
1125
+ "epoch=48: train_ppl=tensor(1.0023, device='cuda:0') train_epoch_loss=tensor(0.0023, device='cuda:0') eval_ppl=tensor(1.0024, device='cuda:0') eval_epoch_loss=tensor(0.0024, device='cuda:0')\n"
1126
+ ]
1127
+ },
1128
+ {
1129
+ "name": "stderr",
1130
+ "output_type": "stream",
1131
+ "text": [
1132
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.41it/s]\n",
1133
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.87it/s]"
1134
+ ]
1135
+ },
1136
+ {
1137
+ "name": "stdout",
1138
+ "output_type": "stream",
1139
+ "text": [
1140
+ "epoch=49: train_ppl=tensor(1.0023, device='cuda:0') train_epoch_loss=tensor(0.0023, device='cuda:0') eval_ppl=tensor(1.0024, device='cuda:0') eval_epoch_loss=tensor(0.0024, device='cuda:0')\n"
1141
+ ]
1142
+ },
1143
+ {
1144
+ "name": "stderr",
1145
+ "output_type": "stream",
1146
+ "text": [
1147
+ "\n"
1148
+ ]
1149
+ }
1150
+ ],
1151
+ "source": [
1152
+ "# training and evaluation\n",
1153
+ "model = model.to(device)\n",
1154
+ "\n",
1155
+ "for epoch in range(num_epochs):\n",
1156
+ " model.train()\n",
1157
+ " total_loss = 0\n",
1158
+ " for step, batch in enumerate(tqdm(train_dataloader)):\n",
1159
+ " batch = {k: v.to(device) for k, v in batch.items()}\n",
1160
+ " # print(batch)\n",
1161
+ " # print(batch[\"input_ids\"].shape)\n",
1162
+ " outputs = model(**batch)\n",
1163
+ " loss = outputs.loss\n",
1164
+ " total_loss += loss.detach().float()\n",
1165
+ " loss.backward()\n",
1166
+ " optimizer.step()\n",
1167
+ " lr_scheduler.step()\n",
1168
+ " optimizer.zero_grad()\n",
1169
+ "\n",
1170
+ " model.eval()\n",
1171
+ " eval_loss = 0\n",
1172
+ " eval_preds = []\n",
1173
+ " for step, batch in enumerate(tqdm(eval_dataloader)):\n",
1174
+ " batch = {k: v.to(device) for k, v in batch.items()}\n",
1175
+ " with torch.no_grad():\n",
1176
+ " outputs = model(**batch)\n",
1177
+ " loss = outputs.loss\n",
1178
+ " eval_loss += loss.detach().float()\n",
1179
+ " eval_preds.extend(\n",
1180
+ " tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)\n",
1181
+ " )\n",
1182
+ "\n",
1183
+ " eval_epoch_loss = eval_loss / len(eval_dataloader)\n",
1184
+ " eval_ppl = torch.exp(eval_epoch_loss)\n",
1185
+ " train_epoch_loss = total_loss / len(train_dataloader)\n",
1186
+ " train_ppl = torch.exp(train_epoch_loss)\n",
1187
+ " print(f\"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}\")"
1188
+ ]
1189
+ },
1190
+ {
1191
+ "cell_type": "code",
1192
+ "execution_count": 36,
1193
+ "id": "53752a7b",
1194
+ "metadata": {},
1195
+ "outputs": [
1196
+ {
1197
+ "name": "stdout",
1198
+ "output_type": "stream",
1199
+ "text": [
1200
+ "Hey @nytimes your link to cancel my subscription isn't working and nobody is answering the chat. Please don't play that kind of stupid game.\n",
1201
+ "{'input_ids': tensor([[227985, 5484, 915, 54078, 2566, 7782, 24502, 2632, 8989,\n",
1202
+ " 427, 36992, 2670, 140711, 21994, 10789, 530, 88399, 632,\n",
1203
+ " 183542, 368, 44799, 17, 29901, 5926, 7229, 861, 11596,\n",
1204
+ " 461, 78851, 14775, 17, 77658, 915, 210]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
1205
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}\n",
1206
+ "tensor([[227985, 5484, 915, 54078, 2566, 7782, 24502, 2632, 8989,\n",
1207
+ " 427, 36992, 2670, 140711, 21994, 10789, 530, 88399, 632,\n",
1208
+ " 183542, 368, 44799, 17, 29901, 5926, 7229, 861, 11596,\n",
1209
+ " 461, 78851, 14775, 17, 77658, 915, 210, 16449, 5952,\n",
1210
+ " 3]], device='cuda:0')\n",
1211
+ "[\"Tweet text : Hey @nytimes your link to cancel my subscription isn't working and nobody is answering the chat. Please don't play that kind of stupid game. Label : complaint\"]\n"
1212
+ ]
1213
+ }
1214
+ ],
1215
+ "source": [
1216
+ "model.eval()\n",
1217
+ "i = 16\n",
1218
+ "inputs = tokenizer(f'{text_column} : {dataset[\"test\"][i][\"Tweet text\"]} Label : ', return_tensors=\"pt\")\n",
1219
+ "print(dataset[\"test\"][i][\"Tweet text\"])\n",
1220
+ "print(inputs)\n",
1221
+ "\n",
1222
+ "with torch.no_grad():\n",
1223
+ " inputs = {k: v.to(device) for k, v in inputs.items()}\n",
1224
+ " outputs = model.generate(\n",
1225
+ " input_ids=inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"], max_new_tokens=10, eos_token_id=3\n",
1226
+ " )\n",
1227
+ " print(outputs)\n",
1228
+ " print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"
1229
+ ]
1230
+ },
1231
+ {
1232
+ "cell_type": "markdown",
1233
+ "id": "0e21c49b",
1234
+ "metadata": {},
1235
+ "source": [
1236
+ "You can push model to hub or save model locally. \n",
1237
+ "\n",
1238
+ "- Option1: Pushing the model to Hugging Face Hub\n",
1239
+ "```python\n",
1240
+ "model.push_to_hub(\n",
1241
+ " f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\".replace(\"/\", \"_\"),\n",
1242
+ " token = \"hf_...\"\n",
1243
+ ")\n",
1244
+ "```\n",
1245
+ "token (`bool` or `str`, *optional*):\n",
1246
+ " `token` is to be used for HTTP Bearer authorization when accessing remote files. If `True`, will use the token generated\n",
1247
+ " when running `huggingface-cli login` (stored in `~/.huggingface`). Will default to `True` if `repo_url`\n",
1248
+ " is not specified.\n",
1249
+ " Or you can get your token from https://huggingface.co/settings/token\n",
1250
+ "```\n",
1251
+ "- Or save model locally\n",
1252
+ "```python\n",
1253
+ "peft_model_id = f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\".replace(\"/\", \"_\")\n",
1254
+ "model.save_pretrained(peft_model_id)\n",
1255
+ "```"
1256
+ ]
1257
+ },
1258
+ {
1259
+ "cell_type": "code",
1260
+ "execution_count": 16,
1261
+ "id": "24041ee1",
1262
+ "metadata": {},
1263
+ "outputs": [],
1264
+ "source": [
1265
+ "# saving model\n",
1266
+ "peft_model_id = f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\".replace(\n",
1267
+ " \"/\", \"_\"\n",
1268
+ ")\n",
1269
+ "model.save_pretrained(peft_model_id)"
1270
+ ]
1271
+ },
1272
+ {
1273
+ "cell_type": "code",
1274
+ "execution_count": null,
1275
+ "id": "527eeaa4",
1276
+ "metadata": {},
1277
+ "outputs": [],
1278
+ "source": [
1279
+ "ckpt = f\"{peft_model_id}/adapter_model.bin\"\n",
1280
+ "!du -h $ckpt"
1281
+ ]
1282
+ },
1283
+ {
1284
+ "cell_type": "code",
1285
+ "execution_count": 18,
1286
+ "id": "b19f5a90",
1287
+ "metadata": {},
1288
+ "outputs": [],
1289
+ "source": [
1290
+ "from peft import PeftModel, PeftConfig\n",
1291
+ "\n",
1292
+ "peft_model_id = f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\".replace(\n",
1293
+ " \"/\", \"_\"\n",
1294
+ ")\n",
1295
+ "\n",
1296
+ "config = PeftConfig.from_pretrained(peft_model_id)\n",
1297
+ "model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)\n",
1298
+ "model = PeftModel.from_pretrained(model, peft_model_id)"
1299
+ ]
1300
+ },
1301
+ {
1302
+ "cell_type": "code",
1303
+ "execution_count": 21,
1304
+ "id": "a11a3768",
1305
+ "metadata": {},
1306
+ "outputs": [
1307
+ {
1308
+ "name": "stdout",
1309
+ "output_type": "stream",
1310
+ "text": [
1311
+ "@greateranglia Ok thanks...\n",
1312
+ "{'input_ids': tensor([[227985, 5484, 915, 2566, 14173, 2960, 29906, 387, 20706,\n",
1313
+ " 49337, 1369, 77658, 915, 210]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}\n",
1314
+ "tensor([[227985, 5484, 915, 2566, 14173, 2960, 29906, 387, 20706,\n",
1315
+ " 49337, 1369, 77658, 915, 210, 1936, 106863, 3]],\n",
1316
+ " device='cuda:0')\n",
1317
+ "['Tweet text : @greateranglia Ok thanks... Label : no complaint']\n"
1318
+ ]
1319
+ }
1320
+ ],
1321
+ "source": [
1322
+ "model.to(device)\n",
1323
+ "model.eval()\n",
1324
+ "i = 4\n",
1325
+ "inputs = tokenizer(f'{text_column} : {dataset[\"test\"][i][\"Tweet text\"]} Label : ', return_tensors=\"pt\")\n",
1326
+ "print(dataset[\"test\"][i][\"Tweet text\"])\n",
1327
+ "print(inputs)\n",
1328
+ "\n",
1329
+ "with torch.no_grad():\n",
1330
+ " inputs = {k: v.to(device) for k, v in inputs.items()}\n",
1331
+ " outputs = model.generate(\n",
1332
+ " input_ids=inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"], max_new_tokens=10, eos_token_id=3\n",
1333
+ " )\n",
1334
+ " print(outputs)\n",
1335
+ " print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"
1336
+ ]
1337
+ },
1338
+ {
1339
+ "cell_type": "code",
1340
+ "execution_count": null,
1341
+ "id": "f890c951",
1342
+ "metadata": {},
1343
+ "outputs": [],
1344
+ "source": []
1345
+ },
1346
+ {
1347
+ "cell_type": "code",
1348
+ "execution_count": null,
1349
+ "id": "463a41a2",
1350
+ "metadata": {},
1351
+ "outputs": [],
1352
+ "source": []
1353
+ },
1354
+ {
1355
+ "cell_type": "code",
1356
+ "execution_count": null,
1357
+ "id": "5c60c7a9",
1358
+ "metadata": {},
1359
+ "outputs": [],
1360
+ "source": []
1361
+ }
1362
+ ],
1363
+ "metadata": {
1364
+ "kernelspec": {
1365
+ "display_name": "Python 3 (ipykernel)",
1366
+ "language": "python",
1367
+ "name": "python3"
1368
+ },
1369
+ "language_info": {
1370
+ "codemirror_mode": {
1371
+ "name": "ipython",
1372
+ "version": 3
1373
+ },
1374
+ "file_extension": ".py",
1375
+ "mimetype": "text/x-python",
1376
+ "name": "python",
1377
+ "nbconvert_exporter": "python",
1378
+ "pygments_lexer": "ipython3",
1379
+ "version": "3.10.5"
1380
+ },
1381
+ "vscode": {
1382
+ "interpreter": {
1383
+ "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
1384
+ }
1385
+ }
1386
+ },
1387
+ "nbformat": 4,
1388
+ "nbformat_minor": 5
1389
+ }