Christina Theodoris commited on
Commit
088ea6e
·
1 Parent(s): f0de016

Add data collator for cell classification and example for cell classification

Browse files
examples/cell_classification.ipynb ADDED
@@ -0,0 +1,1954 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "234afff3",
6
+ "metadata": {},
7
+ "source": [
8
+ "## Geneformer Fine-Tuning for Cell Annotation Application"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 2,
14
+ "id": "1cbe6178-ea4d-478a-80a8-65ffaa4c1820",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import os\n",
19
+ "GPU_NUMBER = [0]\n",
20
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \",\".join([str(s) for s in GPU_NUMBER])\n",
21
+ "os.environ[\"NCCL_DEBUG\"] = \"INFO\""
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 3,
27
+ "id": "a9885d9f-00ac-4c84-b6a3-b7b648a90f0f",
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "# imports\n",
32
+ "from collections import Counter\n",
33
+ "import datetime\n",
34
+ "import pickle\n",
35
+ "import subprocess\n",
36
+ "import seaborn as sns; sns.set()\n",
37
+ "from datasets import load_from_disk\n",
38
+ "from sklearn.metrics import accuracy_score, f1_score\n",
39
+ "from transformers import BertForSequenceClassification\n",
40
+ "from transformers import Trainer\n",
41
+ "from transformers.training_args import TrainingArguments\n",
42
+ "\n",
43
+ "from geneformer import DataCollatorForCellClassification"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "markdown",
48
+ "id": "68bd3b98-5409-4105-b7af-f1ff64ea6a72",
49
+ "metadata": {},
50
+ "source": [
51
+ "## Prepare training and evaluation datasets"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": 15,
57
+ "id": "5735f1b7-7595-4a02-be17-2c5b970ad81a",
58
+ "metadata": {},
59
+ "outputs": [],
60
+ "source": [
61
+ "# load train dataset (includes all tissues)\n",
62
+ "train_dataset=load_from_disk(\"/path/to/cell_type_train_data.dataset\")"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": 17,
68
+ "id": "60eb8b0b-03ba-4065-98e3-0e424a9174ad",
69
+ "metadata": {},
70
+ "outputs": [],
71
+ "source": [
72
+ "# load evaluation dataset (includes all tissues)\n",
73
+ "eval_dataset=load_from_disk(\"/path/to/cell_type_test_data.dataset\")"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "id": "a4297a02-4c4c-434c-ae55-3387a0b239b5",
80
+ "metadata": {
81
+ "collapsed": true,
82
+ "jupyter": {
83
+ "outputs_hidden": true
84
+ },
85
+ "tags": []
86
+ },
87
+ "outputs": [],
88
+ "source": [
89
+ "dataset_list = []\n",
90
+ "evalset_list = []\n",
91
+ "organ_list = []\n",
92
+ "target_dict_list = []\n",
93
+ "\n",
94
+ "for organ in Counter(train_dataset[\"organ_major\"]).keys():\n",
95
+ " # collect list of tissues for fine-tuning (immune and bone marrow are included together)\n",
96
+ " if organ in [\"bone_marrow\"]: \n",
97
+ " continue\n",
98
+ " elif organ==\"immune\":\n",
99
+ " organ_ids = [\"immune\",\"bone_marrow\"]\n",
100
+ " organ_list += [\"immune\"]\n",
101
+ " else:\n",
102
+ " organ_ids = [organ]\n",
103
+ " organ_list += [organ]\n",
104
+ " \n",
105
+ " print(organ)\n",
106
+ " \n",
107
+ " # filter datasets for given organ\n",
108
+ " def if_organ(example):\n",
109
+ " return example[\"organ_major\"] in organ_ids\n",
110
+ " trainset_organ = train_dataset.filter(if_organ, num_proc=16)\n",
111
+ " \n",
112
+ " # per scDeepsort published method, drop cell types representing <0.5% of cells\n",
113
+ " celltype_counter = Counter(trainset_organ[\"cell_type\"])\n",
114
+ " total_cells = sum(celltype_counter.values())\n",
115
+ " cells_to_keep = [k for k,v in celltype_counter.items() if v>(0.005*total_cells)]\n",
116
+ " def if_not_rare_celltype(example):\n",
117
+ " return example[\"cell_type\"] in cells_to_keep\n",
118
+ " trainset_organ_subset = trainset_organ.filter(if_not_rare_celltype, num_proc=16)\n",
119
+ " \n",
120
+ " # shuffle datasets and rename columns\n",
121
+ " trainset_organ_shuffled = trainset_organ_subset.shuffle(seed=42)\n",
122
+ " trainset_organ_shuffled = trainset_organ_shuffled.rename_column(\"cell_type\",\"label\")\n",
123
+ " trainset_organ_shuffled = trainset_organ_shuffled.remove_columns(\"organ_major\")\n",
124
+ " \n",
125
+ " # create dictionary of cell types : label ids\n",
126
+ " target_names = list(Counter(trainset_organ_shuffled[\"label\"]).keys())\n",
127
+ " target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))\n",
128
+ " target_dict_list += [target_name_id_dict]\n",
129
+ " \n",
130
+ " # change labels to numerical ids\n",
131
+ " def classes_to_ids(example):\n",
132
+ " example[\"label\"] = target_name_id_dict[example[\"label\"]]\n",
133
+ " return example\n",
134
+ " labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=16)\n",
135
+ " \n",
136
+ " # create 80/20 train/eval splits\n",
137
+ " labeled_train_split = labeled_trainset.select([i for i in range(0,round(len(labeled_trainset)*0.8))])\n",
138
+ " labeled_eval_split = labeled_trainset.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])\n",
139
+ " \n",
140
+ " # filter dataset for cell types in corresponding training set\n",
141
+ " trained_labels = list(Counter(labeled_train_split[\"label\"]).keys())\n",
142
+ " def if_trained_label(example):\n",
143
+ " return example[\"label\"] in trained_labels\n",
144
+ " labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=16)\n",
145
+ "\n",
146
+ " dataset_list += [labeled_train_split]\n",
147
+ " evalset_list += [labeled_eval_split_subset]"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": 20,
153
+ "id": "83e20521-597a-4c54-897b-c4d42ea622c2",
154
+ "metadata": {},
155
+ "outputs": [],
156
+ "source": [
157
+ "trainset_dict = dict(zip(organ_list,dataset_list))\n",
158
+ "traintargetdict_dict = dict(zip(organ_list,target_dict_list))\n",
159
+ "\n",
160
+ "evalset_dict = dict(zip(organ_list,evalset_list))"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "markdown",
165
+ "id": "10eb110d-ba43-4efc-bc43-1815d6912647",
166
+ "metadata": {},
167
+ "source": [
168
+ "## Fine-Tune With Cell Classification Learning Objective and Quantify Predictive Performance"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": 18,
174
+ "id": "cd7b1cfb-f5cb-460e-ae77-769522ece054",
175
+ "metadata": {},
176
+ "outputs": [],
177
+ "source": [
178
+ "def compute_metrics(pred):\n",
179
+ " labels = pred.label_ids\n",
180
+ " preds = pred.predictions.argmax(-1)\n",
181
+ " # calculate accuracy and macro f1 using sklearn's function\n",
182
+ " acc = accuracy_score(labels, preds)\n",
183
+ " macro_f1 = f1_score(labels, preds, average='macro')\n",
184
+ " return {\n",
185
+ " 'accuracy': acc,\n",
186
+ " 'macro_f1': macro_f1\n",
187
+ " }"
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "execution_count": 19,
193
+ "id": "d24e1ab7-0131-44bd-b458-1ce5ba31853e",
194
+ "metadata": {},
195
+ "outputs": [],
196
+ "source": [
197
+ "# set model parameters\n",
198
+ "# max input size\n",
199
+ "max_input_size = 2 ** 11 # 2048\n",
200
+ "\n",
201
+ "# set training parameters\n",
202
+ "# max learning rate\n",
203
+ "max_lr = 5e-5\n",
204
+ "# how many pretrained layers to freeze\n",
205
+ "freeze_layers = 0\n",
206
+ "# number gpus\n",
207
+ "num_gpus = 1\n",
208
+ "# number cpu cores\n",
209
+ "num_proc = 16\n",
210
+ "# batch size for training and eval\n",
211
+ "geneformer_batch_size = 12\n",
212
+ "# learning schedule\n",
213
+ "lr_schedule_fn = \"linear\"\n",
214
+ "# warmup steps\n",
215
+ "warmup_steps = 500\n",
216
+ "# number of epochs\n",
217
+ "epochs = 10\n",
218
+ "# optimizer\n",
219
+ "optimizer = \"adamw\""
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "code",
224
+ "execution_count": 20,
225
+ "id": "05164c24-5fbf-4372-b26c-a43f3777a88d",
226
+ "metadata": {},
227
+ "outputs": [
228
+ {
229
+ "name": "stderr",
230
+ "output_type": "stream",
231
+ "text": [
232
+ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
233
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
234
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
235
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
236
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
237
+ ]
238
+ },
239
+ {
240
+ "name": "stdout",
241
+ "output_type": "stream",
242
+ "text": [
243
+ "spleen\n"
244
+ ]
245
+ },
246
+ {
247
+ "name": "stderr",
248
+ "output_type": "stream",
249
+ "text": [
250
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
251
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
252
+ ]
253
+ },
254
+ {
255
+ "data": {
256
+ "text/html": [
257
+ "\n",
258
+ " <div>\n",
259
+ " \n",
260
+ " <progress value='10280' max='10280' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
261
+ " [10280/10280 13:33, Epoch 10/10]\n",
262
+ " </div>\n",
263
+ " <table border=\"1\" class=\"dataframe\">\n",
264
+ " <thead>\n",
265
+ " <tr style=\"text-align: left;\">\n",
266
+ " <th>Epoch</th>\n",
267
+ " <th>Training Loss</th>\n",
268
+ " <th>Validation Loss</th>\n",
269
+ " <th>Accuracy</th>\n",
270
+ " <th>Macro F1</th>\n",
271
+ " <th>Weighted F1</th>\n",
272
+ " </tr>\n",
273
+ " </thead>\n",
274
+ " <tbody>\n",
275
+ " <tr>\n",
276
+ " <td>1</td>\n",
277
+ " <td>0.087000</td>\n",
278
+ " <td>0.068067</td>\n",
279
+ " <td>0.985404</td>\n",
280
+ " <td>0.956839</td>\n",
281
+ " <td>0.985483</td>\n",
282
+ " </tr>\n",
283
+ " <tr>\n",
284
+ " <td>2</td>\n",
285
+ " <td>0.044400</td>\n",
286
+ " <td>0.075289</td>\n",
287
+ " <td>0.985079</td>\n",
288
+ " <td>0.955069</td>\n",
289
+ " <td>0.984898</td>\n",
290
+ " </tr>\n",
291
+ " <tr>\n",
292
+ " <td>3</td>\n",
293
+ " <td>0.066700</td>\n",
294
+ " <td>0.078703</td>\n",
295
+ " <td>0.983782</td>\n",
296
+ " <td>0.953240</td>\n",
297
+ " <td>0.983959</td>\n",
298
+ " </tr>\n",
299
+ " <tr>\n",
300
+ " <td>4</td>\n",
301
+ " <td>0.037400</td>\n",
302
+ " <td>0.057132</td>\n",
303
+ " <td>0.989945</td>\n",
304
+ " <td>0.970619</td>\n",
305
+ " <td>0.989883</td>\n",
306
+ " </tr>\n",
307
+ " <tr>\n",
308
+ " <td>5</td>\n",
309
+ " <td>0.025000</td>\n",
310
+ " <td>0.061644</td>\n",
311
+ " <td>0.988323</td>\n",
312
+ " <td>0.961126</td>\n",
313
+ " <td>0.988211</td>\n",
314
+ " </tr>\n",
315
+ " <tr>\n",
316
+ " <td>6</td>\n",
317
+ " <td>0.022400</td>\n",
318
+ " <td>0.065323</td>\n",
319
+ " <td>0.989296</td>\n",
320
+ " <td>0.969737</td>\n",
321
+ " <td>0.989362</td>\n",
322
+ " </tr>\n",
323
+ " <tr>\n",
324
+ " <td>7</td>\n",
325
+ " <td>0.018600</td>\n",
326
+ " <td>0.063710</td>\n",
327
+ " <td>0.989620</td>\n",
328
+ " <td>0.969436</td>\n",
329
+ " <td>0.989579</td>\n",
330
+ " </tr>\n",
331
+ " <tr>\n",
332
+ " <td>8</td>\n",
333
+ " <td>0.039800</td>\n",
334
+ " <td>0.065919</td>\n",
335
+ " <td>0.989945</td>\n",
336
+ " <td>0.968065</td>\n",
337
+ " <td>0.989802</td>\n",
338
+ " </tr>\n",
339
+ " <tr>\n",
340
+ " <td>9</td>\n",
341
+ " <td>0.030200</td>\n",
342
+ " <td>0.061359</td>\n",
343
+ " <td>0.990269</td>\n",
344
+ " <td>0.971700</td>\n",
345
+ " <td>0.990314</td>\n",
346
+ " </tr>\n",
347
+ " <tr>\n",
348
+ " <td>10</td>\n",
349
+ " <td>0.013400</td>\n",
350
+ " <td>0.059181</td>\n",
351
+ " <td>0.991567</td>\n",
352
+ " <td>0.974599</td>\n",
353
+ " <td>0.991552</td>\n",
354
+ " </tr>\n",
355
+ " </tbody>\n",
356
+ "</table><p>"
357
+ ],
358
+ "text/plain": [
359
+ "<IPython.core.display.HTML object>"
360
+ ]
361
+ },
362
+ "metadata": {},
363
+ "output_type": "display_data"
364
+ },
365
+ {
366
+ "name": "stderr",
367
+ "output_type": "stream",
368
+ "text": [
369
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
370
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
371
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
372
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
373
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
374
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
375
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
376
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
377
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
378
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
379
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
380
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
381
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
382
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
383
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
384
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
385
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
386
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
387
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
388
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
389
+ ]
390
+ },
391
+ {
392
+ "data": {
393
+ "text/html": [
394
+ "\n",
395
+ " <div>\n",
396
+ " \n",
397
+ " <progress value='257' max='257' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
398
+ " [257/257 00:07]\n",
399
+ " </div>\n",
400
+ " "
401
+ ],
402
+ "text/plain": [
403
+ "<IPython.core.display.HTML object>"
404
+ ]
405
+ },
406
+ "metadata": {},
407
+ "output_type": "display_data"
408
+ },
409
+ {
410
+ "name": "stderr",
411
+ "output_type": "stream",
412
+ "text": [
413
+ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
414
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
415
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
416
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
417
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
418
+ ]
419
+ },
420
+ {
421
+ "name": "stdout",
422
+ "output_type": "stream",
423
+ "text": [
424
+ "kidney\n"
425
+ ]
426
+ },
427
+ {
428
+ "name": "stderr",
429
+ "output_type": "stream",
430
+ "text": [
431
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
432
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
433
+ ]
434
+ },
435
+ {
436
+ "data": {
437
+ "text/html": [
438
+ "\n",
439
+ " <div>\n",
440
+ " \n",
441
+ " <progress value='29340' max='29340' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
442
+ " [29340/29340 45:43, Epoch 10/10]\n",
443
+ " </div>\n",
444
+ " <table border=\"1\" class=\"dataframe\">\n",
445
+ " <thead>\n",
446
+ " <tr style=\"text-align: left;\">\n",
447
+ " <th>Epoch</th>\n",
448
+ " <th>Training Loss</th>\n",
449
+ " <th>Validation Loss</th>\n",
450
+ " <th>Accuracy</th>\n",
451
+ " <th>Macro F1</th>\n",
452
+ " <th>Weighted F1</th>\n",
453
+ " </tr>\n",
454
+ " </thead>\n",
455
+ " <tbody>\n",
456
+ " <tr>\n",
457
+ " <td>1</td>\n",
458
+ " <td>0.326900</td>\n",
459
+ " <td>0.299193</td>\n",
460
+ " <td>0.912500</td>\n",
461
+ " <td>0.823067</td>\n",
462
+ " <td>0.909627</td>\n",
463
+ " </tr>\n",
464
+ " <tr>\n",
465
+ " <td>2</td>\n",
466
+ " <td>0.224200</td>\n",
467
+ " <td>0.239580</td>\n",
468
+ " <td>0.926477</td>\n",
469
+ " <td>0.850237</td>\n",
470
+ " <td>0.923902</td>\n",
471
+ " </tr>\n",
472
+ " <tr>\n",
473
+ " <td>3</td>\n",
474
+ " <td>0.221600</td>\n",
475
+ " <td>0.242810</td>\n",
476
+ " <td>0.930227</td>\n",
477
+ " <td>0.878553</td>\n",
478
+ " <td>0.930349</td>\n",
479
+ " </tr>\n",
480
+ " <tr>\n",
481
+ " <td>4</td>\n",
482
+ " <td>0.166100</td>\n",
483
+ " <td>0.264178</td>\n",
484
+ " <td>0.933409</td>\n",
485
+ " <td>0.884759</td>\n",
486
+ " <td>0.933031</td>\n",
487
+ " </tr>\n",
488
+ " <tr>\n",
489
+ " <td>5</td>\n",
490
+ " <td>0.144100</td>\n",
491
+ " <td>0.279282</td>\n",
492
+ " <td>0.935000</td>\n",
493
+ " <td>0.887659</td>\n",
494
+ " <td>0.934987</td>\n",
495
+ " </tr>\n",
496
+ " <tr>\n",
497
+ " <td>6</td>\n",
498
+ " <td>0.112800</td>\n",
499
+ " <td>0.307647</td>\n",
500
+ " <td>0.935909</td>\n",
501
+ " <td>0.889239</td>\n",
502
+ " <td>0.935365</td>\n",
503
+ " </tr>\n",
504
+ " <tr>\n",
505
+ " <td>7</td>\n",
506
+ " <td>0.084600</td>\n",
507
+ " <td>0.326399</td>\n",
508
+ " <td>0.932841</td>\n",
509
+ " <td>0.892447</td>\n",
510
+ " <td>0.933191</td>\n",
511
+ " </tr>\n",
512
+ " <tr>\n",
513
+ " <td>8</td>\n",
514
+ " <td>0.068300</td>\n",
515
+ " <td>0.332626</td>\n",
516
+ " <td>0.936591</td>\n",
517
+ " <td>0.891629</td>\n",
518
+ " <td>0.936354</td>\n",
519
+ " </tr>\n",
520
+ " <tr>\n",
521
+ " <td>9</td>\n",
522
+ " <td>0.065500</td>\n",
523
+ " <td>0.348174</td>\n",
524
+ " <td>0.935227</td>\n",
525
+ " <td>0.889484</td>\n",
526
+ " <td>0.935040</td>\n",
527
+ " </tr>\n",
528
+ " <tr>\n",
529
+ " <td>10</td>\n",
530
+ " <td>0.046100</td>\n",
531
+ " <td>0.355350</td>\n",
532
+ " <td>0.935000</td>\n",
533
+ " <td>0.894578</td>\n",
534
+ " <td>0.934971</td>\n",
535
+ " </tr>\n",
536
+ " </tbody>\n",
537
+ "</table><p>"
538
+ ],
539
+ "text/plain": [
540
+ "<IPython.core.display.HTML object>"
541
+ ]
542
+ },
543
+ "metadata": {},
544
+ "output_type": "display_data"
545
+ },
546
+ {
547
+ "name": "stderr",
548
+ "output_type": "stream",
549
+ "text": [
550
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
551
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
552
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
553
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
554
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
555
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
556
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
557
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
558
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
559
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
560
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
561
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
562
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
563
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
564
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
565
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
566
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
567
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
568
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
569
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
570
+ ]
571
+ },
572
+ {
573
+ "data": {
574
+ "text/html": [
575
+ "\n",
576
+ " <div>\n",
577
+ " \n",
578
+ " <progress value='734' max='734' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
579
+ " [734/734 00:27]\n",
580
+ " </div>\n",
581
+ " "
582
+ ],
583
+ "text/plain": [
584
+ "<IPython.core.display.HTML object>"
585
+ ]
586
+ },
587
+ "metadata": {},
588
+ "output_type": "display_data"
589
+ },
590
+ {
591
+ "name": "stderr",
592
+ "output_type": "stream",
593
+ "text": [
594
+ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
595
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
596
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
597
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
598
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
599
+ ]
600
+ },
601
+ {
602
+ "name": "stdout",
603
+ "output_type": "stream",
604
+ "text": [
605
+ "lung\n"
606
+ ]
607
+ },
608
+ {
609
+ "name": "stderr",
610
+ "output_type": "stream",
611
+ "text": [
612
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
613
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
614
+ ]
615
+ },
616
+ {
617
+ "data": {
618
+ "text/html": [
619
+ "\n",
620
+ " <div>\n",
621
+ " \n",
622
+ " <progress value='21750' max='21750' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
623
+ " [21750/21750 30:32, Epoch 10/10]\n",
624
+ " </div>\n",
625
+ " <table border=\"1\" class=\"dataframe\">\n",
626
+ " <thead>\n",
627
+ " <tr style=\"text-align: left;\">\n",
628
+ " <th>Epoch</th>\n",
629
+ " <th>Training Loss</th>\n",
630
+ " <th>Validation Loss</th>\n",
631
+ " <th>Accuracy</th>\n",
632
+ " <th>Macro F1</th>\n",
633
+ " <th>Weighted F1</th>\n",
634
+ " </tr>\n",
635
+ " </thead>\n",
636
+ " <tbody>\n",
637
+ " <tr>\n",
638
+ " <td>1</td>\n",
639
+ " <td>0.337600</td>\n",
640
+ " <td>0.341523</td>\n",
641
+ " <td>0.906360</td>\n",
642
+ " <td>0.759979</td>\n",
643
+ " <td>0.899310</td>\n",
644
+ " </tr>\n",
645
+ " <tr>\n",
646
+ " <td>2</td>\n",
647
+ " <td>0.211900</td>\n",
648
+ " <td>0.258954</td>\n",
649
+ " <td>0.928429</td>\n",
650
+ " <td>0.835534</td>\n",
651
+ " <td>0.925903</td>\n",
652
+ " </tr>\n",
653
+ " <tr>\n",
654
+ " <td>3</td>\n",
655
+ " <td>0.208600</td>\n",
656
+ " <td>0.282081</td>\n",
657
+ " <td>0.930421</td>\n",
658
+ " <td>0.842786</td>\n",
659
+ " <td>0.928013</td>\n",
660
+ " </tr>\n",
661
+ " <tr>\n",
662
+ " <td>4</td>\n",
663
+ " <td>0.144400</td>\n",
664
+ " <td>0.253047</td>\n",
665
+ " <td>0.935479</td>\n",
666
+ " <td>0.871712</td>\n",
667
+ " <td>0.935234</td>\n",
668
+ " </tr>\n",
669
+ " <tr>\n",
670
+ " <td>5</td>\n",
671
+ " <td>0.109200</td>\n",
672
+ " <td>0.268833</td>\n",
673
+ " <td>0.939464</td>\n",
674
+ " <td>0.876173</td>\n",
675
+ " <td>0.938870</td>\n",
676
+ " </tr>\n",
677
+ " <tr>\n",
678
+ " <td>6</td>\n",
679
+ " <td>0.132700</td>\n",
680
+ " <td>0.282697</td>\n",
681
+ " <td>0.940536</td>\n",
682
+ " <td>0.883271</td>\n",
683
+ " <td>0.940191</td>\n",
684
+ " </tr>\n",
685
+ " <tr>\n",
686
+ " <td>7</td>\n",
687
+ " <td>0.081800</td>\n",
688
+ " <td>0.295864</td>\n",
689
+ " <td>0.940843</td>\n",
690
+ " <td>0.884201</td>\n",
691
+ " <td>0.940170</td>\n",
692
+ " </tr>\n",
693
+ " <tr>\n",
694
+ " <td>8</td>\n",
695
+ " <td>0.035900</td>\n",
696
+ " <td>0.306600</td>\n",
697
+ " <td>0.941916</td>\n",
698
+ " <td>0.884777</td>\n",
699
+ " <td>0.941578</td>\n",
700
+ " </tr>\n",
701
+ " <tr>\n",
702
+ " <td>9</td>\n",
703
+ " <td>0.050800</td>\n",
704
+ " <td>0.311677</td>\n",
705
+ " <td>0.940536</td>\n",
706
+ " <td>0.883437</td>\n",
707
+ " <td>0.940294</td>\n",
708
+ " </tr>\n",
709
+ " <tr>\n",
710
+ " <td>10</td>\n",
711
+ " <td>0.035800</td>\n",
712
+ " <td>0.315360</td>\n",
713
+ " <td>0.940843</td>\n",
714
+ " <td>0.883551</td>\n",
715
+ " <td>0.940612</td>\n",
716
+ " </tr>\n",
717
+ " </tbody>\n",
718
+ "</table><p>"
719
+ ],
720
+ "text/plain": [
721
+ "<IPython.core.display.HTML object>"
722
+ ]
723
+ },
724
+ "metadata": {},
725
+ "output_type": "display_data"
726
+ },
727
+ {
728
+ "name": "stderr",
729
+ "output_type": "stream",
730
+ "text": [
731
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
732
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
733
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
734
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
735
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
736
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
737
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
738
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
739
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
740
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
741
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
742
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
743
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
744
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
745
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
746
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
747
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
748
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
749
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
750
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
751
+ ]
752
+ },
753
+ {
754
+ "data": {
755
+ "text/html": [
756
+ "\n",
757
+ " <div>\n",
758
+ " \n",
759
+ " <progress value='544' max='544' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
760
+ " [544/544 00:19]\n",
761
+ " </div>\n",
762
+ " "
763
+ ],
764
+ "text/plain": [
765
+ "<IPython.core.display.HTML object>"
766
+ ]
767
+ },
768
+ "metadata": {},
769
+ "output_type": "display_data"
770
+ },
771
+ {
772
+ "name": "stderr",
773
+ "output_type": "stream",
774
+ "text": [
775
+ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
776
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
777
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
778
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
779
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
780
+ ]
781
+ },
782
+ {
783
+ "name": "stdout",
784
+ "output_type": "stream",
785
+ "text": [
786
+ "brain\n"
787
+ ]
788
+ },
789
+ {
790
+ "name": "stderr",
791
+ "output_type": "stream",
792
+ "text": [
793
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
794
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
795
+ ]
796
+ },
797
+ {
798
+ "data": {
799
+ "text/html": [
800
+ "\n",
801
+ " <div>\n",
802
+ " \n",
803
+ " <progress value='8880' max='8880' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
804
+ " [8880/8880 11:14, Epoch 10/10]\n",
805
+ " </div>\n",
806
+ " <table border=\"1\" class=\"dataframe\">\n",
807
+ " <thead>\n",
808
+ " <tr style=\"text-align: left;\">\n",
809
+ " <th>Epoch</th>\n",
810
+ " <th>Training Loss</th>\n",
811
+ " <th>Validation Loss</th>\n",
812
+ " <th>Accuracy</th>\n",
813
+ " <th>Macro F1</th>\n",
814
+ " <th>Weighted F1</th>\n",
815
+ " </tr>\n",
816
+ " </thead>\n",
817
+ " <tbody>\n",
818
+ " <tr>\n",
819
+ " <td>1</td>\n",
820
+ " <td>0.163100</td>\n",
821
+ " <td>0.156640</td>\n",
822
+ " <td>0.970345</td>\n",
823
+ " <td>0.736455</td>\n",
824
+ " <td>0.960714</td>\n",
825
+ " </tr>\n",
826
+ " <tr>\n",
827
+ " <td>2</td>\n",
828
+ " <td>0.149800</td>\n",
829
+ " <td>0.134897</td>\n",
830
+ " <td>0.968844</td>\n",
831
+ " <td>0.747114</td>\n",
832
+ " <td>0.960726</td>\n",
833
+ " </tr>\n",
834
+ " <tr>\n",
835
+ " <td>3</td>\n",
836
+ " <td>0.105600</td>\n",
837
+ " <td>0.115354</td>\n",
838
+ " <td>0.972222</td>\n",
839
+ " <td>0.775271</td>\n",
840
+ " <td>0.964932</td>\n",
841
+ " </tr>\n",
842
+ " <tr>\n",
843
+ " <td>4</td>\n",
844
+ " <td>0.086900</td>\n",
845
+ " <td>0.207918</td>\n",
846
+ " <td>0.968844</td>\n",
847
+ " <td>0.707927</td>\n",
848
+ " <td>0.958257</td>\n",
849
+ " </tr>\n",
850
+ " <tr>\n",
851
+ " <td>5</td>\n",
852
+ " <td>0.056400</td>\n",
853
+ " <td>0.106548</td>\n",
854
+ " <td>0.974099</td>\n",
855
+ " <td>0.839838</td>\n",
856
+ " <td>0.971611</td>\n",
857
+ " </tr>\n",
858
+ " <tr>\n",
859
+ " <td>6</td>\n",
860
+ " <td>0.037600</td>\n",
861
+ " <td>0.117437</td>\n",
862
+ " <td>0.978228</td>\n",
863
+ " <td>0.856578</td>\n",
864
+ " <td>0.975665</td>\n",
865
+ " </tr>\n",
866
+ " <tr>\n",
867
+ " <td>7</td>\n",
868
+ " <td>0.030500</td>\n",
869
+ " <td>0.127885</td>\n",
870
+ " <td>0.974474</td>\n",
871
+ " <td>0.856296</td>\n",
872
+ " <td>0.973531</td>\n",
873
+ " </tr>\n",
874
+ " <tr>\n",
875
+ " <td>8</td>\n",
876
+ " <td>0.019300</td>\n",
877
+ " <td>0.143203</td>\n",
878
+ " <td>0.977853</td>\n",
879
+ " <td>0.859362</td>\n",
880
+ " <td>0.975776</td>\n",
881
+ " </tr>\n",
882
+ " <tr>\n",
883
+ " <td>9</td>\n",
884
+ " <td>0.007400</td>\n",
885
+ " <td>0.153758</td>\n",
886
+ " <td>0.972598</td>\n",
887
+ " <td>0.852835</td>\n",
888
+ " <td>0.972314</td>\n",
889
+ " </tr>\n",
890
+ " <tr>\n",
891
+ " <td>10</td>\n",
892
+ " <td>0.017200</td>\n",
893
+ " <td>0.153911</td>\n",
894
+ " <td>0.975976</td>\n",
895
+ " <td>0.858196</td>\n",
896
+ " <td>0.974498</td>\n",
897
+ " </tr>\n",
898
+ " </tbody>\n",
899
+ "</table><p>"
900
+ ],
901
+ "text/plain": [
902
+ "<IPython.core.display.HTML object>"
903
+ ]
904
+ },
905
+ "metadata": {},
906
+ "output_type": "display_data"
907
+ },
908
+ {
909
+ "name": "stderr",
910
+ "output_type": "stream",
911
+ "text": [
912
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
913
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
914
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
915
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
916
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
917
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
918
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
919
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
920
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
921
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
922
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
923
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
924
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
925
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
926
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
927
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
928
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
929
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
930
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
931
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
932
+ ]
933
+ },
934
+ {
935
+ "data": {
936
+ "text/html": [
937
+ "\n",
938
+ " <div>\n",
939
+ " \n",
940
+ " <progress value='222' max='222' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
941
+ " [222/222 00:04]\n",
942
+ " </div>\n",
943
+ " "
944
+ ],
945
+ "text/plain": [
946
+ "<IPython.core.display.HTML object>"
947
+ ]
948
+ },
949
+ "metadata": {},
950
+ "output_type": "display_data"
951
+ },
952
+ {
953
+ "name": "stderr",
954
+ "output_type": "stream",
955
+ "text": [
956
+ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
957
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
958
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
959
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
960
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
961
+ ]
962
+ },
963
+ {
964
+ "name": "stdout",
965
+ "output_type": "stream",
966
+ "text": [
967
+ "placenta\n"
968
+ ]
969
+ },
970
+ {
971
+ "name": "stderr",
972
+ "output_type": "stream",
973
+ "text": [
974
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
975
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
976
+ ]
977
+ },
978
+ {
979
+ "data": {
980
+ "text/html": [
981
+ "\n",
982
+ " <div>\n",
983
+ " \n",
984
+ " <progress value='6180' max='6180' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
985
+ " [6180/6180 10:28, Epoch 10/10]\n",
986
+ " </div>\n",
987
+ " <table border=\"1\" class=\"dataframe\">\n",
988
+ " <thead>\n",
989
+ " <tr style=\"text-align: left;\">\n",
990
+ " <th>Epoch</th>\n",
991
+ " <th>Training Loss</th>\n",
992
+ " <th>Validation Loss</th>\n",
993
+ " <th>Accuracy</th>\n",
994
+ " <th>Macro F1</th>\n",
995
+ " <th>Weighted F1</th>\n",
996
+ " </tr>\n",
997
+ " </thead>\n",
998
+ " <tbody>\n",
999
+ " <tr>\n",
1000
+ " <td>1</td>\n",
1001
+ " <td>0.128700</td>\n",
1002
+ " <td>0.125175</td>\n",
1003
+ " <td>0.960626</td>\n",
1004
+ " <td>0.935752</td>\n",
1005
+ " <td>0.959463</td>\n",
1006
+ " </tr>\n",
1007
+ " <tr>\n",
1008
+ " <td>2</td>\n",
1009
+ " <td>0.064000</td>\n",
1010
+ " <td>0.215607</td>\n",
1011
+ " <td>0.951456</td>\n",
1012
+ " <td>0.920579</td>\n",
1013
+ " <td>0.949828</td>\n",
1014
+ " </tr>\n",
1015
+ " <tr>\n",
1016
+ " <td>3</td>\n",
1017
+ " <td>0.051300</td>\n",
1018
+ " <td>0.203044</td>\n",
1019
+ " <td>0.961165</td>\n",
1020
+ " <td>0.934195</td>\n",
1021
+ " <td>0.959470</td>\n",
1022
+ " </tr>\n",
1023
+ " <tr>\n",
1024
+ " <td>4</td>\n",
1025
+ " <td>0.045300</td>\n",
1026
+ " <td>0.115701</td>\n",
1027
+ " <td>0.978964</td>\n",
1028
+ " <td>0.966387</td>\n",
1029
+ " <td>0.978788</td>\n",
1030
+ " </tr>\n",
1031
+ " <tr>\n",
1032
+ " <td>5</td>\n",
1033
+ " <td>0.048200</td>\n",
1034
+ " <td>0.149484</td>\n",
1035
+ " <td>0.973571</td>\n",
1036
+ " <td>0.958927</td>\n",
1037
+ " <td>0.973305</td>\n",
1038
+ " </tr>\n",
1039
+ " <tr>\n",
1040
+ " <td>6</td>\n",
1041
+ " <td>0.040900</td>\n",
1042
+ " <td>0.134339</td>\n",
1043
+ " <td>0.978964</td>\n",
1044
+ " <td>0.967466</td>\n",
1045
+ " <td>0.978899</td>\n",
1046
+ " </tr>\n",
1047
+ " <tr>\n",
1048
+ " <td>7</td>\n",
1049
+ " <td>0.001600</td>\n",
1050
+ " <td>0.159900</td>\n",
1051
+ " <td>0.978425</td>\n",
1052
+ " <td>0.966713</td>\n",
1053
+ " <td>0.978211</td>\n",
1054
+ " </tr>\n",
1055
+ " <tr>\n",
1056
+ " <td>8</td>\n",
1057
+ " <td>0.002400</td>\n",
1058
+ " <td>0.125351</td>\n",
1059
+ " <td>0.979504</td>\n",
1060
+ " <td>0.968064</td>\n",
1061
+ " <td>0.979428</td>\n",
1062
+ " </tr>\n",
1063
+ " <tr>\n",
1064
+ " <td>9</td>\n",
1065
+ " <td>0.009400</td>\n",
1066
+ " <td>0.120132</td>\n",
1067
+ " <td>0.980583</td>\n",
1068
+ " <td>0.969631</td>\n",
1069
+ " <td>0.980506</td>\n",
1070
+ " </tr>\n",
1071
+ " <tr>\n",
1072
+ " <td>10</td>\n",
1073
+ " <td>0.001500</td>\n",
1074
+ " <td>0.137864</td>\n",
1075
+ " <td>0.978964</td>\n",
1076
+ " <td>0.967180</td>\n",
1077
+ " <td>0.978825</td>\n",
1078
+ " </tr>\n",
1079
+ " </tbody>\n",
1080
+ "</table><p>"
1081
+ ],
1082
+ "text/plain": [
1083
+ "<IPython.core.display.HTML object>"
1084
+ ]
1085
+ },
1086
+ "metadata": {},
1087
+ "output_type": "display_data"
1088
+ },
1089
+ {
1090
+ "name": "stderr",
1091
+ "output_type": "stream",
1092
+ "text": [
1093
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1094
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1095
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1096
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1097
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1098
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1099
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1100
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1101
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1102
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1103
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1104
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1105
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1106
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1107
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1108
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1109
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1110
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1111
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1112
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
1113
+ ]
1114
+ },
1115
+ {
1116
+ "data": {
1117
+ "text/html": [
1118
+ "\n",
1119
+ " <div>\n",
1120
+ " \n",
1121
+ " <progress value='155' max='155' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1122
+ " [155/155 00:05]\n",
1123
+ " </div>\n",
1124
+ " "
1125
+ ],
1126
+ "text/plain": [
1127
+ "<IPython.core.display.HTML object>"
1128
+ ]
1129
+ },
1130
+ "metadata": {},
1131
+ "output_type": "display_data"
1132
+ },
1133
+ {
1134
+ "name": "stderr",
1135
+ "output_type": "stream",
1136
+ "text": [
1137
+ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
1138
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
1139
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
1140
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
1141
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
1142
+ ]
1143
+ },
1144
+ {
1145
+ "name": "stdout",
1146
+ "output_type": "stream",
1147
+ "text": [
1148
+ "immune\n"
1149
+ ]
1150
+ },
1151
+ {
1152
+ "name": "stderr",
1153
+ "output_type": "stream",
1154
+ "text": [
1155
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1156
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
1157
+ ]
1158
+ },
1159
+ {
1160
+ "data": {
1161
+ "text/html": [
1162
+ "\n",
1163
+ " <div>\n",
1164
+ " \n",
1165
+ " <progress value='17140' max='17140' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1166
+ " [17140/17140 22:02, Epoch 10/10]\n",
1167
+ " </div>\n",
1168
+ " <table border=\"1\" class=\"dataframe\">\n",
1169
+ " <thead>\n",
1170
+ " <tr style=\"text-align: left;\">\n",
1171
+ " <th>Epoch</th>\n",
1172
+ " <th>Training Loss</th>\n",
1173
+ " <th>Validation Loss</th>\n",
1174
+ " <th>Accuracy</th>\n",
1175
+ " <th>Macro F1</th>\n",
1176
+ " <th>Weighted F1</th>\n",
1177
+ " </tr>\n",
1178
+ " </thead>\n",
1179
+ " <tbody>\n",
1180
+ " <tr>\n",
1181
+ " <td>1</td>\n",
1182
+ " <td>0.288900</td>\n",
1183
+ " <td>0.231582</td>\n",
1184
+ " <td>0.936770</td>\n",
1185
+ " <td>0.868405</td>\n",
1186
+ " <td>0.934816</td>\n",
1187
+ " </tr>\n",
1188
+ " <tr>\n",
1189
+ " <td>2</td>\n",
1190
+ " <td>0.203200</td>\n",
1191
+ " <td>0.206292</td>\n",
1192
+ " <td>0.937354</td>\n",
1193
+ " <td>0.888661</td>\n",
1194
+ " <td>0.939555</td>\n",
1195
+ " </tr>\n",
1196
+ " <tr>\n",
1197
+ " <td>3</td>\n",
1198
+ " <td>0.183500</td>\n",
1199
+ " <td>0.195811</td>\n",
1200
+ " <td>0.944942</td>\n",
1201
+ " <td>0.891149</td>\n",
1202
+ " <td>0.944008</td>\n",
1203
+ " </tr>\n",
1204
+ " <tr>\n",
1205
+ " <td>4</td>\n",
1206
+ " <td>0.151000</td>\n",
1207
+ " <td>0.219581</td>\n",
1208
+ " <td>0.947665</td>\n",
1209
+ " <td>0.906578</td>\n",
1210
+ " <td>0.947093</td>\n",
1211
+ " </tr>\n",
1212
+ " <tr>\n",
1213
+ " <td>5</td>\n",
1214
+ " <td>0.090000</td>\n",
1215
+ " <td>0.247120</td>\n",
1216
+ " <td>0.946693</td>\n",
1217
+ " <td>0.898812</td>\n",
1218
+ " <td>0.945808</td>\n",
1219
+ " </tr>\n",
1220
+ " <tr>\n",
1221
+ " <td>6</td>\n",
1222
+ " <td>0.060400</td>\n",
1223
+ " <td>0.249662</td>\n",
1224
+ " <td>0.948444</td>\n",
1225
+ " <td>0.905014</td>\n",
1226
+ " <td>0.947975</td>\n",
1227
+ " </tr>\n",
1228
+ " <tr>\n",
1229
+ " <td>7</td>\n",
1230
+ " <td>0.071300</td>\n",
1231
+ " <td>0.272767</td>\n",
1232
+ " <td>0.949416</td>\n",
1233
+ " <td>0.911514</td>\n",
1234
+ " <td>0.949748</td>\n",
1235
+ " </tr>\n",
1236
+ " <tr>\n",
1237
+ " <td>8</td>\n",
1238
+ " <td>0.052600</td>\n",
1239
+ " <td>0.305051</td>\n",
1240
+ " <td>0.945331</td>\n",
1241
+ " <td>0.902348</td>\n",
1242
+ " <td>0.944987</td>\n",
1243
+ " </tr>\n",
1244
+ " <tr>\n",
1245
+ " <td>9</td>\n",
1246
+ " <td>0.026900</td>\n",
1247
+ " <td>0.294135</td>\n",
1248
+ " <td>0.948638</td>\n",
1249
+ " <td>0.904058</td>\n",
1250
+ " <td>0.948296</td>\n",
1251
+ " </tr>\n",
1252
+ " <tr>\n",
1253
+ " <td>10</td>\n",
1254
+ " <td>0.034500</td>\n",
1255
+ " <td>0.292029</td>\n",
1256
+ " <td>0.950195</td>\n",
1257
+ " <td>0.908547</td>\n",
1258
+ " <td>0.949753</td>\n",
1259
+ " </tr>\n",
1260
+ " </tbody>\n",
1261
+ "</table><p>"
1262
+ ],
1263
+ "text/plain": [
1264
+ "<IPython.core.display.HTML object>"
1265
+ ]
1266
+ },
1267
+ "metadata": {},
1268
+ "output_type": "display_data"
1269
+ },
1270
+ {
1271
+ "name": "stderr",
1272
+ "output_type": "stream",
1273
+ "text": [
1274
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1275
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1276
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1277
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1278
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1279
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1280
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1281
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1282
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1283
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1284
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1285
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1286
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1287
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1288
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1289
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1290
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1291
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1292
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1293
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
1294
+ ]
1295
+ },
1296
+ {
1297
+ "data": {
1298
+ "text/html": [
1299
+ "\n",
1300
+ " <div>\n",
1301
+ " \n",
1302
+ " <progress value='429' max='429' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1303
+ " [429/429 00:13]\n",
1304
+ " </div>\n",
1305
+ " "
1306
+ ],
1307
+ "text/plain": [
1308
+ "<IPython.core.display.HTML object>"
1309
+ ]
1310
+ },
1311
+ "metadata": {},
1312
+ "output_type": "display_data"
1313
+ },
1314
+ {
1315
+ "name": "stderr",
1316
+ "output_type": "stream",
1317
+ "text": [
1318
+ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
1319
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
1320
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
1321
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
1322
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
1323
+ ]
1324
+ },
1325
+ {
1326
+ "name": "stdout",
1327
+ "output_type": "stream",
1328
+ "text": [
1329
+ "large_intestine\n"
1330
+ ]
1331
+ },
1332
+ {
1333
+ "name": "stderr",
1334
+ "output_type": "stream",
1335
+ "text": [
1336
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1337
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
1338
+ ]
1339
+ },
1340
+ {
1341
+ "data": {
1342
+ "text/html": [
1343
+ "\n",
1344
+ " <div>\n",
1345
+ " \n",
1346
+ " <progress value='33070' max='33070' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1347
+ " [33070/33070 43:02, Epoch 10/10]\n",
1348
+ " </div>\n",
1349
+ " <table border=\"1\" class=\"dataframe\">\n",
1350
+ " <thead>\n",
1351
+ " <tr style=\"text-align: left;\">\n",
1352
+ " <th>Epoch</th>\n",
1353
+ " <th>Training Loss</th>\n",
1354
+ " <th>Validation Loss</th>\n",
1355
+ " <th>Accuracy</th>\n",
1356
+ " <th>Macro F1</th>\n",
1357
+ " <th>Weighted F1</th>\n",
1358
+ " </tr>\n",
1359
+ " </thead>\n",
1360
+ " <tbody>\n",
1361
+ " <tr>\n",
1362
+ " <td>1</td>\n",
1363
+ " <td>0.306200</td>\n",
1364
+ " <td>0.312431</td>\n",
1365
+ " <td>0.908266</td>\n",
1366
+ " <td>0.786242</td>\n",
1367
+ " <td>0.900768</td>\n",
1368
+ " </tr>\n",
1369
+ " <tr>\n",
1370
+ " <td>2</td>\n",
1371
+ " <td>0.223900</td>\n",
1372
+ " <td>0.248096</td>\n",
1373
+ " <td>0.925101</td>\n",
1374
+ " <td>0.841251</td>\n",
1375
+ " <td>0.920987</td>\n",
1376
+ " </tr>\n",
1377
+ " <tr>\n",
1378
+ " <td>3</td>\n",
1379
+ " <td>0.173600</td>\n",
1380
+ " <td>0.259997</td>\n",
1381
+ " <td>0.925907</td>\n",
1382
+ " <td>0.850348</td>\n",
1383
+ " <td>0.926290</td>\n",
1384
+ " </tr>\n",
1385
+ " <tr>\n",
1386
+ " <td>4</td>\n",
1387
+ " <td>0.162900</td>\n",
1388
+ " <td>0.282306</td>\n",
1389
+ " <td>0.925000</td>\n",
1390
+ " <td>0.873669</td>\n",
1391
+ " <td>0.925531</td>\n",
1392
+ " </tr>\n",
1393
+ " <tr>\n",
1394
+ " <td>5</td>\n",
1395
+ " <td>0.143400</td>\n",
1396
+ " <td>0.254494</td>\n",
1397
+ " <td>0.937903</td>\n",
1398
+ " <td>0.876749</td>\n",
1399
+ " <td>0.937836</td>\n",
1400
+ " </tr>\n",
1401
+ " <tr>\n",
1402
+ " <td>6</td>\n",
1403
+ " <td>0.104500</td>\n",
1404
+ " <td>0.289942</td>\n",
1405
+ " <td>0.934677</td>\n",
1406
+ " <td>0.875333</td>\n",
1407
+ " <td>0.934339</td>\n",
1408
+ " </tr>\n",
1409
+ " <tr>\n",
1410
+ " <td>7</td>\n",
1411
+ " <td>0.080300</td>\n",
1412
+ " <td>0.313914</td>\n",
1413
+ " <td>0.935484</td>\n",
1414
+ " <td>0.877271</td>\n",
1415
+ " <td>0.934986</td>\n",
1416
+ " </tr>\n",
1417
+ " <tr>\n",
1418
+ " <td>8</td>\n",
1419
+ " <td>0.063500</td>\n",
1420
+ " <td>0.339868</td>\n",
1421
+ " <td>0.936290</td>\n",
1422
+ " <td>0.882267</td>\n",
1423
+ " <td>0.936187</td>\n",
1424
+ " </tr>\n",
1425
+ " <tr>\n",
1426
+ " <td>9</td>\n",
1427
+ " <td>0.042500</td>\n",
1428
+ " <td>0.345784</td>\n",
1429
+ " <td>0.938911</td>\n",
1430
+ " <td>0.882963</td>\n",
1431
+ " <td>0.938682</td>\n",
1432
+ " </tr>\n",
1433
+ " <tr>\n",
1434
+ " <td>10</td>\n",
1435
+ " <td>0.038900</td>\n",
1436
+ " <td>0.352199</td>\n",
1437
+ " <td>0.939516</td>\n",
1438
+ " <td>0.885509</td>\n",
1439
+ " <td>0.939497</td>\n",
1440
+ " </tr>\n",
1441
+ " </tbody>\n",
1442
+ "</table><p>"
1443
+ ],
1444
+ "text/plain": [
1445
+ "<IPython.core.display.HTML object>"
1446
+ ]
1447
+ },
1448
+ "metadata": {},
1449
+ "output_type": "display_data"
1450
+ },
1451
+ {
1452
+ "name": "stderr",
1453
+ "output_type": "stream",
1454
+ "text": [
1455
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1456
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1457
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1458
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1459
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1460
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1461
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1462
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1463
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1464
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1465
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1466
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1467
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1468
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1469
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1470
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1471
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1472
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1473
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1474
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
1475
+ ]
1476
+ },
1477
+ {
1478
+ "data": {
1479
+ "text/html": [
1480
+ "\n",
1481
+ " <div>\n",
1482
+ " \n",
1483
+ " <progress value='827' max='827' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1484
+ " [827/827 00:26]\n",
1485
+ " </div>\n",
1486
+ " "
1487
+ ],
1488
+ "text/plain": [
1489
+ "<IPython.core.display.HTML object>"
1490
+ ]
1491
+ },
1492
+ "metadata": {},
1493
+ "output_type": "display_data"
1494
+ },
1495
+ {
1496
+ "name": "stderr",
1497
+ "output_type": "stream",
1498
+ "text": [
1499
+ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
1500
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
1501
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
1502
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
1503
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
1504
+ ]
1505
+ },
1506
+ {
1507
+ "name": "stdout",
1508
+ "output_type": "stream",
1509
+ "text": [
1510
+ "pancreas\n"
1511
+ ]
1512
+ },
1513
+ {
1514
+ "name": "stderr",
1515
+ "output_type": "stream",
1516
+ "text": [
1517
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1518
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
1519
+ ]
1520
+ },
1521
+ {
1522
+ "data": {
1523
+ "text/html": [
1524
+ "\n",
1525
+ " <div>\n",
1526
+ " \n",
1527
+ " <progress value='18280' max='18280' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1528
+ " [18280/18280 23:32, Epoch 10/10]\n",
1529
+ " </div>\n",
1530
+ " <table border=\"1\" class=\"dataframe\">\n",
1531
+ " <thead>\n",
1532
+ " <tr style=\"text-align: left;\">\n",
1533
+ " <th>Epoch</th>\n",
1534
+ " <th>Training Loss</th>\n",
1535
+ " <th>Validation Loss</th>\n",
1536
+ " <th>Accuracy</th>\n",
1537
+ " <th>Macro F1</th>\n",
1538
+ " <th>Weighted F1</th>\n",
1539
+ " </tr>\n",
1540
+ " </thead>\n",
1541
+ " <tbody>\n",
1542
+ " <tr>\n",
1543
+ " <td>1</td>\n",
1544
+ " <td>0.340100</td>\n",
1545
+ " <td>0.343200</td>\n",
1546
+ " <td>0.896244</td>\n",
1547
+ " <td>0.655661</td>\n",
1548
+ " <td>0.879469</td>\n",
1549
+ " </tr>\n",
1550
+ " <tr>\n",
1551
+ " <td>2</td>\n",
1552
+ " <td>0.178300</td>\n",
1553
+ " <td>0.224033</td>\n",
1554
+ " <td>0.930890</td>\n",
1555
+ " <td>0.859772</td>\n",
1556
+ " <td>0.925342</td>\n",
1557
+ " </tr>\n",
1558
+ " <tr>\n",
1559
+ " <td>3</td>\n",
1560
+ " <td>0.154200</td>\n",
1561
+ " <td>0.208034</td>\n",
1562
+ " <td>0.941284</td>\n",
1563
+ " <td>0.887012</td>\n",
1564
+ " <td>0.939485</td>\n",
1565
+ " </tr>\n",
1566
+ " <tr>\n",
1567
+ " <td>4</td>\n",
1568
+ " <td>0.121200</td>\n",
1569
+ " <td>0.216660</td>\n",
1570
+ " <td>0.940372</td>\n",
1571
+ " <td>0.880716</td>\n",
1572
+ " <td>0.939431</td>\n",
1573
+ " </tr>\n",
1574
+ " <tr>\n",
1575
+ " <td>5</td>\n",
1576
+ " <td>0.099900</td>\n",
1577
+ " <td>0.254255</td>\n",
1578
+ " <td>0.940554</td>\n",
1579
+ " <td>0.889088</td>\n",
1580
+ " <td>0.938300</td>\n",
1581
+ " </tr>\n",
1582
+ " <tr>\n",
1583
+ " <td>6</td>\n",
1584
+ " <td>0.065800</td>\n",
1585
+ " <td>0.267429</td>\n",
1586
+ " <td>0.942743</td>\n",
1587
+ " <td>0.897682</td>\n",
1588
+ " <td>0.942815</td>\n",
1589
+ " </tr>\n",
1590
+ " <tr>\n",
1591
+ " <td>7</td>\n",
1592
+ " <td>0.061200</td>\n",
1593
+ " <td>0.282509</td>\n",
1594
+ " <td>0.945478</td>\n",
1595
+ " <td>0.898797</td>\n",
1596
+ " <td>0.943881</td>\n",
1597
+ " </tr>\n",
1598
+ " <tr>\n",
1599
+ " <td>8</td>\n",
1600
+ " <td>0.036800</td>\n",
1601
+ " <td>0.301781</td>\n",
1602
+ " <td>0.943837</td>\n",
1603
+ " <td>0.903816</td>\n",
1604
+ " <td>0.944163</td>\n",
1605
+ " </tr>\n",
1606
+ " <tr>\n",
1607
+ " <td>9</td>\n",
1608
+ " <td>0.035400</td>\n",
1609
+ " <td>0.317026</td>\n",
1610
+ " <td>0.942560</td>\n",
1611
+ " <td>0.902241</td>\n",
1612
+ " <td>0.942071</td>\n",
1613
+ " </tr>\n",
1614
+ " <tr>\n",
1615
+ " <td>10</td>\n",
1616
+ " <td>0.014200</td>\n",
1617
+ " <td>0.313259</td>\n",
1618
+ " <td>0.946754</td>\n",
1619
+ " <td>0.904955</td>\n",
1620
+ " <td>0.946129</td>\n",
1621
+ " </tr>\n",
1622
+ " </tbody>\n",
1623
+ "</table><p>"
1624
+ ],
1625
+ "text/plain": [
1626
+ "<IPython.core.display.HTML object>"
1627
+ ]
1628
+ },
1629
+ "metadata": {},
1630
+ "output_type": "display_data"
1631
+ },
1632
+ {
1633
+ "name": "stderr",
1634
+ "output_type": "stream",
1635
+ "text": [
1636
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1637
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1638
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1639
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1640
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1641
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1642
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1643
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1644
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1645
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1646
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1647
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1648
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1649
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1650
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1651
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1652
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1653
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1654
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1655
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
1656
+ ]
1657
+ },
1658
+ {
1659
+ "data": {
1660
+ "text/html": [
1661
+ "\n",
1662
+ " <div>\n",
1663
+ " \n",
1664
+ " <progress value='457' max='457' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1665
+ " [457/457 00:11]\n",
1666
+ " </div>\n",
1667
+ " "
1668
+ ],
1669
+ "text/plain": [
1670
+ "<IPython.core.display.HTML object>"
1671
+ ]
1672
+ },
1673
+ "metadata": {},
1674
+ "output_type": "display_data"
1675
+ },
1676
+ {
1677
+ "name": "stderr",
1678
+ "output_type": "stream",
1679
+ "text": [
1680
+ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
1681
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
1682
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
1683
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
1684
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
1685
+ ]
1686
+ },
1687
+ {
1688
+ "name": "stdout",
1689
+ "output_type": "stream",
1690
+ "text": [
1691
+ "liver\n"
1692
+ ]
1693
+ },
1694
+ {
1695
+ "name": "stderr",
1696
+ "output_type": "stream",
1697
+ "text": [
1698
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1699
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
1700
+ ]
1701
+ },
1702
+ {
1703
+ "data": {
1704
+ "text/html": [
1705
+ "\n",
1706
+ " <div>\n",
1707
+ " \n",
1708
+ " <progress value='18690' max='18690' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1709
+ " [18690/18690 26:56, Epoch 10/10]\n",
1710
+ " </div>\n",
1711
+ " <table border=\"1\" class=\"dataframe\">\n",
1712
+ " <thead>\n",
1713
+ " <tr style=\"text-align: left;\">\n",
1714
+ " <th>Epoch</th>\n",
1715
+ " <th>Training Loss</th>\n",
1716
+ " <th>Validation Loss</th>\n",
1717
+ " <th>Accuracy</th>\n",
1718
+ " <th>Macro F1</th>\n",
1719
+ " <th>Weighted F1</th>\n",
1720
+ " </tr>\n",
1721
+ " </thead>\n",
1722
+ " <tbody>\n",
1723
+ " <tr>\n",
1724
+ " <td>1</td>\n",
1725
+ " <td>0.388500</td>\n",
1726
+ " <td>0.385503</td>\n",
1727
+ " <td>0.878188</td>\n",
1728
+ " <td>0.673887</td>\n",
1729
+ " <td>0.871348</td>\n",
1730
+ " </tr>\n",
1731
+ " <tr>\n",
1732
+ " <td>2</td>\n",
1733
+ " <td>0.315900</td>\n",
1734
+ " <td>0.302775</td>\n",
1735
+ " <td>0.907437</td>\n",
1736
+ " <td>0.754182</td>\n",
1737
+ " <td>0.903474</td>\n",
1738
+ " </tr>\n",
1739
+ " <tr>\n",
1740
+ " <td>3</td>\n",
1741
+ " <td>0.242600</td>\n",
1742
+ " <td>0.321844</td>\n",
1743
+ " <td>0.907972</td>\n",
1744
+ " <td>0.779504</td>\n",
1745
+ " <td>0.905881</td>\n",
1746
+ " </tr>\n",
1747
+ " <tr>\n",
1748
+ " <td>4</td>\n",
1749
+ " <td>0.238600</td>\n",
1750
+ " <td>0.323119</td>\n",
1751
+ " <td>0.911539</td>\n",
1752
+ " <td>0.790922</td>\n",
1753
+ " <td>0.910299</td>\n",
1754
+ " </tr>\n",
1755
+ " <tr>\n",
1756
+ " <td>5</td>\n",
1757
+ " <td>0.160100</td>\n",
1758
+ " <td>0.328203</td>\n",
1759
+ " <td>0.915641</td>\n",
1760
+ " <td>0.793490</td>\n",
1761
+ " <td>0.913836</td>\n",
1762
+ " </tr>\n",
1763
+ " <tr>\n",
1764
+ " <td>6</td>\n",
1765
+ " <td>0.163100</td>\n",
1766
+ " <td>0.348942</td>\n",
1767
+ " <td>0.917425</td>\n",
1768
+ " <td>0.813604</td>\n",
1769
+ " <td>0.916911</td>\n",
1770
+ " </tr>\n",
1771
+ " <tr>\n",
1772
+ " <td>7</td>\n",
1773
+ " <td>0.124100</td>\n",
1774
+ " <td>0.373799</td>\n",
1775
+ " <td>0.916890</td>\n",
1776
+ " <td>0.820355</td>\n",
1777
+ " <td>0.916688</td>\n",
1778
+ " </tr>\n",
1779
+ " <tr>\n",
1780
+ " <td>8</td>\n",
1781
+ " <td>0.118700</td>\n",
1782
+ " <td>0.399474</td>\n",
1783
+ " <td>0.916890</td>\n",
1784
+ " <td>0.818839</td>\n",
1785
+ " <td>0.916640</td>\n",
1786
+ " </tr>\n",
1787
+ " <tr>\n",
1788
+ " <td>9</td>\n",
1789
+ " <td>0.066800</td>\n",
1790
+ " <td>0.414363</td>\n",
1791
+ " <td>0.917603</td>\n",
1792
+ " <td>0.830703</td>\n",
1793
+ " <td>0.917226</td>\n",
1794
+ " </tr>\n",
1795
+ " <tr>\n",
1796
+ " <td>10</td>\n",
1797
+ " <td>0.075800</td>\n",
1798
+ " <td>0.413828</td>\n",
1799
+ " <td>0.919030</td>\n",
1800
+ " <td>0.828149</td>\n",
1801
+ " <td>0.918506</td>\n",
1802
+ " </tr>\n",
1803
+ " </tbody>\n",
1804
+ "</table><p>"
1805
+ ],
1806
+ "text/plain": [
1807
+ "<IPython.core.display.HTML object>"
1808
+ ]
1809
+ },
1810
+ "metadata": {},
1811
+ "output_type": "display_data"
1812
+ },
1813
+ {
1814
+ "name": "stderr",
1815
+ "output_type": "stream",
1816
+ "text": [
1817
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1818
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1819
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1820
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1821
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1822
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1823
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1824
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1825
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1826
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1827
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1828
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1829
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1830
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1831
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1832
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1833
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1834
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
1835
+ "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1836
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
1837
+ ]
1838
+ },
1839
+ {
1840
+ "data": {
1841
+ "text/html": [
1842
+ "\n",
1843
+ " <div>\n",
1844
+ " \n",
1845
+ " <progress value='936' max='468' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1846
+ " [468/468 00:39]\n",
1847
+ " </div>\n",
1848
+ " "
1849
+ ],
1850
+ "text/plain": [
1851
+ "<IPython.core.display.HTML object>"
1852
+ ]
1853
+ },
1854
+ "metadata": {},
1855
+ "output_type": "display_data"
1856
+ }
1857
+ ],
1858
+ "source": [
1859
+ "for organ in organ_list:\n",
1860
+ " print(organ)\n",
1861
+ " organ_trainset = trainset_dict[organ]\n",
1862
+ " organ_evalset = evalset_dict[organ]\n",
1863
+ " organ_label_dict = traintargetdict_dict[organ]\n",
1864
+ " \n",
1865
+ " # set logging steps\n",
1866
+ " logging_steps = round(len(organ_trainset)/geneformer_batch_size/10)\n",
1867
+ " \n",
1868
+ " # reload pretrained model\n",
1869
+ " model = BertForSequenceClassification.from_pretrained(\"/path/to/pretrained_model/\", \n",
1870
+ " num_labels=len(organ_label_dict.keys()),\n",
1871
+ " output_attentions = False,\n",
1872
+ " output_hidden_states = False).to(\"cuda\")\n",
1873
+ " \n",
1874
+ " # create output directory\n",
1875
+ " current_date = datetime.datetime.now()\n",
1876
+ " datestamp = f\"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}\"\n",
1877
+ " output_dir = f\"/path/to/models/{datestamp}_geneformer_CellClassifier_{organ}_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/\"\n",
1878
+ " \n",
1879
+ " # ensure not overwriting previously saved model\n",
1880
+ " saved_model_test = os.path.join(output_dir, f\"pytorch_model.bin\")\n",
1881
+ " if os.path.isfile(saved_model_test) == True:\n",
1882
+ " raise Exception(\"Model already saved to this directory.\")\n",
1883
+ "\n",
1884
+ " # make output directories\n",
1885
+ " subprocess.call(f'mkdir {output_dir}', shell=True)\n",
1886
+ " \n",
1887
+ " # set training arguments\n",
1888
+ " training_args = {\n",
1889
+ " \"learning_rate\": max_lr,\n",
1890
+ " \"do_train\": True,\n",
1891
+ " \"do_eval\": True,\n",
1892
+ " \"evaluation_strategy\": \"epoch\",\n",
1893
+ " \"logging_steps\": logging_steps,\n",
1894
+ " \"group_by_length\": True,\n",
1895
+ " \"length_column_name\": \"length\",\n",
1896
+ " \"disable_tqdm\": False,\n",
1897
+ " \"lr_scheduler_type\": lr_schedule_fn,\n",
1898
+ " \"warmup_steps\": warmup_steps,\n",
1899
+ " \"weight_decay\": 0.001,\n",
1900
+ " \"per_device_train_batch_size\": geneformer_batch_size,\n",
1901
+ " \"per_device_eval_batch_size\": geneformer_batch_size,\n",
1902
+ " \"num_train_epochs\": epochs,\n",
1903
+ " \"load_best_model_at_end\": True,\n",
1904
+ " \"output_dir\": output_dir,\n",
1905
+ " }\n",
1906
+ " \n",
1907
+ " training_args_init = TrainingArguments(**training_args)\n",
1908
+ "\n",
1909
+ " # create the trainer\n",
1910
+ " trainer = Trainer(\n",
1911
+ " model=model,\n",
1912
+ " args=training_args_init,\n",
1913
+ " data_collator=DataCollatorForCellClassification(),\n",
1914
+ " train_dataset=organ_trainset,\n",
1915
+ " eval_dataset=organ_evalset,\n",
1916
+ " compute_metrics=compute_metrics\n",
1917
+ " )\n",
1918
+ " # train the cell type classifier\n",
1919
+ " trainer.train()\n",
1920
+ " predictions = trainer.predict(organ_evalset)\n",
1921
+ " with open(f\"{output_dir}predictions.pickle\", \"wb\") as fp:\n",
1922
+ " pickle.dump(predictions, fp)\n",
1923
+ " trainer.save_metrics(\"eval\",predictions.metrics)\n",
1924
+ " trainer.save_model(output_dir)"
1925
+ ]
1926
+ }
1927
+ ],
1928
+ "metadata": {
1929
+ "kernelspec": {
1930
+ "display_name": "Python 3.8.6 64-bit ('3.8.6')",
1931
+ "language": "python",
1932
+ "name": "python3"
1933
+ },
1934
+ "language_info": {
1935
+ "codemirror_mode": {
1936
+ "name": "ipython",
1937
+ "version": 3
1938
+ },
1939
+ "file_extension": ".py",
1940
+ "mimetype": "text/x-python",
1941
+ "name": "python",
1942
+ "nbconvert_exporter": "python",
1943
+ "pygments_lexer": "ipython3",
1944
+ "version": "3.8.6"
1945
+ },
1946
+ "vscode": {
1947
+ "interpreter": {
1948
+ "hash": "eba1599a1f7e611c14c87ccff6793920aa63510b01fc0e229d6dd014149b8829"
1949
+ }
1950
+ }
1951
+ },
1952
+ "nbformat": 4,
1953
+ "nbformat_minor": 5
1954
+ }
examples/pretrain_geneformer_w_deepspeed.py CHANGED
@@ -23,7 +23,7 @@ import torch
23
  from datasets import load_from_disk
24
  from transformers import BertConfig, BertForMaskedLM, TrainingArguments
25
 
26
- from .trainer import GeneformerTrainer
27
 
28
  seed_num = 0
29
  random.seed(seed_num)
@@ -149,7 +149,7 @@ training_args = TrainingArguments(**training_args)
149
  print("Starting training.")
150
 
151
  # define the trainer
152
- trainer = GeneformerTrainer(
153
  model=model,
154
  args=training_args,
155
  # pretraining corpus (e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048.dataset)
 
23
  from datasets import load_from_disk
24
  from transformers import BertConfig, BertForMaskedLM, TrainingArguments
25
 
26
+ from geneformer import GeneformerPretrainer
27
 
28
  seed_num = 0
29
  random.seed(seed_num)
 
149
  print("Starting training.")
150
 
151
  # define the trainer
152
+ trainer = GeneformerPretrainer(
153
  model=model,
154
  args=training_args,
155
  # pretraining corpus (e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048.dataset)
geneformer/__init__.py CHANGED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from . import tokenizer
2
+ from . import pretrainer
3
+ from . import collator_for_cell_classification
4
+ from . import collator_for_gene_classification
5
+ from .tokenizer import TranscriptomeTokenizer
6
+ from .pretrainer import GeneformerPretrainer
7
+ from .collator_for_gene_classification import DataCollatorForGeneClassification
8
+ from .collator_for_cell_classification import DataCollatorForCellClassification
geneformer/collator_for_cell_classification.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer collator for cell classification.
3
+
4
+ Huggingface data collator modified to accommodate single-cell transcriptomics data for cell classification.
5
+ """
6
+ import numpy as np
7
+ import torch
8
+ import warnings
9
+ from enum import Enum
10
+ from typing import Dict, List, Optional, Union
11
+
12
+ from transformers import (
13
+ DataCollatorForTokenClassification,
14
+ SpecialTokensMixin,
15
+ BatchEncoding,
16
+ )
17
+ from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
18
+ from transformers.utils.generic import _is_tensorflow, _is_torch
19
+
20
+ from .pretrainer import token_dictionary
21
+
22
+ EncodedInput = List[int]
23
+ logger = logging.get_logger(__name__)
24
+ VERY_LARGE_INTEGER = int(
25
+ 1e30
26
+ ) # This is used to set the max input length for a model with infinite size input
27
+ LARGE_INTEGER = int(
28
+ 1e20
29
+ ) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
30
+
31
+ # precollator functions
32
+
33
+ def run_once(f):
34
+ def wrapper(*args, **kwargs):
35
+ if not wrapper.has_run:
36
+ wrapper.has_run = True
37
+ return f(*args, **kwargs)
38
+ wrapper.has_run = False
39
+ return wrapper
40
+
41
+ @run_once
42
+ def check_output_once(output):
43
+ return print(output)
44
+
45
+ class ExplicitEnum(Enum):
46
+ """
47
+ Enum with more explicit error message for missing values.
48
+ """
49
+
50
+ @classmethod
51
+ def _missing_(cls, value):
52
+ raise ValueError(
53
+ "%r is not a valid %s, please select one of %s"
54
+ % (value, cls.__name__, str(list(cls._value2member_map_.keys())))
55
+ )
56
+
57
+ class TruncationStrategy(ExplicitEnum):
58
+ """
59
+ Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
60
+ tab-completion in an IDE.
61
+ """
62
+
63
+ ONLY_FIRST = "only_first"
64
+ ONLY_SECOND = "only_second"
65
+ LONGEST_FIRST = "longest_first"
66
+ DO_NOT_TRUNCATE = "do_not_truncate"
67
+
68
+
69
+
70
+ class PaddingStrategy(ExplicitEnum):
71
+ """
72
+ Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
73
+ in an IDE.
74
+ """
75
+
76
+ LONGEST = "longest"
77
+ MAX_LENGTH = "max_length"
78
+ DO_NOT_PAD = "do_not_pad"
79
+
80
+
81
+
82
+ class TensorType(ExplicitEnum):
83
+ """
84
+ Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
85
+ tab-completion in an IDE.
86
+ """
87
+
88
+ PYTORCH = "pt"
89
+ TENSORFLOW = "tf"
90
+ NUMPY = "np"
91
+ JAX = "jax"
92
+
93
+
94
+ class PrecollatorForCellClassification(SpecialTokensMixin):
95
+ mask_token = "<mask>"
96
+ mask_token_id = token_dictionary.get("<mask>")
97
+ pad_token = "<pad>"
98
+ pad_token_id = token_dictionary.get("<pad>")
99
+ padding_side = "right"
100
+ all_special_ids = [
101
+ token_dictionary.get("<mask>"),
102
+ token_dictionary.get("<pad>")
103
+ ]
104
+ model_input_names = ["input_ids"]
105
+
106
+ def _get_padding_truncation_strategies(
107
+ self, padding=True, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
108
+ ):
109
+ """
110
+ Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
111
+ and pad_to_max_length) and behaviors.
112
+ """
113
+ old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate")
114
+ old_pad_to_max_length = kwargs.pop("pad_to_max_length", False)
115
+
116
+ # Backward compatibility for previous behavior, maybe we should deprecate it:
117
+ # If you only set max_length, it activates truncation for max_length
118
+ if max_length is not None and padding is False and truncation is False:
119
+ if verbose:
120
+ if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False):
121
+ logger.warning(
122
+ "Truncation was not explicitly activated but `max_length` is provided a specific value, "
123
+ "please use `truncation=True` to explicitly truncate examples to max length. "
124
+ "Defaulting to 'longest_first' truncation strategy. "
125
+ "If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy "
126
+ "more precisely by providing a specific strategy to `truncation`."
127
+ )
128
+ self.deprecation_warnings["Truncation-not-explicitly-activated"] = True
129
+ truncation = "longest_first"
130
+
131
+ # Get padding strategy
132
+ if padding is False and old_pad_to_max_length:
133
+ if verbose:
134
+ warnings.warn(
135
+ "The `pad_to_max_length` argument is deprecated and will be removed in a future version, "
136
+ "use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or "
137
+ "use `padding='max_length'` to pad to a max length. In this case, you can give a specific "
138
+ "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the "
139
+ "maximal input size of the model (e.g. 512 for Bert).",
140
+ FutureWarning,
141
+ )
142
+ if max_length is None:
143
+ padding_strategy = PaddingStrategy.LONGEST
144
+ else:
145
+ padding_strategy = PaddingStrategy.MAX_LENGTH
146
+ elif padding is not False:
147
+ if padding is True:
148
+ padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
149
+ elif not isinstance(padding, PaddingStrategy):
150
+ padding_strategy = PaddingStrategy(padding)
151
+ elif isinstance(padding, PaddingStrategy):
152
+ padding_strategy = padding
153
+ else:
154
+ padding_strategy = PaddingStrategy.DO_NOT_PAD
155
+
156
+ # Get truncation strategy
157
+ if truncation is False and old_truncation_strategy != "do_not_truncate":
158
+ if verbose:
159
+ warnings.warn(
160
+ "The `truncation_strategy` argument is deprecated and will be removed in a future version, "
161
+ "use `truncation=True` to truncate examples to a max length. You can give a specific "
162
+ "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the "
163
+ "maximal input size of the model (e.g. 512 for Bert). "
164
+ " If you have pairs of inputs, you can give a specific truncation strategy selected among "
165
+ "`truncation='only_first'` (will only truncate the first sentence in the pairs) "
166
+ "`truncation='only_second'` (will only truncate the second sentence in the pairs) "
167
+ "or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).",
168
+ FutureWarning,
169
+ )
170
+ truncation_strategy = TruncationStrategy(old_truncation_strategy)
171
+ elif truncation is not False:
172
+ if truncation is True:
173
+ truncation_strategy = (
174
+ TruncationStrategy.LONGEST_FIRST
175
+ ) # Default to truncate the longest sequences in pairs of inputs
176
+ elif not isinstance(truncation, TruncationStrategy):
177
+ truncation_strategy = TruncationStrategy(truncation)
178
+ elif isinstance(truncation, TruncationStrategy):
179
+ truncation_strategy = truncation
180
+ else:
181
+ truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
182
+
183
+ # Set max length if needed
184
+ if max_length is None:
185
+ if padding_strategy == PaddingStrategy.MAX_LENGTH:
186
+ if self.model_max_length > LARGE_INTEGER:
187
+ if verbose:
188
+ if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False):
189
+ logger.warning(
190
+ "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
191
+ "Default to no padding."
192
+ )
193
+ self.deprecation_warnings["Asking-to-pad-to-max_length"] = True
194
+ padding_strategy = PaddingStrategy.DO_NOT_PAD
195
+ else:
196
+ max_length = self.model_max_length
197
+
198
+ if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
199
+ if self.model_max_length > LARGE_INTEGER:
200
+ if verbose:
201
+ if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False):
202
+ logger.warning(
203
+ "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
204
+ "Default to no truncation."
205
+ )
206
+ self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True
207
+ truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
208
+ else:
209
+ max_length = self.model_max_length
210
+
211
+ # Test if we have a padding token
212
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD and (not self.pad_token or self.pad_token_id < 0):
213
+ raise ValueError(
214
+ "Asking to pad but the tokenizer does not have a padding token. "
215
+ "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
216
+ "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
217
+ )
218
+
219
+ # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
220
+ if (
221
+ truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
222
+ and padding_strategy != PaddingStrategy.DO_NOT_PAD
223
+ and pad_to_multiple_of is not None
224
+ and max_length is not None
225
+ and (max_length % pad_to_multiple_of != 0)
226
+ ):
227
+ raise ValueError(
228
+ f"Truncation and padding are both activated but "
229
+ f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
230
+ )
231
+
232
+ return padding_strategy, truncation_strategy, max_length, kwargs
233
+
234
+ def pad(
235
+ self,
236
+ encoded_inputs: Union[
237
+ BatchEncoding,
238
+ List[BatchEncoding],
239
+ Dict[str, EncodedInput],
240
+ Dict[str, List[EncodedInput]],
241
+ List[Dict[str, EncodedInput]],
242
+ ],
243
+ padding: Union[bool, str, PaddingStrategy] = True,
244
+ max_length: Optional[int] = None,
245
+ pad_to_multiple_of: Optional[int] = None,
246
+ return_attention_mask: Optional[bool] = True,
247
+ return_tensors: Optional[Union[str, TensorType]] = None,
248
+ verbose: bool = True,
249
+ ) -> BatchEncoding:
250
+ """
251
+ Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
252
+ in the batch.
253
+
254
+ Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``,
255
+ ``self.pad_token_id`` and ``self.pad_token_type_id``)
256
+
257
+ .. note::
258
+
259
+ If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
260
+ result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
261
+ case of PyTorch tensors, you will lose the specific device of your tensors however.
262
+
263
+ Args:
264
+ encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
265
+ Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str,
266
+ List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str,
267
+ List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as
268
+ well as in a PyTorch Dataloader collate function.
269
+
270
+ Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
271
+ see the note above for the return type.
272
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
273
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
274
+ index) among:
275
+
276
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
277
+ single sequence if provided).
278
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
279
+ maximum acceptable input length for the model if that argument is not provided.
280
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
281
+ different lengths).
282
+ max_length (:obj:`int`, `optional`):
283
+ Maximum length of the returned list and optionally padding length (see above).
284
+ pad_to_multiple_of (:obj:`int`, `optional`):
285
+ If set will pad the sequence to a multiple of the provided value.
286
+
287
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
288
+ >= 7.5 (Volta).
289
+ return_attention_mask (:obj:`bool`, `optional`):
290
+ Whether to return the attention mask. If left to the default, will return the attention mask according
291
+ to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
292
+
293
+ `What are attention masks? <../glossary.html#attention-mask>`__
294
+ return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
295
+ If set, will return tensors instead of list of python integers. Acceptable values are:
296
+
297
+ * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
298
+ * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
299
+ * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
300
+ verbose (:obj:`bool`, `optional`, defaults to :obj:`True`):
301
+ Whether or not to print more information and warnings.
302
+ """
303
+ # If we have a list of dicts, let's convert it in a dict of lists
304
+ # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
305
+ if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
306
+ encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
307
+
308
+ # The model's main input name, usually `input_ids`, has be passed for padding
309
+ if self.model_input_names[0] not in encoded_inputs:
310
+ raise ValueError(
311
+ "You should supply an encoding or a list of encodings to this method"
312
+ f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
313
+ )
314
+
315
+ required_input = encoded_inputs[self.model_input_names[0]]
316
+
317
+ if not required_input:
318
+ if return_attention_mask:
319
+ encoded_inputs["attention_mask"] = []
320
+ return encoded_inputs
321
+
322
+ # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
323
+ # and rebuild them afterwards if no return_tensors is specified
324
+ # Note that we lose the specific device the tensor may be on for PyTorch
325
+
326
+ first_element = required_input[0]
327
+ if isinstance(first_element, (list, tuple)):
328
+ # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
329
+ index = 0
330
+ while len(required_input[index]) == 0:
331
+ index += 1
332
+ if index < len(required_input):
333
+ first_element = required_input[index][0]
334
+ # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
335
+ if not isinstance(first_element, (int, list, tuple)):
336
+ if is_tf_available() and _is_tensorflow(first_element):
337
+ return_tensors = "tf" if return_tensors is None else return_tensors
338
+ elif is_torch_available() and _is_torch(first_element):
339
+ return_tensors = "pt" if return_tensors is None else return_tensors
340
+ elif isinstance(first_element, np.ndarray):
341
+ return_tensors = "np" if return_tensors is None else return_tensors
342
+ else:
343
+ raise ValueError(
344
+ f"type of {first_element} unknown: {type(first_element)}. "
345
+ f"Should be one of a python, numpy, pytorch or tensorflow object."
346
+ )
347
+
348
+ for key, value in encoded_inputs.items():
349
+ encoded_inputs[key] = to_py_obj(value)
350
+
351
+ # Convert padding_strategy in PaddingStrategy
352
+ padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
353
+ padding=padding, max_length=max_length, verbose=verbose
354
+ )
355
+
356
+ required_input = encoded_inputs[self.model_input_names[0]]
357
+ if required_input and not isinstance(required_input[0], (list, tuple)):
358
+ encoded_inputs = self._pad(
359
+ encoded_inputs,
360
+ max_length=max_length,
361
+ padding_strategy=padding_strategy,
362
+ pad_to_multiple_of=pad_to_multiple_of,
363
+ return_attention_mask=return_attention_mask,
364
+ )
365
+ return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
366
+
367
+ batch_size = len(required_input)
368
+ assert all(
369
+ len(v) == batch_size for v in encoded_inputs.values()
370
+ ), "Some items in the output dictionary have a different batch size than others."
371
+
372
+ if padding_strategy == PaddingStrategy.LONGEST:
373
+ max_length = max(len(inputs) for inputs in required_input)
374
+ padding_strategy = PaddingStrategy.MAX_LENGTH
375
+
376
+ batch_outputs = {}
377
+ for i in range(batch_size):
378
+ inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
379
+ outputs = self._pad(
380
+ inputs,
381
+ max_length=max_length,
382
+ padding_strategy=padding_strategy,
383
+ pad_to_multiple_of=pad_to_multiple_of,
384
+ return_attention_mask=return_attention_mask,
385
+ )
386
+
387
+ for key, value in outputs.items():
388
+ if key not in batch_outputs:
389
+ batch_outputs[key] = []
390
+ batch_outputs[key].append(value)
391
+ del batch_outputs["label"]
392
+ return BatchEncoding(batch_outputs, tensor_type=return_tensors)
393
+
394
+ def _pad(
395
+ self,
396
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
397
+ max_length: Optional[int] = None,
398
+ padding_strategy: PaddingStrategy = PaddingStrategy.LONGEST,
399
+ pad_to_multiple_of: Optional[int] = None,
400
+ return_attention_mask: Optional[bool] = True,
401
+ ) -> dict:
402
+ """
403
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
404
+
405
+ Args:
406
+ encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
407
+ max_length: maximum length of the returned list and optionally padding length (see below).
408
+ Will truncate by taking into account the special tokens.
409
+ padding_strategy: PaddingStrategy to use for padding.
410
+
411
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
412
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
413
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
414
+ The tokenizer padding sides are defined in self.padding_side:
415
+
416
+ - 'left': pads on the left of the sequences
417
+ - 'right': pads on the right of the sequences
418
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
419
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
420
+ >= 7.5 (Volta).
421
+ return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
422
+ """
423
+ # Load from model defaults
424
+ if return_attention_mask is None:
425
+ return_attention_mask = "attention_mask" in self.model_input_names
426
+
427
+ required_input = encoded_inputs[self.model_input_names[0]]
428
+
429
+ if padding_strategy == PaddingStrategy.LONGEST:
430
+ max_length = len(required_input)
431
+
432
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
433
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
434
+
435
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
436
+
437
+ if needs_to_be_padded:
438
+ difference = max_length - len(required_input)
439
+ if self.padding_side == "right":
440
+ if return_attention_mask:
441
+ encoded_inputs["attention_mask"] = [1] * len(required_input) + [0] * difference
442
+ if "token_type_ids" in encoded_inputs:
443
+ encoded_inputs["token_type_ids"] = (
444
+ encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
445
+ )
446
+ if "special_tokens_mask" in encoded_inputs:
447
+ encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
448
+ encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
449
+ elif self.padding_side == "left":
450
+ if return_attention_mask:
451
+ encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input)
452
+ if "token_type_ids" in encoded_inputs:
453
+ encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
454
+ "token_type_ids"
455
+ ]
456
+ if "special_tokens_mask" in encoded_inputs:
457
+ encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
458
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
459
+ else:
460
+ raise ValueError("Invalid padding strategy:" + str(self.padding_side))
461
+ elif return_attention_mask and "attention_mask" not in encoded_inputs:
462
+ encoded_inputs["attention_mask"] = [1] * len(required_input)
463
+
464
+ # check_output_once(encoded_inputs)
465
+
466
+ return encoded_inputs
467
+
468
+ def get_special_tokens_mask(
469
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
470
+ ) -> List[int]:
471
+ """
472
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
473
+ special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
474
+ Args:
475
+ token_ids_0 (:obj:`List[int]`):
476
+ List of ids of the first sequence.
477
+ token_ids_1 (:obj:`List[int]`, `optional`):
478
+ List of ids of the second sequence.
479
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
480
+ Whether or not the token list is already formatted with special tokens for the model.
481
+ Returns:
482
+ A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
483
+ """
484
+ assert already_has_special_tokens and token_ids_1 is None, (
485
+ "You cannot use ``already_has_special_tokens=False`` with this tokenizer. "
486
+ "Please use a slow (full python) tokenizer to activate this argument."
487
+ "Or set `return_special_tokens_mask=True` when calling the encoding method "
488
+ "to get the special tokens mask in any tokenizer. "
489
+ )
490
+
491
+ all_special_ids = self.all_special_ids # cache the property
492
+
493
+ special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0]
494
+
495
+ return special_tokens_mask
496
+
497
+ def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
498
+ """
499
+ Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
500
+ vocabulary.
501
+ Args:
502
+ tokens (:obj:`str` or :obj:`List[str]`): One or several token(s) to convert to token id(s).
503
+ Returns:
504
+ :obj:`int` or :obj:`List[int]`: The token id or list of token ids.
505
+ """
506
+ if tokens is None:
507
+ return None
508
+
509
+ if isinstance(tokens, str):
510
+ return self._convert_token_to_id_with_added_voc(tokens)
511
+
512
+ ids = []
513
+ for token in tokens:
514
+ ids.append(self._convert_token_to_id_with_added_voc(token))
515
+ return ids
516
+
517
+ def _convert_token_to_id_with_added_voc(self, token):
518
+ if token is None:
519
+ return None
520
+
521
+ return token_dictionary.get(token)
522
+
523
+ def __len__(self):
524
+ return len(token_dictionary)
525
+
526
+
527
+ # collator functions
528
+
529
+ class DataCollatorForCellClassification(DataCollatorForTokenClassification):
530
+ """
531
+ Data collator that will dynamically pad the inputs received, as well as the labels.
532
+ Args:
533
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
534
+ The tokenizer used for encoding the data.
535
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
536
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
537
+ among:
538
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
539
+ sequence if provided).
540
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
541
+ maximum acceptable input length for the model if that argument is not provided.
542
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
543
+ different lengths).
544
+ max_length (:obj:`int`, `optional`):
545
+ Maximum length of the returned list and optionally padding length (see above).
546
+ pad_to_multiple_of (:obj:`int`, `optional`):
547
+ If set will pad the sequence to a multiple of the provided value.
548
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
549
+ 7.5 (Volta).
550
+ label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
551
+ The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
552
+ """
553
+
554
+ tokenizer: PrecollatorForCellClassification()
555
+ padding: Union[bool, str, PaddingStrategy] = True
556
+ max_length: Optional[int] = None
557
+ pad_to_multiple_of: Optional[int] = None
558
+ label_pad_token_id: int = -100
559
+
560
+ def __call__(self, features):
561
+ label_name = "label" if "label" in features[0].keys() else "labels"
562
+ labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
563
+ batch = self.tokenizer.pad(
564
+ features,
565
+ padding=self.padding,
566
+ max_length=self.max_length,
567
+ pad_to_multiple_of=self.pad_to_multiple_of,
568
+ return_tensors="pt",
569
+ )
570
+
571
+ # Special handling for labels.
572
+ # Ensure that tensor is created with the correct type
573
+ # (it should be automatically the case, but let's make sure of it.)
574
+ first = features[0]
575
+ if "label" in first and first["label"] is not None:
576
+ label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
577
+ dtype = torch.long if isinstance(label, int) else torch.float
578
+ batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
579
+
580
+ batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
581
+ return batch
geneformer/{trainer.py → pretrainer.py} RENAMED
@@ -1,7 +1,7 @@
1
  """
2
- Geneformer trainer and collator.
3
 
4
- Huggingface trainer and data collator modified to accommodate single-cell transcriptomics data.
5
  """
6
  import collections
7
  import math
@@ -589,7 +589,7 @@ class GeneformerPreCollator(SpecialTokensMixin):
589
  return len(self.token_dictionary)
590
 
591
 
592
- class GeneformerTrainer(Trainer):
593
  def __init__(self, *args, **kwargs):
594
  data_collator = kwargs.get("data_collator")
595
  token_dictionary = kwargs.get("token_dictionary")
 
1
  """
2
+ Geneformer precollator and pretrainer.
3
 
4
+ Huggingface data collator and trainer modified to accommodate single-cell transcriptomics data.
5
  """
6
  import collections
7
  import math
 
589
  return len(self.token_dictionary)
590
 
591
 
592
+ class GeneformerPretrainer(Trainer):
593
  def __init__(self, *args, **kwargs):
594
  data_collator = kwargs.get("data_collator")
595
  token_dictionary = kwargs.get("token_dictionary")
geneformer/tokenizer.py CHANGED
@@ -2,8 +2,8 @@
2
  Geneformer tokenizer.
3
 
4
  Usage:
5
- from geneformer.tokenizer import Tokenizer
6
- tk = Tokenizer({"cell_type": "cell_type", "organ_major": "organ_major"}, nproc=4)
7
  tk.tokenize_data("loom_data_directory", "output_directory", "output_prefix")
8
  """
9
 
@@ -32,7 +32,7 @@ def tokenize_cell(gene_vector, gene_tokens):
32
  return sentence_tokens
33
 
34
 
35
- class Tokenizer:
36
  def __init__(
37
  self,
38
  custom_attr_name_dict,
 
2
  Geneformer tokenizer.
3
 
4
  Usage:
5
+ from geneformer import TranscriptomeTokenizer
6
+ tk = TranscriptomeTokenizer({"cell_type": "cell_type", "organ_major": "organ_major"}, nproc=4)
7
  tk.tokenize_data("loom_data_directory", "output_directory", "output_prefix")
8
  """
9
 
 
32
  return sentence_tokens
33
 
34
 
35
+ class TranscriptomeTokenizer:
36
  def __init__(
37
  self,
38
  custom_attr_name_dict,