Update geneformer/perturber_utils.py

#361
by hchen725 - opened
Files changed (1) hide show
  1. geneformer/perturber_utils.py +44 -13
geneformer/perturber_utils.py CHANGED
@@ -228,21 +228,41 @@ def overexpress_indices(example):
228
  example["length"] = len(example["input_ids"])
229
  return example
230
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
233
- def overexpress_tokens(example, max_len):
 
234
  # -100 indicates tokens to overexpress are not present in rank value encoding
235
  if example["perturb_index"] != [-100]:
236
  example = delete_indices(example)
237
- [
238
- example["input_ids"].insert(0, token)
239
- for token in example["tokens_to_perturb"][::-1]
240
- ]
 
 
 
 
 
 
 
241
 
242
  # truncate to max input size, must also truncate original emb to be comparable
243
- if len(example["input_ids"]) > max_len:
244
- example["input_ids"] = example["input_ids"][0:max_len]
245
-
 
 
246
  example["length"] = len(example["input_ids"])
247
  return example
248
 
@@ -259,6 +279,12 @@ def truncate_by_n_overflow(example):
259
  example["length"] = len(example["input_ids"])
260
  return example
261
 
 
 
 
 
 
 
262
 
263
  def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
264
  # indices_to_remove is list of indices to remove
@@ -321,7 +347,7 @@ def remove_perturbed_indices_set(
321
 
322
 
323
  def make_perturbation_batch(
324
- example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
325
  ) -> tuple[Dataset, List[int]]:
326
  if combo_lvl == 0 and tokens_to_perturb == "all":
327
  if perturb_type in ["overexpress", "activate"]:
@@ -383,9 +409,14 @@ def make_perturbation_batch(
383
  delete_indices, num_proc=num_proc_i
384
  )
385
  elif perturb_type == "overexpress":
386
- perturbation_dataset = perturbation_dataset.map(
387
- overexpress_indices, num_proc=num_proc_i
388
- )
 
 
 
 
 
389
 
390
  perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
391
 
@@ -758,4 +789,4 @@ class GeneIdHandler:
758
  return self.ens_to_symbol(self.token_to_ens(token))
759
 
760
  def symbol_to_token(self, symbol):
761
- return self.ens_to_token(self.symbol_to_ens(symbol))
 
228
  example["length"] = len(example["input_ids"])
229
  return example
230
 
231
+ # if CLS token present, move to 1st rather than 0th position
232
+ def overexpress_indices_special(example):
233
+ indices = example["perturb_index"]
234
+ if any(isinstance(el, list) for el in indices):
235
+ indices = flatten_list(indices)
236
+ for index in sorted(indices, reverse=True):
237
+ example["input_ids"].insert(1, example["input_ids"].pop(index))
238
+
239
+ example["length"] = len(example["input_ids"])
240
+ return example
241
 
242
  # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
243
+ def overexpress_tokens(example, special_token):
244
+ original_len = example["length"]
245
  # -100 indicates tokens to overexpress are not present in rank value encoding
246
  if example["perturb_index"] != [-100]:
247
  example = delete_indices(example)
248
+ if special_token:
249
+ [
250
+ example["input_ids"].insert(1, token)
251
+ for token in example["tokens_to_perturb"][::-1]
252
+ ]
253
+ else:
254
+ example = overexpress_indices(example)
255
+ [
256
+ example["input_ids"].insert(0, token)
257
+ for token in example["tokens_to_perturb"][::-1]
258
+ ]
259
 
260
  # truncate to max input size, must also truncate original emb to be comparable
261
+ if len(example["input_ids"]) > original_len:
262
+ if special_token:
263
+ del example["input_ids"][original_len-1]
264
+ else:
265
+ example["input_ids"] = example["input_ids"][0:original_len]
266
  example["length"] = len(example["input_ids"])
267
  return example
268
 
 
279
  example["length"] = len(example["input_ids"])
280
  return example
281
 
282
+ def truncate_by_n_overflow_special(example):
283
+ new_max_len = example["length"] - example["n_overflow"]
284
+ example["input_ids"] = example["input_ids"][0:new_max_len-1]+[example["input_ids"][-1]]
285
+ example["length"] = len(example["input_ids"])
286
+ return example
287
+
288
 
289
  def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
290
  # indices_to_remove is list of indices to remove
 
347
 
348
 
349
  def make_perturbation_batch(
350
+ example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc, special_token
351
  ) -> tuple[Dataset, List[int]]:
352
  if combo_lvl == 0 and tokens_to_perturb == "all":
353
  if perturb_type in ["overexpress", "activate"]:
 
409
  delete_indices, num_proc=num_proc_i
410
  )
411
  elif perturb_type == "overexpress":
412
+ if special_token:
413
+ perturbation_dataset = perturbation_dataset.map(
414
+ overexpress_indices_special, num_proc=num_proc_i
415
+ )
416
+ else:
417
+ perturbation_dataset = perturbation_dataset.map(
418
+ overexpress_indices, num_proc=num_proc_i
419
+ )
420
 
421
  perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
422
 
 
789
  return self.ens_to_symbol(self.token_to_ens(token))
790
 
791
  def symbol_to_token(self, symbol):
792
+ return self.ens_to_token(self.symbol_to_ens(symbol))