Xuuuuuuuu commited on
Commit
3e66b7c
·
verified ·
1 Parent(s): 3c99da8

Upload ft-test.ipynb

Browse files
Files changed (1) hide show
  1. ft-test.ipynb +931 -0
ft-test.ipynb ADDED
@@ -0,0 +1,931 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "be94e6d6-4096-4d1a-aa58-5afd89f33bff",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Fine-tuning Sandbox\n",
9
+ "\n",
10
+ "Code authored by: Shawhin Talebi <br>\n",
11
+ "Blog link: https://medium.com/towards-data-science/fine-tuning-large-language-models-llms-23473d763b91"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 1,
17
+ "id": "4ef8ea85-d04d-4217-99a3-21c446bf2ffa",
18
+ "metadata": {},
19
+ "outputs": [
20
+ {
21
+ "name": "stdout",
22
+ "output_type": "stream",
23
+ "text": [
24
+ "WARNING:tensorflow:From C:\\Users\\Administrator\\AppData\\Roaming\\Python\\Python39\\site-packages\\keras\\src\\losses.py:2976: The name tf.losses.sparse_softmax_cross_entropy is deprecated. Please use tf.compat.v1.losses.sparse_softmax_cross_entropy instead.\n",
25
+ "\n"
26
+ ]
27
+ }
28
+ ],
29
+ "source": [
30
+ "from datasets import load_dataset, DatasetDict, Dataset\n",
31
+ "\n",
32
+ "from transformers import (\n",
33
+ " AutoTokenizer,\n",
34
+ " AutoConfig, \n",
35
+ " AutoModelForSequenceClassification,\n",
36
+ " DataCollatorWithPadding,\n",
37
+ " TrainingArguments,\n",
38
+ " Trainer)\n",
39
+ "# PEFT的全称是Parameter-Efficient Fine-Tuning,是transform开发的一个参数高效微调的库\n",
40
+ "from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig\n",
41
+ "import evaluate\n",
42
+ "import torch\n",
43
+ "import numpy as np"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "markdown",
48
+ "id": "aa6a4484-07d8-49dd-81ef-672105f53ebe",
49
+ "metadata": {},
50
+ "source": [
51
+ "### dataset"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": 2,
57
+ "id": "fa9722d3-0609-4aea-9585-9aa2cfc1fc9a",
58
+ "metadata": {
59
+ "jupyter": {
60
+ "source_hidden": true
61
+ }
62
+ },
63
+ "outputs": [],
64
+ "source": [
65
+ "# # how dataset was generated\n",
66
+ "\n",
67
+ "# # load imdb data\n",
68
+ "# imdb_dataset = load_dataset(\"imdb\")\n",
69
+ "\n",
70
+ "# # define subsample size\n",
71
+ "# N = 1000 \n",
72
+ "# # generate indexes for random subsample\n",
73
+ "# rand_idx = np.random.randint(24999, size=N) \n",
74
+ "\n",
75
+ "# # extract train and test data\n",
76
+ "# x_train = imdb_dataset['train'][rand_idx]['text']\n",
77
+ "# y_train = imdb_dataset['train'][rand_idx]['label']\n",
78
+ "\n",
79
+ "# x_test = imdb_dataset['test'][rand_idx]['text']\n",
80
+ "# y_test = imdb_dataset['test'][rand_idx]['label']\n",
81
+ "\n",
82
+ "# # create new dataset\n",
83
+ "# dataset = DatasetDict({'train':Dataset.from_dict({'label':y_train,'text':x_train}),\n",
84
+ "# 'validation':Dataset.from_dict({'label':y_test,'text':x_test})})"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": 3,
90
+ "id": "de226234-c521-4577-802c-0e7079ef4364",
91
+ "metadata": {},
92
+ "outputs": [
93
+ {
94
+ "data": {
95
+ "text/plain": [
96
+ "DatasetDict({\n",
97
+ " train: Dataset({\n",
98
+ " features: ['label', 'text'],\n",
99
+ " num_rows: 1000\n",
100
+ " })\n",
101
+ " validation: Dataset({\n",
102
+ " features: ['label', 'text'],\n",
103
+ " num_rows: 1000\n",
104
+ " })\n",
105
+ " test: Dataset({\n",
106
+ " features: ['label', 'text'],\n",
107
+ " num_rows: 1000\n",
108
+ " })\n",
109
+ "})"
110
+ ]
111
+ },
112
+ "execution_count": 3,
113
+ "metadata": {},
114
+ "output_type": "execute_result"
115
+ }
116
+ ],
117
+ "source": [
118
+ "# 加载数据集 训练 验证 测试\n",
119
+ "dataset = load_dataset('shawhin/imdb-truncated')\n",
120
+ "dataset"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": 4,
126
+ "id": "d5625faa-5fea-4334-bd38-b77de983d8a8",
127
+ "metadata": {},
128
+ "outputs": [
129
+ {
130
+ "data": {
131
+ "text/plain": [
132
+ "0.5"
133
+ ]
134
+ },
135
+ "execution_count": 4,
136
+ "metadata": {},
137
+ "output_type": "execute_result"
138
+ }
139
+ ],
140
+ "source": [
141
+ "# 得出训练集标签的平均值\n",
142
+ "np.array(dataset['train']['label']).sum()/len(dataset['train']['label'])"
143
+ ]
144
+ },
145
+ {
146
+ "cell_type": "markdown",
147
+ "id": "3644c68d-9adf-48a4-90a2-8fd89555a302",
148
+ "metadata": {},
149
+ "source": [
150
+ "### model"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": 5,
156
+ "id": "a60dd1fe-8144-4678-b018-20891e49237a",
157
+ "metadata": {},
158
+ "outputs": [
159
+ {
160
+ "name": "stderr",
161
+ "output_type": "stream",
162
+ "text": [
163
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
164
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
165
+ ]
166
+ }
167
+ ],
168
+ "source": [
169
+ "model_checkpoint = 'distilbert-base-uncased'\n",
170
+ "\n",
171
+ "# 类别的映射关系\n",
172
+ "id2label = {0: \"Negative\", 1: \"Positive\"}\n",
173
+ "label2id = {\"Negative\":0, \"Positive\":1}\n",
174
+ "\n",
175
+ "# 加载预训练的权重 num_labels指明是二分类任务 model_checkpoint 预训练模型的名称\n",
176
+ "model = AutoModelForSequenceClassification.from_pretrained(\n",
177
+ " model_checkpoint, num_labels=2, id2label=id2label, label2id=label2id)"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": 6,
183
+ "id": "853002f8-d39c-4bc4-8d07-e44a47de3b47",
184
+ "metadata": {},
185
+ "outputs": [],
186
+ "source": [
187
+ "# display architecture\n",
188
+ "model = model.cuda()"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "markdown",
193
+ "id": "4bc98609-873d-455c-bac4-155632cda484",
194
+ "metadata": {},
195
+ "source": [
196
+ "### 预处理数据"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "raw",
201
+ "id": "93e728f3-9e12-400d-950e-f7f2e29fe19e",
202
+ "metadata": {},
203
+ "source": [
204
+ "add_prefix_space参数告诉 tokenizer 在处理单词和标点符号之间添加一个前缀空格 前缀空格(表示为 Ġ)\n",
205
+ "# 原始句子\n",
206
+ "sentence = \"Hello, world!\"\n",
207
+ "['ĠHello', ',', 'Ġworld', '!']"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "code",
212
+ "execution_count": 7,
213
+ "id": "7fe08707-657f-4e66-aa72-84899c54bf8d",
214
+ "metadata": {},
215
+ "outputs": [],
216
+ "source": [
217
+ "# 创建分词器\n",
218
+ "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)\n",
219
+ "\n",
220
+ "# 判断是否有填充标记 通过 resize_token_embeddings 方法调整模型的 token embeddings,以包含新添加的 pad token。\n",
221
+ "if tokenizer.pad_token is None:\n",
222
+ " tokenizer.add_special_tokens({'pad_token': '[PAD]'})\n",
223
+ " model.resize_token_embeddings(len(tokenizer))"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "code",
228
+ "execution_count": 8,
229
+ "id": "20f4adb9-ce8f-4f54-9b94-300c9daae1b8",
230
+ "metadata": {},
231
+ "outputs": [],
232
+ "source": [
233
+ "# 创建分词器函数\n",
234
+ "def tokenize_function(examples):\n",
235
+ " # 提取文本\n",
236
+ " text = examples[\"text\"]\n",
237
+ "\n",
238
+ " # 设置 tokenizer 的截断位置为左侧。这意味着如果文本超过指定的 max_length,则在左侧截断。这是为了确保重要的文本内容被保留下来。\n",
239
+ " tokenizer.truncation_side = \"left\"\n",
240
+ " tokenized_inputs = tokenizer(\n",
241
+ " text,\n",
242
+ " # 返回numpy 类型\n",
243
+ " return_tensors=\"np\",\n",
244
+ " # 是否进行文本截断\n",
245
+ " truncation=True,\n",
246
+ " max_length=512\n",
247
+ " )\n",
248
+ "\n",
249
+ " return tokenized_inputs"
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "code",
254
+ "execution_count": 9,
255
+ "id": "b7600bcd-7e93-4fb4-bd8d-ffc76bed1ac2",
256
+ "metadata": {},
257
+ "outputs": [
258
+ {
259
+ "data": {
260
+ "application/vnd.jupyter.widget-view+json": {
261
+ "model_id": "c029f605df0e4e3c9484aa97af255052",
262
+ "version_major": 2,
263
+ "version_minor": 0
264
+ },
265
+ "text/plain": [
266
+ "Map: 0%| | 0/1000 [00:00<?, ? examples/s]"
267
+ ]
268
+ },
269
+ "metadata": {},
270
+ "output_type": "display_data"
271
+ },
272
+ {
273
+ "data": {
274
+ "text/plain": [
275
+ "DatasetDict({\n",
276
+ " train: Dataset({\n",
277
+ " features: ['label', 'text', 'input_ids', 'attention_mask'],\n",
278
+ " num_rows: 1000\n",
279
+ " })\n",
280
+ " validation: Dataset({\n",
281
+ " features: ['label', 'text', 'input_ids', 'attention_mask'],\n",
282
+ " num_rows: 1000\n",
283
+ " })\n",
284
+ " test: Dataset({\n",
285
+ " features: ['label', 'text', 'input_ids', 'attention_mask'],\n",
286
+ " num_rows: 1000\n",
287
+ " })\n",
288
+ "})"
289
+ ]
290
+ },
291
+ "execution_count": 9,
292
+ "metadata": {},
293
+ "output_type": "execute_result"
294
+ }
295
+ ],
296
+ "source": [
297
+ "# tokenize training and validation datasets\n",
298
+ "tokenized_dataset = dataset.map(tokenize_function, batched=True)\n",
299
+ "tokenized_dataset"
300
+ ]
301
+ },
302
+ {
303
+ "cell_type": "code",
304
+ "execution_count": 10,
305
+ "id": "3f8e85f9-1804-4f49-a783-4da59580ea1e",
306
+ "metadata": {},
307
+ "outputs": [],
308
+ "source": [
309
+ "# 创建数据收集器\n",
310
+ "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
311
+ ]
312
+ },
313
+ {
314
+ "cell_type": "markdown",
315
+ "id": "3cd9a120-580d-470c-a981-7c7e22604865",
316
+ "metadata": {},
317
+ "source": [
318
+ "### evaluation"
319
+ ]
320
+ },
321
+ {
322
+ "cell_type": "code",
323
+ "execution_count": 11,
324
+ "id": "2a894819-2e9c-4a53-9790-32130c182bca",
325
+ "metadata": {},
326
+ "outputs": [
327
+ {
328
+ "name": "stderr",
329
+ "output_type": "stream",
330
+ "text": [
331
+ "Using the latest cached version of the module from C:\\Users\\Administrator\\.cache\\huggingface\\modules\\evaluate_modules\\metrics\\evaluate-metric--accuracy\\f887c0aab52c2d38e1f8a215681126379eca617f96c447638f751434e8e65b14 (last modified on Fri Mar 15 09:54:33 2024) since it couldn't be found locally at evaluate-metric--accuracy, or remotely on the Hugging Face Hub.\n"
332
+ ]
333
+ }
334
+ ],
335
+ "source": [
336
+ "# import accuracy evaluation metric\n",
337
+ "accuracy = evaluate.load(\"accuracy\")"
338
+ ]
339
+ },
340
+ {
341
+ "cell_type": "code",
342
+ "execution_count": 12,
343
+ "id": "c07b9be2-a3f6-4b38-b9e8-6a2bc8aa945a",
344
+ "metadata": {},
345
+ "outputs": [],
346
+ "source": [
347
+ "# define an evaluation function to pass into trainer later\n",
348
+ "def compute_metrics(p):\n",
349
+ " predictions, labels = p\n",
350
+ " predictions = np.argmax(predictions, axis=1)\n",
351
+ " # 计算预测结果和真实标签 返回准确率\n",
352
+ " return {\"accuracy\": accuracy.compute(predictions=predictions, references=labels)}"
353
+ ]
354
+ },
355
+ {
356
+ "cell_type": "markdown",
357
+ "id": "47500035-a555-46e0-83dc-440586d96b7e",
358
+ "metadata": {},
359
+ "source": [
360
+ "### Apply untrained model to text"
361
+ ]
362
+ },
363
+ {
364
+ "cell_type": "code",
365
+ "execution_count": 13,
366
+ "id": "8f3761c1-a297-45c8-882e-d74856259810",
367
+ "metadata": {},
368
+ "outputs": [
369
+ {
370
+ "name": "stdout",
371
+ "output_type": "stream",
372
+ "text": [
373
+ "Untrained model predictions:\n",
374
+ "----------------------------\n",
375
+ "I'm sorry. - Negative\n",
376
+ "You areedespicable person - Negative\n",
377
+ "Better than the first one. - Negative\n",
378
+ "This is not worth watching even once. - Negative\n",
379
+ "This one is a pass. - Negative\n"
380
+ ]
381
+ }
382
+ ],
383
+ "source": [
384
+ "# define list of examples\n",
385
+ "text_list = [\"I'm sorry.\", \"You areedespicable person\", \"Better than the first one.\", \"This is not worth watching even once.\", \"This one is a pass.\"]\n",
386
+ "\n",
387
+ "print(\"Untrained model predictions:\")\n",
388
+ "print(\"----------------------------\")\n",
389
+ "for text in text_list:\n",
390
+ " # 将文本转化为可以理解的编码 并返回pytorch张量\n",
391
+ " inputs = tokenizer.encode(text, return_tensors=\"pt\")\n",
392
+ " # 计算对数\n",
393
+ " logits = model(inputs.cuda()).logits\n",
394
+ " # convert logits to label\n",
395
+ " predictions = torch.argmax(logits)\n",
396
+ "\n",
397
+ " print(text + \" - \" + id2label[predictions.tolist()])"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "markdown",
402
+ "id": "ff356f78-c9fd-4f2b-8f5b-097cf29c1c08",
403
+ "metadata": {},
404
+ "source": [
405
+ "### Train model"
406
+ ]
407
+ },
408
+ {
409
+ "cell_type": "code",
410
+ "execution_count": 14,
411
+ "id": "e4dde538-cd7f-4ab5-a96d-c30f3003822e",
412
+ "metadata": {},
413
+ "outputs": [],
414
+ "source": [
415
+ "peft_config = LoraConfig(task_type=\"SEQ_CLS\", # 序列分类任务\n",
416
+ " r = 4, # 递归深度\n",
417
+ " lora_alpha = 32, # alpha 值表示 LORA 模块的影响更大。\n",
418
+ " lora_dropout = 0.01,\n",
419
+ " target_modules = ['q_lin'])"
420
+ ]
421
+ },
422
+ {
423
+ "cell_type": "code",
424
+ "execution_count": 15,
425
+ "id": "f1391303-1e16-4d5c-b2b4-799997eff9f8",
426
+ "metadata": {},
427
+ "outputs": [
428
+ {
429
+ "data": {
430
+ "text/plain": [
431
+ "LoraConfig(peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, base_model_name_or_path=None, revision=None, task_type='SEQ_CLS', inference_mode=False, r=4, target_modules={'q_lin'}, lora_alpha=32, lora_dropout=0.01, fan_in_fan_out=False, bias='none', use_rslora=False, modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', loftq_config={}, use_dora=False)"
432
+ ]
433
+ },
434
+ "execution_count": 15,
435
+ "metadata": {},
436
+ "output_type": "execute_result"
437
+ }
438
+ ],
439
+ "source": [
440
+ "peft_config"
441
+ ]
442
+ },
443
+ {
444
+ "cell_type": "code",
445
+ "execution_count": 16,
446
+ "id": "3e0d9408-9fc4-4bd3-8d35-4d8217fe01e2",
447
+ "metadata": {},
448
+ "outputs": [
449
+ {
450
+ "name": "stdout",
451
+ "output_type": "stream",
452
+ "text": [
453
+ "trainable params: 628,994 || all params: 67,584,004 || trainable%: 0.9306847223789819\n"
454
+ ]
455
+ }
456
+ ],
457
+ "source": [
458
+ "# 对模型进行配置\n",
459
+ "model = get_peft_model(model, peft_config)\n",
460
+ "model.print_trainable_parameters()"
461
+ ]
462
+ },
463
+ {
464
+ "cell_type": "code",
465
+ "execution_count": 17,
466
+ "id": "5db78059-e5ae-4807-89db-b58ef6abedd1",
467
+ "metadata": {},
468
+ "outputs": [],
469
+ "source": [
470
+ "# hyperparameters\n",
471
+ "lr = 1e-3\n",
472
+ "batch_size = 4\n",
473
+ "num_epochs = 10"
474
+ ]
475
+ },
476
+ {
477
+ "cell_type": "code",
478
+ "execution_count": 18,
479
+ "id": "9244ed55-65a4-4c66-8388-55efd87bceb8",
480
+ "metadata": {},
481
+ "outputs": [],
482
+ "source": [
483
+ "# define training arguments\n",
484
+ "training_args = TrainingArguments(\n",
485
+ " output_dir= model_checkpoint + \"-lora-text-classification\",\n",
486
+ " learning_rate=lr,\n",
487
+ " per_device_train_batch_size=batch_size,\n",
488
+ " per_device_eval_batch_size=batch_size,\n",
489
+ " num_train_epochs=num_epochs,\n",
490
+ " weight_decay=0.01, # 权重衰减,一种正则化技术,用于控制模型参数的大小。\n",
491
+ " evaluation_strategy=\"epoch\",\n",
492
+ " save_strategy=\"epoch\",\n",
493
+ " load_best_model_at_end=True, # 是否在训练结束加载最佳模型\n",
494
+ ")"
495
+ ]
496
+ },
497
+ {
498
+ "cell_type": "markdown",
499
+ "id": "6e21aa23-a366-4606-b13b-ad22e4639272",
500
+ "metadata": {},
501
+ "source": [
502
+ "### "
503
+ ]
504
+ },
505
+ {
506
+ "cell_type": "code",
507
+ "execution_count": 19,
508
+ "id": "fc8bc705-5dd7-4305-a797-399b2b0fa2c7",
509
+ "metadata": {},
510
+ "outputs": [
511
+ {
512
+ "name": "stderr",
513
+ "output_type": "stream",
514
+ "text": [
515
+ "D:\\software\\Anaconda\\envs\\Work1\\lib\\site-packages\\accelerate\\accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: \n",
516
+ "dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)\n",
517
+ " warnings.warn(\n",
518
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m1321416285\u001b[0m (\u001b[33mxuuuu\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
519
+ ]
520
+ },
521
+ {
522
+ "data": {
523
+ "text/html": [
524
+ "wandb version 0.16.4 is available! To upgrade, please run:\n",
525
+ " $ pip install wandb --upgrade"
526
+ ],
527
+ "text/plain": [
528
+ "<IPython.core.display.HTML object>"
529
+ ]
530
+ },
531
+ "metadata": {},
532
+ "output_type": "display_data"
533
+ },
534
+ {
535
+ "data": {
536
+ "text/html": [
537
+ "Tracking run with wandb version 0.15.12"
538
+ ],
539
+ "text/plain": [
540
+ "<IPython.core.display.HTML object>"
541
+ ]
542
+ },
543
+ "metadata": {},
544
+ "output_type": "display_data"
545
+ },
546
+ {
547
+ "data": {
548
+ "text/html": [
549
+ "Run data is saved locally in <code>D:\\software\\Anaconda\\jupyterfile\\AIfinetuning\\wandb\\run-20240315_211852-07azjtzv</code>"
550
+ ],
551
+ "text/plain": [
552
+ "<IPython.core.display.HTML object>"
553
+ ]
554
+ },
555
+ "metadata": {},
556
+ "output_type": "display_data"
557
+ },
558
+ {
559
+ "data": {
560
+ "text/html": [
561
+ "Syncing run <strong><a href='https://wandb.ai/xuuuu/huggingface/runs/07azjtzv' target=\"_blank\">fast-firefly-2</a></strong> to <a href='https://wandb.ai/xuuuu/huggingface' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
562
+ ],
563
+ "text/plain": [
564
+ "<IPython.core.display.HTML object>"
565
+ ]
566
+ },
567
+ "metadata": {},
568
+ "output_type": "display_data"
569
+ },
570
+ {
571
+ "data": {
572
+ "text/html": [
573
+ " View project at <a href='https://wandb.ai/xuuuu/huggingface' target=\"_blank\">https://wandb.ai/xuuuu/huggingface</a>"
574
+ ],
575
+ "text/plain": [
576
+ "<IPython.core.display.HTML object>"
577
+ ]
578
+ },
579
+ "metadata": {},
580
+ "output_type": "display_data"
581
+ },
582
+ {
583
+ "data": {
584
+ "text/html": [
585
+ " View run at <a href='https://wandb.ai/xuuuu/huggingface/runs/07azjtzv' target=\"_blank\">https://wandb.ai/xuuuu/huggingface/runs/07azjtzv</a>"
586
+ ],
587
+ "text/plain": [
588
+ "<IPython.core.display.HTML object>"
589
+ ]
590
+ },
591
+ "metadata": {},
592
+ "output_type": "display_data"
593
+ },
594
+ {
595
+ "data": {
596
+ "text/html": [
597
+ "\n",
598
+ " <div>\n",
599
+ " \n",
600
+ " <progress value='2500' max='2500' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
601
+ " [2500/2500 02:44, Epoch 10/10]\n",
602
+ " </div>\n",
603
+ " <table border=\"1\" class=\"dataframe\">\n",
604
+ " <thead>\n",
605
+ " <tr style=\"text-align: left;\">\n",
606
+ " <th>Epoch</th>\n",
607
+ " <th>Training Loss</th>\n",
608
+ " <th>Validation Loss</th>\n",
609
+ " <th>Accuracy</th>\n",
610
+ " </tr>\n",
611
+ " </thead>\n",
612
+ " <tbody>\n",
613
+ " <tr>\n",
614
+ " <td>1</td>\n",
615
+ " <td>No log</td>\n",
616
+ " <td>0.438809</td>\n",
617
+ " <td>{'accuracy': 0.855}</td>\n",
618
+ " </tr>\n",
619
+ " <tr>\n",
620
+ " <td>2</td>\n",
621
+ " <td>0.427600</td>\n",
622
+ " <td>0.648398</td>\n",
623
+ " <td>{'accuracy': 0.859}</td>\n",
624
+ " </tr>\n",
625
+ " <tr>\n",
626
+ " <td>3</td>\n",
627
+ " <td>0.427600</td>\n",
628
+ " <td>0.637398</td>\n",
629
+ " <td>{'accuracy': 0.877}</td>\n",
630
+ " </tr>\n",
631
+ " <tr>\n",
632
+ " <td>4</td>\n",
633
+ " <td>0.218100</td>\n",
634
+ " <td>0.689158</td>\n",
635
+ " <td>{'accuracy': 0.889}</td>\n",
636
+ " </tr>\n",
637
+ " <tr>\n",
638
+ " <td>5</td>\n",
639
+ " <td>0.218100</td>\n",
640
+ " <td>0.774748</td>\n",
641
+ " <td>{'accuracy': 0.897}</td>\n",
642
+ " </tr>\n",
643
+ " <tr>\n",
644
+ " <td>6</td>\n",
645
+ " <td>0.073100</td>\n",
646
+ " <td>0.846054</td>\n",
647
+ " <td>{'accuracy': 0.887}</td>\n",
648
+ " </tr>\n",
649
+ " <tr>\n",
650
+ " <td>7</td>\n",
651
+ " <td>0.073100</td>\n",
652
+ " <td>0.946100</td>\n",
653
+ " <td>{'accuracy': 0.894}</td>\n",
654
+ " </tr>\n",
655
+ " <tr>\n",
656
+ " <td>8</td>\n",
657
+ " <td>0.015500</td>\n",
658
+ " <td>0.941895</td>\n",
659
+ " <td>{'accuracy': 0.901}</td>\n",
660
+ " </tr>\n",
661
+ " <tr>\n",
662
+ " <td>9</td>\n",
663
+ " <td>0.015500</td>\n",
664
+ " <td>0.994161</td>\n",
665
+ " <td>{'accuracy': 0.898}</td>\n",
666
+ " </tr>\n",
667
+ " <tr>\n",
668
+ " <td>10</td>\n",
669
+ " <td>0.006700</td>\n",
670
+ " <td>0.999837</td>\n",
671
+ " <td>{'accuracy': 0.897}</td>\n",
672
+ " </tr>\n",
673
+ " </tbody>\n",
674
+ "</table><p>"
675
+ ],
676
+ "text/plain": [
677
+ "<IPython.core.display.HTML object>"
678
+ ]
679
+ },
680
+ "metadata": {},
681
+ "output_type": "display_data"
682
+ },
683
+ {
684
+ "name": "stderr",
685
+ "output_type": "stream",
686
+ "text": [
687
+ "Trainer is attempting to log a value of \"{'accuracy': 0.855}\" of type <class 'dict'> for key \"eval/accuracy\" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.\n",
688
+ "Checkpoint destination directory distilbert-base-uncased-lora-text-classification\\checkpoint-250 already exists and is non-empty. Saving will proceed but saved results may be invalid.\n",
689
+ "Trainer is attempting to log a value of \"{'accuracy': 0.859}\" of type <class 'dict'> for key \"eval/accuracy\" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.\n",
690
+ "Checkpoint destination directory distilbert-base-uncased-lora-text-classification\\checkpoint-500 already exists and is non-empty. Saving will proceed but saved results may be invalid.\n",
691
+ "Trainer is attempting to log a value of \"{'accuracy': 0.877}\" of type <class 'dict'> for key \"eval/accuracy\" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.\n",
692
+ "Checkpoint destination directory distilbert-base-uncased-lora-text-classification\\checkpoint-750 already exists and is non-empty. Saving will proceed but saved results may be invalid.\n",
693
+ "Trainer is attempting to log a value of \"{'accuracy': 0.889}\" of type <class 'dict'> for key \"eval/accuracy\" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.\n",
694
+ "Checkpoint destination directory distilbert-base-uncased-lora-text-classification\\checkpoint-1000 already exists and is non-empty. Saving will proceed but saved results may be invalid.\n",
695
+ "Trainer is attempting to log a value of \"{'accuracy': 0.897}\" of type <class 'dict'> for key \"eval/accuracy\" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.\n",
696
+ "Checkpoint destination directory distilbert-base-uncased-lora-text-classification\\checkpoint-1250 already exists and is non-empty. Saving will proceed but saved results may be invalid.\n",
697
+ "Trainer is attempting to log a value of \"{'accuracy': 0.887}\" of type <class 'dict'> for key \"eval/accuracy\" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.\n",
698
+ "Checkpoint destination directory distilbert-base-uncased-lora-text-classification\\checkpoint-1500 already exists and is non-empty. Saving will proceed but saved results may be invalid.\n",
699
+ "Trainer is attempting to log a value of \"{'accuracy': 0.894}\" of type <class 'dict'> for key \"eval/accuracy\" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.\n",
700
+ "Checkpoint destination directory distilbert-base-uncased-lora-text-classification\\checkpoint-1750 already exists and is non-empty. Saving will proceed but saved results may be invalid.\n",
701
+ "Trainer is attempting to log a value of \"{'accuracy': 0.901}\" of type <class 'dict'> for key \"eval/accuracy\" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.\n",
702
+ "Checkpoint destination directory distilbert-base-uncased-lora-text-classification\\checkpoint-2000 already exists and is non-empty. Saving will proceed but saved results may be invalid.\n",
703
+ "Trainer is attempting to log a value of \"{'accuracy': 0.898}\" of type <class 'dict'> for key \"eval/accuracy\" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.\n",
704
+ "Checkpoint destination directory distilbert-base-uncased-lora-text-classification\\checkpoint-2250 already exists and is non-empty. Saving will proceed but saved results may be invalid.\n",
705
+ "Trainer is attempting to log a value of \"{'accuracy': 0.897}\" of type <class 'dict'> for key \"eval/accuracy\" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.\n",
706
+ "Checkpoint destination directory distilbert-base-uncased-lora-text-classification\\checkpoint-2500 already exists and is non-empty. Saving will proceed but saved results may be invalid.\n"
707
+ ]
708
+ },
709
+ {
710
+ "data": {
711
+ "text/plain": [
712
+ "TrainOutput(global_step=2500, training_loss=0.14819346437454223, metrics={'train_runtime': 174.6372, 'train_samples_per_second': 57.262, 'train_steps_per_second': 14.315, 'total_flos': 1112883852759936.0, 'train_loss': 0.14819346437454223, 'epoch': 10.0})"
713
+ ]
714
+ },
715
+ "execution_count": 19,
716
+ "metadata": {},
717
+ "output_type": "execute_result"
718
+ }
719
+ ],
720
+ "source": [
721
+ "# creater trainer object\n",
722
+ "trainer = Trainer(\n",
723
+ " model=model,\n",
724
+ " args=training_args,\n",
725
+ " train_dataset=tokenized_dataset[\"train\"],\n",
726
+ " eval_dataset=tokenized_dataset[\"validation\"],\n",
727
+ " tokenizer=tokenizer,\n",
728
+ " data_collator=data_collator, # this will dynamically pad examples in each batch to be equal length\n",
729
+ " compute_metrics=compute_metrics, \n",
730
+ ")\n",
731
+ "\n",
732
+ "# train model\n",
733
+ "trainer.train()"
734
+ ]
735
+ },
736
+ {
737
+ "cell_type": "markdown",
738
+ "id": "6f5664d1-9bd2-4ce1-bc24-cab5adf80f49",
739
+ "metadata": {},
740
+ "source": [
741
+ "### Generate prediction"
742
+ ]
743
+ },
744
+ {
745
+ "cell_type": "code",
746
+ "execution_count": 20,
747
+ "id": "e5dc029e-1c16-491d-a3f1-715f9e0adf52",
748
+ "metadata": {},
749
+ "outputs": [
750
+ {
751
+ "name": "stdout",
752
+ "output_type": "stream",
753
+ "text": [
754
+ "Trained model predictions:\n",
755
+ "--------------------------\n",
756
+ "I'm sorry. - Negative\n",
757
+ "You areedespicable person - Positive\n",
758
+ "Better than the first one. - Positive\n",
759
+ "This is not worth watching even once. - Negative\n",
760
+ "This one is a pass. - Negative\n"
761
+ ]
762
+ }
763
+ ],
764
+ "source": [
765
+ "model.to('cuda') # moving to mps for Mac (can alternatively do 'cpu')\n",
766
+ "\n",
767
+ "print(\"Trained model predictions:\")\n",
768
+ "print(\"--------------------------\")\n",
769
+ "for text in text_list:\n",
770
+ " inputs = tokenizer.encode(text, return_tensors=\"pt\").to(\"cuda\") # moving to mps for Mac (can alternatively do 'cpu')\n",
771
+ "\n",
772
+ " logits = model(inputs).logits\n",
773
+ " predictions = torch.max(logits,1).indices\n",
774
+ "\n",
775
+ " print(text + \" - \" + id2label[predictions.tolist()[0]])"
776
+ ]
777
+ },
778
+ {
779
+ "cell_type": "markdown",
780
+ "id": "c084bd9e-f7b1-4979-b753-73335ee0cede",
781
+ "metadata": {},
782
+ "source": [
783
+ "### Optional: push model to hub"
784
+ ]
785
+ },
786
+ {
787
+ "cell_type": "code",
788
+ "execution_count": 21,
789
+ "id": "159eb49a-dd0d-4c9e-b9ab-27e06585fd84",
790
+ "metadata": {},
791
+ "outputs": [
792
+ {
793
+ "data": {
794
+ "application/vnd.jupyter.widget-view+json": {
795
+ "model_id": "a0e23e8a27634de78c21c18041cd010f",
796
+ "version_major": 2,
797
+ "version_minor": 0
798
+ },
799
+ "text/plain": [
800
+ "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
801
+ ]
802
+ },
803
+ "metadata": {},
804
+ "output_type": "display_data"
805
+ }
806
+ ],
807
+ "source": [
808
+ "# option 1: notebook login\n",
809
+ "from huggingface_hub import notebook_login\n",
810
+ "notebook_login() # ensure token gives write access\n",
811
+ "\n",
812
+ "# # option 2: key login\n",
813
+ "# from huggingface_hub import login\n",
814
+ "# write_key = 'hf_' # paste token here\n",
815
+ "# login(write_key)"
816
+ ]
817
+ },
818
+ {
819
+ "cell_type": "code",
820
+ "execution_count": 22,
821
+ "id": "09496307-e253-47e3-a46f-3f28a84c89a7",
822
+ "metadata": {},
823
+ "outputs": [],
824
+ "source": [
825
+ "hf_name = 'shawhin' # your hf username or org name\n",
826
+ "model_id = hf_name + \"/\" + model_checkpoint + \"-lora-text-classification\" # you can name the model whatever you want"
827
+ ]
828
+ },
829
+ {
830
+ "cell_type": "code",
831
+ "execution_count": 23,
832
+ "id": "c56ea581-0ea3-45f3-af21-362e9093ee37",
833
+ "metadata": {},
834
+ "outputs": [
835
+ {
836
+ "ename": "HfHubHTTPError",
837
+ "evalue": "403 Client Error: Forbidden for url: https://huggingface.co/shawhin/distilbert-base-uncased-lora-text-classification.git/info/lfs/objects/batch (Request ID: Root=1-65f44b6d-3a7059390bd0f46b3618a6e6;b93e4a6f-c6a2-4179-8d62-ec4b3235048e)\n\nAuthorization error.",
838
+ "output_type": "error",
839
+ "traceback": [
840
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
841
+ "\u001b[1;31mHTTPError\u001b[0m Traceback (most recent call last)",
842
+ "File \u001b[1;32mD:\\software\\Anaconda\\envs\\Work1\\lib\\site-packages\\huggingface_hub\\utils\\_errors.py:304\u001b[0m, in \u001b[0;36mhf_raise_for_status\u001b[1;34m(response, endpoint_name)\u001b[0m\n\u001b[0;32m 303\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 304\u001b[0m \u001b[43mresponse\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mraise_for_status\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 305\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m HTTPError \u001b[38;5;28;01mas\u001b[39;00m e:\n",
843
+ "File \u001b[1;32mD:\\software\\Anaconda\\envs\\Work1\\lib\\site-packages\\requests\\models.py:943\u001b[0m, in \u001b[0;36mResponse.raise_for_status\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 942\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m http_error_msg:\n\u001b[1;32m--> 943\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m HTTPError(http_error_msg, response\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m)\n",
844
+ "\u001b[1;31mHTTPError\u001b[0m: 403 Client Error: Forbidden for url: https://huggingface.co/shawhin/distilbert-base-uncased-lora-text-classification.git/info/lfs/objects/batch",
845
+ "\nThe above exception was the direct cause of the following exception:\n",
846
+ "\u001b[1;31mHfHubHTTPError\u001b[0m Traceback (most recent call last)",
847
+ "Cell \u001b[1;32mIn[23], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpush_to_hub\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel_id\u001b[49m\u001b[43m)\u001b[49m\n",
848
+ "File \u001b[1;32mD:\\software\\Anaconda\\envs\\Work1\\lib\\site-packages\\transformers\\utils\\hub.py:894\u001b[0m, in \u001b[0;36mPushToHubMixin.push_to_hub\u001b[1;34m(self, repo_id, use_temp_dir, commit_message, private, token, max_shard_size, create_pr, safe_serialization, revision, commit_description, tags, **deprecated_kwargs)\u001b[0m\n\u001b[0;32m 891\u001b[0m \u001b[38;5;66;03m# Update model card if needed:\u001b[39;00m\n\u001b[0;32m 892\u001b[0m model_card\u001b[38;5;241m.\u001b[39msave(os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(work_dir, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mREADME.md\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m--> 894\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_upload_modified_files\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 895\u001b[0m \u001b[43m \u001b[49m\u001b[43mwork_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 896\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 897\u001b[0m \u001b[43m \u001b[49m\u001b[43mfiles_timestamps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 898\u001b[0m \u001b[43m \u001b[49m\u001b[43mcommit_message\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcommit_message\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 899\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 900\u001b[0m \u001b[43m \u001b[49m\u001b[43mcreate_pr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcreate_pr\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 901\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 902\u001b[0m \u001b[43m \u001b[49m\u001b[43mcommit_description\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcommit_description\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 903\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
849
+ "File \u001b[1;32mD:\\software\\Anaconda\\envs\\Work1\\lib\\site-packages\\transformers\\utils\\hub.py:758\u001b[0m, in \u001b[0;36mPushToHubMixin._upload_modified_files\u001b[1;34m(self, working_dir, repo_id, files_timestamps, commit_message, token, create_pr, revision, commit_description)\u001b[0m\n\u001b[0;32m 755\u001b[0m create_branch(repo_id\u001b[38;5;241m=\u001b[39mrepo_id, branch\u001b[38;5;241m=\u001b[39mrevision, token\u001b[38;5;241m=\u001b[39mtoken, exist_ok\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m 757\u001b[0m logger\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUploading the following files to \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mrepo_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m,\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(modified_files)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m--> 758\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcreate_commit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 759\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrepo_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 760\u001b[0m \u001b[43m \u001b[49m\u001b[43moperations\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moperations\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 761\u001b[0m \u001b[43m \u001b[49m\u001b[43mcommit_message\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcommit_message\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 762\u001b[0m \u001b[43m \u001b[49m\u001b[43mcommit_description\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcommit_description\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 763\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 764\u001b[0m \u001b[43m \u001b[49m\u001b[43mcreate_pr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcreate_pr\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 765\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 766\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
850
+ "File \u001b[1;32mD:\\software\\Anaconda\\envs\\Work1\\lib\\site-packages\\huggingface_hub\\utils\\_validators.py:118\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.<locals>._inner_fn\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 115\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m check_use_auth_token:\n\u001b[0;32m 116\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m smoothly_deprecate_use_auth_token(fn_name\u001b[38;5;241m=\u001b[39mfn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, has_token\u001b[38;5;241m=\u001b[39mhas_token, kwargs\u001b[38;5;241m=\u001b[39mkwargs)\n\u001b[1;32m--> 118\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m fn(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
851
+ "File \u001b[1;32mD:\\software\\Anaconda\\envs\\Work1\\lib\\site-packages\\huggingface_hub\\hf_api.py:1227\u001b[0m, in \u001b[0;36mfuture_compatible.<locals>._inner\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1224\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrun_as_future(fn, \u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1226\u001b[0m \u001b[38;5;66;03m# Otherwise, call the function normally\u001b[39;00m\n\u001b[1;32m-> 1227\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m fn(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
852
+ "File \u001b[1;32mD:\\software\\Anaconda\\envs\\Work1\\lib\\site-packages\\huggingface_hub\\hf_api.py:3762\u001b[0m, in \u001b[0;36mHfApi.create_commit\u001b[1;34m(self, repo_id, operations, commit_message, commit_description, token, repo_type, revision, create_pr, num_threads, parent_commit, run_as_future)\u001b[0m\n\u001b[0;32m 3759\u001b[0m \u001b[38;5;66;03m# If updating twice the same file or update then delete a file in a single commit\u001b[39;00m\n\u001b[0;32m 3760\u001b[0m _warn_on_overwriting_operations(operations)\n\u001b[1;32m-> 3762\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpreupload_lfs_files\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 3763\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrepo_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 3764\u001b[0m \u001b[43m \u001b[49m\u001b[43madditions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43madditions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 3765\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 3766\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrepo_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 3767\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43munquoted_revision\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# first-class methods take unquoted revision\u001b[39;49;00m\n\u001b[0;32m 3768\u001b[0m \u001b[43m \u001b[49m\u001b[43mcreate_pr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcreate_pr\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 3769\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_threads\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_threads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 3770\u001b[0m \u001b[43m \u001b[49m\u001b[43mfree_memory\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# do not remove `CommitOperationAdd.path_or_fileobj` on LFS files for \"normal\" users\u001b[39;49;00m\n\u001b[0;32m 3771\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 3772\u001b[0m files_to_copy \u001b[38;5;241m=\u001b[39m _fetch_files_to_copy(\n\u001b[0;32m 3773\u001b[0m copies\u001b[38;5;241m=\u001b[39mcopies,\n\u001b[0;32m 3774\u001b[0m repo_type\u001b[38;5;241m=\u001b[39mrepo_type,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 3778\u001b[0m endpoint\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mendpoint,\n\u001b[0;32m 3779\u001b[0m )\n\u001b[0;32m 3780\u001b[0m commit_payload \u001b[38;5;241m=\u001b[39m _prepare_commit_payload(\n\u001b[0;32m 3781\u001b[0m operations\u001b[38;5;241m=\u001b[39moperations,\n\u001b[0;32m 3782\u001b[0m files_to_copy\u001b[38;5;241m=\u001b[39mfiles_to_copy,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 3785\u001b[0m parent_commit\u001b[38;5;241m=\u001b[39mparent_commit,\n\u001b[0;32m 3786\u001b[0m )\n",
853
+ "File \u001b[1;32mD:\\software\\Anaconda\\envs\\Work1\\lib\\site-packages\\huggingface_hub\\hf_api.py:4262\u001b[0m, in \u001b[0;36mHfApi.preupload_lfs_files\u001b[1;34m(self, repo_id, additions, token, repo_type, revision, create_pr, num_threads, free_memory, gitignore_content)\u001b[0m\n\u001b[0;32m 4256\u001b[0m logger\u001b[38;5;241m.\u001b[39minfo(\n\u001b[0;32m 4257\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSkipped upload for \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(new_lfs_additions)\u001b[38;5;250m \u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;28mlen\u001b[39m(new_lfs_additions_to_upload)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m LFS file(s) \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 4258\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m(ignored by gitignore file).\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 4259\u001b[0m )\n\u001b[0;32m 4261\u001b[0m \u001b[38;5;66;03m# Upload new LFS files\u001b[39;00m\n\u001b[1;32m-> 4262\u001b[0m \u001b[43m_upload_lfs_files\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 4263\u001b[0m \u001b[43m \u001b[49m\u001b[43madditions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnew_lfs_additions_to_upload\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 4264\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrepo_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 4265\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrepo_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 4266\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 4267\u001b[0m \u001b[43m \u001b[49m\u001b[43mendpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mendpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 4268\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_threads\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_threads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 4269\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# If `create_pr`, we don't want to check user permission on the revision as users with read permission\u001b[39;49;00m\n\u001b[0;32m 4270\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# should still be able to create PRs even if they don't have write permission on the target branch of the\u001b[39;49;00m\n\u001b[0;32m 4271\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# PR (i.e. `revision`).\u001b[39;49;00m\n\u001b[0;32m 4272\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mcreate_pr\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 4273\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 4274\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m addition \u001b[38;5;129;01min\u001b[39;00m new_lfs_additions_to_upload:\n\u001b[0;32m 4275\u001b[0m addition\u001b[38;5;241m.\u001b[39m_is_uploaded \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n",
854
+ "File \u001b[1;32mD:\\software\\Anaconda\\envs\\Work1\\lib\\site-packages\\huggingface_hub\\utils\\_validators.py:118\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.<locals>._inner_fn\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 115\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m check_use_auth_token:\n\u001b[0;32m 116\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m smoothly_deprecate_use_auth_token(fn_name\u001b[38;5;241m=\u001b[39mfn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, has_token\u001b[38;5;241m=\u001b[39mhas_token, kwargs\u001b[38;5;241m=\u001b[39mkwargs)\n\u001b[1;32m--> 118\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m fn(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
855
+ "File \u001b[1;32mD:\\software\\Anaconda\\envs\\Work1\\lib\\site-packages\\huggingface_hub\\_commit_api.py:360\u001b[0m, in \u001b[0;36m_upload_lfs_files\u001b[1;34m(additions, repo_type, repo_id, token, endpoint, num_threads, revision)\u001b[0m\n\u001b[0;32m 358\u001b[0m batch_actions: List[Dict] \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m 359\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m chunk \u001b[38;5;129;01min\u001b[39;00m chunk_iterable(additions, chunk_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m256\u001b[39m):\n\u001b[1;32m--> 360\u001b[0m batch_actions_chunk, batch_errors_chunk \u001b[38;5;241m=\u001b[39m \u001b[43mpost_lfs_batch_info\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 361\u001b[0m \u001b[43m \u001b[49m\u001b[43mupload_infos\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[43mop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mupload_info\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mop\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mchunk\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 362\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 363\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrepo_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 364\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrepo_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 365\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 366\u001b[0m \u001b[43m \u001b[49m\u001b[43mendpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mendpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 367\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 369\u001b[0m \u001b[38;5;66;03m# If at least 1 error, we do not retrieve information for other chunks\u001b[39;00m\n\u001b[0;32m 370\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m batch_errors_chunk:\n",
856
+ "File \u001b[1;32mD:\\software\\Anaconda\\envs\\Work1\\lib\\site-packages\\huggingface_hub\\utils\\_validators.py:118\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.<locals>._inner_fn\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 115\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m check_use_auth_token:\n\u001b[0;32m 116\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m smoothly_deprecate_use_auth_token(fn_name\u001b[38;5;241m=\u001b[39mfn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, has_token\u001b[38;5;241m=\u001b[39mhas_token, kwargs\u001b[38;5;241m=\u001b[39mkwargs)\n\u001b[1;32m--> 118\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m fn(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
857
+ "File \u001b[1;32mD:\\software\\Anaconda\\envs\\Work1\\lib\\site-packages\\huggingface_hub\\lfs.py:159\u001b[0m, in \u001b[0;36mpost_lfs_batch_info\u001b[1;34m(upload_infos, token, repo_type, repo_id, revision, endpoint)\u001b[0m\n\u001b[0;32m 157\u001b[0m headers \u001b[38;5;241m=\u001b[39m {\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mLFS_HEADERS, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mbuild_hf_headers(token\u001b[38;5;241m=\u001b[39mtoken \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m)} \u001b[38;5;66;03m# Token must be provided or retrieved\u001b[39;00m\n\u001b[0;32m 158\u001b[0m resp \u001b[38;5;241m=\u001b[39m get_session()\u001b[38;5;241m.\u001b[39mpost(batch_url, headers\u001b[38;5;241m=\u001b[39mheaders, json\u001b[38;5;241m=\u001b[39mpayload)\n\u001b[1;32m--> 159\u001b[0m \u001b[43mhf_raise_for_status\u001b[49m\u001b[43m(\u001b[49m\u001b[43mresp\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 160\u001b[0m batch_info \u001b[38;5;241m=\u001b[39m resp\u001b[38;5;241m.\u001b[39mjson()\n\u001b[0;32m 162\u001b[0m objects \u001b[38;5;241m=\u001b[39m batch_info\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mobjects\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n",
858
+ "File \u001b[1;32mD:\\software\\Anaconda\\envs\\Work1\\lib\\site-packages\\huggingface_hub\\utils\\_errors.py:362\u001b[0m, in \u001b[0;36mhf_raise_for_status\u001b[1;34m(response, endpoint_name)\u001b[0m\n\u001b[0;32m 358\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m BadRequestError(message, response\u001b[38;5;241m=\u001b[39mresponse) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n\u001b[0;32m 360\u001b[0m \u001b[38;5;66;03m# Convert `HTTPError` into a `HfHubHTTPError` to display request information\u001b[39;00m\n\u001b[0;32m 361\u001b[0m \u001b[38;5;66;03m# as well (request id and/or server error message)\u001b[39;00m\n\u001b[1;32m--> 362\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m HfHubHTTPError(\u001b[38;5;28mstr\u001b[39m(e), response\u001b[38;5;241m=\u001b[39mresponse) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n",
859
+ "\u001b[1;31mHfHubHTTPError\u001b[0m: 403 Client Error: Forbidden for url: https://huggingface.co/shawhin/distilbert-base-uncased-lora-text-classification.git/info/lfs/objects/batch (Request ID: Root=1-65f44b6d-3a7059390bd0f46b3618a6e6;b93e4a6f-c6a2-4179-8d62-ec4b3235048e)\n\nAuthorization error."
860
+ ]
861
+ }
862
+ ],
863
+ "source": [
864
+ "model.push_to_hub(model_id) # save model"
865
+ ]
866
+ },
867
+ {
868
+ "cell_type": "code",
869
+ "execution_count": null,
870
+ "id": "f487331a-8552-4fb2-867f-985b8fe1d1ab",
871
+ "metadata": {},
872
+ "outputs": [],
873
+ "source": [
874
+ "trainer.push_to_hub(model_id) # save trainer"
875
+ ]
876
+ },
877
+ {
878
+ "cell_type": "markdown",
879
+ "id": "00e7feaa-b70e-4b1d-a118-23c616d14639",
880
+ "metadata": {},
881
+ "source": [
882
+ "### Optional: load peft model"
883
+ ]
884
+ },
885
+ {
886
+ "cell_type": "code",
887
+ "execution_count": null,
888
+ "id": "19cffa01-25a4-4c86-a7fa-a84353b8caae",
889
+ "metadata": {},
890
+ "outputs": [],
891
+ "source": [
892
+ "# how to load peft model from hub for inference\n",
893
+ "config = PeftConfig.from_pretrained(model_id)\n",
894
+ "inference_model = AutoModelForSequenceClassification.from_pretrained(\n",
895
+ " config.base_model_name_or_path, num_labels=2, id2label=id2label, label2id=label2id\n",
896
+ ")\n",
897
+ "tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n",
898
+ "model = PeftModel.from_pretrained(inference_model, model_id)"
899
+ ]
900
+ },
901
+ {
902
+ "cell_type": "code",
903
+ "execution_count": null,
904
+ "id": "77c6ed42-8ec3-4343-9e42-405feac052ba",
905
+ "metadata": {},
906
+ "outputs": [],
907
+ "source": []
908
+ }
909
+ ],
910
+ "metadata": {
911
+ "kernelspec": {
912
+ "display_name": "Work1",
913
+ "language": "python",
914
+ "name": "work1"
915
+ },
916
+ "language_info": {
917
+ "codemirror_mode": {
918
+ "name": "ipython",
919
+ "version": 3
920
+ },
921
+ "file_extension": ".py",
922
+ "mimetype": "text/x-python",
923
+ "name": "python",
924
+ "nbconvert_exporter": "python",
925
+ "pygments_lexer": "ipython3",
926
+ "version": "3.9.18"
927
+ }
928
+ },
929
+ "nbformat": 4,
930
+ "nbformat_minor": 5
931
+ }