ctheodoris
commited on
update isp for cls perturb set
Browse files
geneformer/perturber_utils.py
CHANGED
@@ -620,9 +620,10 @@ def quant_cos_sims(
|
|
620 |
cos = torch.nn.CosineSimilarity(dim=1)
|
621 |
|
622 |
# if emb_mode == "gene", can only calculate gene cos sims
|
623 |
-
# against original cell
|
624 |
if cell_states_to_model is None or emb_mode == "gene":
|
625 |
cos_sims = cos(perturbation_emb, original_emb).to("cuda")
|
|
|
626 |
elif cell_states_to_model is not None and emb_mode == "cell":
|
627 |
possible_states = get_possible_states(cell_states_to_model)
|
628 |
cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
|
|
|
620 |
cos = torch.nn.CosineSimilarity(dim=1)
|
621 |
|
622 |
# if emb_mode == "gene", can only calculate gene cos sims
|
623 |
+
# against original cell
|
624 |
if cell_states_to_model is None or emb_mode == "gene":
|
625 |
cos_sims = cos(perturbation_emb, original_emb).to("cuda")
|
626 |
+
|
627 |
elif cell_states_to_model is not None and emb_mode == "cell":
|
628 |
possible_states = get_possible_states(cell_states_to_model)
|
629 |
cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
|