Christina Theodoris commited on
Commit
912860d
·
1 Parent(s): f0b6641

Add instructions for modeling only 2 states and modify stats script for that option

Browse files
geneformer/in_silico_perturber.py CHANGED
@@ -382,6 +382,8 @@ class InSilicoPerturber:
382
  Cell states to model if testing perturbations that achieve goal state change.
383
  Single-item dictionary with key being cell attribute (e.g. "disease").
384
  Value is tuple of three lists indicating start state, goal end state, and alternate possible end states.
 
 
385
  max_ncells : None, int
386
  Maximum number of cells to test.
387
  If None, will test all cells.
 
382
  Cell states to model if testing perturbations that achieve goal state change.
383
  Single-item dictionary with key being cell attribute (e.g. "disease").
384
  Value is tuple of three lists indicating start state, goal end state, and alternate possible end states.
385
+ If no alternate possible end states, third list should be empty or have a single element that is None.
386
+ (i.e. the third list should be [] or [None]).
387
  max_ncells : None, int
388
  Maximum number of cells to test.
389
  If None, will test all cells.
geneformer/in_silico_perturber_stats.py CHANGED
@@ -107,26 +107,37 @@ def get_impact_component(test_value, gaussian_mixture_model):
107
  return impact_component
108
 
109
  # stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
110
- def isp_stats_to_goal_state(cos_sims_df, dict_list):
 
 
 
 
 
111
  random_tuples = []
112
  for i in trange(cos_sims_df.shape[0]):
113
  token = cos_sims_df["Gene"][i]
114
  for dict_i in dict_list:
115
  random_tuples += dict_i.get((token, "cell_emb"),[])
116
- goal_end_random_megalist = [goal_end for goal_end,alt_end,start_state in random_tuples]
117
- alt_end_random_megalist = [alt_end for goal_end,alt_end,start_state in random_tuples]
118
- start_state_random_megalist = [start_state for goal_end,alt_end,start_state in random_tuples]
 
 
 
 
 
119
 
120
  # downsample to improve speed of ranksums
121
  if len(goal_end_random_megalist) > 100_000:
122
  random.seed(42)
123
  goal_end_random_megalist = random.sample(goal_end_random_megalist, k=100_000)
124
- if len(alt_end_random_megalist) > 100_000:
125
- random.seed(42)
126
- alt_end_random_megalist = random.sample(alt_end_random_megalist, k=100_000)
127
  if len(start_state_random_megalist) > 100_000:
128
  random.seed(42)
129
  start_state_random_megalist = random.sample(start_state_random_megalist, k=100_000)
 
 
 
 
130
 
131
  names=["Gene",
132
  "Gene_name",
@@ -135,6 +146,9 @@ def isp_stats_to_goal_state(cos_sims_df, dict_list):
135
  "Shift_from_alt_end",
136
  "Goal_end_vs_random_pval",
137
  "Alt_end_vs_random_pval"]
 
 
 
138
  cos_sims_full_df = pd.DataFrame(columns=names)
139
 
140
  for i in trange(cos_sims_df.shape[0]):
@@ -145,29 +159,39 @@ def isp_stats_to_goal_state(cos_sims_df, dict_list):
145
 
146
  for dict_i in dict_list:
147
  cos_shift_data += dict_i.get((token, "cell_emb"),[])
148
-
149
- goal_end_cos_sim_megalist = [goal_end for goal_end,alt_end,start_state in cos_shift_data]
150
- alt_end_cos_sim_megalist = [alt_end for goal_end,alt_end,start_state in cos_shift_data]
 
 
 
 
 
151
 
152
  mean_goal_end = np.mean(goal_end_cos_sim_megalist)
153
- mean_alt_end = np.mean(alt_end_cos_sim_megalist)
154
-
155
  pval_goal_end = ranksums(goal_end_random_megalist,goal_end_cos_sim_megalist).pvalue
156
- pval_alt_end = ranksums(alt_end_random_megalist,alt_end_cos_sim_megalist).pvalue
157
-
158
- data_i = [token,
159
- name,
160
- ensembl_id,
161
- mean_goal_end,
162
- mean_alt_end,
163
- pval_goal_end,
164
- pval_alt_end]
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  cos_sims_df_i = pd.DataFrame(dict(zip(names,data_i)),index=[i])
167
  cos_sims_full_df = pd.concat([cos_sims_full_df,cos_sims_df_i])
168
 
169
  cos_sims_full_df["Goal_end_FDR"] = get_fdr(list(cos_sims_full_df["Goal_end_vs_random_pval"]))
170
- cos_sims_full_df["Alt_end_FDR"] = get_fdr(list(cos_sims_full_df["Alt_end_vs_random_pval"]))
 
171
 
172
  # quantify number of detections of each gene
173
  cos_sims_full_df["N_Detections"] = [n_detections(i, dict_list, "cell", None) for i in cos_sims_full_df["Gene"]]
@@ -376,6 +400,8 @@ class InSilicoPerturberStats:
376
  Cell states to model if testing perturbations that achieve goal state change.
377
  Single-item dictionary with key being cell attribute (e.g. "disease").
378
  Value is tuple of three lists indicating start state, goal end state, and alternate possible end states.
 
 
379
  token_dictionary_file : Path
380
  Path to pickle file containing token dictionary (Ensembl ID:token).
381
  gene_name_id_dictionary_file : Path
@@ -506,7 +532,7 @@ class InSilicoPerturberStats:
506
  index=[i for i in range(len(gene_list))])
507
 
508
  if self.mode == "goal_state_shift":
509
- cos_sims_df = isp_stats_to_goal_state(cos_sims_df_initial, dict_list)
510
 
511
  elif self.mode == "vs_null":
