ctheodoris commited on
Commit
f07bfd7
·
1 Parent(s): 933ca80

precommit formatting

Browse files
.gitattributes CHANGED
@@ -18,7 +18,7 @@
18
  *.pt filter=lfs diff=lfs merge=lfs -text
19
  *.pth filter=lfs diff=lfs merge=lfs -text
20
  *.rar filter=lfs diff=lfs merge=lfs -text
21
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
22
  *.tar.* filter=lfs diff=lfs merge=lfs -text
23
  *.tflite filter=lfs diff=lfs merge=lfs -text
24
  *.tgz filter=lfs diff=lfs merge=lfs -text
 
18
  *.pt filter=lfs diff=lfs merge=lfs -text
19
  *.pth filter=lfs diff=lfs merge=lfs -text
20
  *.rar filter=lfs diff=lfs merge=lfs -text
21
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
22
  *.tar.* filter=lfs diff=lfs merge=lfs -text
23
  *.tflite filter=lfs diff=lfs merge=lfs -text
24
  *.tgz filter=lfs diff=lfs merge=lfs -text
examples/pretraining_new_model/pretrain_geneformer_w_deepspeed.py CHANGED
@@ -138,7 +138,9 @@ training_args = {
138
  "per_device_train_batch_size": geneformer_batch_size,
139
  "num_train_epochs": epochs,
140
  "save_strategy": "steps",
141
- "save_steps": np.floor(num_examples / geneformer_batch_size / 8), # 8 saves per epoch
 
 
142
  "logging_steps": 1000,
143
  "output_dir": training_output_dir,
144
  "logging_dir": logging_dir,
 
138
  "per_device_train_batch_size": geneformer_batch_size,
139
  "num_train_epochs": epochs,
140
  "save_strategy": "steps",
141
+ "save_steps": np.floor(
142
+ num_examples / geneformer_batch_size / 8
143
+ ), # 8 saves per epoch
144
  "logging_steps": 1000,
145
  "output_dir": training_output_dir,
146
  "logging_dir": logging_dir,
geneformer/__init__.py CHANGED
@@ -1,6 +1,7 @@
1
  # ruff: noqa: F401
2
- from pathlib import Path
3
  import warnings
 
 
4
  warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa # isort:skip
5
 
6
  GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc95M.pkl"
@@ -30,4 +31,4 @@ from . import classifier # noqa # isort:skip
30
  from .classifier import Classifier # noqa # isort:skip
31
 
32
  from . import mtl_classifier # noqa # isort:skip
33
- from .mtl_classifier import MTLClassifier # noqa # isort:skip
 
1
  # ruff: noqa: F401
 
2
  import warnings
3
+ from pathlib import Path
4
+
5
  warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa # isort:skip
6
 
7
  GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc95M.pkl"
 
31
  from .classifier import Classifier # noqa # isort:skip
32
 
33
  from . import mtl_classifier # noqa # isort:skip
34
+ from .mtl_classifier import MTLClassifier # noqa # isort:skip
geneformer/classifier.py CHANGED
@@ -57,11 +57,14 @@ from tqdm.auto import tqdm, trange
57
  from transformers import Trainer
58
  from transformers.training_args import TrainingArguments
59
 
60
- from . import DataCollatorForCellClassification, DataCollatorForGeneClassification
 
 
 
 
61
  from . import classifier_utils as cu
62
  from . import evaluation_utils as eu
63
  from . import perturber_utils as pu
64
- from . import TOKEN_DICTIONARY_FILE
65
 
66
  sns.set()
67
 
@@ -413,7 +416,7 @@ class Classifier:
413
  "Column name 'labels' must be reserved for class IDs. Please rename column."
414
  )
415
  raise
416
-
417
  if (attr_to_split is not None) and (attr_to_balance is None):
418
  logger.error(
419
  "Splitting by attribute while balancing confounders requires both attr_to_split and attr_to_balance to be defined."
@@ -548,17 +551,19 @@ class Classifier:
548
  gene_balance : None, bool
549
  | Whether to automatically balance genes in training set.
550
  | Only available for binary gene classifications.
551
-
552
  **Output**
553
 
554
  Returns trainer after fine-tuning with all data.
555
 
556
  """
557
 
558
- if (gene_balance is True) and (len(self.gene_class_dict.values())!=2):
559
- logger.error("Automatically balancing gene sets for training is only available for binary gene classifications.")
 
 
560
  raise
561
-
562
  ##### Load data and prepare output directory #####
563
  # load numerical id to class dictionary (id:class)
564
  with open(id_class_dict_file, "rb") as f:
@@ -679,11 +684,13 @@ class Classifier:
679
  if self.num_crossval_splits == 0:
680
  logger.error("num_crossval_splits must be 1 or 5 to validate.")
681
  raise
682
-
683
- if (gene_balance is True) and (len(self.gene_class_dict.values())!=2):
684
- logger.error("Automatically balancing gene sets for training is only available for binary gene classifications.")
 
 
685
  raise
686
-
687
  # ensure number of genes in each class is > 5 if validating model
688
  if self.classifier == "gene":
689
  insuff_classes = [k for k, v in self.gene_class_dict.items() if len(v) < 5]
@@ -845,17 +852,18 @@ class Classifier:
845
  self.nproc,
846
  gene_balance,
847
  )
848
-
849
  if save_gene_split_datasets is True:
850
  for split_name in ["train", "valid"]:
851
  labeled_dataset_output_path = (
852
- Path(output_dir) / f"{output_prefix}_{split_name}_gene_labeled_ksplit{iteration_num}"
853
- ).with_suffix(".dataset")
 
854
  if split_name == "train":
855
  train_data.save_to_disk(str(labeled_dataset_output_path))
856
  elif split_name == "valid":
857
  eval_data.save_to_disk(str(labeled_dataset_output_path))
858
-
859
  if self.oos_test_size > 0:
860
  test_data = cu.prep_gene_classifier_split(
861
  data,
@@ -869,11 +877,14 @@ class Classifier:
869
  )
870
  if save_gene_split_datasets is True:
871
  test_labeled_dataset_output_path = (
872
- Path(output_dir) / f"{output_prefix}_test_gene_labeled_ksplit{iteration_num}"
873
- ).with_suffix(".dataset")
 
874
  test_data.save_to_disk(str(test_labeled_dataset_output_path))
875
  if debug_gene_split_datasets is True:
876
- logger.error("Exiting after saving gene split datasets given debug_gene_split_datasets = True.")
 
 
877
  raise
878
  if n_hyperopt_trials == 0:
879
  trainer = self.train_classifier(
@@ -1023,7 +1034,13 @@ class Classifier:
1023
  subprocess.call(f"mkdir {output_directory}", shell=True)
1024
 
1025
  ##### Load model and training args #####
1026
- model = pu.load_model(self.model_type, num_classes, model_directory, "train", quantize=self.quantize)
 
 
 
 
 
 
1027
  def_training_args, def_freeze_layers = cu.get_default_train_args(
1028
  model, self.classifier, train_data, output_directory
1029
  )
@@ -1047,14 +1064,22 @@ class Classifier:
1047
  ##### Fine-tune the model #####
1048
  # define the data collator
1049
  if self.classifier == "cell":
1050
- data_collator = DataCollatorForCellClassification(token_dictionary=self.token_dictionary)
 
 
1051
  elif self.classifier == "gene":
1052
- data_collator = DataCollatorForGeneClassification(token_dictionary=self.token_dictionary)
 
 
1053
 
1054
  # define function to initiate model
1055
  def model_init():
1056
  model = pu.load_model(
1057
- self.model_type, num_classes, model_directory, "train", quantize=self.quantize
 
 
 
 
1058
  )
1059
 
1060
  if self.freeze_layers is not None:
@@ -1180,7 +1205,13 @@ class Classifier:
1180
  subprocess.call(f"mkdir {output_directory}", shell=True)
1181
 
1182
  ##### Load model and training args #####
1183
- model = pu.load_model(self.model_type, num_classes, model_directory, "train", quantize=self.quantize)
 
 
 
 
 
 
1184
 
1185
  def_training_args, def_freeze_layers = cu.get_default_train_args(
1186
  model, self.classifier, train_data, output_directory
@@ -1210,9 +1241,13 @@ class Classifier:
1210
  ##### Fine-tune the model #####
1211
  # define the data collator
1212
  if self.classifier == "cell":
1213
- data_collator = DataCollatorForCellClassification(token_dictionary=self.token_dictionary)
 
 
1214
  elif self.classifier == "gene":
1215
- data_collator = DataCollatorForGeneClassification(token_dictionary=self.token_dictionary)
 
 
1216
 
1217
  # create the trainer
1218
  trainer = Trainer(
@@ -1334,7 +1369,13 @@ class Classifier:
1334
  test_data = pu.load_and_filter(None, self.nproc, test_data_file)
1335
 
1336
  # load previously fine-tuned model
1337
- model = pu.load_model(self.model_type, num_classes, model_directory, "eval", quantize=self.quantize)
 
 
 
 
 
 
1338
 
1339
  # evaluate the model
1340
  result = self.evaluate_model(
 
57
  from transformers import Trainer
58
  from transformers.training_args import TrainingArguments
59
 
60
+ from . import (
61
+ TOKEN_DICTIONARY_FILE,
62
+ DataCollatorForCellClassification,
63
+ DataCollatorForGeneClassification,
64
+ )
65
  from . import classifier_utils as cu
66
  from . import evaluation_utils as eu
67
  from . import perturber_utils as pu
 
68
 
69
  sns.set()
70
 
 
416
  "Column name 'labels' must be reserved for class IDs. Please rename column."
417
  )
418
  raise
419
+
420
  if (attr_to_split is not None) and (attr_to_balance is None):
421
  logger.error(
422
  "Splitting by attribute while balancing confounders requires both attr_to_split and attr_to_balance to be defined."
 
551
  gene_balance : None, bool
552
  | Whether to automatically balance genes in training set.
553
  | Only available for binary gene classifications.
554
+
555
  **Output**
556
 
557
  Returns trainer after fine-tuning with all data.
558
 
559
  """
560
 
561
+ if (gene_balance is True) and (len(self.gene_class_dict.values()) != 2):
562
+ logger.error(
563
+ "Automatically balancing gene sets for training is only available for binary gene classifications."
564
+ )
565
  raise
566
+
567
  ##### Load data and prepare output directory #####
568
  # load numerical id to class dictionary (id:class)
569
  with open(id_class_dict_file, "rb") as f:
 
684
  if self.num_crossval_splits == 0:
685
  logger.error("num_crossval_splits must be 1 or 5 to validate.")
686
  raise
687
+
688
+ if (gene_balance is True) and (len(self.gene_class_dict.values()) != 2):
689
+ logger.error(
690
+ "Automatically balancing gene sets for training is only available for binary gene classifications."
691
+ )
692
  raise
693
+
694
  # ensure number of genes in each class is > 5 if validating model
695
  if self.classifier == "gene":
696
  insuff_classes = [k for k, v in self.gene_class_dict.items() if len(v) < 5]
 
852
  self.nproc,
853
  gene_balance,
854
  )
855
+
856
  if save_gene_split_datasets is True:
857
  for split_name in ["train", "valid"]:
858
  labeled_dataset_output_path = (
859
+ Path(output_dir)
860
+ / f"{output_prefix}_{split_name}_gene_labeled_ksplit{iteration_num}"
861
+ ).with_suffix(".dataset")
862
  if split_name == "train":
863
  train_data.save_to_disk(str(labeled_dataset_output_path))
864
  elif split_name == "valid":
865
  eval_data.save_to_disk(str(labeled_dataset_output_path))
866
+
867
  if self.oos_test_size > 0:
868
  test_data = cu.prep_gene_classifier_split(
869
  data,
 
877
  )
878
  if save_gene_split_datasets is True:
879
  test_labeled_dataset_output_path = (
880
+ Path(output_dir)
881
+ / f"{output_prefix}_test_gene_labeled_ksplit{iteration_num}"
882
+ ).with_suffix(".dataset")
883
  test_data.save_to_disk(str(test_labeled_dataset_output_path))
884
  if debug_gene_split_datasets is True:
885
+ logger.error(
886
+ "Exiting after saving gene split datasets given debug_gene_split_datasets = True."
887
+ )
888
  raise
889
  if n_hyperopt_trials == 0:
890
  trainer = self.train_classifier(
 
1034
  subprocess.call(f"mkdir {output_directory}", shell=True)
1035
 
1036
  ##### Load model and training args #####
1037
+ model = pu.load_model(
1038
+ self.model_type,
1039
+ num_classes,
1040
+ model_directory,
1041
+ "train",
1042
+ quantize=self.quantize,
1043
+ )
1044
  def_training_args, def_freeze_layers = cu.get_default_train_args(
1045
  model, self.classifier, train_data, output_directory
1046
  )
 
1064
  ##### Fine-tune the model #####
1065
  # define the data collator
1066
  if self.classifier == "cell":
1067
+ data_collator = DataCollatorForCellClassification(
1068
+ token_dictionary=self.token_dictionary
1069
+ )
1070
  elif self.classifier == "gene":
1071
+ data_collator = DataCollatorForGeneClassification(
1072
+ token_dictionary=self.token_dictionary
1073
+ )
1074
 
1075
  # define function to initiate model
1076
  def model_init():
1077
  model = pu.load_model(
1078
+ self.model_type,
1079
+ num_classes,
1080
+ model_directory,
1081
+ "train",
1082
+ quantize=self.quantize,
1083
  )
1084
 
1085
  if self.freeze_layers is not None:
 
1205
  subprocess.call(f"mkdir {output_directory}", shell=True)
1206
 
1207
  ##### Load model and training args #####
1208
+ model = pu.load_model(
1209
+ self.model_type,
1210
+ num_classes,
1211
+ model_directory,
1212
+ "train",
1213
+ quantize=self.quantize,
1214
+ )
1215
 
1216
  def_training_args, def_freeze_layers = cu.get_default_train_args(
1217
  model, self.classifier, train_data, output_directory
 
1241
  ##### Fine-tune the model #####
1242
  # define the data collator
1243
  if self.classifier == "cell":
1244
+ data_collator = DataCollatorForCellClassification(
1245
+ token_dictionary=self.token_dictionary
1246
+ )
1247
  elif self.classifier == "gene":
1248
+ data_collator = DataCollatorForGeneClassification(
1249
+ token_dictionary=self.token_dictionary
1250
+ )
1251
 
1252
  # create the trainer
1253
  trainer = Trainer(
 
1369
  test_data = pu.load_and_filter(None, self.nproc, test_data_file)
1370
 
1371
  # load previously fine-tuned model
1372
+ model = pu.load_model(
1373
+ self.model_type,
1374
+ num_classes,
1375
+ model_directory,
1376
+ "eval",
1377
+ quantize=self.quantize,
1378
+ )
1379
 
1380
  # evaluate the model
1381
  result = self.evaluate_model(
geneformer/classifier_utils.py CHANGED
@@ -137,22 +137,53 @@ def label_gene_classes(example, class_id_dict, gene_class_dict):
137
 
138
 
139
  def prep_gene_classifier_train_eval_split(
140
- data, targets, labels, train_index, eval_index, max_ncells, iteration_num, num_proc, balance=False
 
 
 
 
 
 
 
 
141
  ):
142
  # generate cross-validation splits
143
  train_data = prep_gene_classifier_split(
144
- data, targets, labels, train_index, "train", max_ncells, iteration_num, num_proc, balance
 
 
 
 
 
 
 
 
145
  )
146
  eval_data = prep_gene_classifier_split(
147
- data, targets, labels, eval_index, "eval", max_ncells, iteration_num, num_proc, balance
 
 
 
 
 
 
 
 
148
  )
149
  return train_data, eval_data
150
 
151
 
152
  def prep_gene_classifier_split(
153
- data, targets, labels, index, subset_name, max_ncells, iteration_num, num_proc, balance=False
 
 
 
 
 
 
 
 
154
  ):
155
-
156
  # generate cross-validation splits
157
  targets = np.array(targets)
158
  labels = np.array(labels)
@@ -175,7 +206,9 @@ def prep_gene_classifier_split(
175
 
176
  # balance gene subsets if train
177
  if (subset_name == "train") and (balance is True):
178
- subset_data, label_dict_subset = balance_gene_split(subset_data, label_dict_subset, num_proc)
 
 
179
 
180
  # subsample to max_ncells
181
  subset_data = downsample_and_shuffle(subset_data, max_ncells, None, None)
@@ -192,7 +225,9 @@ def prep_gene_classifier_split(
192
  return subset_data
193
 
194
 
195
- def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc, balance=False):
 
 
196
  targets = np.array(targets)
197
  labels = np.array(labels)
198
  label_dict_train = dict(zip(targets, labels))
@@ -211,8 +246,10 @@ def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc, b
211
  )
212
 
213
  if balance is True:
214
- train_data, label_dict_train = balance_gene_split(train_data, label_dict_train, num_proc)
215
-
 
 
216
  # subsample to max_ncells
217
  train_data = downsample_and_shuffle(train_data, max_ncells, None, None)
218
 
@@ -230,62 +267,97 @@ def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc, b
230
 
231
  def balance_gene_split(subset_data, label_dict_subset, num_proc):
232
  # count occurrence of genes in each label category
233
- label0_counts, label1_counts = count_genes_for_balancing(subset_data, label_dict_subset, num_proc)
234
- label_ratio_0to1 = label0_counts/label1_counts
235
-
236
- if 8/10 <= label_ratio_0to1 <= 10/8:
 
 
237
  # gene sets already balanced
238
  logger.info(
239
  "Gene sets were already balanced within 0.8-1.25 fold and did not require balancing.\n"
240
  )
241
  return subset_data, label_dict_subset
242
  else:
243
- label_ratio_0to1_orig = label_ratio_0to1+0
244
  label_dict_subset_orig = label_dict_subset.copy()
245
  # balance gene sets
246
  max_ntrials = 25
247
  boost = 1
248
- if label_ratio_0to1 > 10/8:
249
  # downsample label 0
250
  for i in range(max_ntrials):
251
  label0 = 0
252
- label0_genes = [k for k,v in label_dict_subset.items() if v == label0]
253
  label0_ngenes = len(label0_genes)
254
- label0_nremove = max(1,int(np.floor(label0_ngenes - label0_ngenes/(label_ratio_0to1*boost))))
 
 
 
 
 
 
 
255
  random.seed(i)
256
  label0_remove_genes = random.sample(label0_genes, label0_nremove)
257
- label_dict_subset_new = {k:v for k,v in label_dict_subset.items() if k not in label0_remove_genes}
258
- label0_counts, label1_counts = count_genes_for_balancing(subset_data, label_dict_subset_new, num_proc)
259
- label_ratio_0to1 = label0_counts/label1_counts
260
- if 8/10 <= label_ratio_0to1 <= 10/8:
 
 
 
 
 
 
261
  # if gene sets now balanced, return new filtered data and new label_dict_subset
262
- return filter_data_balanced_genes(subset_data, label_dict_subset_new, num_proc)
263
- elif label_ratio_0to1 > 10/8:
264
- boost = boost*1.1
265
- elif label_ratio_0to1 < 8/10:
266
- boost = boost*0.9
 
 
267
  else:
268
  # downsample label 1
269
  for i in range(max_ntrials):
270
  label1 = 1
271
- label1_genes = [k for k,v in label_dict_subset.items() if v == label1]
272
  label1_ngenes = len(label1_genes)
273
- label1_nremove = max(1,int(np.floor(label1_ngenes - label1_ngenes/((1/label_ratio_0to1)*boost))))
 
 
 
 
 
 
 
 
274
  random.seed(i)
275
  label1_remove_genes = random.sample(label1_genes, label1_nremove)
276
- label_dict_subset_new = {k:v for k,v in label_dict_subset.items() if k not in label1_remove_genes}
277
- label0_counts, label1_counts = count_genes_for_balancing(subset_data, label_dict_subset_new, num_proc)
278
- label_ratio_0to1 = label0_counts/label1_counts
279
- if 8/10 <= label_ratio_0to1 <= 10/8:
 
 
 
 
 
 
280
  # if gene sets now balanced, return new filtered data and new label_dict_subset
281
- return filter_data_balanced_genes(subset_data, label_dict_subset_new, num_proc)
282
- elif label_ratio_0to1 < 8/10:
283
- boost = boost*1.1
284
- elif label_ratio_0to1 > 10/8:
285
- boost = boost*0.9
286
-
287
- assert i+1 == max_ntrials
288
- if (label_ratio_0to1 <= label_ratio_0to1_orig < 8/10) or (10/8 > label_ratio_0to1_orig >= label_ratio_0to1):
 
 
 
 
289
  label_ratio_0to1 = label_ratio_0to1_orig
290
  label_dict_subset_new = label_dict_subset_orig
291
  logger.warning(
@@ -301,11 +373,11 @@ def count_genes_for_balancing(subset_data, label_dict_subset, num_proc):
301
  ]
302
  counter_labels = Counter(labels)
303
  # get count of labels 0 or 1, or if absent, return 0
304
- example["labels_counts"] = [counter_labels.get(0,0),counter_labels.get(1,0)]
305
  return example
306
-
307
  subset_data = subset_data.map(count_targets, num_proc=num_proc)
308
-
309
  label0_counts = sum([counts[0] for counts in subset_data["labels_counts"]])
310
  label1_counts = sum([counts[1] for counts in subset_data["labels_counts"]])
311
 
 
137
 
138
 
139
  def prep_gene_classifier_train_eval_split(
140
+ data,
141
+ targets,
142
+ labels,
143
+ train_index,
144
+ eval_index,
145
+ max_ncells,
146
+ iteration_num,
147
+ num_proc,
148
+ balance=False,
149
  ):
150
  # generate cross-validation splits
151
  train_data = prep_gene_classifier_split(
152
+ data,
153
+ targets,
154
+ labels,
155
+ train_index,
156
+ "train",
157
+ max_ncells,
158
+ iteration_num,
159
+ num_proc,
160
+ balance,
161
  )
162
  eval_data = prep_gene_classifier_split(
163
+ data,
164
+ targets,
165
+ labels,
166
+ eval_index,
167
+ "eval",
168
+ max_ncells,
169
+ iteration_num,
170
+ num_proc,
171
+ balance,
172
  )
173
  return train_data, eval_data
174
 
175
 
176
  def prep_gene_classifier_split(
177
+ data,
178
+ targets,
179
+ labels,
180
+ index,
181
+ subset_name,
182
+ max_ncells,
183
+ iteration_num,
184
+ num_proc,
185
+ balance=False,
186
  ):
 
187
  # generate cross-validation splits
188
  targets = np.array(targets)
189
  labels = np.array(labels)
 
206
 
207
  # balance gene subsets if train
208
  if (subset_name == "train") and (balance is True):
209
+ subset_data, label_dict_subset = balance_gene_split(
210
+ subset_data, label_dict_subset, num_proc
211
+ )
212
 
213
  # subsample to max_ncells
214
  subset_data = downsample_and_shuffle(subset_data, max_ncells, None, None)
 
225
  return subset_data
226
 
227
 
228
+ def prep_gene_classifier_all_data(
229
+ data, targets, labels, max_ncells, num_proc, balance=False
230
+ ):
231
  targets = np.array(targets)
232
  labels = np.array(labels)
233
  label_dict_train = dict(zip(targets, labels))
 
246
  )
247
 
248
  if balance is True:
249
+ train_data, label_dict_train = balance_gene_split(
250
+ train_data, label_dict_train, num_proc
251
+ )
252
+
253
  # subsample to max_ncells
254
  train_data = downsample_and_shuffle(train_data, max_ncells, None, None)
255
 
 
267
 
268
  def balance_gene_split(subset_data, label_dict_subset, num_proc):
269
  # count occurrence of genes in each label category
270
+ label0_counts, label1_counts = count_genes_for_balancing(
271
+ subset_data, label_dict_subset, num_proc
272
+ )
273
+ label_ratio_0to1 = label0_counts / label1_counts
274
+
275
+ if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
276
  # gene sets already balanced
277
  logger.info(
278
  "Gene sets were already balanced within 0.8-1.25 fold and did not require balancing.\n"
279
  )
280
  return subset_data, label_dict_subset
281
  else:
282
+ label_ratio_0to1_orig = label_ratio_0to1 + 0
283
  label_dict_subset_orig = label_dict_subset.copy()
284
  # balance gene sets
285
  max_ntrials = 25
286
  boost = 1
287
+ if label_ratio_0to1 > 10 / 8:
288
  # downsample label 0
289
  for i in range(max_ntrials):
290
  label0 = 0
291
+ label0_genes = [k for k, v in label_dict_subset.items() if v == label0]
292
  label0_ngenes = len(label0_genes)
293
+ label0_nremove = max(
294
+ 1,
295
+ int(
296
+ np.floor(
297
+ label0_ngenes - label0_ngenes / (label_ratio_0to1 * boost)
298
+ )
299
+ ),
300
+ )
301
  random.seed(i)
302
  label0_remove_genes = random.sample(label0_genes, label0_nremove)
303
+ label_dict_subset_new = {
304
+ k: v
305
+ for k, v in label_dict_subset.items()
306
+ if k not in label0_remove_genes
307
+ }
308
+ label0_counts, label1_counts = count_genes_for_balancing(
309
+ subset_data, label_dict_subset_new, num_proc
310
+ )
311
+ label_ratio_0to1 = label0_counts / label1_counts
312
+ if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
313
  # if gene sets now balanced, return new filtered data and new label_dict_subset
314
+ return filter_data_balanced_genes(
315
+ subset_data, label_dict_subset_new, num_proc
316
+ )
317
+ elif label_ratio_0to1 > 10 / 8:
318
+ boost = boost * 1.1
319
+ elif label_ratio_0to1 < 8 / 10:
320
+ boost = boost * 0.9
321
  else:
322
  # downsample label 1
323
  for i in range(max_ntrials):
324
  label1 = 1
325
+ label1_genes = [k for k, v in label_dict_subset.items() if v == label1]
326
  label1_ngenes = len(label1_genes)
327
+ label1_nremove = max(
328
+ 1,
329
+ int(
330
+ np.floor(
331
+ label1_ngenes
332
+ - label1_ngenes / ((1 / label_ratio_0to1) * boost)
333
+ )
334
+ ),
335
+ )
336
  random.seed(i)
337
  label1_remove_genes = random.sample(label1_genes, label1_nremove)
338
+ label_dict_subset_new = {
339
+ k: v
340
+ for k, v in label_dict_subset.items()
341
+ if k not in label1_remove_genes
342
+ }
343
+ label0_counts, label1_counts = count_genes_for_balancing(
344
+ subset_data, label_dict_subset_new, num_proc
345
+ )
346
+ label_ratio_0to1 = label0_counts / label1_counts
347
+ if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
348
  # if gene sets now balanced, return new filtered data and new label_dict_subset
349
+ return filter_data_balanced_genes(
350
+ subset_data, label_dict_subset_new, num_proc
351
+ )
352
+ elif label_ratio_0to1 < 8 / 10:
353
+ boost = boost * 1.1
354
+ elif label_ratio_0to1 > 10 / 8:
355
+ boost = boost * 0.9
356
+
357
+ assert i + 1 == max_ntrials
358
+ if (label_ratio_0to1 <= label_ratio_0to1_orig < 8 / 10) or (
359
+ 10 / 8 > label_ratio_0to1_orig >= label_ratio_0to1
360
+ ):
361
  label_ratio_0to1 = label_ratio_0to1_orig
362
  label_dict_subset_new = label_dict_subset_orig
363
  logger.warning(
 
373
  ]
374
  counter_labels = Counter(labels)
375
  # get count of labels 0 or 1, or if absent, return 0
376
+ example["labels_counts"] = [counter_labels.get(0, 0), counter_labels.get(1, 0)]
377
  return example
378
+
379
  subset_data = subset_data.map(count_targets, num_proc=num_proc)
380
+
381
  label0_counts = sum([counts[0] for counts in subset_data["labels_counts"]])
382
  label1_counts = sum([counts[1] for counts in subset_data["labels_counts"]])
383
 
geneformer/collator_for_classification.py CHANGED
@@ -3,17 +3,17 @@ Geneformer collator for gene and cell classification.
3
 
4
  Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification.
5
  """
6
- import numpy as np
7
- import pickle
8
- import torch
9
  import warnings
10
  from enum import Enum
11
  from typing import Dict, List, Optional, Union
12
 
 
 
13
  from transformers import (
 
14
  DataCollatorForTokenClassification,
15
  SpecialTokensMixin,
16
- BatchEncoding,
17
  )
18
  from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
19
  from transformers.utils.generic import _is_tensorflow, _is_torch
@@ -29,6 +29,7 @@ LARGE_INTEGER = int(
29
 
30
  # precollator functions
31
 
 
32
  class ExplicitEnum(Enum):
33
  """
34
  Enum with more explicit error message for missing values.
@@ -41,6 +42,7 @@ class ExplicitEnum(Enum):
41
  % (value, cls.__name__, str(list(cls._value2member_map_.keys())))
42
  )
43
 
 
44
  class TruncationStrategy(ExplicitEnum):
45
  """
46
  Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
@@ -53,7 +55,6 @@ class TruncationStrategy(ExplicitEnum):
53
  DO_NOT_TRUNCATE = "do_not_truncate"
54
 
55
 
56
-
57
  class PaddingStrategy(ExplicitEnum):
58
  """
59
  Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
@@ -65,7 +66,6 @@ class PaddingStrategy(ExplicitEnum):
65
  DO_NOT_PAD = "do_not_pad"
66
 
67
 
68
-
69
  class TensorType(ExplicitEnum):
70
  """
71
  Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
@@ -77,7 +77,7 @@ class TensorType(ExplicitEnum):
77
  NUMPY = "np"
78
  JAX = "jax"
79
 
80
-
81
  class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
82
  def __init__(self, *args, **kwargs) -> None:
83
  super().__init__(mask_token="<mask>", pad_token="<pad>")
@@ -89,11 +89,17 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
89
  self.pad_token_id = self.token_dictionary.get("<pad>")
90
  self.all_special_ids = [
91
  self.token_dictionary.get("<mask>"),
92
- self.token_dictionary.get("<pad>")
93
  ]
94
 
95
  def _get_padding_truncation_strategies(
96
- self, padding=True, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
 
 
 
 
 
 
97
  ):
98
  """
99
  Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
@@ -106,7 +112,9 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
106
  # If you only set max_length, it activates truncation for max_length
107
  if max_length is not None and padding is False and truncation is False:
108
  if verbose:
109
- if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False):
 
 
110
  logger.warning(
111
  "Truncation was not explicitly activated but `max_length` is provided a specific value, "
112
  "please use `truncation=True` to explicitly truncate examples to max length. "
@@ -134,7 +142,9 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
134
  padding_strategy = PaddingStrategy.MAX_LENGTH
135
  elif padding is not False:
136
  if padding is True:
137
- padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
 
 
138
  elif not isinstance(padding, PaddingStrategy):
139
  padding_strategy = PaddingStrategy(padding)
140
  elif isinstance(padding, PaddingStrategy):
@@ -174,7 +184,9 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
174
  if padding_strategy == PaddingStrategy.MAX_LENGTH:
175
  if self.model_max_length > LARGE_INTEGER:
176
  if verbose:
177
- if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False):
 
 
178
  logger.warning(
179
  "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
180
  "Default to no padding."
@@ -187,18 +199,24 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
187
  if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
188
  if self.model_max_length > LARGE_INTEGER:
189
  if verbose:
190
- if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False):
 
 
191
  logger.warning(
192
  "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
193
  "Default to no truncation."
194
  )
195
- self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True
 
 
196
  truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
197
  else:
198
  max_length = self.model_max_length
199
 
200
  # Test if we have a padding token
201
- if padding_strategy != PaddingStrategy.DO_NOT_PAD and (not self.pad_token or self.pad_token_id < 0):
 
 
202
  raise ValueError(
203
  "Asking to pad but the tokenizer does not have a padding token. "
204
  "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
@@ -229,7 +247,7 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
229
  Dict[str, List[EncodedInput]],
230
  List[Dict[str, EncodedInput]],
231
  ],
232
- class_type, # options: "gene" or "cell"
233
  padding: Union[bool, str, PaddingStrategy] = True,
234
  max_length: Optional[int] = None,
235
  pad_to_multiple_of: Optional[int] = None,
@@ -292,8 +310,13 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
292
  """
293
  # If we have a list of dicts, let's convert it in a dict of lists
294
  # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
295
- if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
296
- encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
 
 
 
 
 
297
 
298
  # The model's main input name, usually `input_ids`, has be passed for padding
299
  if self.model_input_names[0] not in encoded_inputs:
@@ -387,7 +410,7 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
387
  def _pad(
388
  self,
389
  encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
390
- class_type, # options: "gene" or "cell"
391
  max_length: Optional[int] = None,
392
  padding_strategy: PaddingStrategy = PaddingStrategy.LONGEST,
393
  pad_to_multiple_of: Optional[int] = None,
@@ -423,46 +446,73 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
423
  if padding_strategy == PaddingStrategy.LONGEST:
424
  max_length = len(required_input)
425
 
426
- if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
 
 
 
 
427
  max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
428
 
429
- needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
 
 
 
430
 
431
  if needs_to_be_padded:
432
  difference = max_length - len(required_input)
433
  if self.padding_side == "right":
434
  if return_attention_mask:
435
- encoded_inputs["attention_mask"] = [1] * len(required_input) + [0] * difference
 
 
436
  if "token_type_ids" in encoded_inputs:
437
  encoded_inputs["token_type_ids"] = (
438
- encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
 
439
  )
440
  if "special_tokens_mask" in encoded_inputs:
441
- encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
442
- encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
 
 
 
 
443
  if class_type == "gene":
444
- encoded_inputs["labels"] = encoded_inputs["labels"] + [-100] * difference
 
 
445
  elif self.padding_side == "left":
446
  if return_attention_mask:
447
- encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input)
 
 
448
  if "token_type_ids" in encoded_inputs:
449
- encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
450
- "token_type_ids"
451
- ]
452
  if "special_tokens_mask" in encoded_inputs:
453
- encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
454
- encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
 
 
 
 
455
  if class_type == "gene":
456
- encoded_inputs["labels"] = [-100] * difference + encoded_inputs["labels"]
 
 
457
  else:
458
  raise ValueError("Invalid padding strategy:" + str(self.padding_side))
459
  elif return_attention_mask and "attention_mask" not in encoded_inputs:
460
  encoded_inputs["attention_mask"] = [1] * len(required_input)
461
-
462
  return encoded_inputs
463
 
464
  def get_special_tokens_mask(
465
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
 
 
 
466
  ) -> List[int]:
467
  """
468
  Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
@@ -486,11 +536,15 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
486
 
487
  all_special_ids = self.all_special_ids # cache the property
488
 
489
- special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0]
 
 
490
 
