ctheodoris commited on
Commit
933ca80
Β·
1 Parent(s): ec19834

update with 12L and 20L i4096 gc95M models, multitask and quantiz code

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -1
  2. MANIFEST.in +3 -3
  3. config.json +9 -8
  4. fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json +24 -0
  5. fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin +3 -0
  6. fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 β†’ gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/config.json +0 -0
  7. fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 β†’ gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/optimizer.pt +0 -0
  8. fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 β†’ gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/pytorch_model.bin +0 -0
  9. fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 β†’ gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/rng_state.pth +0 -0
  10. fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 β†’ gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/scheduler.pt +0 -0
  11. fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 β†’ gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/trainer_state.json +0 -0
  12. fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 β†’ gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/training_args.bin +0 -0
  13. geneformer/__init__.py +10 -5
  14. geneformer/classifier.py +74 -16
  15. geneformer/classifier_utils.py +117 -5
  16. geneformer/collator_for_classification.py +15 -19
  17. geneformer/emb_extractor.py +20 -13
  18. geneformer/gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl +3 -0
  19. geneformer/{gene_name_id_dict.pkl β†’ gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl} +0 -0
  20. geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl +3 -0
  21. geneformer/gene_median_dictionary.pkl +0 -0
  22. geneformer/in_silico_perturber.py +733 -143
  23. geneformer/in_silico_perturber_stats.py +22 -6
  24. geneformer/mtl/__init__.py +0 -0
  25. geneformer/mtl/collators.py +66 -0
  26. geneformer/mtl/data.py +116 -0
  27. geneformer/mtl/eval_utils.py +81 -0
  28. geneformer/mtl/imports.py +46 -0
  29. geneformer/mtl/model.py +84 -0
  30. geneformer/mtl/optuna_utils.py +21 -0
  31. geneformer/mtl/train.py +242 -0
  32. geneformer/mtl/train_utils.py +126 -0
  33. geneformer/mtl/utils.py +106 -0
  34. geneformer/mtl_classifier.py +338 -0
  35. geneformer/perturber_utils.py +168 -16
  36. geneformer/pretrainer.py +0 -13
  37. geneformer/token_dictionary.pkl +0 -0
  38. geneformer/token_dictionary_gc95M.pkl +0 -0
  39. generation_config.json +5 -0
  40. {geneformer-12L-30M β†’ gf-12L-30M-i2048}/config.json +0 -0
  41. {geneformer-12L-30M β†’ gf-12L-30M-i2048}/pytorch_model.bin +0 -0
  42. {geneformer-12L-30M β†’ gf-12L-30M-i2048}/training_args.bin +0 -0
  43. gf-12L-95M-i4096/config.json +24 -0
  44. gf-12L-95M-i4096/generation_config.json +5 -0
  45. gf-12L-95M-i4096/model.safetensors +3 -0
  46. gf-12L-95M-i4096/training_args.bin +3 -0
  47. gf-12L-95M-i4096_CLcancer/config.json +25 -0
  48. gf-12L-95M-i4096_CLcancer/generation_config.json +5 -0
  49. gf-12L-95M-i4096_CLcancer/model.safetensors +3 -0
  50. gf-12L-95M-i4096_CLcancer/training_args.bin +3 -0
.gitattributes CHANGED
@@ -26,4 +26,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
26
  *.zip filter=lfs diff=lfs merge=lfs -text
27
  *.zstandard filter=lfs diff=lfs merge=lfs -text
28
  *tfevents* filter=lfs diff=lfs merge=lfs -text
29
- model.safetensors filter=lfs diff=lfs merge=lfs -text
 
26
  *.zip filter=lfs diff=lfs merge=lfs -text
27
  *.zstandard filter=lfs diff=lfs merge=lfs -text
28
  *tfevents* filter=lfs diff=lfs merge=lfs -text
29
+ model.safetensors filter=lfs diff=lfs merge=lfs -text
MANIFEST.in CHANGED
@@ -1,3 +1,3 @@
1
- include geneformer/gene_median_dictionary.pkl
2
- include geneformer/token_dictionary.pkl
3
- include geneformer/gene_name_id_dict.pkl
 
1
+ include geneformer/gene_median_dictionary_95m.pkl
2
+ include geneformer/token_dictionary_95m.pkl
3
+ include geneformer/gene_name_id_dict_95m.pkl
config.json CHANGED
@@ -3,21 +3,22 @@
3
  "BertForMaskedLM"
4
  ],
5
  "attention_probs_dropout_prob": 0.02,
6
- "gradient_checkpointing": false,
7
  "hidden_act": "relu",
8
  "hidden_dropout_prob": 0.02,
9
- "hidden_size": 256,
10
  "initializer_range": 0.02,
11
- "intermediate_size": 512,
12
  "layer_norm_eps": 1e-12,
13
- "max_position_embeddings": 2048,
14
  "model_type": "bert",
15
- "num_attention_heads": 4,
16
- "num_hidden_layers": 6,
17
  "pad_token_id": 0,
18
  "position_embedding_type": "absolute",
19
- "transformers_version": "4.6.0",
 
20
  "type_vocab_size": 2,
21
  "use_cache": true,
22
- "vocab_size": 25426
23
  }
 
3
  "BertForMaskedLM"
4
  ],
5
  "attention_probs_dropout_prob": 0.02,
6
+ "classifier_dropout": null,
7
  "hidden_act": "relu",
8
  "hidden_dropout_prob": 0.02,
9
+ "hidden_size": 512,
10
  "initializer_range": 0.02,
11
+ "intermediate_size": 1024,
12
  "layer_norm_eps": 1e-12,
13
+ "max_position_embeddings": 4096,
14
  "model_type": "bert",
15
+ "num_attention_heads": 8,
16
+ "num_hidden_layers": 12,
17
  "pad_token_id": 0,
18
  "position_embedding_type": "absolute",
19
+ "torch_dtype": "float32",
20
+ "transformers_version": "4.37.1",
21
  "type_vocab_size": 2,
22
  "use_cache": true,
23
+ "vocab_size": 20275
24
  }
fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.02,
6
+ "classifier_dropout": null,
7
+ "hidden_act": "relu",
8
+ "hidden_dropout_prob": 0.02,
9
+ "hidden_size": 512,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 1024,
12
+ "layer_norm_eps": 1e-12,
13
+ "max_position_embeddings": 4096,
14
+ "model_type": "bert",
15
+ "num_attention_heads": 8,
16
+ "num_hidden_layers": 12,
17
+ "pad_token_id": 0,
18
+ "position_embedding_type": "absolute",
19
+ "torch_dtype": "float32",
20
+ "transformers_version": "4.37.2",
21
+ "type_vocab_size": 2,
22
+ "use_cache": true,
23
+ "vocab_size": 20275
24
+ }
fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07b28d8c7bb789d59755c42d32f6182cc04d2cf34aafaa6397aa50e4fdf1a9b4
3
+ size 152363342
fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 β†’ gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/config.json RENAMED
File without changes
fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 β†’ gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/optimizer.pt RENAMED
File without changes
fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 β†’ gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/pytorch_model.bin RENAMED
File without changes
fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 β†’ gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/rng_state.pth RENAMED
File without changes
fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 β†’ gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/scheduler.pt RENAMED
File without changes
fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 β†’ gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/trainer_state.json RENAMED
File without changes
fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 β†’ gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/training_args.bin RENAMED
File without changes
geneformer/__init__.py CHANGED
@@ -1,10 +1,12 @@
1
  # ruff: noqa: F401
2
  from pathlib import Path
 
 
3
 
4
- GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
5
- TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
6
- ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
7
- ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict.pkl"
8
 
9
  from . import (
10
  collator_for_classification,
@@ -25,4 +27,7 @@ from .pretrainer import GeneformerPretrainer
25
  from .tokenizer import TranscriptomeTokenizer
26
 
27
  from . import classifier # noqa # isort:skip
28
- from .classifier import Classifier # noqa # isort:skip
 
 
 
 
1
  # ruff: noqa: F401
2
  from pathlib import Path
3
+ import warnings
4
+ warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa # isort:skip
5
 
6
+ GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc95M.pkl"
7
+ TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary_gc95M.pkl"
8
+ ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict_gc95M.pkl"
9
+ ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict_gc95M.pkl"
10
 
11
  from . import (
12
  collator_for_classification,
 
27
  from .tokenizer import TranscriptomeTokenizer
28
 
29
  from . import classifier # noqa # isort:skip
30
+ from .classifier import Classifier # noqa # isort:skip
31
+
32
+ from . import mtl_classifier # noqa # isort:skip
33
+ from .mtl_classifier import MTLClassifier # noqa # isort:skip
geneformer/classifier.py CHANGED
@@ -72,6 +72,7 @@ logger = logging.getLogger(__name__)
72
  class Classifier:
73
  valid_option_dict = {
74
  "classifier": {"cell", "gene"},
 
75
  "cell_state_dict": {None, dict},
76
  "gene_class_dict": {None, dict},
77
  "filter_data": {None, dict},
@@ -93,6 +94,7 @@ class Classifier:
93
  def __init__(
94
  self,
95
  classifier=None,
 
96
  cell_state_dict=None,
97
  gene_class_dict=None,
98
  filter_data=None,
@@ -118,6 +120,13 @@ class Classifier:
118
 
119
  classifier : {"cell", "gene"}
120
  | Whether to fine-tune a cell state or gene classifier.
 
 
 
 
 
 
 
121
  cell_state_dict : None, dict
122
  | Cell states to fine-tune model to distinguish.
123
  | Two-item dictionary with keys: state_key and states
@@ -191,6 +200,7 @@ class Classifier:
191
  self.model_type = "CellClassifier"
192
  elif self.classifier == "gene":
193
  self.model_type = "GeneClassifier"
 
194
  self.cell_state_dict = cell_state_dict
195
  self.gene_class_dict = gene_class_dict
196
  self.filter_data = filter_data
@@ -256,7 +266,7 @@ class Classifier:
256
  f"Genes to classify {missing_genes} are not in token dictionary."
257
  )
258
  self.gene_class_dict = {
259
- k: set([self.gene_token_dict.get(gene) for gene in v])
260
  for k, v in self.gene_class_dict.items()
261
  }
262
  empty_classes = []
@@ -403,6 +413,15 @@ class Classifier:
403
  "Column name 'labels' must be reserved for class IDs. Please rename column."
404
  )
405
  raise
 
 
 
 
 
 
 
 
 
406
 
407
  if self.classifier == "cell":
408
  # remove cell states representing < rare_threshold of cells
@@ -505,6 +524,7 @@ class Classifier:
505
  output_directory,
506
  output_prefix,
507
  save_eval_output=True,
 
508
  ):
509
  """
510
  Train cell state or gene classifier using all data.
@@ -525,13 +545,20 @@ class Classifier:
525
  save_eval_output : bool
526
  | Whether to save cross-fold eval output
527
  | Saves as pickle file of dictionary of eval metrics
528
-
 
 
 
529
  **Output**
530
 
531
  Returns trainer after fine-tuning with all data.
532
 
533
  """
534
 
 
 
 
 
535
  ##### Load data and prepare output directory #####
536
  # load numerical id to class dictionary (id:class)
537
  with open(id_class_dict_file, "rb") as f:
@@ -563,7 +590,7 @@ class Classifier:
563
  )
564
  assert len(targets) == len(labels)
565
  data = cu.prep_gene_classifier_all_data(
566
- data, targets, labels, self.max_ncells, self.nproc
567
  )
568
 
569
  trainer = self.train_classifier(
@@ -582,12 +609,15 @@ class Classifier:
582
  split_id_dict=None,
583
  attr_to_split=None,
584
  attr_to_balance=None,
 
585
  max_trials=100,
586
  pval_threshold=0.1,
587
  save_eval_output=True,
588
  predict_eval=True,
589
  predict_trainer=False,
590
  n_hyperopt_trials=0,
 
 
591
  ):
592
  """
593
  (Cross-)validate cell state or gene classifier.
@@ -622,6 +652,9 @@ class Classifier:
622
  attr_to_balance : None, list
623
  | List of attribute keys on which to balance data while splitting on attr_to_split
624
  | e.g. ["age", "sex"] for balancing these characteristics while splitting by patient
 
 
 
625
  max_trials : None, int
626
  | Maximum number of trials of random splitting to try to achieve balanced other attribute
627
  | If no split is found without significant (p < pval_threshold) differences in other attributes, will select best
@@ -640,11 +673,17 @@ class Classifier:
640
  n_hyperopt_trials : int
641
  | Number of trials to run for hyperparameter optimization
642
  | If 0, will not optimize hyperparameters
 
 
643
  """
644
  if self.num_crossval_splits == 0:
645
  logger.error("num_crossval_splits must be 1 or 5 to validate.")
646
  raise
647
-
 
 
 
 
648
  # ensure number of genes in each class is > 5 if validating model
649
  if self.classifier == "gene":
650
  insuff_classes = [k for k, v in self.gene_class_dict.items() if len(v) < 5]
@@ -725,7 +764,7 @@ class Classifier:
725
  else:
726
  # 5-fold cross-validate
727
  num_cells = len(data)
728
- fifth_cells = num_cells * 0.2
729
  num_eval = min((self.eval_size * num_cells), fifth_cells)
730
  start = i * fifth_cells
731
  end = start + num_eval
@@ -804,8 +843,19 @@ class Classifier:
804
  self.max_ncells,
805
  iteration_num,
806
  self.nproc,
 
807
  )
808
-
 
 
 
 
 
 
 
 
 
 
809
  if self.oos_test_size > 0:
810
  test_data = cu.prep_gene_classifier_split(
811
  data,
@@ -817,7 +867,14 @@ class Classifier:
817
  iteration_num,
818
  self.nproc,
819
  )
820
-
 
 
 
 
 
 
 
821
  if n_hyperopt_trials == 0:
