ctheodoris commited on
Commit
eb038a6
·
1 Parent(s): b2bbd7c

update to account for set of perturbed genes with aggregate_data

Browse files
geneformer/in_silico_perturber_stats.py CHANGED
@@ -196,8 +196,23 @@ def isp_aggregate_grouped_perturb(cos_sims_df, dict_list, genes_perturbed):
196
  names = ["Cosine_shift", "Gene"]
197
  cos_sims_full_dfs = []
198
 
199
-
200
- gene_ids_df = cos_sims_df.loc[np.isin(cos_sims_df["Ensembl_ID"], genes_perturbed), :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  tokens = gene_ids_df["Gene"]
202
  symbols = gene_ids_df["Gene_name"]
203
 
@@ -210,7 +225,6 @@ def isp_aggregate_grouped_perturb(cos_sims_df, dict_list, genes_perturbed):
210
  df["Cosine_shift"] = cos_shift_data
211
  df["Gene"] = symbol
212
  cos_sims_full_dfs.append(df)
213
-
214
 
215
  return pd.concat(cos_sims_full_dfs)
216
 
 
196
  names = ["Cosine_shift", "Gene"]
197
  cos_sims_full_dfs = []
198
 
199
+ if isinstance(genes_perturbed,list):
200
+ if len(genes_perturbed)>1:
201
+ gene_ids_df = cos_sims_df.loc[np.isin([set(idx) for idx in cos_sims_df["Ensembl_ID"]], set(genes_perturbed)), :]
202
+ else:
203
+ gene_ids_df = cos_sims_df.loc[np.isin(cos_sims_df["Ensembl_ID"], genes_perturbed), :]
204
+ else:
205
+ logger.error(
206
+ "aggregate_data is for perturbation of single gene or single group of genes. genes_to_perturb should be formatted as list."
207
+ )
208
+ raise
209
+
210
+ if gene_ids_df.empty:
211
+ logger.error(
212
+ "genes_to_perturb not found in data."
213
+ )
214
+ raise
215
+
216
  tokens = gene_ids_df["Gene"]
217
  symbols = gene_ids_df["Gene_name"]
218
 
 
225
  df["Cosine_shift"] = cos_shift_data
226
  df["Gene"] = symbol
227
  cos_sims_full_dfs.append(df)
 
228
 
229
  return pd.concat(cos_sims_full_dfs)
230