491
  return special_tokens_mask
492
 
493
- def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
 
 
494
  """
495
  Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
496
  vocabulary.
@@ -514,14 +568,15 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
514
  if token is None:
515
  return None
516
 
517
- return token_dictionary.get(token)
518
 
519
  def __len__(self):
520
- return len(token_dictionary)
521
 
522
 
523
  # collator functions
524
 
 
525
  class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
526
  """
527
  Data collator that will dynamically pad the inputs received, as well as the labels.
@@ -546,26 +601,34 @@ class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
546
  label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
547
  The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
548
  """
549
-
550
  class_type = "gene"
551
  padding: Union[bool, str, PaddingStrategy] = True
552
  max_length: Optional[int] = None
553
  pad_to_multiple_of: Optional[int] = None
554
  label_pad_token_id: int = -100
555
-
556
  def __init__(self, *args, **kwargs) -> None:
557
  self.token_dictionary = kwargs.pop("token_dictionary")
558
  super().__init__(
559
- tokenizer=PrecollatorForGeneAndCellClassification(token_dictionary=self.token_dictionary),
 
 
560
  padding=self.padding,
561
  max_length=self.max_length,
562
  pad_to_multiple_of=self.pad_to_multiple_of,
563
  label_pad_token_id=self.label_pad_token_id,
564
- *args, **kwargs)
 
 
565
 
566
  def _prepare_batch(self, features):
567
  label_name = "label" if "label" in features[0].keys() else "labels"
568
- labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
 
 
 
 
569
  batch = self.tokenizer.pad(
570
  features,
571
  class_type=self.class_type,
@@ -575,29 +638,31 @@ class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
575
  return_tensors="pt",
576
  )
577
  return batch
578
-
579
  def __call__(self, features):
580
  batch = self._prepare_batch(features)
581
 
582
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
583
  return batch
584
 
585
-
586
- class DataCollatorForCellClassification(DataCollatorForGeneClassification):
587
 
 
588
  class_type = "cell"
589
 
590
  def _prepare_batch(self, features):
591
-
592
  batch = super()._prepare_batch(features)
593
-
594
  # Special handling for labels.
595
  # Ensure that tensor is created with the correct type
596
  # (it should be automatically the case, but let's make sure of it.)
597
  first = features[0]
598
  if "label" in first and first["label"] is not None:
599
- label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
 
 
 
 
600
  dtype = torch.long if isinstance(label, int) else torch.float
601
  batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
602
-
603
  return batch
 
3
 
4
  Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification.
5
  """
6
+
 
 
7
  import warnings
8
  from enum import Enum
9
  from typing import Dict, List, Optional, Union
10
 
11
+ import numpy as np
12
+ import torch
13
  from transformers import (
14
+ BatchEncoding,
15
  DataCollatorForTokenClassification,
16
  SpecialTokensMixin,
 
17
  )
18
  from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
19
  from transformers.utils.generic import _is_tensorflow, _is_torch
 
29
 
30
  # precollator functions
31
 
32
+
33
  class ExplicitEnum(Enum):
34
  """
35
  Enum with more explicit error message for missing values.
 
42
  % (value, cls.__name__, str(list(cls._value2member_map_.keys())))
43
  )
44
 
45
+
46
  class TruncationStrategy(ExplicitEnum):
47
  """
48
  Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
 
55
  DO_NOT_TRUNCATE = "do_not_truncate"
56
 
57
 
 
58
  class PaddingStrategy(ExplicitEnum):
59
  """
60
  Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
 
66
  DO_NOT_PAD = "do_not_pad"
67
 
68
 
 
69
  class TensorType(ExplicitEnum):
70
  """
71
  Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
 
77
  NUMPY = "np"
78
  JAX = "jax"
79
 
80
+
81
  class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
82
  def __init__(self, *args, **kwargs) -> None:
83
  super().__init__(mask_token="<mask>", pad_token="<pad>")
 
89
  self.pad_token_id = self.token_dictionary.get("<pad>")
90
  self.all_special_ids = [
91
  self.token_dictionary.get("<mask>"),
92
+ self.token_dictionary.get("<pad>"),
93
  ]
94
 
95
  def _get_padding_truncation_strategies(
96
+ self,
97
+ padding=True,
98
+ truncation=False,
99
+ max_length=None,
100
+ pad_to_multiple_of=None,
101
+ verbose=True,
102
+ **kwargs,
103
  ):
104
  """
105
  Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
 
112
  # If you only set max_length, it activates truncation for max_length
113
  if max_length is not None and padding is False and truncation is False:
114
  if verbose:
115
+ if not self.deprecation_warnings.get(
116
+ "Truncation-not-explicitly-activated", False
117
+ ):
118
  logger.warning(
119
  "Truncation was not explicitly activated but `max_length` is provided a specific value, "
120
  "please use `truncation=True` to explicitly truncate examples to max length. "
 
142
  padding_strategy = PaddingStrategy.MAX_LENGTH
143
  elif padding is not False:
144
  if padding is True:
145
+ padding_strategy = (
146
+ PaddingStrategy.LONGEST
147
+ ) # Default to pad to the longest sequence in the batch
148
  elif not isinstance(padding, PaddingStrategy):
149
  padding_strategy = PaddingStrategy(padding)
150
  elif isinstance(padding, PaddingStrategy):
 
184
  if padding_strategy == PaddingStrategy.MAX_LENGTH:
185
  if self.model_max_length > LARGE_INTEGER:
186
  if verbose:
187
+ if not self.deprecation_warnings.get(
188
+ "Asking-to-pad-to-max_length", False
189
+ ):
190
  logger.warning(
191
  "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
192
  "Default to no padding."
 
199
  if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
200
  if self.model_max_length > LARGE_INTEGER:
201
  if verbose:
202
+ if not self.deprecation_warnings.get(
203
+ "Asking-to-truncate-to-max_length", False
204
+ ):
205
  logger.warning(
206
  "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
207
  "Default to no truncation."
208
  )
209
+ self.deprecation_warnings[
210
+ "Asking-to-truncate-to-max_length"
211
+ ] = True
212
  truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
213
  else:
214
  max_length = self.model_max_length
215
 
216
  # Test if we have a padding token
217
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD and (
218
+ not self.pad_token or self.pad_token_id < 0
219
+ ):
220
  raise ValueError(
221
  "Asking to pad but the tokenizer does not have a padding token. "
222
  "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
 
247
  Dict[str, List[EncodedInput]],
248
  List[Dict[str, EncodedInput]],
249
  ],
250
+ class_type, # options: "gene" or "cell"
251
  padding: Union[bool, str, PaddingStrategy] = True,
252
  max_length: Optional[int] = None,
253
  pad_to_multiple_of: Optional[int] = None,
 
310
  """
311
  # If we have a list of dicts, let's convert it in a dict of lists
312
  # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
313
+ if isinstance(encoded_inputs, (list, tuple)) and isinstance(
314
+ encoded_inputs[0], (dict, BatchEncoding)
315
+ ):
316
+ encoded_inputs = {
317
+ key: [example[key] for example in encoded_inputs]
318
+ for key in encoded_inputs[0].keys()
319
+ }
320
 
321
  # The model's main input name, usually `input_ids`, has be passed for padding
322
  if self.model_input_names[0] not in encoded_inputs:
 
410
  def _pad(
411
  self,
412
  encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
413
+ class_type, # options: "gene" or "cell"
414
  max_length: Optional[int] = None,
415
  padding_strategy: PaddingStrategy = PaddingStrategy.LONGEST,
416
  pad_to_multiple_of: Optional[int] = None,
 
446
  if padding_strategy == PaddingStrategy.LONGEST:
447
  max_length = len(required_input)
448
 
449
+ if (
450
+ max_length is not None
451
+ and pad_to_multiple_of is not None
452
+ and (max_length % pad_to_multiple_of != 0)
453
+ ):
454
  max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
455
 
456
+ needs_to_be_padded = (
457
+ padding_strategy != PaddingStrategy.DO_NOT_PAD
458
+ and len(required_input) != max_length
459
+ )
460
 
461
  if needs_to_be_padded:
462
  difference = max_length - len(required_input)
463
  if self.padding_side == "right":
464
  if return_attention_mask:
465
+ encoded_inputs["attention_mask"] = [1] * len(required_input) + [
466
+ 0
467
+ ] * difference
468
  if "token_type_ids" in encoded_inputs:
469
  encoded_inputs["token_type_ids"] = (
470
+ encoded_inputs["token_type_ids"]
471
+ + [self.pad_token_type_id] * difference
472
  )
473
  if "special_tokens_mask" in encoded_inputs:
474
+ encoded_inputs["special_tokens_mask"] = (
475
+ encoded_inputs["special_tokens_mask"] + [1] * difference
476
+ )
477
+ encoded_inputs[self.model_input_names[0]] = (
478
+ required_input + [self.pad_token_id] * difference
479
+ )
480
  if class_type == "gene":
481
+ encoded_inputs["labels"] = (
482
+ encoded_inputs["labels"] + [-100] * difference
483
+ )
484
  elif self.padding_side == "left":
485
  if return_attention_mask:
486
+ encoded_inputs["attention_mask"] = [0] * difference + [1] * len(
487
+ required_input
488
+ )
489
  if "token_type_ids" in encoded_inputs:
490
+ encoded_inputs["token_type_ids"] = [
491
+ self.pad_token_type_id
492
+ ] * difference + encoded_inputs["token_type_ids"]
493
  if "special_tokens_mask" in encoded_inputs:
494
+ encoded_inputs["special_tokens_mask"] = [
495
+ 1
496
+ ] * difference + encoded_inputs["special_tokens_mask"]
497
+ encoded_inputs[self.model_input_names[0]] = [
498
+ self.pad_token_id
499
+ ] * difference + required_input
500
  if class_type == "gene":
501
+ encoded_inputs["labels"] = [-100] * difference + encoded_inputs[
502
+ "labels"
503
+ ]
504
  else:
505
  raise ValueError("Invalid padding strategy:" + str(self.padding_side))
506
  elif return_attention_mask and "attention_mask" not in encoded_inputs:
507
  encoded_inputs["attention_mask"] = [1] * len(required_input)
508
+
509
  return encoded_inputs
510
 
511
  def get_special_tokens_mask(
512
+ self,
513
+ token_ids_0: List[int],
514
+ token_ids_1: Optional[List[int]] = None,
515
+ already_has_special_tokens: bool = False,
516
  ) -> List[int]:
517
  """
518
  Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
 
536
 
537
  all_special_ids = self.all_special_ids # cache the property
538
 
539
+ special_tokens_mask = [
540
+ 1 if token in all_special_ids else 0 for token in token_ids_0
541
+ ]
542
 
543
  return special_tokens_mask
544
 
545
+ def convert_tokens_to_ids(
546
+ self, tokens: Union[str, List[str]]
547
+ ) -> Union[int, List[int]]:
548
  """
549
  Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
550
  vocabulary.
 
568
  if token is None:
569
  return None
570
 
571
+ return self.token_dictionary.get(token)
572
 
573
  def __len__(self):
574
+ return len(self.token_dictionary)
575
 
576
 
577
  # collator functions
578
 
579
+
580
  class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
581
  """
582
  Data collator that will dynamically pad the inputs received, as well as the labels.
 
601
  label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
602
  The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
603
  """
604
+
605
  class_type = "gene"
606
  padding: Union[bool, str, PaddingStrategy] = True
607
  max_length: Optional[int] = None
608
  pad_to_multiple_of: Optional[int] = None
609
  label_pad_token_id: int = -100
610
+
611
  def __init__(self, *args, **kwargs) -> None:
612
  self.token_dictionary = kwargs.pop("token_dictionary")
613
  super().__init__(
614
+ tokenizer=PrecollatorForGeneAndCellClassification(
615
+ token_dictionary=self.token_dictionary
616
+ ),
617
  padding=self.padding,
618
  max_length=self.max_length,
619
  pad_to_multiple_of=self.pad_to_multiple_of,
620
  label_pad_token_id=self.label_pad_token_id,
621
+ *args,
622
+ **kwargs,
623
+ )
624
 
625
  def _prepare_batch(self, features):
626
  label_name = "label" if "label" in features[0].keys() else "labels"
627
+ labels = (
628
+ [feature[label_name] for feature in features]
629
+ if label_name in features[0].keys()
630
+ else None
631
+ )
632
  batch = self.tokenizer.pad(
633
  features,
634
  class_type=self.class_type,
 
638
  return_tensors="pt",
639
  )
640
  return batch
641
+
642
  def __call__(self, features):
643
  batch = self._prepare_batch(features)
644
 
645
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
646
  return batch
647
 
 
 
648
 
649
+ class DataCollatorForCellClassification(DataCollatorForGeneClassification):
650
  class_type = "cell"
651
 
652
  def _prepare_batch(self, features):
 
653
  batch = super()._prepare_batch(features)
654
+
655
  # Special handling for labels.
656
  # Ensure that tensor is created with the correct type
657
  # (it should be automatically the case, but let's make sure of it.)
658
  first = features[0]
659
  if "label" in first and first["label"] is not None:
660
+ label = (
661
+ first["label"].item()
662
+ if isinstance(first["label"], torch.Tensor)
663
+ else first["label"]
664
+ )
665
  dtype = torch.long if isinstance(label, int) else torch.float
666
  batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
667
+
668
  return batch
geneformer/emb_extractor.py CHANGED
@@ -24,8 +24,8 @@ import torch
24
  from tdigest import TDigest
25
  from tqdm.auto import trange
26
 
27
- from . import perturber_utils as pu
28
  from . import TOKEN_DICTIONARY_FILE
 
29
 
30
  logger = logging.getLogger(__name__)
31
 
@@ -45,7 +45,7 @@ def get_embs(
45
  ):
46
  model_input_size = pu.get_model_input_size(model)
47
  total_batch_length = len(filtered_input_data)
48
-
49
  if summary_stat is None:
50
  embs_list = []
51
  elif summary_stat is not None:
@@ -73,17 +73,23 @@ def get_embs(
73
  if emb_mode == "cls":
74
  assert cls_present, "<cls> token missing in token dictionary"
75
  # Check to make sure that the first token of the filtered input data is cls token
76
- gene_token_dict = {v:k for k,v in token_gene_dict.items()}
77
  cls_token_id = gene_token_dict["<cls>"]
78
- assert filtered_input_data["input_ids"][0][0] == cls_token_id, "First token is not <cls> token value"
 
 
79
  elif emb_mode == "cell":
80
  if cls_present:
81
- logger.warning("CLS token present in token dictionary, excluding from average.")
 
 
82
  if eos_present:
83
- logger.warning("EOS token present in token dictionary, excluding from average.")
84
-
 
 
85
  overall_max_len = 0
86
-
87
  for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
88
  max_range = min(i + forward_batch_size, total_batch_length)
89
 
@@ -108,7 +114,7 @@ def get_embs(
108
 
109
  if emb_mode == "cell":
110
  if cls_present:
111
- non_cls_embs = embs_i[:, 1:, :] # Get all layers except the embs
112
  if eos_present:
113
  mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 2)
114
  else:
@@ -146,10 +152,10 @@ def get_embs(
146
  del embs_h
147
  del dict_h
148
  elif emb_mode == "cls":
149
- cls_embs = embs_i[:,0,:].clone().detach() # CLS token layer
150
  embs_list.append(cls_embs)
151
  del cls_embs
152
-
153
  overall_max_len = max(overall_max_len, max_len)
154
  del outputs
155
  del minibatch
@@ -157,8 +163,7 @@ def get_embs(
157
  del embs_i
158
 
159
  torch.cuda.empty_cache()
160
-
161
-
162
  if summary_stat is None:
163
  if (emb_mode == "cell") or (emb_mode == "cls"):
164
  embs_stack = torch.cat(embs_list, dim=0)
@@ -204,6 +209,7 @@ def accumulate_tdigests(embs_tdigests, mean_embs, emb_dims):
204
  for j in range(emb_dims)
205
  ]
206
 
 
207
  def update_tdigest_dict(embs_tdigests_dict, gene, gene_embs, emb_dims):
208
  embs_tdigests_dict[gene] = accumulate_tdigests(
209
  embs_tdigests_dict[gene], gene_embs, emb_dims
@@ -294,11 +300,13 @@ def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict, seed=0):
294
 
295
  with plt.rc_context():
296
  ax = sc.pl.umap(adata, color=label, show=False, **default_kwargs_dict)
297
- ax.legend(markerscale=2,
298
- frameon=False,
299
- loc="center left",
300
- bbox_to_anchor=(1, 0.5),
301
- ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3))
 
 
302
  plt.show()
303
  plt.savefig(output_file, bbox_inches="tight")
304
 
@@ -394,7 +402,7 @@ class EmbExtractor:
394
  "emb_label": {None, list},
395
  "labels_to_plot": {None, list},
396
  "forward_batch_size": {int},
397
- "token_dictionary_file" : {None, str},
398
  "nproc": {int},
399
  "summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
400
  }
@@ -631,13 +639,15 @@ class EmbExtractor:
631
  embs = embs.mean(dim=0)
632
  emb_dims = pu.get_model_emb_dims(model)
633
  embs_df = pd.DataFrame(
634
- embs_df[0:emb_dims-1].mean(axis="rows"), columns=[self.exact_summary_stat]
 
635
  ).T
636
  elif self.exact_summary_stat == "exact_median":
637
  embs = torch.median(embs, dim=0)[0]
638
  emb_dims = pu.get_model_emb_dims(model)
639
  embs_df = pd.DataFrame(
640
- embs_df[0:emb_dims-1].median(axis="rows"), columns=[self.exact_summary_stat]
 
641
  ).T
642
 
643
  if cell_state is not None:
@@ -838,4 +848,4 @@ class EmbExtractor:
838
  output_file = (
839
  Path(output_directory) / output_prefix_label
840
  ).with_suffix(".pdf")
841
- plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
 
24
  from tdigest import TDigest
25
  from tqdm.auto import trange
26
 
 
27
  from . import TOKEN_DICTIONARY_FILE
28
+ from . import perturber_utils as pu
29
 
30
  logger = logging.getLogger(__name__)
31
 
 
45
  ):
46
  model_input_size = pu.get_model_input_size(model)
47
  total_batch_length = len(filtered_input_data)
48
+
49
  if summary_stat is None:
50
  embs_list = []
51
  elif summary_stat is not None:
 
73
  if emb_mode == "cls":
74
  assert cls_present, "<cls> token missing in token dictionary"
75
  # Check to make sure that the first token of the filtered input data is cls token
76
+ gene_token_dict = {v: k for k, v in token_gene_dict.items()}
77
  cls_token_id = gene_token_dict["<cls>"]
78
+ assert (
79
+ filtered_input_data["input_ids"][0][0] == cls_token_id
80
+ ), "First token is not <cls> token value"
81
  elif emb_mode == "cell":
82
  if cls_present:
83
+ logger.warning(
84
+ "CLS token present in token dictionary, excluding from average."
85
+ )
86
  if eos_present:
87
+ logger.warning(
88
+ "EOS token present in token dictionary, excluding from average."
89
+ )
90
+
91
  overall_max_len = 0
92
+
93
  for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
94
  max_range = min(i + forward_batch_size, total_batch_length)
95
 
 
114
 
115
  if emb_mode == "cell":
116
  if cls_present:
117
+ non_cls_embs = embs_i[:, 1:, :] # Get all layers except the embs
118
  if eos_present:
119
  mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 2)
120
  else:
 
152
  del embs_h
153
  del dict_h
154
  elif emb_mode == "cls":
155
+ cls_embs = embs_i[:, 0, :].clone().detach() # CLS token layer
156
  embs_list.append(cls_embs)
157
  del cls_embs
158
+
159
  overall_max_len = max(overall_max_len, max_len)
160
  del outputs
161
  del minibatch
 
163
  del embs_i
164
 
165
  torch.cuda.empty_cache()
166
+
 
167
  if summary_stat is None:
168
  if (emb_mode == "cell") or (emb_mode == "cls"):
169
  embs_stack = torch.cat(embs_list, dim=0)
 
209
  for j in range(emb_dims)
210
  ]
211
 
212
+
213
  def update_tdigest_dict(embs_tdigests_dict, gene, gene_embs, emb_dims):
214
  embs_tdigests_dict[gene] = accumulate_tdigests(
215
  embs_tdigests_dict[gene], gene_embs, emb_dims
 
300
 
301
  with plt.rc_context():
302
  ax = sc.pl.umap(adata, color=label, show=False, **default_kwargs_dict)
303
+ ax.legend(
304
+ markerscale=2,
305
+ frameon=False,
306
+ loc="center left",
307
+ bbox_to_anchor=(1, 0.5),
308
+ ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3),
309
+ )
310
  plt.show()
311
  plt.savefig(output_file, bbox_inches="tight")
312
 
 
402
  "emb_label": {None, list},
403
  "labels_to_plot": {None, list},
404
  "forward_batch_size": {int},
405
+ "token_dictionary_file": {None, str},
406
  "nproc": {int},
407
  "summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
408
  }
 
639
  embs = embs.mean(dim=0)
