Update geneformer/perturber_utils.py
#361
by
hchen725
- opened
- geneformer/perturber_utils.py +44 -13
geneformer/perturber_utils.py
CHANGED
@@ -228,21 +228,41 @@ def overexpress_indices(example):
|
|
228 |
example["length"] = len(example["input_ids"])
|
229 |
return example
|
230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
|
232 |
# for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
|
233 |
-
def overexpress_tokens(example,
|
|
|
234 |
# -100 indicates tokens to overexpress are not present in rank value encoding
|
235 |
if example["perturb_index"] != [-100]:
|
236 |
example = delete_indices(example)
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
|
242 |
# truncate to max input size, must also truncate original emb to be comparable
|
243 |
-
if len(example["input_ids"]) >
|
244 |
-
|
245 |
-
|
|
|
|
|
246 |
example["length"] = len(example["input_ids"])
|
247 |
return example
|
248 |
|
@@ -259,6 +279,12 @@ def truncate_by_n_overflow(example):
|
|
259 |
example["length"] = len(example["input_ids"])
|
260 |
return example
|
261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
|
263 |
def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
|
264 |
# indices_to_remove is list of indices to remove
|
@@ -321,7 +347,7 @@ def remove_perturbed_indices_set(
|
|
321 |
|
322 |
|
323 |
def make_perturbation_batch(
|
324 |
-
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
|
325 |
) -> tuple[Dataset, List[int]]:
|
326 |
if combo_lvl == 0 and tokens_to_perturb == "all":
|
327 |
if perturb_type in ["overexpress", "activate"]:
|
@@ -383,9 +409,14 @@ def make_perturbation_batch(
|
|
383 |
delete_indices, num_proc=num_proc_i
|
384 |
)
|
385 |
elif perturb_type == "overexpress":
|
386 |
-
|
387 |
-
|
388 |
-
|
|
|
|
|
|
|
|
|
|
|
389 |
|
390 |
perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
|
391 |
|
@@ -758,4 +789,4 @@ class GeneIdHandler:
|
|
758 |
return self.ens_to_symbol(self.token_to_ens(token))
|
759 |
|
760 |
def symbol_to_token(self, symbol):
|
761 |
-
return self.ens_to_token(self.symbol_to_ens(symbol))
|
|
|
228 |
example["length"] = len(example["input_ids"])
|
229 |
return example
|
230 |
|
231 |
+
# if CLS token present, move to 1st rather than 0th position
|
232 |
+
def overexpress_indices_special(example):
|
233 |
+
indices = example["perturb_index"]
|
234 |
+
if any(isinstance(el, list) for el in indices):
|
235 |
+
indices = flatten_list(indices)
|
236 |
+
for index in sorted(indices, reverse=True):
|
237 |
+
example["input_ids"].insert(1, example["input_ids"].pop(index))
|
238 |
+
|
239 |
+
example["length"] = len(example["input_ids"])
|
240 |
+
return example
|
241 |
|
242 |
# for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
|
243 |
+
def overexpress_tokens(example, special_token):
|
244 |
+
original_len = example["length"]
|
245 |
# -100 indicates tokens to overexpress are not present in rank value encoding
|
246 |
if example["perturb_index"] != [-100]:
|
247 |
example = delete_indices(example)
|
248 |
+
if special_token:
|
249 |
+
[
|
250 |
+
example["input_ids"].insert(1, token)
|
251 |
+
for token in example["tokens_to_perturb"][::-1]
|
252 |
+
]
|
253 |
+
else:
|
254 |
+
example = overexpress_indices(example)
|
255 |
+
[
|
256 |
+
example["input_ids"].insert(0, token)
|
257 |
+
for token in example["tokens_to_perturb"][::-1]
|
258 |
+
]
|
259 |
|
260 |
# truncate to max input size, must also truncate original emb to be comparable
|
261 |
+
if len(example["input_ids"]) > original_len:
|
262 |
+
if special_token:
|
263 |
+
del example["input_ids"][original_len-1]
|
264 |
+
else:
|
265 |
+
example["input_ids"] = example["input_ids"][0:original_len]
|
266 |
example["length"] = len(example["input_ids"])
|
267 |
return example
|
268 |
|
|
|
279 |
example["length"] = len(example["input_ids"])
|
280 |
return example
|
281 |
|
282 |
+
def truncate_by_n_overflow_special(example):
|
283 |
+
new_max_len = example["length"] - example["n_overflow"]
|
284 |
+
example["input_ids"] = example["input_ids"][0:new_max_len-1]+[example["input_ids"][-1]]
|
285 |
+
example["length"] = len(example["input_ids"])
|
286 |
+
return example
|
287 |
+
|
288 |
|
289 |
def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
|
290 |
# indices_to_remove is list of indices to remove
|
|
|
347 |
|
348 |
|
349 |
def make_perturbation_batch(
|
350 |
+
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc, special_token
|
351 |
) -> tuple[Dataset, List[int]]:
|
352 |
if combo_lvl == 0 and tokens_to_perturb == "all":
|
353 |
if perturb_type in ["overexpress", "activate"]:
|
|
|
409 |
delete_indices, num_proc=num_proc_i
|
410 |
)
|
411 |
elif perturb_type == "overexpress":
|
412 |
+
if special_token:
|
413 |
+
perturbation_dataset = perturbation_dataset.map(
|
414 |
+
overexpress_indices_special, num_proc=num_proc_i
|
415 |
+
)
|
416 |
+
else:
|
417 |
+
perturbation_dataset = perturbation_dataset.map(
|
418 |
+
overexpress_indices, num_proc=num_proc_i
|
419 |
+
)
|
420 |
|
421 |
perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
|
422 |
|
|
|
789 |
return self.ens_to_symbol(self.token_to_ens(token))
|
790 |
|
791 |
def symbol_to_token(self, symbol):
|
792 |
+
return self.ens_to_token(self.symbol_to_ens(symbol))
|