822
  trainer = self.train_classifier(
823
  model_directory,
@@ -966,7 +1023,7 @@ class Classifier:
966
  subprocess.call(f"mkdir {output_directory}", shell=True)
967
 
968
  ##### Load model and training args #####
969
- model = pu.load_model(self.model_type, num_classes, model_directory, "train")
970
  def_training_args, def_freeze_layers = cu.get_default_train_args(
971
  model, self.classifier, train_data, output_directory
972
  )
@@ -990,14 +1047,14 @@ class Classifier:
990
  ##### Fine-tune the model #####
991
  # define the data collator
992
  if self.classifier == "cell":
993
- data_collator = DataCollatorForCellClassification()
994
  elif self.classifier == "gene":
995
- data_collator = DataCollatorForGeneClassification()
996
 
997
  # define function to initiate model
998
  def model_init():
999
  model = pu.load_model(
1000
- self.model_type, num_classes, model_directory, "train"
1001
  )
1002
 
1003
  if self.freeze_layers is not None:
@@ -1009,7 +1066,8 @@ class Classifier:
1009
  for param in module.parameters():
1010
  param.requires_grad = False
1011
 
1012
- model = model.to("cuda:0")
 
1013
  return model
1014
 
1015
  # create the trainer
@@ -1122,7 +1180,7 @@ class Classifier:
1122
  subprocess.call(f"mkdir {output_directory}", shell=True)
1123
 
1124
  ##### Load model and training args #####
1125
- model = pu.load_model(self.model_type, num_classes, model_directory, "train")
1126
 
1127
  def_training_args, def_freeze_layers = cu.get_default_train_args(
1128
  model, self.classifier, train_data, output_directory
@@ -1152,9 +1210,9 @@ class Classifier:
1152
  ##### Fine-tune the model #####
1153
  # define the data collator
1154
  if self.classifier == "cell":
1155
- data_collator = DataCollatorForCellClassification()
1156
  elif self.classifier == "gene":
1157
- data_collator = DataCollatorForGeneClassification()
1158
 
1159
  # create the trainer
1160
  trainer = Trainer(
@@ -1276,7 +1334,7 @@ class Classifier:
1276
  test_data = pu.load_and_filter(None, self.nproc, test_data_file)
1277
 
1278
  # load previously fine-tuned model
1279
- model = pu.load_model(self.model_type, num_classes, model_directory, "eval")
1280
 
1281
  # evaluate the model
1282
  result = self.evaluate_model(
 
72
  class Classifier:
73
  valid_option_dict = {
74
  "classifier": {"cell", "gene"},
75
+ "quantize": {bool, dict},
76
  "cell_state_dict": {None, dict},
77
  "gene_class_dict": {None, dict},
78
  "filter_data": {None, dict},
 
94
  def __init__(
95
  self,
96
  classifier=None,
97
+ quantize=False,
98
  cell_state_dict=None,
99
  gene_class_dict=None,
100
  filter_data=None,
 
120
 
121
  classifier : {"cell", "gene"}
122
  | Whether to fine-tune a cell state or gene classifier.
123
+ quantize : bool, dict
124
+ | Whether to fine-tune a quantized model.
125
+ | If True and no config provided, will use default.
126
+ | Will use custom config if provided.
127
+ | Configs should be provided as dictionary of BitsAndBytesConfig (transformers) and LoraConfig (peft).
128
+ | For example: {"bnb_config": BitsAndBytesConfig(...),
129
+ | "peft_config": LoraConfig(...)}
130
  cell_state_dict : None, dict
131
  | Cell states to fine-tune model to distinguish.
132
  | Two-item dictionary with keys: state_key and states
 
200
  self.model_type = "CellClassifier"
201
  elif self.classifier == "gene":
202
  self.model_type = "GeneClassifier"
203
+ self.quantize = quantize
204
  self.cell_state_dict = cell_state_dict
205
  self.gene_class_dict = gene_class_dict
206
  self.filter_data = filter_data
 
266
  f"Genes to classify {missing_genes} are not in token dictionary."
267
  )
268
  self.gene_class_dict = {
269
+ k: list(set([self.gene_token_dict.get(gene) for gene in v]))
270
  for k, v in self.gene_class_dict.items()
271
  }
272
  empty_classes = []
 
413
  "Column name 'labels' must be reserved for class IDs. Please rename column."
414
  )
415
  raise
416
+
417
+ if (attr_to_split is not None) and (attr_to_balance is None):
418
+ logger.error(
419
+ "Splitting by attribute while balancing confounders requires both attr_to_split and attr_to_balance to be defined."
420
+ )
421
+ raise
422
+
423
+ if not isinstance(attr_to_balance, list):
424
+ attr_to_balance = [attr_to_balance]
425
 
426
  if self.classifier == "cell":
427
  # remove cell states representing < rare_threshold of cells
 
524
  output_directory,
525
  output_prefix,
526
  save_eval_output=True,
527
+ gene_balance=False,
528
  ):
529
  """
530
  Train cell state or gene classifier using all data.
 
545
  save_eval_output : bool
546
  | Whether to save cross-fold eval output
547
  | Saves as pickle file of dictionary of eval metrics
548
+ gene_balance : None, bool
549
+ | Whether to automatically balance genes in training set.
550
+ | Only available for binary gene classifications.
551
+
552
  **Output**
553
 
554
  Returns trainer after fine-tuning with all data.
555
 
556
  """
557
 
558
+ if (gene_balance is True) and (len(self.gene_class_dict.values())!=2):
559
+ logger.error("Automatically balancing gene sets for training is only available for binary gene classifications.")
560
+ raise
561
+
562
  ##### Load data and prepare output directory #####
563
  # load numerical id to class dictionary (id:class)
564
  with open(id_class_dict_file, "rb") as f:
 
590
  )
591
  assert len(targets) == len(labels)
592
  data = cu.prep_gene_classifier_all_data(
593
+ data, targets, labels, self.max_ncells, self.nproc, gene_balance
594
  )
595
 
596
  trainer = self.train_classifier(
 
609
  split_id_dict=None,
610
  attr_to_split=None,
611
  attr_to_balance=None,
612
+ gene_balance=False,
613
  max_trials=100,
614
  pval_threshold=0.1,
615
  save_eval_output=True,
616
  predict_eval=True,
617
  predict_trainer=False,
618
  n_hyperopt_trials=0,
619
+ save_gene_split_datasets=True,
620
+ debug_gene_split_datasets=False,
621
  ):
622
  """
623
  (Cross-)validate cell state or gene classifier.
 
652
  attr_to_balance : None, list
653
  | List of attribute keys on which to balance data while splitting on attr_to_split
654
  | e.g. ["age", "sex"] for balancing these characteristics while splitting by patient
655
+ gene_balance : None, bool
656
+ | Whether to automatically balance genes in training set.
657
+ | Only available for binary gene classifications.
658
  max_trials : None, int
659
  | Maximum number of trials of random splitting to try to achieve balanced other attribute
660
  | If no split is found without significant (p < pval_threshold) differences in other attributes, will select best
 
673
  n_hyperopt_trials : int
674
  | Number of trials to run for hyperparameter optimization
675
  | If 0, will not optimize hyperparameters
676
+ save_gene_split_datasets : bool
677
+ | Whether or not to save train, valid, and test gene-labeled datasets
678
  """
679
  if self.num_crossval_splits == 0:
680
  logger.error("num_crossval_splits must be 1 or 5 to validate.")
681
  raise
682
+
683
+ if (gene_balance is True) and (len(self.gene_class_dict.values())!=2):
684
+ logger.error("Automatically balancing gene sets for training is only available for binary gene classifications.")
685
+ raise
686
+
687
  # ensure number of genes in each class is > 5 if validating model
688
  if self.classifier == "gene":
689
  insuff_classes = [k for k, v in self.gene_class_dict.items() if len(v) < 5]
 
764
  else:
765
  # 5-fold cross-validate
766
  num_cells = len(data)
767
+ fifth_cells = int(np.floor(num_cells * 0.2))
768
  num_eval = min((self.eval_size * num_cells), fifth_cells)
769
  start = i * fifth_cells
770
  end = start + num_eval
 
843
  self.max_ncells,
844
  iteration_num,
845
  self.nproc,
846
+ gene_balance,
847
  )
848
+
849
+ if save_gene_split_datasets is True:
850
+ for split_name in ["train", "valid"]:
851
+ labeled_dataset_output_path = (
852
+ Path(output_dir) / f"{output_prefix}_{split_name}_gene_labeled_ksplit{iteration_num}"
853
+ ).with_suffix(".dataset")
854
+ if split_name == "train":
855
+ train_data.save_to_disk(str(labeled_dataset_output_path))
856
+ elif split_name == "valid":
857
+ eval_data.save_to_disk(str(labeled_dataset_output_path))
858
+
859
  if self.oos_test_size > 0:
860
  test_data = cu.prep_gene_classifier_split(
861
  data,
 
867
  iteration_num,
868
  self.nproc,
869
  )
870
+ if save_gene_split_datasets is True:
871
+ test_labeled_dataset_output_path = (
872
+ Path(output_dir) / f"{output_prefix}_test_gene_labeled_ksplit{iteration_num}"
873
+ ).with_suffix(".dataset")
874
+ test_data.save_to_disk(str(test_labeled_dataset_output_path))
875
+ if debug_gene_split_datasets is True:
876
+ logger.error("Exiting after saving gene split datasets given debug_gene_split_datasets = True.")
877
+ raise
878
  if n_hyperopt_trials == 0:
879
  trainer = self.train_classifier(
880
  model_directory,
 
1023
  subprocess.call(f"mkdir {output_directory}", shell=True)
1024
 
1025
  ##### Load model and training args #####
1026
+ model = pu.load_model(self.model_type, num_classes, model_directory, "train", quantize=self.quantize)
1027
  def_training_args, def_freeze_layers = cu.get_default_train_args(
1028
  model, self.classifier, train_data, output_directory
1029
  )
 
1047
  ##### Fine-tune the model #####
1048
  # define the data collator
1049
  if self.classifier == "cell":
1050
+ data_collator = DataCollatorForCellClassification(token_dictionary=self.token_dictionary)
1051
  elif self.classifier == "gene":
1052
+ data_collator = DataCollatorForGeneClassification(token_dictionary=self.token_dictionary)
1053
 
1054
  # define function to initiate model
1055
  def model_init():
1056
  model = pu.load_model(
1057
+ self.model_type, num_classes, model_directory, "train", quantize=self.quantize
1058
  )
1059
 
1060
  if self.freeze_layers is not None:
 
1066
  for param in module.parameters():
1067
  param.requires_grad = False
1068
 
1069
+ if self.quantize is False:
1070
+ model = model.to("cuda:0")
1071
  return model
1072
 
1073
  # create the trainer
 
1180
  subprocess.call(f"mkdir {output_directory}", shell=True)
1181
 
1182
  ##### Load model and training args #####
1183
+ model = pu.load_model(self.model_type, num_classes, model_directory, "train", quantize=self.quantize)
1184
 
1185
  def_training_args, def_freeze_layers = cu.get_default_train_args(
1186
  model, self.classifier, train_data, output_directory
 
1210
  ##### Fine-tune the model #####
1211
  # define the data collator
1212
  if self.classifier == "cell":
1213
+ data_collator = DataCollatorForCellClassification(token_dictionary=self.token_dictionary)
1214
  elif self.classifier == "gene":
1215
+ data_collator = DataCollatorForGeneClassification(token_dictionary=self.token_dictionary)
1216
 
1217
  # create the trainer
1218
  trainer = Trainer(
 
1334
  test_data = pu.load_and_filter(None, self.nproc, test_data_file)
1335
 
1336
  # load previously fine-tuned model
1337
+ model = pu.load_model(self.model_type, num_classes, model_directory, "eval", quantize=self.quantize)
1338
 
1339
  # evaluate the model
1340
  result = self.evaluate_model(
geneformer/classifier_utils.py CHANGED
@@ -137,21 +137,22 @@ def label_gene_classes(example, class_id_dict, gene_class_dict):
137
 
138
 
139
  def prep_gene_classifier_train_eval_split(
140
- data, targets, labels, train_index, eval_index, max_ncells, iteration_num, num_proc
141
  ):
142
  # generate cross-validation splits
143
  train_data = prep_gene_classifier_split(
144
- data, targets, labels, train_index, "train", max_ncells, iteration_num, num_proc
145
  )
146
  eval_data = prep_gene_classifier_split(
147
- data, targets, labels, eval_index, "eval", max_ncells, iteration_num, num_proc
148
  )
149
  return train_data, eval_data
150
 
151
 
152
  def prep_gene_classifier_split(
153
- data, targets, labels, index, subset_name, max_ncells, iteration_num, num_proc
154
  ):
 
155
  # generate cross-validation splits
156
  targets = np.array(targets)
157
  labels = np.array(labels)
@@ -172,6 +173,10 @@ def prep_gene_classifier_split(
172
  f"Filtered {round((1-len(subset_data)/len(data))*100)}%; {len(subset_data)} remain\n"
173
  )
174
 
 
 
 
 
175
  # subsample to max_ncells
176
  subset_data = downsample_and_shuffle(subset_data, max_ncells, None, None)
177
 
@@ -187,7 +192,7 @@ def prep_gene_classifier_split(
187
  return subset_data
188
 
189
 
190
- def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc):
191
  targets = np.array(targets)
192
  labels = np.array(labels)
193
  label_dict_train = dict(zip(targets, labels))
@@ -205,6 +210,9 @@ def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc):
205
  f"Filtered {round((1-len(train_data)/len(data))*100)}%; {len(train_data)} remain\n"
206
  )
207
 
 
 
 
208
  # subsample to max_ncells
209
  train_data = downsample_and_shuffle(train_data, max_ncells, None, None)
210
 
@@ -220,6 +228,110 @@ def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc):
220
  return train_data
221
 
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  def balance_attr_splits(
224
  data,
225
  attr_to_split,
 
137
 
138
 
139
  def prep_gene_classifier_train_eval_split(
140
+ data, targets, labels, train_index, eval_index, max_ncells, iteration_num, num_proc, balance=False
141
  ):
142
  # generate cross-validation splits
143
  train_data = prep_gene_classifier_split(
144
+ data, targets, labels, train_index, "train", max_ncells, iteration_num, num_proc, balance
145
  )
146
  eval_data = prep_gene_classifier_split(
147
+ data, targets, labels, eval_index, "eval", max_ncells, iteration_num, num_proc, balance
148
  )
149
  return train_data, eval_data
150
 
151
 
152
  def prep_gene_classifier_split(
153
+ data, targets, labels, index, subset_name, max_ncells, iteration_num, num_proc, balance=False
154
  ):
155
+
156
  # generate cross-validation splits
157
  targets = np.array(targets)
158
  labels = np.array(labels)
 
173
  f"Filtered {round((1-len(subset_data)/len(data))*100)}%; {len(subset_data)} remain\n"
174
  )
175
 
176
+ # balance gene subsets if train
177
+ if (subset_name == "train") and (balance is True):
178
+ subset_data, label_dict_subset = balance_gene_split(subset_data, label_dict_subset, num_proc)
179
+
180
  # subsample to max_ncells
181
  subset_data = downsample_and_shuffle(subset_data, max_ncells, None, None)
182
 
 
192
  return subset_data
193
 
194
 
195
+ def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc, balance=False):
196
  targets = np.array(targets)
197
  labels = np.array(labels)
198
  label_dict_train = dict(zip(targets, labels))
 
210
  f"Filtered {round((1-len(train_data)/len(data))*100)}%; {len(train_data)} remain\n"
211
  )
212
 
213
+ if balance is True:
214
+ train_data, label_dict_train = balance_gene_split(train_data, label_dict_train, num_proc)
215
+
216
  # subsample to max_ncells
217
  train_data = downsample_and_shuffle(train_data, max_ncells, None, None)
218
 
 
228
  return train_data
229
 
230
 
231
+ def balance_gene_split(subset_data, label_dict_subset, num_proc):
232
+ # count occurrence of genes in each label category
233
+ label0_counts, label1_counts = count_genes_for_balancing(subset_data, label_dict_subset, num_proc)
234
+ label_ratio_0to1 = label0_counts/label1_counts
235
+
236
+ if 8/10 <= label_ratio_0to1 <= 10/8:
237
+ # gene sets already balanced
238
+ logger.info(
239
+ "Gene sets were already balanced within 0.8-1.25 fold and did not require balancing.\n"
240
+ )
241
+ return subset_data, label_dict_subset
242
+ else:
243
+ label_ratio_0to1_orig = label_ratio_0to1+0
244
+ label_dict_subset_orig = label_dict_subset.copy()
245
+ # balance gene sets
246
+ max_ntrials = 25
247
+ boost = 1
248
+ if label_ratio_0to1 > 10/8:
249
+ # downsample label 0
250
+ for i in range(max_ntrials):
251
+ label0 = 0
252
+ label0_genes = [k for k,v in label_dict_subset.items() if v == label0]
253
+ label0_ngenes = len(label0_genes)
254
+ label0_nremove = max(1,int(np.floor(label0_ngenes - label0_ngenes/(label_ratio_0to1*boost))))
255
+ random.seed(i)
256
+ label0_remove_genes = random.sample(label0_genes, label0_nremove)
257
+ label_dict_subset_new = {k:v for k,v in label_dict_subset.items() if k not in label0_remove_genes}
258
+ label0_counts, label1_counts = count_genes_for_balancing(subset_data, label_dict_subset_new, num_proc)
259
+ label_ratio_0to1 = label0_counts/label1_counts
260
+ if 8/10 <= label_ratio_0to1 <= 10/8:
261
+ # if gene sets now balanced, return new filtered data and new label_dict_subset
262
+ return filter_data_balanced_genes(subset_data, label_dict_subset_new, num_proc)
263
+ elif label_ratio_0to1 > 10/8:
264
+ boost = boost*1.1
265
+ elif label_ratio_0to1 < 8/10:
266
+ boost = boost*0.9
267
+ else:
268
+ # downsample label 1
269
+ for i in range(max_ntrials):
270
+ label1 = 1
271
+ label1_genes = [k for k,v in label_dict_subset.items() if v == label1]
272
+ label1_ngenes = len(label1_genes)
273
+ label1_nremove = max(1,int(np.floor(label1_ngenes - label1_ngenes/((1/label_ratio_0to1)*boost))))
274
+ random.seed(i)
275
+ label1_remove_genes = random.sample(label1_genes, label1_nremove)
276
+ label_dict_subset_new = {k:v for k,v in label_dict_subset.items() if k not in label1_remove_genes}
277
+ label0_counts, label1_counts = count_genes_for_balancing(subset_data, label_dict_subset_new, num_proc)
278
+ label_ratio_0to1 = label0_counts/label1_counts
279
+ if 8/10 <= label_ratio_0to1 <= 10/8:
280
+ # if gene sets now balanced, return new filtered data and new label_dict_subset
281
+ return filter_data_balanced_genes(subset_data, label_dict_subset_new, num_proc)
282
+ elif label_ratio_0to1 < 8/10:
283
+ boost = boost*1.1
284
+ elif label_ratio_0to1 > 10/8:
285
+ boost = boost*0.9
286
+
287
+ assert i+1 == max_ntrials
288
+ if (label_ratio_0to1 <= label_ratio_0to1_orig < 8/10) or (10/8 > label_ratio_0to1_orig >= label_ratio_0to1):
289
+ label_ratio_0to1 = label_ratio_0to1_orig
290
+ label_dict_subset_new = label_dict_subset_orig
291
+ logger.warning(
292
+ f"Gene sets were not able to be balanced within 0.8-1.25 fold after {max_ntrials} trials. Imbalance level: {label_ratio_0to1}\n"
293
+ )
294
+ return filter_data_balanced_genes(subset_data, label_dict_subset_new, num_proc)
295
+
296
+
297
+ def count_genes_for_balancing(subset_data, label_dict_subset, num_proc):
298
+ def count_targets(example):
299
+ labels = [
300
+ label_dict_subset.get(token_id, -100) for token_id in example["input_ids"]
301
+ ]
302
+ counter_labels = Counter(labels)
303
+ # get count of labels 0 or 1, or if absent, return 0
304
+ example["labels_counts"] = [counter_labels.get(0,0),counter_labels.get(1,0)]
305
+ return example
306
+
307
+ subset_data = subset_data.map(count_targets, num_proc=num_proc)
308
+
309
+ label0_counts = sum([counts[0] for counts in subset_data["labels_counts"]])
310
+ label1_counts = sum([counts[1] for counts in subset_data["labels_counts"]])
311
+
312
+ subset_data = subset_data.remove_columns("labels_counts")
313
+
314
+ return label0_counts, label1_counts
315
+
316
+
317
+ def filter_data_balanced_genes(subset_data, label_dict_subset, num_proc):
318
+ # function to filter by whether contains labels
319
+ def if_contains_subset_label(example):
320
+ a = list(label_dict_subset.keys())
321
+ b = example["input_ids"]
322
+ return not set(a).isdisjoint(b)
323
+
324
+ # filter dataset for examples containing classes for this split
325
+ logger.info("Filtering data for balanced genes")
326
+ subset_data_len_orig = len(subset_data)
327
+ subset_data = subset_data.filter(if_contains_subset_label, num_proc=num_proc)
328
+ logger.info(
329
+ f"Filtered {round((1-len(subset_data)/subset_data_len_orig)*100)}%; {len(subset_data)} remain\n"
330
+ )
331
+
332
+ return subset_data, label_dict_subset
333
+
334
+
335
  def balance_attr_splits(
336
  data,
337
  attr_to_split,
geneformer/collator_for_classification.py CHANGED
@@ -18,12 +18,6 @@ from transformers import (
18
  from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
19
  from transformers.utils.generic import _is_tensorflow, _is_torch
20
 
21
- from . import TOKEN_DICTIONARY_FILE
22
-
23
- # load token dictionary (Ensembl IDs:token)
24
- with open(TOKEN_DICTIONARY_FILE, "rb") as f:
25
- token_dictionary = pickle.load(f)
26
-
27
  EncodedInput = List[int]
28
  logger = logging.get_logger(__name__)
29
  VERY_LARGE_INTEGER = int(
@@ -85,16 +79,18 @@ class TensorType(ExplicitEnum):
85
 
86
 
87
  class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
88
- mask_token = "<mask>"
89
- mask_token_id = token_dictionary.get("<mask>")
90
- pad_token = "<pad>"
91
- pad_token_id = token_dictionary.get("<pad>")
92
- padding_side = "right"
93
- all_special_ids = [
94
- token_dictionary.get("<mask>"),
95
- token_dictionary.get("<pad>")
96
- ]
97
- model_input_names = ["input_ids"]
 
 
98
 
99
  def _get_padding_truncation_strategies(
100
  self, padding=True, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
@@ -550,8 +546,7 @@ class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
550
  label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
551
  The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
552
  """
553
-
554
- tokenizer = PrecollatorForGeneAndCellClassification()
555
  class_type = "gene"
556
  padding: Union[bool, str, PaddingStrategy] = True
557
  max_length: Optional[int] = None
@@ -559,8 +554,9 @@ class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
559
  label_pad_token_id: int = -100
560
 
561
  def __init__(self, *args, **kwargs) -> None:
 
562
  super().__init__(
563
- tokenizer=self.tokenizer,
564
  padding=self.padding,
565
  max_length=self.max_length,
566
  pad_to_multiple_of=self.pad_to_multiple_of,
 
18
  from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
19
  from transformers.utils.generic import _is_tensorflow, _is_torch
20
 
 
 
 
 
 
 
21
  EncodedInput = List[int]
22
  logger = logging.get_logger(__name__)
23
  VERY_LARGE_INTEGER = int(
 
79
 
80
 
81
  class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
82
+ def __init__(self, *args, **kwargs) -> None:
83
+ super().__init__(mask_token="<mask>", pad_token="<pad>")
84
+
85
+ self.token_dictionary = kwargs.get("token_dictionary")
86
+ self.padding_side = "right"
87
+ self.model_input_names = ["input_ids"]
88
+ self.mask_token_id = self.token_dictionary.get("<mask>")
89
+ self.pad_token_id = self.token_dictionary.get("<pad>")
90
+ self.all_special_ids = [
91
+ self.token_dictionary.get("<mask>"),
92
+ self.token_dictionary.get("<pad>")
93
+ ]
94
 
95
  def _get_padding_truncation_strategies(
96
  self, padding=True, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
 
546
  label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
547
  The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
548
  """
549
+
 
550
  class_type = "gene"
551
  padding: Union[bool, str, PaddingStrategy] = True
552
  max_length: Optional[int] = None
 
554
  label_pad_token_id: int = -100
555
 
556
  def __init__(self, *args, **kwargs) -> None:
557
+ self.token_dictionary = kwargs.pop("token_dictionary")
558
  super().__init__(
559
+ tokenizer=PrecollatorForGeneAndCellClassification(token_dictionary=self.token_dictionary),
560
  padding=self.padding,
561
  max_length=self.max_length,
562
  pad_to_multiple_of=self.pad_to_multiple_of,
geneformer/emb_extractor.py CHANGED
@@ -286,12 +286,20 @@ def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict, seed=0):
286
  sc.tl.umap(adata, random_state=seed)
287
  sns.set(rc={"figure.figsize": (10, 10)}, font_scale=2.3)
288
  sns.set_style("white")
289
- default_kwargs_dict = {"palette": "Set2", "size": 200}
290
  if kwargs_dict is not None:
291
  default_kwargs_dict.update(kwargs_dict)
292
 
293
- with plt.rc_context():
294
- sc.pl.umap(adata, color=label, **default_kwargs_dict)
 
 
 
 
 
 
 
 
295
  plt.savefig(output_file, bbox_inches="tight")
296
 
297
 
@@ -470,7 +478,6 @@ class EmbExtractor:
470
  ... emb_mode="cell",
471
  ... filter_data={"cell_type":["cardiomyocyte"]},
472
  ... max_ncells=1000,
473
- ... max_ncells_to_plot=1000,
474
  ... emb_layer=-1,
475
  ... emb_label=["disease", "cell_type"],
476
  ... labels_to_plot=["disease", "cell_type"])
@@ -783,15 +790,15 @@ class EmbExtractor:
783
  logger.error("Plotting UMAP requires 'labels_to_plot'. ")
784
  raise
785
 
786
- if max_ncells_to_plot > self.max_ncells:
787
- max_ncells_to_plot = self.max_ncells
788
- logger.warning(
789
- "max_ncells_to_plot must be <= max_ncells. "
790
- f"Changing max_ncells_to_plot to {self.max_ncells}."
791
- )
792
-
793
- if (max_ncells_to_plot is not None) and (max_ncells_to_plot < self.max_ncells):
794
- embs = embs.sample(max_ncells_to_plot, axis=0)
795
 
796
  if self.emb_label is None:
797
  label_len = 0
 
286
  sc.tl.umap(adata, random_state=seed)
287
  sns.set(rc={"figure.figsize": (10, 10)}, font_scale=2.3)
288
  sns.set_style("white")
289
+ default_kwargs_dict = {"size": 200}
290
  if kwargs_dict is not None:
291
  default_kwargs_dict.update(kwargs_dict)
292
 
293
+ cats = set(embs_df[label])
294
+
295
+ with plt.rc_context():
296
+ ax = sc.pl.umap(adata, color=label, show=False, **default_kwargs_dict)
297
+ ax.legend(markerscale=2,
298
+ frameon=False,
299
+ loc="center left",
300
+ bbox_to_anchor=(1, 0.5),
301
+ ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3))
302
+ plt.show()
303
  plt.savefig(output_file, bbox_inches="tight")
304
 
305
 
 
478
  ... emb_mode="cell",
479
  ... filter_data={"cell_type":["cardiomyocyte"]},
480
  ... max_ncells=1000,
 
481
  ... emb_layer=-1,
482
  ... emb_label=["disease", "cell_type"],
483
  ... labels_to_plot=["disease", "cell_type"])
 
790
  logger.error("Plotting UMAP requires 'labels_to_plot'. ")
791
  raise
792
 
793
+ if max_ncells_to_plot is not None:
794
+ if max_ncells_to_plot > self.max_ncells:
795
+ max_ncells_to_plot = self.max_ncells
796
+ logger.warning(
797
+ "max_ncells_to_plot must be <= max_ncells. "
798
+ f"Changing max_ncells_to_plot to {self.max_ncells}."
799
+ )
800
+ elif max_ncells_to_plot < self.max_ncells:
801
+ embs = embs.sample(max_ncells_to_plot, axis=0)
802
 
803
  if self.emb_label is None:
804
  label_len = 0
geneformer/gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3b589bb5ec75040d05fc44dd6bf0184cf87f3c362cf158d196a6ed3b7fe5f39
3
+ size 940965
geneformer/{gene_name_id_dict.pkl β†’ gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl} RENAMED
File without changes
geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab9dc40973fa5224d77b793e2fd114cacf3d08423ed9c4c49caf0ba9c7f218f1
3
+ size 788424
geneformer/gene_median_dictionary.pkl DELETED
Binary file (941 kB)
 
geneformer/in_silico_perturber.py CHANGED
@@ -63,7 +63,7 @@ class InSilicoPerturber:
63
  "anchor_gene": {None, str},
64
  "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
65
  "num_classes": {int},
66
- "emb_mode": {"cell", "cell_and_gene"},
67
  "cell_emb_style": {"mean_pool"},
68
  "filter_data": {None, dict},
69
  "cell_states_to_model": {None, dict},
@@ -71,6 +71,7 @@ class InSilicoPerturber:
71
  "max_ncells": {None, int},
72
  "cell_inds_to_perturb": {"all", dict},
73
  "emb_layer": {-1, 0},
 
74
  "forward_batch_size": {int},
75
  "nproc": {int},
76
  }