640
  emb_dims = pu.get_model_emb_dims(model)
641
  embs_df = pd.DataFrame(
642
+ embs_df[0 : emb_dims - 1].mean(axis="rows"),
643
+ columns=[self.exact_summary_stat],
644
  ).T
645
  elif self.exact_summary_stat == "exact_median":
646
  embs = torch.median(embs, dim=0)[0]
647
  emb_dims = pu.get_model_emb_dims(model)
648
  embs_df = pd.DataFrame(
649
+ embs_df[0 : emb_dims - 1].median(axis="rows"),
650
+ columns=[self.exact_summary_stat],
651
  ).T
652
 
653
  if cell_state is not None:
 
848
  output_file = (
849
  Path(output_directory) / output_prefix_label
850
  ).with_suffix(".pdf")
851
+ plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
geneformer/evaluation_utils.py CHANGED
@@ -20,8 +20,8 @@ from sklearn.metrics import (
20
  )
21
  from tqdm.auto import trange
22
 
23
- from .emb_extractor import make_colorbar
24
  from . import TOKEN_DICTIONARY_FILE
 
25
 
26
  logger = logging.getLogger(__name__)
27
 
 
20
  )
21
  from tqdm.auto import trange
22
 
 
23
  from . import TOKEN_DICTIONARY_FILE
24
+ from .emb_extractor import make_colorbar
25
 
26
  logger = logging.getLogger(__name__)
27
 
geneformer/in_silico_perturber.py CHANGED
@@ -38,16 +38,15 @@ import logging
38
  import os
39
  import pickle
40
  from collections import defaultdict
41
- from multiprocess import set_start_method
42
- from typing import List
43
 
44
  import torch
45
  from datasets import Dataset, disable_progress_bars
 
46
  from tqdm.auto import trange
47
 
 
48
  from . import perturber_utils as pu
49
  from .emb_extractor import get_embs
50
- from . import TOKEN_DICTIONARY_FILE
51
 
52
  disable_progress_bars()
53
 
@@ -71,7 +70,7 @@ class InSilicoPerturber:
71
  "max_ncells": {None, int},
72
  "cell_inds_to_perturb": {"all", dict},
73
  "emb_layer": {-1, 0},
74
- "token_dictionary_file" : {None, str},
75
  "forward_batch_size": {int},
76
  "nproc": {int},
77
  }
@@ -239,13 +238,14 @@ class InSilicoPerturber:
239
  self.cls_token_id = self.gene_token_dict.get("<cls>")
240
  self.eos_token_id = self.gene_token_dict.get("<eos>")
241
 
242
-
243
  # Identify if special token is present in the token dictionary
244
  if (self.cls_token_id is not None) and (self.eos_token_id is not None):
245
  self.special_token = True
246
  else:
247
  if "cls" in self.emb_mode:
248
- logger.error(f"emb_mode set to {self.emb_mode} but <cls> or <eos> token not in token dictionary.")
 
 
249
  raise
250
  self.special_token = False
251
 
@@ -454,17 +454,21 @@ class InSilicoPerturber:
454
 
455
  # Ensure emb_mode is cls if first token of the filtered input data is cls token
456
  if self.special_token:
457
- if (filtered_input_data["input_ids"][0][0] == self.cls_token_id) and ("cls" not in self.emb_mode):
 
 
458
  logger.error(
459
- "Emb mode 'cls' or 'cls_and_gene' required when first token is <cls>."
460
- )
461
  raise
462
- if ("cls" in self.emb_mode):
463
- if (filtered_input_data["input_ids"][0][0] != self.cls_token_id) or (filtered_input_data["input_ids"][0][-1] != self.eos_token_id):
 
 
464
  logger.error(
465
- "Emb mode 'cls' and 'cls_and_gene' require that first token is <cls> and last token is <eos>."
466
- )
467
- raise
468
 
469
  filtered_input_data = self.apply_additional_filters(filtered_input_data)
470
 
@@ -530,7 +534,6 @@ class InSilicoPerturber:
530
  layer_to_quant: int,
531
  output_path_prefix: str,
532
  ):
533
-
534
  def make_group_perturbation_batch(example):
535
  example_input_ids = example["input_ids"]
536
  example["tokens_to_perturb"] = self.tokens_to_perturb
@@ -549,7 +552,9 @@ class InSilicoPerturber:
549
  if self.perturb_type == "delete":
550
  example = pu.delete_indices(example)
551
  elif self.perturb_type == "overexpress":
552
- example = pu.overexpress_tokens(example, self.max_len, self.special_token)
 
 
553
  example["n_overflow"] = pu.calc_n_overflow(
554
  self.max_len,
555
  example["length"],
@@ -570,7 +575,7 @@ class InSilicoPerturber:
570
  perturbed_data = filtered_input_data.map(
571
  make_group_perturbation_batch, num_proc=self.nproc
572
  )
573
-
574
  if self.perturb_type == "overexpress":
575
  filtered_input_data = filtered_input_data.add_column(
576
  "n_overflow", perturbed_data["n_overflow"]
@@ -752,7 +757,6 @@ class InSilicoPerturber:
752
  f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
753
  )
754
 
755
-
756
  def isp_perturb_set_special(
757
  self,
758
  model,
@@ -760,7 +764,6 @@ class InSilicoPerturber:
760
  layer_to_quant: int,
761
  output_path_prefix: str,
762
  ):
763
-
764
  def make_group_perturbation_batch(example):
765
  example_input_ids = example["input_ids"]
766
  example["tokens_to_perturb"] = self.tokens_to_perturb
@@ -779,7 +782,9 @@ class InSilicoPerturber:
779
  if self.perturb_type == "delete":
780
  example = pu.delete_indices(example)
781
  elif self.perturb_type == "overexpress":
782
- example = pu.overexpress_tokens(example, self.max_len, self.special_token)
 
 
783
  example["n_overflow"] = pu.calc_n_overflow(
784
  self.max_len,
785
  example["length"],
@@ -808,7 +813,7 @@ class InSilicoPerturber:
808
  filtered_input_data = filtered_input_data.map(
809
  pu.truncate_by_n_overflow_special, num_proc=self.nproc
810
  )
811
-
812
  if self.emb_mode == "cls_and_gene":
813
  stored_gene_embs_dict = defaultdict(list)
814
 
@@ -847,28 +852,29 @@ class InSilicoPerturber:
847
  summary_stat=None,
848
  silent=True,
849
  )
850
-
851
- # Calculate the cosine similarities
852
  cls_cos_sims = pu.quant_cos_sims(
853
  perturbation_cls_emb,
854
  original_cls_emb,
855
  self.cell_states_to_model,
856
  self.state_embs_dict,
857
- emb_mode="cell")
858
-
 
859
  # Update perturbation dictionary
860
  if self.cell_states_to_model is None:
861
  cos_sims_dict = self.update_perturbation_dictionary(
862
  cos_sims_dict,
863
  cls_cos_sims,
864
- gene_list = None,
865
  )
866
  else:
867
  for state in cos_sims_dict.keys():
868
  cos_sims_dict[state] = self.update_perturbation_dictionary(
869
  cos_sims_dict[state],
870
  cls_cos_sims[state],
871
- gene_list = None,
872
  )
873
 
874
  ##### CLS and Gene Embedding Mode #####
@@ -908,9 +914,13 @@ class InSilicoPerturber:
908
  # remove special tokens and padding
909
  original_emb = original_emb[:, 1:-1, :]
910
  if self.perturb_type == "overexpress":
911
- perturbation_emb = full_perturbation_emb[:,1+len(self.tokens_to_perturb):-1,:]
 
 
912
  elif self.perturb_type == "delete":
913
- perturbation_emb = full_perturbation_emb[:,1:max(perturbation_batch["length"])-1,:]
 
 
914
 
915
  n_perturbation_genes = perturbation_emb.size()[1]
916
 
@@ -923,8 +933,8 @@ class InSilicoPerturber:
923
  )
924
 
925
  # get cls emb
926
- original_cls_emb = full_original_emb[:,0,:]
927
- perturbation_cls_emb = full_perturbation_emb[:,0,:]
928
 
929
  cls_cos_sims = pu.quant_cos_sims(
930
  perturbation_cls_emb,
@@ -939,7 +949,10 @@ class InSilicoPerturber:
939
 
940
  gene_list = minibatch["input_ids"]
941
  # need to truncate gene_list
942
- genes_to_exclude = self.tokens_to_perturb + [self.cls_token_id, self.eos_token_id]
 
 
 
943
  gene_list = [
944
  [g for g in genes if g not in genes_to_exclude][
945
  :n_perturbation_genes
@@ -968,20 +981,20 @@ class InSilicoPerturber:
968
  cos_sims_dict = self.update_perturbation_dictionary(
969
  cos_sims_dict,
970
  cls_cos_sims,
971
- gene_list = None,
972
  )
973
  else:
974
  for state in cos_sims_dict.keys():
975
  cos_sims_dict[state] = self.update_perturbation_dictionary(
976
  cos_sims_dict[state],
977
  cls_cos_sims[state],
978
- gene_list = None,
979
  )
980
  del full_original_emb
981
  del original_emb
982
  del full_perturbation_emb
983
- del perturbation_emb
984
- del gene_cos_sims
985
 
986
  del original_cls_emb
987
  del perturbation_cls_emb
@@ -1035,12 +1048,12 @@ class InSilicoPerturber:
1035
  summary_stat=None,
1036
  silent=True,
1037
  )
1038
-
1039
  if self.cell_states_to_model is not None:
1040
  original_cell_emb = pu.compute_nonpadded_cell_embedding(
1041
  full_original_emb, "mean_pool"
1042
  )
1043
-
1044
  # gene_list is used to assign cos sims back to genes
1045
  gene_list = example_cell["input_ids"][0][:]
1046
  # need to remove the anchor gene
@@ -1049,15 +1062,13 @@ class InSilicoPerturber:
1049
  gene_list.remove(token)
1050
  # index 0 is not overexpressed so remove
1051
  if self.perturb_type == "overexpress":
1052
- gene_list = gene_list[
1053
- num_inds_perturbed:
1054
- ]
1055
  # remove perturbed index for gene list dict
1056
  perturbed_gene_dict = {
1057
  gene: gene_list[:i] + gene_list[i + 1 :]
1058
  for i, gene in enumerate(gene_list)
1059
  }
1060
-
1061
  perturbation_batch, indices_to_perturb = pu.make_perturbation_batch(
1062
  example_cell,
1063
  self.perturb_type,
@@ -1068,11 +1079,19 @@ class InSilicoPerturber:
1068
  )
1069
 
1070
  ispall_total_batch_length = len(perturbation_batch)
1071
- for i in trange(0, ispall_total_batch_length, self.forward_batch_size, leave=False):
1072
- ispall_max_range = min(i + self.forward_batch_size, ispall_total_batch_length)
1073
- perturbation_minibatch = perturbation_batch.select([i for i in range(i, ispall_max_range)])
1074
- indices_to_perturb_mini = indices_to_perturb[i : ispall_max_range]
1075
- gene_list_mini = gene_list[i : ispall_max_range] # only perturbed genes from this minibatch
 
 
 
 
 
 
 
 
1076
 
1077
  full_perturbation_emb = get_embs(
1078
  model,
@@ -1085,18 +1104,20 @@ class InSilicoPerturber:
1085
  summary_stat=None,
1086
  silent=True,
1087
  )
1088
-
1089
  del perturbation_minibatch
1090
-
1091
  # need to remove overexpressed gene to quantify cosine shifts
1092
  if self.perturb_type == "overexpress":
1093
  perturbation_emb = full_perturbation_emb[:, num_inds_perturbed:, :]
1094
-
1095
  elif self.perturb_type == "delete":
1096
  perturbation_emb = full_perturbation_emb
1097
-
1098
-
1099
- if self.cell_states_to_model is None or self.emb_mode == "cell_and_gene":
 
 
1100
  original_emb_minibatch = pu.make_comparison_batch(
1101
  full_original_emb, indices_to_perturb_mini, perturb_group=False
1102
  )
@@ -1108,12 +1129,12 @@ class InSilicoPerturber:
1108
  emb_mode="gene",
1109
  )
1110
  del original_emb_minibatch
1111
-
1112
  if self.cell_states_to_model is not None:
1113
  perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
1114
  full_perturbation_emb, "mean_pool"
1115
  )
1116
-
1117
  cell_cos_sims = pu.quant_cos_sims(
1118
  perturbation_cell_emb,
1119
  original_cell_emb,
@@ -1122,9 +1143,8 @@ class InSilicoPerturber:
1122
  emb_mode="cell",
1123
  )
1124
  del perturbation_cell_emb
1125
-
1126
- if self.emb_mode == "cell_and_gene":
1127
 
 
1128
  for perturbation_i, perturbed_gene in enumerate(gene_list_mini):
1129
  for gene_j, affected_gene in enumerate(
1130
  perturbed_gene_dict[perturbed_gene]
@@ -1137,9 +1157,9 @@ class InSilicoPerturber:
1137
  stored_gene_embs_dict[
1138
  (perturbed_gene, affected_gene)
1139
  ] = gene_cos_sims[perturbation_i, gene_j].item()
1140
-
1141
  del full_perturbation_emb
1142
-
1143
  if self.cell_states_to_model is None:
1144
  cos_sims_data = torch.mean(gene_cos_sims, dim=1)
1145
  cos_sims_dict = self.update_perturbation_dictionary(
@@ -1155,9 +1175,9 @@ class InSilicoPerturber:
1155
  cos_sims_data[state],
1156
  gene_list_mini,
1157
  )
1158
-
1159
  # save dict to disk every self.clear_mem_ncells/10 (default 100) simulated cells
1160
- if i % self.clear_mem_ncells/10 == 0:
1161
  pu.write_perturbation_dictionary(
1162
  cos_sims_dict,
1163
  f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
@@ -1167,7 +1187,7 @@ class InSilicoPerturber:
1167
  stored_gene_embs_dict,
1168
  f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
1169
  )
1170
-
1171
  # reset and clear memory every self.clear_mem_ncells (default 1000) simulated cells or at the end of the example cell
1172
  if i % self.clear_mem_ncells == 0:
1173
  pickle_batch += 1
@@ -1176,18 +1196,21 @@ class InSilicoPerturber:
1176
  else:
1177
  cos_sims_dict = {
1178
  state: defaultdict(list)
1179
- for state in pu.get_possible_states(self.cell_states_to_model)
 
 
1180
  }
1181
-
1182
  if self.emb_mode == "cell_and_gene":
1183
  stored_gene_embs_dict = defaultdict(list)
1184
-
1185
  torch.cuda.empty_cache()
1186
-
1187
  pu.write_perturbation_dictionary(
1188
- cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}"
 
1189
  )
1190
-
1191
  if self.emb_mode == "cell_and_gene":
1192
  pu.write_perturbation_dictionary(
1193
  stored_gene_embs_dict,
@@ -1212,7 +1235,7 @@ class InSilicoPerturber:
1212
  if self.cell_states_to_model is not None:
1213
  del original_cell_emb
1214
  torch.cuda.empty_cache()
1215
-
1216
  def isp_perturb_all_special(
1217
  self,
1218
  model,
@@ -1248,7 +1271,7 @@ class InSilicoPerturber:
1248
  self.token_gene_dict,
1249
  summary_stat=None,
1250
  silent=True,
1251
- )
1252
  elif self.emb_mode == "cls_and_gene":
1253
  full_original_emb = get_embs(
1254
  model,
@@ -1261,8 +1284,8 @@ class InSilicoPerturber:
1261
  summary_stat=None,
1262
  silent=True,
1263
  )
1264
- original_cls_emb = full_original_emb[:,0,:].clone().detach()
1265
-
1266
  # gene_list is used to assign cos sims back to genes
1267
  gene_list = example_cell["input_ids"][0][:]
1268
 
@@ -1275,9 +1298,7 @@ class InSilicoPerturber:
1275
  gene_list.remove(token)
1276
  # index 0 is not overexpressed so remove
1277
  if self.perturb_type == "overexpress":
1278
- gene_list = gene_list[
1279
- num_inds_perturbed:
1280
- ]
1281
  # remove perturbed index for gene list dict
1282
  perturbed_gene_dict = {
1283
  gene: gene_list[:i] + gene_list[i + 1 :]
@@ -1294,12 +1315,20 @@ class InSilicoPerturber:
1294
  )
1295
 
1296
  ispall_total_batch_length = len(perturbation_batch)
1297
- for i in trange(0, ispall_total_batch_length, self.forward_batch_size, leave=False):
1298
- ispall_max_range = min(i + self.forward_batch_size, ispall_total_batch_length)
1299
- perturbation_minibatch = perturbation_batch.select([i for i in range(i, ispall_max_range)])
1300
- indices_to_perturb_mini = indices_to_perturb[i : ispall_max_range]
1301
- gene_list_mini = gene_list[i : ispall_max_range] # only perturbed genes from this minibatch
1302
-
 
 
 
 
 
 
 
 
1303
  ##### CLS Embedding Mode #####
1304
  if self.emb_mode == "cls":
1305
  # Extract cls embeddings from perturbed cells
@@ -1314,7 +1343,7 @@ class InSilicoPerturber:
1314
  summary_stat=None,
1315
  silent=True,
1316
  )
1317
-
1318
  # Calculate cosine similarities
1319
  cls_cos_sims = pu.quant_cos_sims(
1320
  perturbation_cls_emb,
@@ -1322,8 +1351,8 @@ class InSilicoPerturber:
1322
  self.cell_states_to_model,
1323
  self.state_embs_dict,
1324
  emb_mode="cell",
1325
- )
1326
-
1327
  if self.cell_states_to_model is None:
1328
  cos_sims_dict = self.update_perturbation_dictionary(
1329
  cos_sims_dict,
@@ -1331,20 +1360,19 @@ class InSilicoPerturber:
1331
  gene_list_mini,
1332
  )
1333
  else:
1334
-
1335
  for state in cos_sims_dict.keys():
1336
  cos_sims_dict[state] = self.update_perturbation_dictionary(
1337
  cos_sims_dict[state],
1338
  cls_cos_sims[state],
1339
  gene_list_mini,
1340
  )
1341
-
1342
  del perturbation_minibatch
1343
  del perturbation_cls_emb
1344
  del cls_cos_sims
1345
-
1346
  ##### CLS and Gene Embedding Mode #####
1347
- elif self.emb_mode == "cls_and_gene":
1348
  full_perturbation_emb = get_embs(
1349
  model,
1350
  perturbation_minibatch,
@@ -1356,18 +1384,26 @@ class InSilicoPerturber:
1356
  summary_stat=None,
1357
  silent=True,
1358
  )
1359
-
1360
  # need to remove overexpressed gene and cls/eos to quantify cosine shifts
1361
  if self.perturb_type == "overexpress":
1362
- perturbation_emb = full_perturbation_emb[:, 1+num_inds_perturbed:-1, :].clone().detach()
 
 
 
 
1363
  elif self.perturb_type == "delete":
1364
- perturbation_emb = full_perturbation_emb[:, 1:-1, :].clone().detach()
1365
-
 
 
1366
  original_emb_minibatch = pu.make_comparison_batch(
1367
  full_original_emb, indices_to_perturb_mini, perturb_group=False
1368
  )
1369
-
1370
- original_emb_minibatch = original_emb_minibatch[:, 1:-1, :].clone().detach()
 
 
1371
  gene_cos_sims = pu.quant_cos_sims(
1372
  perturbation_emb,
1373
  original_emb_minibatch,
@@ -1375,7 +1411,7 @@ class InSilicoPerturber:
1375
  self.state_embs_dict,
1376
  emb_mode="gene",
1377
  )
1378
-
1379
  for perturbation_i, perturbed_gene in enumerate(gene_list_mini):
1380
  for gene_j, affected_gene in enumerate(
1381
  perturbed_gene_dict[perturbed_gene]
@@ -1388,10 +1424,12 @@ class InSilicoPerturber:
1388
  stored_gene_embs_dict[
1389
  (perturbed_gene, affected_gene)
1390
  ] = gene_cos_sims[perturbation_i, gene_j].item()
1391
-
1392
  # get cls emb
1393
- perturbation_cls_emb = full_perturbation_emb[:,0,:].clone().detach()
1394
-
 
 
1395
  cls_cos_sims = pu.quant_cos_sims(
1396
  perturbation_cls_emb,
1397
  original_cls_emb,
@@ -1399,7 +1437,7 @@ class InSilicoPerturber:
1399
  self.state_embs_dict,
1400
  emb_mode="cell",
1401
  )
1402
-
1403
  if self.cell_states_to_model is None:
1404
  cos_sims_dict = self.update_perturbation_dictionary(
1405
  cos_sims_dict,
@@ -1413,7 +1451,7 @@ class InSilicoPerturber:
1413
  cls_cos_sims[state],
1414
  gene_list_mini,
1415
  )
1416
-
1417
  del perturbation_minibatch
1418
  del original_emb_minibatch
1419
  del full_perturbation_emb
@@ -1421,9 +1459,9 @@ class InSilicoPerturber:
1421
  del perturbation_cls_emb
1422
  del cls_cos_sims
1423
  del gene_cos_sims
1424
-
1425
  # save dict to disk every self.clear_mem_ncells/10 (default 100) simulated cells
1426
- if i % max(1,self.clear_mem_ncells/10) == 0:
1427
  pu.write_perturbation_dictionary(
1428
  cos_sims_dict,
1429
  f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
@@ -1433,7 +1471,7 @@ class InSilicoPerturber:
1433
  stored_gene_embs_dict,
1434
  f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
1435
  )
1436
-
1437
  # reset and clear memory every self.clear_mem_ncells (default 1000) simulated cells or at the end of the example cell
1438
  if i % self.clear_mem_ncells == 0:
1439
  pickle_batch += 1
@@ -1442,18 +1480,21 @@ class InSilicoPerturber:
1442
  else:
1443
  cos_sims_dict = {
1444
  state: defaultdict(list)
1445
- for state in pu.get_possible_states(self.cell_states_to_model)
 
 
1446
  }
1447
-
1448
  if self.emb_mode == "cls_and_gene":
1449
  stored_gene_embs_dict = defaultdict(list)
1450
-
1451
  torch.cuda.empty_cache()
1452
-
1453
  pu.write_perturbation_dictionary(
1454
- cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}"
 
1455
  )
1456
-
1457
  if self.emb_mode == "cls_and_gene":
1458
  pu.write_perturbation_dictionary(
1459
  stored_gene_embs_dict,
@@ -1470,8 +1511,8 @@ class InSilicoPerturber:
1470
  }
1471
 
1472
  if self.emb_mode == "cls_and_gene":
1473
- stored_gene_embs_dict = defaultdict(list)
1474
-
1475
  # clear memory between cells
1476
  del perturbation_batch
1477
  del original_cls_emb
@@ -1479,7 +1520,6 @@ class InSilicoPerturber:
1479
  del full_original_emb
1480
  torch.cuda.empty_cache()
1481
 
1482
-
1483
  def update_perturbation_dictionary(
1484
  self,
1485
  cos_sims_dict: defaultdict,
@@ -1514,4 +1554,4 @@ class InSilicoPerturber:
1514
  for i, cos in enumerate(cos_sims_data.tolist()):
1515
  cos_sims_dict[(gene_list[i], "cell_emb")].append(cos)
1516
 
1517
- return cos_sims_dict
 
38
  import os
39
  import pickle
40
  from collections import defaultdict
 
 
41
 
42
  import torch
43
  from datasets import Dataset, disable_progress_bars
44
+ from multiprocess import set_start_method
45
  from tqdm.auto import trange
46
 
47
+ from . import TOKEN_DICTIONARY_FILE
48
  from . import perturber_utils as pu
49
  from .emb_extractor import get_embs
 
50
 
51
  disable_progress_bars()
52
 
 
70
  "max_ncells": {None, int},
71
  "cell_inds_to_perturb": {"all", dict},
72
  "emb_layer": {-1, 0},
73
+ "token_dictionary_file": {None, str},
74
  "forward_batch_size": {int},
75
  "nproc": {int},
76
  }
 
238
  self.cls_token_id = self.gene_token_dict.get("<cls>")
239
  self.eos_token_id = self.gene_token_dict.get("<eos>")
240
 
 
241
  # Identify if special token is present in the token dictionary
242
  if (self.cls_token_id is not None) and (self.eos_token_id is not None):
243
  self.special_token = True
244
  else:
245
  if "cls" in self.emb_mode:
246
+ logger.error(
247
+ f"emb_mode set to {self.emb_mode} but <cls> or <eos> token not in token dictionary."
248
+ )
249
  raise
250
  self.special_token = False
251
 
 
454
 
455
  # Ensure emb_mode is cls if first token of the filtered input data is cls token
456
  if self.special_token:
457
+ if (filtered_input_data["input_ids"][0][0] == self.cls_token_id) and (
458
+ "cls" not in self.emb_mode
459
+ ):
460
  logger.error(
461
+ "Emb mode 'cls' or 'cls_and_gene' required when first token is <cls>."
462
+ )
463
  raise
464
+ if "cls" in self.emb_mode:
465
+ if (filtered_input_data["input_ids"][0][0] != self.cls_token_id) or (
466
+ filtered_input_data["input_ids"][0][-1] != self.eos_token_id
467
+ ):
468
  logger.error(
469
+ "Emb mode 'cls' and 'cls_and_gene' require that first token is <cls> and last token is <eos>."
470
+ )
471
+ raise
472
 
473
  filtered_input_data = self.apply_additional_filters(filtered_input_data)
474
 
 
534
  layer_to_quant: int,
535
  output_path_prefix: str,
536
  ):
 
537
  def make_group_perturbation_batch(example):
538
  example_input_ids = example["input_ids"]
539
  example["tokens_to_perturb"] = self.tokens_to_perturb
 
552
  if self.perturb_type == "delete":
553
  example = pu.delete_indices(example)
554
  elif self.perturb_type == "overexpress":
555
+ example = pu.overexpress_tokens(
556
+ example, self.max_len, self.special_token
557
+ )
558
  example["n_overflow"] = pu.calc_n_overflow(
559
  self.max_len,
560
  example["length"],
 
575
  perturbed_data = filtered_input_data.map(
576
  make_group_perturbation_batch, num_proc=self.nproc
577
  )
578
+
579
  if self.perturb_type == "overexpress":
580
  filtered_input_data = filtered_input_data.add_column(
581
  "n_overflow", perturbed_data["n_overflow"]
 
757
  f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
758
  )
759
 
 
760
  def isp_perturb_set_special(
761
  self,
762
  model,
 
764
  layer_to_quant: int,
765
  output_path_prefix: str,
766
  ):
 
767
  def make_group_perturbation_batch(example):
768
  example_input_ids = example["input_ids"]
769
  example["tokens_to_perturb"] = self.tokens_to_perturb
 
782
  if self.perturb_type == "delete":
783
  example = pu.delete_indices(example)
784
  elif self.perturb_type == "overexpress":
785
+ example = pu.overexpress_tokens(
786
+ example, self.max_len, self.special_token
787
+ )
788
  example["n_overflow"] = pu.calc_n_overflow(
789
  self.max_len,
790
  example["length"],
 
813
  filtered_input_data = filtered_input_data.map(
814
  pu.truncate_by_n_overflow_special, num_proc=self.nproc
815
  )
816
+
817
  if self.emb_mode == "cls_and_gene":
818
  stored_gene_embs_dict = defaultdict(list)
819
 
 
852
  summary_stat=None,
853
  silent=True,
854
  )
855
+
856
+ # Calculate the cosine similarities
857
  cls_cos_sims = pu.quant_cos_sims(
858
  perturbation_cls_emb,
859
  original_cls_emb,
860
  self.cell_states_to_model,
861
  self.state_embs_dict,
862
+ emb_mode="cell",
863
+ )
864
+
865
  # Update perturbation dictionary
866
  if self.cell_states_to_model is None:
867
  cos_sims_dict = self.update_perturbation_dictionary(
868
  cos_sims_dict,
869
  cls_cos_sims,
870
+ gene_list=None,
871
  )
872
  else:
873
  for state in cos_sims_dict.keys():
874
  cos_sims_dict[state] = self.update_perturbation_dictionary(
875
  cos_sims_dict[state],
876
  cls_cos_sims[state],
877
+ gene_list=None,
878
  )
879
 
880
  ##### CLS and Gene Embedding Mode #####
 
914
  # remove special tokens and padding
915
  original_emb = original_emb[:, 1:-1, :]
916
  if self.perturb_type == "overexpress":
917
+ perturbation_emb = full_perturbation_emb[
918
+ :, 1 + len(self.tokens_to_perturb) : -1, :
919
+ ]
920
  elif self.perturb_type == "delete":
921
+ perturbation_emb = full_perturbation_emb[
922
+ :, 1 : max(perturbation_batch["length"]) - 1, :
923
+ ]
924
 
925
  n_perturbation_genes = perturbation_emb.size()[1]
926
 
 
933
  )
