Christina Theodoris commited on
Commit
79a0c41
·
1 Parent(s): b2aee1b

Add example for hyperparameter optimization for disease classifier

Browse files
examples/hyperparam_optimiz_for_disease_classifier.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # hyperparameter optimization with raytune for disease classification
5
+
6
+ # imports
7
+ import os
8
+ import subprocess
9
+ GPU_NUMBER = [0,1,2,3]
10
+ os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
11
+ os.environ["NCCL_DEBUG"] = "INFO"
12
+ os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56"
13
+ os.environ["LD_LIBRARY_PATH"] = "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"
14
+
15
+ # initiate runtime environment for raytune
16
+ import pyarrow # must occur prior to ray import
17
+ import ray
18
+ from ray import tune
19
+ from ray.tune import ExperimentAnalysis
20
+ from ray.tune.suggest.hyperopt import HyperOptSearch
21
+ runtime_env = {"conda": "base",
22
+ "env_vars": {"LD_LIBRARY_PATH": "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"}}
23
+ ray.init(runtime_env=runtime_env)
24
+
25
+ import datetime
26
+ import numpy as np
27
+ import pandas as pd
28
+ import random
29
+ import seaborn as sns; sns.set()
30
+ from collections import Counter
31
+ from datasets import load_from_disk
32
+ from scipy.stats import ranksums
33
+ from sklearn.metrics import accuracy_score
34
+ from transformers import BertForSequenceClassification
35
+ from transformers import Trainer
36
+ from transformers.training_args import TrainingArguments
37
+
38
+ from geneformer import DataCollatorForCellClassification
39
+
40
+ # number of CPU cores
41
+ num_proc=30
42
+
43
+ # load train dataset with columns:
44
+ # cell_type (annotation of each cell's type)
45
+ # disease (healthy or disease state)
46
+ # individual (unique ID for each patient)
47
+ # length (length of that cell's rank value encoding)
48
+ train_dataset=load_from_disk("/path/to/disease_train_data.dataset")
49
+
50
+ # filter dataset for given cell_type
51
+ def if_cell_type(example):
52
+ return example["cell_type"].startswith("Cardiomyocyte")
53
+
54
+ trainset_v2 = train_dataset.filter(if_cell_type, num_proc=num_proc)
55
+
56
+ # create dictionary of disease states : label ids
57
+ target_names = ["healthy", "disease1", "disease2"]
58
+ target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
59
+
60
+ trainset_v3 = trainset_v2.rename_column("disease","label")
61
+
62
+ # change labels to numerical ids
63
+ def classes_to_ids(example):
64
+ example["label"] = target_name_id_dict[example["label"]]
65
+ return example
66
+
67
+ trainset_v4 = trainset_v3.map(classes_to_ids, num_proc=num_proc)
68
+
69
+ # separate into train, validation, test sets
70
+ indiv_list = trainset_v4["individual"]
71
+ random.seed(42)
72
+ train_indiv = random.sample(indiv_list,round(0.7*len(indiv_list)))
73
+ eval_indiv = [indiv for indiv in indiv_list if indiv not in train_indiv]
74
+ valid_indiv = random.sample(eval_indiv,round(0.5*len(eval_indiv)))
75
+ test_indiv = [indiv for indiv in eval_indiv if indiv not in valid_indiv]
76
+
77
+ def if_train(example):
78
+ return example["individual"] in train_indiv
79
+
80
+ classifier_trainset = trainset_v4.filter(if_train,num_proc=num_proc).shuffle(seed=42)
81
+
82
+ def if_valid(example):
83
+ return example["individual"] in valid_indiv
84
+
85
+ classifier_validset = trainset_v4.filter(if_valid,num_proc=num_proc).shuffle(seed=42)
86
+
87
+ # define output directory path
88
+ current_date = datetime.datetime.now()
89
+ datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
90
+ output_dir = f"/path/to/models/{datestamp}_geneformer_DiseaseClassifier/"
91
+
92
+ # ensure not overwriting previously saved model
93
+ saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
94
+ if os.path.isfile(saved_model_test) == True:
95
+ raise Exception("Model already saved to this directory.")
96
+
97
+ # make output directory
98
+ subprocess.call(f'mkdir {output_dir}', shell=True)
99
+
100
+ # set training parameters
101
+ # how many pretrained layers to freeze
102
+ freeze_layers = 2
103
+ # batch size for training and eval
104
+ geneformer_batch_size = 12
105
+ # number of epochs
106
+ epochs = 1
107
+ # logging steps
108
+ logging_steps = round(len(classifier_trainset)/geneformer_batch_size/10)
109
+
110
+ # define function to initiate model
111
+ def model_init():
112
+ model = BertForSequenceClassification.from_pretrained("/path/to/pretrained_model/",
113
+ num_labels=len(target_names),
114
+ output_attentions = False,
115
+ output_hidden_states = False)
116
+ if freeze_layers is not None:
117
+ modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
118
+ for module in modules_to_freeze:
119
+ for param in module.parameters():
120
+ param.requires_grad = False
121
+
122
+ model = model.to("cuda:0")
123
+ return model
124
+
125
+ # define metrics
126
+ def compute_metrics(pred):
127
+ labels = pred.label_ids
128
+ preds = pred.predictions.argmax(-1)
129
+ # calculate accuracy using sklearn's function
130
+ acc = accuracy_score(labels, preds)
131
+ return {
132
+ 'accuracy': acc,
133
+ }
134
+
135
+ # set training arguments
136
+ training_args = {
137
+ "do_train": True,
138
+ "do_eval": True,
139
+ "evaluation_strategy": "steps",
140
+ "eval_steps": logging_steps,
141
+ "logging_steps": logging_steps,
142
+ "group_by_length": True,
143
+ "length_column_name": "length",
144
+ "disable_tqdm": True,
145
+ "skip_memory_metrics": True, # memory tracker causes errors in raytune
146
+ "per_device_train_batch_size": geneformer_batch_size,
147
+ "per_device_eval_batch_size": geneformer_batch_size,
148
+ "num_train_epochs": epochs,
149
+ "load_best_model_at_end": True,
150
+ "output_dir": output_dir,
151
+ }
152
+
153
+ training_args_init = TrainingArguments(**training_args)
154
+
155
+ # create the trainer
156
+ trainer = Trainer(
157
+ model_init=model_init,
158
+ args=training_args_init,
159
+ data_collator=DataCollatorForCellClassification(),
160
+ train_dataset=classifier_trainset,
161
+ eval_dataset=classifier_validset,
162
+ compute_metrics=compute_metrics,
163
+ )
164
+
165
+ # specify raytune hyperparameter search space
166
+ ray_config = {
167
+ "num_train_epochs": tune.choice([epochs]),
168
+ "learning_rate": tune.loguniform(1e-6, 1e-3),
169
+ "weight_decay": tune.uniform(0.0, 0.3),
170
+ "lr_scheduler_type": tune.choice(["linear","cosine","polynomial"]),
171
+ "warmup_steps": tune.uniform(100, 2000),
172
+ "seed": tune.uniform(0,100),
173
+ "per_device_train_batch_size": tune.choice([geneformer_batch_size])
174
+ }
175
+
176
+ hyperopt_search = HyperOptSearch(
177
+ metric="eval_accuracy", mode="max")
178
+
179
+ # optimize hyperparameters
180
+ trainer.hyperparameter_search(
181
+ direction="maximize",
182
+ backend="ray",
183
+ resources_per_trial={"cpu":8,"gpu":1},
184
+ hp_space=lambda _: ray_config,
185
+ search_alg=hyperopt_search,
186
+ n_trials=100, # number of trials
187
+ progress_reporter=tune.CLIReporter(max_report_frequency=600,
188
+ sort_by_metric=True,
189
+ max_progress_rows=100,
190
+ mode="max",
191
+ metric="eval_accuracy",
192
+ metric_columns=["loss", "eval_loss", "eval_accuracy"])
193
+ )