@@ -94,7 +95,8 @@ class InSilicoPerturber:
94
  emb_layer=-1,
95
  forward_batch_size=100,
96
  nproc=4,
97
- token_dictionary_file=TOKEN_DICTIONARY_FILE,
 
98
  ):
99
  """
100
  Initialize in silico perturber.
@@ -129,16 +131,16 @@ class InSilicoPerturber:
129
  | ENSEMBL ID of gene to use as anchor in combination perturbations.
130
  | For example, if combos=1 and anchor_gene="ENSG00000148400":
131
  | anchor gene will be perturbed in combination with each other gene.
132
- model_type : {"Pretrained", "GeneClassifier", "CellClassifier"}
133
- | Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier.
134
  num_classes : int
135
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
136
  | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
137
- emb_mode : {"cell", "cell_and_gene"}
138
- | Whether to output impact of perturbation on cell and/or gene embeddings.
139
  | Gene embedding shifts only available as compared to original cell, not comparing to goal state.
140
  cell_emb_style : "mean_pool"
141
- | Method for summarizing cell embeddings.
142
  | Currently only option is mean pooling of gene embeddings for given cell.
143
  filter_data : None, dict
144
  | Default is to use all input data for in silico perturbation study.
@@ -183,6 +185,8 @@ class InSilicoPerturber:
183
  | Number of CPU processes to use.
184
  token_dictionary_file : Path
185
  | Path to pickle file containing token dictionary (Ensembl ID:token).
 
 
186
  """
187
  try:
188
  set_start_method("spawn")
@@ -219,15 +223,31 @@ class InSilicoPerturber:
219
  self.emb_layer = emb_layer
220
  self.forward_batch_size = forward_batch_size
221
  self.nproc = nproc
 
 
222
 
223
  self.validate_options()
224
 
225
  # load token dictionary (Ensembl IDs:token)
 
 
226
  with open(token_dictionary_file, "rb") as f:
227
  self.gene_token_dict = pickle.load(f)
228
  self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
229
 
230
  self.pad_token_id = self.gene_token_dict.get("<pad>")
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  if self.anchor_gene is None:
233
  self.anchor_token = None
@@ -285,7 +305,7 @@ class InSilicoPerturber:
285
  continue
286
  valid_type = False
287
  for option in valid_options:
288
- if (option in [bool, int, list, dict]) and isinstance(
289
  attr_value, option
290
  ):
291
  valid_type = True
@@ -426,22 +446,46 @@ class InSilicoPerturber:
426
  self.max_len = pu.get_model_input_size(model)
427
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
428
 
429
-
430
  ### filter input data ###
431
  # general filtering of input data based on filter_data argument
432
  filtered_input_data = pu.load_and_filter(
433
  self.filter_data, self.nproc, input_data_file
434
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
  filtered_input_data = self.apply_additional_filters(filtered_input_data)
436
 
437
  if self.perturb_group is True:
438
- self.isp_perturb_set(
439
- model, filtered_input_data, layer_to_quant, output_path_prefix
440
- )
 
 
 
 
 
441
  else:
442
- self.isp_perturb_all(
443
- model, filtered_input_data, layer_to_quant, output_path_prefix
444
- )
 
 
 
 
 
445
 
446
  def apply_additional_filters(self, filtered_input_data):
447
  # additional filtering of input data dependent on isp mode
@@ -486,6 +530,7 @@ class InSilicoPerturber:
486
  layer_to_quant: int,
487
  output_path_prefix: str,
488
  ):
 
489
  def make_group_perturbation_batch(example):
490
  example_input_ids = example["input_ids"]
491
  example["tokens_to_perturb"] = self.tokens_to_perturb
@@ -504,7 +549,7 @@ class InSilicoPerturber:
504
  if self.perturb_type == "delete":
505
  example = pu.delete_indices(example)
506
  elif self.perturb_type == "overexpress":
507
- example = pu.overexpress_tokens(example, self.max_len)
508
  example["n_overflow"] = pu.calc_n_overflow(
509
  self.max_len,
510
  example["length"],
@@ -678,8 +723,6 @@ class InSilicoPerturber:
678
  cos_sims_dict = self.update_perturbation_dictionary(
679
  cos_sims_dict,
680
  cos_sims_data,
681
- filtered_input_data,
682
- indices_to_perturb,
683
  gene_list,
684
  )
685
  else:
@@ -688,8 +731,6 @@ class InSilicoPerturber:
688
  cos_sims_dict[state] = self.update_perturbation_dictionary(
689
  cos_sims_dict[state],
690
  cos_sims_data[state],
691
- filtered_input_data,
692
- indices_to_perturb,
693
  gene_list,
694
  )
695
  del minibatch
@@ -711,6 +752,256 @@ class InSilicoPerturber:
711
  f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
712
  )
713
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
714
  def isp_perturb_all(
715
  self,
716
  model,
@@ -729,8 +1020,10 @@ class InSilicoPerturber:
729
 
730
  if self.emb_mode == "cell_and_gene":
731
  stored_gene_embs_dict = defaultdict(list)
732
- for i in trange(len(filtered_input_data)):
733
- example_cell = filtered_input_data.select([i])
 
 
734
  full_original_emb = get_embs(
735
  model,
736
  example_cell,
@@ -738,18 +1031,33 @@ class InSilicoPerturber:
738
  layer_to_quant,
739
  self.pad_token_id,
740
  self.forward_batch_size,
741
- token_gene_dict=self.token_gene_dict,
742
  summary_stat=None,
743
  silent=True,
744
  )
745
-
 
 
 
 
 
746
  # gene_list is used to assign cos sims back to genes
747
- # need to remove the anchor gene
748
  gene_list = example_cell["input_ids"][0][:]
 
749
  if self.anchor_token is not None:
750
  for token in self.anchor_token:
751
  gene_list.remove(token)
752
-
 
 
 
 
 
 
 
 
 
 
753
  perturbation_batch, indices_to_perturb = pu.make_perturbation_batch(
754
  example_cell,
755
  self.perturb_type,
@@ -759,148 +1067,430 @@ class InSilicoPerturber:
759
  self.nproc,
760
  )
761
 
762
- full_perturbation_emb = get_embs(
763
- model,
764
- perturbation_batch,
765
- "gene",
766
- layer_to_quant,
767
- self.pad_token_id,
768
- self.forward_batch_size,
769
- token_gene_dict=self.token_gene_dict,
770
- summary_stat=None,
771
- silent=True,
772
- )
773
-
774
- num_inds_perturbed = 1 + self.combos
775
- # need to remove overexpressed gene to quantify cosine shifts
776
- if self.perturb_type == "overexpress":
777
- perturbation_emb = full_perturbation_emb[:, num_inds_perturbed:, :]
778
- gene_list = gene_list[
779
- num_inds_perturbed:
780
- ] # index 0 is not overexpressed
781
-
782
- elif self.perturb_type == "delete":
783
- perturbation_emb = full_perturbation_emb
784
 
785
- original_batch = pu.make_comparison_batch(
786
- full_original_emb, indices_to_perturb, perturb_group=False
787
- )
788
-
789
- if self.cell_states_to_model is None or self.emb_mode == "cell_and_gene":
790
- gene_cos_sims = pu.quant_cos_sims(
791
- perturbation_emb,
792
- original_batch,
793
- self.cell_states_to_model,
794
- self.state_embs_dict,
795
- emb_mode="gene",
796
- )
797
- if self.cell_states_to_model is not None:
798
- original_cell_emb = pu.compute_nonpadded_cell_embedding(
799
- full_original_emb, "mean_pool"
800
- )
801
- perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
802
- full_perturbation_emb, "mean_pool"
803
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
804
 
805
- cell_cos_sims = pu.quant_cos_sims(
806
- perturbation_cell_emb,
807
- original_cell_emb,
808
- self.cell_states_to_model,
809
- self.state_embs_dict,
810
- emb_mode="cell",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
811
  )
812
 
813
- if self.emb_mode == "cell_and_gene":
814
- # remove perturbed index for gene list
815
- perturbed_gene_dict = {
816
- gene: gene_list[:i] + gene_list[i + 1 :]
817
- for i, gene in enumerate(gene_list)
 
 
818
  }
819
 
820
- for perturbation_i, perturbed_gene in enumerate(gene_list):
821
- for gene_j, affected_gene in enumerate(
822
- perturbed_gene_dict[perturbed_gene]
823
- ):
824
- try:
825
- stored_gene_embs_dict[
826
- (perturbed_gene, affected_gene)
827
- ].append(gene_cos_sims[perturbation_i, gene_j].item())
828
- except KeyError:
829
- stored_gene_embs_dict[
830
- (perturbed_gene, affected_gene)
831
- ] = gene_cos_sims[perturbation_i, gene_j].item()
832
 
833
- if self.cell_states_to_model is None:
834
- cos_sims_data = torch.mean(gene_cos_sims, dim=1)
835
- cos_sims_dict = self.update_perturbation_dictionary(
836
- cos_sims_dict,
837
- cos_sims_data,
838
- filtered_input_data,
839
- indices_to_perturb,
840
- gene_list,
841
- )
842
- else:
843
- cos_sims_data = cell_cos_sims
844
- for state in cos_sims_dict.keys():
845
- cos_sims_dict[state] = self.update_perturbation_dictionary(
846
- cos_sims_dict[state],
847
- cos_sims_data[state],
848
- filtered_input_data,
849
- indices_to_perturb,
850
- gene_list,
851
- )
 
 
 
852
 
853
- # save dict to disk every 100 cells
854
- if i % 100 == 0:
855
- pu.write_perturbation_dictionary(
856
- cos_sims_dict,
857
- f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}",
858
- )
859
- if self.emb_mode == "cell_and_gene":
860
- pu.write_perturbation_dictionary(
861
- stored_gene_embs_dict,
862
- f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
863
- )
864
 
865
- # reset and clear memory every 1000 cells
866
- if i % 1000 == 0:
867
- pickle_batch += 1
868
- if self.cell_states_to_model is None:
869
- cos_sims_dict = defaultdict(list)
870
- else:
871
- cos_sims_dict = {
872
- state: defaultdict(list)
873
- for state in pu.get_possible_states(self.cell_states_to_model)
874
- }
875
 
876
- if self.emb_mode == "cell_and_gene":
877
- stored_gene_embs_dict = defaultdict(list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
878
 
879
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
880
 
881
- pu.write_perturbation_dictionary(
882
- cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}"
883
- )
 
 
 
 
 
884
 
885
- if self.emb_mode == "cell_and_gene":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
886
  pu.write_perturbation_dictionary(
887
- stored_gene_embs_dict,
888
- f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
889
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
890
 
 
891
  def update_perturbation_dictionary(
892
  self,
893
  cos_sims_dict: defaultdict,
894
  cos_sims_data: torch.Tensor,
895
- filtered_input_data: Dataset,
896
- indices_to_perturb: List[List[int]],
897
  gene_list=None,
898
  ):
899
  if gene_list is not None and cos_sims_data.shape[0] != len(gene_list):
900
  logger.error(
901
  f"len(cos_sims_data.shape[0]) != len(gene_list). \n \
902
- cos_sims_data.shape[0] = {cos_sims_data.shape[0]}.\n \
903
- len(gene_list) = {len(gene_list)}."
904
  )
905
  raise
906
 
@@ -924,4 +1514,4 @@ class InSilicoPerturber:
924
  for i, cos in enumerate(cos_sims_data.tolist()):
925
  cos_sims_dict[(gene_list[i], "cell_emb")].append(cos)
926
 
927
- return cos_sims_dict
 
63
  "anchor_gene": {None, str},
64
  "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
65
  "num_classes": {int},
66
+ "emb_mode": {"cls", "cell", "cls_and_gene", "cell_and_gene"},
67
  "cell_emb_style": {"mean_pool"},
68
  "filter_data": {None, dict},
69
  "cell_states_to_model": {None, dict},
 
71
  "max_ncells": {None, int},
72
  "cell_inds_to_perturb": {"all", dict},
73
  "emb_layer": {-1, 0},
74
+ "token_dictionary_file" : {None, str},
75
  "forward_batch_size": {int},
76
  "nproc": {int},
77
  }
 
95
  emb_layer=-1,
96
  forward_batch_size=100,
97
  nproc=4,
98
+ token_dictionary_file=None,
99
+ clear_mem_ncells=1000,
100
  ):
101
  """
102
  Initialize in silico perturber.
 
131
  | ENSEMBL ID of gene to use as anchor in combination perturbations.
132
  | For example, if combos=1 and anchor_gene="ENSG00000148400":
133
  | anchor gene will be perturbed in combination with each other gene.
134
+ model_type : {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "MTLCellClassifier-Quantized"}
135
+ | Whether model is the pretrained Geneformer or a fine-tuned gene, cell, or multitask cell classifier (+/- 8bit quantization).
136
  num_classes : int
137
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
138
  | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
139
+ emb_mode : {"cls", "cell", "cls_and_gene","cell_and_gene"}
140
+ | Whether to output impact of perturbation on CLS token, cell, and/or gene embeddings.
141
  | Gene embedding shifts only available as compared to original cell, not comparing to goal state.
142
  cell_emb_style : "mean_pool"
143
+ | Method for summarizing cell embeddings if not using CLS token.
144
  | Currently only option is mean pooling of gene embeddings for given cell.
145
  filter_data : None, dict
146
  | Default is to use all input data for in silico perturbation study.
 
185
  | Number of CPU processes to use.
186
  token_dictionary_file : Path
187
  | Path to pickle file containing token dictionary (Ensembl ID:token).
188
+ clear_mem_ncells : int
189
+ | Clear memory every n cells.
190
  """
191
  try:
192
  set_start_method("spawn")
 
223
  self.emb_layer = emb_layer
224
  self.forward_batch_size = forward_batch_size
225
  self.nproc = nproc
226
+ self.token_dictionary_file = token_dictionary_file
227
+ self.clear_mem_ncells = clear_mem_ncells
228
 
229
  self.validate_options()
230
 
231
  # load token dictionary (Ensembl IDs:token)
232
+ if self.token_dictionary_file is None:
233
+ token_dictionary_file = TOKEN_DICTIONARY_FILE
234
  with open(token_dictionary_file, "rb") as f:
235
  self.gene_token_dict = pickle.load(f)
236
  self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
237
 
238
  self.pad_token_id = self.gene_token_dict.get("<pad>")
239
+ self.cls_token_id = self.gene_token_dict.get("<cls>")
240
+ self.eos_token_id = self.gene_token_dict.get("<eos>")
241
+
242
+
243
+ # Identify if special token is present in the token dictionary
244
+ if (self.cls_token_id is not None) and (self.eos_token_id is not None):
245
+ self.special_token = True
246
+ else:
247
+ if "cls" in self.emb_mode:
248
+ logger.error(f"emb_mode set to {self.emb_mode} but <cls> or <eos> token not in token dictionary.")
249
+ raise
250
+ self.special_token = False
251
 
252
  if self.anchor_gene is None:
253
  self.anchor_token = None
 
305
  continue
306
  valid_type = False
307
  for option in valid_options:
308
+ if (option in [bool, int, list, dict, str]) and isinstance(
309
  attr_value, option
310
  ):
311
  valid_type = True
 
446
  self.max_len = pu.get_model_input_size(model)
447
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
448
 
 
449
  ### filter input data ###
450
  # general filtering of input data based on filter_data argument
451
  filtered_input_data = pu.load_and_filter(
452
  self.filter_data, self.nproc, input_data_file
453
  )
454
+
455
+ # Ensure emb_mode is cls if first token of the filtered input data is cls token
456
+ if self.special_token:
457
+ if (filtered_input_data["input_ids"][0][0] == self.cls_token_id) and ("cls" not in self.emb_mode):
458
+ logger.error(
459
+ "Emb mode 'cls' or 'cls_and_gene' required when first token is <cls>."
460
+ )
461
+ raise
462
+ if ("cls" in self.emb_mode):
463
+ if (filtered_input_data["input_ids"][0][0] != self.cls_token_id) or (filtered_input_data["input_ids"][0][-1] != self.eos_token_id):
464
+ logger.error(
465
+ "Emb mode 'cls' and 'cls_and_gene' require that first token is <cls> and last token is <eos>."
466
+ )
467
+ raise
468
+
469
  filtered_input_data = self.apply_additional_filters(filtered_input_data)
470
 
471
  if self.perturb_group is True:
472
+ if (self.special_token) and ("cls" in self.emb_mode):
473
+ self.isp_perturb_set_special(
474
+ model, filtered_input_data, layer_to_quant, output_path_prefix
475
+ )
476
+ else:
477
+ self.isp_perturb_set(
478
+ model, filtered_input_data, layer_to_quant, output_path_prefix
479
+ )
480
  else:
481
+ if (self.special_token) and ("cls" in self.emb_mode):
482
+ self.isp_perturb_all_special(
483
+ model, filtered_input_data, layer_to_quant, output_path_prefix
484
+ )
485
+ else:
486
+ self.isp_perturb_all(
487
+ model, filtered_input_data, layer_to_quant, output_path_prefix
488
+ )
489
 
490
  def apply_additional_filters(self, filtered_input_data):
491
  # additional filtering of input data dependent on isp mode
 
530
  layer_to_quant: int,
531
  output_path_prefix: str,
532
  ):
533
+
534
  def make_group_perturbation_batch(example):
535
  example_input_ids = example["input_ids"]
536
  example["tokens_to_perturb"] = self.tokens_to_perturb
 
549
  if self.perturb_type == "delete":
550
  example = pu.delete_indices(example)
551
  elif self.perturb_type == "overexpress":
552
+ example = pu.overexpress_tokens(example, self.max_len, self.special_token)
553
  example["n_overflow"] = pu.calc_n_overflow(
554
  self.max_len,
555
  example["length"],
 
723
  cos_sims_dict = self.update_perturbation_dictionary(
724
  cos_sims_dict,
725
  cos_sims_data,
 
 
726
  gene_list,
727
  )
728
  else:
 
731
  cos_sims_dict[state] = self.update_perturbation_dictionary(
732
  cos_sims_dict[state],
733
  cos_sims_data[state],
 
 
734
  gene_list,
735
  )
736
  del minibatch
 
752
  f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
753
  )
754
 
