ctheodoris
commited on
Commit
·
f07bfd7
1
Parent(s):
933ca80
precommit formatting
Browse files- .gitattributes +1 -1
- examples/pretraining_new_model/pretrain_geneformer_w_deepspeed.py +3 -1
- geneformer/__init__.py +3 -2
- geneformer/classifier.py +67 -26
- geneformer/classifier_utils.py +115 -43
- geneformer/collator_for_classification.py +116 -51
- geneformer/emb_extractor.py +32 -22
- geneformer/evaluation_utils.py +1 -1
- geneformer/in_silico_perturber.py +153 -113
- geneformer/in_silico_perturber_stats.py +29 -19
- geneformer/mtl/collators.py +11 -4
- geneformer/mtl/data.py +56 -22
- geneformer/mtl/eval_utils.py +23 -16
- geneformer/mtl/imports.py +24 -27
- geneformer/mtl/model.py +52 -15
- geneformer/mtl/optuna_utils.py +13 -7
- geneformer/mtl/train.py +212 -69
- geneformer/mtl/train_utils.py +65 -30
- geneformer/mtl/utils.py +50 -27
- geneformer/mtl_classifier.py +82 -59
- geneformer/perturber_utils.py +48 -32
- geneformer/tokenizer.py +136 -63
.gitattributes
CHANGED
@@ -18,7 +18,7 @@
|
|
18 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
19 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
20 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
21 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
22 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
23 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
24 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
|
|
18 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
19 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
20 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
21 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
22 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
23 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
24 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
examples/pretraining_new_model/pretrain_geneformer_w_deepspeed.py
CHANGED
@@ -138,7 +138,9 @@ training_args = {
|
|
138 |
"per_device_train_batch_size": geneformer_batch_size,
|
139 |
"num_train_epochs": epochs,
|
140 |
"save_strategy": "steps",
|
141 |
-
"save_steps": np.floor(
|
|
|
|
|
142 |
"logging_steps": 1000,
|
143 |
"output_dir": training_output_dir,
|
144 |
"logging_dir": logging_dir,
|
|
|
138 |
"per_device_train_batch_size": geneformer_batch_size,
|
139 |
"num_train_epochs": epochs,
|
140 |
"save_strategy": "steps",
|
141 |
+
"save_steps": np.floor(
|
142 |
+
num_examples / geneformer_batch_size / 8
|
143 |
+
), # 8 saves per epoch
|
144 |
"logging_steps": 1000,
|
145 |
"output_dir": training_output_dir,
|
146 |
"logging_dir": logging_dir,
|
geneformer/__init__.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
# ruff: noqa: F401
|
2 |
-
from pathlib import Path
|
3 |
import warnings
|
|
|
|
|
4 |
warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa # isort:skip
|
5 |
|
6 |
GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc95M.pkl"
|
@@ -30,4 +31,4 @@ from . import classifier # noqa # isort:skip
|
|
30 |
from .classifier import Classifier # noqa # isort:skip
|
31 |
|
32 |
from . import mtl_classifier # noqa # isort:skip
|
33 |
-
from .mtl_classifier import MTLClassifier # noqa # isort:skip
|
|
|
1 |
# ruff: noqa: F401
|
|
|
2 |
import warnings
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa # isort:skip
|
6 |
|
7 |
GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc95M.pkl"
|
|
|
31 |
from .classifier import Classifier # noqa # isort:skip
|
32 |
|
33 |
from . import mtl_classifier # noqa # isort:skip
|
34 |
+
from .mtl_classifier import MTLClassifier # noqa # isort:skip
|
geneformer/classifier.py
CHANGED
@@ -57,11 +57,14 @@ from tqdm.auto import tqdm, trange
|
|
57 |
from transformers import Trainer
|
58 |
from transformers.training_args import TrainingArguments
|
59 |
|
60 |
-
from . import
|
|
|
|
|
|
|
|
|
61 |
from . import classifier_utils as cu
|
62 |
from . import evaluation_utils as eu
|
63 |
from . import perturber_utils as pu
|
64 |
-
from . import TOKEN_DICTIONARY_FILE
|
65 |
|
66 |
sns.set()
|
67 |
|
@@ -413,7 +416,7 @@ class Classifier:
|
|
413 |
"Column name 'labels' must be reserved for class IDs. Please rename column."
|
414 |
)
|
415 |
raise
|
416 |
-
|
417 |
if (attr_to_split is not None) and (attr_to_balance is None):
|
418 |
logger.error(
|
419 |
"Splitting by attribute while balancing confounders requires both attr_to_split and attr_to_balance to be defined."
|
@@ -548,17 +551,19 @@ class Classifier:
|
|
548 |
gene_balance : None, bool
|
549 |
| Whether to automatically balance genes in training set.
|
550 |
| Only available for binary gene classifications.
|
551 |
-
|
552 |
**Output**
|
553 |
|
554 |
Returns trainer after fine-tuning with all data.
|
555 |
|
556 |
"""
|
557 |
|
558 |
-
if (gene_balance is True) and (len(self.gene_class_dict.values())!=2):
|
559 |
-
logger.error(
|
|
|
|
|
560 |
raise
|
561 |
-
|
562 |
##### Load data and prepare output directory #####
|
563 |
# load numerical id to class dictionary (id:class)
|
564 |
with open(id_class_dict_file, "rb") as f:
|
@@ -679,11 +684,13 @@ class Classifier:
|
|
679 |
if self.num_crossval_splits == 0:
|
680 |
logger.error("num_crossval_splits must be 1 or 5 to validate.")
|
681 |
raise
|
682 |
-
|
683 |
-
if (gene_balance is True) and (len(self.gene_class_dict.values())!=2):
|
684 |
-
logger.error(
|
|
|
|
|
685 |
raise
|
686 |
-
|
687 |
# ensure number of genes in each class is > 5 if validating model
|
688 |
if self.classifier == "gene":
|
689 |
insuff_classes = [k for k, v in self.gene_class_dict.items() if len(v) < 5]
|
@@ -845,17 +852,18 @@ class Classifier:
|
|
845 |
self.nproc,
|
846 |
gene_balance,
|
847 |
)
|
848 |
-
|
849 |
if save_gene_split_datasets is True:
|
850 |
for split_name in ["train", "valid"]:
|
851 |
labeled_dataset_output_path = (
|
852 |
-
|
853 |
-
|
|
|
854 |
if split_name == "train":
|
855 |
train_data.save_to_disk(str(labeled_dataset_output_path))
|
856 |
elif split_name == "valid":
|
857 |
eval_data.save_to_disk(str(labeled_dataset_output_path))
|
858 |
-
|
859 |
if self.oos_test_size > 0:
|
860 |
test_data = cu.prep_gene_classifier_split(
|
861 |
data,
|
@@ -869,11 +877,14 @@ class Classifier:
|
|
869 |
)
|
870 |
if save_gene_split_datasets is True:
|
871 |
test_labeled_dataset_output_path = (
|
872 |
-
|
873 |
-
|
|
|
874 |
test_data.save_to_disk(str(test_labeled_dataset_output_path))
|
875 |
if debug_gene_split_datasets is True:
|
876 |
-
logger.error(
|
|
|
|
|
877 |
raise
|
878 |
if n_hyperopt_trials == 0:
|
879 |
trainer = self.train_classifier(
|
@@ -1023,7 +1034,13 @@ class Classifier:
|
|
1023 |
subprocess.call(f"mkdir {output_directory}", shell=True)
|
1024 |
|
1025 |
##### Load model and training args #####
|
1026 |
-
model = pu.load_model(
|
|
|
|
|
|
|
|
|
|
|
|
|
1027 |
def_training_args, def_freeze_layers = cu.get_default_train_args(
|
1028 |
model, self.classifier, train_data, output_directory
|
1029 |
)
|
@@ -1047,14 +1064,22 @@ class Classifier:
|
|
1047 |
##### Fine-tune the model #####
|
1048 |
# define the data collator
|
1049 |
if self.classifier == "cell":
|
1050 |
-
data_collator = DataCollatorForCellClassification(
|
|
|
|
|
1051 |
elif self.classifier == "gene":
|
1052 |
-
data_collator = DataCollatorForGeneClassification(
|
|
|
|
|
1053 |
|
1054 |
# define function to initiate model
|
1055 |
def model_init():
|
1056 |
model = pu.load_model(
|
1057 |
-
self.model_type,
|
|
|
|
|
|
|
|
|
1058 |
)
|
1059 |
|
1060 |
if self.freeze_layers is not None:
|
@@ -1180,7 +1205,13 @@ class Classifier:
|
|
1180 |
subprocess.call(f"mkdir {output_directory}", shell=True)
|
1181 |
|
1182 |
##### Load model and training args #####
|
1183 |
-
model = pu.load_model(
|
|
|
|
|
|
|
|
|
|
|
|
|
1184 |
|
1185 |
def_training_args, def_freeze_layers = cu.get_default_train_args(
|
1186 |
model, self.classifier, train_data, output_directory
|
@@ -1210,9 +1241,13 @@ class Classifier:
|
|
1210 |
##### Fine-tune the model #####
|
1211 |
# define the data collator
|
1212 |
if self.classifier == "cell":
|
1213 |
-
data_collator = DataCollatorForCellClassification(
|
|
|
|
|
1214 |
elif self.classifier == "gene":
|
1215 |
-
data_collator = DataCollatorForGeneClassification(
|
|
|
|
|
1216 |
|
1217 |
# create the trainer
|
1218 |
trainer = Trainer(
|
@@ -1334,7 +1369,13 @@ class Classifier:
|
|
1334 |
test_data = pu.load_and_filter(None, self.nproc, test_data_file)
|
1335 |
|
1336 |
# load previously fine-tuned model
|
1337 |
-
model = pu.load_model(
|
|
|
|
|
|
|
|
|
|
|
|
|
1338 |
|
1339 |
# evaluate the model
|
1340 |
result = self.evaluate_model(
|
|
|
57 |
from transformers import Trainer
|
58 |
from transformers.training_args import TrainingArguments
|
59 |
|
60 |
+
from . import (
|
61 |
+
TOKEN_DICTIONARY_FILE,
|
62 |
+
DataCollatorForCellClassification,
|
63 |
+
DataCollatorForGeneClassification,
|
64 |
+
)
|
65 |
from . import classifier_utils as cu
|
66 |
from . import evaluation_utils as eu
|
67 |
from . import perturber_utils as pu
|
|
|
68 |
|
69 |
sns.set()
|
70 |
|
|
|
416 |
"Column name 'labels' must be reserved for class IDs. Please rename column."
|
417 |
)
|
418 |
raise
|
419 |
+
|
420 |
if (attr_to_split is not None) and (attr_to_balance is None):
|
421 |
logger.error(
|
422 |
"Splitting by attribute while balancing confounders requires both attr_to_split and attr_to_balance to be defined."
|
|
|
551 |
gene_balance : None, bool
|
552 |
| Whether to automatically balance genes in training set.
|
553 |
| Only available for binary gene classifications.
|
554 |
+
|
555 |
**Output**
|
556 |
|
557 |
Returns trainer after fine-tuning with all data.
|
558 |
|
559 |
"""
|
560 |
|
561 |
+
if (gene_balance is True) and (len(self.gene_class_dict.values()) != 2):
|
562 |
+
logger.error(
|
563 |
+
"Automatically balancing gene sets for training is only available for binary gene classifications."
|
564 |
+
)
|
565 |
raise
|
566 |
+
|
567 |
##### Load data and prepare output directory #####
|
568 |
# load numerical id to class dictionary (id:class)
|
569 |
with open(id_class_dict_file, "rb") as f:
|
|
|
684 |
if self.num_crossval_splits == 0:
|
685 |
logger.error("num_crossval_splits must be 1 or 5 to validate.")
|
686 |
raise
|
687 |
+
|
688 |
+
if (gene_balance is True) and (len(self.gene_class_dict.values()) != 2):
|
689 |
+
logger.error(
|
690 |
+
"Automatically balancing gene sets for training is only available for binary gene classifications."
|
691 |
+
)
|
692 |
raise
|
693 |
+
|
694 |
# ensure number of genes in each class is > 5 if validating model
|
695 |
if self.classifier == "gene":
|
696 |
insuff_classes = [k for k, v in self.gene_class_dict.items() if len(v) < 5]
|
|
|
852 |
self.nproc,
|
853 |
gene_balance,
|
854 |
)
|
855 |
+
|
856 |
if save_gene_split_datasets is True:
|
857 |
for split_name in ["train", "valid"]:
|
858 |
labeled_dataset_output_path = (
|
859 |
+
Path(output_dir)
|
860 |
+
/ f"{output_prefix}_{split_name}_gene_labeled_ksplit{iteration_num}"
|
861 |
+
).with_suffix(".dataset")
|
862 |
if split_name == "train":
|
863 |
train_data.save_to_disk(str(labeled_dataset_output_path))
|
864 |
elif split_name == "valid":
|
865 |
eval_data.save_to_disk(str(labeled_dataset_output_path))
|
866 |
+
|
867 |
if self.oos_test_size > 0:
|
868 |
test_data = cu.prep_gene_classifier_split(
|
869 |
data,
|
|
|
877 |
)
|
878 |
if save_gene_split_datasets is True:
|
879 |
test_labeled_dataset_output_path = (
|
880 |
+
Path(output_dir)
|
881 |
+
/ f"{output_prefix}_test_gene_labeled_ksplit{iteration_num}"
|
882 |
+
).with_suffix(".dataset")
|
883 |
test_data.save_to_disk(str(test_labeled_dataset_output_path))
|
884 |
if debug_gene_split_datasets is True:
|
885 |
+
logger.error(
|
886 |
+
"Exiting after saving gene split datasets given debug_gene_split_datasets = True."
|
887 |
+
)
|
888 |
raise
|
889 |
if n_hyperopt_trials == 0:
|
890 |
trainer = self.train_classifier(
|
|
|
1034 |
subprocess.call(f"mkdir {output_directory}", shell=True)
|
1035 |
|
1036 |
##### Load model and training args #####
|
1037 |
+
model = pu.load_model(
|
1038 |
+
self.model_type,
|
1039 |
+
num_classes,
|
1040 |
+
model_directory,
|
1041 |
+
"train",
|
1042 |
+
quantize=self.quantize,
|
1043 |
+
)
|
1044 |
def_training_args, def_freeze_layers = cu.get_default_train_args(
|
1045 |
model, self.classifier, train_data, output_directory
|
1046 |
)
|
|
|
1064 |
##### Fine-tune the model #####
|
1065 |
# define the data collator
|
1066 |
if self.classifier == "cell":
|
1067 |
+
data_collator = DataCollatorForCellClassification(
|
1068 |
+
token_dictionary=self.token_dictionary
|
1069 |
+
)
|
1070 |
elif self.classifier == "gene":
|
1071 |
+
data_collator = DataCollatorForGeneClassification(
|
1072 |
+
token_dictionary=self.token_dictionary
|
1073 |
+
)
|
1074 |
|
1075 |
# define function to initiate model
|
1076 |
def model_init():
|
1077 |
model = pu.load_model(
|
1078 |
+
self.model_type,
|
1079 |
+
num_classes,
|
1080 |
+
model_directory,
|
1081 |
+
"train",
|
1082 |
+
quantize=self.quantize,
|
1083 |
)
|
1084 |
|
1085 |
if self.freeze_layers is not None:
|
|
|
1205 |
subprocess.call(f"mkdir {output_directory}", shell=True)
|
1206 |
|
1207 |
##### Load model and training args #####
|
1208 |
+
model = pu.load_model(
|
1209 |
+
self.model_type,
|
1210 |
+
num_classes,
|
1211 |
+
model_directory,
|
1212 |
+
"train",
|
1213 |
+
quantize=self.quantize,
|
1214 |
+
)
|
1215 |
|
1216 |
def_training_args, def_freeze_layers = cu.get_default_train_args(
|
1217 |
model, self.classifier, train_data, output_directory
|
|
|
1241 |
##### Fine-tune the model #####
|
1242 |
# define the data collator
|
1243 |
if self.classifier == "cell":
|
1244 |
+
data_collator = DataCollatorForCellClassification(
|
1245 |
+
token_dictionary=self.token_dictionary
|
1246 |
+
)
|
1247 |
elif self.classifier == "gene":
|
1248 |
+
data_collator = DataCollatorForGeneClassification(
|
1249 |
+
token_dictionary=self.token_dictionary
|
1250 |
+
)
|
1251 |
|
1252 |
# create the trainer
|
1253 |
trainer = Trainer(
|
|
|
1369 |
test_data = pu.load_and_filter(None, self.nproc, test_data_file)
|
1370 |
|
1371 |
# load previously fine-tuned model
|
1372 |
+
model = pu.load_model(
|
1373 |
+
self.model_type,
|
1374 |
+
num_classes,
|
1375 |
+
model_directory,
|
1376 |
+
"eval",
|
1377 |
+
quantize=self.quantize,
|
1378 |
+
)
|
1379 |
|
1380 |
# evaluate the model
|
1381 |
result = self.evaluate_model(
|
geneformer/classifier_utils.py
CHANGED
@@ -137,22 +137,53 @@ def label_gene_classes(example, class_id_dict, gene_class_dict):
|
|
137 |
|
138 |
|
139 |
def prep_gene_classifier_train_eval_split(
|
140 |
-
data,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
):
|
142 |
# generate cross-validation splits
|
143 |
train_data = prep_gene_classifier_split(
|
144 |
-
data,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
)
|
146 |
eval_data = prep_gene_classifier_split(
|
147 |
-
data,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
)
|
149 |
return train_data, eval_data
|
150 |
|
151 |
|
152 |
def prep_gene_classifier_split(
|
153 |
-
data,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
):
|
155 |
-
|
156 |
# generate cross-validation splits
|
157 |
targets = np.array(targets)
|
158 |
labels = np.array(labels)
|
@@ -175,7 +206,9 @@ def prep_gene_classifier_split(
|
|
175 |
|
176 |
# balance gene subsets if train
|
177 |
if (subset_name == "train") and (balance is True):
|
178 |
-
subset_data, label_dict_subset = balance_gene_split(
|
|
|
|
|
179 |
|
180 |
# subsample to max_ncells
|
181 |
subset_data = downsample_and_shuffle(subset_data, max_ncells, None, None)
|
@@ -192,7 +225,9 @@ def prep_gene_classifier_split(
|
|
192 |
return subset_data
|
193 |
|
194 |
|
195 |
-
def prep_gene_classifier_all_data(
|
|
|
|
|
196 |
targets = np.array(targets)
|
197 |
labels = np.array(labels)
|
198 |
label_dict_train = dict(zip(targets, labels))
|
@@ -211,8 +246,10 @@ def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc, b
|
|
211 |
)
|
212 |
|
213 |
if balance is True:
|
214 |
-
train_data, label_dict_train = balance_gene_split(
|
215 |
-
|
|
|
|
|
216 |
# subsample to max_ncells
|
217 |
train_data = downsample_and_shuffle(train_data, max_ncells, None, None)
|
218 |
|
@@ -230,62 +267,97 @@ def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc, b
|
|
230 |
|
231 |
def balance_gene_split(subset_data, label_dict_subset, num_proc):
|
232 |
# count occurrence of genes in each label category
|
233 |
-
label0_counts, label1_counts = count_genes_for_balancing(
|
234 |
-
|
235 |
-
|
236 |
-
|
|
|
|
|
237 |
# gene sets already balanced
|
238 |
logger.info(
|
239 |
"Gene sets were already balanced within 0.8-1.25 fold and did not require balancing.\n"
|
240 |
)
|
241 |
return subset_data, label_dict_subset
|
242 |
else:
|
243 |
-
label_ratio_0to1_orig = label_ratio_0to1+0
|
244 |
label_dict_subset_orig = label_dict_subset.copy()
|
245 |
# balance gene sets
|
246 |
max_ntrials = 25
|
247 |
boost = 1
|
248 |
-
if label_ratio_0to1 > 10/8:
|
249 |
# downsample label 0
|
250 |
for i in range(max_ntrials):
|
251 |
label0 = 0
|
252 |
-
label0_genes = [k for k,v in label_dict_subset.items() if v == label0]
|
253 |
label0_ngenes = len(label0_genes)
|
254 |
-
label0_nremove = max(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
random.seed(i)
|
256 |
label0_remove_genes = random.sample(label0_genes, label0_nremove)
|
257 |
-
label_dict_subset_new = {
|
258 |
-
|
259 |
-
|
260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
# if gene sets now balanced, return new filtered data and new label_dict_subset
|
262 |
-
return filter_data_balanced_genes(
|
263 |
-
|
264 |
-
|
265 |
-
elif label_ratio_0to1
|
266 |
-
boost = boost*
|
|
|
|
|
267 |
else:
|
268 |
# downsample label 1
|
269 |
for i in range(max_ntrials):
|
270 |
label1 = 1
|
271 |
-
label1_genes = [k for k,v in label_dict_subset.items() if v == label1]
|
272 |
label1_ngenes = len(label1_genes)
|
273 |
-
label1_nremove = max(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
random.seed(i)
|
275 |
label1_remove_genes = random.sample(label1_genes, label1_nremove)
|
276 |
-
label_dict_subset_new = {
|
277 |
-
|
278 |
-
|
279 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
# if gene sets now balanced, return new filtered data and new label_dict_subset
|
281 |
-
return filter_data_balanced_genes(
|
282 |
-
|
283 |
-
|
284 |
-
elif label_ratio_0to1
|
285 |
-
boost = boost*
|
286 |
-
|
287 |
-
|
288 |
-
|
|
|
|
|
|
|
|
|
289 |
label_ratio_0to1 = label_ratio_0to1_orig
|
290 |
label_dict_subset_new = label_dict_subset_orig
|
291 |
logger.warning(
|
@@ -301,11 +373,11 @@ def count_genes_for_balancing(subset_data, label_dict_subset, num_proc):
|
|
301 |
]
|
302 |
counter_labels = Counter(labels)
|
303 |
# get count of labels 0 or 1, or if absent, return 0
|
304 |
-
example["labels_counts"] = [counter_labels.get(0,0),counter_labels.get(1,0)]
|
305 |
return example
|
306 |
-
|
307 |
subset_data = subset_data.map(count_targets, num_proc=num_proc)
|
308 |
-
|
309 |
label0_counts = sum([counts[0] for counts in subset_data["labels_counts"]])
|
310 |
label1_counts = sum([counts[1] for counts in subset_data["labels_counts"]])
|
311 |
|
|
|
137 |
|
138 |
|
139 |
def prep_gene_classifier_train_eval_split(
|
140 |
+
data,
|
141 |
+
targets,
|
142 |
+
labels,
|
143 |
+
train_index,
|
144 |
+
eval_index,
|
145 |
+
max_ncells,
|
146 |
+
iteration_num,
|
147 |
+
num_proc,
|
148 |
+
balance=False,
|
149 |
):
|
150 |
# generate cross-validation splits
|
151 |
train_data = prep_gene_classifier_split(
|
152 |
+
data,
|
153 |
+
targets,
|
154 |
+
labels,
|
155 |
+
train_index,
|
156 |
+
"train",
|
157 |
+
max_ncells,
|
158 |
+
iteration_num,
|
159 |
+
num_proc,
|
160 |
+
balance,
|
161 |
)
|
162 |
eval_data = prep_gene_classifier_split(
|
163 |
+
data,
|
164 |
+
targets,
|
165 |
+
labels,
|
166 |
+
eval_index,
|
167 |
+
"eval",
|
168 |
+
max_ncells,
|
169 |
+
iteration_num,
|
170 |
+
num_proc,
|
171 |
+
balance,
|
172 |
)
|
173 |
return train_data, eval_data
|
174 |
|
175 |
|
176 |
def prep_gene_classifier_split(
|
177 |
+
data,
|
178 |
+
targets,
|
179 |
+
labels,
|
180 |
+
index,
|
181 |
+
subset_name,
|
182 |
+
max_ncells,
|
183 |
+
iteration_num,
|
184 |
+
num_proc,
|
185 |
+
balance=False,
|
186 |
):
|
|
|
187 |
# generate cross-validation splits
|
188 |
targets = np.array(targets)
|
189 |
labels = np.array(labels)
|
|
|
206 |
|
207 |
# balance gene subsets if train
|
208 |
if (subset_name == "train") and (balance is True):
|
209 |
+
subset_data, label_dict_subset = balance_gene_split(
|
210 |
+
subset_data, label_dict_subset, num_proc
|
211 |
+
)
|
212 |
|
213 |
# subsample to max_ncells
|
214 |
subset_data = downsample_and_shuffle(subset_data, max_ncells, None, None)
|
|
|
225 |
return subset_data
|
226 |
|
227 |
|
228 |
+
def prep_gene_classifier_all_data(
|
229 |
+
data, targets, labels, max_ncells, num_proc, balance=False
|
230 |
+
):
|
231 |
targets = np.array(targets)
|
232 |
labels = np.array(labels)
|
233 |
label_dict_train = dict(zip(targets, labels))
|
|
|
246 |
)
|
247 |
|
248 |
if balance is True:
|
249 |
+
train_data, label_dict_train = balance_gene_split(
|
250 |
+
train_data, label_dict_train, num_proc
|
251 |
+
)
|
252 |
+
|
253 |
# subsample to max_ncells
|
254 |
train_data = downsample_and_shuffle(train_data, max_ncells, None, None)
|
255 |
|
|
|
267 |
|
268 |
def balance_gene_split(subset_data, label_dict_subset, num_proc):
|
269 |
# count occurrence of genes in each label category
|
270 |
+
label0_counts, label1_counts = count_genes_for_balancing(
|
271 |
+
subset_data, label_dict_subset, num_proc
|
272 |
+
)
|
273 |
+
label_ratio_0to1 = label0_counts / label1_counts
|
274 |
+
|
275 |
+
if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
|
276 |
# gene sets already balanced
|
277 |
logger.info(
|
278 |
"Gene sets were already balanced within 0.8-1.25 fold and did not require balancing.\n"
|
279 |
)
|
280 |
return subset_data, label_dict_subset
|
281 |
else:
|
282 |
+
label_ratio_0to1_orig = label_ratio_0to1 + 0
|
283 |
label_dict_subset_orig = label_dict_subset.copy()
|
284 |
# balance gene sets
|
285 |
max_ntrials = 25
|
286 |
boost = 1
|
287 |
+
if label_ratio_0to1 > 10 / 8:
|
288 |
# downsample label 0
|
289 |
for i in range(max_ntrials):
|
290 |
label0 = 0
|
291 |
+
label0_genes = [k for k, v in label_dict_subset.items() if v == label0]
|
292 |
label0_ngenes = len(label0_genes)
|
293 |
+
label0_nremove = max(
|
294 |
+
1,
|
295 |
+
int(
|
296 |
+
np.floor(
|
297 |
+
label0_ngenes - label0_ngenes / (label_ratio_0to1 * boost)
|
298 |
+
)
|
299 |
+
),
|
300 |
+
)
|
301 |
random.seed(i)
|
302 |
label0_remove_genes = random.sample(label0_genes, label0_nremove)
|
303 |
+
label_dict_subset_new = {
|
304 |
+
k: v
|
305 |
+
for k, v in label_dict_subset.items()
|
306 |
+
if k not in label0_remove_genes
|
307 |
+
}
|
308 |
+
label0_counts, label1_counts = count_genes_for_balancing(
|
309 |
+
subset_data, label_dict_subset_new, num_proc
|
310 |
+
)
|
311 |
+
label_ratio_0to1 = label0_counts / label1_counts
|
312 |
+
if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
|
313 |
# if gene sets now balanced, return new filtered data and new label_dict_subset
|
314 |
+
return filter_data_balanced_genes(
|
315 |
+
subset_data, label_dict_subset_new, num_proc
|
316 |
+
)
|
317 |
+
elif label_ratio_0to1 > 10 / 8:
|
318 |
+
boost = boost * 1.1
|
319 |
+
elif label_ratio_0to1 < 8 / 10:
|
320 |
+
boost = boost * 0.9
|
321 |
else:
|
322 |
# downsample label 1
|
323 |
for i in range(max_ntrials):
|
324 |
label1 = 1
|
325 |
+
label1_genes = [k for k, v in label_dict_subset.items() if v == label1]
|
326 |
label1_ngenes = len(label1_genes)
|
327 |
+
label1_nremove = max(
|
328 |
+
1,
|
329 |
+
int(
|
330 |
+
np.floor(
|
331 |
+
label1_ngenes
|
332 |
+
- label1_ngenes / ((1 / label_ratio_0to1) * boost)
|
333 |
+
)
|
334 |
+
),
|
335 |
+
)
|
336 |
random.seed(i)
|
337 |
label1_remove_genes = random.sample(label1_genes, label1_nremove)
|
338 |
+
label_dict_subset_new = {
|
339 |
+
k: v
|
340 |
+
for k, v in label_dict_subset.items()
|
341 |
+
if k not in label1_remove_genes
|
342 |
+
}
|
343 |
+
label0_counts, label1_counts = count_genes_for_balancing(
|
344 |
+
subset_data, label_dict_subset_new, num_proc
|
345 |
+
)
|
346 |
+
label_ratio_0to1 = label0_counts / label1_counts
|
347 |
+
if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
|
348 |
# if gene sets now balanced, return new filtered data and new label_dict_subset
|
349 |
+
return filter_data_balanced_genes(
|
350 |
+
subset_data, label_dict_subset_new, num_proc
|
351 |
+
)
|
352 |
+
elif label_ratio_0to1 < 8 / 10:
|
353 |
+
boost = boost * 1.1
|
354 |
+
elif label_ratio_0to1 > 10 / 8:
|
355 |
+
boost = boost * 0.9
|
356 |
+
|
357 |
+
assert i + 1 == max_ntrials
|
358 |
+
if (label_ratio_0to1 <= label_ratio_0to1_orig < 8 / 10) or (
|
359 |
+
10 / 8 > label_ratio_0to1_orig >= label_ratio_0to1
|
360 |
+
):
|
361 |
label_ratio_0to1 = label_ratio_0to1_orig
|
362 |
label_dict_subset_new = label_dict_subset_orig
|
363 |
logger.warning(
|
|
|
373 |
]
|
374 |
counter_labels = Counter(labels)
|
375 |
# get count of labels 0 or 1, or if absent, return 0
|
376 |
+
example["labels_counts"] = [counter_labels.get(0, 0), counter_labels.get(1, 0)]
|
377 |
return example
|
378 |
+
|
379 |
subset_data = subset_data.map(count_targets, num_proc=num_proc)
|
380 |
+
|
381 |
label0_counts = sum([counts[0] for counts in subset_data["labels_counts"]])
|
382 |
label1_counts = sum([counts[1] for counts in subset_data["labels_counts"]])
|
383 |
|
geneformer/collator_for_classification.py
CHANGED
@@ -3,17 +3,17 @@ Geneformer collator for gene and cell classification.
|
|
3 |
|
4 |
Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification.
|
5 |
"""
|
6 |
-
|
7 |
-
import pickle
|
8 |
-
import torch
|
9 |
import warnings
|
10 |
from enum import Enum
|
11 |
from typing import Dict, List, Optional, Union
|
12 |
|
|
|
|
|
13 |
from transformers import (
|
|
|
14 |
DataCollatorForTokenClassification,
|
15 |
SpecialTokensMixin,
|
16 |
-
BatchEncoding,
|
17 |
)
|
18 |
from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
|
19 |
from transformers.utils.generic import _is_tensorflow, _is_torch
|
@@ -29,6 +29,7 @@ LARGE_INTEGER = int(
|
|
29 |
|
30 |
# precollator functions
|
31 |
|
|
|
32 |
class ExplicitEnum(Enum):
|
33 |
"""
|
34 |
Enum with more explicit error message for missing values.
|
@@ -41,6 +42,7 @@ class ExplicitEnum(Enum):
|
|
41 |
% (value, cls.__name__, str(list(cls._value2member_map_.keys())))
|
42 |
)
|
43 |
|
|
|
44 |
class TruncationStrategy(ExplicitEnum):
|
45 |
"""
|
46 |
Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
|
@@ -53,7 +55,6 @@ class TruncationStrategy(ExplicitEnum):
|
|
53 |
DO_NOT_TRUNCATE = "do_not_truncate"
|
54 |
|
55 |
|
56 |
-
|
57 |
class PaddingStrategy(ExplicitEnum):
|
58 |
"""
|
59 |
Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
|
@@ -65,7 +66,6 @@ class PaddingStrategy(ExplicitEnum):
|
|
65 |
DO_NOT_PAD = "do_not_pad"
|
66 |
|
67 |
|
68 |
-
|
69 |
class TensorType(ExplicitEnum):
|
70 |
"""
|
71 |
Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
|
@@ -77,7 +77,7 @@ class TensorType(ExplicitEnum):
|
|
77 |
NUMPY = "np"
|
78 |
JAX = "jax"
|
79 |
|
80 |
-
|
81 |
class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
82 |
def __init__(self, *args, **kwargs) -> None:
|
83 |
super().__init__(mask_token="<mask>", pad_token="<pad>")
|
@@ -89,11 +89,17 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
89 |
self.pad_token_id = self.token_dictionary.get("<pad>")
|
90 |
self.all_special_ids = [
|
91 |
self.token_dictionary.get("<mask>"),
|
92 |
-
self.token_dictionary.get("<pad>")
|
93 |
]
|
94 |
|
95 |
def _get_padding_truncation_strategies(
|
96 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
):
|
98 |
"""
|
99 |
Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
|
@@ -106,7 +112,9 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
106 |
# If you only set max_length, it activates truncation for max_length
|
107 |
if max_length is not None and padding is False and truncation is False:
|
108 |
if verbose:
|
109 |
-
if not self.deprecation_warnings.get(
|
|
|
|
|
110 |
logger.warning(
|
111 |
"Truncation was not explicitly activated but `max_length` is provided a specific value, "
|
112 |
"please use `truncation=True` to explicitly truncate examples to max length. "
|
@@ -134,7 +142,9 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
134 |
padding_strategy = PaddingStrategy.MAX_LENGTH
|
135 |
elif padding is not False:
|
136 |
if padding is True:
|
137 |
-
padding_strategy =
|
|
|
|
|
138 |
elif not isinstance(padding, PaddingStrategy):
|
139 |
padding_strategy = PaddingStrategy(padding)
|
140 |
elif isinstance(padding, PaddingStrategy):
|
@@ -174,7 +184,9 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
174 |
if padding_strategy == PaddingStrategy.MAX_LENGTH:
|
175 |
if self.model_max_length > LARGE_INTEGER:
|
176 |
if verbose:
|
177 |
-
if not self.deprecation_warnings.get(
|
|
|
|
|
178 |
logger.warning(
|
179 |
"Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
|
180 |
"Default to no padding."
|
@@ -187,18 +199,24 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
187 |
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
|
188 |
if self.model_max_length > LARGE_INTEGER:
|
189 |
if verbose:
|
190 |
-
if not self.deprecation_warnings.get(
|
|
|
|
|
191 |
logger.warning(
|
192 |
"Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
|
193 |
"Default to no truncation."
|
194 |
)
|
195 |
-
self.deprecation_warnings[
|
|
|
|
|
196 |
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
|
197 |
else:
|
198 |
max_length = self.model_max_length
|
199 |
|
200 |
# Test if we have a padding token
|
201 |
-
if padding_strategy != PaddingStrategy.DO_NOT_PAD and (
|
|
|
|
|
202 |
raise ValueError(
|
203 |
"Asking to pad but the tokenizer does not have a padding token. "
|
204 |
"Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
|
@@ -229,7 +247,7 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
229 |
Dict[str, List[EncodedInput]],
|
230 |
List[Dict[str, EncodedInput]],
|
231 |
],
|
232 |
-
class_type,
|
233 |
padding: Union[bool, str, PaddingStrategy] = True,
|
234 |
max_length: Optional[int] = None,
|
235 |
pad_to_multiple_of: Optional[int] = None,
|
@@ -292,8 +310,13 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
292 |
"""
|
293 |
# If we have a list of dicts, let's convert it in a dict of lists
|
294 |
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
|
295 |
-
if isinstance(encoded_inputs, (list, tuple)) and isinstance(
|
296 |
-
encoded_inputs
|
|
|
|
|
|
|
|
|
|
|
297 |
|
298 |
# The model's main input name, usually `input_ids`, has be passed for padding
|
299 |
if self.model_input_names[0] not in encoded_inputs:
|
@@ -387,7 +410,7 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
387 |
def _pad(
|
388 |
self,
|
389 |
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
390 |
-
class_type,
|
391 |
max_length: Optional[int] = None,
|
392 |
padding_strategy: PaddingStrategy = PaddingStrategy.LONGEST,
|
393 |
pad_to_multiple_of: Optional[int] = None,
|
@@ -423,46 +446,73 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
423 |
if padding_strategy == PaddingStrategy.LONGEST:
|
424 |
max_length = len(required_input)
|
425 |
|
426 |
-
if
|
|
|
|
|
|
|
|
|
427 |
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
428 |
|
429 |
-
needs_to_be_padded =
|
|
|
|
|
|
|
430 |
|
431 |
if needs_to_be_padded:
|
432 |
difference = max_length - len(required_input)
|
433 |
if self.padding_side == "right":
|
434 |
if return_attention_mask:
|
435 |
-
encoded_inputs["attention_mask"] = [1] * len(required_input) + [
|
|
|
|
|
436 |
if "token_type_ids" in encoded_inputs:
|
437 |
encoded_inputs["token_type_ids"] = (
|
438 |
-
encoded_inputs["token_type_ids"]
|
|
|
439 |
)
|
440 |
if "special_tokens_mask" in encoded_inputs:
|
441 |
-
encoded_inputs["special_tokens_mask"] =
|
442 |
-
|
|
|
|
|
|
|
|
|
443 |
if class_type == "gene":
|
444 |
-
encoded_inputs["labels"] =
|
|
|
|
|
445 |
elif self.padding_side == "left":
|
446 |
if return_attention_mask:
|
447 |
-
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(
|
|
|
|
|
448 |
if "token_type_ids" in encoded_inputs:
|
449 |
-
encoded_inputs["token_type_ids"] = [
|
450 |
-
|
451 |
-
]
|
452 |
if "special_tokens_mask" in encoded_inputs:
|
453 |
-
encoded_inputs["special_tokens_mask"] = [
|
454 |
-
|
|
|
|
|
|
|
|
|
455 |
if class_type == "gene":
|
456 |
-
encoded_inputs["labels"] = [-100] * difference + encoded_inputs[
|
|
|
|
|
457 |
else:
|
458 |
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
|
459 |
elif return_attention_mask and "attention_mask" not in encoded_inputs:
|
460 |
encoded_inputs["attention_mask"] = [1] * len(required_input)
|
461 |
-
|
462 |
return encoded_inputs
|
463 |
|
464 |
def get_special_tokens_mask(
|
465 |
-
self,
|
|
|
|
|
|
|
466 |
) -> List[int]:
|
467 |
"""
|
468 |
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
@@ -486,11 +536,15 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
486 |
|
487 |
all_special_ids = self.all_special_ids # cache the property
|
488 |
|
489 |
-
special_tokens_mask = [
|
|
|
|
|
490 |
|
491 |
return special_tokens_mask
|
492 |
|
493 |
-
def convert_tokens_to_ids(
|
|
|
|
|
494 |
"""
|
495 |
Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
|
496 |
vocabulary.
|
@@ -514,14 +568,15 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
514 |
if token is None:
|
515 |
return None
|
516 |
|
517 |
-
return token_dictionary.get(token)
|
518 |
|
519 |
def __len__(self):
|
520 |
-
return len(token_dictionary)
|
521 |
|
522 |
|
523 |
# collator functions
|
524 |
|
|
|
525 |
class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
|
526 |
"""
|
527 |
Data collator that will dynamically pad the inputs received, as well as the labels.
|
@@ -546,26 +601,34 @@ class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
|
|
546 |
label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
|
547 |
The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
|
548 |
"""
|
549 |
-
|
550 |
class_type = "gene"
|
551 |
padding: Union[bool, str, PaddingStrategy] = True
|
552 |
max_length: Optional[int] = None
|
553 |
pad_to_multiple_of: Optional[int] = None
|
554 |
label_pad_token_id: int = -100
|
555 |
-
|
556 |
def __init__(self, *args, **kwargs) -> None:
|
557 |
self.token_dictionary = kwargs.pop("token_dictionary")
|
558 |
super().__init__(
|
559 |
-
tokenizer=PrecollatorForGeneAndCellClassification(
|
|
|
|
|
560 |
padding=self.padding,
|
561 |
max_length=self.max_length,
|
562 |
pad_to_multiple_of=self.pad_to_multiple_of,
|
563 |
label_pad_token_id=self.label_pad_token_id,
|
564 |
-
*args,
|
|
|
|
|
565 |
|
566 |
def _prepare_batch(self, features):
|
567 |
label_name = "label" if "label" in features[0].keys() else "labels"
|
568 |
-
labels =
|
|
|
|
|
|
|
|
|
569 |
batch = self.tokenizer.pad(
|
570 |
features,
|
571 |
class_type=self.class_type,
|
@@ -575,29 +638,31 @@ class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
|
|
575 |
return_tensors="pt",
|
576 |
)
|
577 |
return batch
|
578 |
-
|
579 |
def __call__(self, features):
|
580 |
batch = self._prepare_batch(features)
|
581 |
|
582 |
batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
|
583 |
return batch
|
584 |
|
585 |
-
|
586 |
-
class DataCollatorForCellClassification(DataCollatorForGeneClassification):
|
587 |
|
|
|
588 |
class_type = "cell"
|
589 |
|
590 |
def _prepare_batch(self, features):
|
591 |
-
|
592 |
batch = super()._prepare_batch(features)
|
593 |
-
|
594 |
# Special handling for labels.
|
595 |
# Ensure that tensor is created with the correct type
|
596 |
# (it should be automatically the case, but let's make sure of it.)
|
597 |
first = features[0]
|
598 |
if "label" in first and first["label"] is not None:
|
599 |
-
label =
|
|
|
|
|
|
|
|
|
600 |
dtype = torch.long if isinstance(label, int) else torch.float
|
601 |
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
|
602 |
-
|
603 |
return batch
|
|
|
3 |
|
4 |
Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification.
|
5 |
"""
|
6 |
+
|
|
|
|
|
7 |
import warnings
|
8 |
from enum import Enum
|
9 |
from typing import Dict, List, Optional, Union
|
10 |
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
from transformers import (
|
14 |
+
BatchEncoding,
|
15 |
DataCollatorForTokenClassification,
|
16 |
SpecialTokensMixin,
|
|
|
17 |
)
|
18 |
from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
|
19 |
from transformers.utils.generic import _is_tensorflow, _is_torch
|
|
|
29 |
|
30 |
# precollator functions
|
31 |
|
32 |
+
|
33 |
class ExplicitEnum(Enum):
|
34 |
"""
|
35 |
Enum with more explicit error message for missing values.
|
|
|
42 |
% (value, cls.__name__, str(list(cls._value2member_map_.keys())))
|
43 |
)
|
44 |
|
45 |
+
|
46 |
class TruncationStrategy(ExplicitEnum):
|
47 |
"""
|
48 |
Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
|
|
|
55 |
DO_NOT_TRUNCATE = "do_not_truncate"
|
56 |
|
57 |
|
|
|
58 |
class PaddingStrategy(ExplicitEnum):
|
59 |
"""
|
60 |
Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
|
|
|
66 |
DO_NOT_PAD = "do_not_pad"
|
67 |
|
68 |
|
|
|
69 |
class TensorType(ExplicitEnum):
|
70 |
"""
|
71 |
Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
|
|
|
77 |
NUMPY = "np"
|
78 |
JAX = "jax"
|
79 |
|
80 |
+
|
81 |
class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
82 |
def __init__(self, *args, **kwargs) -> None:
|
83 |
super().__init__(mask_token="<mask>", pad_token="<pad>")
|
|
|
89 |
self.pad_token_id = self.token_dictionary.get("<pad>")
|
90 |
self.all_special_ids = [
|
91 |
self.token_dictionary.get("<mask>"),
|
92 |
+
self.token_dictionary.get("<pad>"),
|
93 |
]
|
94 |
|
95 |
def _get_padding_truncation_strategies(
|
96 |
+
self,
|
97 |
+
padding=True,
|
98 |
+
truncation=False,
|
99 |
+
max_length=None,
|
100 |
+
pad_to_multiple_of=None,
|
101 |
+
verbose=True,
|
102 |
+
**kwargs,
|
103 |
):
|
104 |
"""
|
105 |
Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
|
|
|
112 |
# If you only set max_length, it activates truncation for max_length
|
113 |
if max_length is not None and padding is False and truncation is False:
|
114 |
if verbose:
|
115 |
+
if not self.deprecation_warnings.get(
|
116 |
+
"Truncation-not-explicitly-activated", False
|
117 |
+
):
|
118 |
logger.warning(
|
119 |
"Truncation was not explicitly activated but `max_length` is provided a specific value, "
|
120 |
"please use `truncation=True` to explicitly truncate examples to max length. "
|
|
|
142 |
padding_strategy = PaddingStrategy.MAX_LENGTH
|
143 |
elif padding is not False:
|
144 |
if padding is True:
|
145 |
+
padding_strategy = (
|
146 |
+
PaddingStrategy.LONGEST
|
147 |
+
) # Default to pad to the longest sequence in the batch
|
148 |
elif not isinstance(padding, PaddingStrategy):
|
149 |
padding_strategy = PaddingStrategy(padding)
|
150 |
elif isinstance(padding, PaddingStrategy):
|
|
|
184 |
if padding_strategy == PaddingStrategy.MAX_LENGTH:
|
185 |
if self.model_max_length > LARGE_INTEGER:
|
186 |
if verbose:
|
187 |
+
if not self.deprecation_warnings.get(
|
188 |
+
"Asking-to-pad-to-max_length", False
|
189 |
+
):
|
190 |
logger.warning(
|
191 |
"Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
|
192 |
"Default to no padding."
|
|
|
199 |
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
|
200 |
if self.model_max_length > LARGE_INTEGER:
|
201 |
if verbose:
|
202 |
+
if not self.deprecation_warnings.get(
|
203 |
+
"Asking-to-truncate-to-max_length", False
|
204 |
+
):
|
205 |
logger.warning(
|
206 |
"Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
|
207 |
"Default to no truncation."
|
208 |
)
|
209 |
+
self.deprecation_warnings[
|
210 |
+
"Asking-to-truncate-to-max_length"
|
211 |
+
] = True
|
212 |
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
|
213 |
else:
|
214 |
max_length = self.model_max_length
|
215 |
|
216 |
# Test if we have a padding token
|
217 |
+
if padding_strategy != PaddingStrategy.DO_NOT_PAD and (
|
218 |
+
not self.pad_token or self.pad_token_id < 0
|
219 |
+
):
|
220 |
raise ValueError(
|
221 |
"Asking to pad but the tokenizer does not have a padding token. "
|
222 |
"Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
|
|
|
247 |
Dict[str, List[EncodedInput]],
|
248 |
List[Dict[str, EncodedInput]],
|
249 |
],
|
250 |
+
class_type, # options: "gene" or "cell"
|
251 |
padding: Union[bool, str, PaddingStrategy] = True,
|
252 |
max_length: Optional[int] = None,
|
253 |
pad_to_multiple_of: Optional[int] = None,
|
|
|
310 |
"""
|
311 |
# If we have a list of dicts, let's convert it in a dict of lists
|
312 |
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
|
313 |
+
if isinstance(encoded_inputs, (list, tuple)) and isinstance(
|
314 |
+
encoded_inputs[0], (dict, BatchEncoding)
|
315 |
+
):
|
316 |
+
encoded_inputs = {
|
317 |
+
key: [example[key] for example in encoded_inputs]
|
318 |
+
for key in encoded_inputs[0].keys()
|
319 |
+
}
|
320 |
|
321 |
# The model's main input name, usually `input_ids`, has be passed for padding
|
322 |
if self.model_input_names[0] not in encoded_inputs:
|
|
|
410 |
def _pad(
|
411 |
self,
|
412 |
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
413 |
+
class_type, # options: "gene" or "cell"
|
414 |
max_length: Optional[int] = None,
|
415 |
padding_strategy: PaddingStrategy = PaddingStrategy.LONGEST,
|
416 |
pad_to_multiple_of: Optional[int] = None,
|
|
|
446 |
if padding_strategy == PaddingStrategy.LONGEST:
|
447 |
max_length = len(required_input)
|
448 |
|
449 |
+
if (
|
450 |
+
max_length is not None
|
451 |
+
and pad_to_multiple_of is not None
|
452 |
+
and (max_length % pad_to_multiple_of != 0)
|
453 |
+
):
|
454 |
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
455 |
|
456 |
+
needs_to_be_padded = (
|
457 |
+
padding_strategy != PaddingStrategy.DO_NOT_PAD
|
458 |
+
and len(required_input) != max_length
|
459 |
+
)
|
460 |
|
461 |
if needs_to_be_padded:
|
462 |
difference = max_length - len(required_input)
|
463 |
if self.padding_side == "right":
|
464 |
if return_attention_mask:
|
465 |
+
encoded_inputs["attention_mask"] = [1] * len(required_input) + [
|
466 |
+
0
|
467 |
+
] * difference
|
468 |
if "token_type_ids" in encoded_inputs:
|
469 |
encoded_inputs["token_type_ids"] = (
|
470 |
+
encoded_inputs["token_type_ids"]
|
471 |
+
+ [self.pad_token_type_id] * difference
|
472 |
)
|
473 |
if "special_tokens_mask" in encoded_inputs:
|
474 |
+
encoded_inputs["special_tokens_mask"] = (
|
475 |
+
encoded_inputs["special_tokens_mask"] + [1] * difference
|
476 |
+
)
|
477 |
+
encoded_inputs[self.model_input_names[0]] = (
|
478 |
+
required_input + [self.pad_token_id] * difference
|
479 |
+
)
|
480 |
if class_type == "gene":
|
481 |
+
encoded_inputs["labels"] = (
|
482 |
+
encoded_inputs["labels"] + [-100] * difference
|
483 |
+
)
|
484 |
elif self.padding_side == "left":
|
485 |
if return_attention_mask:
|
486 |
+
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(
|
487 |
+
required_input
|
488 |
+
)
|
489 |
if "token_type_ids" in encoded_inputs:
|
490 |
+
encoded_inputs["token_type_ids"] = [
|
491 |
+
self.pad_token_type_id
|
492 |
+
] * difference + encoded_inputs["token_type_ids"]
|
493 |
if "special_tokens_mask" in encoded_inputs:
|
494 |
+
encoded_inputs["special_tokens_mask"] = [
|
495 |
+
1
|
496 |
+
] * difference + encoded_inputs["special_tokens_mask"]
|
497 |
+
encoded_inputs[self.model_input_names[0]] = [
|
498 |
+
self.pad_token_id
|
499 |
+
] * difference + required_input
|
500 |
if class_type == "gene":
|
501 |
+
encoded_inputs["labels"] = [-100] * difference + encoded_inputs[
|
502 |
+
"labels"
|
503 |
+
]
|
504 |
else:
|
505 |
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
|
506 |
elif return_attention_mask and "attention_mask" not in encoded_inputs:
|
507 |
encoded_inputs["attention_mask"] = [1] * len(required_input)
|
508 |
+
|
509 |
return encoded_inputs
|
510 |
|
511 |
def get_special_tokens_mask(
|
512 |
+
self,
|
513 |
+
token_ids_0: List[int],
|
514 |
+
token_ids_1: Optional[List[int]] = None,
|
515 |
+
already_has_special_tokens: bool = False,
|
516 |
) -> List[int]:
|
517 |
"""
|
518 |
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
|
|
536 |
|
537 |
all_special_ids = self.all_special_ids # cache the property
|
538 |
|
539 |
+
special_tokens_mask = [
|
540 |
+
1 if token in all_special_ids else 0 for token in token_ids_0
|
541 |
+
]
|
542 |
|
543 |
return special_tokens_mask
|
544 |
|
545 |
+
def convert_tokens_to_ids(
|
546 |
+
self, tokens: Union[str, List[str]]
|
547 |
+
) -> Union[int, List[int]]:
|
548 |
"""
|
549 |
Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
|
550 |
vocabulary.
|
|
|
568 |
if token is None:
|
569 |
return None
|
570 |
|
571 |
+
return self.token_dictionary.get(token)
|
572 |
|
573 |
def __len__(self):
|
574 |
+
return len(self.token_dictionary)
|
575 |
|
576 |
|
577 |
# collator functions
|
578 |
|
579 |
+
|
580 |
class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
|
581 |
"""
|
582 |
Data collator that will dynamically pad the inputs received, as well as the labels.
|
|
|
601 |
label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
|
602 |
The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
|
603 |
"""
|
604 |
+
|
605 |
class_type = "gene"
|
606 |
padding: Union[bool, str, PaddingStrategy] = True
|
607 |
max_length: Optional[int] = None
|
608 |
pad_to_multiple_of: Optional[int] = None
|
609 |
label_pad_token_id: int = -100
|
610 |
+
|
611 |
def __init__(self, *args, **kwargs) -> None:
|
612 |
self.token_dictionary = kwargs.pop("token_dictionary")
|
613 |
super().__init__(
|
614 |
+
tokenizer=PrecollatorForGeneAndCellClassification(
|
615 |
+
token_dictionary=self.token_dictionary
|
616 |
+
),
|
617 |
padding=self.padding,
|
618 |
max_length=self.max_length,
|
619 |
pad_to_multiple_of=self.pad_to_multiple_of,
|
620 |
label_pad_token_id=self.label_pad_token_id,
|
621 |
+
*args,
|
622 |
+
**kwargs,
|
623 |
+
)
|
624 |
|
625 |
def _prepare_batch(self, features):
|
626 |
label_name = "label" if "label" in features[0].keys() else "labels"
|
627 |
+
labels = (
|
628 |
+
[feature[label_name] for feature in features]
|
629 |
+
if label_name in features[0].keys()
|
630 |
+
else None
|
631 |
+
)
|
632 |
batch = self.tokenizer.pad(
|
633 |
features,
|
634 |
class_type=self.class_type,
|
|
|
638 |
return_tensors="pt",
|
639 |
)
|
640 |
return batch
|
641 |
+
|
642 |
def __call__(self, features):
|
643 |
batch = self._prepare_batch(features)
|
644 |
|
645 |
batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
|
646 |
return batch
|
647 |
|
|
|
|
|
648 |
|
649 |
+
class DataCollatorForCellClassification(DataCollatorForGeneClassification):
|
650 |
class_type = "cell"
|
651 |
|
652 |
def _prepare_batch(self, features):
|
|
|
653 |
batch = super()._prepare_batch(features)
|
654 |
+
|
655 |
# Special handling for labels.
|
656 |
# Ensure that tensor is created with the correct type
|
657 |
# (it should be automatically the case, but let's make sure of it.)
|
658 |
first = features[0]
|
659 |
if "label" in first and first["label"] is not None:
|
660 |
+
label = (
|
661 |
+
first["label"].item()
|
662 |
+
if isinstance(first["label"], torch.Tensor)
|
663 |
+
else first["label"]
|
664 |
+
)
|
665 |
dtype = torch.long if isinstance(label, int) else torch.float
|
666 |
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
|
667 |
+
|
668 |
return batch
|
geneformer/emb_extractor.py
CHANGED
@@ -24,8 +24,8 @@ import torch
|
|
24 |
from tdigest import TDigest
|
25 |
from tqdm.auto import trange
|
26 |
|
27 |
-
from . import perturber_utils as pu
|
28 |
from . import TOKEN_DICTIONARY_FILE
|
|
|
29 |
|
30 |
logger = logging.getLogger(__name__)
|
31 |
|
@@ -45,7 +45,7 @@ def get_embs(
|
|
45 |
):
|
46 |
model_input_size = pu.get_model_input_size(model)
|
47 |
total_batch_length = len(filtered_input_data)
|
48 |
-
|
49 |
if summary_stat is None:
|
50 |
embs_list = []
|
51 |
elif summary_stat is not None:
|
@@ -73,17 +73,23 @@ def get_embs(
|
|
73 |
if emb_mode == "cls":
|
74 |
assert cls_present, "<cls> token missing in token dictionary"
|
75 |
# Check to make sure that the first token of the filtered input data is cls token
|
76 |
-
gene_token_dict = {v:k for k,v in token_gene_dict.items()}
|
77 |
cls_token_id = gene_token_dict["<cls>"]
|
78 |
-
assert
|
|
|
|
|
79 |
elif emb_mode == "cell":
|
80 |
if cls_present:
|
81 |
-
logger.warning(
|
|
|
|
|
82 |
if eos_present:
|
83 |
-
logger.warning(
|
84 |
-
|
|
|
|
|
85 |
overall_max_len = 0
|
86 |
-
|
87 |
for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
|
88 |
max_range = min(i + forward_batch_size, total_batch_length)
|
89 |
|
@@ -108,7 +114,7 @@ def get_embs(
|
|
108 |
|
109 |
if emb_mode == "cell":
|
110 |
if cls_present:
|
111 |
-
non_cls_embs = embs_i[:, 1:, :]
|
112 |
if eos_present:
|
113 |
mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 2)
|
114 |
else:
|
@@ -146,10 +152,10 @@ def get_embs(
|
|
146 |
del embs_h
|
147 |
del dict_h
|
148 |
elif emb_mode == "cls":
|
149 |
-
cls_embs = embs_i[:,0
|
150 |
embs_list.append(cls_embs)
|
151 |
del cls_embs
|
152 |
-
|
153 |
overall_max_len = max(overall_max_len, max_len)
|
154 |
del outputs
|
155 |
del minibatch
|
@@ -157,8 +163,7 @@ def get_embs(
|
|
157 |
del embs_i
|
158 |
|
159 |
torch.cuda.empty_cache()
|
160 |
-
|
161 |
-
|
162 |
if summary_stat is None:
|
163 |
if (emb_mode == "cell") or (emb_mode == "cls"):
|
164 |
embs_stack = torch.cat(embs_list, dim=0)
|
@@ -204,6 +209,7 @@ def accumulate_tdigests(embs_tdigests, mean_embs, emb_dims):
|
|
204 |
for j in range(emb_dims)
|
205 |
]
|
206 |
|
|
|
207 |
def update_tdigest_dict(embs_tdigests_dict, gene, gene_embs, emb_dims):
|
208 |
embs_tdigests_dict[gene] = accumulate_tdigests(
|
209 |
embs_tdigests_dict[gene], gene_embs, emb_dims
|
@@ -294,11 +300,13 @@ def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict, seed=0):
|
|
294 |
|
295 |
with plt.rc_context():
|
296 |
ax = sc.pl.umap(adata, color=label, show=False, **default_kwargs_dict)
|
297 |
-
ax.legend(
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
|
|
|
|
302 |
plt.show()
|
303 |
plt.savefig(output_file, bbox_inches="tight")
|
304 |
|
@@ -394,7 +402,7 @@ class EmbExtractor:
|
|
394 |
"emb_label": {None, list},
|
395 |
"labels_to_plot": {None, list},
|
396 |
"forward_batch_size": {int},
|
397 |
-
"token_dictionary_file"
|
398 |
"nproc": {int},
|
399 |
"summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
|
400 |
}
|
@@ -631,13 +639,15 @@ class EmbExtractor:
|
|
631 |
embs = embs.mean(dim=0)
|
632 |
emb_dims = pu.get_model_emb_dims(model)
|
633 |
embs_df = pd.DataFrame(
|
634 |
-
embs_df[0:emb_dims-1].mean(axis="rows"),
|
|
|
635 |
).T
|
636 |
elif self.exact_summary_stat == "exact_median":
|
637 |
embs = torch.median(embs, dim=0)[0]
|
638 |
emb_dims = pu.get_model_emb_dims(model)
|
639 |
embs_df = pd.DataFrame(
|
640 |
-
embs_df[0:emb_dims-1].median(axis="rows"),
|
|
|
641 |
).T
|
642 |
|
643 |
if cell_state is not None:
|
@@ -838,4 +848,4 @@ class EmbExtractor:
|
|
838 |
output_file = (
|
839 |
Path(output_directory) / output_prefix_label
|
840 |
).with_suffix(".pdf")
|
841 |
-
plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
|
|
|
24 |
from tdigest import TDigest
|
25 |
from tqdm.auto import trange
|
26 |
|
|
|
27 |
from . import TOKEN_DICTIONARY_FILE
|
28 |
+
from . import perturber_utils as pu
|
29 |
|
30 |
logger = logging.getLogger(__name__)
|
31 |
|
|
|
45 |
):
|
46 |
model_input_size = pu.get_model_input_size(model)
|
47 |
total_batch_length = len(filtered_input_data)
|
48 |
+
|
49 |
if summary_stat is None:
|
50 |
embs_list = []
|
51 |
elif summary_stat is not None:
|
|
|
73 |
if emb_mode == "cls":
|
74 |
assert cls_present, "<cls> token missing in token dictionary"
|
75 |
# Check to make sure that the first token of the filtered input data is cls token
|
76 |
+
gene_token_dict = {v: k for k, v in token_gene_dict.items()}
|
77 |
cls_token_id = gene_token_dict["<cls>"]
|
78 |
+
assert (
|
79 |
+
filtered_input_data["input_ids"][0][0] == cls_token_id
|
80 |
+
), "First token is not <cls> token value"
|
81 |
elif emb_mode == "cell":
|
82 |
if cls_present:
|
83 |
+
logger.warning(
|
84 |
+
"CLS token present in token dictionary, excluding from average."
|
85 |
+
)
|
86 |
if eos_present:
|
87 |
+
logger.warning(
|
88 |
+
"EOS token present in token dictionary, excluding from average."
|
89 |
+
)
|
90 |
+
|
91 |
overall_max_len = 0
|
92 |
+
|
93 |
for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
|
94 |
max_range = min(i + forward_batch_size, total_batch_length)
|
95 |
|
|
|
114 |
|
115 |
if emb_mode == "cell":
|
116 |
if cls_present:
|
117 |
+
non_cls_embs = embs_i[:, 1:, :] # Get all layers except the embs
|
118 |
if eos_present:
|
119 |
mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 2)
|
120 |
else:
|
|
|
152 |
del embs_h
|
153 |
del dict_h
|
154 |
elif emb_mode == "cls":
|
155 |
+
cls_embs = embs_i[:, 0, :].clone().detach() # CLS token layer
|
156 |
embs_list.append(cls_embs)
|
157 |
del cls_embs
|
158 |
+
|
159 |
overall_max_len = max(overall_max_len, max_len)
|
160 |
del outputs
|
161 |
del minibatch
|
|
|
163 |
del embs_i
|
164 |
|
165 |
torch.cuda.empty_cache()
|
166 |
+
|
|
|
167 |
if summary_stat is None:
|
168 |
if (emb_mode == "cell") or (emb_mode == "cls"):
|
169 |
embs_stack = torch.cat(embs_list, dim=0)
|
|
|
209 |
for j in range(emb_dims)
|
210 |
]
|
211 |
|
212 |
+
|
213 |
def update_tdigest_dict(embs_tdigests_dict, gene, gene_embs, emb_dims):
|
214 |
embs_tdigests_dict[gene] = accumulate_tdigests(
|
215 |
embs_tdigests_dict[gene], gene_embs, emb_dims
|
|
|
300 |
|
301 |
with plt.rc_context():
|
302 |
ax = sc.pl.umap(adata, color=label, show=False, **default_kwargs_dict)
|
303 |
+
ax.legend(
|
304 |
+
markerscale=2,
|
305 |
+
frameon=False,
|
306 |
+
loc="center left",
|
307 |
+
bbox_to_anchor=(1, 0.5),
|
308 |
+
ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3),
|
309 |
+
)
|
310 |
plt.show()
|
311 |
plt.savefig(output_file, bbox_inches="tight")
|
312 |
|
|
|
402 |
"emb_label": {None, list},
|
403 |
"labels_to_plot": {None, list},
|
404 |
"forward_batch_size": {int},
|
405 |
+
"token_dictionary_file": {None, str},
|
406 |
"nproc": {int},
|
407 |
"summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
|
408 |
}
|
|
|
639 |
embs = embs.mean(dim=0)
|
640 |
emb_dims = pu.get_model_emb_dims(model)
|
641 |
embs_df = pd.DataFrame(
|
642 |
+
embs_df[0 : emb_dims - 1].mean(axis="rows"),
|
643 |
+
columns=[self.exact_summary_stat],
|
644 |
).T
|
645 |
elif self.exact_summary_stat == "exact_median":
|
646 |
embs = torch.median(embs, dim=0)[0]
|
647 |
emb_dims = pu.get_model_emb_dims(model)
|
648 |
embs_df = pd.DataFrame(
|
649 |
+
embs_df[0 : emb_dims - 1].median(axis="rows"),
|
650 |
+
columns=[self.exact_summary_stat],
|
651 |
).T
|
652 |
|
653 |
if cell_state is not None:
|
|
|
848 |
output_file = (
|
849 |
Path(output_directory) / output_prefix_label
|
850 |
).with_suffix(".pdf")
|
851 |
+
plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
|
geneformer/evaluation_utils.py
CHANGED
@@ -20,8 +20,8 @@ from sklearn.metrics import (
|
|
20 |
)
|
21 |
from tqdm.auto import trange
|
22 |
|
23 |
-
from .emb_extractor import make_colorbar
|
24 |
from . import TOKEN_DICTIONARY_FILE
|
|
|
25 |
|
26 |
logger = logging.getLogger(__name__)
|
27 |
|
|
|
20 |
)
|
21 |
from tqdm.auto import trange
|
22 |
|
|
|
23 |
from . import TOKEN_DICTIONARY_FILE
|
24 |
+
from .emb_extractor import make_colorbar
|
25 |
|
26 |
logger = logging.getLogger(__name__)
|
27 |
|
geneformer/in_silico_perturber.py
CHANGED
@@ -38,16 +38,15 @@ import logging
|
|
38 |
import os
|
39 |
import pickle
|
40 |
from collections import defaultdict
|
41 |
-
from multiprocess import set_start_method
|
42 |
-
from typing import List
|
43 |
|
44 |
import torch
|
45 |
from datasets import Dataset, disable_progress_bars
|
|
|
46 |
from tqdm.auto import trange
|
47 |
|
|
|
48 |
from . import perturber_utils as pu
|
49 |
from .emb_extractor import get_embs
|
50 |
-
from . import TOKEN_DICTIONARY_FILE
|
51 |
|
52 |
disable_progress_bars()
|
53 |
|
@@ -71,7 +70,7 @@ class InSilicoPerturber:
|
|
71 |
"max_ncells": {None, int},
|
72 |
"cell_inds_to_perturb": {"all", dict},
|
73 |
"emb_layer": {-1, 0},
|
74 |
-
"token_dictionary_file"
|
75 |
"forward_batch_size": {int},
|
76 |
"nproc": {int},
|
77 |
}
|
@@ -239,13 +238,14 @@ class InSilicoPerturber:
|
|
239 |
self.cls_token_id = self.gene_token_dict.get("<cls>")
|
240 |
self.eos_token_id = self.gene_token_dict.get("<eos>")
|
241 |
|
242 |
-
|
243 |
# Identify if special token is present in the token dictionary
|
244 |
if (self.cls_token_id is not None) and (self.eos_token_id is not None):
|
245 |
self.special_token = True
|
246 |
else:
|
247 |
if "cls" in self.emb_mode:
|
248 |
-
logger.error(
|
|
|
|
|
249 |
raise
|
250 |
self.special_token = False
|
251 |
|
@@ -454,17 +454,21 @@ class InSilicoPerturber:
|
|
454 |
|
455 |
# Ensure emb_mode is cls if first token of the filtered input data is cls token
|
456 |
if self.special_token:
|
457 |
-
if (filtered_input_data["input_ids"][0][0] == self.cls_token_id) and (
|
|
|
|
|
458 |
logger.error(
|
459 |
-
|
460 |
-
|
461 |
raise
|
462 |
-
if
|
463 |
-
if (filtered_input_data["input_ids"][0][0] != self.cls_token_id) or (
|
|
|
|
|
464 |
logger.error(
|
465 |
-
|
466 |
-
|
467 |
-
raise
|
468 |
|
469 |
filtered_input_data = self.apply_additional_filters(filtered_input_data)
|
470 |
|
@@ -530,7 +534,6 @@ class InSilicoPerturber:
|
|
530 |
layer_to_quant: int,
|
531 |
output_path_prefix: str,
|
532 |
):
|
533 |
-
|
534 |
def make_group_perturbation_batch(example):
|
535 |
example_input_ids = example["input_ids"]
|
536 |
example["tokens_to_perturb"] = self.tokens_to_perturb
|
@@ -549,7 +552,9 @@ class InSilicoPerturber:
|
|
549 |
if self.perturb_type == "delete":
|
550 |
example = pu.delete_indices(example)
|
551 |
elif self.perturb_type == "overexpress":
|
552 |
-
example = pu.overexpress_tokens(
|
|
|
|
|
553 |
example["n_overflow"] = pu.calc_n_overflow(
|
554 |
self.max_len,
|
555 |
example["length"],
|
@@ -570,7 +575,7 @@ class InSilicoPerturber:
|
|
570 |
perturbed_data = filtered_input_data.map(
|
571 |
make_group_perturbation_batch, num_proc=self.nproc
|
572 |
)
|
573 |
-
|
574 |
if self.perturb_type == "overexpress":
|
575 |
filtered_input_data = filtered_input_data.add_column(
|
576 |
"n_overflow", perturbed_data["n_overflow"]
|
@@ -752,7 +757,6 @@ class InSilicoPerturber:
|
|
752 |
f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
|
753 |
)
|
754 |
|
755 |
-
|
756 |
def isp_perturb_set_special(
|
757 |
self,
|
758 |
model,
|
@@ -760,7 +764,6 @@ class InSilicoPerturber:
|
|
760 |
layer_to_quant: int,
|
761 |
output_path_prefix: str,
|
762 |
):
|
763 |
-
|
764 |
def make_group_perturbation_batch(example):
|
765 |
example_input_ids = example["input_ids"]
|
766 |
example["tokens_to_perturb"] = self.tokens_to_perturb
|
@@ -779,7 +782,9 @@ class InSilicoPerturber:
|
|
779 |
if self.perturb_type == "delete":
|
780 |
example = pu.delete_indices(example)
|
781 |
elif self.perturb_type == "overexpress":
|
782 |
-
example = pu.overexpress_tokens(
|
|
|
|
|
783 |
example["n_overflow"] = pu.calc_n_overflow(
|
784 |
self.max_len,
|
785 |
example["length"],
|
@@ -808,7 +813,7 @@ class InSilicoPerturber:
|
|
808 |
filtered_input_data = filtered_input_data.map(
|
809 |
pu.truncate_by_n_overflow_special, num_proc=self.nproc
|
810 |
)
|
811 |
-
|
812 |
if self.emb_mode == "cls_and_gene":
|
813 |
stored_gene_embs_dict = defaultdict(list)
|
814 |
|
@@ -847,28 +852,29 @@ class InSilicoPerturber:
|
|
847 |
summary_stat=None,
|
848 |
silent=True,
|
849 |
)
|
850 |
-
|
851 |
-
# Calculate the cosine similarities
|
852 |
cls_cos_sims = pu.quant_cos_sims(
|
853 |
perturbation_cls_emb,
|
854 |
original_cls_emb,
|
855 |
self.cell_states_to_model,
|
856 |
self.state_embs_dict,
|
857 |
-
emb_mode="cell"
|
858 |
-
|
|
|
859 |
# Update perturbation dictionary
|
860 |
if self.cell_states_to_model is None:
|
861 |
cos_sims_dict = self.update_perturbation_dictionary(
|
862 |
cos_sims_dict,
|
863 |
cls_cos_sims,
|
864 |
-
gene_list
|
865 |
)
|
866 |
else:
|
867 |
for state in cos_sims_dict.keys():
|
868 |
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
869 |
cos_sims_dict[state],
|
870 |
cls_cos_sims[state],
|
871 |
-
gene_list
|
872 |
)
|
873 |
|
874 |
##### CLS and Gene Embedding Mode #####
|
@@ -908,9 +914,13 @@ class InSilicoPerturber:
|
|
908 |
# remove special tokens and padding
|
909 |
original_emb = original_emb[:, 1:-1, :]
|
910 |
if self.perturb_type == "overexpress":
|
911 |
-
perturbation_emb = full_perturbation_emb[
|
|
|
|
|
912 |
elif self.perturb_type == "delete":
|
913 |
-
perturbation_emb = full_perturbation_emb[
|
|
|
|
|
914 |
|
915 |
n_perturbation_genes = perturbation_emb.size()[1]
|
916 |
|
@@ -923,8 +933,8 @@ class InSilicoPerturber:
|
|
923 |
)
|
924 |
|
925 |
# get cls emb
|
926 |
-
original_cls_emb = full_original_emb[:,0
|
927 |
-
perturbation_cls_emb = full_perturbation_emb[:,0
|
928 |
|
929 |
cls_cos_sims = pu.quant_cos_sims(
|
930 |
perturbation_cls_emb,
|
@@ -939,7 +949,10 @@ class InSilicoPerturber:
|
|
939 |
|
940 |
gene_list = minibatch["input_ids"]
|
941 |
# need to truncate gene_list
|
942 |
-
genes_to_exclude = self.tokens_to_perturb + [
|
|
|
|
|
|
|
943 |
gene_list = [
|
944 |
[g for g in genes if g not in genes_to_exclude][
|
945 |
:n_perturbation_genes
|
@@ -968,20 +981,20 @@ class InSilicoPerturber:
|
|
968 |
cos_sims_dict = self.update_perturbation_dictionary(
|
969 |
cos_sims_dict,
|
970 |
cls_cos_sims,
|
971 |
-
gene_list
|
972 |
)
|
973 |
else:
|
974 |
for state in cos_sims_dict.keys():
|
975 |
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
976 |
cos_sims_dict[state],
|
977 |
cls_cos_sims[state],
|
978 |
-
gene_list
|
979 |
)
|
980 |
del full_original_emb
|
981 |
del original_emb
|
982 |
del full_perturbation_emb
|
983 |
-
del perturbation_emb
|
984 |
-
del gene_cos_sims
|
985 |
|
986 |
del original_cls_emb
|
987 |
del perturbation_cls_emb
|
@@ -1035,12 +1048,12 @@ class InSilicoPerturber:
|
|
1035 |
summary_stat=None,
|
1036 |
silent=True,
|
1037 |
)
|
1038 |
-
|
1039 |
if self.cell_states_to_model is not None:
|
1040 |
original_cell_emb = pu.compute_nonpadded_cell_embedding(
|
1041 |
full_original_emb, "mean_pool"
|
1042 |
)
|
1043 |
-
|
1044 |
# gene_list is used to assign cos sims back to genes
|
1045 |
gene_list = example_cell["input_ids"][0][:]
|
1046 |
# need to remove the anchor gene
|
@@ -1049,15 +1062,13 @@ class InSilicoPerturber:
|
|
1049 |
gene_list.remove(token)
|
1050 |
# index 0 is not overexpressed so remove
|
1051 |
if self.perturb_type == "overexpress":
|
1052 |
-
gene_list = gene_list[
|
1053 |
-
num_inds_perturbed:
|
1054 |
-
]
|
1055 |
# remove perturbed index for gene list dict
|
1056 |
perturbed_gene_dict = {
|
1057 |
gene: gene_list[:i] + gene_list[i + 1 :]
|
1058 |
for i, gene in enumerate(gene_list)
|
1059 |
}
|
1060 |
-
|
1061 |
perturbation_batch, indices_to_perturb = pu.make_perturbation_batch(
|
1062 |
example_cell,
|
1063 |
self.perturb_type,
|
@@ -1068,11 +1079,19 @@ class InSilicoPerturber:
|
|
1068 |
)
|
1069 |
|
1070 |
ispall_total_batch_length = len(perturbation_batch)
|
1071 |
-
for i in trange(
|
1072 |
-
|
1073 |
-
|
1074 |
-
|
1075 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1076 |
|
1077 |
full_perturbation_emb = get_embs(
|
1078 |
model,
|
@@ -1085,18 +1104,20 @@ class InSilicoPerturber:
|
|
1085 |
summary_stat=None,
|
1086 |
silent=True,
|
1087 |
)
|
1088 |
-
|
1089 |
del perturbation_minibatch
|
1090 |
-
|
1091 |
# need to remove overexpressed gene to quantify cosine shifts
|
1092 |
if self.perturb_type == "overexpress":
|
1093 |
perturbation_emb = full_perturbation_emb[:, num_inds_perturbed:, :]
|
1094 |
-
|
1095 |
elif self.perturb_type == "delete":
|
1096 |
perturbation_emb = full_perturbation_emb
|
1097 |
-
|
1098 |
-
|
1099 |
-
|
|
|
|
|
1100 |
original_emb_minibatch = pu.make_comparison_batch(
|
1101 |
full_original_emb, indices_to_perturb_mini, perturb_group=False
|
1102 |
)
|
@@ -1108,12 +1129,12 @@ class InSilicoPerturber:
|
|
1108 |
emb_mode="gene",
|
1109 |
)
|
1110 |
del original_emb_minibatch
|
1111 |
-
|
1112 |
if self.cell_states_to_model is not None:
|
1113 |
perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
|
1114 |
full_perturbation_emb, "mean_pool"
|
1115 |
)
|
1116 |
-
|
1117 |
cell_cos_sims = pu.quant_cos_sims(
|
1118 |
perturbation_cell_emb,
|
1119 |
original_cell_emb,
|
@@ -1122,9 +1143,8 @@ class InSilicoPerturber:
|
|
1122 |
emb_mode="cell",
|
1123 |
)
|
1124 |
del perturbation_cell_emb
|
1125 |
-
|
1126 |
-
if self.emb_mode == "cell_and_gene":
|
1127 |
|
|
|
1128 |
for perturbation_i, perturbed_gene in enumerate(gene_list_mini):
|
1129 |
for gene_j, affected_gene in enumerate(
|
1130 |
perturbed_gene_dict[perturbed_gene]
|
@@ -1137,9 +1157,9 @@ class InSilicoPerturber:
|
|
1137 |
stored_gene_embs_dict[
|
1138 |
(perturbed_gene, affected_gene)
|
1139 |
] = gene_cos_sims[perturbation_i, gene_j].item()
|
1140 |
-
|
1141 |
del full_perturbation_emb
|
1142 |
-
|
1143 |
if self.cell_states_to_model is None:
|
1144 |
cos_sims_data = torch.mean(gene_cos_sims, dim=1)
|
1145 |
cos_sims_dict = self.update_perturbation_dictionary(
|
@@ -1155,9 +1175,9 @@ class InSilicoPerturber:
|
|
1155 |
cos_sims_data[state],
|
1156 |
gene_list_mini,
|
1157 |
)
|
1158 |
-
|
1159 |
# save dict to disk every self.clear_mem_ncells/10 (default 100) simulated cells
|
1160 |
-
if i % self.clear_mem_ncells/10 == 0:
|
1161 |
pu.write_perturbation_dictionary(
|
1162 |
cos_sims_dict,
|
1163 |
f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
|
@@ -1167,7 +1187,7 @@ class InSilicoPerturber:
|
|
1167 |
stored_gene_embs_dict,
|
1168 |
f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
|
1169 |
)
|
1170 |
-
|
1171 |
# reset and clear memory every self.clear_mem_ncells (default 1000) simulated cells or at the end of the example cell
|
1172 |
if i % self.clear_mem_ncells == 0:
|
1173 |
pickle_batch += 1
|
@@ -1176,18 +1196,21 @@ class InSilicoPerturber:
|
|
1176 |
else:
|
1177 |
cos_sims_dict = {
|
1178 |
state: defaultdict(list)
|
1179 |
-
for state in pu.get_possible_states(
|
|
|
|
|
1180 |
}
|
1181 |
-
|
1182 |
if self.emb_mode == "cell_and_gene":
|
1183 |
stored_gene_embs_dict = defaultdict(list)
|
1184 |
-
|
1185 |
torch.cuda.empty_cache()
|
1186 |
-
|
1187 |
pu.write_perturbation_dictionary(
|
1188 |
-
cos_sims_dict,
|
|
|
1189 |
)
|
1190 |
-
|
1191 |
if self.emb_mode == "cell_and_gene":
|
1192 |
pu.write_perturbation_dictionary(
|
1193 |
stored_gene_embs_dict,
|
@@ -1212,7 +1235,7 @@ class InSilicoPerturber:
|
|
1212 |
if self.cell_states_to_model is not None:
|
1213 |
del original_cell_emb
|
1214 |
torch.cuda.empty_cache()
|
1215 |
-
|
1216 |
def isp_perturb_all_special(
|
1217 |
self,
|
1218 |
model,
|
@@ -1248,7 +1271,7 @@ class InSilicoPerturber:
|
|
1248 |
self.token_gene_dict,
|
1249 |
summary_stat=None,
|
1250 |
silent=True,
|
1251 |
-
)
|
1252 |
elif self.emb_mode == "cls_and_gene":
|
1253 |
full_original_emb = get_embs(
|
1254 |
model,
|
@@ -1261,8 +1284,8 @@ class InSilicoPerturber:
|
|
1261 |
summary_stat=None,
|
1262 |
silent=True,
|
1263 |
)
|
1264 |
-
original_cls_emb = full_original_emb[:,0
|
1265 |
-
|
1266 |
# gene_list is used to assign cos sims back to genes
|
1267 |
gene_list = example_cell["input_ids"][0][:]
|
1268 |
|
@@ -1275,9 +1298,7 @@ class InSilicoPerturber:
|
|
1275 |
gene_list.remove(token)
|
1276 |
# index 0 is not overexpressed so remove
|
1277 |
if self.perturb_type == "overexpress":
|
1278 |
-
gene_list = gene_list[
|
1279 |
-
num_inds_perturbed:
|
1280 |
-
]
|
1281 |
# remove perturbed index for gene list dict
|
1282 |
perturbed_gene_dict = {
|
1283 |
gene: gene_list[:i] + gene_list[i + 1 :]
|
@@ -1294,12 +1315,20 @@ class InSilicoPerturber:
|
|
1294 |
)
|
1295 |
|
1296 |
ispall_total_batch_length = len(perturbation_batch)
|
1297 |
-
for i in trange(
|
1298 |
-
|
1299 |
-
|
1300 |
-
|
1301 |
-
|
1302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1303 |
##### CLS Embedding Mode #####
|
1304 |
if self.emb_mode == "cls":
|
1305 |
# Extract cls embeddings from perturbed cells
|
@@ -1314,7 +1343,7 @@ class InSilicoPerturber:
|
|
1314 |
summary_stat=None,
|
1315 |
silent=True,
|
1316 |
)
|
1317 |
-
|
1318 |
# Calculate cosine similarities
|
1319 |
cls_cos_sims = pu.quant_cos_sims(
|
1320 |
perturbation_cls_emb,
|
@@ -1322,8 +1351,8 @@ class InSilicoPerturber:
|
|
1322 |
self.cell_states_to_model,
|
1323 |
self.state_embs_dict,
|
1324 |
emb_mode="cell",
|
1325 |
-
)
|
1326 |
-
|
1327 |
if self.cell_states_to_model is None:
|
1328 |
cos_sims_dict = self.update_perturbation_dictionary(
|
1329 |
cos_sims_dict,
|
@@ -1331,20 +1360,19 @@ class InSilicoPerturber:
|
|
1331 |
gene_list_mini,
|
1332 |
)
|
1333 |
else:
|
1334 |
-
|
1335 |
for state in cos_sims_dict.keys():
|
1336 |
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
1337 |
cos_sims_dict[state],
|
1338 |
cls_cos_sims[state],
|
1339 |
gene_list_mini,
|
1340 |
)
|
1341 |
-
|
1342 |
del perturbation_minibatch
|
1343 |
del perturbation_cls_emb
|
1344 |
del cls_cos_sims
|
1345 |
-
|
1346 |
##### CLS and Gene Embedding Mode #####
|
1347 |
-
elif self.emb_mode == "cls_and_gene":
|
1348 |
full_perturbation_emb = get_embs(
|
1349 |
model,
|
1350 |
perturbation_minibatch,
|
@@ -1356,18 +1384,26 @@ class InSilicoPerturber:
|
|
1356 |
summary_stat=None,
|
1357 |
silent=True,
|
1358 |
)
|
1359 |
-
|
1360 |
# need to remove overexpressed gene and cls/eos to quantify cosine shifts
|
1361 |
if self.perturb_type == "overexpress":
|
1362 |
-
perturbation_emb =
|
|
|
|
|
|
|
|
|
1363 |
elif self.perturb_type == "delete":
|
1364 |
-
perturbation_emb =
|
1365 |
-
|
|
|
|
|
1366 |
original_emb_minibatch = pu.make_comparison_batch(
|
1367 |
full_original_emb, indices_to_perturb_mini, perturb_group=False
|
1368 |
)
|
1369 |
-
|
1370 |
-
original_emb_minibatch =
|
|
|
|
|
1371 |
gene_cos_sims = pu.quant_cos_sims(
|
1372 |
perturbation_emb,
|
1373 |
original_emb_minibatch,
|
@@ -1375,7 +1411,7 @@ class InSilicoPerturber:
|
|
1375 |
self.state_embs_dict,
|
1376 |
emb_mode="gene",
|
1377 |
)
|
1378 |
-
|
1379 |
for perturbation_i, perturbed_gene in enumerate(gene_list_mini):
|
1380 |
for gene_j, affected_gene in enumerate(
|
1381 |
perturbed_gene_dict[perturbed_gene]
|
@@ -1388,10 +1424,12 @@ class InSilicoPerturber:
|
|
1388 |
stored_gene_embs_dict[
|
1389 |
(perturbed_gene, affected_gene)
|
1390 |
] = gene_cos_sims[perturbation_i, gene_j].item()
|
1391 |
-
|
1392 |
# get cls emb
|
1393 |
-
perturbation_cls_emb =
|
1394 |
-
|
|
|
|
|
1395 |
cls_cos_sims = pu.quant_cos_sims(
|
1396 |
perturbation_cls_emb,
|
1397 |
original_cls_emb,
|
@@ -1399,7 +1437,7 @@ class InSilicoPerturber:
|
|
1399 |
self.state_embs_dict,
|
1400 |
emb_mode="cell",
|
1401 |
)
|
1402 |
-
|
1403 |
if self.cell_states_to_model is None:
|
1404 |
cos_sims_dict = self.update_perturbation_dictionary(
|
1405 |
cos_sims_dict,
|
@@ -1413,7 +1451,7 @@ class InSilicoPerturber:
|
|
1413 |
cls_cos_sims[state],
|
1414 |
gene_list_mini,
|
1415 |
)
|
1416 |
-
|
1417 |
del perturbation_minibatch
|
1418 |
del original_emb_minibatch
|
1419 |
del full_perturbation_emb
|
@@ -1421,9 +1459,9 @@ class InSilicoPerturber:
|
|
1421 |
del perturbation_cls_emb
|
1422 |
del cls_cos_sims
|
1423 |
del gene_cos_sims
|
1424 |
-
|
1425 |
# save dict to disk every self.clear_mem_ncells/10 (default 100) simulated cells
|
1426 |
-
if i % max(1,self.clear_mem_ncells/10) == 0:
|
1427 |
pu.write_perturbation_dictionary(
|
1428 |
cos_sims_dict,
|
1429 |
f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
|
@@ -1433,7 +1471,7 @@ class InSilicoPerturber:
|
|
1433 |
stored_gene_embs_dict,
|
1434 |
f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
|
1435 |
)
|
1436 |
-
|
1437 |
# reset and clear memory every self.clear_mem_ncells (default 1000) simulated cells or at the end of the example cell
|
1438 |
if i % self.clear_mem_ncells == 0:
|
1439 |
pickle_batch += 1
|
@@ -1442,18 +1480,21 @@ class InSilicoPerturber:
|
|
1442 |
else:
|
1443 |
cos_sims_dict = {
|
1444 |
state: defaultdict(list)
|
1445 |
-
for state in pu.get_possible_states(
|
|
|
|
|
1446 |
}
|
1447 |
-
|
1448 |
if self.emb_mode == "cls_and_gene":
|
1449 |
stored_gene_embs_dict = defaultdict(list)
|
1450 |
-
|
1451 |
torch.cuda.empty_cache()
|
1452 |
-
|
1453 |
pu.write_perturbation_dictionary(
|
1454 |
-
cos_sims_dict,
|
|
|
1455 |
)
|
1456 |
-
|
1457 |
if self.emb_mode == "cls_and_gene":
|
1458 |
pu.write_perturbation_dictionary(
|
1459 |
stored_gene_embs_dict,
|
@@ -1470,8 +1511,8 @@ class InSilicoPerturber:
|
|
1470 |
}
|
1471 |
|
1472 |
if self.emb_mode == "cls_and_gene":
|
1473 |
-
stored_gene_embs_dict = defaultdict(list)
|
1474 |
-
|
1475 |
# clear memory between cells
|
1476 |
del perturbation_batch
|
1477 |
del original_cls_emb
|
@@ -1479,7 +1520,6 @@ class InSilicoPerturber:
|
|
1479 |
del full_original_emb
|
1480 |
torch.cuda.empty_cache()
|
1481 |
|
1482 |
-
|
1483 |
def update_perturbation_dictionary(
|
1484 |
self,
|
1485 |
cos_sims_dict: defaultdict,
|
@@ -1514,4 +1554,4 @@ class InSilicoPerturber:
|
|
1514 |
for i, cos in enumerate(cos_sims_data.tolist()):
|
1515 |
cos_sims_dict[(gene_list[i], "cell_emb")].append(cos)
|
1516 |
|
1517 |
-
return cos_sims_dict
|
|
|
38 |
import os
|
39 |
import pickle
|
40 |
from collections import defaultdict
|
|
|
|
|
41 |
|
42 |
import torch
|
43 |
from datasets import Dataset, disable_progress_bars
|
44 |
+
from multiprocess import set_start_method
|
45 |
from tqdm.auto import trange
|
46 |
|
47 |
+
from . import TOKEN_DICTIONARY_FILE
|
48 |
from . import perturber_utils as pu
|
49 |
from .emb_extractor import get_embs
|
|
|
50 |
|
51 |
disable_progress_bars()
|
52 |
|
|
|
70 |
"max_ncells": {None, int},
|
71 |
"cell_inds_to_perturb": {"all", dict},
|
72 |
"emb_layer": {-1, 0},
|
73 |
+
"token_dictionary_file": {None, str},
|
74 |
"forward_batch_size": {int},
|
75 |
"nproc": {int},
|
76 |
}
|
|
|
238 |
self.cls_token_id = self.gene_token_dict.get("<cls>")
|
239 |
self.eos_token_id = self.gene_token_dict.get("<eos>")
|
240 |
|
|
|
241 |
# Identify if special token is present in the token dictionary
|
242 |
if (self.cls_token_id is not None) and (self.eos_token_id is not None):
|
243 |
self.special_token = True
|
244 |
else:
|
245 |
if "cls" in self.emb_mode:
|
246 |
+
logger.error(
|
247 |
+
f"emb_mode set to {self.emb_mode} but <cls> or <eos> token not in token dictionary."
|
248 |
+
)
|
249 |
raise
|
250 |
self.special_token = False
|
251 |
|
|
|
454 |
|
455 |
# Ensure emb_mode is cls if first token of the filtered input data is cls token
|
456 |
if self.special_token:
|
457 |
+
if (filtered_input_data["input_ids"][0][0] == self.cls_token_id) and (
|
458 |
+
"cls" not in self.emb_mode
|
459 |
+
):
|
460 |
logger.error(
|
461 |
+
"Emb mode 'cls' or 'cls_and_gene' required when first token is <cls>."
|
462 |
+
)
|
463 |
raise
|
464 |
+
if "cls" in self.emb_mode:
|
465 |
+
if (filtered_input_data["input_ids"][0][0] != self.cls_token_id) or (
|
466 |
+
filtered_input_data["input_ids"][0][-1] != self.eos_token_id
|
467 |
+
):
|
468 |
logger.error(
|
469 |
+
"Emb mode 'cls' and 'cls_and_gene' require that first token is <cls> and last token is <eos>."
|
470 |
+
)
|
471 |
+
raise
|
472 |
|
473 |
filtered_input_data = self.apply_additional_filters(filtered_input_data)
|
474 |
|
|
|
534 |
layer_to_quant: int,
|
535 |
output_path_prefix: str,
|
536 |
):
|
|
|
537 |
def make_group_perturbation_batch(example):
|
538 |
example_input_ids = example["input_ids"]
|
539 |
example["tokens_to_perturb"] = self.tokens_to_perturb
|
|
|
552 |
if self.perturb_type == "delete":
|
553 |
example = pu.delete_indices(example)
|
554 |
elif self.perturb_type == "overexpress":
|
555 |
+
example = pu.overexpress_tokens(
|
556 |
+
example, self.max_len, self.special_token
|
557 |
+
)
|
558 |
example["n_overflow"] = pu.calc_n_overflow(
|
559 |
self.max_len,
|
560 |
example["length"],
|
|
|
575 |
perturbed_data = filtered_input_data.map(
|
576 |
make_group_perturbation_batch, num_proc=self.nproc
|
577 |
)
|
578 |
+
|
579 |
if self.perturb_type == "overexpress":
|
580 |
filtered_input_data = filtered_input_data.add_column(
|
581 |
"n_overflow", perturbed_data["n_overflow"]
|
|
|
757 |
f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
|
758 |
)
|
759 |
|
|
|
760 |
def isp_perturb_set_special(
|
761 |
self,
|
762 |
model,
|
|
|
764 |
layer_to_quant: int,
|
765 |
output_path_prefix: str,
|
766 |
):
|
|
|
767 |
def make_group_perturbation_batch(example):
|
768 |
example_input_ids = example["input_ids"]
|
769 |
example["tokens_to_perturb"] = self.tokens_to_perturb
|
|
|
782 |
if self.perturb_type == "delete":
|
783 |
example = pu.delete_indices(example)
|
784 |
elif self.perturb_type == "overexpress":
|
785 |
+
example = pu.overexpress_tokens(
|
786 |
+
example, self.max_len, self.special_token
|
787 |
+
)
|
788 |
example["n_overflow"] = pu.calc_n_overflow(
|
789 |
self.max_len,
|
790 |
example["length"],
|
|
|
813 |
filtered_input_data = filtered_input_data.map(
|
814 |
pu.truncate_by_n_overflow_special, num_proc=self.nproc
|
815 |
)
|
816 |
+
|
817 |
if self.emb_mode == "cls_and_gene":
|
818 |
stored_gene_embs_dict = defaultdict(list)
|
819 |
|
|
|
852 |
summary_stat=None,
|
853 |
silent=True,
|
854 |
)
|
855 |
+
|
856 |
+
# Calculate the cosine similarities
|
857 |
cls_cos_sims = pu.quant_cos_sims(
|
858 |
perturbation_cls_emb,
|
859 |
original_cls_emb,
|
860 |
self.cell_states_to_model,
|
861 |
self.state_embs_dict,
|
862 |
+
emb_mode="cell",
|
863 |
+
)
|
864 |
+
|
865 |
# Update perturbation dictionary
|
866 |
if self.cell_states_to_model is None:
|
867 |
cos_sims_dict = self.update_perturbation_dictionary(
|
868 |
cos_sims_dict,
|
869 |
cls_cos_sims,
|
870 |
+
gene_list=None,
|
871 |
)
|
872 |
else:
|
873 |
for state in cos_sims_dict.keys():
|
874 |
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
875 |
cos_sims_dict[state],
|
876 |
cls_cos_sims[state],
|
877 |
+
gene_list=None,
|
878 |
)
|
879 |
|
880 |
##### CLS and Gene Embedding Mode #####
|
|
|
914 |
# remove special tokens and padding
|
915 |
original_emb = original_emb[:, 1:-1, :]
|
916 |
if self.perturb_type == "overexpress":
|
917 |
+
perturbation_emb = full_perturbation_emb[
|
918 |
+
:, 1 + len(self.tokens_to_perturb) : -1, :
|
919 |
+
]
|
920 |
elif self.perturb_type == "delete":
|
921 |
+
perturbation_emb = full_perturbation_emb[
|
922 |
+
:, 1 : max(perturbation_batch["length"]) - 1, :
|
923 |
+
]
|
924 |
|
925 |
n_perturbation_genes = perturbation_emb.size()[1]
|
926 |
|
|
|
933 |
)
|
934 |
|
935 |
# get cls emb
|
936 |
+
original_cls_emb = full_original_emb[:, 0, :]
|
937 |
+
perturbation_cls_emb = full_perturbation_emb[:, 0, :]
|
938 |
|
939 |
cls_cos_sims = pu.quant_cos_sims(
|
940 |
perturbation_cls_emb,
|
|
|
949 |
|
950 |
gene_list = minibatch["input_ids"]
|
951 |
# need to truncate gene_list
|
952 |
+
genes_to_exclude = self.tokens_to_perturb + [
|
953 |
+
self.cls_token_id,
|
954 |
+
self.eos_token_id,
|
955 |
+
]
|
956 |
gene_list = [
|
957 |
[g for g in genes if g not in genes_to_exclude][
|
958 |
:n_perturbation_genes
|
|
|
981 |
cos_sims_dict = self.update_perturbation_dictionary(
|
982 |
cos_sims_dict,
|
983 |
cls_cos_sims,
|
984 |
+
gene_list=None,
|
985 |
)
|
986 |
else:
|
987 |
for state in cos_sims_dict.keys():
|
988 |
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
989 |
cos_sims_dict[state],
|
990 |
cls_cos_sims[state],
|
991 |
+
gene_list=None,
|
992 |
)
|
993 |
del full_original_emb
|
994 |
del original_emb
|
995 |
del full_perturbation_emb
|
996 |
+
del perturbation_emb
|
997 |
+
del gene_cos_sims
|
998 |
|
999 |
del original_cls_emb
|
1000 |
del perturbation_cls_emb
|
|
|
1048 |
summary_stat=None,
|
1049 |
silent=True,
|
1050 |
)
|
1051 |
+
|
1052 |
if self.cell_states_to_model is not None:
|
1053 |
original_cell_emb = pu.compute_nonpadded_cell_embedding(
|
1054 |
full_original_emb, "mean_pool"
|
1055 |
)
|
1056 |
+
|
1057 |
# gene_list is used to assign cos sims back to genes
|
1058 |
gene_list = example_cell["input_ids"][0][:]
|
1059 |
# need to remove the anchor gene
|
|
|
1062 |
gene_list.remove(token)
|
1063 |
# index 0 is not overexpressed so remove
|
1064 |
if self.perturb_type == "overexpress":
|
1065 |
+
gene_list = gene_list[num_inds_perturbed:]
|
|
|
|
|
1066 |
# remove perturbed index for gene list dict
|
1067 |
perturbed_gene_dict = {
|
1068 |
gene: gene_list[:i] + gene_list[i + 1 :]
|
1069 |
for i, gene in enumerate(gene_list)
|
1070 |
}
|
1071 |
+
|
1072 |
perturbation_batch, indices_to_perturb = pu.make_perturbation_batch(
|
1073 |
example_cell,
|
1074 |
self.perturb_type,
|
|
|
1079 |
)
|
1080 |
|
1081 |
ispall_total_batch_length = len(perturbation_batch)
|
1082 |
+
for i in trange(
|
1083 |
+
0, ispall_total_batch_length, self.forward_batch_size, leave=False
|
1084 |
+
):
|
1085 |
+
ispall_max_range = min(
|
1086 |
+
i + self.forward_batch_size, ispall_total_batch_length
|
1087 |
+
)
|
1088 |
+
perturbation_minibatch = perturbation_batch.select(
|
1089 |
+
[i for i in range(i, ispall_max_range)]
|
1090 |
+
)
|
1091 |
+
indices_to_perturb_mini = indices_to_perturb[i:ispall_max_range]
|
1092 |
+
gene_list_mini = gene_list[
|
1093 |
+
i:ispall_max_range
|
1094 |
+
] # only perturbed genes from this minibatch
|
1095 |
|
1096 |
full_perturbation_emb = get_embs(
|
1097 |
model,
|
|
|
1104 |
summary_stat=None,
|
1105 |
silent=True,
|
1106 |
)
|
1107 |
+
|
1108 |
del perturbation_minibatch
|
1109 |
+
|
1110 |
# need to remove overexpressed gene to quantify cosine shifts
|
1111 |
if self.perturb_type == "overexpress":
|
1112 |
perturbation_emb = full_perturbation_emb[:, num_inds_perturbed:, :]
|
1113 |
+
|
1114 |
elif self.perturb_type == "delete":
|
1115 |
perturbation_emb = full_perturbation_emb
|
1116 |
+
|
1117 |
+
if (
|
1118 |
+
self.cell_states_to_model is None
|
1119 |
+
or self.emb_mode == "cell_and_gene"
|
1120 |
+
):
|
1121 |
original_emb_minibatch = pu.make_comparison_batch(
|
1122 |
full_original_emb, indices_to_perturb_mini, perturb_group=False
|
1123 |
)
|
|
|
1129 |
emb_mode="gene",
|
1130 |
)
|
1131 |
del original_emb_minibatch
|
1132 |
+
|
1133 |
if self.cell_states_to_model is not None:
|
1134 |
perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
|
1135 |
full_perturbation_emb, "mean_pool"
|
1136 |
)
|
1137 |
+
|
1138 |
cell_cos_sims = pu.quant_cos_sims(
|
1139 |
perturbation_cell_emb,
|
1140 |
original_cell_emb,
|
|
|
1143 |
emb_mode="cell",
|
1144 |
)
|
1145 |
del perturbation_cell_emb
|
|
|
|
|
1146 |
|
1147 |
+
if self.emb_mode == "cell_and_gene":
|
1148 |
for perturbation_i, perturbed_gene in enumerate(gene_list_mini):
|
1149 |
for gene_j, affected_gene in enumerate(
|
1150 |
perturbed_gene_dict[perturbed_gene]
|
|
|
1157 |
stored_gene_embs_dict[
|
1158 |
(perturbed_gene, affected_gene)
|
1159 |
] = gene_cos_sims[perturbation_i, gene_j].item()
|
1160 |
+
|
1161 |
del full_perturbation_emb
|
1162 |
+
|
1163 |
if self.cell_states_to_model is None:
|
1164 |
cos_sims_data = torch.mean(gene_cos_sims, dim=1)
|
1165 |
cos_sims_dict = self.update_perturbation_dictionary(
|
|
|
1175 |
cos_sims_data[state],
|
1176 |
gene_list_mini,
|
1177 |
)
|
1178 |
+
|
1179 |
# save dict to disk every self.clear_mem_ncells/10 (default 100) simulated cells
|
1180 |
+
if i % self.clear_mem_ncells / 10 == 0:
|
1181 |
pu.write_perturbation_dictionary(
|
1182 |
cos_sims_dict,
|
1183 |
f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
|
|
|
1187 |
stored_gene_embs_dict,
|
1188 |
f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
|
1189 |
)
|
1190 |
+
|
1191 |
# reset and clear memory every self.clear_mem_ncells (default 1000) simulated cells or at the end of the example cell
|
1192 |
if i % self.clear_mem_ncells == 0:
|
1193 |
pickle_batch += 1
|
|
|
1196 |
else:
|
1197 |
cos_sims_dict = {
|
1198 |
state: defaultdict(list)
|
1199 |
+
for state in pu.get_possible_states(
|
1200 |
+
self.cell_states_to_model
|
1201 |
+
)
|
1202 |
}
|
1203 |
+
|
1204 |
if self.emb_mode == "cell_and_gene":
|
1205 |
stored_gene_embs_dict = defaultdict(list)
|
1206 |
+
|
1207 |
torch.cuda.empty_cache()
|
1208 |
+
|
1209 |
pu.write_perturbation_dictionary(
|
1210 |
+
cos_sims_dict,
|
1211 |
+
f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
|
1212 |
)
|
1213 |
+
|
1214 |
if self.emb_mode == "cell_and_gene":
|
1215 |
pu.write_perturbation_dictionary(
|
1216 |
stored_gene_embs_dict,
|
|
|
1235 |
if self.cell_states_to_model is not None:
|
1236 |
del original_cell_emb
|
1237 |
torch.cuda.empty_cache()
|
1238 |
+
|
1239 |
def isp_perturb_all_special(
|
1240 |
self,
|
1241 |
model,
|
|
|
1271 |
self.token_gene_dict,
|
1272 |
summary_stat=None,
|
1273 |
silent=True,
|
1274 |
+
)
|
1275 |
elif self.emb_mode == "cls_and_gene":
|
1276 |
full_original_emb = get_embs(
|
1277 |
model,
|
|
|
1284 |
summary_stat=None,
|
1285 |
silent=True,
|
1286 |
)
|
1287 |
+
original_cls_emb = full_original_emb[:, 0, :].clone().detach()
|
1288 |
+
|
1289 |
# gene_list is used to assign cos sims back to genes
|
1290 |
gene_list = example_cell["input_ids"][0][:]
|
1291 |
|
|
|
1298 |
gene_list.remove(token)
|
1299 |
# index 0 is not overexpressed so remove
|
1300 |
if self.perturb_type == "overexpress":
|
1301 |
+
gene_list = gene_list[num_inds_perturbed:]
|
|
|
|
|
1302 |
# remove perturbed index for gene list dict
|
1303 |
perturbed_gene_dict = {
|
1304 |
gene: gene_list[:i] + gene_list[i + 1 :]
|
|
|
1315 |
)
|
1316 |
|
1317 |
ispall_total_batch_length = len(perturbation_batch)
|
1318 |
+
for i in trange(
|
1319 |
+
0, ispall_total_batch_length, self.forward_batch_size, leave=False
|
1320 |
+
):
|
1321 |
+
ispall_max_range = min(
|
1322 |
+
i + self.forward_batch_size, ispall_total_batch_length
|
1323 |
+
)
|
1324 |
+
perturbation_minibatch = perturbation_batch.select(
|
1325 |
+
[i for i in range(i, ispall_max_range)]
|
1326 |
+
)
|
1327 |
+
indices_to_perturb_mini = indices_to_perturb[i:ispall_max_range]
|
1328 |
+
gene_list_mini = gene_list[
|
1329 |
+
i:ispall_max_range
|
1330 |
+
] # only perturbed genes from this minibatch
|
1331 |
+
|
1332 |
##### CLS Embedding Mode #####
|
1333 |
if self.emb_mode == "cls":
|
1334 |
# Extract cls embeddings from perturbed cells
|
|
|
1343 |
summary_stat=None,
|
1344 |
silent=True,
|
1345 |
)
|
1346 |
+
|
1347 |
# Calculate cosine similarities
|
1348 |
cls_cos_sims = pu.quant_cos_sims(
|
1349 |
perturbation_cls_emb,
|
|
|
1351 |
self.cell_states_to_model,
|
1352 |
self.state_embs_dict,
|
1353 |
emb_mode="cell",
|
1354 |
+
)
|
1355 |
+
|
1356 |
if self.cell_states_to_model is None:
|
1357 |
cos_sims_dict = self.update_perturbation_dictionary(
|
1358 |
cos_sims_dict,
|
|
|
1360 |
gene_list_mini,
|
1361 |
)
|
1362 |
else:
|
|
|
1363 |
for state in cos_sims_dict.keys():
|
1364 |
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
1365 |
cos_sims_dict[state],
|
1366 |
cls_cos_sims[state],
|
1367 |
gene_list_mini,
|
1368 |
)
|
1369 |
+
|
1370 |
del perturbation_minibatch
|
1371 |
del perturbation_cls_emb
|
1372 |
del cls_cos_sims
|
1373 |
+
|
1374 |
##### CLS and Gene Embedding Mode #####
|
1375 |
+
elif self.emb_mode == "cls_and_gene":
|
1376 |
full_perturbation_emb = get_embs(
|
1377 |
model,
|
1378 |
perturbation_minibatch,
|
|
|
1384 |
summary_stat=None,
|
1385 |
silent=True,
|
1386 |
)
|
1387 |
+
|
1388 |
# need to remove overexpressed gene and cls/eos to quantify cosine shifts
|
1389 |
if self.perturb_type == "overexpress":
|
1390 |
+
perturbation_emb = (
|
1391 |
+
full_perturbation_emb[:, 1 + num_inds_perturbed : -1, :]
|
1392 |
+
.clone()
|
1393 |
+
.detach()
|
1394 |
+
)
|
1395 |
elif self.perturb_type == "delete":
|
1396 |
+
perturbation_emb = (
|
1397 |
+
full_perturbation_emb[:, 1:-1, :].clone().detach()
|
1398 |
+
)
|
1399 |
+
|
1400 |
original_emb_minibatch = pu.make_comparison_batch(
|
1401 |
full_original_emb, indices_to_perturb_mini, perturb_group=False
|
1402 |
)
|
1403 |
+
|
1404 |
+
original_emb_minibatch = (
|
1405 |
+
original_emb_minibatch[:, 1:-1, :].clone().detach()
|
1406 |
+
)
|
1407 |
gene_cos_sims = pu.quant_cos_sims(
|
1408 |
perturbation_emb,
|
1409 |
original_emb_minibatch,
|
|
|
1411 |
self.state_embs_dict,
|
1412 |
emb_mode="gene",
|
1413 |
)
|
1414 |
+
|
1415 |
for perturbation_i, perturbed_gene in enumerate(gene_list_mini):
|
1416 |
for gene_j, affected_gene in enumerate(
|
1417 |
perturbed_gene_dict[perturbed_gene]
|
|
|
1424 |
stored_gene_embs_dict[
|
1425 |
(perturbed_gene, affected_gene)
|
1426 |
] = gene_cos_sims[perturbation_i, gene_j].item()
|
1427 |
+
|
1428 |
# get cls emb
|
1429 |
+
perturbation_cls_emb = (
|
1430 |
+
full_perturbation_emb[:, 0, :].clone().detach()
|
1431 |
+
)
|
1432 |
+
|
1433 |
cls_cos_sims = pu.quant_cos_sims(
|
1434 |
perturbation_cls_emb,
|
1435 |
original_cls_emb,
|
|
|
1437 |
self.state_embs_dict,
|
1438 |
emb_mode="cell",
|
1439 |
)
|
1440 |
+
|
1441 |
if self.cell_states_to_model is None:
|
1442 |
cos_sims_dict = self.update_perturbation_dictionary(
|
1443 |
cos_sims_dict,
|
|
|
1451 |
cls_cos_sims[state],
|
1452 |
gene_list_mini,
|
1453 |
)
|
1454 |
+
|
1455 |
del perturbation_minibatch
|
1456 |
del original_emb_minibatch
|
1457 |
del full_perturbation_emb
|
|
|
1459 |
del perturbation_cls_emb
|
1460 |
del cls_cos_sims
|
1461 |
del gene_cos_sims
|
1462 |
+
|
1463 |
# save dict to disk every self.clear_mem_ncells/10 (default 100) simulated cells
|
1464 |
+
if i % max(1, self.clear_mem_ncells / 10) == 0:
|
1465 |
pu.write_perturbation_dictionary(
|
1466 |
cos_sims_dict,
|
1467 |
f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
|
|
|
1471 |
stored_gene_embs_dict,
|
1472 |
f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
|
1473 |
)
|
1474 |
+
|
1475 |
# reset and clear memory every self.clear_mem_ncells (default 1000) simulated cells or at the end of the example cell
|
1476 |
if i % self.clear_mem_ncells == 0:
|
1477 |
pickle_batch += 1
|
|
|
1480 |
else:
|
1481 |
cos_sims_dict = {
|
1482 |
state: defaultdict(list)
|
1483 |
+
for state in pu.get_possible_states(
|
1484 |
+
self.cell_states_to_model
|
1485 |
+
)
|
1486 |
}
|
1487 |
+
|
1488 |
if self.emb_mode == "cls_and_gene":
|
1489 |
stored_gene_embs_dict = defaultdict(list)
|
1490 |
+
|
1491 |
torch.cuda.empty_cache()
|
1492 |
+
|
1493 |
pu.write_perturbation_dictionary(
|
1494 |
+
cos_sims_dict,
|
1495 |
+
f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
|
1496 |
)
|
1497 |
+
|
1498 |
if self.emb_mode == "cls_and_gene":
|
1499 |
pu.write_perturbation_dictionary(
|
1500 |
stored_gene_embs_dict,
|
|
|
1511 |
}
|
1512 |
|
1513 |
if self.emb_mode == "cls_and_gene":
|
1514 |
+
stored_gene_embs_dict = defaultdict(list)
|
1515 |
+
|
1516 |
# clear memory between cells
|
1517 |
del perturbation_batch
|
1518 |
del original_cls_emb
|
|
|
1520 |
del full_original_emb
|
1521 |
torch.cuda.empty_cache()
|
1522 |
|
|
|
1523 |
def update_perturbation_dictionary(
|
1524 |
self,
|
1525 |
cos_sims_dict: defaultdict,
|
|
|
1554 |
for i, cos in enumerate(cos_sims_data.tolist()):
|
1555 |
cos_sims_dict[(gene_list[i], "cell_emb")].append(cos)
|
1556 |
|
1557 |
+
return cos_sims_dict
|
geneformer/in_silico_perturber_stats.py
CHANGED
@@ -37,8 +37,8 @@ from scipy.stats import ranksums
|
|
37 |
from sklearn.mixture import GaussianMixture
|
38 |
from tqdm.auto import tqdm, trange
|
39 |
|
|
|
40 |
from .perturber_utils import flatten_list, validate_cell_states_to_model
|
41 |
-
from . import TOKEN_DICTIONARY_FILE, ENSEMBL_DICTIONARY_FILE
|
42 |
|
43 |
logger = logging.getLogger(__name__)
|
44 |
|
@@ -194,23 +194,29 @@ def get_impact_component(test_value, gaussian_mixture_model):
|
|
194 |
def isp_aggregate_grouped_perturb(cos_sims_df, dict_list, genes_perturbed):
|
195 |
names = ["Cosine_sim", "Gene"]
|
196 |
cos_sims_full_dfs = []
|
197 |
-
if isinstance(genes_perturbed,list):
|
198 |
-
if len(genes_perturbed)>1:
|
199 |
-
gene_ids_df = cos_sims_df.loc[
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
else:
|
201 |
-
gene_ids_df = cos_sims_df.loc[
|
|
|
|
|
202 |
else:
|
203 |
logger.error(
|
204 |
-
|
205 |
-
|
206 |
-
raise
|
207 |
|
208 |
if gene_ids_df.empty:
|
209 |
-
logger.error(
|
210 |
-
"genes_to_perturb not found in data."
|
211 |
-
)
|
212 |
raise
|
213 |
-
|
214 |
tokens = gene_ids_df["Gene"]
|
215 |
symbols = gene_ids_df["Gene_name"]
|
216 |
|
@@ -223,7 +229,7 @@ def isp_aggregate_grouped_perturb(cos_sims_df, dict_list, genes_perturbed):
|
|
223 |
df["Cosine_sim"] = cos_shift_data
|
224 |
df["Gene"] = symbol
|
225 |
cos_sims_full_dfs.append(df)
|
226 |
-
|
227 |
return pd.concat(cos_sims_full_dfs)
|
228 |
|
229 |
|
@@ -1018,7 +1024,7 @@ class InSilicoPerturberStats:
|
|
1018 |
},
|
1019 |
index=[i for i in range(len(gene_list))],
|
1020 |
)
|
1021 |
-
|
1022 |
if self.mode == "goal_state_shift":
|
1023 |
cos_sims_df = isp_stats_to_goal_state(
|
1024 |
cos_sims_df_initial,
|
@@ -1045,12 +1051,16 @@ class InSilicoPerturberStats:
|
|
1045 |
cos_sims_df_initial, dict_list, self.combos, self.anchor_token
|
1046 |
)
|
1047 |
|
1048 |
-
elif self.mode == "aggregate_data":
|
1049 |
-
cos_sims_df = isp_aggregate_grouped_perturb(
|
|
|
|
|
1050 |
|
1051 |
elif self.mode == "aggregate_gene_shifts":
|
1052 |
if (self.genes_perturbed == "all") and (self.combos == 0):
|
1053 |
-
tuple_types = [
|
|
|
|
|
1054 |
if all(tuple_types):
|
1055 |
token_dtype = "tuple"
|
1056 |
elif not any(tuple_types):
|
@@ -1059,13 +1069,13 @@ class InSilicoPerturberStats:
|
|
1059 |
token_dtype = "mix"
|
1060 |
else:
|
1061 |
token_dtype = "mix"
|
1062 |
-
|
1063 |
cos_sims_df = isp_aggregate_gene_shifts(
|
1064 |
cos_sims_df_initial,
|
1065 |
dict_list,
|
1066 |
self.gene_token_id_dict,
|
1067 |
self.gene_id_name_dict,
|
1068 |
-
token_dtype
|
1069 |
)
|
1070 |
|
1071 |
# save perturbation stats to output_path
|
|
|
37 |
from sklearn.mixture import GaussianMixture
|
38 |
from tqdm.auto import tqdm, trange
|
39 |
|
40 |
+
from . import ENSEMBL_DICTIONARY_FILE, TOKEN_DICTIONARY_FILE
|
41 |
from .perturber_utils import flatten_list, validate_cell_states_to_model
|
|
|
42 |
|
43 |
logger = logging.getLogger(__name__)
|
44 |
|
|
|
194 |
def isp_aggregate_grouped_perturb(cos_sims_df, dict_list, genes_perturbed):
|
195 |
names = ["Cosine_sim", "Gene"]
|
196 |
cos_sims_full_dfs = []
|
197 |
+
if isinstance(genes_perturbed, list):
|
198 |
+
if len(genes_perturbed) > 1:
|
199 |
+
gene_ids_df = cos_sims_df.loc[
|
200 |
+
np.isin(
|
201 |
+
[set(idx) for idx in cos_sims_df["Ensembl_ID"]],
|
202 |
+
set(genes_perturbed),
|
203 |
+
),
|
204 |
+
:,
|
205 |
+
]
|
206 |
else:
|
207 |
+
gene_ids_df = cos_sims_df.loc[
|
208 |
+
np.isin(cos_sims_df["Ensembl_ID"], genes_perturbed), :
|
209 |
+
]
|
210 |
else:
|
211 |
logger.error(
|
212 |
+
"aggregate_data is for perturbation of single gene or single group of genes. genes_to_perturb should be formatted as list."
|
213 |
+
)
|
214 |
+
raise
|
215 |
|
216 |
if gene_ids_df.empty:
|
217 |
+
logger.error("genes_to_perturb not found in data.")
|
|
|
|
|
218 |
raise
|
219 |
+
|
220 |
tokens = gene_ids_df["Gene"]
|
221 |
symbols = gene_ids_df["Gene_name"]
|
222 |
|
|
|
229 |
df["Cosine_sim"] = cos_shift_data
|
230 |
df["Gene"] = symbol
|
231 |
cos_sims_full_dfs.append(df)
|
232 |
+
|
233 |
return pd.concat(cos_sims_full_dfs)
|
234 |
|
235 |
|
|
|
1024 |
},
|
1025 |
index=[i for i in range(len(gene_list))],
|
1026 |
)
|
1027 |
+
|
1028 |
if self.mode == "goal_state_shift":
|
1029 |
cos_sims_df = isp_stats_to_goal_state(
|
1030 |
cos_sims_df_initial,
|
|
|
1051 |
cos_sims_df_initial, dict_list, self.combos, self.anchor_token
|
1052 |
)
|
1053 |
|
1054 |
+
elif self.mode == "aggregate_data":
|
1055 |
+
cos_sims_df = isp_aggregate_grouped_perturb(
|
1056 |
+
cos_sims_df_initial, dict_list, self.genes_perturbed
|
1057 |
+
)
|
1058 |
|
1059 |
elif self.mode == "aggregate_gene_shifts":
|
1060 |
if (self.genes_perturbed == "all") and (self.combos == 0):
|
1061 |
+
tuple_types = [
|
1062 |
+
True if isinstance(genes, tuple) else False for genes in gene_list
|
1063 |
+
]
|
1064 |
if all(tuple_types):
|
1065 |
token_dtype = "tuple"
|
1066 |
elif not any(tuple_types):
|
|
|
1069 |
token_dtype = "mix"
|
1070 |
else:
|
1071 |
token_dtype = "mix"
|
1072 |
+
|
1073 |
cos_sims_df = isp_aggregate_gene_shifts(
|
1074 |
cos_sims_df_initial,
|
1075 |
dict_list,
|
1076 |
self.gene_token_id_dict,
|
1077 |
self.gene_id_name_dict,
|
1078 |
+
token_dtype,
|
1079 |
)
|
1080 |
|
1081 |
# save perturbation stats to output_path
|
geneformer/mtl/collators.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
#imports
|
2 |
import torch
|
3 |
|
4 |
from ..collator_for_classification import DataCollatorForGeneClassification
|
@@ -7,6 +7,7 @@ from ..collator_for_classification import DataCollatorForGeneClassification
|
|
7 |
Geneformer collator for multi-task cell classification.
|
8 |
"""
|
9 |
|
|
|
10 |
class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
|
11 |
class_type = "cell"
|
12 |
|
@@ -47,7 +48,10 @@ class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassificati
|
|
47 |
batch["labels"] = labels
|
48 |
else:
|
49 |
# If no labels are present, create empty labels for all tasks
|
50 |
-
batch["labels"] = {
|
|
|
|
|
|
|
51 |
|
52 |
return batch
|
53 |
|
@@ -59,8 +63,11 @@ class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassificati
|
|
59 |
batch[k] = v.clone().detach()
|
60 |
elif isinstance(v, dict):
|
61 |
# Assuming nested structure needs conversion
|
62 |
-
batch[k] = {
|
|
|
|
|
|
|
63 |
else:
|
64 |
batch[k] = torch.tensor(v, dtype=torch.int64)
|
65 |
|
66 |
-
return batch
|
|
|
1 |
+
# imports
|
2 |
import torch
|
3 |
|
4 |
from ..collator_for_classification import DataCollatorForGeneClassification
|
|
|
7 |
Geneformer collator for multi-task cell classification.
|
8 |
"""
|
9 |
|
10 |
+
|
11 |
class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
|
12 |
class_type = "cell"
|
13 |
|
|
|
48 |
batch["labels"] = labels
|
49 |
else:
|
50 |
# If no labels are present, create empty labels for all tasks
|
51 |
+
batch["labels"] = {
|
52 |
+
task: torch.tensor([], dtype=torch.long)
|
53 |
+
for task in features[0]["input_ids"].keys()
|
54 |
+
}
|
55 |
|
56 |
return batch
|
57 |
|
|
|
63 |
batch[k] = v.clone().detach()
|
64 |
elif isinstance(v, dict):
|
65 |
# Assuming nested structure needs conversion
|
66 |
+
batch[k] = {
|
67 |
+
task: torch.tensor(labels, dtype=torch.int64)
|
68 |
+
for task, labels in v.items()
|
69 |
+
}
|
70 |
else:
|
71 |
batch[k] = torch.tensor(v, dtype=torch.int64)
|
72 |
|
73 |
+
return batch
|
geneformer/mtl/data.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
-
from .imports import *
|
2 |
import os
|
|
|
3 |
from .collators import DataCollatorForMultitaskCellClassification
|
|
|
|
|
4 |
|
5 |
def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
|
6 |
try:
|
@@ -14,7 +16,9 @@ def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type="
|
|
14 |
available_columns = set(dataset.column_names)
|
15 |
for column in task_to_column.values():
|
16 |
if column not in available_columns:
|
17 |
-
raise KeyError(
|
|
|
|
|
18 |
|
19 |
label_mappings = {}
|
20 |
task_label_mappings = {}
|
@@ -25,13 +29,17 @@ def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type="
|
|
25 |
if not is_test:
|
26 |
for task, column in task_to_column.items():
|
27 |
unique_values = sorted(set(dataset[column])) # Ensure consistency
|
28 |
-
label_mappings[column] = {
|
|
|
|
|
29 |
task_label_mappings[task] = label_mappings[column]
|
30 |
num_labels_list.append(len(unique_values))
|
31 |
|
32 |
# Print the mappings for each task with dataset type prefix
|
33 |
for task, mapping in task_label_mappings.items():
|
34 |
-
print(
|
|
|
|
|
35 |
|
36 |
# Save the task label mappings as a pickle file
|
37 |
with open(f"{config['results_dir']}/task_label_mappings.pkl", "wb") as f:
|
@@ -40,24 +48,26 @@ def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type="
|
|
40 |
# Load task label mappings from pickle file for test data
|
41 |
with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
|
42 |
task_label_mappings = pickle.load(f)
|
43 |
-
|
44 |
# Infer num_labels_list from task_label_mappings
|
45 |
for task, mapping in task_label_mappings.items():
|
46 |
num_labels_list.append(len(mapping))
|
47 |
|
48 |
# Store unique cell IDs in a separate dictionary
|
49 |
for idx, record in enumerate(dataset):
|
50 |
-
cell_id = record.get(
|
51 |
cell_id_mapping[idx] = cell_id
|
52 |
|
53 |
# Transform records to the desired format
|
54 |
transformed_dataset = []
|
55 |
for idx, record in enumerate(dataset):
|
56 |
transformed_record = {}
|
57 |
-
transformed_record[
|
58 |
-
|
|
|
|
|
59 |
# Use index-based cell ID for internal tracking
|
60 |
-
transformed_record[
|
61 |
|
62 |
if not is_test:
|
63 |
# Prepare labels
|
@@ -66,11 +76,11 @@ def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type="
|
|
66 |
label_value = record[column]
|
67 |
label_index = task_label_mappings[task][label_value]
|
68 |
label_dict[task] = label_index
|
69 |
-
transformed_record[
|
70 |
else:
|
71 |
# Create dummy labels for test data
|
72 |
label_dict = {task: -1 for task in config["task_names"]}
|
73 |
-
transformed_record[
|
74 |
|
75 |
transformed_dataset.append(transformed_record)
|
76 |
|
@@ -81,36 +91,60 @@ def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type="
|
|
81 |
print(f"An error occurred while loading or preprocessing data: {e}")
|
82 |
return None, None, None
|
83 |
|
|
|
84 |
def preload_and_process_data(config):
|
85 |
# Load and preprocess data once
|
86 |
-
train_dataset, train_cell_id_mapping, num_labels_list = load_and_preprocess_data(
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
def get_data_loader(preprocessed_dataset, batch_size):
|
91 |
-
nproc = os.cpu_count()
|
92 |
-
|
93 |
data_collator = DataCollatorForMultitaskCellClassification()
|
94 |
-
|
95 |
-
loader = DataLoader(
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
return loader
|
|
|
|
|
98 |
def preload_data(config):
|
99 |
# Preprocessing the data before the Optuna trials start
|
100 |
train_loader = get_data_loader("train", config)
|
101 |
val_loader = get_data_loader("val", config)
|
102 |
return train_loader, val_loader
|
103 |
|
|
|
104 |
def load_and_preprocess_test_data(config):
|
105 |
"""
|
106 |
Load and preprocess test data, treating it as unlabeled.
|
107 |
"""
|
108 |
return load_and_preprocess_data(config["test_path"], config, is_test=True)
|
109 |
|
|
|
110 |
def prepare_test_loader(config):
|
111 |
"""
|
112 |
Prepare DataLoader for the test dataset.
|
113 |
"""
|
114 |
-
test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
|
3 |
from .collators import DataCollatorForMultitaskCellClassification
|
4 |
+
from .imports import *
|
5 |
+
|
6 |
|
7 |
def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
|
8 |
try:
|
|
|
16 |
available_columns = set(dataset.column_names)
|
17 |
for column in task_to_column.values():
|
18 |
if column not in available_columns:
|
19 |
+
raise KeyError(
|
20 |
+
f"Column {column} not found in the dataset. Available columns: {list(available_columns)}"
|
21 |
+
)
|
22 |
|
23 |
label_mappings = {}
|
24 |
task_label_mappings = {}
|
|
|
29 |
if not is_test:
|
30 |
for task, column in task_to_column.items():
|
31 |
unique_values = sorted(set(dataset[column])) # Ensure consistency
|
32 |
+
label_mappings[column] = {
|
33 |
+
label: idx for idx, label in enumerate(unique_values)
|
34 |
+
}
|
35 |
task_label_mappings[task] = label_mappings[column]
|
36 |
num_labels_list.append(len(unique_values))
|
37 |
|
38 |
# Print the mappings for each task with dataset type prefix
|
39 |
for task, mapping in task_label_mappings.items():
|
40 |
+
print(
|
41 |
+
f"{dataset_type.capitalize()} mapping for {task}: {mapping}"
|
42 |
+
) # sanity check, for train/validation splits
|
43 |
|
44 |
# Save the task label mappings as a pickle file
|
45 |
with open(f"{config['results_dir']}/task_label_mappings.pkl", "wb") as f:
|
|
|
48 |
# Load task label mappings from pickle file for test data
|
49 |
with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
|
50 |
task_label_mappings = pickle.load(f)
|
51 |
+
|
52 |
# Infer num_labels_list from task_label_mappings
|
53 |
for task, mapping in task_label_mappings.items():
|
54 |
num_labels_list.append(len(mapping))
|
55 |
|
56 |
# Store unique cell IDs in a separate dictionary
|
57 |
for idx, record in enumerate(dataset):
|
58 |
+
cell_id = record.get("unique_cell_id", idx)
|
59 |
cell_id_mapping[idx] = cell_id
|
60 |
|
61 |
# Transform records to the desired format
|
62 |
transformed_dataset = []
|
63 |
for idx, record in enumerate(dataset):
|
64 |
transformed_record = {}
|
65 |
+
transformed_record["input_ids"] = torch.tensor(
|
66 |
+
record["input_ids"], dtype=torch.long
|
67 |
+
)
|
68 |
+
|
69 |
# Use index-based cell ID for internal tracking
|
70 |
+
transformed_record["cell_id"] = idx
|
71 |
|
72 |
if not is_test:
|
73 |
# Prepare labels
|
|
|
76 |
label_value = record[column]
|
77 |
label_index = task_label_mappings[task][label_value]
|
78 |
label_dict[task] = label_index
|
79 |
+
transformed_record["label"] = label_dict
|
80 |
else:
|
81 |
# Create dummy labels for test data
|
82 |
label_dict = {task: -1 for task in config["task_names"]}
|
83 |
+
transformed_record["label"] = label_dict
|
84 |
|
85 |
transformed_dataset.append(transformed_record)
|
86 |
|
|
|
91 |
print(f"An error occurred while loading or preprocessing data: {e}")
|
92 |
return None, None, None
|
93 |
|
94 |
+
|
95 |
def preload_and_process_data(config):
|
96 |
# Load and preprocess data once
|
97 |
+
train_dataset, train_cell_id_mapping, num_labels_list = load_and_preprocess_data(
|
98 |
+
config["train_path"], config, dataset_type="train"
|
99 |
+
)
|
100 |
+
val_dataset, val_cell_id_mapping, _ = load_and_preprocess_data(
|
101 |
+
config["val_path"], config, dataset_type="validation"
|
102 |
+
)
|
103 |
+
return (
|
104 |
+
train_dataset,
|
105 |
+
train_cell_id_mapping,
|
106 |
+
val_dataset,
|
107 |
+
val_cell_id_mapping,
|
108 |
+
num_labels_list,
|
109 |
+
)
|
110 |
+
|
111 |
|
112 |
def get_data_loader(preprocessed_dataset, batch_size):
|
113 |
+
nproc = os.cpu_count() ### I/O operations
|
114 |
+
|
115 |
data_collator = DataCollatorForMultitaskCellClassification()
|
116 |
+
|
117 |
+
loader = DataLoader(
|
118 |
+
preprocessed_dataset,
|
119 |
+
batch_size=batch_size,
|
120 |
+
shuffle=True,
|
121 |
+
collate_fn=data_collator,
|
122 |
+
num_workers=nproc,
|
123 |
+
pin_memory=True,
|
124 |
+
)
|
125 |
return loader
|
126 |
+
|
127 |
+
|
128 |
def preload_data(config):
|
129 |
# Preprocessing the data before the Optuna trials start
|
130 |
train_loader = get_data_loader("train", config)
|
131 |
val_loader = get_data_loader("val", config)
|
132 |
return train_loader, val_loader
|
133 |
|
134 |
+
|
135 |
def load_and_preprocess_test_data(config):
|
136 |
"""
|
137 |
Load and preprocess test data, treating it as unlabeled.
|
138 |
"""
|
139 |
return load_and_preprocess_data(config["test_path"], config, is_test=True)
|
140 |
|
141 |
+
|
142 |
def prepare_test_loader(config):
|
143 |
"""
|
144 |
Prepare DataLoader for the test dataset.
|
145 |
"""
|
146 |
+
test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(
|
147 |
+
config
|
148 |
+
)
|
149 |
+
test_loader = get_data_loader(test_dataset, config["batch_size"])
|
150 |
+
return test_loader, cell_id_mapping, num_labels_list
|
geneformer/mtl/eval_utils.py
CHANGED
@@ -1,29 +1,33 @@
|
|
1 |
-
from .imports import *
|
2 |
import pandas as pd
|
|
|
3 |
from .data import prepare_test_loader
|
|
|
4 |
from .model import GeneformerMultiTask
|
5 |
|
|
|
6 |
def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config):
|
7 |
task_pred_labels = {task_name: [] for task_name in config["task_names"]}
|
8 |
task_pred_probs = {task_name: [] for task_name in config["task_names"]}
|
9 |
cell_ids = []
|
10 |
|
11 |
-
# Load task label mappings from pickle file
|
12 |
-
with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
|
13 |
-
|
14 |
|
15 |
model.eval()
|
16 |
with torch.no_grad():
|
17 |
for batch in test_loader:
|
18 |
-
input_ids = batch[
|
19 |
-
attention_mask = batch[
|
20 |
_, logits, _ = model(input_ids, attention_mask)
|
21 |
-
for sample_idx in range(len(batch[
|
22 |
-
cell_id = cell_id_mapping[batch[
|
23 |
cell_ids.append(cell_id)
|
24 |
for i, task_name in enumerate(config["task_names"]):
|
25 |
pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
|
26 |
-
pred_prob =
|
|
|
|
|
27 |
task_pred_labels[task_name].append(pred_label)
|
28 |
task_pred_probs[task_name].append(pred_prob)
|
29 |
|
@@ -31,19 +35,22 @@ def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config):
|
|
31 |
test_results_dir = config["results_dir"]
|
32 |
os.makedirs(test_results_dir, exist_ok=True)
|
33 |
test_preds_file = os.path.join(test_results_dir, "test_preds.csv")
|
34 |
-
|
35 |
rows = []
|
36 |
for sample_idx in range(len(cell_ids)):
|
37 |
-
row = {
|
38 |
for task_name in config["task_names"]:
|
39 |
-
row[f
|
40 |
-
row[f
|
|
|
|
|
41 |
rows.append(row)
|
42 |
-
|
43 |
df = pd.DataFrame(rows)
|
44 |
df.to_csv(test_preds_file, index=False)
|
45 |
print(f"Test predictions saved to {test_preds_file}")
|
46 |
|
|
|
47 |
def load_and_evaluate_test_model(config):
|
48 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
49 |
test_loader, cell_id_mapping, num_labels_list = prepare_test_loader(config)
|
@@ -51,7 +58,7 @@ def load_and_evaluate_test_model(config):
|
|
51 |
hyperparams_path = os.path.join(model_directory, "hyperparameters.json")
|
52 |
|
53 |
# Load the saved best hyperparameters
|
54 |
-
with open(hyperparams_path,
|
55 |
best_hyperparams = json.load(f)
|
56 |
|
57 |
# Extract the task weights if present, otherwise set to None
|
@@ -72,7 +79,7 @@ def load_and_evaluate_test_model(config):
|
|
72 |
num_labels_list,
|
73 |
dropout_rate=best_hyperparams["dropout_rate"],
|
74 |
use_task_weights=config["use_task_weights"],
|
75 |
-
task_weights=normalized_task_weights
|
76 |
)
|
77 |
best_model.load_state_dict(torch.load(best_model_path))
|
78 |
best_model.to(device)
|
|
|
|
|
1 |
import pandas as pd
|
2 |
+
|
3 |
from .data import prepare_test_loader
|
4 |
+
from .imports import *
|
5 |
from .model import GeneformerMultiTask
|
6 |
|
7 |
+
|
8 |
def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config):
|
9 |
task_pred_labels = {task_name: [] for task_name in config["task_names"]}
|
10 |
task_pred_probs = {task_name: [] for task_name in config["task_names"]}
|
11 |
cell_ids = []
|
12 |
|
13 |
+
# # Load task label mappings from pickle file
|
14 |
+
# with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
|
15 |
+
# task_label_mappings = pickle.load(f)
|
16 |
|
17 |
model.eval()
|
18 |
with torch.no_grad():
|
19 |
for batch in test_loader:
|
20 |
+
input_ids = batch["input_ids"].to(device)
|
21 |
+
attention_mask = batch["attention_mask"].to(device)
|
22 |
_, logits, _ = model(input_ids, attention_mask)
|
23 |
+
for sample_idx in range(len(batch["input_ids"])):
|
24 |
+
cell_id = cell_id_mapping[batch["cell_id"][sample_idx].item()]
|
25 |
cell_ids.append(cell_id)
|
26 |
for i, task_name in enumerate(config["task_names"]):
|
27 |
pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
|
28 |
+
pred_prob = (
|
29 |
+
torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
|
30 |
+
)
|
31 |
task_pred_labels[task_name].append(pred_label)
|
32 |
task_pred_probs[task_name].append(pred_prob)
|
33 |
|
|
|
35 |
test_results_dir = config["results_dir"]
|
36 |
os.makedirs(test_results_dir, exist_ok=True)
|
37 |
test_preds_file = os.path.join(test_results_dir, "test_preds.csv")
|
38 |
+
|
39 |
rows = []
|
40 |
for sample_idx in range(len(cell_ids)):
|
41 |
+
row = {"Cell ID": cell_ids[sample_idx]}
|
42 |
for task_name in config["task_names"]:
|
43 |
+
row[f"{task_name} Prediction"] = task_pred_labels[task_name][sample_idx]
|
44 |
+
row[f"{task_name} Probabilities"] = ",".join(
|
45 |
+
map(str, task_pred_probs[task_name][sample_idx])
|
46 |
+
)
|
47 |
rows.append(row)
|
48 |
+
|
49 |
df = pd.DataFrame(rows)
|
50 |
df.to_csv(test_preds_file, index=False)
|
51 |
print(f"Test predictions saved to {test_preds_file}")
|
52 |
|
53 |
+
|
54 |
def load_and_evaluate_test_model(config):
|
55 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
56 |
test_loader, cell_id_mapping, num_labels_list = prepare_test_loader(config)
|
|
|
58 |
hyperparams_path = os.path.join(model_directory, "hyperparameters.json")
|
59 |
|
60 |
# Load the saved best hyperparameters
|
61 |
+
with open(hyperparams_path, "r") as f:
|
62 |
best_hyperparams = json.load(f)
|
63 |
|
64 |
# Extract the task weights if present, otherwise set to None
|
|
|
79 |
num_labels_list,
|
80 |
dropout_rate=best_hyperparams["dropout_rate"],
|
81 |
use_task_weights=config["use_task_weights"],
|
82 |
+
task_weights=normalized_task_weights,
|
83 |
)
|
84 |
best_model.load_state_dict(torch.load(best_model_path))
|
85 |
best_model.to(device)
|
geneformer/mtl/imports.py
CHANGED
@@ -1,46 +1,43 @@
|
|
1 |
-
import
|
|
|
|
|
|
|
2 |
import pickle
|
3 |
-
import
|
4 |
-
import torch.nn as nn
|
5 |
-
import torch.optim as optim
|
6 |
-
import torch.nn.functional as F
|
7 |
-
from torch.utils.data import DataLoader
|
8 |
-
|
9 |
-
from itertools import chain
|
10 |
import warnings
|
11 |
from enum import Enum
|
|
|
12 |
from typing import Dict, List, Optional, Union
|
13 |
-
import sys
|
14 |
-
import os
|
15 |
-
import json
|
16 |
-
import gc
|
17 |
-
import functools
|
18 |
-
import pandas as pd
|
19 |
-
|
20 |
-
from sklearn.metrics import f1_score, accuracy_score, roc_auc_score, roc_curve
|
21 |
-
from sklearn.preprocessing import LabelEncoder
|
22 |
-
from sklearn.model_selection import train_test_split
|
23 |
|
|
|
24 |
import optuna
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
from transformers import (
|
|
|
|
|
27 |
BertConfig,
|
28 |
BertModel,
|
29 |
-
AdamW,
|
30 |
-
get_linear_schedule_with_warmup,
|
31 |
-
get_cosine_schedule_with_warmup,
|
32 |
DataCollatorForTokenClassification,
|
33 |
SpecialTokensMixin,
|
34 |
-
|
|
|
35 |
get_scheduler,
|
36 |
)
|
37 |
from transformers.utils import logging, to_py_obj
|
38 |
|
39 |
-
from
|
40 |
|
41 |
# local modules
|
42 |
-
from .data import
|
43 |
from .model import GeneformerMultiTask
|
44 |
-
from .utils import save_model
|
45 |
from .optuna_utils import create_optuna_study
|
46 |
-
from .
|
|
|
1 |
+
import functools
|
2 |
+
import gc
|
3 |
+
import json
|
4 |
+
import os
|
5 |
import pickle
|
6 |
+
import sys
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
import warnings
|
8 |
from enum import Enum
|
9 |
+
from itertools import chain
|
10 |
from typing import Dict, List, Optional, Union
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
+
import numpy as np
|
13 |
import optuna
|
14 |
+
import pandas as pd
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
import torch.optim as optim
|
19 |
+
from datasets import load_from_disk
|
20 |
+
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, roc_curve
|
21 |
+
from sklearn.model_selection import train_test_split
|
22 |
+
from sklearn.preprocessing import LabelEncoder
|
23 |
+
from torch.utils.data import DataLoader
|
24 |
from transformers import (
|
25 |
+
AdamW,
|
26 |
+
BatchEncoding,
|
27 |
BertConfig,
|
28 |
BertModel,
|
|
|
|
|
|
|
29 |
DataCollatorForTokenClassification,
|
30 |
SpecialTokensMixin,
|
31 |
+
get_cosine_schedule_with_warmup,
|
32 |
+
get_linear_schedule_with_warmup,
|
33 |
get_scheduler,
|
34 |
)
|
35 |
from transformers.utils import logging, to_py_obj
|
36 |
|
37 |
+
from .collators import DataCollatorForMultitaskCellClassification
|
38 |
|
39 |
# local modules
|
40 |
+
from .data import get_data_loader, preload_and_process_data
|
41 |
from .model import GeneformerMultiTask
|
|
|
42 |
from .optuna_utils import create_optuna_study
|
43 |
+
from .utils import save_model
|
geneformer/mtl/model.py
CHANGED
@@ -1,13 +1,17 @@
|
|
1 |
-
from transformers import BertModel, BertConfig
|
2 |
import torch
|
3 |
import torch.nn as nn
|
|
|
|
|
4 |
|
5 |
class AttentionPool(nn.Module):
|
6 |
"""Attention-based pooling layer."""
|
|
|
7 |
def __init__(self, hidden_size):
|
8 |
super(AttentionPool, self).__init__()
|
9 |
self.attention_weights = nn.Parameter(torch.randn(hidden_size, 1))
|
10 |
-
nn.init.xavier_uniform_(
|
|
|
|
|
11 |
|
12 |
def forward(self, hidden_states):
|
13 |
attention_scores = torch.matmul(hidden_states, self.attention_weights)
|
@@ -15,8 +19,18 @@ class AttentionPool(nn.Module):
|
|
15 |
pooled_output = torch.sum(hidden_states * attention_scores, dim=1)
|
16 |
return pooled_output
|
17 |
|
|
|
18 |
class GeneformerMultiTask(nn.Module):
|
19 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
super(GeneformerMultiTask, self).__init__()
|
21 |
self.config = BertConfig.from_pretrained(pretrained_path)
|
22 |
self.bert = BertModel(self.config)
|
@@ -25,20 +39,31 @@ class GeneformerMultiTask(nn.Module):
|
|
25 |
self.dropout = nn.Dropout(dropout_rate)
|
26 |
self.use_attention_pooling = use_attention_pooling
|
27 |
|
28 |
-
if use_task_weights and (
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
# Freeze the specified initial layers
|
33 |
for layer in self.bert.encoder.layer[:max_layers_to_freeze]:
|
34 |
for param in layer.parameters():
|
35 |
param.requires_grad = False
|
36 |
|
37 |
-
self.attention_pool =
|
|
|
|
|
38 |
|
39 |
-
self.classification_heads = nn.ModuleList(
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
42 |
# initialization of the classification heads: https://pytorch.org/docs/stable/nn.init.html
|
43 |
for head in self.classification_heads:
|
44 |
nn.init.xavier_uniform_(head.weight)
|
@@ -53,7 +78,11 @@ class GeneformerMultiTask(nn.Module):
|
|
53 |
sequence_output = outputs.last_hidden_state
|
54 |
|
55 |
try:
|
56 |
-
pooled_output =
|
|
|
|
|
|
|
|
|
57 |
pooled_output = self.dropout(pooled_output)
|
58 |
except Exception as e:
|
59 |
raise RuntimeError(f"Error during pooling and dropout: {e}")
|
@@ -62,23 +91,31 @@ class GeneformerMultiTask(nn.Module):
|
|
62 |
logits = []
|
63 |
losses = []
|
64 |
|
65 |
-
for task_id, (head, num_labels) in enumerate(
|
|
|
|
|
66 |
try:
|
67 |
task_logits = head(pooled_output)
|
68 |
except Exception as e:
|
69 |
-
raise RuntimeError(
|
|
|
|
|
70 |
|
71 |
logits.append(task_logits)
|
72 |
|
73 |
if labels is not None:
|
74 |
try:
|
75 |
loss_fct = nn.CrossEntropyLoss()
|
76 |
-
task_loss = loss_fct(
|
|
|
|
|
77 |
if self.use_task_weights:
|
78 |
task_loss *= self.task_weights[task_id]
|
79 |
total_loss += task_loss
|
80 |
losses.append(task_loss.item())
|
81 |
except Exception as e:
|
82 |
-
raise RuntimeError(
|
|
|
|
|
83 |
|
84 |
return total_loss, logits, losses if labels is not None else logits
|
|
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
+
from transformers import BertConfig, BertModel
|
4 |
+
|
5 |
|
6 |
class AttentionPool(nn.Module):
|
7 |
"""Attention-based pooling layer."""
|
8 |
+
|
9 |
def __init__(self, hidden_size):
|
10 |
super(AttentionPool, self).__init__()
|
11 |
self.attention_weights = nn.Parameter(torch.randn(hidden_size, 1))
|
12 |
+
nn.init.xavier_uniform_(
|
13 |
+
self.attention_weights
|
14 |
+
) # https://pytorch.org/docs/stable/nn.init.html
|
15 |
|
16 |
def forward(self, hidden_states):
|
17 |
attention_scores = torch.matmul(hidden_states, self.attention_weights)
|
|
|
19 |
pooled_output = torch.sum(hidden_states * attention_scores, dim=1)
|
20 |
return pooled_output
|
21 |
|
22 |
+
|
23 |
class GeneformerMultiTask(nn.Module):
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
pretrained_path,
|
27 |
+
num_labels_list,
|
28 |
+
dropout_rate=0.1,
|
29 |
+
use_task_weights=False,
|
30 |
+
task_weights=None,
|
31 |
+
max_layers_to_freeze=0,
|
32 |
+
use_attention_pooling=False,
|
33 |
+
):
|
34 |
super(GeneformerMultiTask, self).__init__()
|
35 |
self.config = BertConfig.from_pretrained(pretrained_path)
|
36 |
self.bert = BertModel(self.config)
|
|
|
39 |
self.dropout = nn.Dropout(dropout_rate)
|
40 |
self.use_attention_pooling = use_attention_pooling
|
41 |
|
42 |
+
if use_task_weights and (
|
43 |
+
task_weights is None or len(task_weights) != len(num_labels_list)
|
44 |
+
):
|
45 |
+
raise ValueError(
|
46 |
+
"Task weights must be defined and match the number of tasks when 'use_task_weights' is True."
|
47 |
+
)
|
48 |
+
self.task_weights = (
|
49 |
+
task_weights if use_task_weights else [1.0] * len(num_labels_list)
|
50 |
+
)
|
51 |
|
52 |
# Freeze the specified initial layers
|
53 |
for layer in self.bert.encoder.layer[:max_layers_to_freeze]:
|
54 |
for param in layer.parameters():
|
55 |
param.requires_grad = False
|
56 |
|
57 |
+
self.attention_pool = (
|
58 |
+
AttentionPool(self.config.hidden_size) if use_attention_pooling else None
|
59 |
+
)
|
60 |
|
61 |
+
self.classification_heads = nn.ModuleList(
|
62 |
+
[
|
63 |
+
nn.Linear(self.config.hidden_size, num_labels)
|
64 |
+
for num_labels in num_labels_list
|
65 |
+
]
|
66 |
+
)
|
67 |
# initialization of the classification heads: https://pytorch.org/docs/stable/nn.init.html
|
68 |
for head in self.classification_heads:
|
69 |
nn.init.xavier_uniform_(head.weight)
|
|
|
78 |
sequence_output = outputs.last_hidden_state
|
79 |
|
80 |
try:
|
81 |
+
pooled_output = (
|
82 |
+
self.attention_pool(sequence_output)
|
83 |
+
if self.use_attention_pooling
|
84 |
+
else sequence_output[:, 0, :]
|
85 |
+
)
|
86 |
pooled_output = self.dropout(pooled_output)
|
87 |
except Exception as e:
|
88 |
raise RuntimeError(f"Error during pooling and dropout: {e}")
|
|
|
91 |
logits = []
|
92 |
losses = []
|
93 |
|
94 |
+
for task_id, (head, num_labels) in enumerate(
|
95 |
+
zip(self.classification_heads, self.num_labels_list)
|
96 |
+
):
|
97 |
try:
|
98 |
task_logits = head(pooled_output)
|
99 |
except Exception as e:
|
100 |
+
raise RuntimeError(
|
101 |
+
f"Error during forward pass of classification head {task_id}: {e}"
|
102 |
+
)
|
103 |
|
104 |
logits.append(task_logits)
|
105 |
|
106 |
if labels is not None:
|
107 |
try:
|
108 |
loss_fct = nn.CrossEntropyLoss()
|
109 |
+
task_loss = loss_fct(
|
110 |
+
task_logits.view(-1, num_labels), labels[task_id].view(-1)
|
111 |
+
)
|
112 |
if self.use_task_weights:
|
113 |
task_loss *= self.task_weights[task_id]
|
114 |
total_loss += task_loss
|
115 |
losses.append(task_loss.item())
|
116 |
except Exception as e:
|
117 |
+
raise RuntimeError(
|
118 |
+
f"Error during loss computation for task {task_id}: {e}"
|
119 |
+
)
|
120 |
|
121 |
return total_loss, logits, losses if labels is not None else logits
|
geneformer/mtl/optuna_utils.py
CHANGED
@@ -1,21 +1,27 @@
|
|
1 |
import optuna
|
2 |
from optuna.integration import TensorBoardCallback
|
3 |
|
|
|
4 |
def save_trial_callback(study, trial, trials_result_path):
|
5 |
with open(trials_result_path, "a") as f:
|
6 |
-
f.write(
|
|
|
|
|
|
|
7 |
|
8 |
def create_optuna_study(objective, n_trials, trials_result_path, tensorboard_log_dir):
|
9 |
study = optuna.create_study(direction="maximize")
|
10 |
-
|
11 |
# init TensorBoard callback
|
12 |
-
tensorboard_callback = TensorBoardCallback(
|
13 |
-
|
|
|
|
|
14 |
# callback and TensorBoard callback
|
15 |
callbacks = [
|
16 |
lambda study, trial: save_trial_callback(study, trial, trials_result_path),
|
17 |
-
tensorboard_callback
|
18 |
]
|
19 |
-
|
20 |
study.optimize(objective, n_trials=n_trials, callbacks=callbacks)
|
21 |
-
return study
|
|
|
1 |
import optuna
|
2 |
from optuna.integration import TensorBoardCallback
|
3 |
|
4 |
+
|
5 |
def save_trial_callback(study, trial, trials_result_path):
|
6 |
with open(trials_result_path, "a") as f:
|
7 |
+
f.write(
|
8 |
+
f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n"
|
9 |
+
)
|
10 |
+
|
11 |
|
12 |
def create_optuna_study(objective, n_trials, trials_result_path, tensorboard_log_dir):
|
13 |
study = optuna.create_study(direction="maximize")
|
14 |
+
|
15 |
# init TensorBoard callback
|
16 |
+
tensorboard_callback = TensorBoardCallback(
|
17 |
+
dirname=tensorboard_log_dir, metric_name="F1 Macro"
|
18 |
+
)
|
19 |
+
|
20 |
# callback and TensorBoard callback
|
21 |
callbacks = [
|
22 |
lambda study, trial: save_trial_callback(study, trial, trials_result_path),
|
23 |
+
tensorboard_callback,
|
24 |
]
|
25 |
+
|
26 |
study.optimize(objective, n_trials=n_trials, callbacks=callbacks)
|
27 |
+
return study
|
geneformer/mtl/train.py
CHANGED
@@ -1,14 +1,15 @@
|
|
1 |
-
from .imports import *
|
2 |
-
from .data import preload_and_process_data, get_data_loader
|
3 |
-
from .model import GeneformerMultiTask
|
4 |
-
from .utils import calculate_task_specific_metrics
|
5 |
-
from torch.utils.tensorboard import SummaryWriter
|
6 |
-
import pandas as pd
|
7 |
import os
|
8 |
-
from tqdm import tqdm
|
9 |
import random
|
|
|
10 |
import numpy as np
|
|
|
11 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
def set_seed(seed):
|
@@ -19,13 +20,18 @@ def set_seed(seed):
|
|
19 |
torch.backends.cudnn.deterministic = True
|
20 |
torch.backends.cudnn.benchmark = False
|
21 |
|
|
|
22 |
def initialize_wandb(config):
|
23 |
if config.get("use_wandb", False):
|
24 |
import wandb
|
|
|
25 |
wandb.init(project=config["wandb_project"], config=config)
|
26 |
print("Weights & Biases (wandb) initialized and will be used for logging.")
|
27 |
else:
|
28 |
-
print(
|
|
|
|
|
|
|
29 |
|
30 |
def create_model(config, num_labels_list, device):
|
31 |
model = GeneformerMultiTask(
|
@@ -35,31 +41,48 @@ def create_model(config, num_labels_list, device):
|
|
35 |
use_task_weights=config["use_task_weights"],
|
36 |
task_weights=config["task_weights"],
|
37 |
max_layers_to_freeze=config["max_layers_to_freeze"],
|
38 |
-
use_attention_pooling=config["use_attention_pooling"]
|
39 |
)
|
40 |
if config["use_data_parallel"]:
|
41 |
model = nn.DataParallel(model)
|
42 |
return model.to(device)
|
43 |
|
|
|
44 |
def setup_optimizer_and_scheduler(model, config, total_steps):
|
45 |
-
optimizer = AdamW(
|
|
|
|
|
|
|
|
|
46 |
warmup_steps = int(config["warmup_ratio"] * total_steps)
|
47 |
|
48 |
if config["lr_scheduler_type"] == "linear":
|
49 |
-
scheduler = get_linear_schedule_with_warmup(
|
|
|
|
|
50 |
elif config["lr_scheduler_type"] == "cosine":
|
51 |
-
scheduler = get_cosine_schedule_with_warmup(
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
return optimizer, scheduler
|
54 |
|
55 |
-
|
|
|
|
|
|
|
56 |
model.train()
|
57 |
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
|
58 |
for batch_idx, batch in enumerate(progress_bar):
|
59 |
optimizer.zero_grad()
|
60 |
-
input_ids = batch[
|
61 |
-
attention_mask = batch[
|
62 |
-
labels = [
|
|
|
|
|
63 |
|
64 |
loss, _, _ = model(input_ids, attention_mask, labels)
|
65 |
loss.backward()
|
@@ -70,15 +93,20 @@ def train_epoch(model, train_loader, optimizer, scheduler, device, config, write
|
|
70 |
optimizer.step()
|
71 |
scheduler.step()
|
72 |
|
73 |
-
writer.add_scalar(
|
|
|
|
|
74 |
if config.get("use_wandb", False):
|
75 |
-
wandb
|
|
|
|
|
76 |
|
77 |
# Update progress bar
|
78 |
-
progress_bar.set_postfix({
|
79 |
|
80 |
return loss.item() # Return the last batch loss
|
81 |
|
|
|
82 |
def validate_model(model, val_loader, device, config):
|
83 |
model.eval()
|
84 |
val_loss = 0.0
|
@@ -88,17 +116,22 @@ def validate_model(model, val_loader, device, config):
|
|
88 |
|
89 |
with torch.no_grad():
|
90 |
for batch in val_loader:
|
91 |
-
input_ids = batch[
|
92 |
-
attention_mask = batch[
|
93 |
-
labels = [
|
|
|
|
|
|
|
94 |
loss, logits, _ = model(input_ids, attention_mask, labels)
|
95 |
val_loss += loss.item()
|
96 |
|
97 |
-
for sample_idx in range(len(batch[
|
98 |
for i, task_name in enumerate(config["task_names"]):
|
99 |
-
true_label = batch[
|
100 |
pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
|
101 |
-
pred_prob =
|
|
|
|
|
102 |
task_true_labels[task_name].append(true_label)
|
103 |
task_pred_labels[task_name].append(pred_label)
|
104 |
task_pred_probs[task_name].append(pred_prob)
|
@@ -106,44 +139,70 @@ def validate_model(model, val_loader, device, config):
|
|
106 |
val_loss /= len(val_loader)
|
107 |
return val_loss, task_true_labels, task_pred_labels, task_pred_probs
|
108 |
|
|
|
109 |
def log_metrics(task_metrics, val_loss, config, writer, epochs):
|
110 |
for task_name, metrics in task_metrics.items():
|
111 |
-
print(
|
|
|
|
|
112 |
if config.get("use_wandb", False):
|
113 |
import wandb
|
114 |
-
wandb.log({
|
115 |
-
f'{task_name} Validation F1 Macro': metrics['f1'],
|
116 |
-
f'{task_name} Validation Accuracy': metrics['accuracy']
|
117 |
-
})
|
118 |
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
123 |
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
if trial_number is not None:
|
126 |
trial_results_dir = os.path.join(config["results_dir"], f"trial_{trial_number}")
|
127 |
os.makedirs(trial_results_dir, exist_ok=True)
|
128 |
val_preds_file = os.path.join(trial_results_dir, "val_preds.csv")
|
129 |
else:
|
130 |
val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv")
|
131 |
-
|
132 |
rows = []
|
133 |
for sample_idx in range(len(val_cell_id_mapping)):
|
134 |
-
row = {
|
135 |
for task_name in config["task_names"]:
|
136 |
-
row[f
|
137 |
-
row[f
|
138 |
-
row[f
|
|
|
|
|
139 |
rows.append(row)
|
140 |
-
|
141 |
df = pd.DataFrame(rows)
|
142 |
df.to_csv(val_preds_file, index=False)
|
143 |
print(f"Validation predictions saved to {val_preds_file}")
|
144 |
|
145 |
|
146 |
-
def train_model(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
set_seed(config["seed"])
|
148 |
initialize_wandb(config)
|
149 |
|
@@ -156,46 +215,97 @@ def train_model(config, device, train_loader, val_loader, train_cell_id_mapping,
|
|
156 |
|
157 |
epoch_progress = tqdm(range(config["epochs"]), desc="Training Progress")
|
158 |
for epoch in epoch_progress:
|
159 |
-
last_loss = train_epoch(
|
160 |
-
|
|
|
|
|
161 |
|
162 |
-
val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
|
|
|
|
|
163 |
task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
|
164 |
-
|
165 |
log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
|
166 |
writer.close()
|
167 |
|
168 |
-
save_validation_predictions(
|
|
|
|
|
169 |
|
170 |
if config.get("use_wandb", False):
|
171 |
import wandb
|
|
|
172 |
wandb.finish()
|
173 |
|
174 |
print(f"\nFinal Validation Loss: {val_loss:.4f}")
|
175 |
return val_loss, model # Return both the validation loss and the trained model
|
176 |
|
177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
set_seed(config["seed"]) # Set the seed before each trial
|
179 |
initialize_wandb(config)
|
180 |
|
181 |
# Hyperparameters
|
182 |
-
config["learning_rate"] = trial.suggest_float(
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
if config["use_task_weights"]:
|
190 |
-
config["task_weights"] = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
weight_sum = sum(config["task_weights"])
|
192 |
-
config["task_weights"] = [
|
|
|
|
|
193 |
else:
|
194 |
config["task_weights"] = None
|
195 |
|
196 |
# Fix for max_layers_to_freeze
|
197 |
if isinstance(config["max_layers_to_freeze"], dict):
|
198 |
-
config["max_layers_to_freeze"] = trial.suggest_int(
|
|
|
|
|
|
|
|
|
199 |
elif isinstance(config["max_layers_to_freeze"], int):
|
200 |
# If it's already an int, we don't need to suggest it
|
201 |
pass
|
@@ -210,15 +320,26 @@ def objective(trial, train_loader, val_loader, train_cell_id_mapping, val_cell_i
|
|
210 |
writer = SummaryWriter(log_dir=log_dir)
|
211 |
|
212 |
for epoch in range(config["epochs"]):
|
213 |
-
train_epoch(
|
|
|
|
|
214 |
|
215 |
-
val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
|
|
|
|
|
216 |
task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
|
217 |
-
|
218 |
log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
|
219 |
writer.close()
|
220 |
|
221 |
-
save_validation_predictions(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
|
223 |
trial.set_user_attr("model_state_dict", model.state_dict())
|
224 |
trial.set_user_attr("task_weights", config["task_weights"])
|
@@ -230,13 +351,35 @@ def objective(trial, train_loader, val_loader, train_cell_id_mapping, val_cell_i
|
|
230 |
|
231 |
if config.get("use_wandb", False):
|
232 |
import wandb
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
wandb.finish()
|
241 |
|
242 |
-
return val_loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
|
|
2 |
import random
|
3 |
+
|
4 |
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
import torch
|
7 |
+
from torch.utils.tensorboard import SummaryWriter
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from .imports import *
|
11 |
+
from .model import GeneformerMultiTask
|
12 |
+
from .utils import calculate_task_specific_metrics
|
13 |
|
14 |
|
15 |
def set_seed(seed):
|
|
|
20 |
torch.backends.cudnn.deterministic = True
|
21 |
torch.backends.cudnn.benchmark = False
|
22 |
|
23 |
+
|
24 |
def initialize_wandb(config):
|
25 |
if config.get("use_wandb", False):
|
26 |
import wandb
|
27 |
+
|
28 |
wandb.init(project=config["wandb_project"], config=config)
|
29 |
print("Weights & Biases (wandb) initialized and will be used for logging.")
|
30 |
else:
|
31 |
+
print(
|
32 |
+
"Weights & Biases (wandb) is not enabled. Logging will use other methods."
|
33 |
+
)
|
34 |
+
|
35 |
|
36 |
def create_model(config, num_labels_list, device):
|
37 |
model = GeneformerMultiTask(
|
|
|
41 |
use_task_weights=config["use_task_weights"],
|
42 |
task_weights=config["task_weights"],
|
43 |
max_layers_to_freeze=config["max_layers_to_freeze"],
|
44 |
+
use_attention_pooling=config["use_attention_pooling"],
|
45 |
)
|
46 |
if config["use_data_parallel"]:
|
47 |
model = nn.DataParallel(model)
|
48 |
return model.to(device)
|
49 |
|
50 |
+
|
51 |
def setup_optimizer_and_scheduler(model, config, total_steps):
|
52 |
+
optimizer = AdamW(
|
53 |
+
model.parameters(),
|
54 |
+
lr=config["learning_rate"],
|
55 |
+
weight_decay=config["weight_decay"],
|
56 |
+
)
|
57 |
warmup_steps = int(config["warmup_ratio"] * total_steps)
|
58 |
|
59 |
if config["lr_scheduler_type"] == "linear":
|
60 |
+
scheduler = get_linear_schedule_with_warmup(
|
61 |
+
optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
|
62 |
+
)
|
63 |
elif config["lr_scheduler_type"] == "cosine":
|
64 |
+
scheduler = get_cosine_schedule_with_warmup(
|
65 |
+
optimizer,
|
66 |
+
num_warmup_steps=warmup_steps,
|
67 |
+
num_training_steps=total_steps,
|
68 |
+
num_cycles=0.5,
|
69 |
+
)
|
70 |
|
71 |
return optimizer, scheduler
|
72 |
|
73 |
+
|
74 |
+
def train_epoch(
|
75 |
+
model, train_loader, optimizer, scheduler, device, config, writer, epoch
|
76 |
+
):
|
77 |
model.train()
|
78 |
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
|
79 |
for batch_idx, batch in enumerate(progress_bar):
|
80 |
optimizer.zero_grad()
|
81 |
+
input_ids = batch["input_ids"].to(device)
|
82 |
+
attention_mask = batch["attention_mask"].to(device)
|
83 |
+
labels = [
|
84 |
+
batch["labels"][task_name].to(device) for task_name in config["task_names"]
|
85 |
+
]
|
86 |
|
87 |
loss, _, _ = model(input_ids, attention_mask, labels)
|
88 |
loss.backward()
|
|
|
93 |
optimizer.step()
|
94 |
scheduler.step()
|
95 |
|
96 |
+
writer.add_scalar(
|
97 |
+
"Training Loss", loss.item(), epoch * len(train_loader) + batch_idx
|
98 |
+
)
|
99 |
if config.get("use_wandb", False):
|
100 |
+
import wandb
|
101 |
+
|
102 |
+
wandb.log({"Training Loss": loss.item()})
|
103 |
|
104 |
# Update progress bar
|
105 |
+
progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
|
106 |
|
107 |
return loss.item() # Return the last batch loss
|
108 |
|
109 |
+
|
110 |
def validate_model(model, val_loader, device, config):
|
111 |
model.eval()
|
112 |
val_loss = 0.0
|
|
|
116 |
|
117 |
with torch.no_grad():
|
118 |
for batch in val_loader:
|
119 |
+
input_ids = batch["input_ids"].to(device)
|
120 |
+
attention_mask = batch["attention_mask"].to(device)
|
121 |
+
labels = [
|
122 |
+
batch["labels"][task_name].to(device)
|
123 |
+
for task_name in config["task_names"]
|
124 |
+
]
|
125 |
loss, logits, _ = model(input_ids, attention_mask, labels)
|
126 |
val_loss += loss.item()
|
127 |
|
128 |
+
for sample_idx in range(len(batch["input_ids"])):
|
129 |
for i, task_name in enumerate(config["task_names"]):
|
130 |
+
true_label = batch["labels"][task_name][sample_idx].item()
|
131 |
pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
|
132 |
+
pred_prob = (
|
133 |
+
torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
|
134 |
+
)
|
135 |
task_true_labels[task_name].append(true_label)
|
136 |
task_pred_labels[task_name].append(pred_label)
|
137 |
task_pred_probs[task_name].append(pred_prob)
|
|
|
139 |
val_loss /= len(val_loader)
|
140 |
return val_loss, task_true_labels, task_pred_labels, task_pred_probs
|
141 |
|
142 |
+
|
143 |
def log_metrics(task_metrics, val_loss, config, writer, epochs):
|
144 |
for task_name, metrics in task_metrics.items():
|
145 |
+
print(
|
146 |
+
f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}"
|
147 |
+
)
|
148 |
if config.get("use_wandb", False):
|
149 |
import wandb
|
|
|
|
|
|
|
|
|
150 |
|
151 |
+
wandb.log(
|
152 |
+
{
|
153 |
+
f"{task_name} Validation F1 Macro": metrics["f1"],
|
154 |
+
f"{task_name} Validation Accuracy": metrics["accuracy"],
|
155 |
+
}
|
156 |
+
)
|
157 |
|
158 |
+
writer.add_scalar("Validation Loss", val_loss, epochs)
|
159 |
+
for task_name, metrics in task_metrics.items():
|
160 |
+
writer.add_scalar(f"{task_name} - Validation F1 Macro", metrics["f1"], epochs)
|
161 |
+
writer.add_scalar(
|
162 |
+
f"{task_name} - Validation Accuracy", metrics["accuracy"], epochs
|
163 |
+
)
|
164 |
+
|
165 |
+
|
166 |
+
def save_validation_predictions(
|
167 |
+
val_cell_id_mapping,
|
168 |
+
task_true_labels,
|
169 |
+
task_pred_labels,
|
170 |
+
task_pred_probs,
|
171 |
+
config,
|
172 |
+
trial_number=None,
|
173 |
+
):
|
174 |
if trial_number is not None:
|
175 |
trial_results_dir = os.path.join(config["results_dir"], f"trial_{trial_number}")
|
176 |
os.makedirs(trial_results_dir, exist_ok=True)
|
177 |
val_preds_file = os.path.join(trial_results_dir, "val_preds.csv")
|
178 |
else:
|
179 |
val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv")
|
180 |
+
|
181 |
rows = []
|
182 |
for sample_idx in range(len(val_cell_id_mapping)):
|
183 |
+
row = {"Cell ID": val_cell_id_mapping[sample_idx]}
|
184 |
for task_name in config["task_names"]:
|
185 |
+
row[f"{task_name} True"] = task_true_labels[task_name][sample_idx]
|
186 |
+
row[f"{task_name} Pred"] = task_pred_labels[task_name][sample_idx]
|
187 |
+
row[f"{task_name} Probabilities"] = ",".join(
|
188 |
+
map(str, task_pred_probs[task_name][sample_idx])
|
189 |
+
)
|
190 |
rows.append(row)
|
191 |
+
|
192 |
df = pd.DataFrame(rows)
|
193 |
df.to_csv(val_preds_file, index=False)
|
194 |
print(f"Validation predictions saved to {val_preds_file}")
|
195 |
|
196 |
|
197 |
+
def train_model(
|
198 |
+
config,
|
199 |
+
device,
|
200 |
+
train_loader,
|
201 |
+
val_loader,
|
202 |
+
train_cell_id_mapping,
|
203 |
+
val_cell_id_mapping,
|
204 |
+
num_labels_list,
|
205 |
+
):
|
206 |
set_seed(config["seed"])
|
207 |
initialize_wandb(config)
|
208 |
|
|
|
215 |
|
216 |
epoch_progress = tqdm(range(config["epochs"]), desc="Training Progress")
|
217 |
for epoch in epoch_progress:
|
218 |
+
last_loss = train_epoch(
|
219 |
+
model, train_loader, optimizer, scheduler, device, config, writer, epoch
|
220 |
+
)
|
221 |
+
epoch_progress.set_postfix({"last_loss": f"{last_loss:.4f}"})
|
222 |
|
223 |
+
val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
|
224 |
+
model, val_loader, device, config
|
225 |
+
)
|
226 |
task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
|
227 |
+
|
228 |
log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
|
229 |
writer.close()
|
230 |
|
231 |
+
save_validation_predictions(
|
232 |
+
val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config
|
233 |
+
)
|
234 |
|
235 |
if config.get("use_wandb", False):
|
236 |
import wandb
|
237 |
+
|
238 |
wandb.finish()
|
239 |
|
240 |
print(f"\nFinal Validation Loss: {val_loss:.4f}")
|
241 |
return val_loss, model # Return both the validation loss and the trained model
|
242 |
|
243 |
+
|
244 |
+
def objective(
|
245 |
+
trial,
|
246 |
+
train_loader,
|
247 |
+
val_loader,
|
248 |
+
train_cell_id_mapping,
|
249 |
+
val_cell_id_mapping,
|
250 |
+
num_labels_list,
|
251 |
+
config,
|
252 |
+
device,
|
253 |
+
):
|
254 |
set_seed(config["seed"]) # Set the seed before each trial
|
255 |
initialize_wandb(config)
|
256 |
|
257 |
# Hyperparameters
|
258 |
+
config["learning_rate"] = trial.suggest_float(
|
259 |
+
"learning_rate",
|
260 |
+
config["hyperparameters"]["learning_rate"]["low"],
|
261 |
+
config["hyperparameters"]["learning_rate"]["high"],
|
262 |
+
log=config["hyperparameters"]["learning_rate"]["log"],
|
263 |
+
)
|
264 |
+
config["warmup_ratio"] = trial.suggest_float(
|
265 |
+
"warmup_ratio",
|
266 |
+
config["hyperparameters"]["warmup_ratio"]["low"],
|
267 |
+
config["hyperparameters"]["warmup_ratio"]["high"],
|
268 |
+
)
|
269 |
+
config["weight_decay"] = trial.suggest_float(
|
270 |
+
"weight_decay",
|
271 |
+
config["hyperparameters"]["weight_decay"]["low"],
|
272 |
+
config["hyperparameters"]["weight_decay"]["high"],
|
273 |
+
)
|
274 |
+
config["dropout_rate"] = trial.suggest_float(
|
275 |
+
"dropout_rate",
|
276 |
+
config["hyperparameters"]["dropout_rate"]["low"],
|
277 |
+
config["hyperparameters"]["dropout_rate"]["high"],
|
278 |
+
)
|
279 |
+
config["lr_scheduler_type"] = trial.suggest_categorical(
|
280 |
+
"lr_scheduler_type", config["hyperparameters"]["lr_scheduler_type"]["choices"]
|
281 |
+
)
|
282 |
+
config["use_attention_pooling"] = trial.suggest_categorical(
|
283 |
+
"use_attention_pooling", [True, False]
|
284 |
+
)
|
285 |
|
286 |
if config["use_task_weights"]:
|
287 |
+
config["task_weights"] = [
|
288 |
+
trial.suggest_float(
|
289 |
+
f"task_weight_{i}",
|
290 |
+
config["hyperparameters"]["task_weights"]["low"],
|
291 |
+
config["hyperparameters"]["task_weights"]["high"],
|
292 |
+
)
|
293 |
+
for i in range(len(num_labels_list))
|
294 |
+
]
|
295 |
weight_sum = sum(config["task_weights"])
|
296 |
+
config["task_weights"] = [
|
297 |
+
weight / weight_sum for weight in config["task_weights"]
|
298 |
+
]
|
299 |
else:
|
300 |
config["task_weights"] = None
|
301 |
|
302 |
# Fix for max_layers_to_freeze
|
303 |
if isinstance(config["max_layers_to_freeze"], dict):
|
304 |
+
config["max_layers_to_freeze"] = trial.suggest_int(
|
305 |
+
"max_layers_to_freeze",
|
306 |
+
config["max_layers_to_freeze"]["min"],
|
307 |
+
config["max_layers_to_freeze"]["max"],
|
308 |
+
)
|
309 |
elif isinstance(config["max_layers_to_freeze"], int):
|
310 |
# If it's already an int, we don't need to suggest it
|
311 |
pass
|
|
|
320 |
writer = SummaryWriter(log_dir=log_dir)
|
321 |
|
322 |
for epoch in range(config["epochs"]):
|
323 |
+
train_epoch(
|
324 |
+
model, train_loader, optimizer, scheduler, device, config, writer, epoch
|
325 |
+
)
|
326 |
|
327 |
+
val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
|
328 |
+
model, val_loader, device, config
|
329 |
+
)
|
330 |
task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
|
331 |
+
|
332 |
log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
|
333 |
writer.close()
|
334 |
|
335 |
+
save_validation_predictions(
|
336 |
+
val_cell_id_mapping,
|
337 |
+
task_true_labels,
|
338 |
+
task_pred_labels,
|
339 |
+
task_pred_probs,
|
340 |
+
config,
|
341 |
+
trial.number,
|
342 |
+
)
|
343 |
|
344 |
trial.set_user_attr("model_state_dict", model.state_dict())
|
345 |
trial.set_user_attr("task_weights", config["task_weights"])
|
|
|
351 |
|
352 |
if config.get("use_wandb", False):
|
353 |
import wandb
|
354 |
+
|
355 |
+
wandb.log(
|
356 |
+
{
|
357 |
+
"trial_number": trial.number,
|
358 |
+
"val_loss": val_loss,
|
359 |
+
**{
|
360 |
+
f"{task_name}_f1": metrics["f1"]
|
361 |
+
for task_name, metrics in task_metrics.items()
|
362 |
+
},
|
363 |
+
**{
|
364 |
+
f"{task_name}_accuracy": metrics["accuracy"]
|
365 |
+
for task_name, metrics in task_metrics.items()
|
366 |
+
},
|
367 |
+
**{
|
368 |
+
k: v
|
369 |
+
for k, v in config.items()
|
370 |
+
if k
|
371 |
+
in [
|
372 |
+
"learning_rate",
|
373 |
+
"warmup_ratio",
|
374 |
+
"weight_decay",
|
375 |
+
"dropout_rate",
|
376 |
+
"lr_scheduler_type",
|
377 |
+
"use_attention_pooling",
|
378 |
+
"max_layers_to_freeze",
|
379 |
+
]
|
380 |
+
},
|
381 |
+
}
|
382 |
+
)
|
383 |
wandb.finish()
|
384 |
|
385 |
+
return val_loss
|
geneformer/mtl/train_utils.py
CHANGED
@@ -1,9 +1,11 @@
|
|
|
|
|
|
|
|
1 |
from .imports import *
|
2 |
-
from .data import preload_and_process_data, get_data_loader
|
3 |
-
from .train import objective, train_model
|
4 |
from .model import GeneformerMultiTask
|
|
|
5 |
from .utils import save_model
|
6 |
-
|
7 |
|
8 |
def set_seed(seed):
|
9 |
random.seed(seed)
|
@@ -12,15 +14,22 @@ def set_seed(seed):
|
|
12 |
torch.cuda.manual_seed_all(seed)
|
13 |
torch.backends.cudnn.deterministic = True
|
14 |
torch.backends.cudnn.benchmark = False
|
15 |
-
|
|
|
16 |
def run_manual_tuning(config):
|
17 |
# Set seed for reproducibility
|
18 |
set_seed(config["seed"])
|
19 |
|
20 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
# Print the manual hyperparameters being used
|
26 |
print("\nManual hyperparameters being used:")
|
@@ -33,12 +42,22 @@ def run_manual_tuning(config):
|
|
33 |
config[key] = value
|
34 |
|
35 |
# Train the model
|
36 |
-
val_loss, trained_model = train_model(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
print(f"\nValidation loss with manual hyperparameters: {val_loss}")
|
39 |
|
40 |
# Save the trained model
|
41 |
-
model_save_directory = os.path.join(
|
|
|
|
|
42 |
save_model(trained_model, model_save_directory)
|
43 |
|
44 |
# Save the hyperparameters
|
@@ -48,26 +67,41 @@ def run_manual_tuning(config):
|
|
48 |
"use_task_weights": config["use_task_weights"],
|
49 |
"task_weights": config["task_weights"],
|
50 |
"max_layers_to_freeze": config["max_layers_to_freeze"],
|
51 |
-
"use_attention_pooling": config["use_attention_pooling"]
|
52 |
}
|
53 |
hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
|
54 |
-
with open(hyperparams_path,
|
55 |
json.dump(hyperparams_to_save, f)
|
56 |
print(f"Manual hyperparameters saved to {hyperparams_path}")
|
57 |
|
58 |
return val_loss
|
59 |
|
|
|
60 |
def run_optuna_study(config):
|
61 |
# Set seed for reproducibility
|
62 |
set_seed(config["seed"])
|
63 |
|
64 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
if config["use_manual_hyperparameters"]:
|
70 |
-
train_model(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
else:
|
72 |
objective_with_config_and_data = functools.partial(
|
73 |
objective,
|
@@ -77,20 +111,17 @@ def run_optuna_study(config):
|
|
77 |
val_cell_id_mapping=val_cell_id_mapping,
|
78 |
num_labels_list=num_labels_list,
|
79 |
config=config,
|
80 |
-
device=device
|
81 |
)
|
82 |
|
83 |
study = optuna.create_study(
|
84 |
-
direction=
|
85 |
study_name=config["study_name"],
|
86 |
-
#storage=config["storage"],
|
87 |
-
load_if_exists=True
|
88 |
)
|
89 |
|
90 |
-
study.optimize(
|
91 |
-
objective_with_config_and_data,
|
92 |
-
n_trials=config["n_trials"]
|
93 |
-
)
|
94 |
|
95 |
# After finding the best trial
|
96 |
best_params = study.best_trial.params
|
@@ -103,24 +134,28 @@ def run_optuna_study(config):
|
|
103 |
num_labels_list,
|
104 |
dropout_rate=best_params["dropout_rate"],
|
105 |
use_task_weights=config["use_task_weights"],
|
106 |
-
task_weights=best_task_weights
|
107 |
)
|
108 |
-
|
109 |
# Get the best model state dictionary
|
110 |
best_model_state_dict = study.best_trial.user_attrs["model_state_dict"]
|
111 |
|
112 |
# Remove the "module." prefix from the state dictionary keys if present
|
113 |
-
best_model_state_dict = {
|
|
|
|
|
114 |
|
115 |
# Load the modified state dictionary into the model, skipping unexpected keys
|
116 |
best_model.load_state_dict(best_model_state_dict, strict=False)
|
117 |
|
118 |
-
model_save_directory = os.path.join(
|
|
|
|
|
119 |
save_model(best_model, model_save_directory)
|
120 |
|
121 |
# Additionally, save the best hyperparameters and task weights
|
122 |
hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
|
123 |
-
|
124 |
-
with open(hyperparams_path,
|
125 |
json.dump({**best_params, "task_weights": best_task_weights}, f)
|
126 |
-
print(f"Best hyperparameters and task weights saved to {hyperparams_path}")
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
from .data import get_data_loader, preload_and_process_data
|
4 |
from .imports import *
|
|
|
|
|
5 |
from .model import GeneformerMultiTask
|
6 |
+
from .train import objective, train_model
|
7 |
from .utils import save_model
|
8 |
+
|
9 |
|
10 |
def set_seed(seed):
|
11 |
random.seed(seed)
|
|
|
14 |
torch.cuda.manual_seed_all(seed)
|
15 |
torch.backends.cudnn.deterministic = True
|
16 |
torch.backends.cudnn.benchmark = False
|
17 |
+
|
18 |
+
|
19 |
def run_manual_tuning(config):
|
20 |
# Set seed for reproducibility
|
21 |
set_seed(config["seed"])
|
22 |
|
23 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
24 |
+
(
|
25 |
+
train_dataset,
|
26 |
+
train_cell_id_mapping,
|
27 |
+
val_dataset,
|
28 |
+
val_cell_id_mapping,
|
29 |
+
num_labels_list,
|
30 |
+
) = preload_and_process_data(config)
|
31 |
+
train_loader = get_data_loader(train_dataset, config["batch_size"])
|
32 |
+
val_loader = get_data_loader(val_dataset, config["batch_size"])
|
33 |
|
34 |
# Print the manual hyperparameters being used
|
35 |
print("\nManual hyperparameters being used:")
|
|
|
42 |
config[key] = value
|
43 |
|
44 |
# Train the model
|
45 |
+
val_loss, trained_model = train_model(
|
46 |
+
config,
|
47 |
+
device,
|
48 |
+
train_loader,
|
49 |
+
val_loader,
|
50 |
+
train_cell_id_mapping,
|
51 |
+
val_cell_id_mapping,
|
52 |
+
num_labels_list,
|
53 |
+
)
|
54 |
|
55 |
print(f"\nValidation loss with manual hyperparameters: {val_loss}")
|
56 |
|
57 |
# Save the trained model
|
58 |
+
model_save_directory = os.path.join(
|
59 |
+
config["model_save_path"], "GeneformerMultiTask"
|
60 |
+
)
|
61 |
save_model(trained_model, model_save_directory)
|
62 |
|
63 |
# Save the hyperparameters
|
|
|
67 |
"use_task_weights": config["use_task_weights"],
|
68 |
"task_weights": config["task_weights"],
|
69 |
"max_layers_to_freeze": config["max_layers_to_freeze"],
|
70 |
+
"use_attention_pooling": config["use_attention_pooling"],
|
71 |
}
|
72 |
hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
|
73 |
+
with open(hyperparams_path, "w") as f:
|
74 |
json.dump(hyperparams_to_save, f)
|
75 |
print(f"Manual hyperparameters saved to {hyperparams_path}")
|
76 |
|
77 |
return val_loss
|
78 |
|
79 |
+
|
80 |
def run_optuna_study(config):
|
81 |
# Set seed for reproducibility
|
82 |
set_seed(config["seed"])
|
83 |
|
84 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
85 |
+
(
|
86 |
+
train_dataset,
|
87 |
+
train_cell_id_mapping,
|
88 |
+
val_dataset,
|
89 |
+
val_cell_id_mapping,
|
90 |
+
num_labels_list,
|
91 |
+
) = preload_and_process_data(config)
|
92 |
+
train_loader = get_data_loader(train_dataset, config["batch_size"])
|
93 |
+
val_loader = get_data_loader(val_dataset, config["batch_size"])
|
94 |
|
95 |
if config["use_manual_hyperparameters"]:
|
96 |
+
train_model(
|
97 |
+
config,
|
98 |
+
device,
|
99 |
+
train_loader,
|
100 |
+
val_loader,
|
101 |
+
train_cell_id_mapping,
|
102 |
+
val_cell_id_mapping,
|
103 |
+
num_labels_list,
|
104 |
+
)
|
105 |
else:
|
106 |
objective_with_config_and_data = functools.partial(
|
107 |
objective,
|
|
|
111 |
val_cell_id_mapping=val_cell_id_mapping,
|
112 |
num_labels_list=num_labels_list,
|
113 |
config=config,
|
114 |
+
device=device,
|
115 |
)
|
116 |
|
117 |
study = optuna.create_study(
|
118 |
+
direction="minimize", # Minimize validation loss
|
119 |
study_name=config["study_name"],
|
120 |
+
# storage=config["storage"],
|
121 |
+
load_if_exists=True,
|
122 |
)
|
123 |
|
124 |
+
study.optimize(objective_with_config_and_data, n_trials=config["n_trials"])
|
|
|
|
|
|
|
125 |
|
126 |
# After finding the best trial
|
127 |
best_params = study.best_trial.params
|
|
|
134 |
num_labels_list,
|
135 |
dropout_rate=best_params["dropout_rate"],
|
136 |
use_task_weights=config["use_task_weights"],
|
137 |
+
task_weights=best_task_weights,
|
138 |
)
|
139 |
+
|
140 |
# Get the best model state dictionary
|
141 |
best_model_state_dict = study.best_trial.user_attrs["model_state_dict"]
|
142 |
|
143 |
# Remove the "module." prefix from the state dictionary keys if present
|
144 |
+
best_model_state_dict = {
|
145 |
+
k.replace("module.", ""): v for k, v in best_model_state_dict.items()
|
146 |
+
}
|
147 |
|
148 |
# Load the modified state dictionary into the model, skipping unexpected keys
|
149 |
best_model.load_state_dict(best_model_state_dict, strict=False)
|
150 |
|
151 |
+
model_save_directory = os.path.join(
|
152 |
+
config["model_save_path"], "GeneformerMultiTask"
|
153 |
+
)
|
154 |
save_model(best_model, model_save_directory)
|
155 |
|
156 |
# Additionally, save the best hyperparameters and task weights
|
157 |
hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
|
158 |
+
|
159 |
+
with open(hyperparams_path, "w") as f:
|
160 |
json.dump({**best_params, "task_weights": best_task_weights}, f)
|
161 |
+
print(f"Best hyperparameters and task weights saved to {hyperparams_path}")
|
geneformer/mtl/utils.py
CHANGED
@@ -1,44 +1,55 @@
|
|
1 |
-
from .imports import *
|
2 |
-
from sklearn.metrics import f1_score, accuracy_score
|
3 |
-
from sklearn.preprocessing import LabelEncoder
|
4 |
-
from transformers import BertModel, BertConfig, AutoConfig
|
5 |
import os
|
6 |
import shutil
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
def save_model(model, model_save_directory):
|
9 |
if not os.path.exists(model_save_directory):
|
10 |
os.makedirs(model_save_directory)
|
11 |
-
|
12 |
# Get the state dict
|
13 |
if isinstance(model, nn.DataParallel):
|
14 |
-
model_state_dict =
|
|
|
|
|
15 |
else:
|
16 |
model_state_dict = model.state_dict()
|
17 |
-
|
18 |
# Remove the "module." prefix from the keys if present
|
19 |
-
model_state_dict = {
|
20 |
-
|
|
|
|
|
21 |
model_save_path = os.path.join(model_save_directory, "pytorch_model.bin")
|
22 |
torch.save(model_state_dict, model_save_path)
|
23 |
-
|
24 |
# Save the model configuration
|
25 |
if isinstance(model, nn.DataParallel):
|
26 |
-
model.module.config.to_json_file(
|
|
|
|
|
27 |
else:
|
28 |
model.config.to_json_file(os.path.join(model_save_directory, "config.json"))
|
29 |
-
|
30 |
print(f"Model and configuration saved to {model_save_directory}")
|
31 |
|
|
|
32 |
def calculate_task_specific_metrics(task_true_labels, task_pred_labels):
|
33 |
task_metrics = {}
|
34 |
for task_name in task_true_labels.keys():
|
35 |
true_labels = task_true_labels[task_name]
|
36 |
pred_labels = task_pred_labels[task_name]
|
37 |
-
f1 = f1_score(true_labels, pred_labels, average=
|
38 |
accuracy = accuracy_score(true_labels, pred_labels)
|
39 |
-
task_metrics[task_name] = {
|
40 |
return task_metrics
|
41 |
|
|
|
42 |
def calculate_combined_f1(combined_labels, combined_preds):
|
43 |
# Initialize the LabelEncoder
|
44 |
le = LabelEncoder()
|
@@ -57,10 +68,11 @@ def calculate_combined_f1(combined_labels, combined_preds):
|
|
57 |
accuracy = accuracy_score(encoded_true_labels, encoded_pred_labels)
|
58 |
|
59 |
# Calculate F1 Macro score
|
60 |
-
f1 = f1_score(encoded_true_labels, encoded_pred_labels, average=
|
61 |
|
62 |
return f1, accuracy
|
63 |
|
|
|
64 |
def save_model_without_heads(original_model_save_directory):
|
65 |
# Create a new directory for the model without heads
|
66 |
new_model_save_directory = original_model_save_directory + "_No_Heads"
|
@@ -68,25 +80,36 @@ def save_model_without_heads(original_model_save_directory):
|
|
68 |
os.makedirs(new_model_save_directory)
|
69 |
|
70 |
# Load the model state dictionary
|
71 |
-
model_state_dict = torch.load(
|
|
|
|
|
72 |
|
73 |
# Initialize a new BERT model without the classification heads
|
74 |
-
config = BertConfig.from_pretrained(
|
|
|
|
|
75 |
model_without_heads = BertModel(config)
|
76 |
-
|
77 |
# Filter the state dict to exclude classification heads
|
78 |
-
model_without_heads_state_dict = {
|
79 |
-
|
|
|
|
|
|
|
|
|
80 |
# Load the filtered state dict into the model
|
81 |
model_without_heads.load_state_dict(model_without_heads_state_dict, strict=False)
|
82 |
-
|
83 |
# Save the model without heads
|
84 |
model_save_path = os.path.join(new_model_save_directory, "pytorch_model.bin")
|
85 |
torch.save(model_without_heads.state_dict(), model_save_path)
|
86 |
-
|
87 |
# Copy the configuration file
|
88 |
-
shutil.copy(
|
89 |
-
|
|
|
|
|
|
|
90 |
print(f"Model without classification heads saved to {new_model_save_directory}")
|
91 |
|
92 |
|
@@ -100,7 +123,7 @@ def get_layer_freeze_range(pretrained_path):
|
|
100 |
"""
|
101 |
if pretrained_path:
|
102 |
config = AutoConfig.from_pretrained(pretrained_path)
|
103 |
-
total_layers = config.num_hidden_layers
|
104 |
-
return {"min": 0, "max": total_layers - 1}
|
105 |
else:
|
106 |
-
return {"min": 0, "max": 0}
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import shutil
|
3 |
|
4 |
+
from sklearn.metrics import accuracy_score, f1_score
|
5 |
+
from sklearn.preprocessing import LabelEncoder
|
6 |
+
from transformers import AutoConfig, BertConfig, BertModel
|
7 |
+
|
8 |
+
from .imports import *
|
9 |
+
|
10 |
+
|
11 |
def save_model(model, model_save_directory):
|
12 |
if not os.path.exists(model_save_directory):
|
13 |
os.makedirs(model_save_directory)
|
14 |
+
|
15 |
# Get the state dict
|
16 |
if isinstance(model, nn.DataParallel):
|
17 |
+
model_state_dict = (
|
18 |
+
model.module.state_dict()
|
19 |
+
) # Use model.module to access the underlying model
|
20 |
else:
|
21 |
model_state_dict = model.state_dict()
|
22 |
+
|
23 |
# Remove the "module." prefix from the keys if present
|
24 |
+
model_state_dict = {
|
25 |
+
k.replace("module.", ""): v for k, v in model_state_dict.items()
|
26 |
+
}
|
27 |
+
|
28 |
model_save_path = os.path.join(model_save_directory, "pytorch_model.bin")
|
29 |
torch.save(model_state_dict, model_save_path)
|
30 |
+
|
31 |
# Save the model configuration
|
32 |
if isinstance(model, nn.DataParallel):
|
33 |
+
model.module.config.to_json_file(
|
34 |
+
os.path.join(model_save_directory, "config.json")
|
35 |
+
)
|
36 |
else:
|
37 |
model.config.to_json_file(os.path.join(model_save_directory, "config.json"))
|
38 |
+
|
39 |
print(f"Model and configuration saved to {model_save_directory}")
|
40 |
|
41 |
+
|
42 |
def calculate_task_specific_metrics(task_true_labels, task_pred_labels):
|
43 |
task_metrics = {}
|
44 |
for task_name in task_true_labels.keys():
|
45 |
true_labels = task_true_labels[task_name]
|
46 |
pred_labels = task_pred_labels[task_name]
|
47 |
+
f1 = f1_score(true_labels, pred_labels, average="macro")
|
48 |
accuracy = accuracy_score(true_labels, pred_labels)
|
49 |
+
task_metrics[task_name] = {"f1": f1, "accuracy": accuracy}
|
50 |
return task_metrics
|
51 |
|
52 |
+
|
53 |
def calculate_combined_f1(combined_labels, combined_preds):
|
54 |
# Initialize the LabelEncoder
|
55 |
le = LabelEncoder()
|
|
|
68 |
accuracy = accuracy_score(encoded_true_labels, encoded_pred_labels)
|
69 |
|
70 |
# Calculate F1 Macro score
|
71 |
+
f1 = f1_score(encoded_true_labels, encoded_pred_labels, average="macro")
|
72 |
|
73 |
return f1, accuracy
|
74 |
|
75 |
+
|
76 |
def save_model_without_heads(original_model_save_directory):
|
77 |
# Create a new directory for the model without heads
|
78 |
new_model_save_directory = original_model_save_directory + "_No_Heads"
|
|
|
80 |
os.makedirs(new_model_save_directory)
|
81 |
|
82 |
# Load the model state dictionary
|
83 |
+
model_state_dict = torch.load(
|
84 |
+
os.path.join(original_model_save_directory, "pytorch_model.bin")
|
85 |
+
)
|
86 |
|
87 |
# Initialize a new BERT model without the classification heads
|
88 |
+
config = BertConfig.from_pretrained(
|
89 |
+
os.path.join(original_model_save_directory, "config.json")
|
90 |
+
)
|
91 |
model_without_heads = BertModel(config)
|
92 |
+
|
93 |
# Filter the state dict to exclude classification heads
|
94 |
+
model_without_heads_state_dict = {
|
95 |
+
k: v
|
96 |
+
for k, v in model_state_dict.items()
|
97 |
+
if not k.startswith("classification_heads")
|
98 |
+
}
|
99 |
+
|
100 |
# Load the filtered state dict into the model
|
101 |
model_without_heads.load_state_dict(model_without_heads_state_dict, strict=False)
|
102 |
+
|
103 |
# Save the model without heads
|
104 |
model_save_path = os.path.join(new_model_save_directory, "pytorch_model.bin")
|
105 |
torch.save(model_without_heads.state_dict(), model_save_path)
|
106 |
+
|
107 |
# Copy the configuration file
|
108 |
+
shutil.copy(
|
109 |
+
os.path.join(original_model_save_directory, "config.json"),
|
110 |
+
new_model_save_directory,
|
111 |
+
)
|
112 |
+
|
113 |
print(f"Model without classification heads saved to {new_model_save_directory}")
|
114 |
|
115 |
|
|
|
123 |
"""
|
124 |
if pretrained_path:
|
125 |
config = AutoConfig.from_pretrained(pretrained_path)
|
126 |
+
total_layers = config.num_hidden_layers
|
127 |
+
return {"min": 0, "max": total_layers - 1}
|
128 |
else:
|
129 |
+
return {"min": 0, "max": 0}
|
geneformer/mtl_classifier.py
CHANGED
@@ -28,9 +28,8 @@ Geneformer multi-task cell classifier.
|
|
28 |
|
29 |
import logging
|
30 |
import os
|
31 |
-
|
32 |
-
from .mtl import utils
|
33 |
-
from .mtl import eval_utils
|
34 |
|
35 |
logger = logging.getLogger(__name__)
|
36 |
|
@@ -90,9 +89,8 @@ class MTLClassifier:
|
|
90 |
wandb_project=None,
|
91 |
gradient_clipping=False,
|
92 |
max_grad_norm=None,
|
93 |
-
seed=42 # Default seed value
|
94 |
):
|
95 |
-
|
96 |
"""
|
97 |
Initialize Geneformer multi-task classifier.
|
98 |
**Parameters:**
|
@@ -165,11 +163,11 @@ class MTLClassifier:
|
|
165 |
self.batch_size = batch_size
|
166 |
self.n_trials = n_trials
|
167 |
self.study_name = study_name
|
168 |
-
|
169 |
if max_layers_to_freeze is None:
|
170 |
# Dynamically determine the range of layers to freeze
|
171 |
layer_freeze_range = utils.get_layer_freeze_range(pretrained_path)
|
172 |
-
self.max_layers_to_freeze = {"min": 1, "max": layer_freeze_range[
|
173 |
else:
|
174 |
self.max_layers_to_freeze = max_layers_to_freeze
|
175 |
|
@@ -178,48 +176,37 @@ class MTLClassifier:
|
|
178 |
self.use_data_parallel = use_data_parallel
|
179 |
self.use_attention_pooling = use_attention_pooling
|
180 |
self.use_task_weights = use_task_weights
|
181 |
-
self.hyperparameters =
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
"
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
"high": 0.01
|
192 |
-
|
193 |
-
|
194 |
-
"type": "
|
195 |
-
"low": 0.
|
196 |
-
"high": 0.1
|
197 |
-
},
|
198 |
-
"dropout_rate": {
|
199 |
-
"type": "float",
|
200 |
-
"low": 0.0,
|
201 |
-
"high": 0.7
|
202 |
-
},
|
203 |
-
"lr_scheduler_type": {
|
204 |
-
"type": "categorical",
|
205 |
-
"choices": ["cosine"]
|
206 |
-
},
|
207 |
-
"task_weights": {
|
208 |
-
"type": "float",
|
209 |
-
"low": 0.1,
|
210 |
-
"high": 2.0
|
211 |
}
|
212 |
-
|
213 |
-
self.manual_hyperparameters =
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
|
|
|
|
|
|
|
|
223 |
self.use_manual_hyperparameters = use_manual_hyperparameters
|
224 |
self.use_wandb = use_wandb
|
225 |
self.wandb_project = wandb_project
|
@@ -236,13 +223,19 @@ class MTLClassifier:
|
|
236 |
|
237 |
# set up output directories
|
238 |
if self.results_dir is not None:
|
239 |
-
self.trials_results_path = f"{self.results_dir}/results.txt".replace(
|
240 |
-
|
|
|
|
|
241 |
for output_dir in [self.model_save_path, self.results_dir]:
|
242 |
if not os.path.exists(output_dir):
|
243 |
os.makedirs(output_dir)
|
244 |
|
245 |
-
self.config = {
|
|
|
|
|
|
|
|
|
246 |
|
247 |
def validate_options(self):
|
248 |
# confirm arguments are within valid options and compatible with each other
|
@@ -264,19 +257,35 @@ class MTLClassifier:
|
|
264 |
f"Invalid option for {attr_name}. "
|
265 |
f"Valid options for {attr_name}: {valid_options}"
|
266 |
)
|
267 |
-
raise ValueError(
|
|
|
|
|
268 |
|
269 |
def run_manual_tuning(self):
|
270 |
"""
|
271 |
Manual hyperparameter tuning and multi-task fine-tuning of pretrained model.
|
272 |
"""
|
273 |
-
required_variable_names = [
|
274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
req_var_dict = dict(zip(required_variable_names, required_variables))
|
276 |
self.validate_additional_options(req_var_dict)
|
277 |
|
278 |
if not self.use_manual_hyperparameters:
|
279 |
-
raise ValueError(
|
|
|
|
|
280 |
|
281 |
# Ensure manual_hyperparameters are set in the config
|
282 |
self.config["manual_hyperparameters"] = self.manual_hyperparameters
|
@@ -302,8 +311,20 @@ class MTLClassifier:
|
|
302 |
Hyperparameter optimization and/or multi-task fine-tuning of pretrained model.
|
303 |
"""
|
304 |
|
305 |
-
required_variable_names = [
|
306 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
307 |
req_var_dict = dict(zip(required_variable_names, required_variables))
|
308 |
self.validate_additional_options(req_var_dict)
|
309 |
|
@@ -322,7 +343,7 @@ class MTLClassifier:
|
|
322 |
self.validate_additional_options(req_var_dict)
|
323 |
|
324 |
eval_utils.load_and_evaluate_test_model(self.config)
|
325 |
-
|
326 |
def save_model_without_heads(
|
327 |
self,
|
328 |
):
|
@@ -335,4 +356,6 @@ class MTLClassifier:
|
|
335 |
req_var_dict = dict(zip(required_variable_names, required_variables))
|
336 |
self.validate_additional_options(req_var_dict)
|
337 |
|
338 |
-
utils.save_model_without_heads(
|
|
|
|
|
|
28 |
|
29 |
import logging
|
30 |
import os
|
31 |
+
|
32 |
+
from .mtl import eval_utils, train_utils, utils
|
|
|
33 |
|
34 |
logger = logging.getLogger(__name__)
|
35 |
|
|
|
89 |
wandb_project=None,
|
90 |
gradient_clipping=False,
|
91 |
max_grad_norm=None,
|
92 |
+
seed=42, # Default seed value
|
93 |
):
|
|
|
94 |
"""
|
95 |
Initialize Geneformer multi-task classifier.
|
96 |
**Parameters:**
|
|
|
163 |
self.batch_size = batch_size
|
164 |
self.n_trials = n_trials
|
165 |
self.study_name = study_name
|
166 |
+
|
167 |
if max_layers_to_freeze is None:
|
168 |
# Dynamically determine the range of layers to freeze
|
169 |
layer_freeze_range = utils.get_layer_freeze_range(pretrained_path)
|
170 |
+
self.max_layers_to_freeze = {"min": 1, "max": layer_freeze_range["max"]}
|
171 |
else:
|
172 |
self.max_layers_to_freeze = max_layers_to_freeze
|
173 |
|
|
|
176 |
self.use_data_parallel = use_data_parallel
|
177 |
self.use_attention_pooling = use_attention_pooling
|
178 |
self.use_task_weights = use_task_weights
|
179 |
+
self.hyperparameters = (
|
180 |
+
hyperparameters
|
181 |
+
if hyperparameters is not None
|
182 |
+
else {
|
183 |
+
"learning_rate": {
|
184 |
+
"type": "float",
|
185 |
+
"low": 1e-5,
|
186 |
+
"high": 1e-3,
|
187 |
+
"log": True,
|
188 |
+
},
|
189 |
+
"warmup_ratio": {"type": "float", "low": 0.005, "high": 0.01},
|
190 |
+
"weight_decay": {"type": "float", "low": 0.01, "high": 0.1},
|
191 |
+
"dropout_rate": {"type": "float", "low": 0.0, "high": 0.7},
|
192 |
+
"lr_scheduler_type": {"type": "categorical", "choices": ["cosine"]},
|
193 |
+
"task_weights": {"type": "float", "low": 0.1, "high": 2.0},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
}
|
195 |
+
)
|
196 |
+
self.manual_hyperparameters = (
|
197 |
+
manual_hyperparameters
|
198 |
+
if manual_hyperparameters is not None
|
199 |
+
else {
|
200 |
+
"learning_rate": 0.001,
|
201 |
+
"warmup_ratio": 0.01,
|
202 |
+
"weight_decay": 0.1,
|
203 |
+
"dropout_rate": 0.1,
|
204 |
+
"lr_scheduler_type": "cosine",
|
205 |
+
"use_attention_pooling": False,
|
206 |
+
"task_weights": [1, 1],
|
207 |
+
"max_layers_to_freeze": 2,
|
208 |
+
}
|
209 |
+
)
|
210 |
self.use_manual_hyperparameters = use_manual_hyperparameters
|
211 |
self.use_wandb = use_wandb
|
212 |
self.wandb_project = wandb_project
|
|
|
223 |
|
224 |
# set up output directories
|
225 |
if self.results_dir is not None:
|
226 |
+
self.trials_results_path = f"{self.results_dir}/results.txt".replace(
|
227 |
+
"//", "/"
|
228 |
+
)
|
229 |
+
|
230 |
for output_dir in [self.model_save_path, self.results_dir]:
|
231 |
if not os.path.exists(output_dir):
|
232 |
os.makedirs(output_dir)
|
233 |
|
234 |
+
self.config = {
|
235 |
+
key: value
|
236 |
+
for key, value in self.__dict__.items()
|
237 |
+
if key in self.valid_option_dict
|
238 |
+
}
|
239 |
|
240 |
def validate_options(self):
|
241 |
# confirm arguments are within valid options and compatible with each other
|
|
|
257 |
f"Invalid option for {attr_name}. "
|
258 |
f"Valid options for {attr_name}: {valid_options}"
|
259 |
)
|
260 |
+
raise ValueError(
|
261 |
+
f"Invalid option for {attr_name}. Valid options for {attr_name}: {valid_options}"
|
262 |
+
)
|
263 |
|
264 |
def run_manual_tuning(self):
|
265 |
"""
|
266 |
Manual hyperparameter tuning and multi-task fine-tuning of pretrained model.
|
267 |
"""
|
268 |
+
required_variable_names = [
|
269 |
+
"train_path",
|
270 |
+
"val_path",
|
271 |
+
"pretrained_path",
|
272 |
+
"model_save_path",
|
273 |
+
"results_dir",
|
274 |
+
]
|
275 |
+
required_variables = [
|
276 |
+
self.train_path,
|
277 |
+
self.val_path,
|
278 |
+
self.pretrained_path,
|
279 |
+
self.model_save_path,
|
280 |
+
self.results_dir,
|
281 |
+
]
|
282 |
req_var_dict = dict(zip(required_variable_names, required_variables))
|
283 |
self.validate_additional_options(req_var_dict)
|
284 |
|
285 |
if not self.use_manual_hyperparameters:
|
286 |
+
raise ValueError(
|
287 |
+
"Manual hyperparameters are not enabled. Set use_manual_hyperparameters to True."
|
288 |
+
)
|
289 |
|
290 |
# Ensure manual_hyperparameters are set in the config
|
291 |
self.config["manual_hyperparameters"] = self.manual_hyperparameters
|
|
|
311 |
Hyperparameter optimization and/or multi-task fine-tuning of pretrained model.
|
312 |
"""
|
313 |
|
314 |
+
required_variable_names = [
|
315 |
+
"train_path",
|
316 |
+
"val_path",
|
317 |
+
"pretrained_path",
|
318 |
+
"model_save_path",
|
319 |
+
"results_dir",
|
320 |
+
]
|
321 |
+
required_variables = [
|
322 |
+
self.train_path,
|
323 |
+
self.val_path,
|
324 |
+
self.pretrained_path,
|
325 |
+
self.model_save_path,
|
326 |
+
self.results_dir,
|
327 |
+
]
|
328 |
req_var_dict = dict(zip(required_variable_names, required_variables))
|
329 |
self.validate_additional_options(req_var_dict)
|
330 |
|
|
|
343 |
self.validate_additional_options(req_var_dict)
|
344 |
|
345 |
eval_utils.load_and_evaluate_test_model(self.config)
|
346 |
+
|
347 |
def save_model_without_heads(
|
348 |
self,
|
349 |
):
|
|
|
356 |
req_var_dict = dict(zip(required_variable_names, required_variables))
|
357 |
self.validate_additional_options(req_var_dict)
|
358 |
|
359 |
+
utils.save_model_without_heads(
|
360 |
+
os.path.join(self.model_save_path, "GeneformerMultiTask")
|
361 |
+
)
|
geneformer/perturber_utils.py
CHANGED
@@ -1,18 +1,15 @@
|
|
1 |
import itertools as it
|
2 |
import logging
|
3 |
import pickle
|
4 |
-
import re
|
5 |
from collections import defaultdict
|
6 |
-
from typing import List
|
7 |
from pathlib import Path
|
8 |
-
|
9 |
|
10 |
import numpy as np
|
11 |
import pandas as pd
|
12 |
-
import seaborn as sns
|
13 |
import torch
|
14 |
from datasets import Dataset, load_from_disk
|
15 |
-
from peft import LoraConfig, get_peft_model
|
16 |
from transformers import (
|
17 |
BertForMaskedLM,
|
18 |
BertForSequenceClassification,
|
@@ -119,7 +116,7 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
|
|
119 |
if model_type == "MTLCellClassifier-Quantized":
|
120 |
model_type = "MTLCellClassifier"
|
121 |
quantize = True
|
122 |
-
|
123 |
if mode == "eval":
|
124 |
output_hidden_states = True
|
125 |
elif mode == "train":
|
@@ -131,7 +128,7 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
|
|
131 |
"peft_config": None,
|
132 |
"bnb_config": BitsAndBytesConfig(
|
133 |
load_in_8bit=True,
|
134 |
-
)
|
135 |
}
|
136 |
else:
|
137 |
quantize = {
|
@@ -141,13 +138,13 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
|
|
141 |
r=64,
|
142 |
bias="none",
|
143 |
task_type="TokenClassification",
|
144 |
-
|
145 |
"bnb_config": BitsAndBytesConfig(
|
146 |
load_in_4bit=True,
|
147 |
bnb_4bit_use_double_quant=True,
|
148 |
bnb_4bit_quant_type="nf4",
|
149 |
-
bnb_4bit_compute_dtype=torch.bfloat16
|
150 |
-
|
151 |
}
|
152 |
elif quantize is False:
|
153 |
quantize = {"bnb_config": None}
|
@@ -186,7 +183,11 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
|
|
186 |
# if eval mode, put the model in eval mode for fwd pass
|
187 |
if mode == "eval":
|
188 |
model.eval()
|
189 |
-
if (
|
|
|
|
|
|
|
|
|
190 |
model = model.to("cuda")
|
191 |
else:
|
192 |
model.enable_input_require_grads()
|
@@ -279,18 +280,20 @@ def overexpress_indices(example):
|
|
279 |
example["length"] = len(example["input_ids"])
|
280 |
return example
|
281 |
|
|
|
282 |
# if CLS token present, move to 1st rather than 0th position
|
283 |
def overexpress_indices_special(example):
|
284 |
indices = example["perturb_index"]
|
285 |
if any(isinstance(el, list) for el in indices):
|
286 |
indices = flatten_list(indices)
|
287 |
-
insert_pos = 1
|
288 |
for index in sorted(indices, reverse=False):
|
289 |
example["input_ids"].insert(insert_pos, example["input_ids"].pop(index))
|
290 |
insert_pos += 1
|
291 |
example["length"] = len(example["input_ids"])
|
292 |
return example
|
293 |
|
|
|
294 |
# for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
|
295 |
def overexpress_tokens(example, max_len, special_token):
|
296 |
# -100 indicates tokens to overexpress are not present in rank value encoding
|
@@ -310,7 +313,9 @@ def overexpress_tokens(example, max_len, special_token):
|
|
310 |
# truncate to max input size, must also truncate original emb to be comparable
|
311 |
if len(example["input_ids"]) > max_len:
|
312 |
if special_token:
|
313 |
-
example["input_ids"] = example["input_ids"][0:max_len-1]+[
|
|
|
|
|
314 |
else:
|
315 |
example["input_ids"] = example["input_ids"][0:max_len]
|
316 |
example["length"] = len(example["input_ids"])
|
@@ -329,10 +334,13 @@ def truncate_by_n_overflow(example):
|
|
329 |
example["length"] = len(example["input_ids"])
|
330 |
return example
|
331 |
|
|
|
332 |
def truncate_by_n_overflow_special(example):
|
333 |
if example["n_overflow"] > 0:
|
334 |
new_max_len = example["length"] - example["n_overflow"]
|
335 |
-
example["input_ids"] = example["input_ids"][0:new_max_len-1]+[
|
|
|
|
|
336 |
example["length"] = len(example["input_ids"])
|
337 |
return example
|
338 |
|
@@ -477,19 +485,24 @@ def make_perturbation_batch_special(
|
|
477 |
range_start = 1
|
478 |
elif perturb_type in ["delete", "inhibit"]:
|
479 |
range_start = 0
|
480 |
-
range_start += 1
|
481 |
indices_to_perturb = [
|
482 |
-
[i]
|
|
|
|
|
|
|
483 |
]
|
484 |
|
485 |
# elif combo_lvl > 0 and anchor_token is None:
|
486 |
## to implement
|
487 |
-
elif combo_lvl > 0 and (anchor_token is not None):
|
488 |
example_input_ids = example_cell["input_ids"][0]
|
489 |
anchor_index = example_input_ids.index(anchor_token[0])
|
490 |
indices_to_perturb = [
|
491 |
sorted([anchor_index, i]) if i != anchor_index else None
|
492 |
-
for i in range(
|
|
|
|
|
493 |
]
|
494 |
indices_to_perturb = [item for item in indices_to_perturb if item is not None]
|
495 |
else:
|
@@ -508,7 +521,9 @@ def make_perturbation_batch_special(
|
|
508 |
list(x) for x in it.combinations(indices_to_perturb, combo_lvl + 1)
|
509 |
]
|
510 |
else:
|
511 |
-
all_indices = [
|
|
|
|
|
512 |
all_indices = [
|
513 |
index for index in all_indices if index not in indices_to_perturb
|
514 |
]
|
@@ -535,7 +550,7 @@ def make_perturbation_batch_special(
|
|
535 |
)
|
536 |
elif perturb_type == "overexpress":
|
537 |
perturbation_dataset = perturbation_dataset.map(
|
538 |
-
|
539 |
)
|
540 |
|
541 |
perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
|
@@ -743,7 +758,7 @@ def quant_cos_sims(
|
|
743 |
# against original cell
|
744 |
if cell_states_to_model is None or emb_mode == "gene":
|
745 |
cos_sims = cos(perturbation_emb, original_emb).to("cuda")
|
746 |
-
|
747 |
elif cell_states_to_model is not None and emb_mode == "cell":
|
748 |
possible_states = get_possible_states(cell_states_to_model)
|
749 |
cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
|
@@ -867,27 +882,28 @@ def validate_cell_states_to_model(cell_states_to_model):
|
|
867 |
)
|
868 |
raise
|
869 |
|
|
|
870 |
class GeneIdHandler:
|
871 |
def __init__(self, raise_errors=False):
|
872 |
def invert_dict(dict_obj):
|
873 |
-
return {v:k for k,v in dict_obj.items()}
|
874 |
-
|
875 |
self.raise_errors = raise_errors
|
876 |
-
|
877 |
-
with open(TOKEN_DICTIONARY_FILE,
|
878 |
self.gene_token_dict = pickle.load(f)
|
879 |
self.token_gene_dict = invert_dict(self.gene_token_dict)
|
880 |
|
881 |
-
with open(ENSEMBL_DICTIONARY_FILE,
|
882 |
self.id_gene_dict = pickle.load(f)
|
883 |
self.gene_id_dict = invert_dict(self.id_gene_dict)
|
884 |
-
|
885 |
def ens_to_token(self, ens_id):
|
886 |
if not self.raise_errors:
|
887 |
return self.gene_token_dict.get(ens_id, ens_id)
|
888 |
else:
|
889 |
return self.gene_token_dict[ens_id]
|
890 |
-
|
891 |
def token_to_ens(self, token):
|
892 |
if not self.raise_errors:
|
893 |
return self.token_gene_dict.get(token, token)
|
@@ -899,15 +915,15 @@ class GeneIdHandler:
|
|
899 |
return self.gene_id_dict.get(ens_id, ens_id)
|
900 |
else:
|
901 |
return self.gene_id_dict[ens_id]
|
902 |
-
|
903 |
def symbol_to_ens(self, symbol):
|
904 |
if not self.raise_errors:
|
905 |
return self.id_gene_dict.get(symbol, symbol)
|
906 |
else:
|
907 |
return self.id_gene_dict[symbol]
|
908 |
-
|
909 |
def token_to_symbol(self, token):
|
910 |
return self.ens_to_symbol(self.token_to_ens(token))
|
911 |
-
|
912 |
def symbol_to_token(self, symbol):
|
913 |
-
return self.ens_to_token(self.symbol_to_ens(symbol))
|
|
|
1 |
import itertools as it
|
2 |
import logging
|
3 |
import pickle
|
|
|
4 |
from collections import defaultdict
|
|
|
5 |
from pathlib import Path
|
6 |
+
from typing import List
|
7 |
|
8 |
import numpy as np
|
9 |
import pandas as pd
|
|
|
10 |
import torch
|
11 |
from datasets import Dataset, load_from_disk
|
12 |
+
from peft import LoraConfig, get_peft_model
|
13 |
from transformers import (
|
14 |
BertForMaskedLM,
|
15 |
BertForSequenceClassification,
|
|
|
116 |
if model_type == "MTLCellClassifier-Quantized":
|
117 |
model_type = "MTLCellClassifier"
|
118 |
quantize = True
|
119 |
+
|
120 |
if mode == "eval":
|
121 |
output_hidden_states = True
|
122 |
elif mode == "train":
|
|
|
128 |
"peft_config": None,
|
129 |
"bnb_config": BitsAndBytesConfig(
|
130 |
load_in_8bit=True,
|
131 |
+
),
|
132 |
}
|
133 |
else:
|
134 |
quantize = {
|
|
|
138 |
r=64,
|
139 |
bias="none",
|
140 |
task_type="TokenClassification",
|
141 |
+
),
|
142 |
"bnb_config": BitsAndBytesConfig(
|
143 |
load_in_4bit=True,
|
144 |
bnb_4bit_use_double_quant=True,
|
145 |
bnb_4bit_quant_type="nf4",
|
146 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
147 |
+
),
|
148 |
}
|
149 |
elif quantize is False:
|
150 |
quantize = {"bnb_config": None}
|
|
|
183 |
# if eval mode, put the model in eval mode for fwd pass
|
184 |
if mode == "eval":
|
185 |
model.eval()
|
186 |
+
if (
|
187 |
+
(quantize is False)
|
188 |
+
or (quantize == {"bnb_config": None})
|
189 |
+
or (model_type == "MTLCellClassifier")
|
190 |
+
):
|
191 |
model = model.to("cuda")
|
192 |
else:
|
193 |
model.enable_input_require_grads()
|
|
|
280 |
example["length"] = len(example["input_ids"])
|
281 |
return example
|
282 |
|
283 |
+
|
284 |
# if CLS token present, move to 1st rather than 0th position
|
285 |
def overexpress_indices_special(example):
|
286 |
indices = example["perturb_index"]
|
287 |
if any(isinstance(el, list) for el in indices):
|
288 |
indices = flatten_list(indices)
|
289 |
+
insert_pos = 1 # Insert starting after CLS token
|
290 |
for index in sorted(indices, reverse=False):
|
291 |
example["input_ids"].insert(insert_pos, example["input_ids"].pop(index))
|
292 |
insert_pos += 1
|
293 |
example["length"] = len(example["input_ids"])
|
294 |
return example
|
295 |
|
296 |
+
|
297 |
# for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
|
298 |
def overexpress_tokens(example, max_len, special_token):
|
299 |
# -100 indicates tokens to overexpress are not present in rank value encoding
|
|
|
313 |
# truncate to max input size, must also truncate original emb to be comparable
|
314 |
if len(example["input_ids"]) > max_len:
|
315 |
if special_token:
|
316 |
+
example["input_ids"] = example["input_ids"][0 : max_len - 1] + [
|
317 |
+
example["input_ids"][-1]
|
318 |
+
]
|
319 |
else:
|
320 |
example["input_ids"] = example["input_ids"][0:max_len]
|
321 |
example["length"] = len(example["input_ids"])
|
|
|
334 |
example["length"] = len(example["input_ids"])
|
335 |
return example
|
336 |
|
337 |
+
|
338 |
def truncate_by_n_overflow_special(example):
|
339 |
if example["n_overflow"] > 0:
|
340 |
new_max_len = example["length"] - example["n_overflow"]
|
341 |
+
example["input_ids"] = example["input_ids"][0 : new_max_len - 1] + [
|
342 |
+
example["input_ids"][-1]
|
343 |
+
]
|
344 |
example["length"] = len(example["input_ids"])
|
345 |
return example
|
346 |
|
|
|
485 |
range_start = 1
|
486 |
elif perturb_type in ["delete", "inhibit"]:
|
487 |
range_start = 0
|
488 |
+
range_start += 1 # Starting after the CLS token
|
489 |
indices_to_perturb = [
|
490 |
+
[i]
|
491 |
+
for i in range(
|
492 |
+
range_start, example_cell["length"][0] - 1
|
493 |
+
) # And excluding the EOS token
|
494 |
]
|
495 |
|
496 |
# elif combo_lvl > 0 and anchor_token is None:
|
497 |
## to implement
|
498 |
+
elif combo_lvl > 0 and (anchor_token is not None):
|
499 |
example_input_ids = example_cell["input_ids"][0]
|
500 |
anchor_index = example_input_ids.index(anchor_token[0])
|
501 |
indices_to_perturb = [
|
502 |
sorted([anchor_index, i]) if i != anchor_index else None
|
503 |
+
for i in range(
|
504 |
+
1, example_cell["length"][0] - 1
|
505 |
+
) # Exclude CLS and EOS tokens
|
506 |
]
|
507 |
indices_to_perturb = [item for item in indices_to_perturb if item is not None]
|
508 |
else:
|
|
|
521 |
list(x) for x in it.combinations(indices_to_perturb, combo_lvl + 1)
|
522 |
]
|
523 |
else:
|
524 |
+
all_indices = [
|
525 |
+
[i] for i in range(1, example_cell["length"][0] - 1)
|
526 |
+
] # Exclude CLS and EOS tokens
|
527 |
all_indices = [
|
528 |
index for index in all_indices if index not in indices_to_perturb
|
529 |
]
|
|
|
550 |
)
|
551 |
elif perturb_type == "overexpress":
|
552 |
perturbation_dataset = perturbation_dataset.map(
|
553 |
+
overexpress_indices_special, num_proc=num_proc_i
|
554 |
)
|
555 |
|
556 |
perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
|
|
|
758 |
# against original cell
|
759 |
if cell_states_to_model is None or emb_mode == "gene":
|
760 |
cos_sims = cos(perturbation_emb, original_emb).to("cuda")
|
761 |
+
|
762 |
elif cell_states_to_model is not None and emb_mode == "cell":
|
763 |
possible_states = get_possible_states(cell_states_to_model)
|
764 |
cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
|
|
|
882 |
)
|
883 |
raise
|
884 |
|
885 |
+
|
886 |
class GeneIdHandler:
|
887 |
def __init__(self, raise_errors=False):
|
888 |
def invert_dict(dict_obj):
|
889 |
+
return {v: k for k, v in dict_obj.items()}
|
890 |
+
|
891 |
self.raise_errors = raise_errors
|
892 |
+
|
893 |
+
with open(TOKEN_DICTIONARY_FILE, "rb") as f:
|
894 |
self.gene_token_dict = pickle.load(f)
|
895 |
self.token_gene_dict = invert_dict(self.gene_token_dict)
|
896 |
|
897 |
+
with open(ENSEMBL_DICTIONARY_FILE, "rb") as f:
|
898 |
self.id_gene_dict = pickle.load(f)
|
899 |
self.gene_id_dict = invert_dict(self.id_gene_dict)
|
900 |
+
|
901 |
def ens_to_token(self, ens_id):
|
902 |
if not self.raise_errors:
|
903 |
return self.gene_token_dict.get(ens_id, ens_id)
|
904 |
else:
|
905 |
return self.gene_token_dict[ens_id]
|
906 |
+
|
907 |
def token_to_ens(self, token):
|
908 |
if not self.raise_errors:
|
909 |
return self.token_gene_dict.get(token, token)
|
|
|
915 |
return self.gene_id_dict.get(ens_id, ens_id)
|
916 |
else:
|
917 |
return self.gene_id_dict[ens_id]
|
918 |
+
|
919 |
def symbol_to_ens(self, symbol):
|
920 |
if not self.raise_errors:
|
921 |
return self.id_gene_dict.get(symbol, symbol)
|
922 |
else:
|
923 |
return self.id_gene_dict[symbol]
|
924 |
+
|
925 |
def token_to_symbol(self, token):
|
926 |
return self.ens_to_symbol(self.token_to_ens(token))
|
927 |
+
|
928 |
def symbol_to_token(self, symbol):
|
929 |
+
return self.ens_to_token(self.symbol_to_ens(symbol))
|
geneformer/tokenizer.py
CHANGED
@@ -22,28 +22,28 @@ Geneformer tokenizer.
|
|
22 |
|
23 |
from __future__ import annotations
|
24 |
|
25 |
-
import os
|
26 |
import logging
|
|
|
27 |
import pickle
|
28 |
import warnings
|
|
|
29 |
from pathlib import Path
|
30 |
from typing import Literal
|
31 |
-
from tqdm import tqdm
|
32 |
-
from collections import Counter
|
33 |
|
34 |
-
import numpy as np
|
35 |
-
import scanpy as sc
|
36 |
import loompy as lp
|
|
|
37 |
import pandas as pd
|
|
|
38 |
import scipy.sparse as sp
|
39 |
from datasets import Dataset
|
|
|
40 |
|
41 |
warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa
|
42 |
import loompy as lp # noqa
|
43 |
|
44 |
logger = logging.getLogger(__name__)
|
45 |
|
46 |
-
from . import GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE
|
47 |
|
48 |
|
49 |
def rank_genes(gene_vector, gene_tokens):
|
@@ -65,20 +65,26 @@ def tokenize_cell(gene_vector, gene_tokens):
|
|
65 |
# rank by median-scaled gene values
|
66 |
return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])
|
67 |
|
68 |
-
def sum_ensembl_ids(data_directory,
|
69 |
-
collapse_gene_ids,
|
70 |
-
gene_mapping_dict,
|
71 |
-
gene_token_dict,
|
72 |
-
file_format = "loom",
|
73 |
-
chunk_size = 512):
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
if file_format == "loom":
|
76 |
-
"""
|
77 |
Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
|
78 |
"""
|
79 |
with lp.connect(data_directory) as data:
|
80 |
-
assert
|
81 |
-
|
|
|
|
|
|
|
|
|
82 |
if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
|
83 |
token_genes_unique = True
|
84 |
else:
|
@@ -89,45 +95,80 @@ def sum_ensembl_ids(data_directory,
|
|
89 |
else:
|
90 |
raise ValueError("Error: data Ensembl IDs non-unique.")
|
91 |
|
92 |
-
gene_ids_collapsed = [
|
93 |
-
|
|
|
|
|
|
|
|
|
94 |
|
95 |
-
if (
|
|
|
|
|
96 |
return data_directory
|
97 |
else:
|
98 |
-
dedup_filename = data_directory.with_name(
|
|
|
|
|
99 |
data.ra["gene_ids_collapsed"] = gene_ids_collapsed
|
100 |
-
dup_genes = [
|
|
|
|
|
|
|
|
|
101 |
num_chunks = int(np.ceil(data.shape[1] / chunk_size))
|
102 |
first_chunk = True
|
103 |
-
for _, _, view in tqdm(
|
|
|
|
|
|
|
104 |
def process_chunk(view, duplic_genes):
|
105 |
-
data_count_view = pd.DataFrame(
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
summed_data = dup_data_df.groupby(dup_data_df.index).sum()
|
109 |
if not summed_data.index.is_unique:
|
110 |
-
raise ValueError(
|
111 |
-
|
|
|
|
|
|
|
|
|
112 |
if not data_count_view.index.is_unique:
|
113 |
-
raise ValueError(
|
|
|
|
|
114 |
return data_count_view
|
|
|
115 |
processed_chunk = process_chunk(view[:, :], dup_genes)
|
116 |
processed_array = processed_chunk.to_numpy()
|
117 |
new_row_attrs = {"ensembl_id": processed_chunk.index.to_numpy()}
|
118 |
|
119 |
if "n_counts" not in view.ca.keys():
|
120 |
-
total_count_view = np.sum(view[
|
121 |
view.ca["n_counts"] = total_count_view
|
122 |
|
123 |
-
if first_chunk:
|
124 |
-
lp.create(
|
|
|
|
|
|
|
|
|
|
|
125 |
first_chunk = False
|
126 |
-
else:
|
127 |
-
with lp.connect(dedup_filename, mode=
|
128 |
dsout.add_columns(processed_array, col_attrs=view.ca)
|
129 |
return dedup_filename
|
130 |
-
|
131 |
elif file_format == "h5ad":
|
132 |
"""
|
133 |
Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
|
@@ -136,8 +177,12 @@ def sum_ensembl_ids(data_directory,
|
|
136 |
|
137 |
data = sc.read_h5ad(str(data_directory))
|
138 |
|
139 |
-
assert
|
140 |
-
|
|
|
|
|
|
|
|
|
141 |
if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
|
142 |
token_genes_unique = True
|
143 |
else:
|
@@ -148,22 +193,29 @@ def sum_ensembl_ids(data_directory,
|
|
148 |
else:
|
149 |
raise ValueError("Error: data Ensembl IDs non-unique.")
|
150 |
|
151 |
-
gene_ids_collapsed = [
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
return data
|
155 |
-
|
156 |
else:
|
157 |
data.var["gene_ids_collapsed"] = gene_ids_collapsed
|
158 |
data.var_names = gene_ids_collapsed
|
159 |
data = data[:, ~data.var.index.isna()]
|
160 |
-
dup_genes = [
|
|
|
|
|
161 |
|
162 |
num_chunks = int(np.ceil(data.shape[0] / chunk_size))
|
163 |
|
164 |
processed_genes = []
|
165 |
for i in tqdm(range(num_chunks)):
|
166 |
-
|
167 |
start_idx = i * chunk_size
|
168 |
end_idx = min((i + 1) * chunk_size, data.shape[0])
|
169 |
data_chunk = data[start_idx:end_idx, :]
|
@@ -171,29 +223,32 @@ def sum_ensembl_ids(data_directory,
|
|
171 |
processed_chunks = []
|
172 |
for dup_gene in dup_genes:
|
173 |
data_dup_gene = data_chunk[:, data_chunk.var_names == dup_gene]
|
174 |
-
df = pd.DataFrame.sparse.from_spmatrix(
|
175 |
-
|
176 |
-
|
|
|
|
|
177 |
df_sum = pd.DataFrame(df.sum(axis=1))
|
178 |
-
df_sum.columns = [dup_gene]
|
179 |
df_sum.index = data_dup_gene.obs.index
|
180 |
processed_chunks.append(df_sum)
|
181 |
|
182 |
-
processed_chunks = pd.concat(processed_chunks, axis=1)
|
183 |
processed_genes.append(processed_chunks)
|
184 |
-
processed_genes = pd.concat(processed_genes, axis
|
185 |
-
var_df = pd.DataFrame({"gene_ids_collapsed"
|
186 |
var_df.index = processed_genes.columns
|
187 |
-
processed_genes = sc.AnnData(X =
|
188 |
-
obs = data.obs,
|
189 |
-
var = var_df)
|
190 |
|
191 |
-
data_dedup = data[:, ~data.var.index.isin(dup_genes)]
|
192 |
-
data_dedup = sc.concat([data_dedup, processed_genes], axis
|
193 |
data_dedup.obs = data.obs
|
194 |
-
data_dedup.var = data_dedup.var.rename(
|
|
|
|
|
195 |
return data_dedup
|
196 |
|
|
|
197 |
class TranscriptomeTokenizer:
|
198 |
def __init__(
|
199 |
self,
|
@@ -258,10 +313,12 @@ class TranscriptomeTokenizer:
|
|
258 |
|
259 |
# check for special token in gene_token_dict
|
260 |
if self.special_token:
|
261 |
-
if ("<cls>" not in self.gene_token_dict.keys()) and (
|
|
|
|
|
262 |
logger.error(
|
263 |
-
|
264 |
-
|
265 |
raise
|
266 |
|
267 |
# if collapsing duplicate gene IDs
|
@@ -272,14 +329,16 @@ class TranscriptomeTokenizer:
|
|
272 |
with open(gene_mapping_file, "rb") as f:
|
273 |
self.gene_mapping_dict = pickle.load(f)
|
274 |
else:
|
275 |
-
self.gene_mapping_dict = {k:k for k,_ in self.gene_token_dict.items()}
|
276 |
|
277 |
# gene keys for full vocabulary
|
278 |
self.gene_keys = list(self.gene_token_dict.keys())
|
279 |
|
280 |
# Filter gene mapping dict for items that exist in gene_token_dict
|
281 |
gene_keys_set = set(self.gene_token_dict.keys())
|
282 |
-
self.gene_mapping_dict = {
|
|
|
|
|
283 |
|
284 |
# protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
|
285 |
self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))
|
@@ -355,7 +414,14 @@ class TranscriptomeTokenizer:
|
|
355 |
return tokenized_cells, cell_metadata
|
356 |
|
357 |
def tokenize_anndata(self, adata_file_path, target_sum=10_000):
|
358 |
-
adata = sum_ensembl_ids(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
|
360 |
if self.custom_attr_name_dict is not None:
|
361 |
file_cell_metadata = {
|
@@ -397,7 +463,7 @@ class TranscriptomeTokenizer:
|
|
397 |
idx = filter_pass_loc[i : i + self.chunk_size]
|
398 |
|
399 |
n_counts = adata[idx].obs["n_counts"].values[:, None]
|
400 |
-
X_view0 = adata[idx
|
401 |
X_view = X_view0[:, coding_miRNA_loc]
|
402 |
X_norm = X_view / n_counts * target_sum / norm_factor_vector
|
403 |
X_norm = sp.csr_matrix(X_norm)
|
@@ -423,7 +489,14 @@ class TranscriptomeTokenizer:
|
|
423 |
}
|
424 |
|
425 |
dedup_filename = loom_file_path.with_name(loom_file_path.stem + "__dedup.loom")
|
426 |
-
loom_file_path = sum_ensembl_ids(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
427 |
|
428 |
with lp.connect(str(loom_file_path)) as data:
|
429 |
# define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
|
@@ -544,4 +617,4 @@ class TranscriptomeTokenizer:
|
|
544 |
output_dataset_truncated = output_dataset.map(
|
545 |
format_cell_features, num_proc=self.nproc
|
546 |
)
|
547 |
-
return output_dataset_truncated
|
|
|
22 |
|
23 |
from __future__ import annotations
|
24 |
|
|
|
25 |
import logging
|
26 |
+
import os
|
27 |
import pickle
|
28 |
import warnings
|
29 |
+
from collections import Counter
|
30 |
from pathlib import Path
|
31 |
from typing import Literal
|
|
|
|
|
32 |
|
|
|
|
|
33 |
import loompy as lp
|
34 |
+
import numpy as np
|
35 |
import pandas as pd
|
36 |
+
import scanpy as sc
|
37 |
import scipy.sparse as sp
|
38 |
from datasets import Dataset
|
39 |
+
from tqdm import tqdm
|
40 |
|
41 |
warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa
|
42 |
import loompy as lp # noqa
|
43 |
|
44 |
logger = logging.getLogger(__name__)
|
45 |
|
46 |
+
from . import ENSEMBL_MAPPING_FILE, GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE
|
47 |
|
48 |
|
49 |
def rank_genes(gene_vector, gene_tokens):
|
|
|
65 |
# rank by median-scaled gene values
|
66 |
return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
+
def sum_ensembl_ids(
|
70 |
+
data_directory,
|
71 |
+
collapse_gene_ids,
|
72 |
+
gene_mapping_dict,
|
73 |
+
gene_token_dict,
|
74 |
+
file_format="loom",
|
75 |
+
chunk_size=512,
|
76 |
+
):
|
77 |
if file_format == "loom":
|
78 |
+
"""
|
79 |
Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
|
80 |
"""
|
81 |
with lp.connect(data_directory) as data:
|
82 |
+
assert (
|
83 |
+
"ensembl_id" in data.ra.keys()
|
84 |
+
), "'ensembl_id' column missing from data.ra.keys()"
|
85 |
+
gene_ids_in_dict = [
|
86 |
+
gene for gene in data.ra.ensembl_id if gene in gene_token_dict.keys()
|
87 |
+
]
|
88 |
if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
|
89 |
token_genes_unique = True
|
90 |
else:
|
|
|
95 |
else:
|
96 |
raise ValueError("Error: data Ensembl IDs non-unique.")
|
97 |
|
98 |
+
gene_ids_collapsed = [
|
99 |
+
gene_mapping_dict.get(gene_id.upper()) for gene_id in data.ra.ensembl_id
|
100 |
+
]
|
101 |
+
gene_ids_collapsed_in_dict = [
|
102 |
+
gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()
|
103 |
+
]
|
104 |
|
105 |
+
if (
|
106 |
+
len(set(gene_ids_collapsed_in_dict)) == len(set(gene_ids_in_dict))
|
107 |
+
) and token_genes_unique:
|
108 |
return data_directory
|
109 |
else:
|
110 |
+
dedup_filename = data_directory.with_name(
|
111 |
+
data_directory.stem + "__dedup.loom"
|
112 |
+
)
|
113 |
data.ra["gene_ids_collapsed"] = gene_ids_collapsed
|
114 |
+
dup_genes = [
|
115 |
+
idx
|
116 |
+
for idx, count in Counter(data.ra["gene_ids_collapsed"]).items()
|
117 |
+
if count > 1
|
118 |
+
]
|
119 |
num_chunks = int(np.ceil(data.shape[1] / chunk_size))
|
120 |
first_chunk = True
|
121 |
+
for _, _, view in tqdm(
|
122 |
+
data.scan(axis=1, batch_size=chunk_size), total=num_chunks
|
123 |
+
):
|
124 |
+
|
125 |
def process_chunk(view, duplic_genes):
|
126 |
+
data_count_view = pd.DataFrame(
|
127 |
+
view, index=data.ra["gene_ids_collapsed"]
|
128 |
+
)
|
129 |
+
unique_data_df = data_count_view.loc[
|
130 |
+
~data_count_view.index.isin(duplic_genes)
|
131 |
+
]
|
132 |
+
dup_data_df = data_count_view.loc[
|
133 |
+
data_count_view.index.isin(
|
134 |
+
[i for i in duplic_genes if "None" not in i]
|
135 |
+
)
|
136 |
+
]
|
137 |
summed_data = dup_data_df.groupby(dup_data_df.index).sum()
|
138 |
if not summed_data.index.is_unique:
|
139 |
+
raise ValueError(
|
140 |
+
"Error: Ensembl IDs in summed data frame non-unique."
|
141 |
+
)
|
142 |
+
data_count_view = pd.concat(
|
143 |
+
[unique_data_df, summed_data], axis=0
|
144 |
+
)
|
145 |
if not data_count_view.index.is_unique:
|
146 |
+
raise ValueError(
|
147 |
+
"Error: Ensembl IDs in final data frame non-unique."
|
148 |
+
)
|
149 |
return data_count_view
|
150 |
+
|
151 |
processed_chunk = process_chunk(view[:, :], dup_genes)
|
152 |
processed_array = processed_chunk.to_numpy()
|
153 |
new_row_attrs = {"ensembl_id": processed_chunk.index.to_numpy()}
|
154 |
|
155 |
if "n_counts" not in view.ca.keys():
|
156 |
+
total_count_view = np.sum(view[:, :], axis=0).astype(int)
|
157 |
view.ca["n_counts"] = total_count_view
|
158 |
|
159 |
+
if first_chunk: # Create the Loom file with the first chunk
|
160 |
+
lp.create(
|
161 |
+
f"{dedup_filename}",
|
162 |
+
processed_array,
|
163 |
+
row_attrs=new_row_attrs,
|
164 |
+
col_attrs=view.ca,
|
165 |
+
)
|
166 |
first_chunk = False
|
167 |
+
else: # Append subsequent chunks
|
168 |
+
with lp.connect(dedup_filename, mode="r+") as dsout:
|
169 |
dsout.add_columns(processed_array, col_attrs=view.ca)
|
170 |
return dedup_filename
|
171 |
+
|
172 |
elif file_format == "h5ad":
|
173 |
"""
|
174 |
Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
|
|
|
177 |
|
178 |
data = sc.read_h5ad(str(data_directory))
|
179 |
|
180 |
+
assert (
|
181 |
+
"ensembl_id" in data.var.columns
|
182 |
+
), "'ensembl_id' column missing from data.var"
|
183 |
+
gene_ids_in_dict = [
|
184 |
+
gene for gene in data.var.ensembl_id if gene in gene_token_dict.keys()
|
185 |
+
]
|
186 |
if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
|
187 |
token_genes_unique = True
|
188 |
else:
|
|
|
193 |
else:
|
194 |
raise ValueError("Error: data Ensembl IDs non-unique.")
|
195 |
|
196 |
+
gene_ids_collapsed = [
|
197 |
+
gene_mapping_dict.get(gene_id.upper()) for gene_id in data.var.ensembl_id
|
198 |
+
]
|
199 |
+
gene_ids_collapsed_in_dict = [
|
200 |
+
gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()
|
201 |
+
]
|
202 |
+
if (
|
203 |
+
len(set(gene_ids_collapsed_in_dict)) == len(set(gene_ids_in_dict))
|
204 |
+
) and token_genes_unique:
|
205 |
return data
|
206 |
+
|
207 |
else:
|
208 |
data.var["gene_ids_collapsed"] = gene_ids_collapsed
|
209 |
data.var_names = gene_ids_collapsed
|
210 |
data = data[:, ~data.var.index.isna()]
|
211 |
+
dup_genes = [
|
212 |
+
idx for idx, count in Counter(data.var_names).items() if count > 1
|
213 |
+
]
|
214 |
|
215 |
num_chunks = int(np.ceil(data.shape[0] / chunk_size))
|
216 |
|
217 |
processed_genes = []
|
218 |
for i in tqdm(range(num_chunks)):
|
|
|
219 |
start_idx = i * chunk_size
|
220 |
end_idx = min((i + 1) * chunk_size, data.shape[0])
|
221 |
data_chunk = data[start_idx:end_idx, :]
|
|
|
223 |
processed_chunks = []
|
224 |
for dup_gene in dup_genes:
|
225 |
data_dup_gene = data_chunk[:, data_chunk.var_names == dup_gene]
|
226 |
+
df = pd.DataFrame.sparse.from_spmatrix(
|
227 |
+
data_dup_gene.X,
|
228 |
+
index=data_dup_gene.obs_names,
|
229 |
+
columns=data_dup_gene.var_names,
|
230 |
+
)
|
231 |
df_sum = pd.DataFrame(df.sum(axis=1))
|
232 |
+
df_sum.columns = [dup_gene]
|
233 |
df_sum.index = data_dup_gene.obs.index
|
234 |
processed_chunks.append(df_sum)
|
235 |
|
236 |
+
processed_chunks = pd.concat(processed_chunks, axis=1)
|
237 |
processed_genes.append(processed_chunks)
|
238 |
+
processed_genes = pd.concat(processed_genes, axis=0)
|
239 |
+
var_df = pd.DataFrame({"gene_ids_collapsed": processed_genes.columns})
|
240 |
var_df.index = processed_genes.columns
|
241 |
+
processed_genes = sc.AnnData(X=processed_genes, obs=data.obs, var=var_df)
|
|
|
|
|
242 |
|
243 |
+
data_dedup = data[:, ~data.var.index.isin(dup_genes)] # Deduplicated data
|
244 |
+
data_dedup = sc.concat([data_dedup, processed_genes], axis=1)
|
245 |
data_dedup.obs = data.obs
|
246 |
+
data_dedup.var = data_dedup.var.rename(
|
247 |
+
columns={"gene_ids_collapsed": "ensembl_id"}
|
248 |
+
)
|
249 |
return data_dedup
|
250 |
|
251 |
+
|
252 |
class TranscriptomeTokenizer:
|
253 |
def __init__(
|
254 |
self,
|
|
|
313 |
|
314 |
# check for special token in gene_token_dict
|
315 |
if self.special_token:
|
316 |
+
if ("<cls>" not in self.gene_token_dict.keys()) and (
|
317 |
+
"<eos>" not in self.gene_token_dict.keys()
|
318 |
+
):
|
319 |
logger.error(
|
320 |
+
"<cls> and <eos> required in gene_token_dict when special_token = True."
|
321 |
+
)
|
322 |
raise
|
323 |
|
324 |
# if collapsing duplicate gene IDs
|
|
|
329 |
with open(gene_mapping_file, "rb") as f:
|
330 |
self.gene_mapping_dict = pickle.load(f)
|
331 |
else:
|
332 |
+
self.gene_mapping_dict = {k: k for k, _ in self.gene_token_dict.items()}
|
333 |
|
334 |
# gene keys for full vocabulary
|
335 |
self.gene_keys = list(self.gene_token_dict.keys())
|
336 |
|
337 |
# Filter gene mapping dict for items that exist in gene_token_dict
|
338 |
gene_keys_set = set(self.gene_token_dict.keys())
|
339 |
+
self.gene_mapping_dict = {
|
340 |
+
k: v for k, v in self.gene_mapping_dict.items() if v in gene_keys_set
|
341 |
+
}
|
342 |
|
343 |
# protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
|
344 |
self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))
|
|
|
414 |
return tokenized_cells, cell_metadata
|
415 |
|
416 |
def tokenize_anndata(self, adata_file_path, target_sum=10_000):
|
417 |
+
adata = sum_ensembl_ids(
|
418 |
+
adata_file_path,
|
419 |
+
self.collapse_gene_ids,
|
420 |
+
self.gene_mapping_dict,
|
421 |
+
self.gene_token_dict,
|
422 |
+
file_format="h5ad",
|
423 |
+
chunk_size=self.chunk_size,
|
424 |
+
)
|
425 |
|
426 |
if self.custom_attr_name_dict is not None:
|
427 |
file_cell_metadata = {
|
|
|
463 |
idx = filter_pass_loc[i : i + self.chunk_size]
|
464 |
|
465 |
n_counts = adata[idx].obs["n_counts"].values[:, None]
|
466 |
+
X_view0 = adata[idx, :].X
|
467 |
X_view = X_view0[:, coding_miRNA_loc]
|
468 |
X_norm = X_view / n_counts * target_sum / norm_factor_vector
|
469 |
X_norm = sp.csr_matrix(X_norm)
|
|
|
489 |
}
|
490 |
|
491 |
dedup_filename = loom_file_path.with_name(loom_file_path.stem + "__dedup.loom")
|
492 |
+
loom_file_path = sum_ensembl_ids(
|
493 |
+
loom_file_path,
|
494 |
+
self.collapse_gene_ids,
|
495 |
+
self.gene_mapping_dict,
|
496 |
+
self.gene_token_dict,
|
497 |
+
file_format="loom",
|
498 |
+
chunk_size=self.chunk_size,
|
499 |
+
)
|
500 |
|
501 |
with lp.connect(str(loom_file_path)) as data:
|
502 |
# define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
|
|
|
617 |
output_dataset_truncated = output_dataset.map(
|
618 |
format_cell_features, num_proc=self.nproc
|
619 |
)
|
620 |
+
return output_dataset_truncated
|