Christina Theodoris
commited on
Commit
·
268e566
1
Parent(s):
57b9778
Fix min_genes to be >= tokens to perturb as a group
Browse files
geneformer/in_silico_perturber.py
CHANGED
@@ -58,6 +58,16 @@ def measure_length(example):
|
|
58 |
example["length"] = len(example["input_ids"])
|
59 |
return example
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
def forward_pass_single_cell(model, example_cell, layer_to_quant):
|
62 |
example_cell.set_format(type="torch")
|
63 |
input_data = example_cell["input_ids"]
|
@@ -75,8 +85,8 @@ def perturb_emb_by_index(emb, indices):
|
|
75 |
return emb[mask]
|
76 |
|
77 |
def delete_indices(example):
|
78 |
-
indices = example["perturb_index"]
|
79 |
-
if
|
80 |
indices = flatten_list(indices)
|
81 |
for index in sorted(indices, reverse=True):
|
82 |
del example["input_ids"][index]
|
@@ -84,10 +94,10 @@ def delete_indices(example):
|
|
84 |
|
85 |
# for genes_to_perturb = "all" where only genes within cell are overexpressed
|
86 |
def overexpress_indices(example):
|
87 |
-
|
88 |
-
if
|
89 |
-
|
90 |
-
for index in sorted(
|
91 |
example["input_ids"].insert(0, example["input_ids"].pop(index))
|
92 |
return example
|
93 |
|
@@ -165,7 +175,7 @@ def make_comparison_batch(original_emb_batch, indices_to_perturb, perturb_group)
|
|
165 |
continue
|
166 |
emb_list = []
|
167 |
start = 0
|
168 |
-
if
|
169 |
indices = flatten_list(indices)
|
170 |
for i in sorted(indices):
|
171 |
emb_list += [original_emb[start:i]]
|
@@ -724,8 +734,9 @@ class InSilicoPerturber:
|
|
724 |
state_embs_dict = None
|
725 |
else:
|
726 |
# get dictionary of average cell state embeddings for comparison
|
|
|
727 |
state_embs_dict = get_cell_state_avg_embs(model,
|
728 |
-
|
729 |
self.cell_states_to_model,
|
730 |
layer_to_quant,
|
731 |
self.pad_token_id,
|
@@ -758,14 +769,7 @@ class InSilicoPerturber:
|
|
758 |
"No cells remain after filtering. Check filtering criteria.")
|
759 |
raise
|
760 |
data_shuffled = data.shuffle(seed=42)
|
761 |
-
|
762 |
-
# if max number of cells is defined, then subsample to this max number
|
763 |
-
if self.max_ncells != None:
|
764 |
-
num_cells = min(self.max_ncells,num_cells)
|
765 |
-
data_subset = data_shuffled.select([i for i in range(num_cells)])
|
766 |
-
# sort dataset with largest cell first to encounter any memory errors earlier
|
767 |
-
data_sorted = data_subset.sort("length",reverse=True)
|
768 |
-
return data_sorted
|
769 |
|
770 |
# load model to GPU
|
771 |
def load_model(self, model_directory):
|
@@ -804,17 +808,29 @@ class InSilicoPerturber:
|
|
804 |
if self.anchor_token is not None:
|
805 |
def if_has_tokens_to_perturb(example):
|
806 |
return (len(set(example["input_ids"]).intersection(self.anchor_token))==len(self.anchor_token))
|
807 |
-
filtered_input_data = filtered_input_data.filter(if_has_tokens_to_perturb, num_proc=self.nproc)
|
808 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
809 |
if (self.tokens_to_perturb != "all") and (self.perturb_type != "overexpress"):
|
810 |
# minimum # genes needed for perturbation test
|
811 |
min_genes = len(self.tokens_to_perturb)
|
|
|
812 |
def if_has_tokens_to_perturb(example):
|
813 |
-
return (len(set(example["input_ids"]).intersection(self.tokens_to_perturb))
|
814 |
filtered_input_data = filtered_input_data.filter(if_has_tokens_to_perturb, num_proc=self.nproc)
|
815 |
-
|
|
|
|
|
|
|
|
|
816 |
cos_sims_dict = defaultdict(list)
|
817 |
pickle_batch = -1
|
|
|
818 |
|
819 |
# make perturbation batch w/ single perturbation in multiple cells
|
820 |
if self.perturb_group == True:
|
|
|
58 |
example["length"] = len(example["input_ids"])
|
59 |
return example
|
60 |
|
61 |
+
def downsample_and_sort(data_shuffled, max_ncells):
|
62 |
+
num_cells = len(data_shuffled)
|
63 |
+
# if max number of cells is defined, then subsample to this max number
|
64 |
+
if max_ncells != None:
|
65 |
+
num_cells = min(max_ncells,num_cells)
|
66 |
+
data_subset = data_shuffled.select([i for i in range(num_cells)])
|
67 |
+
# sort dataset with largest cell first to encounter any memory errors earlier
|
68 |
+
data_sorted = data_subset.sort("length",reverse=True)
|
69 |
+
return data_sorted
|
70 |
+
|
71 |
def forward_pass_single_cell(model, example_cell, layer_to_quant):
|
72 |
example_cell.set_format(type="torch")
|
73 |
input_data = example_cell["input_ids"]
|
|
|
85 |
return emb[mask]
|
86 |
|
87 |
def delete_indices(example):
|
88 |
+
indices = example["perturb_index"]
|
89 |
+
if any(isinstance(el, list) for el in indices):
|
90 |
indices = flatten_list(indices)
|
91 |
for index in sorted(indices, reverse=True):
|
92 |
del example["input_ids"][index]
|
|
|
94 |
|
95 |
# for genes_to_perturb = "all" where only genes within cell are overexpressed
|
96 |
def overexpress_indices(example):
|
97 |
+
indices = example["perturb_index"]
|
98 |
+
if any(isinstance(el, list) for el in indices):
|
99 |
+
indices = flatten_list(indices)
|
100 |
+
for index in sorted(indices, reverse=True):
|
101 |
example["input_ids"].insert(0, example["input_ids"].pop(index))
|
102 |
return example
|
103 |
|
|
|
175 |
continue
|
176 |
emb_list = []
|
177 |
start = 0
|
178 |
+
if any(isinstance(el, list) for el in indices):
|
179 |
indices = flatten_list(indices)
|
180 |
for i in sorted(indices):
|
181 |
emb_list += [original_emb[start:i]]
|
|
|
734 |
state_embs_dict = None
|
735 |
else:
|
736 |
# get dictionary of average cell state embeddings for comparison
|
737 |
+
downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells)
|
738 |
state_embs_dict = get_cell_state_avg_embs(model,
|
739 |
+
downsampled_data,
|
740 |
self.cell_states_to_model,
|
741 |
layer_to_quant,
|
742 |
self.pad_token_id,
|
|
|
769 |
"No cells remain after filtering. Check filtering criteria.")
|
770 |
raise
|
771 |
data_shuffled = data.shuffle(seed=42)
|
772 |
+
return data_shuffled
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
773 |
|
774 |
# load model to GPU
|
775 |
def load_model(self, model_directory):
|
|
|
808 |
if self.anchor_token is not None:
|
809 |
def if_has_tokens_to_perturb(example):
|
810 |
return (len(set(example["input_ids"]).intersection(self.anchor_token))==len(self.anchor_token))
|
811 |
+
filtered_input_data = filtered_input_data.filter(if_has_tokens_to_perturb, num_proc=self.nproc)
|
812 |
+
if len(filtered_input_data) == 0:
|
813 |
+
logger.error(
|
814 |
+
"No cells in dataset contain anchor gene.")
|
815 |
+
raise
|
816 |
+
else:
|
817 |
+
logger.info(f"# cells with anchor gene: {len(filtered_input_data)}")
|
818 |
+
|
819 |
if (self.tokens_to_perturb != "all") and (self.perturb_type != "overexpress"):
|
820 |
# minimum # genes needed for perturbation test
|
821 |
min_genes = len(self.tokens_to_perturb)
|
822 |
+
|
823 |
def if_has_tokens_to_perturb(example):
|
824 |
+
return (len(set(example["input_ids"]).intersection(self.tokens_to_perturb))>=min_genes)
|
825 |
filtered_input_data = filtered_input_data.filter(if_has_tokens_to_perturb, num_proc=self.nproc)
|
826 |
+
if len(filtered_input_data) == 0:
|
827 |
+
logger.error(
|
828 |
+
"No cells in dataset contain all genes to perturb as a group.")
|
829 |
+
raise
|
830 |
+
|
831 |
cos_sims_dict = defaultdict(list)
|
832 |
pickle_batch = -1
|
833 |
+
filtered_input_data = downsample_and_sort(filtered_input_data, self.max_ncells)
|
834 |
|
835 |
# make perturbation batch w/ single perturbation in multiple cells
|
836 |
if self.perturb_group == True:
|