755
+
756
+ def isp_perturb_set_special(
757
+ self,
758
+ model,
759
+ filtered_input_data: Dataset,
760
+ layer_to_quant: int,
761
+ output_path_prefix: str,
762
+ ):
763
+
764
+ def make_group_perturbation_batch(example):
765
+ example_input_ids = example["input_ids"]
766
+ example["tokens_to_perturb"] = self.tokens_to_perturb
767
+ indices_to_perturb = [
768
+ example_input_ids.index(token) if token in example_input_ids else None
769
+ for token in self.tokens_to_perturb
770
+ ]
771
+ indices_to_perturb = [
772
+ item for item in indices_to_perturb if item is not None
773
+ ]
774
+ if len(indices_to_perturb) > 0:
775
+ example["perturb_index"] = indices_to_perturb
776
+ else:
777
+ # -100 indicates tokens to overexpress are not present in rank value encoding
778
+ example["perturb_index"] = [-100]
779
+ if self.perturb_type == "delete":
780
+ example = pu.delete_indices(example)
781
+ elif self.perturb_type == "overexpress":
782
+ example = pu.overexpress_tokens(example, self.max_len, self.special_token)
783
+ example["n_overflow"] = pu.calc_n_overflow(
784
+ self.max_len,
785
+ example["length"],
786
+ self.tokens_to_perturb,
787
+ indices_to_perturb,
788
+ )
789
+ return example
790
+
791
+ total_batch_length = len(filtered_input_data)
792
+ if self.cell_states_to_model is None:
793
+ cos_sims_dict = defaultdict(list)
794
+ else:
795
+ cos_sims_dict = {
796
+ state: defaultdict(list)
797
+ for state in pu.get_possible_states(self.cell_states_to_model)
798
+ }
799
+
800
+ perturbed_data = filtered_input_data.map(
801
+ make_group_perturbation_batch, num_proc=self.nproc
802
+ )
803
+
804
+ if self.perturb_type == "overexpress":
805
+ filtered_input_data = filtered_input_data.add_column(
806
+ "n_overflow", perturbed_data["n_overflow"]
807
+ )
808
+ filtered_input_data = filtered_input_data.map(
809
+ pu.truncate_by_n_overflow_special, num_proc=self.nproc
810
+ )
811
+
812
+ if self.emb_mode == "cls_and_gene":
813
+ stored_gene_embs_dict = defaultdict(list)
814
+
815
+ # iterate through batches
816
+ for i in trange(0, total_batch_length, self.forward_batch_size):
817
+ max_range = min(i + self.forward_batch_size, total_batch_length)
818
+ inds_select = [i for i in range(i, max_range)]
819
+
820
+ minibatch = filtered_input_data.select(inds_select)
821
+ perturbation_batch = perturbed_data.select(inds_select)
822
+
823
+ ##### CLS Embedding Mode #####
824
+ if self.emb_mode == "cls":
825
+ indices_to_perturb = perturbation_batch["perturb_index"]
826
+
827
+ original_cls_emb = get_embs(
828
+ model,
829
+ minibatch,
830
+ "cls",
831
+ layer_to_quant,
832
+ self.pad_token_id,
833
+ self.forward_batch_size,
834
+ token_gene_dict=self.token_gene_dict,
835
+ summary_stat=None,
836
+ silent=True,
837
+ )
838
+
839
+ perturbation_cls_emb = get_embs(
840
+ model,
841
+ perturbation_batch,
842
+ "cls",
843
+ layer_to_quant,
844
+ self.pad_token_id,
845
+ self.forward_batch_size,
846
+ token_gene_dict=self.token_gene_dict,
847
+ summary_stat=None,
848
+ silent=True,
849
+ )
850
+
851
+ # Calculate the cosine similarities
852
+ cls_cos_sims = pu.quant_cos_sims(
853
+ perturbation_cls_emb,
854
+ original_cls_emb,
855
+ self.cell_states_to_model,
856
+ self.state_embs_dict,
857
+ emb_mode="cell")
858
+
859
+ # Update perturbation dictionary
860
+ if self.cell_states_to_model is None:
861
+ cos_sims_dict = self.update_perturbation_dictionary(
862
+ cos_sims_dict,
863
+ cls_cos_sims,
864
+ gene_list = None,
865
+ )
866
+ else:
867
+ for state in cos_sims_dict.keys():
868
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
869
+ cos_sims_dict[state],
870
+ cls_cos_sims[state],
871
+ gene_list = None,
872
+ )
873
+
874
+ ##### CLS and Gene Embedding Mode #####
875
+ elif self.emb_mode == "cls_and_gene":
876
+ full_original_emb = get_embs(
877
+ model,
878
+ minibatch,
879
+ "gene",
880
+ layer_to_quant,
881
+ self.pad_token_id,
882
+ self.forward_batch_size,
883
+ self.token_gene_dict,
884
+ summary_stat=None,
885
+ silent=True,
886
+ )
887
+ indices_to_perturb = perturbation_batch["perturb_index"]
888
+ # remove indices that were perturbed
889
+ original_emb = pu.remove_perturbed_indices_set(
890
+ full_original_emb,
891
+ self.perturb_type,
892
+ indices_to_perturb,
893
+ self.tokens_to_perturb,
894
+ minibatch["length"],
895
+ )
896
+ full_perturbation_emb = get_embs(
897
+ model,
898
+ perturbation_batch,
899
+ "gene",
900
+ layer_to_quant,
901
+ self.pad_token_id,
902
+ self.forward_batch_size,
903
+ self.token_gene_dict,
904
+ summary_stat=None,
905
+ silent=True,
906
+ )
907
+
908
+ # remove special tokens and padding
909
+ original_emb = original_emb[:, 1:-1, :]
910
+ if self.perturb_type == "overexpress":
911
+ perturbation_emb = full_perturbation_emb[:,1+len(self.tokens_to_perturb):-1,:]
912
+ elif self.perturb_type == "delete":
913
+ perturbation_emb = full_perturbation_emb[:,1:max(perturbation_batch["length"])-1,:]
914
+
915
+ n_perturbation_genes = perturbation_emb.size()[1]
916
+
917
+ gene_cos_sims = pu.quant_cos_sims(
918
+ perturbation_emb,
919
+ original_emb,
920
+ self.cell_states_to_model,
921
+ self.state_embs_dict,
922
+ emb_mode="gene",
923
+ )
924
+
925
+ # get cls emb
926
+ original_cls_emb = full_original_emb[:,0,:]
927
+ perturbation_cls_emb = full_perturbation_emb[:,0,:]
928
+
929
+ cls_cos_sims = pu.quant_cos_sims(
930
+ perturbation_cls_emb,
931
+ original_cls_emb,
932
+ self.cell_states_to_model,
933
+ self.state_embs_dict,
934
+ emb_mode="cell",
935
+ )
936
+
937
+ # get cosine similarities in gene embeddings
938
+ # since getting gene embeddings, need gene names
939
+
940
+ gene_list = minibatch["input_ids"]
941
+ # need to truncate gene_list
942
+ genes_to_exclude = self.tokens_to_perturb + [self.cls_token_id, self.eos_token_id]
943
+ gene_list = [
944
+ [g for g in genes if g not in genes_to_exclude][
945
+ :n_perturbation_genes
946
+ ]
947
+ for genes in gene_list
948
+ ]
949
+
950
+ for cell_i, genes in enumerate(gene_list):
951
+ for gene_j, affected_gene in enumerate(genes):
952
+ if len(self.genes_to_perturb) > 1:
953
+ tokens_to_perturb = tuple(self.tokens_to_perturb)
954
+ else:
955
+ tokens_to_perturb = self.tokens_to_perturb[0]
956
+
957
+ # fill in the gene cosine similarities
958
+ try:
959
+ stored_gene_embs_dict[
960
+ (tokens_to_perturb, affected_gene)
961
+ ].append(gene_cos_sims[cell_i, gene_j].item())
962
+ except KeyError:
963
+ stored_gene_embs_dict[
964
+ (tokens_to_perturb, affected_gene)
965
+ ] = gene_cos_sims[cell_i, gene_j].item()
966
+
967
+ if self.cell_states_to_model is None:
968
+ cos_sims_dict = self.update_perturbation_dictionary(
969
+ cos_sims_dict,
970
+ cls_cos_sims,
971
+ gene_list = None,
972
+ )
973
+ else:
974
+ for state in cos_sims_dict.keys():
975
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
976
+ cos_sims_dict[state],
977
+ cls_cos_sims[state],
978
+ gene_list = None,
979
+ )
980
+ del full_original_emb
981
+ del original_emb
982
+ del full_perturbation_emb
983
+ del perturbation_emb
984
+ del gene_cos_sims
985
+
986
+ del original_cls_emb
987
+ del perturbation_cls_emb
988
+ del cls_cos_sims
989
+ del minibatch
990
+ del perturbation_batch
991
+
992
+ torch.cuda.empty_cache()
993
+
994
+ pu.write_perturbation_dictionary(
995
+ cos_sims_dict,
996
+ f"{output_path_prefix}_cell_embs_dict_{self.tokens_to_perturb}",
997
+ )
998
+
999
+ if self.emb_mode == "cls_and_gene":
1000
+ pu.write_perturbation_dictionary(
1001
+ stored_gene_embs_dict,
1002
+ f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
1003
+ )
1004
+
1005
  def isp_perturb_all(
1006
  self,
1007
  model,
 
1020
 
1021
  if self.emb_mode == "cell_and_gene":
1022
  stored_gene_embs_dict = defaultdict(list)
1023
+
1024
+ num_inds_perturbed = 1 + self.combos
1025
+ for h in trange(len(filtered_input_data)):
1026
+ example_cell = filtered_input_data.select([h])
1027
  full_original_emb = get_embs(
1028
  model,
1029
  example_cell,
 
1031
  layer_to_quant,
1032
  self.pad_token_id,
1033
  self.forward_batch_size,
1034
+ self.token_gene_dict,
1035
  summary_stat=None,
1036
  silent=True,
1037
  )
1038
+
1039
+ if self.cell_states_to_model is not None:
1040
+ original_cell_emb = pu.compute_nonpadded_cell_embedding(
1041
+ full_original_emb, "mean_pool"
1042
+ )
1043
+
1044
  # gene_list is used to assign cos sims back to genes
 
1045
  gene_list = example_cell["input_ids"][0][:]
1046
+ # need to remove the anchor gene
1047
  if self.anchor_token is not None:
1048
  for token in self.anchor_token:
1049
  gene_list.remove(token)
1050
+ # index 0 is not overexpressed so remove
1051
+ if self.perturb_type == "overexpress":
1052
+ gene_list = gene_list[
1053
+ num_inds_perturbed:
1054
+ ]
1055
+ # remove perturbed index for gene list dict
1056
+ perturbed_gene_dict = {
1057
+ gene: gene_list[:i] + gene_list[i + 1 :]
1058
+ for i, gene in enumerate(gene_list)
1059
+ }
1060
+
1061
  perturbation_batch, indices_to_perturb = pu.make_perturbation_batch(
1062
  example_cell,
1063
  self.perturb_type,
 
1067
  self.nproc,
1068
  )
1069
 
1070
+ ispall_total_batch_length = len(perturbation_batch)
1071
+ for i in trange(0, ispall_total_batch_length, self.forward_batch_size, leave=False):
1072
+ ispall_max_range = min(i + self.forward_batch_size, ispall_total_batch_length)
1073
+ perturbation_minibatch = perturbation_batch.select([i for i in range(i, ispall_max_range)])
1074
+ indices_to_perturb_mini = indices_to_perturb[i : ispall_max_range]
1075
+ gene_list_mini = gene_list[i : ispall_max_range] # only perturbed genes from this minibatch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1076
 
1077
+ full_perturbation_emb = get_embs(
1078
+ model,
1079
+ perturbation_minibatch,
1080
+ "gene",
1081
+ layer_to_quant,
1082
+ self.pad_token_id,
1083
+ self.forward_batch_size,
1084
+ self.token_gene_dict,
1085
+ summary_stat=None,
1086
+ silent=True,
 
 
 
 
 
 
 
 
1087
  )
1088
+
1089
+ del perturbation_minibatch
1090
+
1091
+ # need to remove overexpressed gene to quantify cosine shifts
1092
+ if self.perturb_type == "overexpress":
1093
+ perturbation_emb = full_perturbation_emb[:, num_inds_perturbed:, :]
1094
+
1095
+ elif self.perturb_type == "delete":
1096
+ perturbation_emb = full_perturbation_emb
1097
+
1098
+
1099
+ if self.cell_states_to_model is None or self.emb_mode == "cell_and_gene":
1100
+ original_emb_minibatch = pu.make_comparison_batch(
1101
+ full_original_emb, indices_to_perturb_mini, perturb_group=False
1102
+ )
1103
+ gene_cos_sims = pu.quant_cos_sims(
1104
+ perturbation_emb,
1105
+ original_emb_minibatch,
1106
+ self.cell_states_to_model,
1107
+ self.state_embs_dict,
1108
+ emb_mode="gene",
1109
+ )
1110
+ del original_emb_minibatch
1111
+
1112
+ if self.cell_states_to_model is not None:
1113
+ perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
1114
+ full_perturbation_emb, "mean_pool"
1115
+ )
1116
+
1117
+ cell_cos_sims = pu.quant_cos_sims(
1118
+ perturbation_cell_emb,
1119
+ original_cell_emb,
1120
+ self.cell_states_to_model,
1121
+ self.state_embs_dict,
1122
+ emb_mode="cell",
1123
+ )
1124
+ del perturbation_cell_emb
1125
+
1126
+ if self.emb_mode == "cell_and_gene":
1127
 
1128
+ for perturbation_i, perturbed_gene in enumerate(gene_list_mini):
1129
+ for gene_j, affected_gene in enumerate(
1130
+ perturbed_gene_dict[perturbed_gene]
1131
+ ):
1132
+ try:
1133
+ stored_gene_embs_dict[
1134
+ (perturbed_gene, affected_gene)
1135
+ ].append(gene_cos_sims[perturbation_i, gene_j].item())
1136
+ except KeyError:
1137
+ stored_gene_embs_dict[
1138
+ (perturbed_gene, affected_gene)
1139
+ ] = gene_cos_sims[perturbation_i, gene_j].item()
1140
+
1141
+ del full_perturbation_emb
1142
+
1143
+ if self.cell_states_to_model is None:
1144
+ cos_sims_data = torch.mean(gene_cos_sims, dim=1)
1145
+ cos_sims_dict = self.update_perturbation_dictionary(
1146
+ cos_sims_dict,
1147
+ cos_sims_data,
1148
+ gene_list_mini,
1149
+ )
1150
+ else:
1151
+ cos_sims_data = cell_cos_sims
1152
+ for state in cos_sims_dict.keys():
1153
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
1154
+ cos_sims_dict[state],
1155
+ cos_sims_data[state],
1156
+ gene_list_mini,
1157
+ )
1158
+
1159
+ # save dict to disk every self.clear_mem_ncells/10 (default 100) simulated cells
1160
+ if i % self.clear_mem_ncells/10 == 0:
1161
+ pu.write_perturbation_dictionary(
1162
+ cos_sims_dict,
1163
+ f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
1164
+ )
1165
+ if self.emb_mode == "cell_and_gene":
1166
+ pu.write_perturbation_dictionary(
1167
+ stored_gene_embs_dict,
1168
+ f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
1169
+ )
1170
+
1171
+ # reset and clear memory every self.clear_mem_ncells (default 1000) simulated cells or at the end of the example cell
1172
+ if i % self.clear_mem_ncells == 0:
1173
+ pickle_batch += 1
1174
+ if self.cell_states_to_model is None:
1175
+ cos_sims_dict = defaultdict(list)
1176
+ else:
1177
+ cos_sims_dict = {
1178
+ state: defaultdict(list)
1179
+ for state in pu.get_possible_states(self.cell_states_to_model)
1180
+ }
1181
+
1182
+ if self.emb_mode == "cell_and_gene":
1183
+ stored_gene_embs_dict = defaultdict(list)
1184
+
1185
+ torch.cuda.empty_cache()
1186
+
1187
+ pu.write_perturbation_dictionary(
1188
+ cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}"
1189
+ )
1190
+
1191
+ if self.emb_mode == "cell_and_gene":
1192
+ pu.write_perturbation_dictionary(
1193
+ stored_gene_embs_dict,
1194
+ f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
1195
  )
1196
 
1197
+ pickle_batch = -1
1198
+ if self.cell_states_to_model is None:
1199
+ cos_sims_dict = defaultdict(list)
1200
+ else:
1201
+ cos_sims_dict = {
1202
+ state: defaultdict(list)
1203
+ for state in pu.get_possible_states(self.cell_states_to_model)
1204
  }
1205
 
1206
+ if self.emb_mode == "cell_and_gene":
1207
+ stored_gene_embs_dict = defaultdict(list)
 
 
 
 
 
 
 
 
 
 
1208
 
1209
+ # clear memory between cells
1210
+ del perturbation_batch
1211
+ del full_original_emb
1212
+ if self.cell_states_to_model is not None:
1213
+ del original_cell_emb
1214
+ torch.cuda.empty_cache()
1215
+
1216
+ def isp_perturb_all_special(
1217
+ self,
1218
+ model,
1219
+ filtered_input_data: Dataset,
1220
+ layer_to_quant: int,
1221
+ output_path_prefix: str,
1222
+ ):
1223
+ pickle_batch = -1
1224
+ if self.cell_states_to_model is None:
1225
+ cos_sims_dict = defaultdict(list)
1226
+ else:
1227
+ cos_sims_dict = {
1228
+ state: defaultdict(list)
1229
+ for state in pu.get_possible_states(self.cell_states_to_model)
1230
+ }
1231
 
1232
+ if self.emb_mode == "cls_and_gene":
1233
+ stored_gene_embs_dict = defaultdict(list)
 
 
 
 
 
 
 
 
 
1234
 
1235
+ num_inds_perturbed = 1 + self.combos
1236
+ for h in trange(len(filtered_input_data)):
1237
+ example_cell = filtered_input_data.select([h])
 
 
 
 
 
 
 
1238
 
1239
+ # get original example cell cls and/or gene embs for comparison
1240
+ if self.emb_mode == "cls":
1241
+ original_cls_emb = get_embs(
1242
+ model,
1243
+ example_cell,
1244
+ "cls",
1245
+ layer_to_quant,
1246
+ self.pad_token_id,
1247
+ self.forward_batch_size,
1248
+ self.token_gene_dict,
1249
+ summary_stat=None,
1250
+ silent=True,
1251
+ )
1252
+ elif self.emb_mode == "cls_and_gene":
1253
+ full_original_emb = get_embs(
1254
+ model,
1255
+ example_cell,
1256
+ "gene",
1257
+ layer_to_quant,
1258
+ self.pad_token_id,
1259
+ self.forward_batch_size,
1260
+ self.token_gene_dict,
1261
+ summary_stat=None,
1262
+ silent=True,
1263
+ )
1264
+ original_cls_emb = full_original_emb[:,0,:].clone().detach()
1265
+
1266
+ # gene_list is used to assign cos sims back to genes
1267
+ gene_list = example_cell["input_ids"][0][:]
1268
 
1269
+ # need to remove special tokens
1270
+ for token in [self.cls_token_id, self.eos_token_id]:
1271
+ gene_list.remove(token)
1272
+ # need to remove the anchor gene
1273
+ if self.anchor_token is not None:
1274
+ for token in self.anchor_token:
1275
+ gene_list.remove(token)
1276
+ # index 0 is not overexpressed so remove
1277
+ if self.perturb_type == "overexpress":
1278
+ gene_list = gene_list[
1279
+ num_inds_perturbed:
1280
+ ]
1281
+ # remove perturbed index for gene list dict
1282
+ perturbed_gene_dict = {
1283
+ gene: gene_list[:i] + gene_list[i + 1 :]
1284
+ for i, gene in enumerate(gene_list)
1285
+ }
1286
 
1287
+ perturbation_batch, indices_to_perturb = pu.make_perturbation_batch_special(
1288
+ example_cell,
1289
+ self.perturb_type,
1290
+ self.tokens_to_perturb,
1291
+ self.anchor_token,
1292
+ self.combos,
1293
+ self.nproc,
1294
+ )
1295
 
