ctheodoris davidjwen commited on
Commit
cb89107
·
verified ·
1 Parent(s): ebc1e09

Upload in_silico_perturber.py (#432)

Browse files

- Upload in_silico_perturber.py (9f0fa5d773efa20d201f2d548baff6edca7d08d5)


Co-authored-by: David Wen <[email protected]>

Files changed (1) hide show
  1. geneformer/in_silico_perturber.py +27 -5
geneformer/in_silico_perturber.py CHANGED
@@ -40,7 +40,7 @@ import pickle
40
  from collections import defaultdict
41
 
42
  import torch
43
- from datasets import Dataset, disable_progress_bars
44
  from multiprocess import set_start_method
45
  from tqdm.auto import trange
46
 
@@ -48,7 +48,9 @@ from . import TOKEN_DICTIONARY_FILE
48
  from . import perturber_utils as pu
49
  from .emb_extractor import get_embs
50
 
51
- disable_progress_bars()
 
 
52
 
53
  logger = logging.getLogger(__name__)
54
 
@@ -794,6 +796,8 @@ class InSilicoPerturber:
794
  return example
795
 
796
  total_batch_length = len(filtered_input_data)
 
 
797
  if self.cell_states_to_model is None:
798
  cos_sims_dict = defaultdict(list)
799
  else:
@@ -878,7 +882,7 @@ class InSilicoPerturber:
878
  )
879
 
880
  ##### CLS and Gene Embedding Mode #####
881
- elif self.emb_mode == "cls_and_gene":
882
  full_original_emb = get_embs(
883
  model,
884
  minibatch,
@@ -891,6 +895,7 @@ class InSilicoPerturber:
891
  silent=True,
892
  )
893
  indices_to_perturb = perturbation_batch["perturb_index"]
 
894
  # remove indices that were perturbed
895
  original_emb = pu.remove_perturbed_indices_set(
896
  full_original_emb,
@@ -899,6 +904,7 @@ class InSilicoPerturber:
899
  self.tokens_to_perturb,
900
  minibatch["length"],
901
  )
 
902
  full_perturbation_emb = get_embs(
903
  model,
904
  perturbation_batch,
@@ -910,7 +916,7 @@ class InSilicoPerturber:
910
  summary_stat=None,
911
  silent=True,
912
  )
913
-
914
  # remove special tokens and padding
915
  original_emb = original_emb[:, 1:-1, :]
916
  if self.perturb_type == "overexpress":
@@ -921,9 +927,25 @@ class InSilicoPerturber:
921
  perturbation_emb = full_perturbation_emb[
922
  :, 1 : max(perturbation_batch["length"]) - 1, :
923
  ]
924
-
925
  n_perturbation_genes = perturbation_emb.size()[1]
926
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
927
  gene_cos_sims = pu.quant_cos_sims(
928
  perturbation_emb,
929
  original_emb,
 
40
  from collections import defaultdict
41
 
42
  import torch
43
+ from datasets import Dataset
44
  from multiprocess import set_start_method
45
  from tqdm.auto import trange
46
 
 
48
  from . import perturber_utils as pu
49
  from .emb_extractor import get_embs
50
 
51
+ import datasets
52
+ datasets.logging.disable_progress_bar()
53
+
54
 
55
  logger = logging.getLogger(__name__)
56
 
 
796
  return example
797
 
798
  total_batch_length = len(filtered_input_data)
799
+
800
+
801
  if self.cell_states_to_model is None:
802
  cos_sims_dict = defaultdict(list)
803
  else:
 
882
  )
883
 
884
  ##### CLS and Gene Embedding Mode #####
885
+ elif self.emb_mode == "cls_and_gene":
886
  full_original_emb = get_embs(
887
  model,
888
  minibatch,
 
895
  silent=True,
896
  )
897
  indices_to_perturb = perturbation_batch["perturb_index"]
898
+
899
  # remove indices that were perturbed
900
  original_emb = pu.remove_perturbed_indices_set(
901
  full_original_emb,
 
904
  self.tokens_to_perturb,
905
  minibatch["length"],
906
  )
907
+
908
  full_perturbation_emb = get_embs(
909
  model,
910
  perturbation_batch,
 
916
  summary_stat=None,
917
  silent=True,
918
  )
919
+
920
  # remove special tokens and padding
921
  original_emb = original_emb[:, 1:-1, :]
922
  if self.perturb_type == "overexpress":
 
927
  perturbation_emb = full_perturbation_emb[
928
  :, 1 : max(perturbation_batch["length"]) - 1, :
929
  ]
930
+
931
  n_perturbation_genes = perturbation_emb.size()[1]
932
 
933
+ # truncate the original embedding as necessary
934
+ if self.perturb_type == "overexpress":
935
+ def calc_perturbation_length(ids):
936
+ if ids == [-100]:
937
+ return 0
938
+ else:
939
+ return len(ids)
940
+
941
+ max_tensor_size = max([length - calc_perturbation_length(ids) - 2 for length, ids in zip(minibatch["length"], indices_to_perturb)])
942
+
943
+ max_n_overflow = max(minibatch["n_overflow"])
944
+ if max_n_overflow > 0 and perturbation_emb.size()[1] < original_emb.size()[1]:
945
+ original_emb = original_emb[:, 0 : perturbation_emb.size()[1], :]
946
+ elif perturbation_emb.size()[1] < original_emb.size()[1]:
947
+ original_emb = original_emb[:, 0:max_tensor_size, :]
948
+
949
  gene_cos_sims = pu.quant_cos_sims(
950
  perturbation_emb,
951
  original_emb,