934
 
935
  # get cls emb
936
+ original_cls_emb = full_original_emb[:, 0, :]
937
+ perturbation_cls_emb = full_perturbation_emb[:, 0, :]
938
 
939
  cls_cos_sims = pu.quant_cos_sims(
940
  perturbation_cls_emb,
 
949
 
950
  gene_list = minibatch["input_ids"]
951
  # need to truncate gene_list
952
+ genes_to_exclude = self.tokens_to_perturb + [
953
+ self.cls_token_id,
954
+ self.eos_token_id,
955
+ ]
956
  gene_list = [
957
  [g for g in genes if g not in genes_to_exclude][
958
  :n_perturbation_genes
 
981
  cos_sims_dict = self.update_perturbation_dictionary(
982
  cos_sims_dict,
983
  cls_cos_sims,
984
+ gene_list=None,
985
  )
986
  else:
987
  for state in cos_sims_dict.keys():
988
  cos_sims_dict[state] = self.update_perturbation_dictionary(
989
  cos_sims_dict[state],
990
  cls_cos_sims[state],
991
+ gene_list=None,
992
  )
993
  del full_original_emb
994
  del original_emb
995
  del full_perturbation_emb
996
+ del perturbation_emb
997
+ del gene_cos_sims
998
 
999
  del original_cls_emb
1000
  del perturbation_cls_emb
 
1048
  summary_stat=None,
1049
  silent=True,
1050
  )
1051
+
1052
  if self.cell_states_to_model is not None:
1053
  original_cell_emb = pu.compute_nonpadded_cell_embedding(
1054
  full_original_emb, "mean_pool"
1055
  )
1056
+
1057
  # gene_list is used to assign cos sims back to genes
1058
  gene_list = example_cell["input_ids"][0][:]
1059
  # need to remove the anchor gene
 
1062
  gene_list.remove(token)
1063
  # index 0 is not overexpressed so remove
1064
  if self.perturb_type == "overexpress":
1065
+ gene_list = gene_list[num_inds_perturbed:]
 
 
1066
  # remove perturbed index for gene list dict
1067
  perturbed_gene_dict = {
1068
  gene: gene_list[:i] + gene_list[i + 1 :]
1069
  for i, gene in enumerate(gene_list)
1070
  }
1071
+
1072
  perturbation_batch, indices_to_perturb = pu.make_perturbation_batch(
1073
  example_cell,
1074
  self.perturb_type,
 
1079
  )
1080
 
1081
  ispall_total_batch_length = len(perturbation_batch)
1082
+ for i in trange(
1083
+ 0, ispall_total_batch_length, self.forward_batch_size, leave=False
1084
+ ):
1085
+ ispall_max_range = min(
1086
+ i + self.forward_batch_size, ispall_total_batch_length
1087
+ )
1088
+ perturbation_minibatch = perturbation_batch.select(
1089
+ [i for i in range(i, ispall_max_range)]
1090
+ )
1091
+ indices_to_perturb_mini = indices_to_perturb[i:ispall_max_range]
1092
+ gene_list_mini = gene_list[
1093
+ i:ispall_max_range
1094
+ ] # only perturbed genes from this minibatch
1095
 
1096
  full_perturbation_emb = get_embs(
1097
  model,
 
1104
  summary_stat=None,
1105
  silent=True,
1106
  )
1107
+
1108
  del perturbation_minibatch
1109
+
1110
  # need to remove overexpressed gene to quantify cosine shifts
1111
  if self.perturb_type == "overexpress":
1112
  perturbation_emb = full_perturbation_emb[:, num_inds_perturbed:, :]
1113
+
1114
  elif self.perturb_type == "delete":
1115
  perturbation_emb = full_perturbation_emb
1116
+
1117
+ if (
1118
+ self.cell_states_to_model is None
1119
+ or self.emb_mode == "cell_and_gene"
1120
+ ):
1121
  original_emb_minibatch = pu.make_comparison_batch(
1122
  full_original_emb, indices_to_perturb_mini, perturb_group=False
1123
  )
 
1129
  emb_mode="gene",
1130
  )
1131
  del original_emb_minibatch
1132
+
1133
  if self.cell_states_to_model is not None:
1134
  perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
1135
  full_perturbation_emb, "mean_pool"
1136
  )
1137
+
1138
  cell_cos_sims = pu.quant_cos_sims(
1139
  perturbation_cell_emb,
1140
  original_cell_emb,
 
1143
  emb_mode="cell",
1144
  )
1145
  del perturbation_cell_emb
 
 
1146
 
1147
+ if self.emb_mode == "cell_and_gene":
1148
  for perturbation_i, perturbed_gene in enumerate(gene_list_mini):
1149
  for gene_j, affected_gene in enumerate(
1150
  perturbed_gene_dict[perturbed_gene]
 
1157
  stored_gene_embs_dict[
1158
  (perturbed_gene, affected_gene)
1159
  ] = gene_cos_sims[perturbation_i, gene_j].item()
1160
+
1161
  del full_perturbation_emb
1162
+
1163
  if self.cell_states_to_model is None:
1164
  cos_sims_data = torch.mean(gene_cos_sims, dim=1)
1165
  cos_sims_dict = self.update_perturbation_dictionary(
 
1175
  cos_sims_data[state],
1176
  gene_list_mini,
1177
  )
1178
+
1179
  # save dict to disk every self.clear_mem_ncells/10 (default 100) simulated cells
1180
+ if i % self.clear_mem_ncells / 10 == 0:
1181
  pu.write_perturbation_dictionary(
1182
  cos_sims_dict,
1183
  f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
 
1187
  stored_gene_embs_dict,
1188
  f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
1189
  )
1190
+
1191
  # reset and clear memory every self.clear_mem_ncells (default 1000) simulated cells or at the end of the example cell
1192
  if i % self.clear_mem_ncells == 0:
1193
  pickle_batch += 1
 
1196
  else:
1197
  cos_sims_dict = {
1198
  state: defaultdict(list)
1199
+ for state in pu.get_possible_states(
1200
+ self.cell_states_to_model
1201
+ )
1202
  }
1203
+
1204
  if self.emb_mode == "cell_and_gene":
1205
  stored_gene_embs_dict = defaultdict(list)
1206
+
1207
  torch.cuda.empty_cache()
1208
+
1209
  pu.write_perturbation_dictionary(
1210
+ cos_sims_dict,
1211
+ f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
1212
  )
1213
+
1214
  if self.emb_mode == "cell_and_gene":
1215
  pu.write_perturbation_dictionary(
1216
  stored_gene_embs_dict,
 
1235
  if self.cell_states_to_model is not None:
1236
  del original_cell_emb
1237
  torch.cuda.empty_cache()
1238
+
1239
  def isp_perturb_all_special(
1240
  self,
1241
  model,
 
1271
  self.token_gene_dict,
1272
  summary_stat=None,
1273
  silent=True,
1274
+ )
1275
  elif self.emb_mode == "cls_and_gene":
1276
  full_original_emb = get_embs(
1277
  model,
 
1284
  summary_stat=None,
1285
  silent=True,
1286
  )
1287
+ original_cls_emb = full_original_emb[:, 0, :].clone().detach()
1288
+
1289
  # gene_list is used to assign cos sims back to genes
1290
  gene_list = example_cell["input_ids"][0][:]
1291
 
 
1298
  gene_list.remove(token)
1299
  # index 0 is not overexpressed so remove
1300
  if self.perturb_type == "overexpress":
1301
+ gene_list = gene_list[num_inds_perturbed:]
 
 
1302
  # remove perturbed index for gene list dict
1303
  perturbed_gene_dict = {
1304
  gene: gene_list[:i] + gene_list[i + 1 :]
 
1315
  )
1316
 
1317
  ispall_total_batch_length = len(perturbation_batch)
1318
+ for i in trange(
1319
+ 0, ispall_total_batch_length, self.forward_batch_size, leave=False
1320
+ ):
1321
+ ispall_max_range = min(
1322
+ i + self.forward_batch_size, ispall_total_batch_length
1323
+ )
1324
+ perturbation_minibatch = perturbation_batch.select(
1325
+ [i for i in range(i, ispall_max_range)]
1326
+ )
1327
+ indices_to_perturb_mini = indices_to_perturb[i:ispall_max_range]
1328
+ gene_list_mini = gene_list[
1329
+ i:ispall_max_range
1330
+ ] # only perturbed genes from this minibatch
1331
+
1332
  ##### CLS Embedding Mode #####
1333
  if self.emb_mode == "cls":
1334
  # Extract cls embeddings from perturbed cells
 
1343
  summary_stat=None,
1344
  silent=True,
1345
  )
1346
+
1347
  # Calculate cosine similarities
1348
  cls_cos_sims = pu.quant_cos_sims(
1349
  perturbation_cls_emb,
 
1351
  self.cell_states_to_model,
1352
  self.state_embs_dict,
1353
  emb_mode="cell",
1354
+ )
1355
+
1356
  if self.cell_states_to_model is None:
1357
  cos_sims_dict = self.update_perturbation_dictionary(
1358
  cos_sims_dict,
 
1360
  gene_list_mini,
1361
  )
1362
  else:
 
1363
  for state in cos_sims_dict.keys():
1364
  cos_sims_dict[state] = self.update_perturbation_dictionary(
1365
  cos_sims_dict[state],
1366
  cls_cos_sims[state],
1367
  gene_list_mini,
1368
  )
1369
+
1370
  del perturbation_minibatch
1371
  del perturbation_cls_emb
1372
  del cls_cos_sims
1373
+
1374
  ##### CLS and Gene Embedding Mode #####
1375
+ elif self.emb_mode == "cls_and_gene":
1376
  full_perturbation_emb = get_embs(
1377
  model,
1378
  perturbation_minibatch,
 
1384
  summary_stat=None,
1385
  silent=True,
1386
  )
1387
+
1388
  # need to remove overexpressed gene and cls/eos to quantify cosine shifts
1389
  if self.perturb_type == "overexpress":
1390
+ perturbation_emb = (
1391
+ full_perturbation_emb[:, 1 + num_inds_perturbed : -1, :]
1392
+ .clone()
1393
+ .detach()
1394
+ )
1395
  elif self.perturb_type == "delete":
1396
+ perturbation_emb = (
1397
+ full_perturbation_emb[:, 1:-1, :].clone().detach()
1398
+ )
1399
+
1400
  original_emb_minibatch = pu.make_comparison_batch(
1401
  full_original_emb, indices_to_perturb_mini, perturb_group=False
1402
  )
1403
+
1404
+ original_emb_minibatch = (
1405
+ original_emb_minibatch[:, 1:-1, :].clone().detach()
1406
+ )
1407
  gene_cos_sims = pu.quant_cos_sims(
1408
  perturbation_emb,
1409
  original_emb_minibatch,
 
1411
  self.state_embs_dict,
1412
  emb_mode="gene",
1413
  )
1414
+
1415
  for perturbation_i, perturbed_gene in enumerate(gene_list_mini):
1416
  for gene_j, affected_gene in enumerate(
1417
  perturbed_gene_dict[perturbed_gene]
 
1424
  stored_gene_embs_dict[
1425
  (perturbed_gene, affected_gene)
1426
  ] = gene_cos_sims[perturbation_i, gene_j].item()
1427
+
1428
  # get cls emb
1429
+ perturbation_cls_emb = (
1430
+ full_perturbation_emb[:, 0, :].clone().detach()
1431
+ )
1432
+
1433
  cls_cos_sims = pu.quant_cos_sims(
1434
  perturbation_cls_emb,
1435
  original_cls_emb,
 
1437
  self.state_embs_dict,
1438
  emb_mode="cell",
1439
  )
1440
+
1441
  if self.cell_states_to_model is None:
1442
  cos_sims_dict = self.update_perturbation_dictionary(
1443
  cos_sims_dict,
 
1451
  cls_cos_sims[state],
1452
  gene_list_mini,
1453
  )
1454
+
1455
  del perturbation_minibatch
1456
  del original_emb_minibatch
1457
  del full_perturbation_emb
 
1459
  del perturbation_cls_emb
1460
  del cls_cos_sims
1461
  del gene_cos_sims
1462
+
1463
  # save dict to disk every self.clear_mem_ncells/10 (default 100) simulated cells
1464
+ if i % max(1, self.clear_mem_ncells / 10) == 0:
1465
  pu.write_perturbation_dictionary(
1466
  cos_sims_dict,
1467
  f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
 
1471
  stored_gene_embs_dict,
1472
  f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
1473
  )
1474
+
1475
  # reset and clear memory every self.clear_mem_ncells (default 1000) simulated cells or at the end of the example cell
1476
  if i % self.clear_mem_ncells == 0:
1477
  pickle_batch += 1
 
1480
  else:
1481
  cos_sims_dict = {
1482
  state: defaultdict(list)
1483
+ for state in pu.get_possible_states(
1484
+ self.cell_states_to_model
1485
+ )
1486
  }
1487
+
1488
  if self.emb_mode == "cls_and_gene":
1489
  stored_gene_embs_dict = defaultdict(list)
1490
+
1491
  torch.cuda.empty_cache()
1492
+
1493
  pu.write_perturbation_dictionary(
1494
+ cos_sims_dict,
1495
+ f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
1496
  )
1497
+
1498
  if self.emb_mode == "cls_and_gene":
1499
  pu.write_perturbation_dictionary(
1500
  stored_gene_embs_dict,
 
1511
  }
1512
 
1513
  if self.emb_mode == "cls_and_gene":
1514
+ stored_gene_embs_dict = defaultdict(list)
1515
+
1516
  # clear memory between cells
1517
  del perturbation_batch
1518
  del original_cls_emb
 
1520
  del full_original_emb
1521
  torch.cuda.empty_cache()
1522
 
 
1523
  def update_perturbation_dictionary(
1524
  self,
1525
  cos_sims_dict: defaultdict,
 
1554
  for i, cos in enumerate(cos_sims_data.tolist()):
1555
  cos_sims_dict[(gene_list[i], "cell_emb")].append(cos)
1556
 
1557
+ return cos_sims_dict
geneformer/in_silico_perturber_stats.py CHANGED
@@ -37,8 +37,8 @@ from scipy.stats import ranksums
37
  from sklearn.mixture import GaussianMixture
38
  from tqdm.auto import tqdm, trange
39
 
 
40
  from .perturber_utils import flatten_list, validate_cell_states_to_model
41
- from . import TOKEN_DICTIONARY_FILE, ENSEMBL_DICTIONARY_FILE
42
 
43
  logger = logging.getLogger(__name__)
44
 
@@ -194,23 +194,29 @@ def get_impact_component(test_value, gaussian_mixture_model):
194
  def isp_aggregate_grouped_perturb(cos_sims_df, dict_list, genes_perturbed):
195
  names = ["Cosine_sim", "Gene"]
196
  cos_sims_full_dfs = []
197
- if isinstance(genes_perturbed,list):
198
- if len(genes_perturbed)>1:
199
- gene_ids_df = cos_sims_df.loc[np.isin([set(idx) for idx in cos_sims_df["Ensembl_ID"]], set(genes_perturbed)), :]
 
 
 
 
 
 
200
  else:
201
- gene_ids_df = cos_sims_df.loc[np.isin(cos_sims_df["Ensembl_ID"], genes_perturbed), :]
 
 
202
  else:
203
  logger.error(
204
- "aggregate_data is for perturbation of single gene or single group of genes. genes_to_perturb should be formatted as list."
205
- )
206
- raise
207
 
208
  if gene_ids_df.empty:
209
- logger.error(
210
- "genes_to_perturb not found in data."
211
- )
212
  raise
213
-
214
  tokens = gene_ids_df["Gene"]
215
  symbols = gene_ids_df["Gene_name"]
216
 
@@ -223,7 +229,7 @@ def isp_aggregate_grouped_perturb(cos_sims_df, dict_list, genes_perturbed):
223
  df["Cosine_sim"] = cos_shift_data
224
  df["Gene"] = symbol
225
  cos_sims_full_dfs.append(df)
226
-
227
  return pd.concat(cos_sims_full_dfs)
228
 
229
 
@@ -1018,7 +1024,7 @@ class InSilicoPerturberStats:
1018
  },
1019
  index=[i for i in range(len(gene_list))],
1020
  )
1021
-
1022
  if self.mode == "goal_state_shift":
1023
  cos_sims_df = isp_stats_to_goal_state(
1024
  cos_sims_df_initial,
@@ -1045,12 +1051,16 @@ class InSilicoPerturberStats:
1045
  cos_sims_df_initial, dict_list, self.combos, self.anchor_token
1046
  )
1047
 
1048
- elif self.mode == "aggregate_data":
1049
- cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list, self.genes_perturbed)
 
 
1050
 
1051
  elif self.mode == "aggregate_gene_shifts":
1052
  if (self.genes_perturbed == "all") and (self.combos == 0):
1053
- tuple_types = [True if isinstance(genes, tuple) else False for genes in gene_list]
 
 
1054
  if all(tuple_types):
1055
  token_dtype = "tuple"
1056
  elif not any(tuple_types):
@@ -1059,13 +1069,13 @@ class InSilicoPerturberStats:
1059
  token_dtype = "mix"
1060
  else:
1061
  token_dtype = "mix"
1062
-
1063
  cos_sims_df = isp_aggregate_gene_shifts(
1064
  cos_sims_df_initial,
1065
  dict_list,
1066
  self.gene_token_id_dict,
1067
  self.gene_id_name_dict,
1068
- token_dtype
1069
  )
1070
 
1071
  # save perturbation stats to output_path
 
37
  from sklearn.mixture import GaussianMixture
38
  from tqdm.auto import tqdm, trange
39
 
40
+ from . import ENSEMBL_DICTIONARY_FILE, TOKEN_DICTIONARY_FILE
41
  from .perturber_utils import flatten_list, validate_cell_states_to_model
 
42
 
43
  logger = logging.getLogger(__name__)
44
 
 
194
  def isp_aggregate_grouped_perturb(cos_sims_df, dict_list, genes_perturbed):
195
  names = ["Cosine_sim", "Gene"]
196
  cos_sims_full_dfs = []
197
+ if isinstance(genes_perturbed, list):
198
+ if len(genes_perturbed) > 1:
199
+ gene_ids_df = cos_sims_df.loc[
200
+ np.isin(
201
+ [set(idx) for idx in cos_sims_df["Ensembl_ID"]],
202
+ set(genes_perturbed),
203
+ ),
204
+ :,
205
+ ]
206
  else:
207
+ gene_ids_df = cos_sims_df.loc[
208
+ np.isin(cos_sims_df["Ensembl_ID"], genes_perturbed), :
209
+ ]
210
  else:
211
  logger.error(
212
+ "aggregate_data is for perturbation of single gene or single group of genes. genes_to_perturb should be formatted as list."
213
+ )
214
+ raise
215
 
216
  if gene_ids_df.empty:
217
+ logger.error("genes_to_perturb not found in data.")
 
 
218
  raise
219
+
220
  tokens = gene_ids_df["Gene"]
221
  symbols = gene_ids_df["Gene_name"]
222
 
 
229
  df["Cosine_sim"] = cos_shift_data
230
  df["Gene"] = symbol
231
  cos_sims_full_dfs.append(df)
232
+
233
  return pd.concat(cos_sims_full_dfs)
234
 
235
 
 
1024
  },
1025
  index=[i for i in range(len(gene_list))],
1026
  )
1027
+
1028
  if self.mode == "goal_state_shift":
1029
  cos_sims_df = isp_stats_to_goal_state(
1030
  cos_sims_df_initial,
 
1051
  cos_sims_df_initial, dict_list, self.combos, self.anchor_token
1052
  )
1053
 
1054
+ elif self.mode == "aggregate_data":
1055
+ cos_sims_df = isp_aggregate_grouped_perturb(
1056
+ cos_sims_df_initial, dict_list, self.genes_perturbed
1057
+ )
1058
 
1059
  elif self.mode == "aggregate_gene_shifts":
1060
  if (self.genes_perturbed == "all") and (self.combos == 0):
1061
+ tuple_types = [
1062
+ True if isinstance(genes, tuple) else False for genes in gene_list
1063
+ ]
1064
  if all(tuple_types):
1065
  token_dtype = "tuple"
1066
  elif not any(tuple_types):
 
1069
  token_dtype = "mix"
1070
  else:
1071
  token_dtype = "mix"
1072
+
1073
  cos_sims_df = isp_aggregate_gene_shifts(
1074
  cos_sims_df_initial,
1075
  dict_list,
1076
  self.gene_token_id_dict,
1077
  self.gene_id_name_dict,
1078
+ token_dtype,
1079
  )
1080
 
1081
  # save perturbation stats to output_path
geneformer/mtl/collators.py CHANGED
@@ -1,4 +1,4 @@
1
- #imports
2
  import torch
3
 
4
  from ..collator_for_classification import DataCollatorForGeneClassification
@@ -7,6 +7,7 @@ from ..collator_for_classification import DataCollatorForGeneClassification
7
  Geneformer collator for multi-task cell classification.
8
  """
9
 
 
10
  class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
11
  class_type = "cell"
12
 
@@ -47,7 +48,10 @@ class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassificati
47
  batch["labels"] = labels
48
  else:
49
  # If no labels are present, create empty labels for all tasks
50
- batch["labels"] = {task: torch.tensor([], dtype=torch.long) for task in features[0]["input_ids"].keys()}
 
 
 
51
 
52
  return batch
53
 
@@ -59,8 +63,11 @@ class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassificati
59
  batch[k] = v.clone().detach()
60
  elif isinstance(v, dict):
61
  # Assuming nested structure needs conversion
62
- batch[k] = {task: torch.tensor(labels, dtype=torch.int64) for task, labels in v.items()}
 
 
 
63
  else:
64
  batch[k] = torch.tensor(v, dtype=torch.int64)
65
 
66
- return batch
 
1
+ # imports
2
  import torch
3
 
4
  from ..collator_for_classification import DataCollatorForGeneClassification
 
7
  Geneformer collator for multi-task cell classification.
8
  """
9
 
10
+
11
  class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
12
  class_type = "cell"
13
 
 
48
  batch["labels"] = labels
49
  else:
50
  # If no labels are present, create empty labels for all tasks
51
+ batch["labels"] = {
52
+ task: torch.tensor([], dtype=torch.long)
53
+ for task in features[0]["input_ids"].keys()
54
+ }
55
 
56
  return batch
57
 
 
63
  batch[k] = v.clone().detach()
64
  elif isinstance(v, dict):
65
  # Assuming nested structure needs conversion
66
+ batch[k] = {
67
+ task: torch.tensor(labels, dtype=torch.int64)
68
+ for task, labels in v.items()
69
+ }
70
  else:
71
  batch[k] = torch.tensor(v, dtype=torch.int64)
72
 
73
+ return batch
geneformer/mtl/data.py CHANGED
@@ -1,6 +1,8 @@
1
- from .imports import *
2
  import os
 
3
  from .collators import DataCollatorForMultitaskCellClassification
 
 
4
 
5
  def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
6
  try:
@@ -14,7 +16,9 @@ def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type="
14
  available_columns = set(dataset.column_names)
15
  for column in task_to_column.values():
16
  if column not in available_columns:
17
- raise KeyError(f"Column {column} not found in the dataset. Available columns: {list(available_columns)}")
 
 
18
 
19
  label_mappings = {}
20
  task_label_mappings = {}
@@ -25,13 +29,17 @@ def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type="
25
  if not is_test:
26
  for task, column in task_to_column.items():
27
  unique_values = sorted(set(dataset[column])) # Ensure consistency
28
- label_mappings[column] = {label: idx for idx, label in enumerate(unique_values)}
 
 
29
  task_label_mappings[task] = label_mappings[column]
30
  num_labels_list.append(len(unique_values))
31
 
32
  # Print the mappings for each task with dataset type prefix
33
  for task, mapping in task_label_mappings.items():
34
- print(f"{dataset_type.capitalize()} mapping for {task}: {mapping}") # sanity check, for train/validation splits
 
 
35
 
36
  # Save the task label mappings as a pickle file
37
  with open(f"{config['results_dir']}/task_label_mappings.pkl", "wb") as f:
@@ -40,24 +48,26 @@ def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type="
40
  # Load task label mappings from pickle file for test data
41
  with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
42
  task_label_mappings = pickle.load(f)
43
-
44
  # Infer num_labels_list from task_label_mappings
45
  for task, mapping in task_label_mappings.items():
46
  num_labels_list.append(len(mapping))
47
 
48
  # Store unique cell IDs in a separate dictionary
49
  for idx, record in enumerate(dataset):
50
- cell_id = record.get('unique_cell_id', idx)
51
  cell_id_mapping[idx] = cell_id
52
 
53
  # Transform records to the desired format
54
  transformed_dataset = []
55
  for idx, record in enumerate(dataset):
56
  transformed_record = {}
57
- transformed_record['input_ids'] = torch.tensor(record['input_ids'], dtype=torch.long)
58
-
 
 
59
  # Use index-based cell ID for internal tracking
60
- transformed_record['cell_id'] = idx
61
 
62
  if not is_test:
63
  # Prepare labels
@@ -66,11 +76,11 @@ def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type="
66
  label_value = record[column]
67
  label_index = task_label_mappings[task][label_value]
68
  label_dict[task] = label_index
69
- transformed_record['label'] = label_dict
70
  else:
71
  # Create dummy labels for test data
72
  label_dict = {task: -1 for task in config["task_names"]}
73
- transformed_record['label'] = label_dict
74
 
75
  transformed_dataset.append(transformed_record)
76
 
@@ -81,36 +91,60 @@ def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type="
81
  print(f"An error occurred while loading or preprocessing data: {e}")
82
  return None, None, None
83
 
 
84
  def preload_and_process_data(config):
85
  # Load and preprocess data once
86
- train_dataset, train_cell_id_mapping, num_labels_list = load_and_preprocess_data(config["train_path"], config, dataset_type="train")
87
- val_dataset, val_cell_id_mapping, _ = load_and_preprocess_data(config["val_path"], config, dataset_type="validation")
88
- return train_dataset, train_cell_id_mapping, val_dataset, val_cell_id_mapping, num_labels_list
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  def get_data_loader(preprocessed_dataset, batch_size):
91
- nproc = os.cpu_count() ### I/O operations
92
-
93
  data_collator = DataCollatorForMultitaskCellClassification()
94
-
95
- loader = DataLoader(preprocessed_dataset, batch_size=batch_size, shuffle=True,
96
- collate_fn=data_collator, num_workers=nproc, pin_memory=True)
 
 
 
 
 
 
97
  return loader
 
 
98
  def preload_data(config):
99
  # Preprocessing the data before the Optuna trials start
100
  train_loader = get_data_loader("train", config)
101
  val_loader = get_data_loader("val", config)
102
  return train_loader, val_loader
103
 
 
104
  def load_and_preprocess_test_data(config):
105
  """
106
  Load and preprocess test data, treating it as unlabeled.
107
  """
108
  return load_and_preprocess_data(config["test_path"], config, is_test=True)
109
 
 
110
  def prepare_test_loader(config):