1296
+ ispall_total_batch_length = len(perturbation_batch)
1297
+ for i in trange(0, ispall_total_batch_length, self.forward_batch_size, leave=False):
1298
+ ispall_max_range = min(i + self.forward_batch_size, ispall_total_batch_length)
1299
+ perturbation_minibatch = perturbation_batch.select([i for i in range(i, ispall_max_range)])
1300
+ indices_to_perturb_mini = indices_to_perturb[i : ispall_max_range]
1301
+ gene_list_mini = gene_list[i : ispall_max_range] # only perturbed genes from this minibatch
1302
+
1303
+ ##### CLS Embedding Mode #####
1304
+ if self.emb_mode == "cls":
1305
+ # Extract cls embeddings from perturbed cells
1306
+ perturbation_cls_emb = get_embs(
1307
+ model,
1308
+ perturbation_minibatch,
1309
+ "cls",
1310
+ layer_to_quant,
1311
+ self.pad_token_id,
1312
+ self.forward_batch_size,
1313
+ self.token_gene_dict,
1314
+ summary_stat=None,
1315
+ silent=True,
1316
+ )
1317
+
1318
+ # Calculate cosine similarities
1319
+ cls_cos_sims = pu.quant_cos_sims(
1320
+ perturbation_cls_emb,
1321
+ original_cls_emb,
1322
+ self.cell_states_to_model,
1323
+ self.state_embs_dict,
1324
+ emb_mode="cell",
1325
+ )
1326
+
1327
+ if self.cell_states_to_model is None:
1328
+ cos_sims_dict = self.update_perturbation_dictionary(
1329
+ cos_sims_dict,
1330
+ cls_cos_sims,
1331
+ gene_list_mini,
1332
+ )
1333
+ else:
1334
+
1335
+ for state in cos_sims_dict.keys():
1336
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
1337
+ cos_sims_dict[state],
1338
+ cls_cos_sims[state],
1339
+ gene_list_mini,
1340
+ )
1341
+
1342
+ del perturbation_minibatch
1343
+ del perturbation_cls_emb
1344
+ del cls_cos_sims
1345
+
1346
+ ##### CLS and Gene Embedding Mode #####
1347
+ elif self.emb_mode == "cls_and_gene":
1348
+ full_perturbation_emb = get_embs(
1349
+ model,
1350
+ perturbation_minibatch,
1351
+ "gene",
1352
+ layer_to_quant,
1353
+ self.pad_token_id,
1354
+ self.forward_batch_size,
1355
+ self.token_gene_dict,
1356
+ summary_stat=None,
1357
+ silent=True,
1358
+ )
1359
+
1360
+ # need to remove overexpressed gene and cls/eos to quantify cosine shifts
1361
+ if self.perturb_type == "overexpress":
1362
+ perturbation_emb = full_perturbation_emb[:, 1+num_inds_perturbed:-1, :].clone().detach()
1363
+ elif self.perturb_type == "delete":
1364
+ perturbation_emb = full_perturbation_emb[:, 1:-1, :].clone().detach()
1365
+
1366
+ original_emb_minibatch = pu.make_comparison_batch(
1367
+ full_original_emb, indices_to_perturb_mini, perturb_group=False
1368
+ )
1369
+
1370
+ original_emb_minibatch = original_emb_minibatch[:, 1:-1, :].clone().detach()
1371
+ gene_cos_sims = pu.quant_cos_sims(
1372
+ perturbation_emb,
1373
+ original_emb_minibatch,
1374
+ self.cell_states_to_model,
1375
+ self.state_embs_dict,
1376
+ emb_mode="gene",
1377
+ )
1378
+
1379
+ for perturbation_i, perturbed_gene in enumerate(gene_list_mini):
1380
+ for gene_j, affected_gene in enumerate(
1381
+ perturbed_gene_dict[perturbed_gene]
1382
+ ):
1383
+ try:
1384
+ stored_gene_embs_dict[
1385
+ (perturbed_gene, affected_gene)
1386
+ ].append(gene_cos_sims[perturbation_i, gene_j].item())
1387
+ except KeyError:
1388
+ stored_gene_embs_dict[
1389
+ (perturbed_gene, affected_gene)
1390
+ ] = gene_cos_sims[perturbation_i, gene_j].item()
1391
+
1392
+ # get cls emb
1393
+ perturbation_cls_emb = full_perturbation_emb[:,0,:].clone().detach()
1394
+
1395
+ cls_cos_sims = pu.quant_cos_sims(
1396
+ perturbation_cls_emb,
1397
+ original_cls_emb,
1398
+ self.cell_states_to_model,
1399
+ self.state_embs_dict,
1400
+ emb_mode="cell",
1401
+ )
1402
+
1403
+ if self.cell_states_to_model is None:
1404
+ cos_sims_dict = self.update_perturbation_dictionary(
1405
+ cos_sims_dict,
1406
+ cls_cos_sims,
1407
+ gene_list_mini,
1408
+ )
1409
+ else:
1410
+ for state in cos_sims_dict.keys():
1411
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
1412
+ cos_sims_dict[state],
1413
+ cls_cos_sims[state],
1414
+ gene_list_mini,
1415
+ )
1416
+
1417
+ del perturbation_minibatch
1418
+ del original_emb_minibatch
1419
+ del full_perturbation_emb
1420
+ del perturbation_emb
1421
+ del perturbation_cls_emb
1422
+ del cls_cos_sims
1423
+ del gene_cos_sims
1424
+
1425
+ # save dict to disk every self.clear_mem_ncells/10 (default 100) simulated cells
1426
+ if i % max(1,self.clear_mem_ncells/10) == 0:
1427
+ pu.write_perturbation_dictionary(
1428
+ cos_sims_dict,
1429
+ f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
1430
+ )
1431
+ if self.emb_mode == "cls_and_gene":
1432
+ pu.write_perturbation_dictionary(
1433
+ stored_gene_embs_dict,
1434
+ f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
1435
+ )
1436
+
1437
+ # reset and clear memory every self.clear_mem_ncells (default 1000) simulated cells or at the end of the example cell
1438
+ if i % self.clear_mem_ncells == 0:
1439
+ pickle_batch += 1
1440
+ if self.cell_states_to_model is None:
1441
+ cos_sims_dict = defaultdict(list)
1442
+ else:
1443
+ cos_sims_dict = {
1444
+ state: defaultdict(list)
1445
+ for state in pu.get_possible_states(self.cell_states_to_model)
1446
+ }
1447
+
1448
+ if self.emb_mode == "cls_and_gene":
1449
+ stored_gene_embs_dict = defaultdict(list)
1450
+
1451
+ torch.cuda.empty_cache()
1452
+
1453
  pu.write_perturbation_dictionary(
1454
+ cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}"
 
1455
  )
1456
+
1457
+ if self.emb_mode == "cls_and_gene":
1458
+ pu.write_perturbation_dictionary(
1459
+ stored_gene_embs_dict,
1460
+ f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
1461
+ )
1462
+
1463
+ pickle_batch = -1
1464
+ if self.cell_states_to_model is None:
1465
+ cos_sims_dict = defaultdict(list)
1466
+ else:
1467
+ cos_sims_dict = {
1468
+ state: defaultdict(list)
1469
+ for state in pu.get_possible_states(self.cell_states_to_model)
1470
+ }
1471
+
1472
+ if self.emb_mode == "cls_and_gene":
1473
+ stored_gene_embs_dict = defaultdict(list)
1474
+
1475
+ # clear memory between cells
1476
+ del perturbation_batch
1477
+ del original_cls_emb
1478
+ if self.emb_mode == "cls_and_gene":
1479
+ del full_original_emb
1480
+ torch.cuda.empty_cache()
1481
 
1482
+
1483
  def update_perturbation_dictionary(
1484
  self,
1485
  cos_sims_dict: defaultdict,
1486
  cos_sims_data: torch.Tensor,
 
 
1487
  gene_list=None,
1488
  ):
1489
  if gene_list is not None and cos_sims_data.shape[0] != len(gene_list):
1490
  logger.error(
1491
  f"len(cos_sims_data.shape[0]) != len(gene_list). \n \
1492
+ {cos_sims_data.shape[0]=}.\n \
1493
+ {len(gene_list)=}."
1494
  )
1495
  raise
1496
 
 
1514
  for i, cos in enumerate(cos_sims_data.tolist()):
1515
  cos_sims_dict[(gene_list[i], "cell_emb")].append(cos)
1516
 
1517
+ return cos_sims_dict
geneformer/in_silico_perturber_stats.py CHANGED
@@ -114,6 +114,7 @@ def read_dictionaries(
114
  state_dict[state_value][key] += new_dict[key]
115
  except KeyError:
116
  state_dict[state_value][key] = new_dict[key]
 
117
  if not file_found:
118
  logger.error(
119
  "No raw data for processing found within provided directory. "
@@ -237,13 +238,16 @@ def find(variable, x):
237
 
238
 
239
  def isp_aggregate_gene_shifts(
240
- cos_sims_df, dict_list, gene_token_id_dict, gene_id_name_dict
241
  ):
242
  cos_shift_data = dict()
243
  for i in trange(cos_sims_df.shape[0]):
244
  token = cos_sims_df["Gene"][i]
245
  for dict_i in dict_list:
246
- affected_pairs = [k for k, v in dict_i.items() if find(k[0], token)]
 
 
 
247
  for key in affected_pairs:
248
  if key in cos_shift_data.keys():
249
  cos_shift_data[key] += dict_i.get(key, [])
@@ -256,11 +260,11 @@ def isp_aggregate_gene_shifts(
256
  cos_sims_full_df = pd.DataFrame()
257
  cos_sims_full_df["Perturbed"] = [k[0] for k, v in cos_data_mean.items()]
258
  cos_sims_full_df["Gene_name"] = [
259
- cos_sims_df[cos_sims_df["Gene"] == k[0]]["Gene_name"][0]
260
  for k, v in cos_data_mean.items()
261
  ]
262
  cos_sims_full_df["Ensembl_ID"] = [
263
- cos_sims_df[cos_sims_df["Gene"] == k[0]]["Ensembl_ID"][0]
264
  for k, v in cos_data_mean.items()
265
  ]
266
 
@@ -690,7 +694,7 @@ class InSilicoPerturberStats:
690
  | Default is assuming genes_to_perturb in isp experiment was "all" (each gene in each cell).
691
  | Otherwise, may provide a list of ENSEMBL IDs of genes perturbed as a group all together.
692
  combos : {0,1,2}
693
- | Whether to perturb genes individually (0), in pairs (1), or in triplets (2).
694
  anchor_gene : None, str
695
  | ENSEMBL ID of gene to use as anchor in combination perturbations or in testing effect on downstream genes.
696
  | For example, if combos=1 and anchor_gene="ENSG00000136574":
@@ -1014,7 +1018,7 @@ class InSilicoPerturberStats:
1014
  },
1015
  index=[i for i in range(len(gene_list))],
1016
  )
1017
-
1018
  if self.mode == "goal_state_shift":
1019
  cos_sims_df = isp_stats_to_goal_state(
1020
  cos_sims_df_initial,
@@ -1045,11 +1049,23 @@ class InSilicoPerturberStats:
1045
  cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list, self.genes_perturbed)
1046
 
1047
  elif self.mode == "aggregate_gene_shifts":
 
 
 
 
 
 
 
 
 
 
 
1048
  cos_sims_df = isp_aggregate_gene_shifts(
1049
  cos_sims_df_initial,
1050
  dict_list,
1051
  self.gene_token_id_dict,
1052
  self.gene_id_name_dict,
 
1053
  )
1054
 
1055
  # save perturbation stats to output_path
 
114
  state_dict[state_value][key] += new_dict[key]
115
  except KeyError:
116
  state_dict[state_value][key] = new_dict[key]
117
+
118
  if not file_found:
119
  logger.error(
120
  "No raw data for processing found within provided directory. "
 
238
 
239
 
240
  def isp_aggregate_gene_shifts(
241
+ cos_sims_df, dict_list, gene_token_id_dict, gene_id_name_dict, token_dtype
242
  ):
243
  cos_shift_data = dict()
244
  for i in trange(cos_sims_df.shape[0]):
245
  token = cos_sims_df["Gene"][i]
246
  for dict_i in dict_list:
247
+ if token_dtype == "nontuple":
248
+ affected_pairs = [k for k, v in dict_i.items() if k[0] == token]
249
+ else:
250
+ affected_pairs = [k for k, v in dict_i.items() if find(k[0], token)]
251
  for key in affected_pairs:
252
  if key in cos_shift_data.keys():
253
  cos_shift_data[key] += dict_i.get(key, [])
 
260
  cos_sims_full_df = pd.DataFrame()
261
  cos_sims_full_df["Perturbed"] = [k[0] for k, v in cos_data_mean.items()]
262
  cos_sims_full_df["Gene_name"] = [
263
+ cos_sims_df[cos_sims_df["Gene"] == k[0]]["Gene_name"].item()
264
  for k, v in cos_data_mean.items()
265
  ]
266
  cos_sims_full_df["Ensembl_ID"] = [
267
+ cos_sims_df[cos_sims_df["Gene"] == k[0]]["Ensembl_ID"].item()
268
  for k, v in cos_data_mean.items()
269
  ]
270
 
 
694
  | Default is assuming genes_to_perturb in isp experiment was "all" (each gene in each cell).
695
  | Otherwise, may provide a list of ENSEMBL IDs of genes perturbed as a group all together.
696
  combos : {0,1,2}
697
+ | Whether genex perturbed in isp experiment were perturbed individually (0), in pairs (1), or in triplets (2).
698
  anchor_gene : None, str
699
  | ENSEMBL ID of gene to use as anchor in combination perturbations or in testing effect on downstream genes.
700
  | For example, if combos=1 and anchor_gene="ENSG00000136574":
 
1018
  },
1019
  index=[i for i in range(len(gene_list))],
1020
  )
1021
+
1022
  if self.mode == "goal_state_shift":
1023
  cos_sims_df = isp_stats_to_goal_state(
1024
  cos_sims_df_initial,
 
1049
  cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list, self.genes_perturbed)
1050
 
1051
  elif self.mode == "aggregate_gene_shifts":
1052
+ if (self.genes_perturbed == "all") and (self.combos == 0):
1053
+ tuple_types = [True if isinstance(genes, tuple) else False for genes in gene_list]
1054
+ if all(tuple_types):
1055
+ token_dtype = "tuple"
1056
+ elif not any(tuple_types):
1057
+ token_dtype = "nontuple"
1058
+ else:
1059
+ token_dtype = "mix"
1060
+ else:
1061
+ token_dtype = "mix"
1062
+
1063
  cos_sims_df = isp_aggregate_gene_shifts(
1064
  cos_sims_df_initial,
1065
  dict_list,
1066
  self.gene_token_id_dict,
1067
  self.gene_id_name_dict,
1068
+ token_dtype
1069
  )
1070
 
1071
  # save perturbation stats to output_path