512
  null_dict_list = read_dictionaries(null_dist_data_directory, "cell", self.anchor_token)
 
107
  return impact_component
108
 
109
  # stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
110
+ def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model):
111
+ if (cell_states_to_model["disease"][2] == []) | (cell_states_to_model["disease"][2] == [None]):
112
+ alt_end_state_exists = False
113
+ elif (len(cell_states_to_model["disease"][2]) > 0) & (cell_states_to_model["disease"][2] != [None]):
114
+ alt_end_state_exists = True
115
+
116
  random_tuples = []
117
  for i in trange(cos_sims_df.shape[0]):
118
  token = cos_sims_df["Gene"][i]
119
  for dict_i in dict_list:
120
  random_tuples += dict_i.get((token, "cell_emb"),[])
121
+
122
+ if alt_end_state_exists == False:
123
+ goal_end_random_megalist = [goal_end for goal_end,start_state in random_tuples]
124
+ start_state_random_megalist = [start_state for goal_end,start_state in random_tuples]
125
+ elif alt_end_state_exists == True:
126
+ goal_end_random_megalist = [goal_end for goal_end,alt_end,start_state in random_tuples]
127
+ alt_end_random_megalist = [alt_end for goal_end,alt_end,start_state in random_tuples]
128
+ start_state_random_megalist = [start_state for goal_end,alt_end,start_state in random_tuples]
129
 
130
  # downsample to improve speed of ranksums
131
  if len(goal_end_random_megalist) > 100_000:
132
  random.seed(42)
133
  goal_end_random_megalist = random.sample(goal_end_random_megalist, k=100_000)
 
 
 
134
  if len(start_state_random_megalist) > 100_000:
135
  random.seed(42)
136
  start_state_random_megalist = random.sample(start_state_random_megalist, k=100_000)
137
+ if alt_end_state_exists == True:
138
+ if len(alt_end_random_megalist) > 100_000:
139
+ random.seed(42)
140
+ alt_end_random_megalist = random.sample(alt_end_random_megalist, k=100_000)
141
 
142
  names=["Gene",
143
  "Gene_name",
 
146
  "Shift_from_alt_end",
147
  "Goal_end_vs_random_pval",
148
  "Alt_end_vs_random_pval"]
149
+ if alt_end_state_exists == False:
150
+ names.remove("Shift_from_alt_end")
151
+ names.remove("Alt_end_vs_random_pval")
152
  cos_sims_full_df = pd.DataFrame(columns=names)
153
 
154
  for i in trange(cos_sims_df.shape[0]):
 
159
 
160
  for dict_i in dict_list:
161
  cos_shift_data += dict_i.get((token, "cell_emb"),[])
162
+
163
+ if alt_end_state_exists == False:
164
+ goal_end_cos_sim_megalist = [goal_end for goal_end,start_state in cos_shift_data]
165
+ elif alt_end_state_exists == True:
166
+ goal_end_cos_sim_megalist = [goal_end for goal_end,alt_end,start_state in cos_shift_data]
167
+ alt_end_cos_sim_megalist = [alt_end for goal_end,alt_end,start_state in cos_shift_data]
168
+ mean_alt_end = np.mean(alt_end_cos_sim_megalist)
169
+ pval_alt_end = ranksums(alt_end_random_megalist,alt_end_cos_sim_megalist).pvalue
170
 
171
  mean_goal_end = np.mean(goal_end_cos_sim_megalist)
 
 
172
  pval_goal_end = ranksums(goal_end_random_megalist,goal_end_cos_sim_megalist).pvalue
 
 
 
 
 
 
 
 
 
173
 
174
+ if alt_end_state_exists == False:
175
+ data_i = [token,
176
+ name,
177
+ ensembl_id,
178
+ mean_goal_end,
179
+ pval_goal_end]
180
+ elif alt_end_state_exists == True:
181
+ data_i = [token,
182
+ name,
183
+ ensembl_id,
184
+ mean_goal_end,
185
+ mean_alt_end,
186
+ pval_goal_end,
187
+ pval_alt_end]
188
+
189
  cos_sims_df_i = pd.DataFrame(dict(zip(names,data_i)),index=[i])
190
  cos_sims_full_df = pd.concat([cos_sims_full_df,cos_sims_df_i])
191
 
192
  cos_sims_full_df["Goal_end_FDR"] = get_fdr(list(cos_sims_full_df["Goal_end_vs_random_pval"]))
193
+ if alt_end_state_exists == True:
194
+ cos_sims_full_df["Alt_end_FDR"] = get_fdr(list(cos_sims_full_df["Alt_end_vs_random_pval"]))
195
 
196
  # quantify number of detections of each gene
197
  cos_sims_full_df["N_Detections"] = [n_detections(i, dict_list, "cell", None) for i in cos_sims_full_df["Gene"]]
 
400
  Cell states to model if testing perturbations that achieve goal state change.
401
  Single-item dictionary with key being cell attribute (e.g. "disease").
402
  Value is tuple of three lists indicating start state, goal end state, and alternate possible end states.
403
+ If no alternate possible end states, third list should be empty or have a single element that is None.
404
+ (i.e. the third list should be [] or [None]).
405
  token_dictionary_file : Path
406
  Path to pickle file containing token dictionary (Ensembl ID:token).
407
  gene_name_id_dictionary_file : Path
 
532
  index=[i for i in range(len(gene_list))])
533
 
534
  if self.mode == "goal_state_shift":
535
+ cos_sims_df = isp_stats_to_goal_state(cos_sims_df_initial, dict_list, self.cell_states_to_model)
536
 
537
  elif self.mode == "vs_null":
538
  null_dict_list = read_dictionaries(null_dist_data_directory, "cell", self.anchor_token)