111
  """
112
  Prepare DataLoader for the test dataset.
113
  """
114
- test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(config)
115
- test_loader = get_data_loader(test_dataset, config['batch_size'])
116
- return test_loader, cell_id_mapping, num_labels_list
 
 
 
 
1
  import os
2
+
3
  from .collators import DataCollatorForMultitaskCellClassification
4
+ from .imports import *
5
+
6
 
7
  def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
8
  try:
 
16
  available_columns = set(dataset.column_names)
17
  for column in task_to_column.values():
18
  if column not in available_columns:
19
+ raise KeyError(
20
+ f"Column {column} not found in the dataset. Available columns: {list(available_columns)}"
21
+ )
22
 
23
  label_mappings = {}
24
  task_label_mappings = {}
 
29
  if not is_test:
30
  for task, column in task_to_column.items():
31
  unique_values = sorted(set(dataset[column])) # Ensure consistency
32
+ label_mappings[column] = {
33
+ label: idx for idx, label in enumerate(unique_values)
34
+ }
35
  task_label_mappings[task] = label_mappings[column]
36
  num_labels_list.append(len(unique_values))
37
 
38
  # Print the mappings for each task with dataset type prefix
39
  for task, mapping in task_label_mappings.items():
40
+ print(
41
+ f"{dataset_type.capitalize()} mapping for {task}: {mapping}"
42
+ ) # sanity check, for train/validation splits
43
 
44
  # Save the task label mappings as a pickle file
45
  with open(f"{config['results_dir']}/task_label_mappings.pkl", "wb") as f:
 
48
  # Load task label mappings from pickle file for test data
49
  with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
50
  task_label_mappings = pickle.load(f)
51
+
52
  # Infer num_labels_list from task_label_mappings
53
  for task, mapping in task_label_mappings.items():
54
  num_labels_list.append(len(mapping))
55
 
56
  # Store unique cell IDs in a separate dictionary
57
  for idx, record in enumerate(dataset):
58
+ cell_id = record.get("unique_cell_id", idx)
59
  cell_id_mapping[idx] = cell_id
60
 
61
  # Transform records to the desired format
62
  transformed_dataset = []
63
  for idx, record in enumerate(dataset):
64
  transformed_record = {}
65
+ transformed_record["input_ids"] = torch.tensor(
66
+ record["input_ids"], dtype=torch.long
67
+ )
68
+
69
  # Use index-based cell ID for internal tracking
70
+ transformed_record["cell_id"] = idx
71
 
72
  if not is_test:
73
  # Prepare labels
 
76
  label_value = record[column]
77
  label_index = task_label_mappings[task][label_value]
78
  label_dict[task] = label_index
79
+ transformed_record["label"] = label_dict
80
  else:
81
  # Create dummy labels for test data
82
  label_dict = {task: -1 for task in config["task_names"]}
83
+ transformed_record["label"] = label_dict
84
 
85
  transformed_dataset.append(transformed_record)
86
 
 
91
  print(f"An error occurred while loading or preprocessing data: {e}")
92
  return None, None, None
93
 
94
+
95
  def preload_and_process_data(config):
96
  # Load and preprocess data once
97
+ train_dataset, train_cell_id_mapping, num_labels_list = load_and_preprocess_data(
98
+ config["train_path"], config, dataset_type="train"
99
+ )
100
+ val_dataset, val_cell_id_mapping, _ = load_and_preprocess_data(
101
+ config["val_path"], config, dataset_type="validation"
102
+ )
103
+ return (
104
+ train_dataset,
105
+ train_cell_id_mapping,
106
+ val_dataset,
107
+ val_cell_id_mapping,
108
+ num_labels_list,
109
+ )
110
+
111
 
112
  def get_data_loader(preprocessed_dataset, batch_size):
113
+ nproc = os.cpu_count() ### I/O operations
114
+
115
  data_collator = DataCollatorForMultitaskCellClassification()
116
+
117
+ loader = DataLoader(
118
+ preprocessed_dataset,
119
+ batch_size=batch_size,
120
+ shuffle=True,
121
+ collate_fn=data_collator,
122
+ num_workers=nproc,
123
+ pin_memory=True,
124
+ )
125
  return loader
126
+
127
+
128
  def preload_data(config):
129
  # Preprocessing the data before the Optuna trials start
130
  train_loader = get_data_loader("train", config)
131
  val_loader = get_data_loader("val", config)
132
  return train_loader, val_loader
133
 
134
+
135
  def load_and_preprocess_test_data(config):
136
  """
137
  Load and preprocess test data, treating it as unlabeled.
138
  """
139
  return load_and_preprocess_data(config["test_path"], config, is_test=True)
140
 
141
+
142
  def prepare_test_loader(config):
143
  """
144
  Prepare DataLoader for the test dataset.
145
  """
146
+ test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(
147
+ config
148
+ )
149
+ test_loader = get_data_loader(test_dataset, config["batch_size"])
150
+ return test_loader, cell_id_mapping, num_labels_list
geneformer/mtl/eval_utils.py CHANGED
@@ -1,29 +1,33 @@
1
- from .imports import *
2
  import pandas as pd
 
3
  from .data import prepare_test_loader
 
4
  from .model import GeneformerMultiTask
5
 
 
6
  def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config):
7
  task_pred_labels = {task_name: [] for task_name in config["task_names"]}
8
  task_pred_probs = {task_name: [] for task_name in config["task_names"]}
9
  cell_ids = []
10
 
11
- # Load task label mappings from pickle file
12
- with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
13
- task_label_mappings = pickle.load(f)
14
 
15
  model.eval()
16
  with torch.no_grad():
17
  for batch in test_loader:
18
- input_ids = batch['input_ids'].to(device)
19
- attention_mask = batch['attention_mask'].to(device)
20
  _, logits, _ = model(input_ids, attention_mask)
21
- for sample_idx in range(len(batch['input_ids'])):
22
- cell_id = cell_id_mapping[batch['cell_id'][sample_idx].item()]
23
  cell_ids.append(cell_id)
24
  for i, task_name in enumerate(config["task_names"]):
25
  pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
26
- pred_prob = torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
 
 
27
  task_pred_labels[task_name].append(pred_label)
28
  task_pred_probs[task_name].append(pred_prob)
29
 
@@ -31,19 +35,22 @@ def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config):
31
  test_results_dir = config["results_dir"]
32
  os.makedirs(test_results_dir, exist_ok=True)
33
  test_preds_file = os.path.join(test_results_dir, "test_preds.csv")
34
-
35
  rows = []
36
  for sample_idx in range(len(cell_ids)):
37
- row = {'Cell ID': cell_ids[sample_idx]}
38
  for task_name in config["task_names"]:
39
- row[f'{task_name} Prediction'] = task_pred_labels[task_name][sample_idx]
40
- row[f'{task_name} Probabilities'] = ','.join(map(str, task_pred_probs[task_name][sample_idx]))
 
 
41
  rows.append(row)
42
-
43
  df = pd.DataFrame(rows)
44
  df.to_csv(test_preds_file, index=False)
45
  print(f"Test predictions saved to {test_preds_file}")
46
 
 
47
  def load_and_evaluate_test_model(config):
48
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
  test_loader, cell_id_mapping, num_labels_list = prepare_test_loader(config)
@@ -51,7 +58,7 @@ def load_and_evaluate_test_model(config):
51
  hyperparams_path = os.path.join(model_directory, "hyperparameters.json")
52
 
53
  # Load the saved best hyperparameters
54
- with open(hyperparams_path, 'r') as f:
55
  best_hyperparams = json.load(f)
56
 
57
  # Extract the task weights if present, otherwise set to None
@@ -72,7 +79,7 @@ def load_and_evaluate_test_model(config):
72
  num_labels_list,
73
  dropout_rate=best_hyperparams["dropout_rate"],
74
  use_task_weights=config["use_task_weights"],
75
- task_weights=normalized_task_weights
76
  )
77
  best_model.load_state_dict(torch.load(best_model_path))
78
  best_model.to(device)
 
 
1
  import pandas as pd
2
+
3
  from .data import prepare_test_loader
4
+ from .imports import *
5
  from .model import GeneformerMultiTask
6
 
7
+
8
  def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config):
9
  task_pred_labels = {task_name: [] for task_name in config["task_names"]}
10
  task_pred_probs = {task_name: [] for task_name in config["task_names"]}
11
  cell_ids = []
12
 
13
+ # # Load task label mappings from pickle file
14
+ # with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
15
+ # task_label_mappings = pickle.load(f)
16
 
17
  model.eval()
18
  with torch.no_grad():
19
  for batch in test_loader:
20
+ input_ids = batch["input_ids"].to(device)
21
+ attention_mask = batch["attention_mask"].to(device)
22
  _, logits, _ = model(input_ids, attention_mask)
23
+ for sample_idx in range(len(batch["input_ids"])):
24
+ cell_id = cell_id_mapping[batch["cell_id"][sample_idx].item()]
25
  cell_ids.append(cell_id)
26
  for i, task_name in enumerate(config["task_names"]):
27
  pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
28
+ pred_prob = (
29
+ torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
30
+ )
31
  task_pred_labels[task_name].append(pred_label)
32
  task_pred_probs[task_name].append(pred_prob)
33
 
 
35
  test_results_dir = config["results_dir"]
36
  os.makedirs(test_results_dir, exist_ok=True)
37
  test_preds_file = os.path.join(test_results_dir, "test_preds.csv")
38
+
39
  rows = []
40
  for sample_idx in range(len(cell_ids)):
41
+ row = {"Cell ID": cell_ids[sample_idx]}
42
  for task_name in config["task_names"]:
43
+ row[f"{task_name} Prediction"] = task_pred_labels[task_name][sample_idx]
44
+ row[f"{task_name} Probabilities"] = ",".join(
45
+ map(str, task_pred_probs[task_name][sample_idx])
46
+ )
47
  rows.append(row)
48
+
49
  df = pd.DataFrame(rows)
50
  df.to_csv(test_preds_file, index=False)
51
  print(f"Test predictions saved to {test_preds_file}")
52
 
53
+
54
  def load_and_evaluate_test_model(config):
55
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
  test_loader, cell_id_mapping, num_labels_list = prepare_test_loader(config)
 
58
  hyperparams_path = os.path.join(model_directory, "hyperparameters.json")
59
 
60
  # Load the saved best hyperparameters
61
+ with open(hyperparams_path, "r") as f:
62
  best_hyperparams = json.load(f)
63
 
64
  # Extract the task weights if present, otherwise set to None
 
79
  num_labels_list,
80
  dropout_rate=best_hyperparams["dropout_rate"],
81
  use_task_weights=config["use_task_weights"],
82
+ task_weights=normalized_task_weights,
83
  )
84
  best_model.load_state_dict(torch.load(best_model_path))
85
  best_model.to(device)
geneformer/mtl/imports.py CHANGED
@@ -1,46 +1,43 @@
1
- import numpy as np
 
 
 
2
  import pickle
3
- import torch
4
- import torch.nn as nn
5
- import torch.optim as optim
6
- import torch.nn.functional as F
7
- from torch.utils.data import DataLoader
8
-
9
- from itertools import chain
10
  import warnings
11
  from enum import Enum
 
12
  from typing import Dict, List, Optional, Union
13
- import sys
14
- import os
15
- import json
16
- import gc
17
- import functools
18
- import pandas as pd
19
-
20
- from sklearn.metrics import f1_score, accuracy_score, roc_auc_score, roc_curve
21
- from sklearn.preprocessing import LabelEncoder
22
- from sklearn.model_selection import train_test_split
23
 
 
24
  import optuna
25
-
 
 
 
 
 
 
 
 
 
26
  from transformers import (
 
 
27
  BertConfig,
28
  BertModel,
29
- AdamW,
30
- get_linear_schedule_with_warmup,
31
- get_cosine_schedule_with_warmup,
32
  DataCollatorForTokenClassification,
33
  SpecialTokensMixin,
34
- BatchEncoding,
 
35
  get_scheduler,
36
  )
37
  from transformers.utils import logging, to_py_obj
38
 
39
- from datasets import load_from_disk
40
 
41
  # local modules
42
- from .data import preload_and_process_data, get_data_loader
43
  from .model import GeneformerMultiTask
44
- from .utils import save_model
45
  from .optuna_utils import create_optuna_study
46
- from .collators import DataCollatorForMultitaskCellClassification
 
1
+ import functools
2
+ import gc
3
+ import json
4
+ import os
5
  import pickle
6
+ import sys
 
 
 
 
 
 
7
  import warnings
8
  from enum import Enum
9
+ from itertools import chain
10
  from typing import Dict, List, Optional, Union
 
 
 
 
 
 
 
 
 
 
11
 
12
+ import numpy as np
13
  import optuna
14
+ import pandas as pd
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.optim as optim
19
+ from datasets import load_from_disk
20
+ from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, roc_curve
21
+ from sklearn.model_selection import train_test_split
22
+ from sklearn.preprocessing import LabelEncoder
23
+ from torch.utils.data import DataLoader
24
  from transformers import (
25
+ AdamW,
26
+ BatchEncoding,
27
  BertConfig,
28
  BertModel,
 
 
 
29
  DataCollatorForTokenClassification,
30
  SpecialTokensMixin,
31
+ get_cosine_schedule_with_warmup,
32
+ get_linear_schedule_with_warmup,
33
  get_scheduler,
34
  )
35
  from transformers.utils import logging, to_py_obj
36
 
37
+ from .collators import DataCollatorForMultitaskCellClassification
38
 
39
  # local modules
40
+ from .data import get_data_loader, preload_and_process_data
41
  from .model import GeneformerMultiTask
 
42
  from .optuna_utils import create_optuna_study
43
+ from .utils import save_model
geneformer/mtl/model.py CHANGED
@@ -1,13 +1,17 @@
1
- from transformers import BertModel, BertConfig
2
  import torch
3
  import torch.nn as nn
 
 
4
 
5
  class AttentionPool(nn.Module):
6
  """Attention-based pooling layer."""
 
7
  def __init__(self, hidden_size):
8
  super(AttentionPool, self).__init__()
9
  self.attention_weights = nn.Parameter(torch.randn(hidden_size, 1))
10
- nn.init.xavier_uniform_(self.attention_weights) # https://pytorch.org/docs/stable/nn.init.html
 
 
11
 
12
  def forward(self, hidden_states):
13
  attention_scores = torch.matmul(hidden_states, self.attention_weights)
@@ -15,8 +19,18 @@ class AttentionPool(nn.Module):
15
  pooled_output = torch.sum(hidden_states * attention_scores, dim=1)
16
  return pooled_output
17
 
 
18
  class GeneformerMultiTask(nn.Module):
19
- def __init__(self, pretrained_path, num_labels_list, dropout_rate=0.1, use_task_weights=False, task_weights=None, max_layers_to_freeze=0, use_attention_pooling=False):
 
 
 
 
 
 
 
 
 
20
  super(GeneformerMultiTask, self).__init__()
21
  self.config = BertConfig.from_pretrained(pretrained_path)
22
  self.bert = BertModel(self.config)
@@ -25,20 +39,31 @@ class GeneformerMultiTask(nn.Module):
25
  self.dropout = nn.Dropout(dropout_rate)
26
  self.use_attention_pooling = use_attention_pooling
27
 
28
- if use_task_weights and (task_weights is None or len(task_weights) != len(num_labels_list)):
29
- raise ValueError("Task weights must be defined and match the number of tasks when 'use_task_weights' is True.")
30
- self.task_weights = task_weights if use_task_weights else [1.0] * len(num_labels_list)
 
 
 
 
 
 
31
 
32
  # Freeze the specified initial layers
33
  for layer in self.bert.encoder.layer[:max_layers_to_freeze]:
34
  for param in layer.parameters():
35
  param.requires_grad = False
36
 
37
- self.attention_pool = AttentionPool(self.config.hidden_size) if use_attention_pooling else None
 
 
38
 
39
- self.classification_heads = nn.ModuleList([
40
- nn.Linear(self.config.hidden_size, num_labels) for num_labels in num_labels_list
41
- ])
 
 
 
42
  # initialization of the classification heads: https://pytorch.org/docs/stable/nn.init.html
43
  for head in self.classification_heads:
44
  nn.init.xavier_uniform_(head.weight)
@@ -53,7 +78,11 @@ class GeneformerMultiTask(nn.Module):
53
  sequence_output = outputs.last_hidden_state
54
 
55
  try:
56
- pooled_output = self.attention_pool(sequence_output) if self.use_attention_pooling else sequence_output[:, 0, :]
 
 
 
 
57
  pooled_output = self.dropout(pooled_output)
58
  except Exception as e:
59
  raise RuntimeError(f"Error during pooling and dropout: {e}")
@@ -62,23 +91,31 @@ class GeneformerMultiTask(nn.Module):
62
  logits = []
63
  losses = []
64
 
65
- for task_id, (head, num_labels) in enumerate(zip(self.classification_heads, self.num_labels_list)):
 
 
66
  try:
67
  task_logits = head(pooled_output)
68
  except Exception as e:
69
- raise RuntimeError(f"Error during forward pass of classification head {task_id}: {e}")
 
 
70
 
71
  logits.append(task_logits)
72
 
73
  if labels is not None:
74
  try:
75
  loss_fct = nn.CrossEntropyLoss()
76
- task_loss = loss_fct(task_logits.view(-1, num_labels), labels[task_id].view(-1))
 
 
77
  if self.use_task_weights:
78
  task_loss *= self.task_weights[task_id]
79
  total_loss += task_loss
80
  losses.append(task_loss.item())
81
  except Exception as e:
82
- raise RuntimeError(f"Error during loss computation for task {task_id}: {e}")
 
 
83
 
84
  return total_loss, logits, losses if labels is not None else logits
 
 
1
  import torch
2
  import torch.nn as nn
3
+ from transformers import BertConfig, BertModel
4
+
5
 
6
  class AttentionPool(nn.Module):
7
  """Attention-based pooling layer."""
8
+
9
  def __init__(self, hidden_size):
10
  super(AttentionPool, self).__init__()
11
  self.attention_weights = nn.Parameter(torch.randn(hidden_size, 1))
12
+ nn.init.xavier_uniform_(
13
+ self.attention_weights
14
+ ) # https://pytorch.org/docs/stable/nn.init.html
15
 
16
  def forward(self, hidden_states):
17
  attention_scores = torch.matmul(hidden_states, self.attention_weights)
 
19
  pooled_output = torch.sum(hidden_states * attention_scores, dim=1)
20
  return pooled_output
21
 
22
+
23
  class GeneformerMultiTask(nn.Module):
24
+ def __init__(
25
+ self,
26
+ pretrained_path,
27
+ num_labels_list,
28
+ dropout_rate=0.1,
29
+ use_task_weights=False,
30
+ task_weights=None,
31
+ max_layers_to_freeze=0,
32
+ use_attention_pooling=False,
33
+ ):
34
  super(GeneformerMultiTask, self).__init__()
35
  self.config = BertConfig.from_pretrained(pretrained_path)
36
  self.bert = BertModel(self.config)
 
39
  self.dropout = nn.Dropout(dropout_rate)
40
  self.use_attention_pooling = use_attention_pooling
41
 
42
+ if use_task_weights and (
43
+ task_weights is None or len(task_weights) != len(num_labels_list)
44
+ ):
45
+ raise ValueError(
46
+ "Task weights must be defined and match the number of tasks when 'use_task_weights' is True."
47
+ )
48
+ self.task_weights = (
49
+ task_weights if use_task_weights else [1.0] * len(num_labels_list)
50
+ )
51
 
52
  # Freeze the specified initial layers
53
  for layer in self.bert.encoder.layer[:max_layers_to_freeze]:
54
  for param in layer.parameters():
55
  param.requires_grad = False
56
 
57
+ self.attention_pool = (
58
+ AttentionPool(self.config.hidden_size) if use_attention_pooling else None
59
+ )
60
 
61
+ self.classification_heads = nn.ModuleList(
62
+ [
63
+ nn.Linear(self.config.hidden_size, num_labels)
64
+ for num_labels in num_labels_list
65
+ ]
66
+ )
67
  # initialization of the classification heads: https://pytorch.org/docs/stable/nn.init.html
68
  for head in self.classification_heads:
69
  nn.init.xavier_uniform_(head.weight)
 
78
  sequence_output = outputs.last_hidden_state
79
 
80
  try:
81
+ pooled_output = (
82
+ self.attention_pool(sequence_output)
83
+ if self.use_attention_pooling
84
+ else sequence_output[:, 0, :]
85
+ )
86
  pooled_output = self.dropout(pooled_output)
87
  except Exception as e:
88
  raise RuntimeError(f"Error during pooling and dropout: {e}")
 
91
  logits = []
92
  losses = []
93
 
94
+ for task_id, (head, num_labels) in enumerate(
95
+ zip(self.classification_heads, self.num_labels_list)
96
+ ):
97
  try:
98
  task_logits = head(pooled_output)
99
  except Exception as e:
100
+ raise RuntimeError(
101
+ f"Error during forward pass of classification head {task_id}: {e}"
102
+ )
103
 
104
  logits.append(task_logits)
105
 
106
  if labels is not None:
107
  try:
108
  loss_fct = nn.CrossEntropyLoss()
109
+ task_loss = loss_fct(
110
+ task_logits.view(-1, num_labels), labels[task_id].view(-1)
111
+ )
112
  if self.use_task_weights:
113
  task_loss *= self.task_weights[task_id]
114
  total_loss += task_loss
115
  losses.append(task_loss.item())
116
  except Exception as e:
117
+ raise RuntimeError(
118
+ f"Error during loss computation for task {task_id}: {e}"
119
+ )
120
 
121
  return total_loss, logits, losses if labels is not None else logits
geneformer/mtl/optuna_utils.py CHANGED
@@ -1,21 +1,27 @@
1
  import optuna
2
  from optuna.integration import TensorBoardCallback
3
 
 
4
  def save_trial_callback(study, trial, trials_result_path):
5
  with open(trials_result_path, "a") as f:
6
- f.write(f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n")
 
 
 
7
 
8
  def create_optuna_study(objective, n_trials, trials_result_path, tensorboard_log_dir):
9
  study = optuna.create_study(direction="maximize")
10
-
11
  # init TensorBoard callback
12
- tensorboard_callback = TensorBoardCallback(dirname=tensorboard_log_dir, metric_name="F1 Macro")
13
-
 
 
14
  # callback and TensorBoard callback
15
  callbacks = [
16
  lambda study, trial: save_trial_callback(study, trial, trials_result_path),
17
- tensorboard_callback
18
  ]
19
-
20
  study.optimize(objective, n_trials=n_trials, callbacks=callbacks)
21
- return study
 
1
  import optuna
2
  from optuna.integration import TensorBoardCallback
3
 
4
+
5
  def save_trial_callback(study, trial, trials_result_path):
6
  with open(trials_result_path, "a") as f:
7
+ f.write(
8
+ f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n"
9
+ )
10
+
11
 
12
  def create_optuna_study(objective, n_trials, trials_result_path, tensorboard_log_dir):
13
  study = optuna.create_study(direction="maximize")
14
+
15
  # init TensorBoard callback
16
+ tensorboard_callback = TensorBoardCallback(
17
+ dirname=tensorboard_log_dir, metric_name="F1 Macro"
18
+ )
19
+
20
  # callback and TensorBoard callback
21
  callbacks = [
22
  lambda study, trial: save_trial_callback(study, trial, trials_result_path),
23
+ tensorboard_callback,
24
  ]
25
+
26
  study.optimize(objective, n_trials=n_trials, callbacks=callbacks)
27
+ return study
geneformer/mtl/train.py CHANGED
@@ -1,14 +1,15 @@
1
- from .imports import *
2
- from .data import preload_and_process_data, get_data_loader
3
- from .model import GeneformerMultiTask
4
- from .utils import calculate_task_specific_metrics
5
- from torch.utils.tensorboard import SummaryWriter
6
- import pandas as pd
7
  import os
8
- from tqdm import tqdm
9
  import random
 
10
  import numpy as np
 
11
  import torch
 
 
 
 
 
 
12
 
13
 
14
  def set_seed(seed):
@@ -19,13 +20,18 @@ def set_seed(seed):
19
  torch.backends.cudnn.deterministic = True
20
  torch.backends.cudnn.benchmark = False
21
 
 
22
  def initialize_wandb(config):
23
  if config.get("use_wandb", False):
24
  import wandb
 
25
  wandb.init(project=config["wandb_project"], config=config)
26
  print("Weights & Biases (wandb) initialized and will be used for logging.")
27
  else:
28
- print("Weights & Biases (wandb) is not enabled. Logging will use other methods.")
 
 
 
29
 
30
  def create_model(config, num_labels_list, device):
31
  model = GeneformerMultiTask(
@@ -35,31 +41,48 @@ def create_model(config, num_labels_list, device):
35
  use_task_weights=config["use_task_weights"],
36
  task_weights=config["task_weights"],
37
  max_layers_to_freeze=config["max_layers_to_freeze"],
38
- use_attention_pooling=config["use_attention_pooling"]
39
  )
40
  if config["use_data_parallel"]:
41
  model = nn.DataParallel(model)
42
  return model.to(device)
43
 
 
44
  def setup_optimizer_and_scheduler(model, config, total_steps):
45
- optimizer = AdamW(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])
 
 
 
 
46
  warmup_steps = int(config["warmup_ratio"] * total_steps)
47
 
48
  if config["lr_scheduler_type"] == "linear":
49
- scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
 
 
50
  elif config["lr_scheduler_type"] == "cosine":
51
- scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps, num_cycles=0.5)
 
 
 
 
 
52
 
53
  return optimizer, scheduler
54
 
55
- def train_epoch(model, train_loader, optimizer, scheduler, device, config, writer, epoch):
 
 
 
56
  model.train()
57
  progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
58
  for batch_idx, batch in enumerate(progress_bar):
59
  optimizer.zero_grad()
60
- input_ids = batch['input_ids'].to(device)
61
- attention_mask = batch['attention_mask'].to(device)
62
- labels = [batch['labels'][task_name].to(device) for task_name in config["task_names"]]
 
 
63
 
64
  loss, _, _ = model(input_ids, attention_mask, labels)
65
  loss.backward()