geneformer/mtl/__init__.py ADDED
File without changes
geneformer/mtl/collators.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #imports
2
+ import torch
3
+
4
+ from ..collator_for_classification import DataCollatorForGeneClassification
5
+
6
+ """
7
+ Geneformer collator for multi-task cell classification.
8
+ """
9
+
10
+ class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
11
+ class_type = "cell"
12
+
13
+ def __init__(self, *args, **kwargs) -> None:
14
+ super().__init__(*args, **kwargs)
15
+
16
+ def _prepare_batch(self, features):
17
+ # Process inputs as usual
18
+ batch = self.tokenizer.pad(
19
+ features,
20
+ class_type=self.class_type,
21
+ padding=self.padding,
22
+ max_length=self.max_length,
23
+ pad_to_multiple_of=self.pad_to_multiple_of,
24
+ return_tensors="pt",
25
+ )
26
+
27
+ # Check if labels are present
28
+ if "label" in features[0]:
29
+ # Initialize labels dictionary for all tasks
30
+ labels = {task: [] for task in features[0]["label"].keys()}
31
+
32
+ # Populate labels for each task
33
+ for feature in features:
34
+ for task, label in feature["label"].items():
35
+ labels[task].append(label)
36
+
37
+ # Convert label lists to tensors, handling dictionaries appropriately
38
+ for task in labels:
39
+ if isinstance(labels[task][0], (list, torch.Tensor)):
40
+ dtype = torch.long
41
+ labels[task] = torch.tensor(labels[task], dtype=dtype)
42
+ elif isinstance(labels[task][0], dict):
43
+ # Handle dict specifically if needed
44
+ pass # Resolve nested data structure
45
+
46
+ # Update the batch to include task-specific labels
47
+ batch["labels"] = labels
48
+ else:
49
+ # If no labels are present, create empty labels for all tasks
50
+ batch["labels"] = {task: torch.tensor([], dtype=torch.long) for task in features[0]["input_ids"].keys()}
51
+
52
+ return batch
53
+
54
+ def __call__(self, features):
55
+ batch = self._prepare_batch(features)
56
+
57
+ for k, v in batch.items():
58
+ if torch.is_tensor(v):
59
+ batch[k] = v.clone().detach()
60
+ elif isinstance(v, dict):
61
+ # Assuming nested structure needs conversion
62
+ batch[k] = {task: torch.tensor(labels, dtype=torch.int64) for task, labels in v.items()}
63
+ else:
64
+ batch[k] = torch.tensor(v, dtype=torch.int64)
65
+
66
+ return batch
geneformer/mtl/data.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .imports import *
2
+ import os
3
+ from .collators import DataCollatorForMultitaskCellClassification
4
+
5
+ def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
6
+ try:
7
+ dataset = load_from_disk(dataset_path)
8
+
9
+ task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
10
+ task_to_column = dict(zip(task_names, config["task_columns"]))
11
+ config["task_names"] = task_names
12
+
13
+ if not is_test:
14
+ available_columns = set(dataset.column_names)
15
+ for column in task_to_column.values():
16
+ if column not in available_columns:
17
+ raise KeyError(f"Column {column} not found in the dataset. Available columns: {list(available_columns)}")
18
+
19
+ label_mappings = {}
20
+ task_label_mappings = {}
21
+ cell_id_mapping = {}
22
+ num_labels_list = []
23
+
24
+ # Load or create task label mappings
25
+ if not is_test:
26
+ for task, column in task_to_column.items():
27
+ unique_values = sorted(set(dataset[column])) # Ensure consistency
28
+ label_mappings[column] = {label: idx for idx, label in enumerate(unique_values)}
29
+ task_label_mappings[task] = label_mappings[column]
30
+ num_labels_list.append(len(unique_values))
31
+
32
+ # Print the mappings for each task with dataset type prefix
33
+ for task, mapping in task_label_mappings.items():
34
+ print(f"{dataset_type.capitalize()} mapping for {task}: {mapping}") # sanity check, for train/validation splits
35
+
36
+ # Save the task label mappings as a pickle file
37
+ with open(f"{config['results_dir']}/task_label_mappings.pkl", "wb") as f:
38
+ pickle.dump(task_label_mappings, f)
39
+ else:
40
+ # Load task label mappings from pickle file for test data
41
+ with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
42
+ task_label_mappings = pickle.load(f)
43
+
44
+ # Infer num_labels_list from task_label_mappings
45
+ for task, mapping in task_label_mappings.items():
46
+ num_labels_list.append(len(mapping))
47
+
48
+ # Store unique cell IDs in a separate dictionary
49
+ for idx, record in enumerate(dataset):
50
+ cell_id = record.get('unique_cell_id', idx)
51
+ cell_id_mapping[idx] = cell_id
52
+
53
+ # Transform records to the desired format
54
+ transformed_dataset = []
55
+ for idx, record in enumerate(dataset):
56
+ transformed_record = {}
57
+ transformed_record['input_ids'] = torch.tensor(record['input_ids'], dtype=torch.long)
58
+
59
+ # Use index-based cell ID for internal tracking
60
+ transformed_record['cell_id'] = idx
61
+
62
+ if not is_test:
63
+ # Prepare labels
64
+ label_dict = {}
65
+ for task, column in task_to_column.items():
66
+ label_value = record[column]
67
+ label_index = task_label_mappings[task][label_value]
68
+ label_dict[task] = label_index
69
+ transformed_record['label'] = label_dict
70
+ else:
71
+ # Create dummy labels for test data
72
+ label_dict = {task: -1 for task in config["task_names"]}
73
+ transformed_record['label'] = label_dict
74
+
75
+ transformed_dataset.append(transformed_record)
76
+
77
+ return transformed_dataset, cell_id_mapping, num_labels_list
78
+ except KeyError as e:
79
+ print(f"Missing configuration or dataset key: {e}")
80
+ except Exception as e:
81
+ print(f"An error occurred while loading or preprocessing data: {e}")
82
+ return None, None, None
83
+
84
+ def preload_and_process_data(config):
85
+ # Load and preprocess data once
86
+ train_dataset, train_cell_id_mapping, num_labels_list = load_and_preprocess_data(config["train_path"], config, dataset_type="train")
87
+ val_dataset, val_cell_id_mapping, _ = load_and_preprocess_data(config["val_path"], config, dataset_type="validation")
88
+ return train_dataset, train_cell_id_mapping, val_dataset, val_cell_id_mapping, num_labels_list
89
+
90
+ def get_data_loader(preprocessed_dataset, batch_size):
91
+ nproc = os.cpu_count() ### I/O operations
92
+
93
+ data_collator = DataCollatorForMultitaskCellClassification()
94
+
95
+ loader = DataLoader(preprocessed_dataset, batch_size=batch_size, shuffle=True,
96
+ collate_fn=data_collator, num_workers=nproc, pin_memory=True)
97
+ return loader
98
+ def preload_data(config):
99
+ # Preprocessing the data before the Optuna trials start
100
+ train_loader = get_data_loader("train", config)
101
+ val_loader = get_data_loader("val", config)
102
+ return train_loader, val_loader
103
+
104
+ def load_and_preprocess_test_data(config):
105
+ """
106
+ Load and preprocess test data, treating it as unlabeled.
107
+ """
108
+ return load_and_preprocess_data(config["test_path"], config, is_test=True)
109
+
110
+ def prepare_test_loader(config):
111
+ """
112
+ Prepare DataLoader for the test dataset.
113
+ """
114
+ test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(config)
115
+ test_loader = get_data_loader(test_dataset, config['batch_size'])
116
+ return test_loader, cell_id_mapping, num_labels_list
geneformer/mtl/eval_utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .imports import *
2
+ import pandas as pd
3
+ from .data import prepare_test_loader
4
+ from .model import GeneformerMultiTask
5
+
6
+ def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config):
7
+ task_pred_labels = {task_name: [] for task_name in config["task_names"]}
8
+ task_pred_probs = {task_name: [] for task_name in config["task_names"]}
9
+ cell_ids = []
10
+
11
+ # Load task label mappings from pickle file
12
+ with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
13
+ task_label_mappings = pickle.load(f)
14
+
15
+ model.eval()
16
+ with torch.no_grad():
17
+ for batch in test_loader:
18
+ input_ids = batch['input_ids'].to(device)
19
+ attention_mask = batch['attention_mask'].to(device)
20
+ _, logits, _ = model(input_ids, attention_mask)
21
+ for sample_idx in range(len(batch['input_ids'])):
22
+ cell_id = cell_id_mapping[batch['cell_id'][sample_idx].item()]
23
+ cell_ids.append(cell_id)
24
+ for i, task_name in enumerate(config["task_names"]):
25
+ pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
26
+ pred_prob = torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
27
+ task_pred_labels[task_name].append(pred_label)
28
+ task_pred_probs[task_name].append(pred_prob)
29
+
30
+ # Save test predictions with cell IDs and probabilities to CSV
31
+ test_results_dir = config["results_dir"]
32
+ os.makedirs(test_results_dir, exist_ok=True)
33
+ test_preds_file = os.path.join(test_results_dir, "test_preds.csv")
34
+
35
+ rows = []
36
+ for sample_idx in range(len(cell_ids)):
37
+ row = {'Cell ID': cell_ids[sample_idx]}
38
+ for task_name in config["task_names"]:
39
+ row[f'{task_name} Prediction'] = task_pred_labels[task_name][sample_idx]
40
+ row[f'{task_name} Probabilities'] = ','.join(map(str, task_pred_probs[task_name][sample_idx]))
41
+ rows.append(row)
42
+
43
+ df = pd.DataFrame(rows)
44
+ df.to_csv(test_preds_file, index=False)
45
+ print(f"Test predictions saved to {test_preds_file}")
46
+
47
+ def load_and_evaluate_test_model(config):
48
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+ test_loader, cell_id_mapping, num_labels_list = prepare_test_loader(config)
50
+ model_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
51
+ hyperparams_path = os.path.join(model_directory, "hyperparameters.json")
52
+
53
+ # Load the saved best hyperparameters
54
+ with open(hyperparams_path, 'r') as f:
55
+ best_hyperparams = json.load(f)
56
+
57
+ # Extract the task weights if present, otherwise set to None
58
+ task_weights = best_hyperparams.get("task_weights", None)
59
+ normalized_task_weights = task_weights if task_weights else []
60
+
61
+ # Print the loaded hyperparameters
62
+ print("Loaded hyperparameters:")
63
+ for param, value in best_hyperparams.items():
64
+ if param == "task_weights":
65
+ print(f"normalized_task_weights: {value}")
66
+ else:
67
+ print(f"{param}: {value}")
68
+
69
+ best_model_path = os.path.join(model_directory, "pytorch_model.bin")
70
+ best_model = GeneformerMultiTask(
71
+ config["pretrained_path"],
72
+ num_labels_list,
73
+ dropout_rate=best_hyperparams["dropout_rate"],
74
+ use_task_weights=config["use_task_weights"],
75
+ task_weights=normalized_task_weights
76
+ )
77
+ best_model.load_state_dict(torch.load(best_model_path))
78
+ best_model.to(device)
79
+
80
+ evaluate_test_dataset(best_model, device, test_loader, cell_id_mapping, config)
81
+ print("Evaluation completed.")
geneformer/mtl/imports.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pickle
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ import torch.nn.functional as F
7
+ from torch.utils.data import DataLoader
8
+
9
+ from itertools import chain
10
+ import warnings
11
+ from enum import Enum
12
+ from typing import Dict, List, Optional, Union
13
+ import sys
14
+ import os
15
+ import json
16
+ import gc
17
+ import functools
18
+ import pandas as pd
19
+
20
+ from sklearn.metrics import f1_score, accuracy_score, roc_auc_score, roc_curve
21
+ from sklearn.preprocessing import LabelEncoder
22
+ from sklearn.model_selection import train_test_split
23
+
24
+ import optuna
25
+
26
+ from transformers import (
27
+ BertConfig,
28
+ BertModel,
29
+ AdamW,
30
+ get_linear_schedule_with_warmup,
31
+ get_cosine_schedule_with_warmup,
32
+ DataCollatorForTokenClassification,
33
+ SpecialTokensMixin,
34
+ BatchEncoding,
35
+ get_scheduler,
36
+ )
37
+ from transformers.utils import logging, to_py_obj
38
+
39
+ from datasets import load_from_disk
40
+
41
+ # local modules
42
+ from .data import preload_and_process_data, get_data_loader
43
+ from .model import GeneformerMultiTask
44
+ from .utils import save_model
45
+ from .optuna_utils import create_optuna_study
46
+ from .collators import DataCollatorForMultitaskCellClassification
geneformer/mtl/model.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertModel, BertConfig
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class AttentionPool(nn.Module):
6
+ """Attention-based pooling layer."""
7
+ def __init__(self, hidden_size):
8
+ super(AttentionPool, self).__init__()
9
+ self.attention_weights = nn.Parameter(torch.randn(hidden_size, 1))
10
+ nn.init.xavier_uniform_(self.attention_weights) # https://pytorch.org/docs/stable/nn.init.html
11
+
12
+ def forward(self, hidden_states):
13
+ attention_scores = torch.matmul(hidden_states, self.attention_weights)
14
+ attention_scores = torch.softmax(attention_scores, dim=1)
15
+ pooled_output = torch.sum(hidden_states * attention_scores, dim=1)
16
+ return pooled_output
17
+
18
+ class GeneformerMultiTask(nn.Module):
19
+ def __init__(self, pretrained_path, num_labels_list, dropout_rate=0.1, use_task_weights=False, task_weights=None, max_layers_to_freeze=0, use_attention_pooling=False):
20
+ super(GeneformerMultiTask, self).__init__()
21
+ self.config = BertConfig.from_pretrained(pretrained_path)
22
+ self.bert = BertModel(self.config)
23
+ self.num_labels_list = num_labels_list
24
+ self.use_task_weights = use_task_weights
25
+ self.dropout = nn.Dropout(dropout_rate)
26
+ self.use_attention_pooling = use_attention_pooling
27
+
28
+ if use_task_weights and (task_weights is None or len(task_weights) != len(num_labels_list)):
29
+ raise ValueError("Task weights must be defined and match the number of tasks when 'use_task_weights' is True.")
30
+ self.task_weights = task_weights if use_task_weights else [1.0] * len(num_labels_list)
31
+
32
+ # Freeze the specified initial layers
33
+ for layer in self.bert.encoder.layer[:max_layers_to_freeze]:
34
+ for param in layer.parameters():
35
+ param.requires_grad = False
36
+
37
+ self.attention_pool = AttentionPool(self.config.hidden_size) if use_attention_pooling else None
38
+
39
+ self.classification_heads = nn.ModuleList([
40
+ nn.Linear(self.config.hidden_size, num_labels) for num_labels in num_labels_list
41
+ ])
42
+ # initialization of the classification heads: https://pytorch.org/docs/stable/nn.init.html
43
+ for head in self.classification_heads:
44
+ nn.init.xavier_uniform_(head.weight)
45
+ nn.init.zeros_(head.bias)
46
+
47
+ def forward(self, input_ids, attention_mask, labels=None):
48
+ try:
49
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
50
+ except Exception as e:
51
+ raise RuntimeError(f"Error during BERT forward pass: {e}")
52
+
53
+ sequence_output = outputs.last_hidden_state
54
+
55
+ try:
56
+ pooled_output = self.attention_pool(sequence_output) if self.use_attention_pooling else sequence_output[:, 0, :]
57
+ pooled_output = self.dropout(pooled_output)
58
+ except Exception as e:
59
+ raise RuntimeError(f"Error during pooling and dropout: {e}")
60
+
61
+ total_loss = 0
62
+ logits = []
63
+ losses = []
64
+
65
+ for task_id, (head, num_labels) in enumerate(zip(self.classification_heads, self.num_labels_list)):
66
+ try:
67
+ task_logits = head(pooled_output)
68
+ except Exception as e:
69
+ raise RuntimeError(f"Error during forward pass of classification head {task_id}: {e}")
70
+
71
+ logits.append(task_logits)
72
+
73
+ if labels is not None:
74
+ try:
75
+ loss_fct = nn.CrossEntropyLoss()
76
+ task_loss = loss_fct(task_logits.view(-1, num_labels), labels[task_id].view(-1))
77
+ if self.use_task_weights:
78
+ task_loss *= self.task_weights[task_id]
79
+ total_loss += task_loss
80
+ losses.append(task_loss.item())
81
+ except Exception as e:
82
+ raise RuntimeError(f"Error during loss computation for task {task_id}: {e}")
83
+
84
+ return total_loss, logits, losses if labels is not None else logits
geneformer/mtl/optuna_utils.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import optuna
2
+ from optuna.integration import TensorBoardCallback
3
+
4
+ def save_trial_callback(study, trial, trials_result_path):
5
+ with open(trials_result_path, "a") as f:
6
+ f.write(f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n")
7
+
8
+ def create_optuna_study(objective, n_trials, trials_result_path, tensorboard_log_dir):
9
+ study = optuna.create_study(direction="maximize")
10
+
11
+ # init TensorBoard callback
12
+ tensorboard_callback = TensorBoardCallback(dirname=tensorboard_log_dir, metric_name="F1 Macro")
13
+
14
+ # callback and TensorBoard callback
15
+ callbacks = [
16
+ lambda study, trial: save_trial_callback(study, trial, trials_result_path),
17
+ tensorboard_callback
18
+ ]
19
+
20
+ study.optimize(objective, n_trials=n_trials, callbacks=callbacks)
21
+ return study
geneformer/mtl/train.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .imports import *
2
+ from .data import preload_and_process_data, get_data_loader
3
+ from .model import GeneformerMultiTask
4
+ from .utils import calculate_task_specific_metrics
5
+ from torch.utils.tensorboard import SummaryWriter
6
+ import pandas as pd
7
+ import os
8
+ from tqdm import tqdm
9
+ import random
10
+ import numpy as np
11
+ import torch
12
+
13
+
14
+ def set_seed(seed):
15
+ random.seed(seed)
16
+ np.random.seed(seed)
17
+ torch.manual_seed(seed)
18
+ torch.cuda.manual_seed_all(seed)
19
+ torch.backends.cudnn.deterministic = True
20
+ torch.backends.cudnn.benchmark = False
21
+
22
+ def initialize_wandb(config):
23
+ if config.get("use_wandb", False):
24
+ import wandb
25
+ wandb.init(project=config["wandb_project"], config=config)
26
+ print("Weights & Biases (wandb) initialized and will be used for logging.")
27
+ else:
28
+ print("Weights & Biases (wandb) is not enabled. Logging will use other methods.")
29
+
30
+ def create_model(config, num_labels_list, device):
31
+ model = GeneformerMultiTask(
32
+ config["pretrained_path"],
33
+ num_labels_list,
34
+ dropout_rate=config["dropout_rate"],
35
+ use_task_weights=config["use_task_weights"],
36
+ task_weights=config["task_weights"],
37
+ max_layers_to_freeze=config["max_layers_to_freeze"],
38
+ use_attention_pooling=config["use_attention_pooling"]
39
+ )
40
+ if config["use_data_parallel"]:
41
+ model = nn.DataParallel(model)
42
+ return model.to(device)
43
+
44
+ def setup_optimizer_and_scheduler(model, config, total_steps):
45
+ optimizer = AdamW(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])
46
+ warmup_steps = int(config["warmup_ratio"] * total_steps)
47
+
48
+ if config["lr_scheduler_type"] == "linear":
49
+ scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
50
+ elif config["lr_scheduler_type"] == "cosine":
51
+ scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps, num_cycles=0.5)
52
+
53
+ return optimizer, scheduler
54
+
55
+ def train_epoch(model, train_loader, optimizer, scheduler, device, config, writer, epoch):
56
+ model.train()
57
+ progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
58
+ for batch_idx, batch in enumerate(progress_bar):
59
+ optimizer.zero_grad()
60
+ input_ids = batch['input_ids'].to(device)
61
+ attention_mask = batch['attention_mask'].to(device)
62
+ labels = [batch['labels'][task_name].to(device) for task_name in config["task_names"]]
63
+
64
+ loss, _, _ = model(input_ids, attention_mask, labels)
65
+ loss.backward()
66
+
67
+ if config["gradient_clipping"]:
68
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config["max_grad_norm"])
69
+
70
+ optimizer.step()
71
+ scheduler.step()
72
+
73
+ writer.add_scalar('Training Loss', loss.item(), epoch * len(train_loader) + batch_idx)
74
+ if config.get("use_wandb", False):
75
+ wandb.log({'Training Loss': loss.item()})
76
+
77
+ # Update progress bar
78
+ progress_bar.set_postfix({'loss': f"{loss.item():.4f}"})
79
+
80
+ return loss.item() # Return the last batch loss
81
+
82
+ def validate_model(model, val_loader, device, config):
83
+ model.eval()
84
+ val_loss = 0.0
85
+ task_true_labels = {task_name: [] for task_name in config["task_names"]}
86
+ task_pred_labels = {task_name: [] for task_name in config["task_names"]}
87
+ task_pred_probs = {task_name: [] for task_name in config["task_names"]}
88
+
89
+ with torch.no_grad():
90
+ for batch in val_loader:
91
+ input_ids = batch['input_ids'].to(device)
92
+ attention_mask = batch['attention_mask'].to(device)
93
+ labels = [batch['labels'][task_name].to(device) for task_name in config["task_names"]]
94
+ loss, logits, _ = model(input_ids, attention_mask, labels)
95
+ val_loss += loss.item()
96
+
97
+ for sample_idx in range(len(batch['input_ids'])):
98
+ for i, task_name in enumerate(config["task_names"]):
99
+ true_label = batch['labels'][task_name][sample_idx].item()
100
+ pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
101
+ pred_prob = torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
102
+ task_true_labels[task_name].append(true_label)
103
+ task_pred_labels[task_name].append(pred_label)
104
+ task_pred_probs[task_name].append(pred_prob)
105
+
106
+ val_loss /= len(val_loader)
107
+ return val_loss, task_true_labels, task_pred_labels, task_pred_probs
108
+
109
+ def log_metrics(task_metrics, val_loss, config, writer, epochs):
110
+ for task_name, metrics in task_metrics.items():
111
+ print(f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}")
112
+ if config.get("use_wandb", False):
113
+ import wandb
114
+ wandb.log({
115
+ f'{task_name} Validation F1 Macro': metrics['f1'],
116
+ f'{task_name} Validation Accuracy': metrics['accuracy']
117
+ })
118
+
119
+ writer.add_scalar('Validation Loss', val_loss, epochs)
120
+ for task_name, metrics in task_metrics.items():
121
+ writer.add_scalar(f'{task_name} - Validation F1 Macro', metrics['f1'], epochs)
122
+ writer.add_scalar(f'{task_name} - Validation Accuracy', metrics['accuracy'], epochs)
123
+
124
+ def save_validation_predictions(val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config, trial_number=None):
125
+ if trial_number is not None:
126
+ trial_results_dir = os.path.join(config["results_dir"], f"trial_{trial_number}")
127
+ os.makedirs(trial_results_dir, exist_ok=True)
128
+ val_preds_file = os.path.join(trial_results_dir, "val_preds.csv")
129
+ else:
130
+ val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv")
131
+
132
+ rows = []
133
+ for sample_idx in range(len(val_cell_id_mapping)):
134
+ row = {'Cell ID': val_cell_id_mapping[sample_idx]}
135
+ for task_name in config["task_names"]:
136
+ row[f'{task_name} True'] = task_true_labels[task_name][sample_idx]
137
+ row[f'{task_name} Pred'] = task_pred_labels[task_name][sample_idx]
138
+ row[f'{task_name} Probabilities'] = ','.join(map(str, task_pred_probs[task_name][sample_idx]))
139
+ rows.append(row)
140
+
141
+ df = pd.DataFrame(rows)
142
+ df.to_csv(val_preds_file, index=False)
143
+ print(f"Validation predictions saved to {val_preds_file}")
144
+
145
+
146
+ def train_model(config, device, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list):
147
+ set_seed(config["seed"])
148
+ initialize_wandb(config)
149
+
150
+ model = create_model(config, num_labels_list, device)
151
+ total_steps = len(train_loader) * config["epochs"]
152
+ optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps)
153
+
154
+ log_dir = os.path.join(config["tensorboard_log_dir"], "manual_run")
155
+ writer = SummaryWriter(log_dir=log_dir)
156
+
157
+ epoch_progress = tqdm(range(config["epochs"]), desc="Training Progress")
158
+ for epoch in epoch_progress:
159
+ last_loss = train_epoch(model, train_loader, optimizer, scheduler, device, config, writer, epoch)
160
+ epoch_progress.set_postfix({'last_loss': f"{last_loss:.4f}"})
161
+
162
+ val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(model, val_loader, device, config)
163
+ task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
164
+
165
+ log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
166
+ writer.close()
167
+
168
+ save_validation_predictions(val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config)
169
+
170
+ if config.get("use_wandb", False):
171
+ import wandb
172
+ wandb.finish()
173
+
174
+ print(f"\nFinal Validation Loss: {val_loss:.4f}")
175
+ return val_loss, model # Return both the validation loss and the trained model
176
+
177
+ def objective(trial, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, config, device):
178
+ set_seed(config["seed"]) # Set the seed before each trial
179
+ initialize_wandb(config)
180
+
181
+ # Hyperparameters
182
+ config["learning_rate"] = trial.suggest_float("learning_rate", config["hyperparameters"]["learning_rate"]["low"], config["hyperparameters"]["learning_rate"]["high"], log=config["hyperparameters"]["learning_rate"]["log"])
183
+ config["warmup_ratio"] = trial.suggest_float("warmup_ratio", config["hyperparameters"]["warmup_ratio"]["low"], config["hyperparameters"]["warmup_ratio"]["high"])
184
+ config["weight_decay"] = trial.suggest_float("weight_decay", config["hyperparameters"]["weight_decay"]["low"], config["hyperparameters"]["weight_decay"]["high"])
185
+ config["dropout_rate"] = trial.suggest_float("dropout_rate", config["hyperparameters"]["dropout_rate"]["low"], config["hyperparameters"]["dropout_rate"]["high"])
186
+ config["lr_scheduler_type"] = trial.suggest_categorical("lr_scheduler_type", config["hyperparameters"]["lr_scheduler_type"]["choices"])
187
+ config["use_attention_pooling"] = trial.suggest_categorical("use_attention_pooling", [True, False])
188
+
189
+ if config["use_task_weights"]:
190
+ config["task_weights"] = [trial.suggest_float(f"task_weight_{i}", config["hyperparameters"]["task_weights"]["low"], config["hyperparameters"]["task_weights"]["high"]) for i in range(len(num_labels_list))]
191
+ weight_sum = sum(config["task_weights"])
192
+ config["task_weights"] = [weight / weight_sum for weight in config["task_weights"]]
193
+ else:
194
+ config["task_weights"] = None
195
+
196
+ # Fix for max_layers_to_freeze
197
+ if isinstance(config["max_layers_to_freeze"], dict):
198
+ config["max_layers_to_freeze"] = trial.suggest_int("max_layers_to_freeze", config["max_layers_to_freeze"]["min"], config["max_layers_to_freeze"]["max"])
199
+ elif isinstance(config["max_layers_to_freeze"], int):
200
+ # If it's already an int, we don't need to suggest it
201
+ pass
202
+ else:
203
+ raise ValueError("Invalid type for max_layers_to_freeze. Expected dict or int.")
204
+
205
+ model = create_model(config, num_labels_list, device)
206
+ total_steps = len(train_loader) * config["epochs"]
207
+ optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps)
208
+
209
+ log_dir = os.path.join(config["tensorboard_log_dir"], f"trial_{trial.number}")
210
+ writer = SummaryWriter(log_dir=log_dir)
211
+
212
+ for epoch in range(config["epochs"]):
213
+ train_epoch(model, train_loader, optimizer, scheduler, device, config, writer, epoch)
214
+
215
+ val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(model, val_loader, device, config)
216
+ task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
217
+
218
+ log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
219
+ writer.close()
220
+
221
+ save_validation_predictions(val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config, trial.number)
222
+
223
+ trial.set_user_attr("model_state_dict", model.state_dict())
224
+ trial.set_user_attr("task_weights", config["task_weights"])
225
+
226
+ trial.report(val_loss, config["epochs"])
227
+
228
+ if trial.should_prune():
229
+ raise optuna.TrialPruned()
230
+
231
+ if config.get("use_wandb", False):
232
+ import wandb
233
+ wandb.log({
234
+ "trial_number": trial.number,
235
+ "val_loss": val_loss,
236
+ **{f"{task_name}_f1": metrics['f1'] for task_name, metrics in task_metrics.items()},
237
+ **{f"{task_name}_accuracy": metrics['accuracy'] for task_name, metrics in task_metrics.items()},
238
+ **{k: v for k, v in config.items() if k in ["learning_rate", "warmup_ratio", "weight_decay", "dropout_rate", "lr_scheduler_type", "use_attention_pooling", "max_layers_to_freeze"]}
239
+ })
240
+ wandb.finish()
241
+
242
+ return val_loss
geneformer/mtl/train_utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .imports import *
2
+ from .data import preload_and_process_data, get_data_loader
3
+ from .train import objective, train_model
4
+ from .model import GeneformerMultiTask
5
+ from .utils import save_model
6
+ import random
7
+
8
+ def set_seed(seed):
9
+ random.seed(seed)
10
+ np.random.seed(seed)
11
+ torch.manual_seed(seed)
12
+ torch.cuda.manual_seed_all(seed)
13
+ torch.backends.cudnn.deterministic = True
14
+ torch.backends.cudnn.benchmark = False
15
+
16
+ def run_manual_tuning(config):
17
+ # Set seed for reproducibility
18
+ set_seed(config["seed"])
19
+
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ train_dataset, train_cell_id_mapping, val_dataset, val_cell_id_mapping, num_labels_list = preload_and_process_data(config)
22
+ train_loader = get_data_loader(train_dataset, config['batch_size'])
23
+ val_loader = get_data_loader(val_dataset, config['batch_size'])
24
+
25
+ # Print the manual hyperparameters being used
26
+ print("\nManual hyperparameters being used:")
27
+ for key, value in config["manual_hyperparameters"].items():
28
+ print(f"{key}: {value}")
29
+ print() # Add an empty line for better readability
30
+
31
+ # Use the manual hyperparameters
32
+ for key, value in config["manual_hyperparameters"].items():
33
+ config[key] = value
34
+
35
+ # Train the model
36
+ val_loss, trained_model = train_model(config, device, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list)
37
+
38
+ print(f"\nValidation loss with manual hyperparameters: {val_loss}")
39
+
40
+ # Save the trained model
41
+ model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
42
+ save_model(trained_model, model_save_directory)
43
+
44
+ # Save the hyperparameters
45
+ hyperparams_to_save = {
46
+ **config["manual_hyperparameters"],
47
+ "dropout_rate": config["dropout_rate"],
48
+ "use_task_weights": config["use_task_weights"],
49
+ "task_weights": config["task_weights"],
50
+ "max_layers_to_freeze": config["max_layers_to_freeze"],
51
+ "use_attention_pooling": config["use_attention_pooling"]
52
+ }
53
+ hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
54
+ with open(hyperparams_path, 'w') as f:
55
+ json.dump(hyperparams_to_save, f)
56
+ print(f"Manual hyperparameters saved to {hyperparams_path}")
57
+
58
+ return val_loss
59
+
60
+ def run_optuna_study(config):
61
+ # Set seed for reproducibility
62
+ set_seed(config["seed"])
63
+
64
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
+ train_dataset, train_cell_id_mapping, val_dataset, val_cell_id_mapping, num_labels_list = preload_and_process_data(config)
66
+ train_loader = get_data_loader(train_dataset, config['batch_size'])
67
+ val_loader = get_data_loader(val_dataset, config['batch_size'])
68
+
69
+ if config["use_manual_hyperparameters"]:
70
+ train_model(config, device, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list)
71
+ else:
72
+ objective_with_config_and_data = functools.partial(
73
+ objective,
74
+ train_loader=train_loader,
75
+ val_loader=val_loader,
76
+ train_cell_id_mapping=train_cell_id_mapping,
77
+ val_cell_id_mapping=val_cell_id_mapping,
78
+ num_labels_list=num_labels_list,
79
+ config=config,
80
+ device=device
81
+ )
82
+
83
+ study = optuna.create_study(
84
+ direction='minimize', # Minimize validation loss
85
+ study_name=config["study_name"],
86
+ #storage=config["storage"],
87
+ load_if_exists=True
88
+ )
89
+
90
+ study.optimize(
91
+ objective_with_config_and_data,
92
+ n_trials=config["n_trials"]
93
+ )
94
+
95
+ # After finding the best trial
96
+ best_params = study.best_trial.params
97
+ best_task_weights = study.best_trial.user_attrs["task_weights"]
98
+ print("Saving the best model and its hyperparameters...")
99
+
100
+ # Saving model as before
101
+ best_model = GeneformerMultiTask(
102
+ config["pretrained_path"],
103
+ num_labels_list,
104
+ dropout_rate=best_params["dropout_rate"],
105
+ use_task_weights=config["use_task_weights"],
106
+ task_weights=best_task_weights
107
+ )
108
+
109
+ # Get the best model state dictionary
110
+ best_model_state_dict = study.best_trial.user_attrs["model_state_dict"]
111
+
112
+ # Remove the "module." prefix from the state dictionary keys if present
113
+ best_model_state_dict = {k.replace("module.", ""): v for k, v in best_model_state_dict.items()}
114
+
115
+ # Load the modified state dictionary into the model, skipping unexpected keys
116
+ best_model.load_state_dict(best_model_state_dict, strict=False)
117
+
118
+ model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
119
+ save_model(best_model, model_save_directory)
120
+
121
+ # Additionally, save the best hyperparameters and task weights
122
+ hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
123
+
124
+ with open(hyperparams_path, 'w') as f:
125
+ json.dump({**best_params, "task_weights": best_task_weights}, f)
126
+ print(f"Best hyperparameters and task weights saved to {hyperparams_path}")
geneformer/mtl/utils.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .imports import *
2
+ from sklearn.metrics import f1_score, accuracy_score
3
+ from sklearn.preprocessing import LabelEncoder
4
+ from transformers import BertModel, BertConfig, AutoConfig
5
+ import os
6
+ import shutil
7
+
8
+ def save_model(model, model_save_directory):
9
+ if not os.path.exists(model_save_directory):
10
+ os.makedirs(model_save_directory)
11
+
12
+ # Get the state dict
13
+ if isinstance(model, nn.DataParallel):
14
+ model_state_dict = model.module.state_dict() # Use model.module to access the underlying model
15
+ else:
16
+ model_state_dict = model.state_dict()
17
+
18
+ # Remove the "module." prefix from the keys if present
19
+ model_state_dict = {k.replace("module.", ""): v for k, v in model_state_dict.items()}
20
+
21
+ model_save_path = os.path.join(model_save_directory, "pytorch_model.bin")
22
+ torch.save(model_state_dict, model_save_path)
23
+
24
+ # Save the model configuration
25
+ if isinstance(model, nn.DataParallel):
26
+ model.module.config.to_json_file(os.path.join(model_save_directory, "config.json"))
27
+ else:
28
+ model.config.to_json_file(os.path.join(model_save_directory, "config.json"))
29
+
30
+ print(f"Model and configuration saved to {model_save_directory}")
31
+
32
+ def calculate_task_specific_metrics(task_true_labels, task_pred_labels):
33
+ task_metrics = {}
34
+ for task_name in task_true_labels.keys():
35
+ true_labels = task_true_labels[task_name]
36
+ pred_labels = task_pred_labels[task_name]
37
+ f1 = f1_score(true_labels, pred_labels, average='macro')
38
+ accuracy = accuracy_score(true_labels, pred_labels)
39
+ task_metrics[task_name] = {'f1': f1, 'accuracy': accuracy}
40
+ return task_metrics
41
+
42
+ def calculate_combined_f1(combined_labels, combined_preds):
43
+ # Initialize the LabelEncoder
44
+ le = LabelEncoder()
45
+
46
+ # Fit and transform combined labels and predictions to numerical values
47
+ le.fit(combined_labels + combined_preds)
48
+ encoded_true_labels = le.transform(combined_labels)
49
+ encoded_pred_labels = le.transform(combined_preds)
50
+
51
+ # Print out the mapping for sanity check
52
+ print("\nLabel Encoder Mapping:")
53
+ for index, class_label in enumerate(le.classes_):
54
+ print(f"'{class_label}': {index}")
55
+
56
+ # Calculate accuracy
57
+ accuracy = accuracy_score(encoded_true_labels, encoded_pred_labels)
58
+
59
+ # Calculate F1 Macro score
60
+ f1 = f1_score(encoded_true_labels, encoded_pred_labels, average='macro')
61
+
62
+ return f1, accuracy
63
+
64
+ def save_model_without_heads(original_model_save_directory):
65
+ # Create a new directory for the model without heads
66
+ new_model_save_directory = original_model_save_directory + "_No_Heads"
67
+ if not os.path.exists(new_model_save_directory):
68
+ os.makedirs(new_model_save_directory)
69
+
70
+ # Load the model state dictionary
71
+ model_state_dict = torch.load(os.path.join(original_model_save_directory, "pytorch_model.bin"))
72
+
73
+ # Initialize a new BERT model without the classification heads
74
+ config = BertConfig.from_pretrained(os.path.join(original_model_save_directory, "config.json"))
75
+ model_without_heads = BertModel(config)
76
+
77
+ # Filter the state dict to exclude classification heads
78
+ model_without_heads_state_dict = {k: v for k, v in model_state_dict.items() if not k.startswith("classification_heads")}
79
+
80
+ # Load the filtered state dict into the model
81
+ model_without_heads.load_state_dict(model_without_heads_state_dict, strict=False)
82
+
83
+ # Save the model without heads
84
+ model_save_path = os.path.join(new_model_save_directory, "pytorch_model.bin")
85
+ torch.save(model_without_heads.state_dict(), model_save_path)
86
+
87
+ # Copy the configuration file
88
+ shutil.copy(os.path.join(original_model_save_directory, "config.json"), new_model_save_directory)
89
+
90
+ print(f"Model without classification heads saved to {new_model_save_directory}")
91
+
92
+
93
+ def get_layer_freeze_range(pretrained_path):
94
+ """
95
+ Dynamically determines the number of layers to freeze based on the model depth from its configuration.
96
+ Args:
97
+ pretrained_path (str): Path to the pretrained model directory or model identifier.
98
+ Returns:
99
+ dict: A dictionary with 'min' and 'max' keys indicating the range of layers to freeze.
100
+ """
101
+ if pretrained_path:
102
+ config = AutoConfig.from_pretrained(pretrained_path)
103
+ total_layers = config.num_hidden_layers
104
+ return {"min": 0, "max": total_layers - 1}
105
+ else:
106
+ return {"min": 0, "max": 0}
geneformer/mtl_classifier.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer multi-task cell classifier.
3
+
4
+ **Input data:**
5
+
6
+ | Single-cell transcriptomes as Geneformer rank value encodings with cell state labels for each task in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py). Must contain "unique_cell_id" column for logging.
7
+
8
+ **Usage:**
9
+
10
+ .. code-block :: python
11
+
12
+ >>> from geneformer import MTLClassifier
13
+ >>> mc = MTLClassifier(task_columns = ["task1", "task2"],
14
+ ... study_name = "mtl",
15
+ ... pretrained_path = "/path/pretrained/model",
16
+ ... train_path = "/path/train/set",
17
+ ... val_path = "/path/eval/set",
18
+ ... test_path = "/path/test/set",
19
+ ... model_save_path = "/results/directory/save_path",
20
+ ... trials_result_path = "/results/directory/results.txt",
21
+ ... results_dir = "/results/directory",
22
+ ... tensorboard_log_dir = "/results/tblogdir",
23
+ ... hyperparameters = hyperparameters)
24
+ >>> mc.run_optuna_study()
25
+ >>> mc.load_and_evaluate_test_model()
26
+ >>> mc.save_model_without_heads()
27
+ """
28
+
29
+ import logging
30
+ import os
31
+ from .mtl import train_utils
32
+ from .mtl import utils
33
+ from .mtl import eval_utils
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ class MTLClassifier:
39
+ valid_option_dict = {
40
+ "task_columns": {list},
41
+ "train_path": {None, str},
42
+ "val_path": {None, str},
43
+ "test_path": {None, str},
44
+ "pretrained_path": {None, str},
45
+ "model_save_path": {None, str},
46
+ "results_dir": {None, str},
47
+ "batch_size": {None, int},
48
+ "n_trials": {None, int},
49
+ "study_name": {None, str},
50
+ "max_layers_to_freeze": {None, dict},
51
+ "epochs": {None, int},
52
+ "tensorboard_log_dir": {None, str},
53
+ "use_data_parallel": {None, bool},
54
+ "use_attention_pooling": {None, bool},
55
+ "use_task_weights": {None, bool},
56
+ "hyperparameters": {None, dict},
57
+ "manual_hyperparameters": {None, dict},
58
+ "use_manual_hyperparameters": {None, bool},
59
+ "use_wandb": {None, bool},
60
+ "wandb_project": {None, str},
61
+ "gradient_clipping": {None, bool},
62
+ "max_grad_norm": {None, int, float},
63
+ "seed": {None, int},
64
+ "trials_result_path": {None, str},
65
+ }
66
+
67
+ def __init__(
68
+ self,
69
+ task_columns=None,
70
+ train_path=None,
71
+ val_path=None,
72
+ test_path=None,
73
+ pretrained_path=None,
74
+ model_save_path=None,
75
+ results_dir=None,
76
+ trials_result_path=None,
77
+ batch_size=4,
78
+ n_trials=15,
79
+ study_name="mtl",
80
+ max_layers_to_freeze=None,
81
+ epochs=1,
82
+ tensorboard_log_dir="/results/tblogdir",
83
+ use_data_parallel=False,
84
+ use_attention_pooling=True,
85
+ use_task_weights=True,
86
+ hyperparameters=None, # Default is None
87
+ manual_hyperparameters=None, # Default is None
88
+ use_manual_hyperparameters=False, # Default is False
89
+ use_wandb=False,
90
+ wandb_project=None,
91
+ gradient_clipping=False,
92
+ max_grad_norm=None,
93
+ seed=42 # Default seed value
94
+ ):
95
+
96
+ """
97
+ Initialize Geneformer multi-task classifier.
98
+ **Parameters:**
99
+ task_columns : list
100
+ | List of tasks for cell state classification
101
+ | Input data columns are labeled with corresponding task names
102
+ study_name : None, str
103
+ | Study name for labeling output files
104
+ pretrained_path : None, str
105
+ | Path to pretrained model
106
+ train_path : None, str
107
+ | Path to training dataset with task columns and "unique_cell_id" column
108
+ val_path : None, str
109
+ | Path to validation dataset with task columns and "unique_cell_id" column
110
+ test_path : None, str
111
+ | Path to test dataset with task columns and "unique_cell_id" column
112
+ model_save_path : None, str
113
+ | Path to directory to save output model (either full model or model without heads)
114
+ trials_result_path : None, str
115
+ | Path to directory to save hyperparameter tuning trial results
116
+ results_dir : None, str
117
+ | Path to directory to save results
118
+ tensorboard_log_dir : None, str
119
+ | Path to directory for Tensorboard logging results
120
+ use_data_parallel : None, bool
121
+ | Whether to use data parallelization
122
+ use_attention_pooling : None, bool
123
+ | Whether to use attention pooling
124
+ use_task_weights : None, bool
125
+ | Whether to use task weights
126
+ batch_size : None, int
127
+ | Batch size to use
128
+ n_trials : None, int
129
+ | Number of trials for hyperparameter tuning
130
+ epochs : None, int
131
+ | Number of epochs for training
132
+ max_layers_to_freeze : None, dict
133
+ | Dictionary with keys "min" and "max" indicating the min and max layers to freeze from fine-tuning (int)
134
+ | 0: no layers will be frozen; 2: first two layers will be frozen; etc.
135
+ hyperparameters : None, dict
136
+ | Dictionary of categorical max and min for each hyperparameter for tuning
137
+ | For example:
138
+ | {"learning_rate": {"type":"float", "low":"1e-5", "high":"1e-3", "log":True}, "task_weights": {...}, ...}
139
+ manual_hyperparameters : None, dict
140
+ | Dictionary of manually set value for each hyperparameter
141
+ | For example:
142
+ | {"learning_rate": 0.001, "task_weights": [1, 1], ...}
143
+ use_manual_hyperparameters : None, bool
144
+ | Whether to use manually set hyperparameters
145
+ use_wandb : None, bool
146
+ | Whether to use Weights & Biases for logging
147
+ wandb_project : None, str
148
+ | Weights & Biases project name
149
+ gradient_clipping : None, bool
150
+ | Whether to use gradient clipping
151
+ max_grad_norm : None, int, float
152
+ | Maximum norm for gradient clipping
153
+ seed : None, int
154
+ | Random seed
155
+ """
156
+
157
+ self.task_columns = task_columns
158
+ self.train_path = train_path
159
+ self.val_path = val_path
160
+ self.test_path = test_path
161
+ self.pretrained_path = pretrained_path
162
+ self.model_save_path = model_save_path
163
+ self.results_dir = results_dir
164
+ self.trials_result_path = trials_result_path
165
+ self.batch_size = batch_size
166
+ self.n_trials = n_trials
167
+ self.study_name = study_name
168
+
169
+ if max_layers_to_freeze is None:
170
+ # Dynamically determine the range of layers to freeze
171
+ layer_freeze_range = utils.get_layer_freeze_range(pretrained_path)
172
+ self.max_layers_to_freeze = {"min": 1, "max": layer_freeze_range['max']}
173
+ else:
174
+ self.max_layers_to_freeze = max_layers_to_freeze
175
+
176
+ self.epochs = epochs
177
+ self.tensorboard_log_dir = tensorboard_log_dir
178
+ self.use_data_parallel = use_data_parallel
179
+ self.use_attention_pooling = use_attention_pooling
180
+ self.use_task_weights = use_task_weights
181
+ self.hyperparameters = hyperparameters if hyperparameters is not None else {
182
+ "learning_rate": {
183
+ "type": "float",
184
+ "low": 1e-5,
185
+ "high": 1e-3,
186
+ "log": True
187
+ },
188
+ "warmup_ratio": {
189
+ "type": "float",
190
+ "low": 0.005,
191
+ "high": 0.01
192
+ },
193
+ "weight_decay": {
194
+ "type": "float",
195
+ "low": 0.01,
196
+ "high": 0.1
197
+ },
198
+ "dropout_rate": {
199
+ "type": "float",
200
+ "low": 0.0,
201
+ "high": 0.7
202
+ },
203
+ "lr_scheduler_type": {
204
+ "type": "categorical",
205
+ "choices": ["cosine"]
206
+ },
207
+ "task_weights": {
208
+ "type": "float",
209
+ "low": 0.1,
210
+ "high": 2.0
211
+ }
212
+ }
213
+ self.manual_hyperparameters = manual_hyperparameters if manual_hyperparameters is not None else {
214
+ "learning_rate": 0.001,
215
+ "warmup_ratio": 0.01,
216
+ "weight_decay": 0.1,
217
+ "dropout_rate": 0.1,
218
+ "lr_scheduler_type": "cosine",
219
+ "use_attention_pooling": False,
220
+ "task_weights": [1, 1],
221
+ "max_layers_to_freeze": 2
222
+ }
223
+ self.use_manual_hyperparameters = use_manual_hyperparameters
224
+ self.use_wandb = use_wandb
225
+ self.wandb_project = wandb_project
226
+ self.gradient_clipping = gradient_clipping
227
+ self.max_grad_norm = max_grad_norm
228
+ self.seed = seed
229
+
230
+ if self.use_manual_hyperparameters:
231
+ logger.warning(
232
+ "Hyperparameter tuning is highly recommended for optimal results."
233
+ )
234
+
235
+ self.validate_options()
236
+
237
+ # set up output directories
238
+ if self.results_dir is not None:
239
+ self.trials_results_path = f"{self.results_dir}/results.txt".replace("//","/")
240
+
241
+ for output_dir in [self.model_save_path, self.results_dir]:
242
+ if not os.path.exists(output_dir):
243
+ os.makedirs(output_dir)
244
+
245
+ self.config = {key: value for key, value in self.__dict__.items() if key in self.valid_option_dict}
246
+
247
+ def validate_options(self):
248
+ # confirm arguments are within valid options and compatible with each other
249
+ for attr_name, valid_options in self.valid_option_dict.items():
250
+ attr_value = self.__dict__[attr_name]
251
+ if not isinstance(attr_value, (list, dict)):
252
+ if attr_value in valid_options:
253
+ continue
254
+ valid_type = False
255
+ for option in valid_options:
256
+ if (option in [int, float, list, dict, bool, str]) and isinstance(
257
+ attr_value, option
258
+ ):
259
+ valid_type = True
260
+ break
261
+ if valid_type:
262
+ continue
263
+ logger.error(
264
+ f"Invalid option for {attr_name}. "
265
+ f"Valid options for {attr_name}: {valid_options}"
266
+ )
267
+ raise ValueError(f"Invalid option for {attr_name}. Valid options for {attr_name}: {valid_options}")
268
+
269
+ def run_manual_tuning(self):
270
+ """
271
+ Manual hyperparameter tuning and multi-task fine-tuning of pretrained model.
272
+ """
273
+ required_variable_names = ["train_path", "val_path", "pretrained_path", "model_save_path", "results_dir"]
274
+ required_variables = [self.train_path, self.val_path, self.pretrained_path, self.model_save_path, self.results_dir]
275
+ req_var_dict = dict(zip(required_variable_names, required_variables))
276
+ self.validate_additional_options(req_var_dict)
277
+
278
+ if not self.use_manual_hyperparameters:
279
+ raise ValueError("Manual hyperparameters are not enabled. Set use_manual_hyperparameters to True.")
280
+
281
+ # Ensure manual_hyperparameters are set in the config
282
+ self.config["manual_hyperparameters"] = self.manual_hyperparameters
283
+ self.config["use_manual_hyperparameters"] = True
284
+
285
+ train_utils.run_manual_tuning(self.config)
286
+
287
+ def validate_additional_options(self, req_var_dict):
288
+ missing_variable = False
289
+ for variable_name, variable in req_var_dict.items():
290
+ if variable is None:
291
+ logger.warning(
292
+ f"Please provide value to MTLClassifier for required variable {variable_name}"
293
+ )
294
+ missing_variable = True
295
+ if missing_variable is True:
296
+ raise ValueError("Missing required variables for MTLClassifier")
297
+
298
+ def run_optuna_study(
299
+ self,
300
+ ):
301
+ """
302
+ Hyperparameter optimization and/or multi-task fine-tuning of pretrained model.
303
+ """
304
+
305
+ required_variable_names = ["train_path", "val_path", "pretrained_path", "model_save_path", "results_dir"]
306
+ required_variables = [self.train_path, self.val_path, self.pretrained_path, self.model_save_path, self.results_dir]
307
+ req_var_dict = dict(zip(required_variable_names, required_variables))
308
+ self.validate_additional_options(req_var_dict)
309
+
310
+ train_utils.run_optuna_study(self.config)
311
+
312
+ def load_and_evaluate_test_model(
313
+ self,
314
+ ):
315
+ """
316
+ Loads previously fine-tuned multi-task model and evaluates on test data.
317
+ """
318
+
319
+ required_variable_names = ["test_path", "model_save_path", "results_dir"]
320
+ required_variables = [self.test_path, self.model_save_path, self.results_dir]
321
+ req_var_dict = dict(zip(required_variable_names, required_variables))
322
+ self.validate_additional_options(req_var_dict)
323
+
324
+ eval_utils.load_and_evaluate_test_model(self.config)
325
+
326
+ def save_model_without_heads(
327
+ self,
328
+ ):
329
+ """
330
+ Save previously fine-tuned multi-task model without classification heads.
331
+ """
332
+
333
+ required_variable_names = ["model_save_path"]
334
+ required_variables = [self.model_save_path]
335
+ req_var_dict = dict(zip(required_variable_names, required_variables))
336
+ self.validate_additional_options(req_var_dict)
337
+
338
+ utils.save_model_without_heads(os.path.join(self.model_save_path, "GeneformerMultiTask"))
geneformer/perturber_utils.py CHANGED
@@ -12,13 +12,17 @@ import pandas as pd
12
  import seaborn as sns
13
  import torch
14
  from datasets import Dataset, load_from_disk
 
15
  from transformers import (
16
  BertForMaskedLM,
17
  BertForSequenceClassification,
18
  BertForTokenClassification,
 
19
  )
20
 
21
- from . import GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE, ENSEMBL_DICTIONARY_FILE
 
 
22
 
23
 
24
  logger = logging.getLogger(__name__)
@@ -111,17 +115,49 @@ def slice_by_inds_to_perturb(filtered_input_data, cell_inds_to_perturb):
111
 
112
 
113
  # load model to GPU
114
- def load_model(model_type, num_classes, model_directory, mode):
 
 
 
 
115
  if mode == "eval":
116
  output_hidden_states = True
117
  elif mode == "train":
118
  output_hidden_states = False
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  if model_type == "Pretrained":
121
  model = BertForMaskedLM.from_pretrained(
122
  model_directory,
123
  output_hidden_states=output_hidden_states,
124
  output_attentions=False,
 
125
  )
126
  elif model_type == "GeneClassifier":
127
  model = BertForTokenClassification.from_pretrained(
@@ -129,6 +165,7 @@ def load_model(model_type, num_classes, model_directory, mode):
129
  num_labels=num_classes,
130
  output_hidden_states=output_hidden_states,
131
  output_attentions=False,
 
132
  )
133
  elif model_type == "CellClassifier":
134
  model = BertForSequenceClassification.from_pretrained(
@@ -136,11 +173,24 @@ def load_model(model_type, num_classes, model_directory, mode):
136
  num_labels=num_classes,
137
  output_hidden_states=output_hidden_states,
138
  output_attentions=False,
 
 
 
 
 
 
 
 
 
139
  )
140
  # if eval mode, put the model in eval mode for fwd pass
141
  if mode == "eval":
142
  model.eval()
143
- model = model.to("cuda")
 
 
 
 
144
  return model
145
 
146
 
@@ -222,27 +272,47 @@ def overexpress_indices(example):
222
  indices = example["perturb_index"]
223
  if any(isinstance(el, list) for el in indices):
224
  indices = flatten_list(indices)
225
- for index in sorted(indices, reverse=True):
226
- example["input_ids"].insert(0, example["input_ids"].pop(index))
227
-
 
228
  example["length"] = len(example["input_ids"])
229
  return example
230
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
233
- def overexpress_tokens(example, max_len):
234
  # -100 indicates tokens to overexpress are not present in rank value encoding
235
  if example["perturb_index"] != [-100]:
236
  example = delete_indices(example)
237
- [
238
- example["input_ids"].insert(0, token)
239
- for token in example["tokens_to_perturb"][::-1]
240
- ]
 
 
 
 
 
 
241
 
242
  # truncate to max input size, must also truncate original emb to be comparable
243
  if len(example["input_ids"]) > max_len:
244
- example["input_ids"] = example["input_ids"][0:max_len]
245
-
 
 
246
  example["length"] = len(example["input_ids"])
247
  return example
248
 
@@ -259,6 +329,13 @@ def truncate_by_n_overflow(example):
259
  example["length"] = len(example["input_ids"])
260
  return example
261
 
 
 
 
 
 
 
 
262
 
263
  def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
264
  # indices_to_remove is list of indices to remove
@@ -392,7 +469,81 @@ def make_perturbation_batch(
392
  return perturbation_dataset, indices_to_perturb
393
 
394
 
395
- # perturbed cell emb removing the activated/overexpressed/inhibited gene emb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  # so that only non-perturbed gene embeddings are compared to each other
397
  # in original or perturbed context
398
  def make_comparison_batch(original_emb_batch, indices_to_perturb, perturb_group):
@@ -589,9 +740,10 @@ def quant_cos_sims(
589
  cos = torch.nn.CosineSimilarity(dim=1)
590
 
591
  # if emb_mode == "gene", can only calculate gene cos sims
592
- # against original cell anyways
593
  if cell_states_to_model is None or emb_mode == "gene":
594
  cos_sims = cos(perturbation_emb, original_emb).to("cuda")
 
595
  elif cell_states_to_model is not None and emb_mode == "cell":
596
  possible_states = get_possible_states(cell_states_to_model)
597
  cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
@@ -758,4 +910,4 @@ class GeneIdHandler:
758
  return self.ens_to_symbol(self.token_to_ens(token))
759
 
760
  def symbol_to_token(self, symbol):
761
- return self.ens_to_token(self.symbol_to_ens(symbol))
 
12
  import seaborn as sns
13
  import torch
14
  from datasets import Dataset, load_from_disk
15
+ from peft import LoraConfig, get_peft_model
16
  from transformers import (
17
  BertForMaskedLM,
18
  BertForSequenceClassification,
19
  BertForTokenClassification,
20
+ BitsAndBytesConfig,
21
  )
22
 
23
+ GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
24
+ TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
25
+ ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
26
 
27
 
28
  logger = logging.getLogger(__name__)
 
115
 
116
 
117
  # load model to GPU
118
+ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
119
+ if model_type == "MTLCellClassifier-Quantized":
120
+ model_type = "MTLCellClassifier"
121
+ quantize = True
122
+
123
  if mode == "eval":
124
  output_hidden_states = True
125
  elif mode == "train":
126
  output_hidden_states = False
127
 
128
+ if quantize is True:
129
+ if model_type == "MTLCellClassifier":
130
+ quantize = {
131
+ "peft_config": None,
132
+ "bnb_config": BitsAndBytesConfig(
133
+ load_in_8bit=True,
134
+ )
135
+ }
136
+ else:
137
+ quantize = {
138
+ "peft_config": LoraConfig(
139
+ lora_alpha=128,
140
+ lora_dropout=0.1,
141
+ r=64,
142
+ bias="none",
143
+ task_type="TokenClassification",
144
+ ),
145
+ "bnb_config": BitsAndBytesConfig(
146
+ load_in_4bit=True,
147
+ bnb_4bit_use_double_quant=True,
148
+ bnb_4bit_quant_type="nf4",
149
+ bnb_4bit_compute_dtype=torch.bfloat16
150
+ )
151
+ }
152
+ elif quantize is False:
153
+ quantize = {"bnb_config": None}
154
+
155
  if model_type == "Pretrained":
156
  model = BertForMaskedLM.from_pretrained(
157
  model_directory,
158
  output_hidden_states=output_hidden_states,
159
  output_attentions=False,
160
+ quantization_config=quantize["bnb_config"],
161
  )
162
  elif model_type == "GeneClassifier":
163
  model = BertForTokenClassification.from_pretrained(
 
165
  num_labels=num_classes,
166
  output_hidden_states=output_hidden_states,
167
  output_attentions=False,
168
+ quantization_config=quantize["bnb_config"],
169
  )
170
  elif model_type == "CellClassifier":
171
  model = BertForSequenceClassification.from_pretrained(
 
173
  num_labels=num_classes,
174
  output_hidden_states=output_hidden_states,
175
  output_attentions=False,
176
+ quantization_config=quantize["bnb_config"],
177
+ )
178
+ elif model_type == "MTLCellClassifier":
179
+ model = BertForMaskedLM.from_pretrained(
180
+ model_directory,
181
+ num_labels=num_classes,
182
+ output_hidden_states=output_hidden_states,
183
+ output_attentions=False,
184
+ quantization_config=quantize["bnb_config"],
185
  )
186
  # if eval mode, put the model in eval mode for fwd pass
187
  if mode == "eval":
188
  model.eval()
189
+ if (quantize is False) or (quantize == {'bnb_config': None}) or (model_type == "MTLCellClassifier"):
190
+ model = model.to("cuda")
191
+ else:
192
+ model.enable_input_require_grads()
193
+ model = get_peft_model(model, quantize["peft_config"])
194
  return model
195
 
196
 
 
272
  indices = example["perturb_index"]
273
  if any(isinstance(el, list) for el in indices):
274
  indices = flatten_list(indices)
275
+ insert_pos = 0
276
+ for index in sorted(indices, reverse=False):
277
+ example["input_ids"].insert(insert_pos, example["input_ids"].pop(index))
278
+ insert_pos += 1
279
  example["length"] = len(example["input_ids"])
280
  return example
281
 
282
+ # if CLS token present, move to 1st rather than 0th position
283
+ def overexpress_indices_special(example):
284
+ indices = example["perturb_index"]
285
+ if any(isinstance(el, list) for el in indices):
286
+ indices = flatten_list(indices)
287
+ insert_pos = 1 # Insert starting after CLS token
288
+ for index in sorted(indices, reverse=False):
289
+ example["input_ids"].insert(insert_pos, example["input_ids"].pop(index))
290
+ insert_pos += 1
291
+ example["length"] = len(example["input_ids"])
292
+ return example
293
 
294
  # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
295
+ def overexpress_tokens(example, max_len, special_token):
296
  # -100 indicates tokens to overexpress are not present in rank value encoding
297
  if example["perturb_index"] != [-100]:
298
  example = delete_indices(example)
299
+ if special_token:
300
+ [
301
+ example["input_ids"].insert(1, token)
302
+ for token in example["tokens_to_perturb"][::-1]
303
+ ]
304
+ else:
305
+ [
306
+ example["input_ids"].insert(0, token)
307
+ for token in example["tokens_to_perturb"][::-1]
308
+ ]
309
 
310
  # truncate to max input size, must also truncate original emb to be comparable
311
  if len(example["input_ids"]) > max_len:
312
+ if special_token:
313
+ example["input_ids"] = example["input_ids"][0:max_len-1]+[example["input_ids"][-1]]
314
+ else:
315
+ example["input_ids"] = example["input_ids"][0:max_len]
316
  example["length"] = len(example["input_ids"])
317
  return example
318
 
 
329
  example["length"] = len(example["input_ids"])
330
  return example
331
 
332
+ def truncate_by_n_overflow_special(example):
333
+ if example["n_overflow"] > 0:
334
+ new_max_len = example["length"] - example["n_overflow"]
335
+ example["input_ids"] = example["input_ids"][0:new_max_len-1]+[example["input_ids"][-1]]
336
+ example["length"] = len(example["input_ids"])
337
+ return example
338
+
339
 
340
  def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
341
  # indices_to_remove is list of indices to remove
 
469
  return perturbation_dataset, indices_to_perturb
470
 
471
 
472
+ def make_perturbation_batch_special(
473
+ example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
474
+ ) -> tuple[Dataset, List[int]]:
475
+ if combo_lvl == 0 and tokens_to_perturb == "all":
476
+ if perturb_type in ["overexpress", "activate"]:
477
+ range_start = 1
478
+ elif perturb_type in ["delete", "inhibit"]:
479
+ range_start = 0
480
+ range_start += 1 # Starting after the CLS token
481
+ indices_to_perturb = [
482
+ [i] for i in range(range_start, example_cell["length"][0]-1) # And excluding the EOS token
483
+ ]
484
+
485
+ # elif combo_lvl > 0 and anchor_token is None:
486
+ ## to implement
487
+ elif combo_lvl > 0 and (anchor_token is not None):
488
+ example_input_ids = example_cell["input_ids"][0]
489
+ anchor_index = example_input_ids.index(anchor_token[0])
490
+ indices_to_perturb = [
491
+ sorted([anchor_index, i]) if i != anchor_index else None
492
+ for i in range(1, example_cell["length"][0]-1) # Exclude CLS and EOS tokens
493
+ ]
494
+ indices_to_perturb = [item for item in indices_to_perturb if item is not None]
495
+ else:
496
+ example_input_ids = example_cell["input_ids"][0]
497
+ indices_to_perturb = [
498
+ [example_input_ids.index(token)] if token in example_input_ids else None
499
+ for token in tokens_to_perturb
500
+ ]
501
+ indices_to_perturb = [item for item in indices_to_perturb if item is not None]
502
+
503
+ # create all permutations of combo_lvl of modifiers from tokens_to_perturb
504
+ if combo_lvl > 0 and (anchor_token is None):
505
+ if tokens_to_perturb != "all":
506
+ if len(tokens_to_perturb) == combo_lvl + 1:
507
+ indices_to_perturb = [
508
+ list(x) for x in it.combinations(indices_to_perturb, combo_lvl + 1)
509
+ ]
510
+ else:
511
+ all_indices = [[i] for i in range(1, example_cell["length"][0]-1)] # Exclude CLS and EOS tokens
512
+ all_indices = [
513
+ index for index in all_indices if index not in indices_to_perturb
514
+ ]
515
+ indices_to_perturb = [
516
+ [[j for i in indices_to_perturb for j in i], x] for x in all_indices
517
+ ]
518
+
519
+ length = len(indices_to_perturb)
520
+ perturbation_dataset = Dataset.from_dict(
521
+ {
522
+ "input_ids": example_cell["input_ids"] * length,
523
+ "perturb_index": indices_to_perturb,
524
+ }
525
+ )
526
+
527
+ if length < 400:
528
+ num_proc_i = 1
529
+ else:
530
+ num_proc_i = num_proc
531
+
532
+ if perturb_type == "delete":
533
+ perturbation_dataset = perturbation_dataset.map(
534
+ delete_indices, num_proc=num_proc_i
535
+ )
536
+ elif perturb_type == "overexpress":
537
+ perturbation_dataset = perturbation_dataset.map(
538
+ overexpress_indices_special, num_proc=num_proc_i
539
+ )
540
+
541
+ perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
542
+
543
+ return perturbation_dataset, indices_to_perturb
544
+
545
+
546
+ # original cell emb removing the activated/overexpressed/inhibited gene emb
547
  # so that only non-perturbed gene embeddings are compared to each other
548
  # in original or perturbed context
549
  def make_comparison_batch(original_emb_batch, indices_to_perturb, perturb_group):
 
740
  cos = torch.nn.CosineSimilarity(dim=1)
741
 
742
  # if emb_mode == "gene", can only calculate gene cos sims
743
+ # against original cell
744
  if cell_states_to_model is None or emb_mode == "gene":
745
  cos_sims = cos(perturbation_emb, original_emb).to("cuda")
746
+
747
  elif cell_states_to_model is not None and emb_mode == "cell":
748
  possible_states = get_possible_states(cell_states_to_model)
749
  cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
 
910
  return self.ens_to_symbol(self.token_to_ens(token))
911
 
912
  def symbol_to_token(self, symbol):
913
+ return self.ens_to_token(self.symbol_to_ens(symbol))
geneformer/pretrainer.py CHANGED
@@ -32,8 +32,6 @@ from transformers.training_args import ParallelMode
32
  from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
33
  from transformers.utils.generic import _is_tensorflow, _is_torch
34
 
35
- from . import TOKEN_DICTIONARY_FILE
36
-
37
  logger = logging.get_logger(__name__)
38
  EncodedInput = List[int]
39
  VERY_LARGE_INTEGER = int(
@@ -52,9 +50,6 @@ _is_torch_generator_available = False
52
  if version.parse(torch.__version__) >= version.parse("1.6"):
53
  _is_torch_generator_available = True
54
 
55
- with open(TOKEN_DICTIONARY_FILE, "rb") as f:
56
- token_dictionary = pickle.load(f)
57
-
58
 
59
  class ExplicitEnum(Enum):
60
  """
