hchen725 commited on
Commit
93035af
·
verified ·
1 Parent(s): cec6545

add cls specific option for perturbations

Browse files
Files changed (1) hide show
  1. geneformer/in_silico_perturber.py +184 -121
geneformer/in_silico_perturber.py CHANGED
@@ -821,17 +821,6 @@ class InSilicoPerturber:
821
  stored_gene_embs_dict = defaultdict(list)
822
  for i in trange(len(filtered_input_data)):
823
  example_cell = filtered_input_data.select([i])
824
- full_original_emb = get_embs(
825
- model,
826
- example_cell,
827
- "gene",
828
- layer_to_quant,
829
- self.pad_token_id,
830
- self.forward_batch_size,
831
- self.token_gene_dict,
832
- summary_stat=None,
833
- silent=True,
834
- )
835
 
836
  # gene_list is used to assign cos sims back to genes
837
  # need to remove the anchor gene
@@ -839,6 +828,9 @@ class InSilicoPerturber:
839
  if self.anchor_token is not None:
840
  for token in self.anchor_token:
841
  gene_list.remove(token)
 
 
 
842
 
843
  perturbation_batch, indices_to_perturb = pu.make_perturbation_batch(
844
  example_cell,
@@ -861,6 +853,8 @@ class InSilicoPerturber:
861
  silent=True,
862
  )
863
 
 
 
864
  num_inds_perturbed = 1 + self.combos
865
 
866
  if self.perturb_type == "overexpress":
@@ -868,11 +862,22 @@ class InSilicoPerturber:
868
  elif self.perturb_type == "delete":
869
  perturbation_emb = full_perturbation_emb
870
 
871
- original_batch = pu.make_comparison_batch(
872
- full_original_emb, indices_to_perturb, perturb_group=False
 
 
 
 
 
 
 
 
873
  )
874
 
875
  if self.cell_states_to_model is None or self.emb_mode == "cell_and_gene":
 
 
 
876
  gene_cos_sims = pu.quant_cos_sims(
877
  perturbation_emb,
878
  original_batch,
@@ -880,6 +885,7 @@ class InSilicoPerturber:
880
  self.state_embs_dict,
881
  emb_mode="gene",
882
  )
 
883
 
884
  if self.cell_states_to_model is not None:
885
  original_cell_emb = pu.compute_nonpadded_cell_embedding(
@@ -896,6 +902,8 @@ class InSilicoPerturber:
896
  self.state_embs_dict,
897
  emb_mode="cell",
898
  )
 
 
899
 
900
  if self.emb_mode == "cell_and_gene":
901
  # remove perturbed index for gene list
@@ -917,6 +925,9 @@ class InSilicoPerturber:
917
  (perturbed_gene, affected_gene)
918
  ] = gene_cos_sims[perturbation_i, gene_j].item()
919
 
 
 
 
920
  if self.cell_states_to_model is None:
921
  cos_sims_data = torch.mean(gene_cos_sims, dim=1)
