Christina Theodoris
commited on
Commit
·
feeecd0
1
Parent(s):
dc1481d
Add function to create remainder emb for in silico overexpression batch
Browse files
geneformer/in_silico_perturber.py
CHANGED
@@ -140,6 +140,18 @@ def make_comparison_batch(original_emb, indices_to_perturb):
|
|
140 |
all_embs_list += [torch.cat(emb_list)]
|
141 |
return torch.stack(all_embs_list)
|
142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
# average embedding position of goal cell states
|
144 |
def get_cell_state_avg_embs(model,
|
145 |
filtered_input_data,
|
@@ -188,6 +200,7 @@ def get_cell_state_avg_embs(model,
|
|
188 |
|
189 |
# quantify cosine similarity of perturbed vs original or alternate states
|
190 |
def quant_cos_sims(model,
|
|
|
191 |
perturbation_batch,
|
192 |
forward_batch_size,
|
193 |
layer_to_quant,
|
@@ -226,8 +239,14 @@ def quant_cos_sims(model,
|
|
226 |
minibatch_emb = outputs.hidden_states[layer_to_quant]
|
227 |
if cell_states_to_model is None:
|
228 |
minibatch_comparison = comparison_batch[i:max_range]
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
|
230 |
-
|
231 |
for state in possible_states:
|
232 |
cos_sims_vs_alt_dict[state] += cos_sim_shift(original_emb, minibatch_emb, state_embs_dict[state])
|
233 |
del outputs
|
@@ -279,9 +298,9 @@ def pad_tensor_list(tensor_list, dynamic_or_constant, token_dictionary):
|
|
279 |
class InSilicoPerturber:
|
280 |
valid_option_dict = {
|
281 |
"perturb_type": {"delete","overexpress","inhibit","activate"},
|
282 |
-
"perturb_rank_shift": {None,
|
283 |
"genes_to_perturb": {"all", list},
|
284 |
-
"combos": {0,1,2},
|
285 |
"anchor_gene": {None, str},
|
286 |
"model_type": {"Pretrained","GeneClassifier","CellClassifier"},
|
287 |
"num_classes": {int},
|
@@ -326,7 +345,7 @@ class InSilicoPerturber:
|
|
326 |
"overexpress": move gene to front of rank value encoding
|
327 |
"inhibit": move gene to lower quartile of rank value encoding
|
328 |
"activate": move gene to higher quartile of rank value encoding
|
329 |
-
perturb_rank_shift : None,
|
330 |
Number of quartiles by which to shift rank of gene.
|
331 |
For example, if perturb_type="activate" and perturb_rank_shift=1:
|
332 |
genes in 4th quartile will move to middle of 3rd quartile.
|
@@ -414,6 +433,15 @@ class InSilicoPerturber:
|
|
414 |
self.tokens_to_perturb = [self.gene_token_dict[gene] for gene in self.genes_to_perturb]
|
415 |
|
416 |
def validate_options(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
417 |
for attr_name,valid_options in self.valid_option_dict.items():
|
418 |
attr_value = self.__dict__[attr_name]
|
419 |
if type(attr_value) not in {list, dict}:
|
@@ -442,7 +470,7 @@ class InSilicoPerturber:
|
|
442 |
elif self.perturb_type == "overexpress":
|
443 |
logger.warning(
|
444 |
"perturb_rank_shift set to None. " \
|
445 |
-
"If perturb type is
|
446 |
"of rank value encoding rather than shifted by quartile")
|
447 |
self.perturb_rank_shift = None
|
448 |
|
@@ -626,13 +654,14 @@ class InSilicoPerturber:
|
|
626 |
combo_lvl,
|
627 |
self.nproc)
|
628 |
cos_sims_data = quant_cos_sims(model,
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
|
|
636 |
|
637 |
if self.cell_states_to_model is None:
|
638 |
# update cos sims dict
|
@@ -699,6 +728,7 @@ class InSilicoPerturber:
|
|
699 |
0,
|
700 |
self.nproc)
|
701 |
cos_sims_data = quant_cos_sims(model,
|
|
|
702 |
perturbation_batch,
|
703 |
self.forward_batch_size,
|
704 |
layer_to_quant,
|
@@ -715,6 +745,7 @@ class InSilicoPerturber:
|
|
715 |
1,
|
716 |
self.nproc)
|
717 |
combo_cos_sims_data = quant_cos_sims(model,
|
|
|
718 |
combo_perturbation_batch,
|
719 |
self.forward_batch_size,
|
720 |
layer_to_quant,
|
|
|
140 |
all_embs_list += [torch.cat(emb_list)]
|
141 |
return torch.stack(all_embs_list)
|
142 |
|
143 |
+
# perturbed cell emb removing the activated/overexpressed/inhibited gene emb
|
144 |
+
# so that only non-perturbed gene embeddings are compared to each other
|
145 |
+
# in original or perturbed context
|
146 |
+
def make_perturbed_remainder_batch(emb_batch, indices_to_remove):
|
147 |
+
if type(indices_to_remove) == int:
|
148 |
+
indices_to_keep = [i for i in range(emb_batch.size()[1])]
|
149 |
+
indices_to_keep.pop(indices_to_remove)
|
150 |
+
perturbed_remainder_batch = torch.stack([emb[indices_to_keep,:] for emb in emb_batch])
|
151 |
+
elif type(indices_to_remove) == list:
|
152 |
+
perturbed_remainder_batch = torch.stack([make_comparison_batch(emb_batch[i],indices_to_remove[i]) for i in range(len(emb_batch))])
|
153 |
+
return perturbed_remainder_batch
|
154 |
+
|
155 |
# average embedding position of goal cell states
|
156 |
def get_cell_state_avg_embs(model,
|
157 |
filtered_input_data,
|
|
|
200 |
|
201 |
# quantify cosine similarity of perturbed vs original or alternate states
|
202 |
def quant_cos_sims(model,
|
203 |
+
perturb_type,
|
204 |
perturbation_batch,
|
205 |
forward_batch_size,
|
206 |
layer_to_quant,
|
|
|
239 |
minibatch_emb = outputs.hidden_states[layer_to_quant]
|
240 |
if cell_states_to_model is None:
|
241 |
minibatch_comparison = comparison_batch[i:max_range]
|
242 |
+
if perturb_type == "overexpress":
|
243 |
+
index_to_remove = 0
|
244 |
+
minibatch_emb = make_perturbed_remainder_batch(minibatch_emb, index_to_remove)
|
245 |
+
# elif (perturb_type == "inhibit") or (perturb_type == "activate"):
|
246 |
+
# index_to_remove = placeholder
|
247 |
+
# minibatch_emb = make_perturbed_remainder_batch(minibatch_emb, index_to_remove)
|
248 |
cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
|
249 |
+
elif cell_states_to_model is not None:
|
250 |
for state in possible_states:
|
251 |
cos_sims_vs_alt_dict[state] += cos_sim_shift(original_emb, minibatch_emb, state_embs_dict[state])
|
252 |
del outputs
|
|
|
298 |
class InSilicoPerturber:
|
299 |
valid_option_dict = {
|
300 |
"perturb_type": {"delete","overexpress","inhibit","activate"},
|
301 |
+
"perturb_rank_shift": {None, 1, 2, 3},
|
302 |
"genes_to_perturb": {"all", list},
|
303 |
+
"combos": {0, 1, 2},
|
304 |
"anchor_gene": {None, str},
|
305 |
"model_type": {"Pretrained","GeneClassifier","CellClassifier"},
|
306 |
"num_classes": {int},
|
|
|
345 |
"overexpress": move gene to front of rank value encoding
|
346 |
"inhibit": move gene to lower quartile of rank value encoding
|
347 |
"activate": move gene to higher quartile of rank value encoding
|
348 |
+
perturb_rank_shift : None, {1,2,3}
|
349 |
Number of quartiles by which to shift rank of gene.
|
350 |
For example, if perturb_type="activate" and perturb_rank_shift=1:
|
351 |
genes in 4th quartile will move to middle of 3rd quartile.
|
|
|
433 |
self.tokens_to_perturb = [self.gene_token_dict[gene] for gene in self.genes_to_perturb]
|
434 |
|
435 |
def validate_options(self):
|
436 |
+
# first disallow options under development
|
437 |
+
if self.perturb_type in ["inhibit", "activate"]:
|
438 |
+
logger.error(
|
439 |
+
f"In silico inhibition and activation currently under developemnt. " \
|
440 |
+
f"Current valid options for 'perturb_type': 'delete' or 'overexpress'"
|
441 |
+
)
|
442 |
+
raise
|
443 |
+
|
444 |
+
# confirm arguments are within valid options and compatible with each other
|
445 |
for attr_name,valid_options in self.valid_option_dict.items():
|
446 |
attr_value = self.__dict__[attr_name]
|
447 |
if type(attr_value) not in {list, dict}:
|
|
|
470 |
elif self.perturb_type == "overexpress":
|
471 |
logger.warning(
|
472 |
"perturb_rank_shift set to None. " \
|
473 |
+
"If perturb type is overexpress then gene is moved to front " \
|
474 |
"of rank value encoding rather than shifted by quartile")
|
475 |
self.perturb_rank_shift = None
|
476 |
|
|
|
654 |
combo_lvl,
|
655 |
self.nproc)
|
656 |
cos_sims_data = quant_cos_sims(model,
|
657 |
+
self.perturb_type,
|
658 |
+
perturbation_batch,
|
659 |
+
self.forward_batch_size,
|
660 |
+
layer_to_quant,
|
661 |
+
original_emb,
|
662 |
+
indices_to_perturb,
|
663 |
+
self.cell_states_to_model,
|
664 |
+
state_embs_dict)
|
665 |
|
666 |
if self.cell_states_to_model is None:
|
667 |
# update cos sims dict
|
|
|
728 |
0,
|
729 |
self.nproc)
|
730 |
cos_sims_data = quant_cos_sims(model,
|
731 |
+
self.perturb_type,
|
732 |
perturbation_batch,
|
733 |
self.forward_batch_size,
|
734 |
layer_to_quant,
|
|
|
745 |
1,
|
746 |
self.nproc)
|
747 |
combo_cos_sims_data = quant_cos_sims(model,
|
748 |
+
self.perturb_type,
|
749 |
combo_perturbation_batch,
|
750 |
self.forward_batch_size,
|
751 |
layer_to_quant,
|