Upload in_silico_perturber_stats.py

#313
by davidjwen - opened
geneformer/in_silico_perturber_stats.py CHANGED
@@ -192,16 +192,27 @@ def get_impact_component(test_value, gaussian_mixture_model):
192
 
193
 
194
  # aggregate data for single perturbation in multiple cells
195
- def isp_aggregate_grouped_perturb(cos_sims_df, dict_list):
196
- names = ["Cosine_shift"]
197
- cos_sims_full_df = pd.DataFrame(columns=names)
198
 
199
- cos_shift_data = []
200
- token = cos_sims_df["Gene"][0]
201
- for dict_i in dict_list:
202
- cos_shift_data += dict_i.get((token, "cell_emb"), [])
203
- cos_sims_full_df["Cosine_shift"] = cos_shift_data
204
- return cos_sims_full_df
 
 
 
 
 
 
 
 
 
 
 
205
 
206
 
207
  def find(variable, x):
@@ -1017,8 +1028,8 @@ class InSilicoPerturberStats:
1017
  cos_sims_df_initial, dict_list, self.combos, self.anchor_token
1018
  )
1019
 
1020
- elif self.mode == "aggregate_data":
1021
- cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list)
1022
 
1023
  elif self.mode == "aggregate_gene_shifts":
1024
  cos_sims_df = isp_aggregate_gene_shifts(
 
192
 
193
 
194
  # aggregate data for single perturbation in multiple cells
195
+ def isp_aggregate_grouped_perturb(cos_sims_df, dict_list, genes_perturbed):
196
+ names = ["Cosine_shift", "Gene"]
197
+ cos_sims_full_dfs = []
198
 
199
+
200
+ gene_ids_df = cos_sims_df.loc[np.isin(cos_sims_df["Ensembl_ID"], genes_perturbed), :]
201
+ tokens = gene_ids_df["Gene"]
202
+ symbols = gene_ids_df["Gene_name"]
203
+
204
+ for token, symbol in zip(tokens, symbols):
205
+ cos_shift_data = []
206
+ for dict_i in dict_list:
207
+ cos_shift_data += dict_i.get((token, "cell_emb"), [])
208
+
209
+ df = pd.DataFrame(columns=names)
210
+ df["Cosine_shift"] = cos_shift_data
211
+ df["Gene"] = symbol
212
+ cos_sims_full_dfs.append(df)
213
+
214
+
215
+ return pd.concat(cos_sims_full_dfs)
216
 
217
 
218
  def find(variable, x):
 
1028
  cos_sims_df_initial, dict_list, self.combos, self.anchor_token
1029
  )
1030
 
1031
+ elif self.mode == "aggregate_data":
1032
+ cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list, self.genes_perturbed)
1033
 
1034
  elif self.mode == "aggregate_gene_shifts":
1035
  cos_sims_df = isp_aggregate_gene_shifts(