@@ -109,15 +104,7 @@ class GeneformerPreCollator(SpecialTokensMixin):
109
  super().__init__(mask_token="<mask>", pad_token="<pad>")
110
 
111
  self.token_dictionary = kwargs.get("token_dictionary")
112
- # self.mask_token = "<mask>"
113
- # self.mask_token_id = self.token_dictionary.get("<mask>")
114
- # self.pad_token = "<pad>"
115
- # self.pad_token_id = self.token_dictionary.get("<pad>")
116
  self.padding_side = "right"
117
- # self.all_special_ids = [
118
- # self.token_dictionary.get("<mask>"),
119
- # self.token_dictionary.get("<pad>"),
120
- # ]
121
  self.model_input_names = ["input_ids"]
122
 
123
  def convert_ids_to_tokens(self, value):
 
32
  from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
33
  from transformers.utils.generic import _is_tensorflow, _is_torch
34
 
 
 
35
  logger = logging.get_logger(__name__)
36
  EncodedInput = List[int]
37
  VERY_LARGE_INTEGER = int(
 
50
  if version.parse(torch.__version__) >= version.parse("1.6"):
51
  _is_torch_generator_available = True
52
 
 
 
 
53
 
54
  class ExplicitEnum(Enum):
55
  """
 
104
  super().__init__(mask_token="<mask>", pad_token="<pad>")
105
 
106
  self.token_dictionary = kwargs.get("token_dictionary")
 
 
 
 
107
  self.padding_side = "right"
 
 
 
 
108
  self.model_input_names = ["input_ids"]
109
 
110
  def convert_ids_to_tokens(self, value):
geneformer/token_dictionary.pkl DELETED
Binary file (788 kB)
 
geneformer/token_dictionary_gc95M.pkl CHANGED
Binary files a/geneformer/token_dictionary_gc95M.pkl and b/geneformer/token_dictionary_gc95M.pkl differ
 
generation_config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "pad_token_id": 0,
4
+ "transformers_version": "4.37.1"
5
+ }
{geneformer-12L-30M β†’ gf-12L-30M-i2048}/config.json RENAMED
File without changes
{geneformer-12L-30M β†’ gf-12L-30M-i2048}/pytorch_model.bin RENAMED
File without changes
{geneformer-12L-30M β†’ gf-12L-30M-i2048}/training_args.bin RENAMED
File without changes
gf-12L-95M-i4096/config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.02,
6
+ "classifier_dropout": null,
7
+ "hidden_act": "relu",
8
+ "hidden_dropout_prob": 0.02,
9
+ "hidden_size": 512,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 1024,
12
+ "layer_norm_eps": 1e-12,
13
+ "max_position_embeddings": 4096,
14
+ "model_type": "bert",
15
+ "num_attention_heads": 8,
16
+ "num_hidden_layers": 12,
17
+ "pad_token_id": 0,
18
+ "position_embedding_type": "absolute",
19
+ "torch_dtype": "float32",
20
+ "transformers_version": "4.37.1",
21
+ "type_vocab_size": 2,
22
+ "use_cache": true,
23
+ "vocab_size": 20275
24
+ }
gf-12L-95M-i4096/generation_config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "pad_token_id": 0,
4
+ "transformers_version": "4.37.1"
5
+ }
gf-12L-95M-i4096/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4365ba23e393fcfa0e65a94ac64a0983cd788bd23a8d4914f4ab66f85cfe043c
3
+ size 152012980
gf-12L-95M-i4096/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21a45980734b138029422e95a5601def858821a9ec02cd473938b9f525ac108d
3
+ size 4920
gf-12L-95M-i4096_CLcancer/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/gladstone/theodoris/lab/pretrained_models/encoder/240402_194213_geneformer_94M_L12_emb512_SL4096_E3_B4_LR0.0005_LScosine_WU5000_Oadamw_DS8/models",
3
+ "architectures": [
4
+ "BertForMaskedLM"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.02,
7
+ "classifier_dropout": null,
8
+ "hidden_act": "relu",
9
+ "hidden_dropout_prob": 0.02,
10
+ "hidden_size": 512,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 1024,
13
+ "layer_norm_eps": 1e-12,
14
+ "max_position_embeddings": 4096,
15
+ "model_type": "bert",
16
+ "num_attention_heads": 8,
17
+ "num_hidden_layers": 12,
18
+ "pad_token_id": 0,
19
+ "position_embedding_type": "absolute",
20
+ "torch_dtype": "float32",
21
+ "transformers_version": "4.37.1",
22
+ "type_vocab_size": 2,
23
+ "use_cache": true,
24
+ "vocab_size": 20275
25
+ }
gf-12L-95M-i4096_CLcancer/generation_config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "pad_token_id": 0,
4
+ "transformers_version": "4.37.1"
5
+ }
gf-12L-95M-i4096_CLcancer/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2451adeed240c165634fea60ccba17063da8a2843ea9fcdcc0ce185720bf0dc2
3
+ size 152012980
gf-12L-95M-i4096_CLcancer/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37074f3ea62a6ba0a312c38526c20c2dccbb068a2c7ee8c7c73b435dd90ab7b1
3
+ size 5048