@@ -70,15 +93,20 @@ def train_epoch(model, train_loader, optimizer, scheduler, device, config, write
70
  optimizer.step()
71
  scheduler.step()
72
 
73
- writer.add_scalar('Training Loss', loss.item(), epoch * len(train_loader) + batch_idx)
 
 
74
  if config.get("use_wandb", False):
75
- wandb.log({'Training Loss': loss.item()})
 
 
76
 
77
  # Update progress bar
78
- progress_bar.set_postfix({'loss': f"{loss.item():.4f}"})
79
 
80
  return loss.item() # Return the last batch loss
81
 
 
82
  def validate_model(model, val_loader, device, config):
83
  model.eval()
84
  val_loss = 0.0
@@ -88,17 +116,22 @@ def validate_model(model, val_loader, device, config):
88
 
89
  with torch.no_grad():
90
  for batch in val_loader:
91
- input_ids = batch['input_ids'].to(device)
92
- attention_mask = batch['attention_mask'].to(device)
93
- labels = [batch['labels'][task_name].to(device) for task_name in config["task_names"]]
 
 
 
94
  loss, logits, _ = model(input_ids, attention_mask, labels)
95
  val_loss += loss.item()
96
 
97
- for sample_idx in range(len(batch['input_ids'])):
98
  for i, task_name in enumerate(config["task_names"]):
99
- true_label = batch['labels'][task_name][sample_idx].item()
100
  pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
101
- pred_prob = torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
 
 
102
  task_true_labels[task_name].append(true_label)
103
  task_pred_labels[task_name].append(pred_label)
104
  task_pred_probs[task_name].append(pred_prob)
@@ -106,44 +139,70 @@ def validate_model(model, val_loader, device, config):
106
  val_loss /= len(val_loader)
107
  return val_loss, task_true_labels, task_pred_labels, task_pred_probs
108
 
 
109
  def log_metrics(task_metrics, val_loss, config, writer, epochs):
110
  for task_name, metrics in task_metrics.items():
111
- print(f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}")
 
 
112
  if config.get("use_wandb", False):
113
  import wandb
114
- wandb.log({
115
- f'{task_name} Validation F1 Macro': metrics['f1'],
116
- f'{task_name} Validation Accuracy': metrics['accuracy']
117
- })
118
 
119
- writer.add_scalar('Validation Loss', val_loss, epochs)
120
- for task_name, metrics in task_metrics.items():
121
- writer.add_scalar(f'{task_name} - Validation F1 Macro', metrics['f1'], epochs)
122
- writer.add_scalar(f'{task_name} - Validation Accuracy', metrics['accuracy'], epochs)
 
 
123
 
124
- def save_validation_predictions(val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config, trial_number=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  if trial_number is not None:
126
  trial_results_dir = os.path.join(config["results_dir"], f"trial_{trial_number}")
127
  os.makedirs(trial_results_dir, exist_ok=True)
128
  val_preds_file = os.path.join(trial_results_dir, "val_preds.csv")
129
  else:
130
  val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv")
131
-
132
  rows = []
133
  for sample_idx in range(len(val_cell_id_mapping)):
134
- row = {'Cell ID': val_cell_id_mapping[sample_idx]}
135
  for task_name in config["task_names"]:
136
- row[f'{task_name} True'] = task_true_labels[task_name][sample_idx]
137
- row[f'{task_name} Pred'] = task_pred_labels[task_name][sample_idx]
138
- row[f'{task_name} Probabilities'] = ','.join(map(str, task_pred_probs[task_name][sample_idx]))
 
 
139
  rows.append(row)
140
-
141
  df = pd.DataFrame(rows)
142
  df.to_csv(val_preds_file, index=False)
143
  print(f"Validation predictions saved to {val_preds_file}")
144
 
145
 
146
- def train_model(config, device, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list):
 
 
 
 
 
 
 
 
147
  set_seed(config["seed"])
148
  initialize_wandb(config)
149
 
@@ -156,46 +215,97 @@ def train_model(config, device, train_loader, val_loader, train_cell_id_mapping,
156
 
157
  epoch_progress = tqdm(range(config["epochs"]), desc="Training Progress")
158
  for epoch in epoch_progress:
159
- last_loss = train_epoch(model, train_loader, optimizer, scheduler, device, config, writer, epoch)
160
- epoch_progress.set_postfix({'last_loss': f"{last_loss:.4f}"})
 
 
161
 
162
- val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(model, val_loader, device, config)
 
 
163
  task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
164
-
165
  log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
166
  writer.close()
167
 
168
- save_validation_predictions(val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config)
 
 
169
 
170
  if config.get("use_wandb", False):
171
  import wandb
 
172
  wandb.finish()
173
 
174
  print(f"\nFinal Validation Loss: {val_loss:.4f}")
175
  return val_loss, model # Return both the validation loss and the trained model
176
 
177
- def objective(trial, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, config, device):
 
 
 
 
 
 
 
 
 
 
178
  set_seed(config["seed"]) # Set the seed before each trial
179
  initialize_wandb(config)
180
 
181
  # Hyperparameters
182
- config["learning_rate"] = trial.suggest_float("learning_rate", config["hyperparameters"]["learning_rate"]["low"], config["hyperparameters"]["learning_rate"]["high"], log=config["hyperparameters"]["learning_rate"]["log"])
183
- config["warmup_ratio"] = trial.suggest_float("warmup_ratio", config["hyperparameters"]["warmup_ratio"]["low"], config["hyperparameters"]["warmup_ratio"]["high"])
184
- config["weight_decay"] = trial.suggest_float("weight_decay", config["hyperparameters"]["weight_decay"]["low"], config["hyperparameters"]["weight_decay"]["high"])
185
- config["dropout_rate"] = trial.suggest_float("dropout_rate", config["hyperparameters"]["dropout_rate"]["low"], config["hyperparameters"]["dropout_rate"]["high"])
186
- config["lr_scheduler_type"] = trial.suggest_categorical("lr_scheduler_type", config["hyperparameters"]["lr_scheduler_type"]["choices"])
187
- config["use_attention_pooling"] = trial.suggest_categorical("use_attention_pooling", [True, False])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  if config["use_task_weights"]:
190
- config["task_weights"] = [trial.suggest_float(f"task_weight_{i}", config["hyperparameters"]["task_weights"]["low"], config["hyperparameters"]["task_weights"]["high"]) for i in range(len(num_labels_list))]
 
 
 
 
 
 
 
191
  weight_sum = sum(config["task_weights"])
192
- config["task_weights"] = [weight / weight_sum for weight in config["task_weights"]]
 
 
193
  else:
194
  config["task_weights"] = None
195
 
196
  # Fix for max_layers_to_freeze
197
  if isinstance(config["max_layers_to_freeze"], dict):
198
- config["max_layers_to_freeze"] = trial.suggest_int("max_layers_to_freeze", config["max_layers_to_freeze"]["min"], config["max_layers_to_freeze"]["max"])
 
 
 
 
199
  elif isinstance(config["max_layers_to_freeze"], int):
200
  # If it's already an int, we don't need to suggest it
201
  pass
@@ -210,15 +320,26 @@ def objective(trial, train_loader, val_loader, train_cell_id_mapping, val_cell_i
210
  writer = SummaryWriter(log_dir=log_dir)
211
 
212
  for epoch in range(config["epochs"]):
213
- train_epoch(model, train_loader, optimizer, scheduler, device, config, writer, epoch)
 
 
214
 
215
- val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(model, val_loader, device, config)
 
 
216
  task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
217
-
218
  log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
219
  writer.close()
220
 
221
- save_validation_predictions(val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config, trial.number)
 
 
 
 
 
 
 
222
 
223
  trial.set_user_attr("model_state_dict", model.state_dict())
224
  trial.set_user_attr("task_weights", config["task_weights"])
@@ -230,13 +351,35 @@ def objective(trial, train_loader, val_loader, train_cell_id_mapping, val_cell_i
230
 
231
  if config.get("use_wandb", False):
232
  import wandb
233
- wandb.log({
234
- "trial_number": trial.number,
235
- "val_loss": val_loss,
236
- **{f"{task_name}_f1": metrics['f1'] for task_name, metrics in task_metrics.items()},
237
- **{f"{task_name}_accuracy": metrics['accuracy'] for task_name, metrics in task_metrics.items()},
238
- **{k: v for k, v in config.items() if k in ["learning_rate", "warmup_ratio", "weight_decay", "dropout_rate", "lr_scheduler_type", "use_attention_pooling", "max_layers_to_freeze"]}
239
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  wandb.finish()
241
 
242
- return val_loss
 
 
 
 
 
 
 
1
  import os
 
2
  import random
3
+
4
  import numpy as np
5
+ import pandas as pd
6
  import torch
7
+ from torch.utils.tensorboard import SummaryWriter
8
+ from tqdm import tqdm
9
+
10
+ from .imports import *
11
+ from .model import GeneformerMultiTask
12
+ from .utils import calculate_task_specific_metrics
13
 
14
 
15
  def set_seed(seed):
 
20
  torch.backends.cudnn.deterministic = True
21
  torch.backends.cudnn.benchmark = False
22
 
23
+
24
  def initialize_wandb(config):
25
  if config.get("use_wandb", False):
26
  import wandb
27
+
28
  wandb.init(project=config["wandb_project"], config=config)
29
  print("Weights & Biases (wandb) initialized and will be used for logging.")
30
  else:
31
+ print(
32
+ "Weights & Biases (wandb) is not enabled. Logging will use other methods."
33
+ )
34
+
35
 
36
  def create_model(config, num_labels_list, device):
37
  model = GeneformerMultiTask(
 
41
  use_task_weights=config["use_task_weights"],
42
  task_weights=config["task_weights"],
43
  max_layers_to_freeze=config["max_layers_to_freeze"],
44
+ use_attention_pooling=config["use_attention_pooling"],
45
  )
46
  if config["use_data_parallel"]:
47
  model = nn.DataParallel(model)
48
  return model.to(device)
49
 
50
+
51
  def setup_optimizer_and_scheduler(model, config, total_steps):
52
+ optimizer = AdamW(
53
+ model.parameters(),
54
+ lr=config["learning_rate"],
55
+ weight_decay=config["weight_decay"],
56
+ )
57
  warmup_steps = int(config["warmup_ratio"] * total_steps)
58
 
59
  if config["lr_scheduler_type"] == "linear":
60
+ scheduler = get_linear_schedule_with_warmup(
61
+ optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
62
+ )
63
  elif config["lr_scheduler_type"] == "cosine":
64
+ scheduler = get_cosine_schedule_with_warmup(
65
+ optimizer,
66
+ num_warmup_steps=warmup_steps,
67
+ num_training_steps=total_steps,
68
+ num_cycles=0.5,
69
+ )
70
 
71
  return optimizer, scheduler
72
 
73
+
74
+ def train_epoch(
75
+ model, train_loader, optimizer, scheduler, device, config, writer, epoch
76
+ ):
77
  model.train()
78
  progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
79
  for batch_idx, batch in enumerate(progress_bar):
80
  optimizer.zero_grad()
81
+ input_ids = batch["input_ids"].to(device)
82
+ attention_mask = batch["attention_mask"].to(device)
83
+ labels = [
84
+ batch["labels"][task_name].to(device) for task_name in config["task_names"]
85
+ ]
86
 
87
  loss, _, _ = model(input_ids, attention_mask, labels)
88
  loss.backward()
 
93
  optimizer.step()
94
  scheduler.step()
95
 
96
+ writer.add_scalar(
97
+ "Training Loss", loss.item(), epoch * len(train_loader) + batch_idx
98
+ )
99
  if config.get("use_wandb", False):
100
+ import wandb
101
+
102
+ wandb.log({"Training Loss": loss.item()})
103
 
104
  # Update progress bar
105
+ progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
106
 
107
  return loss.item() # Return the last batch loss
108
 
109
+
110
  def validate_model(model, val_loader, device, config):
111
  model.eval()
112
  val_loss = 0.0
 
116
 
117
  with torch.no_grad():
118
  for batch in val_loader:
119
+ input_ids = batch["input_ids"].to(device)
120
+ attention_mask = batch["attention_mask"].to(device)
121
+ labels = [
122
+ batch["labels"][task_name].to(device)
123
+ for task_name in config["task_names"]
124
+ ]
125
  loss, logits, _ = model(input_ids, attention_mask, labels)
126
  val_loss += loss.item()
127
 
128
+ for sample_idx in range(len(batch["input_ids"])):
129
  for i, task_name in enumerate(config["task_names"]):
130
+ true_label = batch["labels"][task_name][sample_idx].item()
131
  pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
132
+ pred_prob = (
133
+ torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
134
+ )
135
  task_true_labels[task_name].append(true_label)
136
  task_pred_labels[task_name].append(pred_label)
137
  task_pred_probs[task_name].append(pred_prob)
 
139
  val_loss /= len(val_loader)
140
  return val_loss, task_true_labels, task_pred_labels, task_pred_probs
141
 
142
+
143
  def log_metrics(task_metrics, val_loss, config, writer, epochs):
144
  for task_name, metrics in task_metrics.items():
145
+ print(
146
+ f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}"
147
+ )
148
  if config.get("use_wandb", False):
149
  import wandb
 
 
 
 
150
 
151
+ wandb.log(
152
+ {
153
+ f"{task_name} Validation F1 Macro": metrics["f1"],
154
+ f"{task_name} Validation Accuracy": metrics["accuracy"],
155
+ }
156
+ )
157
 
158
+ writer.add_scalar("Validation Loss", val_loss, epochs)
159
+ for task_name, metrics in task_metrics.items():
160
+ writer.add_scalar(f"{task_name} - Validation F1 Macro", metrics["f1"], epochs)
161
+ writer.add_scalar(
162
+ f"{task_name} - Validation Accuracy", metrics["accuracy"], epochs
163
+ )
164
+
165
+
166
+ def save_validation_predictions(
167
+ val_cell_id_mapping,
168
+ task_true_labels,
169
+ task_pred_labels,
170
+ task_pred_probs,
171
+ config,
172
+ trial_number=None,
173
+ ):
174
  if trial_number is not None:
175
  trial_results_dir = os.path.join(config["results_dir"], f"trial_{trial_number}")
176
  os.makedirs(trial_results_dir, exist_ok=True)
177
  val_preds_file = os.path.join(trial_results_dir, "val_preds.csv")
178
  else:
179
  val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv")
180
+
181
  rows = []
182
  for sample_idx in range(len(val_cell_id_mapping)):
183
+ row = {"Cell ID": val_cell_id_mapping[sample_idx]}
184
  for task_name in config["task_names"]:
185
+ row[f"{task_name} True"] = task_true_labels[task_name][sample_idx]
186
+ row[f"{task_name} Pred"] = task_pred_labels[task_name][sample_idx]
187
+ row[f"{task_name} Probabilities"] = ",".join(
188
+ map(str, task_pred_probs[task_name][sample_idx])
189
+ )
190
  rows.append(row)
191
+
192
  df = pd.DataFrame(rows)
193
  df.to_csv(val_preds_file, index=False)
194
  print(f"Validation predictions saved to {val_preds_file}")
195
 
196
 
197
+ def train_model(
198
+ config,
199
+ device,
200
+ train_loader,
201
+ val_loader,
202
+ train_cell_id_mapping,
203
+ val_cell_id_mapping,
204
+ num_labels_list,
205
+ ):
206
  set_seed(config["seed"])
207
  initialize_wandb(config)
208
 
 
215
 
216
  epoch_progress = tqdm(range(config["epochs"]), desc="Training Progress")
217
  for epoch in epoch_progress:
218
+ last_loss = train_epoch(
219
+ model, train_loader, optimizer, scheduler, device, config, writer, epoch
220
+ )
221
+ epoch_progress.set_postfix({"last_loss": f"{last_loss:.4f}"})
222
 
223
+ val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
224
+ model, val_loader, device, config
225
+ )
226
  task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
227
+
228
  log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
229
  writer.close()
230
 
231
+ save_validation_predictions(
232
+ val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config
233
+ )
234
 
235
  if config.get("use_wandb", False):
236
  import wandb
237
+
238
  wandb.finish()
239
 
240
  print(f"\nFinal Validation Loss: {val_loss:.4f}")
241
  return val_loss, model # Return both the validation loss and the trained model
242
 
243
+
244
+ def objective(
245
+ trial,
246
+ train_loader,
247
+ val_loader,
248
+ train_cell_id_mapping,
249
+ val_cell_id_mapping,
250
+ num_labels_list,
251
+ config,
252
+ device,
253
+ ):
254
  set_seed(config["seed"]) # Set the seed before each trial
255
  initialize_wandb(config)
256
 
257
  # Hyperparameters
258
+ config["learning_rate"] = trial.suggest_float(
259
+ "learning_rate",
260
+ config["hyperparameters"]["learning_rate"]["low"],
261
+ config["hyperparameters"]["learning_rate"]["high"],
262
+ log=config["hyperparameters"]["learning_rate"]["log"],
263
+ )
264
+ config["warmup_ratio"] = trial.suggest_float(
265
+ "warmup_ratio",
266
+ config["hyperparameters"]["warmup_ratio"]["low"],
267
+ config["hyperparameters"]["warmup_ratio"]["high"],
268
+ )
269
+ config["weight_decay"] = trial.suggest_float(
270
+ "weight_decay",
271
+ config["hyperparameters"]["weight_decay"]["low"],
272
+ config["hyperparameters"]["weight_decay"]["high"],
273
+ )
274
+ config["dropout_rate"] = trial.suggest_float(
275
+ "dropout_rate",
276
+ config["hyperparameters"]["dropout_rate"]["low"],
277
+ config["hyperparameters"]["dropout_rate"]["high"],
278
+ )
279
+ config["lr_scheduler_type"] = trial.suggest_categorical(
280
+ "lr_scheduler_type", config["hyperparameters"]["lr_scheduler_type"]["choices"]
281
+ )
282
+ config["use_attention_pooling"] = trial.suggest_categorical(
283
+ "use_attention_pooling", [True, False]
284
+ )
285
 
286
  if config["use_task_weights"]:
287
+ config["task_weights"] = [
288
+ trial.suggest_float(
289
+ f"task_weight_{i}",
290
+ config["hyperparameters"]["task_weights"]["low"],
291
+ config["hyperparameters"]["task_weights"]["high"],
292
+ )
293
+ for i in range(len(num_labels_list))
294
+ ]
295
  weight_sum = sum(config["task_weights"])
296
+ config["task_weights"] = [
297
+ weight / weight_sum for weight in config["task_weights"]
298
+ ]
299
  else:
300
  config["task_weights"] = None
301
 
302
  # Fix for max_layers_to_freeze
303
  if isinstance(config["max_layers_to_freeze"], dict):
304
+ config["max_layers_to_freeze"] = trial.suggest_int(
305
+ "max_layers_to_freeze",
306
+ config["max_layers_to_freeze"]["min"],
307
+ config["max_layers_to_freeze"]["max"],
308
+ )
309
  elif isinstance(config["max_layers_to_freeze"], int):
310
  # If it's already an int, we don't need to suggest it
311
  pass
 
320
  writer = SummaryWriter(log_dir=log_dir)
321
 
322
  for epoch in range(config["epochs"]):
323
+ train_epoch(
324
+ model, train_loader, optimizer, scheduler, device, config, writer, epoch
325
+ )
326
 
327
+ val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
328
+ model, val_loader, device, config
329
+ )
330
  task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
331
+
332
  log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
333
  writer.close()
334
 
335
+ save_validation_predictions(
336
+ val_cell_id_mapping,
337
+ task_true_labels,
338
+ task_pred_labels,
339
+ task_pred_probs,
340
+ config,
341
+ trial.number,
342
+ )
343
 
344
  trial.set_user_attr("model_state_dict", model.state_dict())
345
  trial.set_user_attr("task_weights", config["task_weights"])
 
351
 
352
  if config.get("use_wandb", False):
353
  import wandb
354
+
355
+ wandb.log(
356
+ {
357
+ "trial_number": trial.number,
358
+ "val_loss": val_loss,
359
+ **{
360
+ f"{task_name}_f1": metrics["f1"]
361
+ for task_name, metrics in task_metrics.items()
362
+ },
363
+ **{
364
+ f"{task_name}_accuracy": metrics["accuracy"]
365
+ for task_name, metrics in task_metrics.items()
366
+ },
367
+ **{
368
+ k: v
369
+ for k, v in config.items()
370
+ if k
371
+ in [
372
+ "learning_rate",
373
+ "warmup_ratio",
374
+ "weight_decay",
375
+ "dropout_rate",
376
+ "lr_scheduler_type",
377
+ "use_attention_pooling",
378
+ "max_layers_to_freeze",
379
+ ]
380
+ },
381
+ }
382
+ )
383
  wandb.finish()
384
 
385
+ return val_loss
geneformer/mtl/train_utils.py CHANGED
@@ -1,9 +1,11 @@
 
 
 
1
  from .imports import *
2
- from .data import preload_and_process_data, get_data_loader
3
- from .train import objective, train_model
4
  from .model import GeneformerMultiTask
 
5
  from .utils import save_model
6
- import random
7
 
8
  def set_seed(seed):
9
  random.seed(seed)
@@ -12,15 +14,22 @@ def set_seed(seed):
12
  torch.cuda.manual_seed_all(seed)
13
  torch.backends.cudnn.deterministic = True
14
  torch.backends.cudnn.benchmark = False
15
-
 
16
  def run_manual_tuning(config):
17
  # Set seed for reproducibility
18
  set_seed(config["seed"])
19
 
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
- train_dataset, train_cell_id_mapping, val_dataset, val_cell_id_mapping, num_labels_list = preload_and_process_data(config)
22
- train_loader = get_data_loader(train_dataset, config['batch_size'])
23
- val_loader = get_data_loader(val_dataset, config['batch_size'])
 
 
 
 
 
 
24
 
25
  # Print the manual hyperparameters being used
26
  print("\nManual hyperparameters being used:")
@@ -33,12 +42,22 @@ def run_manual_tuning(config):
33
  config[key] = value
34
 
35
  # Train the model
36
- val_loss, trained_model = train_model(config, device, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list)
 
 
 
 
 
 
 
 
37
 
38
  print(f"\nValidation loss with manual hyperparameters: {val_loss}")
39
 
40
  # Save the trained model
41
- model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
 
 
42
  save_model(trained_model, model_save_directory)
43
 
44
  # Save the hyperparameters
@@ -48,26 +67,41 @@ def run_manual_tuning(config):
48
  "use_task_weights": config["use_task_weights"],
49
  "task_weights": config["task_weights"],
50
  "max_layers_to_freeze": config["max_layers_to_freeze"],
51
- "use_attention_pooling": config["use_attention_pooling"]
52
  }
53
  hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
54
- with open(hyperparams_path, 'w') as f:
55
  json.dump(hyperparams_to_save, f)
56
  print(f"Manual hyperparameters saved to {hyperparams_path}")
57
 
58
  return val_loss
59
 
 
60
  def run_optuna_study(config):
61
  # Set seed for reproducibility
62
  set_seed(config["seed"])
63
 
64
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
- train_dataset, train_cell_id_mapping, val_dataset, val_cell_id_mapping, num_labels_list = preload_and_process_data(config)
66
- train_loader = get_data_loader(train_dataset, config['batch_size'])
67
- val_loader = get_data_loader(val_dataset, config['batch_size'])
 
 
 
 
 
 
68
 
69
  if config["use_manual_hyperparameters"]:
70
- train_model(config, device, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list)
 
 
 
 
 
 
 
 
71
  else:
72
  objective_with_config_and_data = functools.partial(
73
  objective,
@@ -77,20 +111,17 @@ def run_optuna_study(config):
77
  val_cell_id_mapping=val_cell_id_mapping,
78
  num_labels_list=num_labels_list,
79
  config=config,
80
- device=device
81
  )
82
 
83
  study = optuna.create_study(
84
- direction='minimize', # Minimize validation loss
85
  study_name=config["study_name"],
86
- #storage=config["storage"],
87
- load_if_exists=True
88
  )
89
 
90
- study.optimize(
91
- objective_with_config_and_data,
92
- n_trials=config["n_trials"]
93
- )
94
 
95
  # After finding the best trial
96
  best_params = study.best_trial.params
@@ -103,24 +134,28 @@ def run_optuna_study(config):
103
  num_labels_list,
104
  dropout_rate=best_params["dropout_rate"],
105
  use_task_weights=config["use_task_weights"],
106
- task_weights=best_task_weights
107
  )
108
-
109
  # Get the best model state dictionary
110
  best_model_state_dict = study.best_trial.user_attrs["model_state_dict"]
111
 
112
  # Remove the "module." prefix from the state dictionary keys if present
113
- best_model_state_dict = {k.replace("module.", ""): v for k, v in best_model_state_dict.items()}
 
 
114
 
115
  # Load the modified state dictionary into the model, skipping unexpected keys
116
  best_model.load_state_dict(best_model_state_dict, strict=False)
117
 
118
- model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
 
 
119
  save_model(best_model, model_save_directory)
120
 
121
  # Additionally, save the best hyperparameters and task weights
122
  hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
123
-
124
- with open(hyperparams_path, 'w') as f:
125
  json.dump({**best_params, "task_weights": best_task_weights}, f)
126
- print(f"Best hyperparameters and task weights saved to {hyperparams_path}")
 
1
+ import random
2
+
3
+ from .data import get_data_loader, preload_and_process_data
4
  from .imports import *
 
 
5
  from .model import GeneformerMultiTask
6
+ from .train import objective, train_model
7
  from .utils import save_model
8
+
9
 
10
  def set_seed(seed):
11
  random.seed(seed)
 
14
  torch.cuda.manual_seed_all(seed)
15
  torch.backends.cudnn.deterministic = True
16
  torch.backends.cudnn.benchmark = False
17
+
18
+
19
  def run_manual_tuning(config):
20
  # Set seed for reproducibility
21
  set_seed(config["seed"])
22
 
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ (
25
+ train_dataset,
26
+ train_cell_id_mapping,
27
+ val_dataset,
28
+ val_cell_id_mapping,
29
+ num_labels_list,
30
+ ) = preload_and_process_data(config)
31
+ train_loader = get_data_loader(train_dataset, config["batch_size"])
32
+ val_loader = get_data_loader(val_dataset, config["batch_size"])
33
 
34
  # Print the manual hyperparameters being used
35
  print("\nManual hyperparameters being used:")
 
42
  config[key] = value
43
 
44
  # Train the model
45
+ val_loss, trained_model = train_model(
46
+ config,
47
+ device,
48
+ train_loader,
49
+ val_loader,
50
+ train_cell_id_mapping,
51
+ val_cell_id_mapping,
52
+ num_labels_list,
53
+ )
54
 
55
  print(f"\nValidation loss with manual hyperparameters: {val_loss}")
56
 
57
  # Save the trained model
58
+ model_save_directory = os.path.join(
59
+ config["model_save_path"], "GeneformerMultiTask"
60
+ )
61
  save_model(trained_model, model_save_directory)
62
 
63
  # Save the hyperparameters
 
67
  "use_task_weights": config["use_task_weights"],
68
  "task_weights": config["task_weights"],
69
  "max_layers_to_freeze": config["max_layers_to_freeze"],
70
+ "use_attention_pooling": config["use_attention_pooling"],
71
  }
72
  hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
73
+ with open(hyperparams_path, "w") as f:
74
  json.dump(hyperparams_to_save, f)
75
  print(f"Manual hyperparameters saved to {hyperparams_path}")
76
 
77
  return val_loss
78
 
79
+
80
  def run_optuna_study(config):
81
  # Set seed for reproducibility
82
  set_seed(config["seed"])
83
 
84
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85
+ (
86
+ train_dataset,
87
+ train_cell_id_mapping,
88
+ val_dataset,
89
+ val_cell_id_mapping,
90
+ num_labels_list,
91
+ ) = preload_and_process_data(config)
92
+ train_loader = get_data_loader(train_dataset, config["batch_size"])
93
+ val_loader = get_data_loader(val_dataset, config["batch_size"])
94
 
95
  if config["use_manual_hyperparameters"]:
96
+ train_model(
97
+ config,
98
+ device,
99
+ train_loader,
100
+ val_loader,
101
+ train_cell_id_mapping,
102
+ val_cell_id_mapping,
103
+ num_labels_list,
104
+ )
105
  else:
106
  objective_with_config_and_data = functools.partial(
107
  objective,
 
111
  val_cell_id_mapping=val_cell_id_mapping,
112
  num_labels_list=num_labels_list,
113
  config=config,
114
+ device=device,
115
  )
116
 
117
  study = optuna.create_study(
118
+ direction="minimize", # Minimize validation loss
119
  study_name=config["study_name"],
120
+ # storage=config["storage"],
121
+ load_if_exists=True,
122
  )
123
 
124
+ study.optimize(objective_with_config_and_data, n_trials=config["n_trials"])
 
 
 
125
 
126
  # After finding the best trial
127
  best_params = study.best_trial.params
 
134
  num_labels_list,
135
  dropout_rate=best_params["dropout_rate"],
136
  use_task_weights=config["use_task_weights"],
137
+ task_weights=best_task_weights,
138
  )
139
+
140
  # Get the best model state dictionary
141
  best_model_state_dict = study.best_trial.user_attrs["model_state_dict"]
142
 
143
  # Remove the "module." prefix from the state dictionary keys if present
144
+ best_model_state_dict = {
145
+ k.replace("module.", ""): v for k, v in best_model_state_dict.items()
146
+ }
147
 
148
  # Load the modified state dictionary into the model, skipping unexpected keys
149
  best_model.load_state_dict(best_model_state_dict, strict=False)
150
 
