ctheodoris
commited on
Commit
·
25dd1da
1
Parent(s):
eb038a6
update perturber stats to reflect cos sim and emb_extractor to suppress warnings for non-cls
Browse files
geneformer/emb_extractor.py
CHANGED
@@ -78,7 +78,7 @@ def get_embs(
|
|
78 |
gene_token_dict = {v:k for k,v in token_gene_dict.items()}
|
79 |
cls_token_id = gene_token_dict["<cls>"]
|
80 |
assert filtered_input_data["input_ids"][0][0] == cls_token_id, "First token is not <cls> token value"
|
81 |
-
|
82 |
if cls_present:
|
83 |
logger.warning("CLS token present in token dictionary, excluding from average.")
|
84 |
if eos_present:
|
|
|
78 |
gene_token_dict = {v:k for k,v in token_gene_dict.items()}
|
79 |
cls_token_id = gene_token_dict["<cls>"]
|
80 |
assert filtered_input_data["input_ids"][0][0] == cls_token_id, "First token is not <cls> token value"
|
81 |
+
elif emb_mode == "cell":
|
82 |
if cls_present:
|
83 |
logger.warning("CLS token present in token dictionary, excluding from average.")
|
84 |
if eos_present:
|
geneformer/in_silico_perturber_stats.py
CHANGED
@@ -193,9 +193,8 @@ def get_impact_component(test_value, gaussian_mixture_model):
|
|
193 |
|
194 |
# aggregate data for single perturbation in multiple cells
|
195 |
def isp_aggregate_grouped_perturb(cos_sims_df, dict_list, genes_perturbed):
|
196 |
-
names = ["
|
197 |
cos_sims_full_dfs = []
|
198 |
-
|
199 |
if isinstance(genes_perturbed,list):
|
200 |
if len(genes_perturbed)>1:
|
201 |
gene_ids_df = cos_sims_df.loc[np.isin([set(idx) for idx in cos_sims_df["Ensembl_ID"]], set(genes_perturbed)), :]
|
@@ -222,7 +221,7 @@ def isp_aggregate_grouped_perturb(cos_sims_df, dict_list, genes_perturbed):
|
|
222 |
cos_shift_data += dict_i.get((token, "cell_emb"), [])
|
223 |
|
224 |
df = pd.DataFrame(columns=names)
|
225 |
-
df["
|
226 |
df["Gene"] = symbol
|
227 |
cos_sims_full_dfs.append(df)
|
228 |
|
@@ -233,6 +232,8 @@ def find(variable, x):
|
|
233 |
try:
|
234 |
if x in variable: # Test if variable is iterable and contains x
|
235 |
return True
|
|
|
|
|
236 |
except (ValueError, TypeError):
|
237 |
return x == variable # Test if variable is x if non-iterable
|
238 |
|
@@ -273,15 +274,15 @@ def isp_aggregate_gene_shifts(
|
|
273 |
cos_sims_full_df["Affected_Ensembl_ID"] = [
|
274 |
gene_token_id_dict.get(token, np.nan) for token in cos_sims_full_df["Affected"]
|
275 |
]
|
276 |
-
cos_sims_full_df["
|
277 |
-
cos_sims_full_df["
|
278 |
cos_sims_full_df["N_Detections"] = [v[2] for k, v in cos_data_mean.items()]
|
279 |
|
280 |
specific_val = "cell_emb"
|
281 |
cos_sims_full_df["temp"] = list(cos_sims_full_df["Affected"] == specific_val)
|
282 |
-
# reorder so cell embs are at the top and all are subordered by magnitude of cosine
|
283 |
cos_sims_full_df = cos_sims_full_df.sort_values(
|
284 |
-
by=(["temp", "
|
285 |
).drop("temp", axis=1)
|
286 |
|
287 |
return cos_sims_full_df
|
@@ -939,11 +940,11 @@ class InSilicoPerturberStats:
|
|
939 |
| 1: within impact component; 0: not within impact component
|
940 |
| "Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component
|
941 |
|
942 |
-
| In case of aggregating gene shifts:
|
943 |
| "Perturbed": ID(s) of gene(s) being perturbed
|
944 |
| "Affected": ID of affected gene or "cell_emb" indicating the impact on the cell embedding as a whole
|
945 |
-
| "
|
946 |
-
| "
|
947 |
"""
|
948 |
|
949 |
if self.mode not in [
|
|
|
193 |
|
194 |
# aggregate data for single perturbation in multiple cells
|
195 |
def isp_aggregate_grouped_perturb(cos_sims_df, dict_list, genes_perturbed):
|
196 |
+
names = ["Cosine_sim", "Gene"]
|
197 |
cos_sims_full_dfs = []
|
|
|
198 |
if isinstance(genes_perturbed,list):
|
199 |
if len(genes_perturbed)>1:
|
200 |
gene_ids_df = cos_sims_df.loc[np.isin([set(idx) for idx in cos_sims_df["Ensembl_ID"]], set(genes_perturbed)), :]
|
|
|
221 |
cos_shift_data += dict_i.get((token, "cell_emb"), [])
|
222 |
|
223 |
df = pd.DataFrame(columns=names)
|
224 |
+
df["Cosine_sim"] = cos_shift_data
|
225 |
df["Gene"] = symbol
|
226 |
cos_sims_full_dfs.append(df)
|
227 |
|
|
|
232 |
try:
|
233 |
if x in variable: # Test if variable is iterable and contains x
|
234 |
return True
|
235 |
+
elif x == variable:
|
236 |
+
return True
|
237 |
except (ValueError, TypeError):
|
238 |
return x == variable # Test if variable is x if non-iterable
|
239 |
|
|
|
274 |
cos_sims_full_df["Affected_Ensembl_ID"] = [
|
275 |
gene_token_id_dict.get(token, np.nan) for token in cos_sims_full_df["Affected"]
|
276 |
]
|
277 |
+
cos_sims_full_df["Cosine_sim_mean"] = [v[0] for k, v in cos_data_mean.items()]
|
278 |
+
cos_sims_full_df["Cosine_sim_stdev"] = [v[1] for k, v in cos_data_mean.items()]
|
279 |
cos_sims_full_df["N_Detections"] = [v[2] for k, v in cos_data_mean.items()]
|
280 |
|
281 |
specific_val = "cell_emb"
|
282 |
cos_sims_full_df["temp"] = list(cos_sims_full_df["Affected"] == specific_val)
|
283 |
+
# reorder so cell embs are at the top and all are subordered by magnitude of cosine sim
|
284 |
cos_sims_full_df = cos_sims_full_df.sort_values(
|
285 |
+
by=(["temp", "Cosine_sim_mean"]), ascending=[False, True]
|
286 |
).drop("temp", axis=1)
|
287 |
|
288 |
return cos_sims_full_df
|
|
|
940 |
| 1: within impact component; 0: not within impact component
|
941 |
| "Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component
|
942 |
|
943 |
+
| In case of aggregating data / gene shifts:
|
944 |
| "Perturbed": ID(s) of gene(s) being perturbed
|
945 |
| "Affected": ID of affected gene or "cell_emb" indicating the impact on the cell embedding as a whole
|
946 |
+
| "Cosine_sim_mean": mean of cosine similarity of cell or affected gene in original vs. perturbed
|
947 |
+
| "Cosine_sim_stdev": standard deviation of cosine similarity of cell or affected gene in original vs. perturbed
|
948 |
"""
|
949 |
|
950 |
if self.mode not in [
|