922
  cos_sims_dict = self.update_perturbation_dictionary(
@@ -963,12 +974,6 @@ class InSilicoPerturber:
963
  if self.emb_mode == "cell_and_gene":
964
  stored_gene_embs_dict = defaultdict(list)
965
 
966
- del full_original_emb
967
- del perturbation_batch
968
- del full_perturbation_emb
969
- del perturbation_emb
970
- del original_batch
971
-
972
  torch.cuda.empty_cache()
973
 
974
  pu.write_perturbation_dictionary(
@@ -1002,29 +1007,23 @@ class InSilicoPerturber:
1002
  stored_gene_embs_dict = defaultdict(list)
1003
  for i in trange(len(filtered_input_data)):
1004
  example_cell = filtered_input_data.select([i])
1005
- full_original_emb = get_embs(
1006
- model,
1007
- example_cell,
1008
- "gene",
1009
- layer_to_quant,
1010
- self.pad_token_id,
1011
- self.forward_batch_size,
1012
- self.token_gene_dict,
1013
- summary_stat=None,
1014
- silent=True,
1015
- )
1016
-
1017
  # gene_list is used to assign cos sims back to genes
1018
  # need to remove the anchor gene
1019
  gene_list = example_cell["input_ids"][0][:]
 
 
 
 
1020
  if self.anchor_token is not None:
1021
  for token in self.anchor_token:
1022
  gene_list.remove(token)
 
 
 
1023
 
1024
  # Also exclude special token from gene_list
1025
- if self.special_token:
1026
- for token in [self.cls_token_id, self.eos_token_id]:
1027
- gene_list.remove(token)
1028
 
1029
  perturbation_batch, indices_to_perturb = pu.make_perturbation_batch_special(
1030
  example_cell,
@@ -1034,102 +1033,171 @@ class InSilicoPerturber:
1034
  self.combos,
1035
  self.nproc,
1036
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1037
 
1038
- full_perturbation_emb = get_embs(
1039
- model,
1040
- perturbation_batch,
1041
- "gene",
1042
- layer_to_quant,
1043
- self.pad_token_id,
1044
- self.forward_batch_size,
1045
- self.token_gene_dict,
1046
- summary_stat=None,
1047
- silent=True,
1048
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1049
 
1050
- num_inds_perturbed = 1 + self.combos
1051
-
1052
- # need to remove overexpressed gene and cls/eos to quantify cosine shifts
1053
- if self.perturb_type == "overexpress":
1054
- perturbation_emb = full_perturbation_emb[:, 1+num_inds_perturbed:-1, :]
1055
- elif self.perturb_type == "delete":
1056
- perturbation_emb = full_perturbation_emb[:, 1:-1, :]
 
 
 
 
1057
 
1058
- original_batch = pu.make_comparison_batch(
1059
- full_original_emb, indices_to_perturb, perturb_group=False
1060
- )
 
1061
 
1062
- original_batch = original_batch[:, 1:-1, :]
 
 
 
 
 
 
 
 
 
1063
 
1064
- if self.cell_states_to_model is None or self.emb_mode == "cls_and_gene":
1065
- gene_cos_sims = pu.quant_cos_sims(
1066
- perturbation_emb,
1067
- original_batch,
1068
- self.cell_states_to_model,
1069
- self.state_embs_dict,
1070
- emb_mode="gene",
1071
- )
1072
 
1073
- if self.cell_states_to_model is not None:
1074
- # get cls emb
1075
- original_cell_emb = full_original_emb[:,0,:]
1076
- perturbation_cell_emb = full_perturbation_emb[:,0,:]
 
 
 
 
 
1077
 
1078
- cell_cos_sims = pu.quant_cos_sims(
1079
- perturbation_cell_emb,
1080
- original_cell_emb,
1081
- self.cell_states_to_model,
1082
- self.state_embs_dict,
1083
- emb_mode="cell",
1084
- )
1085
 
1086
- if self.emb_mode == "cls_and_gene":
1087
- # remove perturbed index for gene list
1088
- perturbed_gene_dict = {
1089
- gene: gene_list[:i] + gene_list[i + 1 :]
1090
- for i, gene in enumerate(gene_list)
1091
- }
 
 
 
 
 
 
1092
 
1093
- for perturbation_i, perturbed_gene in enumerate(gene_list):
1094
- for gene_j, affected_gene in enumerate(
1095
- perturbed_gene_dict[perturbed_gene]
1096
- ):
1097
- try:
1098
- stored_gene_embs_dict[
1099
- (perturbed_gene, affected_gene)
1100
- ].append(gene_cos_sims[perturbation_i, gene_j].item())
1101
- except KeyError:
1102
- stored_gene_embs_dict[
1103
- (perturbed_gene, affected_gene)
1104
- ] = gene_cos_sims[perturbation_i, gene_j].item()
1105
 
1106
- if self.cell_states_to_model is None:
1107
- original_cls_emb = full_original_emb[:,0,:]
1108
- perturbation_cls_emb = full_perturbation_emb[:,0,:]
1109
- cos_sims_data = pu.quant_cos_sims(
1110
- perturbation_cls_emb,
1111
- original_cls_emb,
1112
- self.cell_states_to_model,
1113
- self.state_embs_dict,
1114
- emb_mode="cell",
1115
- )
1116
- cos_sims_dict = self.update_perturbation_dictionary(
1117
- cos_sims_dict,
1118
- cos_sims_data,
1119
- filtered_input_data,
1120
- indices_to_perturb,
1121
- gene_list,
1122
- )
1123
- else:
1124
- cos_sims_data = cell_cos_sims
1125
- for state in cos_sims_dict.keys():
1126
- cos_sims_dict[state] = self.update_perturbation_dictionary(
1127
- cos_sims_dict[state],
1128
- cos_sims_data[state],
1129
  filtered_input_data,
1130
  indices_to_perturb,
1131
  gene_list,
1132
  )
 
 
 
 
 
 
 
 
 
 
1133
 
1134
  # save dict to disk every 100 cells
1135
  if i % self.clear_mem_ncells/10 == 0:
@@ -1157,11 +1225,6 @@ class InSilicoPerturber:
1157
  if self.emb_mode == "cls_and_gene":
1158
  stored_gene_embs_dict = defaultdict(list)
1159
 
1160
- del full_original_emb
1161
- del perturbation_batch
1162
- del full_perturbation_emb
1163
- del perturbation_emb
1164
- del original_batch
1165
  torch.cuda.empty_cache()
1166
 
1167
  pu.write_perturbation_dictionary(
 
821
  stored_gene_embs_dict = defaultdict(list)
822
  for i in trange(len(filtered_input_data)):
823
  example_cell = filtered_input_data.select([i])
 
 
 
 
 
 
 
 
 
 
 
824
 
825
  # gene_list is used to assign cos sims back to genes
826
  # need to remove the anchor gene
 
828
  if self.anchor_token is not None:
829
  for token in self.anchor_token:
830
  gene_list.remove(token)
831
+ else:
832
+ if self.perturb_type == "overexpress":
833
+ gene_list = gene_list[1:]
834
 
835
  perturbation_batch, indices_to_perturb = pu.make_perturbation_batch(
836
  example_cell,
 
853
  silent=True,
854
  )
855
 
856
+ del perturbation_batch
857
+
858
  num_inds_perturbed = 1 + self.combos
859
 
860
  if self.perturb_type == "overexpress":
 
862
  elif self.perturb_type == "delete":
863
  perturbation_emb = full_perturbation_emb
864
 
865
+ full_original_emb = get_embs(
866
+ model,
867
+ example_cell,
868
+ "gene",
869
+ layer_to_quant,
870
+ self.pad_token_id,
871
+ self.forward_batch_size,
872
+ self.token_gene_dict,
873
+ summary_stat=None,
874
+ silent=True,
875
  )
876
 
877
  if self.cell_states_to_model is None or self.emb_mode == "cell_and_gene":
878
+ original_batch = pu.make_comparison_batch(
879
+ full_original_emb, indices_to_perturb, perturb_group=False
880
+ )
881
  gene_cos_sims = pu.quant_cos_sims(
882
  perturbation_emb,
883
  original_batch,
 
885
  self.state_embs_dict,
886
  emb_mode="gene",
887
  )
888
+ del original_batch
889
 
890
  if self.cell_states_to_model is not None:
891
  original_cell_emb = pu.compute_nonpadded_cell_embedding(
 
902
  self.state_embs_dict,
903
  emb_mode="cell",
904
  )
905
+ del original_cell_emb
906
+ del perturbation_cell_emb
907
 
908
  if self.emb_mode == "cell_and_gene":
909
  # remove perturbed index for gene list
 
925
  (perturbed_gene, affected_gene)
926
  ] = gene_cos_sims[perturbation_i, gene_j].item()
927
 
928
+ del full_original_emb
929
+ del full_perturbation_emb
930
+
931
  if self.cell_states_to_model is None:
932
  cos_sims_data = torch.mean(gene_cos_sims, dim=1)
933
  cos_sims_dict = self.update_perturbation_dictionary(
 
974
  if self.emb_mode == "cell_and_gene":
975
  stored_gene_embs_dict = defaultdict(list)
976
 
 
 
 
 
 
 
977
  torch.cuda.empty_cache()
978
 
979
  pu.write_perturbation_dictionary(
 
1007
  stored_gene_embs_dict = defaultdict(list)
1008
  for i in trange(len(filtered_input_data)):
1009
  example_cell = filtered_input_data.select([i])
1010
+
 
 
 
 
 
 
 
 
 
 
 
1011
  # gene_list is used to assign cos sims back to genes
1012
  # need to remove the anchor gene
1013
  gene_list = example_cell["input_ids"][0][:]
1014
+ if self.special_token:
1015
+ for token in [self.cls_token_id, self.eos_token_id]:
1016
+ gene_list.remove(token)
1017
+
1018
  if self.anchor_token is not None:
1019
  for token in self.anchor_token:
1020
  gene_list.remove(token)
1021
+ else:
1022
+ if self.perturb_type == "overexpress":
1023
+ gene_list = gene_list[1:]
1024
 
1025
  # Also exclude special token from gene_list
1026
+
 
 
1027
 
1028
  perturbation_batch, indices_to_perturb = pu.make_perturbation_batch_special(
1029
  example_cell,
 
1033
  self.combos,
1034
  self.nproc,
1035
  )
1036
+ if self.emb_mode == "cls":
1037
+ perturbation_cls_emb = get_embs(
1038
+ model,
1039
+ perturbation_batch,
1040
+ "cls",
1041
+ layer_to_quant,
1042
+ self.pad_token_id,
1043
+ self.forward_batch_size,
1044
+ self.token_gene_dict,
1045
+ summary_stat=None,
1046
+ silent=True,
1047
+ )
1048
+
1049
+ original_cls_emb = get_embs(
1050
+ model,
1051
+ example_cell,
1052
+ "cls",
1053
+ layer_to_quant,
1054
+ self.pad_token_id,
1055
+ self.forward_batch_size,
1056
+ self.token_gene_dict,
1057
+ summary_stat=None,
1058
+ silent=True,
1059
+ )
1060
+
1061
+ if self.cell_states_to_model is None:
1062
+ cos_sims_data = pu.quant_cos_sims(
1063
+ perturbation_cls_emb,
1064
+ original_cls_emb,
1065
+ self.cell_states_to_model,
1066
+ self.state_embs_dict,
1067
+ emb_mode="cell",
1068
+ )
1069
 
1070
+ cos_sims_dict = self.update_perturbation_dictionary(
1071
+ cos_sims_dict,
1072
+ cos_sims_data,
1073
+ filtered_input_data,
1074
+ indices_to_perturb,
1075
+ gene_list,
1076
+ )
1077
+ else:
1078
+ cos_sims_data = cell_cos_sims
1079
+ for state in cos_sims_dict.keys():
1080
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
1081
+ cos_sims_dict[state],
1082
+ cos_sims_data[state],
1083
+ filtered_input_data,
1084
+ indices_to_perturb,
1085
+ gene_list,
1086
+ )
1087
+ else:
1088
+ full_perturbation_emb = get_embs(
1089
+ model,
1090
+ perturbation_batch,
1091
+ "gene",
1092
+ layer_to_quant,
1093
+ self.pad_token_id,
1094
+ self.forward_batch_size,
1095
+ self.token_gene_dict,
1096
+ summary_stat=None,
1097
+ silent=True,
1098
+ )
1099
+ del perturbation_batch
1100
+ num_inds_perturbed = 1 + self.combos
1101
+
1102
+ # need to remove overexpressed gene and cls/eos to quantify cosine shifts
1103
+ if self.perturb_type == "overexpress":
1104
+ perturbation_emb = full_perturbation_emb[:, 1+num_inds_perturbed:-1, :]
1105
+ elif self.perturb_type == "delete":
1106
+ perturbation_emb = full_perturbation_emb[:, 1:-1, :]
1107
 
1108
+ full_original_emb = get_embs(
1109
+ model,
1110
+ example_cell,
1111
+ "gene",
1112
+ layer_to_quant,
1113
+ self.pad_token_id,
1114
+ self.forward_batch_size,
1115
+ self.token_gene_dict,
1116
+ summary_stat=None,
1117
+ silent=True,
1118
+ )
1119
 
1120
+ if self.cell_states_to_model is None or self.emb_mode == "cls_and_gene":
1121
+ original_batch = pu.make_comparison_batch(
1122
+ full_original_emb, indices_to_perturb, perturb_group=False
1123
+ )
1124
 
1125
+ original_batch = original_batch[:, 1:-1, :]
1126
+ gene_cos_sims = pu.quant_cos_sims(
1127
+ perturbation_emb,
1128
+ original_batch,
1129
+ self.cell_states_to_model,
1130
+ self.state_embs_dict,
1131
+ emb_mode="gene",
1132
+ )
1133
+ del perturbation_emb
1134
+ del original_batch
1135
 
1136
+ if self.cell_states_to_model is not None:
1137
+ # get cls emb
1138
+ original_cls_emb = full_original_emb[:,0,:]
1139
+ perturbation_cls_emb = full_perturbation_emb[:,0,:]
 
 
 
 
1140
 
1141
+ cell_cos_sims = pu.quant_cos_sims(
1142
+ perturbation_cls_emb,
1143
+ original_cls_emb,
1144
+ self.cell_states_to_model,
1145
+ self.state_embs_dict,
1146
+ emb_mode="cell",
1147
+ )
1148
+ del original_cls_emb
1149
+ del perturbation_cls_emb
1150
 
1151
+ if self.emb_mode == "cls_and_gene":
1152
+ # remove perturbed index for gene list
1153
+ perturbed_gene_dict = {
1154
+ gene: gene_list[:i] + gene_list[i + 1 :]
1155
+ for i, gene in enumerate(gene_list)
1156
+ }
 
1157
 
1158
+ for perturbation_i, perturbed_gene in enumerate(gene_list):
1159
+ for gene_j, affected_gene in enumerate(
1160
+ perturbed_gene_dict[perturbed_gene]
1161
+ ):
1162
+ try:
1163
+ stored_gene_embs_dict[
1164
+ (perturbed_gene, affected_gene)
1165
+ ].append(gene_cos_sims[perturbation_i, gene_j].item())
1166
+ except KeyError:
1167
+ stored_gene_embs_dict[
1168
+ (perturbed_gene, affected_gene)
1169
+ ] = gene_cos_sims[perturbation_i, gene_j].item()
1170
 
1171
+ if self.cell_states_to_model is None:
1172
+ original_cls_emb = full_original_emb[:,0,:]
1173
+ perturbation_cls_emb = full_perturbation_emb[:,0,:]
1174
+ cos_sims_data = pu.quant_cos_sims(
1175
+ perturbation_cls_emb,
1176
+ original_cls_emb,
1177
+ self.cell_states_to_model,
1178
+ self.state_embs_dict,
1179
+ emb_mode="cell",
1180
+ )
1181
+ del original_cls_emb
1182
+ del perturbation_cls_emb
1183
 
1184
+ cos_sims_dict = self.update_perturbation_dictionary(
1185
+ cos_sims_dict,
1186
+ cos_sims_data,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1187
  filtered_input_data,
1188
  indices_to_perturb,
1189
  gene_list,
1190
  )
1191
+ else:
1192
+ cos_sims_data = cell_cos_sims
1193
+ for state in cos_sims_dict.keys():
1194
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
1195
+ cos_sims_dict[state],
1196
+ cos_sims_data[state],
1197
+ filtered_input_data,
1198
+ indices_to_perturb,
1199
+ gene_list,
1200
+ )
1201
 
1202
  # save dict to disk every 100 cells
1203
  if i % self.clear_mem_ncells/10 == 0:
 
1225
  if self.emb_mode == "cls_and_gene":
1226
  stored_gene_embs_dict = defaultdict(list)
1227
 
 
 
 
 
 
1228
  torch.cuda.empty_cache()
1229
 
1230
  pu.write_perturbation_dictionary(