151
+ model_save_directory = os.path.join(
152
+ config["model_save_path"], "GeneformerMultiTask"
153
+ )
154
  save_model(best_model, model_save_directory)
155
 
156
  # Additionally, save the best hyperparameters and task weights
157
  hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
158
+
159
+ with open(hyperparams_path, "w") as f:
160
  json.dump({**best_params, "task_weights": best_task_weights}, f)
161
+ print(f"Best hyperparameters and task weights saved to {hyperparams_path}")
geneformer/mtl/utils.py CHANGED
@@ -1,44 +1,55 @@
1
- from .imports import *
2
- from sklearn.metrics import f1_score, accuracy_score
3
- from sklearn.preprocessing import LabelEncoder
4
- from transformers import BertModel, BertConfig, AutoConfig
5
  import os
6
  import shutil
7
 
 
 
 
 
 
 
 
8
  def save_model(model, model_save_directory):
9
  if not os.path.exists(model_save_directory):
10
  os.makedirs(model_save_directory)
11
-
12
  # Get the state dict
13
  if isinstance(model, nn.DataParallel):
14
- model_state_dict = model.module.state_dict() # Use model.module to access the underlying model
 
 
15
  else:
16
  model_state_dict = model.state_dict()
17
-
18
  # Remove the "module." prefix from the keys if present
19
- model_state_dict = {k.replace("module.", ""): v for k, v in model_state_dict.items()}
20
-
 
 
21
  model_save_path = os.path.join(model_save_directory, "pytorch_model.bin")
22
  torch.save(model_state_dict, model_save_path)
23
-
24
  # Save the model configuration
25
  if isinstance(model, nn.DataParallel):
26
- model.module.config.to_json_file(os.path.join(model_save_directory, "config.json"))
 
 
27
  else:
28
  model.config.to_json_file(os.path.join(model_save_directory, "config.json"))
29
-
30
  print(f"Model and configuration saved to {model_save_directory}")
31
 
 
32
  def calculate_task_specific_metrics(task_true_labels, task_pred_labels):
33
  task_metrics = {}
34
  for task_name in task_true_labels.keys():
35
  true_labels = task_true_labels[task_name]
36
  pred_labels = task_pred_labels[task_name]
37
- f1 = f1_score(true_labels, pred_labels, average='macro')
38
  accuracy = accuracy_score(true_labels, pred_labels)
39
- task_metrics[task_name] = {'f1': f1, 'accuracy': accuracy}
40
  return task_metrics
41
 
 
42
  def calculate_combined_f1(combined_labels, combined_preds):
43
  # Initialize the LabelEncoder
44
  le = LabelEncoder()
@@ -57,10 +68,11 @@ def calculate_combined_f1(combined_labels, combined_preds):
57
  accuracy = accuracy_score(encoded_true_labels, encoded_pred_labels)
58
 
59
  # Calculate F1 Macro score
60
- f1 = f1_score(encoded_true_labels, encoded_pred_labels, average='macro')
61
 
62
  return f1, accuracy
63
 
 
64
  def save_model_without_heads(original_model_save_directory):
65
  # Create a new directory for the model without heads
66
  new_model_save_directory = original_model_save_directory + "_No_Heads"
@@ -68,25 +80,36 @@ def save_model_without_heads(original_model_save_directory):
68
  os.makedirs(new_model_save_directory)
69
 
70
  # Load the model state dictionary
71
- model_state_dict = torch.load(os.path.join(original_model_save_directory, "pytorch_model.bin"))
 
 
72
 
73
  # Initialize a new BERT model without the classification heads
74
- config = BertConfig.from_pretrained(os.path.join(original_model_save_directory, "config.json"))
 
 
75
  model_without_heads = BertModel(config)
76
-
77
  # Filter the state dict to exclude classification heads
78
- model_without_heads_state_dict = {k: v for k, v in model_state_dict.items() if not k.startswith("classification_heads")}
79
-
 
 
 
 
80
  # Load the filtered state dict into the model
81
  model_without_heads.load_state_dict(model_without_heads_state_dict, strict=False)
82
-
83
  # Save the model without heads
84
  model_save_path = os.path.join(new_model_save_directory, "pytorch_model.bin")
85
  torch.save(model_without_heads.state_dict(), model_save_path)
86
-
87
  # Copy the configuration file
88
- shutil.copy(os.path.join(original_model_save_directory, "config.json"), new_model_save_directory)
89
-
 
 
 
90
  print(f"Model without classification heads saved to {new_model_save_directory}")
91
 
92
 
@@ -100,7 +123,7 @@ def get_layer_freeze_range(pretrained_path):
100
  """
101
  if pretrained_path:
102
  config = AutoConfig.from_pretrained(pretrained_path)
103
- total_layers = config.num_hidden_layers
104
- return {"min": 0, "max": total_layers - 1}
105
  else:
106
- return {"min": 0, "max": 0}
 
 
 
 
 
1
  import os
2
  import shutil
3
 
4
+ from sklearn.metrics import accuracy_score, f1_score
5
+ from sklearn.preprocessing import LabelEncoder
6
+ from transformers import AutoConfig, BertConfig, BertModel
7
+
8
+ from .imports import *
9
+
10
+
11
  def save_model(model, model_save_directory):
12
  if not os.path.exists(model_save_directory):
13
  os.makedirs(model_save_directory)
14
+
15
  # Get the state dict
16
  if isinstance(model, nn.DataParallel):
17
+ model_state_dict = (
18
+ model.module.state_dict()
19
+ ) # Use model.module to access the underlying model
20
  else:
21
  model_state_dict = model.state_dict()
22
+
23
  # Remove the "module." prefix from the keys if present
24
+ model_state_dict = {
25
+ k.replace("module.", ""): v for k, v in model_state_dict.items()
26
+ }
27
+
28
  model_save_path = os.path.join(model_save_directory, "pytorch_model.bin")
29
  torch.save(model_state_dict, model_save_path)
30
+
31
  # Save the model configuration
32
  if isinstance(model, nn.DataParallel):
33
+ model.module.config.to_json_file(
34
+ os.path.join(model_save_directory, "config.json")
35
+ )
36
  else:
37
  model.config.to_json_file(os.path.join(model_save_directory, "config.json"))
38
+
39
  print(f"Model and configuration saved to {model_save_directory}")
40
 
41
+
42
  def calculate_task_specific_metrics(task_true_labels, task_pred_labels):
43
  task_metrics = {}
44
  for task_name in task_true_labels.keys():
45
  true_labels = task_true_labels[task_name]
46
  pred_labels = task_pred_labels[task_name]
47
+ f1 = f1_score(true_labels, pred_labels, average="macro")
48
  accuracy = accuracy_score(true_labels, pred_labels)
49
+ task_metrics[task_name] = {"f1": f1, "accuracy": accuracy}
50
  return task_metrics
51
 
52
+
53
  def calculate_combined_f1(combined_labels, combined_preds):
54
  # Initialize the LabelEncoder
55
  le = LabelEncoder()
 
68
  accuracy = accuracy_score(encoded_true_labels, encoded_pred_labels)
69
 
70
  # Calculate F1 Macro score
71
+ f1 = f1_score(encoded_true_labels, encoded_pred_labels, average="macro")
72
 
73
  return f1, accuracy
74
 
75
+
76
  def save_model_without_heads(original_model_save_directory):
77
  # Create a new directory for the model without heads
78
  new_model_save_directory = original_model_save_directory + "_No_Heads"
 
80
  os.makedirs(new_model_save_directory)
81
 
82
  # Load the model state dictionary
83
+ model_state_dict = torch.load(
84
+ os.path.join(original_model_save_directory, "pytorch_model.bin")
85
+ )
86
 
87
  # Initialize a new BERT model without the classification heads
88
+ config = BertConfig.from_pretrained(
89
+ os.path.join(original_model_save_directory, "config.json")
90
+ )
91
  model_without_heads = BertModel(config)
92
+
93
  # Filter the state dict to exclude classification heads
94
+ model_without_heads_state_dict = {
95
+ k: v
96
+ for k, v in model_state_dict.items()
97
+ if not k.startswith("classification_heads")
98
+ }
99
+
100
  # Load the filtered state dict into the model
101
  model_without_heads.load_state_dict(model_without_heads_state_dict, strict=False)
102
+
103
  # Save the model without heads
104
  model_save_path = os.path.join(new_model_save_directory, "pytorch_model.bin")
105
  torch.save(model_without_heads.state_dict(), model_save_path)
106
+
107
  # Copy the configuration file
108
+ shutil.copy(
109
+ os.path.join(original_model_save_directory, "config.json"),
110
+ new_model_save_directory,
111
+ )
112
+
113
  print(f"Model without classification heads saved to {new_model_save_directory}")
114
 
115
 
 
123
  """
124
  if pretrained_path:
125
  config = AutoConfig.from_pretrained(pretrained_path)
126
+ total_layers = config.num_hidden_layers
127
+ return {"min": 0, "max": total_layers - 1}
128
  else:
129
+ return {"min": 0, "max": 0}
geneformer/mtl_classifier.py CHANGED
@@ -28,9 +28,8 @@ Geneformer multi-task cell classifier.
28
 
29
  import logging
30
  import os
31
- from .mtl import train_utils
32
- from .mtl import utils
33
- from .mtl import eval_utils
34
 
35
  logger = logging.getLogger(__name__)
36
 
@@ -90,9 +89,8 @@ class MTLClassifier:
90
  wandb_project=None,
91
  gradient_clipping=False,
92
  max_grad_norm=None,
93
- seed=42 # Default seed value
94
  ):
95
-
96
  """
97
  Initialize Geneformer multi-task classifier.
98
  **Parameters:**
@@ -165,11 +163,11 @@ class MTLClassifier:
165
  self.batch_size = batch_size
166
  self.n_trials = n_trials
167
  self.study_name = study_name
168
-
169
  if max_layers_to_freeze is None:
170
  # Dynamically determine the range of layers to freeze
171
  layer_freeze_range = utils.get_layer_freeze_range(pretrained_path)
172
- self.max_layers_to_freeze = {"min": 1, "max": layer_freeze_range['max']}
173
  else:
174
  self.max_layers_to_freeze = max_layers_to_freeze
175
 
@@ -178,48 +176,37 @@ class MTLClassifier:
178
  self.use_data_parallel = use_data_parallel
179
  self.use_attention_pooling = use_attention_pooling
180
  self.use_task_weights = use_task_weights
181
- self.hyperparameters = hyperparameters if hyperparameters is not None else {
182
- "learning_rate": {
183
- "type": "float",
184
- "low": 1e-5,
185
- "high": 1e-3,
186
- "log": True
187
- },
188
- "warmup_ratio": {
189
- "type": "float",
190
- "low": 0.005,
191
- "high": 0.01
192
- },
193
- "weight_decay": {
194
- "type": "float",
195
- "low": 0.01,
196
- "high": 0.1
197
- },
198
- "dropout_rate": {
199
- "type": "float",
200
- "low": 0.0,
201
- "high": 0.7
202
- },
203
- "lr_scheduler_type": {
204
- "type": "categorical",
205
- "choices": ["cosine"]
206
- },
207
- "task_weights": {
208
- "type": "float",
209
- "low": 0.1,
210
- "high": 2.0
211
  }
212
- }
213
- self.manual_hyperparameters = manual_hyperparameters if manual_hyperparameters is not None else {
214
- "learning_rate": 0.001,
215
- "warmup_ratio": 0.01,
216
- "weight_decay": 0.1,
217
- "dropout_rate": 0.1,
218
- "lr_scheduler_type": "cosine",
219
- "use_attention_pooling": False,
220
- "task_weights": [1, 1],
221
- "max_layers_to_freeze": 2
222
- }
 
 
 
 
223
  self.use_manual_hyperparameters = use_manual_hyperparameters
224
  self.use_wandb = use_wandb
225
  self.wandb_project = wandb_project
@@ -236,13 +223,19 @@ class MTLClassifier:
236
 
237
  # set up output directories
238
  if self.results_dir is not None:
239
- self.trials_results_path = f"{self.results_dir}/results.txt".replace("//","/")
240
-
 
 
241
  for output_dir in [self.model_save_path, self.results_dir]:
242
  if not os.path.exists(output_dir):
243
  os.makedirs(output_dir)
244
 
245
- self.config = {key: value for key, value in self.__dict__.items() if key in self.valid_option_dict}
 
 
 
 
246
 
247
  def validate_options(self):
248
  # confirm arguments are within valid options and compatible with each other
@@ -264,19 +257,35 @@ class MTLClassifier:
264
  f"Invalid option for {attr_name}. "
265
  f"Valid options for {attr_name}: {valid_options}"
266
  )
267
- raise ValueError(f"Invalid option for {attr_name}. Valid options for {attr_name}: {valid_options}")
 
 
268
 
269
  def run_manual_tuning(self):
270
  """
271
  Manual hyperparameter tuning and multi-task fine-tuning of pretrained model.
272
  """
273
- required_variable_names = ["train_path", "val_path", "pretrained_path", "model_save_path", "results_dir"]
274
- required_variables = [self.train_path, self.val_path, self.pretrained_path, self.model_save_path, self.results_dir]
 
 
 
 
 
 
 
 
 
 
 
 
275
  req_var_dict = dict(zip(required_variable_names, required_variables))
276
  self.validate_additional_options(req_var_dict)
277
 
278
  if not self.use_manual_hyperparameters:
279
- raise ValueError("Manual hyperparameters are not enabled. Set use_manual_hyperparameters to True.")
 
 
280
 
281
  # Ensure manual_hyperparameters are set in the config
282
  self.config["manual_hyperparameters"] = self.manual_hyperparameters
@@ -302,8 +311,20 @@ class MTLClassifier:
302
  Hyperparameter optimization and/or multi-task fine-tuning of pretrained model.
303
  """
304
 
305
- required_variable_names = ["train_path", "val_path", "pretrained_path", "model_save_path", "results_dir"]
306
- required_variables = [self.train_path, self.val_path, self.pretrained_path, self.model_save_path, self.results_dir]
 
 
 
 
 
 
 
 
 
 
 
 
307
  req_var_dict = dict(zip(required_variable_names, required_variables))
308
  self.validate_additional_options(req_var_dict)
309
 
@@ -322,7 +343,7 @@ class MTLClassifier:
322
  self.validate_additional_options(req_var_dict)
323
 
324
  eval_utils.load_and_evaluate_test_model(self.config)
325
-
326
  def save_model_without_heads(
327
  self,
328
  ):
@@ -335,4 +356,6 @@ class MTLClassifier:
335
  req_var_dict = dict(zip(required_variable_names, required_variables))
336
  self.validate_additional_options(req_var_dict)
337
 
338
- utils.save_model_without_heads(os.path.join(self.model_save_path, "GeneformerMultiTask"))
 
 
 
28
 
29
  import logging
30
  import os
31
+
32
+ from .mtl import eval_utils, train_utils, utils
 
33
 
34
  logger = logging.getLogger(__name__)
35
 
 
89
  wandb_project=None,
90
  gradient_clipping=False,
91
  max_grad_norm=None,
92
+ seed=42, # Default seed value
93
  ):
 
94
  """
95
  Initialize Geneformer multi-task classifier.
96
  **Parameters:**
 
163
  self.batch_size = batch_size
164
  self.n_trials = n_trials
165
  self.study_name = study_name
166
+
167
  if max_layers_to_freeze is None:
168
  # Dynamically determine the range of layers to freeze
169
  layer_freeze_range = utils.get_layer_freeze_range(pretrained_path)
170
+ self.max_layers_to_freeze = {"min": 1, "max": layer_freeze_range["max"]}
171
  else:
172
  self.max_layers_to_freeze = max_layers_to_freeze
173
 
 
176
  self.use_data_parallel = use_data_parallel
177
  self.use_attention_pooling = use_attention_pooling
178
  self.use_task_weights = use_task_weights
179
+ self.hyperparameters = (
180
+ hyperparameters
181
+ if hyperparameters is not None
182
+ else {
183
+ "learning_rate": {
184
+ "type": "float",
185
+ "low": 1e-5,
186
+ "high": 1e-3,
187
+ "log": True,
188
+ },
189
+ "warmup_ratio": {"type": "float", "low": 0.005, "high": 0.01},
190
+ "weight_decay": {"type": "float", "low": 0.01, "high": 0.1},
191
+ "dropout_rate": {"type": "float", "low": 0.0, "high": 0.7},
192
+ "lr_scheduler_type": {"type": "categorical", "choices": ["cosine"]},
193
+ "task_weights": {"type": "float", "low": 0.1, "high": 2.0},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  }
195
+ )
196
+ self.manual_hyperparameters = (
197
+ manual_hyperparameters
198
+ if manual_hyperparameters is not None
199
+ else {
200
+ "learning_rate": 0.001,
201
+ "warmup_ratio": 0.01,
202
+ "weight_decay": 0.1,
203
+ "dropout_rate": 0.1,
204
+ "lr_scheduler_type": "cosine",
205
+ "use_attention_pooling": False,
206
+ "task_weights": [1, 1],
207
+ "max_layers_to_freeze": 2,
208
+ }
209
+ )
210
  self.use_manual_hyperparameters = use_manual_hyperparameters
211
  self.use_wandb = use_wandb
212
  self.wandb_project = wandb_project
 
223
 
224
  # set up output directories
225
  if self.results_dir is not None:
226
+ self.trials_results_path = f"{self.results_dir}/results.txt".replace(
227
+ "//", "/"
228
+ )
229
+
230
  for output_dir in [self.model_save_path, self.results_dir]:
231
  if not os.path.exists(output_dir):
232
  os.makedirs(output_dir)
233
 
234
+ self.config = {
235
+ key: value
236
+ for key, value in self.__dict__.items()
237
+ if key in self.valid_option_dict
238
+ }
239
 
240
  def validate_options(self):
241
  # confirm arguments are within valid options and compatible with each other
 
257
  f"Invalid option for {attr_name}. "
258
  f"Valid options for {attr_name}: {valid_options}"
259
  )
260
+ raise ValueError(
261
+ f"Invalid option for {attr_name}. Valid options for {attr_name}: {valid_options}"
262
+ )
263
 
264
  def run_manual_tuning(self):
265
  """
266
  Manual hyperparameter tuning and multi-task fine-tuning of pretrained model.
267
  """
268
+ required_variable_names = [
269
+ "train_path",
270
+ "val_path",
271
+ "pretrained_path",
272
+ "model_save_path",
273
+ "results_dir",
274
+ ]
275
+ required_variables = [
276
+ self.train_path,
277
+ self.val_path,
278
+ self.pretrained_path,
279
+ self.model_save_path,
280
+ self.results_dir,
281
+ ]
282
  req_var_dict = dict(zip(required_variable_names, required_variables))
283
  self.validate_additional_options(req_var_dict)
284
 
285
  if not self.use_manual_hyperparameters:
286
+ raise ValueError(
287
+ "Manual hyperparameters are not enabled. Set use_manual_hyperparameters to True."
288
+ )
289
 
290
  # Ensure manual_hyperparameters are set in the config
291
  self.config["manual_hyperparameters"] = self.manual_hyperparameters
 
311
  Hyperparameter optimization and/or multi-task fine-tuning of pretrained model.
312
  """
313
 
314
+ required_variable_names = [
315
+ "train_path",
316
+ "val_path",
317
+ "pretrained_path",
318
+ "model_save_path",
319
+ "results_dir",
320
+ ]
321
+ required_variables = [
322
+ self.train_path,
323
+ self.val_path,
324
+ self.pretrained_path,
325
+ self.model_save_path,
326
+ self.results_dir,
327
+ ]
328
  req_var_dict = dict(zip(required_variable_names, required_variables))
329
  self.validate_additional_options(req_var_dict)
330
 
 
343
  self.validate_additional_options(req_var_dict)
344
 
345
  eval_utils.load_and_evaluate_test_model(self.config)
346
+
347
  def save_model_without_heads(
348
  self,
349
  ):
 
356
  req_var_dict = dict(zip(required_variable_names, required_variables))
357
  self.validate_additional_options(req_var_dict)
358
 
359
+ utils.save_model_without_heads(
360
+ os.path.join(self.model_save_path, "GeneformerMultiTask")
361
+ )
geneformer/perturber_utils.py CHANGED
@@ -1,18 +1,15 @@
1
  import itertools as it
2
  import logging
3
  import pickle
4
- import re
5
  from collections import defaultdict
6
- from typing import List
7
  from pathlib import Path
8
-
9
 
10
  import numpy as np
11
  import pandas as pd
12
- import seaborn as sns
13
  import torch
14
  from datasets import Dataset, load_from_disk
15
- from peft import LoraConfig, get_peft_model
16
  from transformers import (
17
  BertForMaskedLM,
18
  BertForSequenceClassification,
@@ -119,7 +116,7 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
119
  if model_type == "MTLCellClassifier-Quantized":
120
  model_type = "MTLCellClassifier"
121
  quantize = True
122
-
123
  if mode == "eval":
124
  output_hidden_states = True
125
  elif mode == "train":
@@ -131,7 +128,7 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
131
  "peft_config": None,
132
  "bnb_config": BitsAndBytesConfig(
133
  load_in_8bit=True,
134
- )
135
  }
136
  else:
137
  quantize = {
@@ -141,13 +138,13 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
141
  r=64,
142
  bias="none",
143
  task_type="TokenClassification",
144
- ),
145
  "bnb_config": BitsAndBytesConfig(
146
  load_in_4bit=True,
147
  bnb_4bit_use_double_quant=True,
148
  bnb_4bit_quant_type="nf4",
149
- bnb_4bit_compute_dtype=torch.bfloat16
150
- )
151
  }
152
  elif quantize is False:
153
  quantize = {"bnb_config": None}
@@ -186,7 +183,11 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
186
  # if eval mode, put the model in eval mode for fwd pass
187
  if mode == "eval":
188
  model.eval()
189
- if (quantize is False) or (quantize == {'bnb_config': None}) or (model_type == "MTLCellClassifier"):
 
 
 
 
190
  model = model.to("cuda")
191
  else:
192
  model.enable_input_require_grads()
@@ -279,18 +280,20 @@ def overexpress_indices(example):
279
  example["length"] = len(example["input_ids"])
280
  return example
281
 
 
282
  # if CLS token present, move to 1st rather than 0th position
283
  def overexpress_indices_special(example):
284
  indices = example["perturb_index"]
285
  if any(isinstance(el, list) for el in indices):
286
  indices = flatten_list(indices)
287
- insert_pos = 1 # Insert starting after CLS token
288
  for index in sorted(indices, reverse=False):
289
  example["input_ids"].insert(insert_pos, example["input_ids"].pop(index))
290
  insert_pos += 1
291
  example["length"] = len(example["input_ids"])
292
  return example
293
 
 
294
  # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
295
  def overexpress_tokens(example, max_len, special_token):
296
  # -100 indicates tokens to overexpress are not present in rank value encoding
@@ -310,7 +313,9 @@ def overexpress_tokens(example, max_len, special_token):
310
  # truncate to max input size, must also truncate original emb to be comparable
311
  if len(example["input_ids"]) > max_len:
312
  if special_token:
313
- example["input_ids"] = example["input_ids"][0:max_len-1]+[example["input_ids"][-1]]
 
 
314
  else:
315
  example["input_ids"] = example["input_ids"][0:max_len]
316
  example["length"] = len(example["input_ids"])
@@ -329,10 +334,13 @@ def truncate_by_n_overflow(example):
329
  example["length"] = len(example["input_ids"])
330
  return example
331
 
 
332
  def truncate_by_n_overflow_special(example):
333
  if example["n_overflow"] > 0:
334
  new_max_len = example["length"] - example["n_overflow"]
335
- example["input_ids"] = example["input_ids"][0:new_max_len-1]+[example["input_ids"][-1]]
 
 
336
  example["length"] = len(example["input_ids"])
337
  return example
338
 
@@ -477,19 +485,24 @@ def make_perturbation_batch_special(
477
  range_start = 1
478
  elif perturb_type in ["delete", "inhibit"]:
479
  range_start = 0
480
- range_start += 1 # Starting after the CLS token
481
  indices_to_perturb = [
482
- [i] for i in range(range_start, example_cell["length"][0]-1) # And excluding the EOS token
 
 
 
483
  ]
484
 
485
  # elif combo_lvl > 0 and anchor_token is None:
486
  ## to implement
487
- elif combo_lvl > 0 and (anchor_token is not None):
488
  example_input_ids = example_cell["input_ids"][0]
489
  anchor_index = example_input_ids.index(anchor_token[0])
490
  indices_to_perturb = [
491
  sorted([anchor_index, i]) if i != anchor_index else None
492
- for i in range(1, example_cell["length"][0]-1) # Exclude CLS and EOS tokens
 
 
493
  ]
494
  indices_to_perturb = [item for item in indices_to_perturb if item is not None]
495
  else:
@@ -508,7 +521,9 @@ def make_perturbation_batch_special(
508
  list(x) for x in it.combinations(indices_to_perturb, combo_lvl + 1)
509
  ]
510
  else:
511
- all_indices = [[i] for i in range(1, example_cell["length"][0]-1)] # Exclude CLS and EOS tokens
 
 
512
  all_indices = [
513
  index for index in all_indices if index not in indices_to_perturb
514
  ]
@@ -535,7 +550,7 @@ def make_perturbation_batch_special(
535
  )
536
  elif perturb_type == "overexpress":
537
  perturbation_dataset = perturbation_dataset.map(
538
- overexpress_indices_special, num_proc=num_proc_i
539
  )
540
 
541
  perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
@@ -743,7 +758,7 @@ def quant_cos_sims(
743
  # against original cell
744
  if cell_states_to_model is None or emb_mode == "gene":
745
  cos_sims = cos(perturbation_emb, original_emb).to("cuda")
746
-
747
  elif cell_states_to_model is not None and emb_mode == "cell":
748
  possible_states = get_possible_states(cell_states_to_model)
749
  cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
@@ -867,27 +882,28 @@ def validate_cell_states_to_model(cell_states_to_model):
867
  )
868
  raise
869
 
 
870
  class GeneIdHandler:
871
  def __init__(self, raise_errors=False):
872
  def invert_dict(dict_obj):
873
- return {v:k for k,v in dict_obj.items()}
874
-
875
  self.raise_errors = raise_errors
876
-
877
- with open(TOKEN_DICTIONARY_FILE, 'rb') as f:
878
  self.gene_token_dict = pickle.load(f)
879
  self.token_gene_dict = invert_dict(self.gene_token_dict)
880
 
881
- with open(ENSEMBL_DICTIONARY_FILE, 'rb') as f:
882
  self.id_gene_dict = pickle.load(f)
883
  self.gene_id_dict = invert_dict(self.id_gene_dict)
884
-
885
  def ens_to_token(self, ens_id):
886
  if not self.raise_errors:
887
  return self.gene_token_dict.get(ens_id, ens_id)
888
  else:
889
  return self.gene_token_dict[ens_id]
890
-
891
  def token_to_ens(self, token):
892
  if not self.raise_errors:
893
  return self.token_gene_dict.get(token, token)
@@ -899,15 +915,15 @@ class GeneIdHandler:
899
  return self.gene_id_dict.get(ens_id, ens_id)
900
  else:
901
  return self.gene_id_dict[ens_id]
902
-
903
  def symbol_to_ens(self, symbol):
904
  if not self.raise_errors:
905
  return self.id_gene_dict.get(symbol, symbol)
906
  else:
907
  return self.id_gene_dict[symbol]
908
-
909
  def token_to_symbol(self, token):
910
  return self.ens_to_symbol(self.token_to_ens(token))
911
-
912
  def symbol_to_token(self, symbol):
913
- return self.ens_to_token(self.symbol_to_ens(symbol))
 
1
  import itertools as it
2
  import logging
3
  import pickle
 
4
  from collections import defaultdict
 
5
  from pathlib import Path
6
+ from typing import List
7
 
8
  import numpy as np
9
  import pandas as pd
 
10
  import torch
11
  from datasets import Dataset, load_from_disk
12
+ from peft import LoraConfig, get_peft_model
13
  from transformers import (
14
  BertForMaskedLM,
15
  BertForSequenceClassification,
 
116
  if model_type == "MTLCellClassifier-Quantized":
117
  model_type = "MTLCellClassifier"
118
  quantize = True
119
+
120
  if mode == "eval":
121
  output_hidden_states = True
122
  elif mode == "train":
 
128
  "peft_config": None,
129
  "bnb_config": BitsAndBytesConfig(
130
  load_in_8bit=True,
131
+ ),
132
  }
133
  else:
134
  quantize = {
 
138
  r=64,
139
  bias="none",
140
  task_type="TokenClassification",
141
+ ),
142
  "bnb_config": BitsAndBytesConfig(
143
  load_in_4bit=True,
144
  bnb_4bit_use_double_quant=True,
145
  bnb_4bit_quant_type="nf4",
146
+ bnb_4bit_compute_dtype=torch.bfloat16,
147
+ ),
148
  }
149
  elif quantize is False:
150
  quantize = {"bnb_config": None}
 
183
  # if eval mode, put the model in eval mode for fwd pass
184
  if mode == "eval":
185
  model.eval()
186
+ if (
187
+ (quantize is False)
188
+ or (quantize == {"bnb_config": None})
189
+ or (model_type == "MTLCellClassifier")
190
+ ):
191
  model = model.to("cuda")
192
  else:
193
  model.enable_input_require_grads()
 
280
  example["length"] = len(example["input_ids"])
281
  return example
282
 
283
+
284
  # if CLS token present, move to 1st rather than 0th position
285
  def overexpress_indices_special(example):
286
  indices = example["perturb_index"]
287
  if any(isinstance(el, list) for el in indices):
288
  indices = flatten_list(indices)
289
+ insert_pos = 1 # Insert starting after CLS token
290
  for index in sorted(indices, reverse=False):
291
  example["input_ids"].insert(insert_pos, example["input_ids"].pop(index))
292
  insert_pos += 1
293
  example["length"] = len(example["input_ids"])
294
  return example
295
 
296
+
297
  # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
298
  def overexpress_tokens(example, max_len, special_token):
299
  # -100 indicates tokens to overexpress are not present in rank value encoding
 
313
  # truncate to max input size, must also truncate original emb to be comparable
314
  if len(example["input_ids"]) > max_len:
315
  if special_token:
316
+ example["input_ids"] = example["input_ids"][0 : max_len - 1] + [
317
+ example["input_ids"][-1]
318
+ ]
319
  else:
320
  example["input_ids"] = example["input_ids"][0:max_len]
321
  example["length"] = len(example["input_ids"])
 
334
  example["length"] = len(example["input_ids"])
335
  return example
336
 
337
+
338
  def truncate_by_n_overflow_special(example):
339
  if example["n_overflow"] > 0:
340
  new_max_len = example["length"] - example["n_overflow"]
341
+ example["input_ids"] = example["input_ids"][0 : new_max_len - 1] + [
342
+ example["input_ids"][-1]
343
+ ]
344
  example["length"] = len(example["input_ids"])
345
  return example
346
 
 
485
  range_start = 1
486
  elif perturb_type in ["delete", "inhibit"]:
487
  range_start = 0
488
+ range_start += 1 # Starting after the CLS token
489
  indices_to_perturb = [
490
+ [i]
491
+ for i in range(
492
+ range_start, example_cell["length"][0] - 1
493
+ ) # And excluding the EOS token
494
  ]
495
 
496
  # elif combo_lvl > 0 and anchor_token is None:
497
  ## to implement
498
+ elif combo_lvl > 0 and (anchor_token is not None):
499
  example_input_ids = example_cell["input_ids"][0]
500
  anchor_index = example_input_ids.index(anchor_token[0])
501
  indices_to_perturb = [
502
  sorted([anchor_index, i]) if i != anchor_index else None
503
+ for i in range(
504
+ 1, example_cell["length"][0] - 1
505
+ ) # Exclude CLS and EOS tokens
506
  ]
507
  indices_to_perturb = [item for item in indices_to_perturb if item is not None]
508
  else:
 
521
  list(x) for x in it.combinations(indices_to_perturb, combo_lvl + 1)
522
  ]
523
  else:
524
+ all_indices = [
525
+ [i] for i in range(1, example_cell["length"][0] - 1)
526
+ ] # Exclude CLS and EOS tokens
527
  all_indices = [
528
  index for index in all_indices if index not in indices_to_perturb
529
  ]
 
550
  )
551
  elif perturb_type == "overexpress":
552
  perturbation_dataset = perturbation_dataset.map(
553
+ overexpress_indices_special, num_proc=num_proc_i
554
  )
555
 
556
  perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
 
758
  # against original cell
759
  if cell_states_to_model is None or emb_mode == "gene":
760
  cos_sims = cos(perturbation_emb, original_emb).to("cuda")
761
+
762
  elif cell_states_to_model is not None and emb_mode == "cell":
763
  possible_states = get_possible_states(cell_states_to_model)
764
  cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
 
882
  )
883
  raise
884
 
885
+
886
  class GeneIdHandler:
887
  def __init__(self, raise_errors=False):
888
  def invert_dict(dict_obj):
889
+ return {v: k for k, v in dict_obj.items()}
890
+
891
  self.raise_errors = raise_errors
892
+
893
+ with open(TOKEN_DICTIONARY_FILE, "rb") as f:
894
  self.gene_token_dict = pickle.load(f)
895
  self.token_gene_dict = invert_dict(self.gene_token_dict)
896
 
897
+ with open(ENSEMBL_DICTIONARY_FILE, "rb") as f:
898
  self.id_gene_dict = pickle.load(f)
899
  self.gene_id_dict = invert_dict(self.id_gene_dict)
900
+
901
  def ens_to_token(self, ens_id):
902
  if not self.raise_errors:
903
  return self.gene_token_dict.get(ens_id, ens_id)
904
  else:
905
  return self.gene_token_dict[ens_id]
906
+
907
  def token_to_ens(self, token):
908
  if not self.raise_errors:
909
  return self.token_gene_dict.get(token, token)
 
915
  return self.gene_id_dict.get(ens_id, ens_id)
916
  else:
917
  return self.gene_id_dict[ens_id]
918
+
919
  def symbol_to_ens(self, symbol):
920
  if not self.raise_errors:
921
  return self.id_gene_dict.get(symbol, symbol)
922
  else:
923
  return self.id_gene_dict[symbol]
924
+
925
  def token_to_symbol(self, token):
926
  return self.ens_to_symbol(self.token_to_ens(token))
927
+
928
  def symbol_to_token(self, symbol):
929
+ return self.ens_to_token(self.symbol_to_ens(symbol))
geneformer/tokenizer.py CHANGED
@@ -22,28 +22,28 @@ Geneformer tokenizer.
22
 
23
  from __future__ import annotations
24
 
25
- import os
26
  import logging
 
27
  import pickle
28
  import warnings
 
29
  from pathlib import Path
30
  from typing import Literal
31
- from tqdm import tqdm
32
- from collections import Counter
33
 
34
- import numpy as np
35
- import scanpy as sc
36
  import loompy as lp
 
37
  import pandas as pd
 
38
  import scipy.sparse as sp
39
  from datasets import Dataset
 
40
 
41
  warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa
42
  import loompy as lp # noqa
43
 
44
  logger = logging.getLogger(__name__)
45
 
46
- from . import GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE, ENSEMBL_MAPPING_FILE
47
 
48
 
49
  def rank_genes(gene_vector, gene_tokens):
@@ -65,20 +65,26 @@ def tokenize_cell(gene_vector, gene_tokens):
65
  # rank by median-scaled gene values
66
  return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])
67
 
68
- def sum_ensembl_ids(data_directory,
69
- collapse_gene_ids,
70
- gene_mapping_dict,
71
- gene_token_dict,
72
- file_format = "loom",
73
- chunk_size = 512):
74
 
 
 
 
 
 
 
 
 
75
  if file_format == "loom":
76
- """
77
  Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
78
  """
79
  with lp.connect(data_directory) as data:
80
- assert "ensembl_id" in data.ra.keys(), "'ensembl_id' column missing from data.ra.keys()"
81
- gene_ids_in_dict = [gene for gene in data.ra.ensembl_id if gene in gene_token_dict.keys()]
 
 
 
 
82
  if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
83
  token_genes_unique = True
84
  else:
@@ -89,45 +95,80 @@ def sum_ensembl_ids(data_directory,
89
  else:
90
  raise ValueError("Error: data Ensembl IDs non-unique.")
91
 
92
- gene_ids_collapsed = [gene_mapping_dict.get(gene_id.upper()) for gene_id in data.ra.ensembl_id]
93
- gene_ids_collapsed_in_dict = [gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()]
 
 
 
 
94
 
95
- if (len(set(gene_ids_collapsed_in_dict)) == len(set(gene_ids_in_dict))) and token_genes_unique:
 
 
96
  return data_directory
97
  else:
98
- dedup_filename = data_directory.with_name(data_directory.stem + "__dedup.loom")
 
 
99
  data.ra["gene_ids_collapsed"] = gene_ids_collapsed
100
- dup_genes = [idx for idx, count in Counter(data.ra["gene_ids_collapsed"]).items() if count > 1]
 
 
 
 
101
  num_chunks = int(np.ceil(data.shape[1] / chunk_size))
102
  first_chunk = True
103
- for _, _, view in tqdm(data.scan(axis = 1, batch_size = chunk_size), total = num_chunks):
 
 
 
104
  def process_chunk(view, duplic_genes):
105
- data_count_view = pd.DataFrame(view, index=data.ra["gene_ids_collapsed"])
106
- unique_data_df = data_count_view.loc[~data_count_view.index.isin(duplic_genes)]
107
- dup_data_df = data_count_view.loc[data_count_view.index.isin([i for i in duplic_genes if "None" not in i])]
 
 
 
 
 
 
 
 
108
  summed_data = dup_data_df.groupby(dup_data_df.index).sum()
109
  if not summed_data.index.is_unique:
110
- raise ValueError("Error: Ensembl IDs in summed data frame non-unique.")
111
- data_count_view = pd.concat([unique_data_df, summed_data], axis=0)
 
 
 
 
112
  if not data_count_view.index.is_unique:
113
- raise ValueError("Error: Ensembl IDs in final data frame non-unique.")
 
 
114
  return data_count_view
 
115
  processed_chunk = process_chunk(view[:, :], dup_genes)
116
  processed_array = processed_chunk.to_numpy()
117
  new_row_attrs = {"ensembl_id": processed_chunk.index.to_numpy()}
118
 
119
  if "n_counts" not in view.ca.keys():
120
- total_count_view = np.sum(view[:,:], axis=0).astype(int)
121
  view.ca["n_counts"] = total_count_view
122
 
123
- if first_chunk: # Create the Loom file with the first chunk
124
- lp.create(f"{dedup_filename}", processed_array, row_attrs=new_row_attrs, col_attrs=view.ca)
 
 
 
 
 
125
  first_chunk = False
126
- else: # Append subsequent chunks
127
- with lp.connect(dedup_filename, mode='r+') as dsout:
128
  dsout.add_columns(processed_array, col_attrs=view.ca)
129
  return dedup_filename
130
-
131
  elif file_format == "h5ad":
132
  """
133
  Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
@@ -136,8 +177,12 @@ def sum_ensembl_ids(data_directory,
136
 
137
  data = sc.read_h5ad(str(data_directory))
138
 
139
- assert "ensembl_id" in data.var.columns, "'ensembl_id' column missing from data.var"
140
- gene_ids_in_dict = [gene for gene in data.var.ensembl_id if gene in gene_token_dict.keys()]
 
 
 
 
141
  if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
142
  token_genes_unique = True
143
  else:
@@ -148,22 +193,29 @@ def sum_ensembl_ids(data_directory,
148
  else:
149
  raise ValueError("Error: data Ensembl IDs non-unique.")
150
 
151
- gene_ids_collapsed = [gene_mapping_dict.get(gene_id.upper()) for gene_id in data.var.ensembl_id]
152
- gene_ids_collapsed_in_dict = [gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()]
153
- if (len(set(gene_ids_collapsed_in_dict)) == len(set(gene_ids_in_dict))) and token_genes_unique:
 
 
 
 
 
 
154
  return data
155
-
156
  else:
157
  data.var["gene_ids_collapsed"] = gene_ids_collapsed
158
  data.var_names = gene_ids_collapsed
159
  data = data[:, ~data.var.index.isna()]
160
- dup_genes = [idx for idx, count in Counter(data.var_names).items() if count > 1]
 
 
161
 
162
  num_chunks = int(np.ceil(data.shape[0] / chunk_size))
163
 
164
  processed_genes = []
165
  for i in tqdm(range(num_chunks)):
166
-
167
  start_idx = i * chunk_size
168
  end_idx = min((i + 1) * chunk_size, data.shape[0])
169
  data_chunk = data[start_idx:end_idx, :]
@@ -171,29 +223,32 @@ def sum_ensembl_ids(data_directory,
171
  processed_chunks = []
172
  for dup_gene in dup_genes:
173
  data_dup_gene = data_chunk[:, data_chunk.var_names == dup_gene]
174
- df = pd.DataFrame.sparse.from_spmatrix(data_dup_gene.X,
175
- index=data_dup_gene.obs_names,
176
- columns=data_dup_gene.var_names)
 
 
177
  df_sum = pd.DataFrame(df.sum(axis=1))
178
- df_sum.columns = [dup_gene]
179
  df_sum.index = data_dup_gene.obs.index
180
  processed_chunks.append(df_sum)
181
 
182
- processed_chunks = pd.concat(processed_chunks, axis=1)
183
  processed_genes.append(processed_chunks)
184
- processed_genes = pd.concat(processed_genes, axis = 0)
185
- var_df = pd.DataFrame({"gene_ids_collapsed" : processed_genes.columns})
186
  var_df.index = processed_genes.columns
187
- processed_genes = sc.AnnData(X = processed_genes,
188
- obs = data.obs,
189
- var = var_df)
190
 
191
- data_dedup = data[:, ~data.var.index.isin(dup_genes)] # Deduplicated data
192
- data_dedup = sc.concat([data_dedup, processed_genes], axis = 1)
193
  data_dedup.obs = data.obs
194
- data_dedup.var = data_dedup.var.rename(columns = {"gene_ids_collapsed" : "ensembl_id"})
 
 
195
  return data_dedup
196
 
 
197
  class TranscriptomeTokenizer:
198
  def __init__(
199
  self,
@@ -258,10 +313,12 @@ class TranscriptomeTokenizer:
258
 
259
  # check for special token in gene_token_dict
260
  if self.special_token:
261
- if ("<cls>" not in self.gene_token_dict.keys()) and ("<eos>" not in self.gene_token_dict.keys()):
 
 
262
  logger.error(
263
- "<cls> and <eos> required in gene_token_dict when special_token = True."
264
- )
265
  raise
266
 
267
  # if collapsing duplicate gene IDs
@@ -272,14 +329,16 @@ class TranscriptomeTokenizer:
272
  with open(gene_mapping_file, "rb") as f:
273
  self.gene_mapping_dict = pickle.load(f)
274
  else:
275
- self.gene_mapping_dict = {k:k for k,_ in self.gene_token_dict.items()}
276
 
277
  # gene keys for full vocabulary
278
  self.gene_keys = list(self.gene_token_dict.keys())
279
 
280
  # Filter gene mapping dict for items that exist in gene_token_dict
281
  gene_keys_set = set(self.gene_token_dict.keys())
282
- self.gene_mapping_dict = {k: v for k, v in self.gene_mapping_dict.items() if v in gene_keys_set}
 
 
283
 
284
  # protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
285
  self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))
@@ -355,7 +414,14 @@ class TranscriptomeTokenizer:
355
  return tokenized_cells, cell_metadata
356
 
357
  def tokenize_anndata(self, adata_file_path, target_sum=10_000):
358
- adata = sum_ensembl_ids(adata_file_path, self.collapse_gene_ids, self.gene_mapping_dict, self.gene_token_dict, file_format = "h5ad", chunk_size = self.chunk_size)
 
 
 
 
 
 
 
359
 
360
  if self.custom_attr_name_dict is not None:
361
  file_cell_metadata = {
@@ -397,7 +463,7 @@ class TranscriptomeTokenizer:
397
  idx = filter_pass_loc[i : i + self.chunk_size]
398
 
399
  n_counts = adata[idx].obs["n_counts"].values[:, None]
400
- X_view0 = adata[idx,:].X
401
  X_view = X_view0[:, coding_miRNA_loc]
402
  X_norm = X_view / n_counts * target_sum / norm_factor_vector
403
  X_norm = sp.csr_matrix(X_norm)
@@ -423,7 +489,14 @@ class TranscriptomeTokenizer:
423
  }
424
 
425
  dedup_filename = loom_file_path.with_name(loom_file_path.stem + "__dedup.loom")
426
- loom_file_path = sum_ensembl_ids(loom_file_path, self.collapse_gene_ids, self.gene_mapping_dict, self.gene_token_dict, file_format = "loom", chunk_size = self.chunk_size)
 
 
 
 
 
 
 
427
 
428
  with lp.connect(str(loom_file_path)) as data:
429
  # define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
@@ -544,4 +617,4 @@ class TranscriptomeTokenizer:
544
  output_dataset_truncated = output_dataset.map(
545
  format_cell_features, num_proc=self.nproc
546
  )
547
- return output_dataset_truncated
 
22
 
23
  from __future__ import annotations
24
 
 
25
  import logging
26
+ import os
27
  import pickle
28
  import warnings
29
+ from collections import Counter
30
  from pathlib import Path
31
  from typing import Literal
 
 
32
 
 
 
33
  import loompy as lp
34
+ import numpy as np
35
  import pandas as pd
36
+ import scanpy as sc
37
  import scipy.sparse as sp
38
  from datasets import Dataset
39
+ from tqdm import tqdm
40
 
41
  warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa
42
  import loompy as lp # noqa
43
 
44
  logger = logging.getLogger(__name__)
45
 
46
+ from . import ENSEMBL_MAPPING_FILE, GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE
47
 
48
 
49
  def rank_genes(gene_vector, gene_tokens):
 
65
  # rank by median-scaled gene values
66
  return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])
67
 
 
 
 
 
 
 
68
 
69
+ def sum_ensembl_ids(
70
+ data_directory,
71
+ collapse_gene_ids,
72
+ gene_mapping_dict,
73
+ gene_token_dict,
74
+ file_format="loom",
75
+ chunk_size=512,
76
+ ):
77
  if file_format == "loom":
78
+ """
79
  Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
80
  """
81
  with lp.connect(data_directory) as data:
82
+ assert (
83
+ "ensembl_id" in data.ra.keys()
84
+ ), "'ensembl_id' column missing from data.ra.keys()"
85
+ gene_ids_in_dict = [
86
+ gene for gene in data.ra.ensembl_id if gene in gene_token_dict.keys()
87
+ ]
88
  if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
89
  token_genes_unique = True
90
  else:
 
95
  else:
96
  raise ValueError("Error: data Ensembl IDs non-unique.")
97
 
98
+ gene_ids_collapsed = [
99
+ gene_mapping_dict.get(gene_id.upper()) for gene_id in data.ra.ensembl_id
100
+ ]
101
+ gene_ids_collapsed_in_dict = [
102
+ gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()
103
+ ]
104
 
105
+ if (
106
+ len(set(gene_ids_collapsed_in_dict)) == len(set(gene_ids_in_dict))
107
+ ) and token_genes_unique:
108
  return data_directory
109
  else:
110
+ dedup_filename = data_directory.with_name(
111
+ data_directory.stem + "__dedup.loom"
112
+ )
113
  data.ra["gene_ids_collapsed"] = gene_ids_collapsed
114
+ dup_genes = [
115
+ idx
116
+ for idx, count in Counter(data.ra["gene_ids_collapsed"]).items()
117
+ if count > 1
118
+ ]
119
  num_chunks = int(np.ceil(data.shape[1] / chunk_size))
120
  first_chunk = True
121
+ for _, _, view in tqdm(
122
+ data.scan(axis=1, batch_size=chunk_size), total=num_chunks
123
+ ):
124
+
125
  def process_chunk(view, duplic_genes):
126
+ data_count_view = pd.DataFrame(
127
+ view, index=data.ra["gene_ids_collapsed"]
128
+ )
129
+ unique_data_df = data_count_view.loc[
130
+ ~data_count_view.index.isin(duplic_genes)
131
+ ]
132
+ dup_data_df = data_count_view.loc[
133
+ data_count_view.index.isin(
134
+ [i for i in duplic_genes if "None" not in i]
135
+ )
136
+ ]
137
  summed_data = dup_data_df.groupby(dup_data_df.index).sum()
138
  if not summed_data.index.is_unique:
139
+ raise ValueError(
140
+ "Error: Ensembl IDs in summed data frame non-unique."
141
+ )
142
+ data_count_view = pd.concat(
143
+ [unique_data_df, summed_data], axis=0
144
+ )
145
  if not data_count_view.index.is_unique:
146
+ raise ValueError(
147
+ "Error: Ensembl IDs in final data frame non-unique."
148
+ )
149
  return data_count_view
150
+
151
  processed_chunk = process_chunk(view[:, :], dup_genes)
152
  processed_array = processed_chunk.to_numpy()
153
  new_row_attrs = {"ensembl_id": processed_chunk.index.to_numpy()}
154
 
155
  if "n_counts" not in view.ca.keys():
156
+ total_count_view = np.sum(view[:, :], axis=0).astype(int)
157
  view.ca["n_counts"] = total_count_view
158
 
159
+ if first_chunk: # Create the Loom file with the first chunk
160
+ lp.create(
161
+ f"{dedup_filename}",
162
+ processed_array,
163
+ row_attrs=new_row_attrs,
164
+ col_attrs=view.ca,
165
+ )
166
  first_chunk = False
167
+ else: # Append subsequent chunks
168
+ with lp.connect(dedup_filename, mode="r+") as dsout:
169
  dsout.add_columns(processed_array, col_attrs=view.ca)
170
  return dedup_filename
171
+
172
  elif file_format == "h5ad":
173
  """
174
  Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
 
177
 
178
  data = sc.read_h5ad(str(data_directory))
179
 
180
+ assert (
181
+ "ensembl_id" in data.var.columns
182
+ ), "'ensembl_id' column missing from data.var"
183
+ gene_ids_in_dict = [
184
+ gene for gene in data.var.ensembl_id if gene in gene_token_dict.keys()
185
+ ]
186
  if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
187
  token_genes_unique = True
188
  else:
 
193
  else:
194
  raise ValueError("Error: data Ensembl IDs non-unique.")
195
 
196
+ gene_ids_collapsed = [
197
+ gene_mapping_dict.get(gene_id.upper()) for gene_id in data.var.ensembl_id
198
+ ]
199
+ gene_ids_collapsed_in_dict = [
200
+ gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()
201
+ ]
202
+ if (
203
+ len(set(gene_ids_collapsed_in_dict)) == len(set(gene_ids_in_dict))
204
+ ) and token_genes_unique:
205
  return data
206
+
207
  else:
208
  data.var["gene_ids_collapsed"] = gene_ids_collapsed
209
  data.var_names = gene_ids_collapsed
210
  data = data[:, ~data.var.index.isna()]
211
+ dup_genes = [
212
+ idx for idx, count in Counter(data.var_names).items() if count > 1
213
+ ]
214
 
215
  num_chunks = int(np.ceil(data.shape[0] / chunk_size))
216
 
217
  processed_genes = []
218
  for i in tqdm(range(num_chunks)):
 
219
  start_idx = i * chunk_size
220
  end_idx = min((i + 1) * chunk_size, data.shape[0])
221
  data_chunk = data[start_idx:end_idx, :]
 
223
  processed_chunks = []
224
  for dup_gene in dup_genes:
225
  data_dup_gene = data_chunk[:, data_chunk.var_names == dup_gene]
226
+ df = pd.DataFrame.sparse.from_spmatrix(
227
+ data_dup_gene.X,
228
+ index=data_dup_gene.obs_names,
229
+ columns=data_dup_gene.var_names,
230
+ )
231
  df_sum = pd.DataFrame(df.sum(axis=1))
232
+ df_sum.columns = [dup_gene]
233
  df_sum.index = data_dup_gene.obs.index
234
  processed_chunks.append(df_sum)
235
 
236
+ processed_chunks = pd.concat(processed_chunks, axis=1)
237
  processed_genes.append(processed_chunks)
238
+ processed_genes = pd.concat(processed_genes, axis=0)
239
+ var_df = pd.DataFrame({"gene_ids_collapsed": processed_genes.columns})
240
  var_df.index = processed_genes.columns
241
+ processed_genes = sc.AnnData(X=processed_genes, obs=data.obs, var=var_df)
 
 
242
 
243
+ data_dedup = data[:, ~data.var.index.isin(dup_genes)] # Deduplicated data
244
+ data_dedup = sc.concat([data_dedup, processed_genes], axis=1)
245
  data_dedup.obs = data.obs
246
+ data_dedup.var = data_dedup.var.rename(
247
+ columns={"gene_ids_collapsed": "ensembl_id"}
248
+ )
249
  return data_dedup
250
 
251
+
252
  class TranscriptomeTokenizer:
253
  def __init__(
254
  self,
 
313
 
314
  # check for special token in gene_token_dict
315
  if self.special_token:
316
+ if ("<cls>" not in self.gene_token_dict.keys()) and (
317
+ "<eos>" not in self.gene_token_dict.keys()
318
+ ):
319
  logger.error(
320
+ "<cls> and <eos> required in gene_token_dict when special_token = True."
321
+ )
322
  raise
323
 
324
  # if collapsing duplicate gene IDs
 
329
  with open(gene_mapping_file, "rb") as f:
330
  self.gene_mapping_dict = pickle.load(f)
331
  else:
332
+ self.gene_mapping_dict = {k: k for k, _ in self.gene_token_dict.items()}
333
 
334
  # gene keys for full vocabulary
335
  self.gene_keys = list(self.gene_token_dict.keys())
336
 
337
  # Filter gene mapping dict for items that exist in gene_token_dict
338
  gene_keys_set = set(self.gene_token_dict.keys())
339
+ self.gene_mapping_dict = {
340
+ k: v for k, v in self.gene_mapping_dict.items() if v in gene_keys_set
341
+ }
342
 
343
  # protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
344
  self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))
 
414
  return tokenized_cells, cell_metadata
415
 
416
  def tokenize_anndata(self, adata_file_path, target_sum=10_000):
417
+ adata = sum_ensembl_ids(
418
+ adata_file_path,
419
+ self.collapse_gene_ids,
420
+ self.gene_mapping_dict,
421
+ self.gene_token_dict,
422
+ file_format="h5ad",
423
+ chunk_size=self.chunk_size,
424
+ )
425
 
426
  if self.custom_attr_name_dict is not None:
427
  file_cell_metadata = {
 
463
  idx = filter_pass_loc[i : i + self.chunk_size]
464
 
465
  n_counts = adata[idx].obs["n_counts"].values[:, None]
466
+ X_view0 = adata[idx, :].X
467
  X_view = X_view0[:, coding_miRNA_loc]
468
  X_norm = X_view / n_counts * target_sum / norm_factor_vector
469
  X_norm = sp.csr_matrix(X_norm)
 
489
  }
490
 
491
  dedup_filename = loom_file_path.with_name(loom_file_path.stem + "__dedup.loom")
492
+ loom_file_path = sum_ensembl_ids(
493
+ loom_file_path,
494
+ self.collapse_gene_ids,
495
+ self.gene_mapping_dict,
496
+ self.gene_token_dict,
497
+ file_format="loom",
498
+ chunk_size=self.chunk_size,
499
+ )
500
 
501
  with lp.connect(str(loom_file_path)) as data:
502
  # define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
 
617
  output_dataset_truncated = output_dataset.map(
618
  format_cell_features, num_proc=self.nproc
619
  )
620
+ return output_dataset_truncated