Christina Theodoris commited on
Commit
c2679c4
·
1 Parent(s): 771c8bd

Fix isp perturb_group dims, reformat cell states dict to keyed, add attn mask

Browse files
examples/extract_and_plot_cell_embeddings.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
examples/in_silico_perturbation.ipynb CHANGED
@@ -33,7 +33,10 @@
33
  " emb_mode=\"cell\",\n",
34
  " cell_emb_style=\"mean_pool\",\n",
35
  " filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
36
- " cell_states_to_model={\"disease\":([\"dcm\"],[\"nf\"],[\"hcm\"])},\n",
 
 
 
37
  " max_ncells=2000,\n",
38
  " emb_layer=0,\n",
39
  " forward_batch_size=400,\n",
 
33
  " emb_mode=\"cell\",\n",
34
  " cell_emb_style=\"mean_pool\",\n",
35
  " filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
36
+ " cell_states_to_model={'state_key': 'disease', \n",
37
+ " 'start_state': 'dcm', \n",
38
+ " 'goal_state': 'nf', \n",
39
+ " 'alt_states': ['hcm']},\n",
40
  " max_ncells=2000,\n",
41
  " emb_layer=0,\n",
42
  " forward_batch_size=400,\n",
geneformer/emb_extractor.py CHANGED
@@ -43,32 +43,17 @@ from transformers import BertForMaskedLM, BertForTokenClassification, BertForSeq
43
 
44
  from .tokenizer import TOKEN_DICTIONARY_FILE
45
 
46
- from .in_silico_perturber import load_and_filter, \
47
- downsample_and_sort, \
 
 
48
  load_model, \
49
- quant_layers, \
50
- downsample_and_sort, \
51
  pad_tensor_list, \
52
- get_model_input_size
53
-
54
 
55
  logger = logging.getLogger(__name__)
56
 
57
- # get cell embeddings excluding padding
58
- def mean_nonpadding_embs(embs, original_lens):
59
- # mask based on padding lengths
60
- mask = torch.arange(embs.size(1)).unsqueeze(0).to("cuda") < original_lens.unsqueeze(1)
61
-
62
- # extend mask dimensions to match the embeddings tensor
63
- mask = mask.unsqueeze(2).expand_as(embs)
64
-
65
- # use the mask to zero out the embeddings in padded areas
66
- masked_embs = embs * mask.float()
67
-
68
- # sum and divide by the lengths to get the mean of non-padding embs
69
- mean_embs = masked_embs.sum(1) / original_lens.view(-1, 1).float()
70
- return mean_embs
71
-
72
  # average embedding position of goal cell states
73
  def get_embs(model,
74
  filtered_input_data,
@@ -99,7 +84,8 @@ def get_embs(model,
99
 
100
  with torch.no_grad():
101
  outputs = model(
102
- input_ids = input_data_minibatch.to("cuda")
 
103
  )
104
 
105
  embs_i = outputs.hidden_states[layer_to_quant]
 
43
 
44
  from .tokenizer import TOKEN_DICTIONARY_FILE
45
 
46
+ from .in_silico_perturber import downsample_and_sort, \
47
+ gen_attention_mask, \
48
+ get_model_input_size, \
49
+ load_and_filter, \
50
  load_model, \
51
+ mean_nonpadding_embs, \
 
52
  pad_tensor_list, \
53
+ quant_layers
 
54
 
55
  logger = logging.getLogger(__name__)
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # average embedding position of goal cell states
58
  def get_embs(model,
59
  filtered_input_data,
 
84
 
85
  with torch.no_grad():
86
  outputs = model(
87
+ input_ids = input_data_minibatch.to("cuda"),
88
+ attention_mask = gen_attention_mask(minibatch)
89
  )
90
 
91
  embs_i = outputs.hidden_states[layer_to_quant]
geneformer/in_silico_perturber.py CHANGED
@@ -13,7 +13,7 @@ Usage:
13
  emb_mode="cell",
14
  cell_emb_style="mean_pool",
15
  filter_data={"cell_type":["cardiomyocyte"]},
16
- cell_states_to_model={"disease":(["dcm"],["ctrl"],["hcm"])},
17
  max_ncells=None,
18
  emb_layer=-1,
19
  forward_batch_size=100,
@@ -106,10 +106,11 @@ def downsample_and_sort(data_shuffled, max_ncells):
106
  return data_sorted
107
 
108
  def get_possible_states(cell_states_to_model):
109
- if list(cell_states_to_model.values())[3] is not None:
110
- return list(cell_states_to_model.values())[1:3] + list(cell_states_to_model.values())[3]
111
- else:
112
- return list(cell_states_to_model.values())[1:3]
 
113
 
114
  def forward_pass_single_cell(model, example_cell, layer_to_quant):
115
  example_cell.set_format(type="torch")
@@ -152,6 +153,21 @@ def overexpress_tokens(example):
152
  [example["input_ids"].insert(0, token) for token in example["tokens_to_perturb"][::-1]]
153
  return example
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  def make_perturbation_batch(example_cell,
156
  perturb_type,
157
  tokens_to_perturb,
@@ -249,7 +265,7 @@ def get_cell_state_avg_embs(model,
249
 
250
  def filter_states(example):
251
  state_key = cell_states_to_model["state_key"]
252
- return example[state_key] in possible_state
253
  filtered_input_data_state = filtered_input_data.filter(filter_states, num_proc=num_proc)
254
  total_batch_length = len(filtered_input_data_state)
255
  if ((total_batch_length-1)/forward_batch_size).is_integer():
@@ -262,15 +278,17 @@ def get_cell_state_avg_embs(model,
262
  state_minibatch.set_format(type="torch")
263
 
264
  input_data_minibatch = state_minibatch["input_ids"]
265
- original_lens += [tensor.numel() for tensor in input_data_minibatch]
266
  input_data_minibatch = pad_tensor_list(input_data_minibatch,
267
  max_len,
268
  pad_token_id,
269
  model_input_size)
 
270
 
271
  with torch.no_grad():
272
  outputs = model(
273
- input_ids = input_data_minibatch.to("cuda")
 
274
  )
275
 
276
  state_embs_i = outputs.hidden_states[layer_to_quant]
@@ -278,11 +296,10 @@ def get_cell_state_avg_embs(model,
278
  del outputs
279
  del state_minibatch
280
  del input_data_minibatch
 
281
  del state_embs_i
282
  torch.cuda.empty_cache()
283
 
284
- # import here to avoid circular imports
285
- from .emb_extractor import mean_nonpadding_embs
286
  state_embs = torch.cat(state_embs_list)
287
  avg_state_emb = mean_nonpadding_embs(state_embs, torch.Tensor(original_lens).to("cuda"))
288
  avg_state_emb = torch.mean(avg_state_emb, dim=0, keepdim=True)
@@ -333,6 +350,7 @@ def quant_cos_sims(model,
333
  needs_pad_or_trunc = True
334
  else:
335
  needs_pad_or_trunc = False
 
336
 
337
  if needs_pad_or_trunc == True:
338
  max_len = min(max(minibatch_length_set),model_input_size)
@@ -345,14 +363,17 @@ def quant_cos_sims(model,
345
  perturbation_minibatch.set_format(type="torch")
346
 
347
  input_data_minibatch = perturbation_minibatch["input_ids"]
 
348
 
349
  # extract embeddings for perturbation minibatch
350
  with torch.no_grad():
351
  outputs = model(
352
- input_ids = input_data_minibatch.to("cuda")
 
353
  )
354
  del input_data_minibatch
355
  del perturbation_minibatch
 
356
 
357
  if len(indices_to_perturb)>1:
358
  minibatch_emb = torch.squeeze(outputs.hidden_states[layer_to_quant])
@@ -387,43 +408,29 @@ def quant_cos_sims(model,
387
  original_minibatch = original_minibatch.map(pad_or_trunc_example, num_proc=nproc)
388
  original_minibatch.set_format(type="torch")
389
  original_input_data_minibatch = original_minibatch["input_ids"]
 
390
  # extract embeddings for original minibatch
391
  with torch.no_grad():
392
  original_outputs = model(
393
- input_ids = original_input_data_minibatch.to("cuda")
 
394
  )
395
  del original_input_data_minibatch
396
  del original_minibatch
 
397
 
398
  if len(indices_to_perturb)>1:
399
  original_minibatch_emb = torch.squeeze(original_outputs.hidden_states[layer_to_quant])
400
  else:
401
  original_minibatch_emb = original_outputs.hidden_states[layer_to_quant]
402
 
403
- # remove perturbed index before calculating the cos sims
404
- def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
405
- # indices_to_remove is list of indices to remove
406
- gene_dim -= 1 # removing a dim in calling the function
407
- indices_to_keep = [i for i in range(emb.size()[gene_dim]) if i not in indices_to_remove]
408
- num_dims = emb.dim()
409
- emb_slice = [slice(None) if dim != gene_dim else indices_to_keep for dim in range(num_dims)]
410
- sliced_emb = emb[emb_slice]
411
- return sliced_emb
412
-
413
- # this could probably be optimized
414
  gene_dim = 1
415
-
416
- # current there's the case if a gene is not expressed and is being overexpressed,
417
- # the dimensions will be thrown off --> not removing indices to get around that issue
418
- # not sure what's the best way to handle it
419
  if perturb_type != "overexpress":
420
- original_minibatch_emb = torch.stack([
421
- remove_indices_from_emb(original_minibatch_emb[i, :, :], idx, gene_dim) for
422
- i, idx in enumerate(indices_to_perturb)
423
- ])
424
-
425
- # do the averaging here
426
-
427
 
428
  # cosine similarity between original emb and batch items
429
  if cell_states_to_model is None:
@@ -433,6 +440,7 @@ def quant_cos_sims(model,
433
  minibatch_comparison = make_comparison_batch(original_minibatch_emb,
434
  indices_to_perturb,
435
  perturb_group)
 
436
  cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
437
  elif cell_states_to_model is not None:
438
  for state in possible_states:
@@ -462,12 +470,17 @@ def quant_cos_sims(model,
462
  return cos_sims_vs_alt_dict
463
 
464
  # calculate cos sim shift of perturbation with respect to origin and alternative cell
465
- def cos_sim_shift(original_emb, minibatch_emb, alt_emb, perturb_group, original_minibatch_lengths = None, minibatch_lengths = None,):
 
 
 
 
 
466
  cos = torch.nn.CosineSimilarity(dim=2)
467
  if not perturb_group:
468
  original_emb = torch.mean(original_emb,dim=0,keepdim=True)
469
  original_emb = original_emb[None, :]
470
- origin_v_end = torch.squeeze(cos(original_emb, alt_emb))
471
  else:
472
  if original_emb.size() != minibatch_emb.size():
473
  logger.error(
@@ -476,26 +489,22 @@ def cos_sim_shift(original_emb, minibatch_emb, alt_emb, perturb_group, original_
476
  f"minibatch_emb is {minibatch_emb.size()}. "
477
  )
478
  raise
479
- from .emb_extractor import mean_nonpadding_embs
480
 
481
  if original_minibatch_lengths is not None:
482
  original_emb = mean_nonpadding_embs(original_emb, original_minibatch_lengths)
483
- # not sure if the else is necessary, but keeping it here in case
484
- else:
485
- original_emb = torch.mean(original_emb,dim=1,keepdim=True)
486
 
487
- alt_emb = torch.unsqueeze(alt_emb, 1)
488
- origin_v_end = cos(original_emb, alt_emb)
489
  origin_v_end = torch.squeeze(origin_v_end)
490
-
491
  if minibatch_lengths is not None:
492
  perturb_emb = mean_nonpadding_embs(minibatch_emb, minibatch_lengths)
493
  else:
494
  perturb_emb = torch.mean(minibatch_emb,dim=1,keepdim=True)
495
 
496
- perturb_v_end = cos(perturb_emb, alt_emb)
497
  perturb_v_end = torch.squeeze(perturb_v_end)
498
-
499
  return [(perturb_v_end-origin_v_end).to("cpu")]
500
 
501
  def pad_list(input_ids, pad_token_id, max_len):
@@ -555,6 +564,30 @@ def pad_tensor_list(tensor_list, dynamic_or_constant, pad_token_id, model_input_
555
  # return stacked tensors
556
  return torch.stack(tensor_list)
557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
  class InSilicoPerturber:
559
  valid_option_dict = {
560
  "perturb_type": {"delete","overexpress","inhibit","activate"},
@@ -640,9 +673,15 @@ class InSilicoPerturber:
640
  Otherwise, dictionary specifying .dataset column name and list of values to filter by.
641
  cell_states_to_model: None, dict
642
  Cell states to model if testing perturbations that achieve goal state change.
643
- Single-item dictionary with key being cell attribute (e.g. "disease").
644
- Value is tuple of three lists indicating start state, goal end state, and alternate possible end states.
645
- If no alternate possible end states, third list should be empty (i.e. the third list should be []).
 
 
 
 
 
 
646
  max_ncells : None, int
647
  Maximum number of cells to test.
648
  If None, will test all cells.
@@ -775,9 +814,14 @@ class InSilicoPerturber:
775
  if len(self.cell_states_to_model.items()) == 1:
776
  logger.warning(
777
  "The single value dictionary for cell_states_to_model will be " \
778
- "replaced with explicitly modeling start and end states. " \
779
- "Please specify state_key, start_state, end_state, and alt_states " \
780
- "in the cell_states_to_model dictionary for future use."
 
 
 
 
 
781
  )
782
  for key,value in self.cell_states_to_model.items():
783
  if (len(value) == 3) and isinstance(value, tuple):
@@ -786,7 +830,7 @@ class InSilicoPerturber:
786
  all_values = value[0]+value[1]+value[2]
787
  if len(all_values) == len(set(all_values)):
788
  continue
789
- # reformat to the new format
790
  state_values = flatten_list(list(self.cell_states_to_model.values()))
791
  self.cell_states_to_model = {
792
  "state_key": list(self.cell_states_to_model.keys())[0],
@@ -795,11 +839,13 @@ class InSilicoPerturber:
795
  "alt_states": state_values[2:][0]
796
  }
797
  elif set(self.cell_states_to_model.keys()) == {"state_key", "start_state", "goal_state", "alt_states"}:
798
- if self.cell_states_to_model["start_state"] is None or self.cell_states_to_model["goal_state"] is None:
 
 
799
  logger.error(
800
- "Please specify 'start_state' and 'goal_state' in cell_states_to_model.")
801
  raise
802
-
803
  if self.cell_states_to_model["start_state"] == self.cell_states_to_model["goal_state"]:
804
  logger.error(
805
  "All states must be unique.")
@@ -818,13 +864,13 @@ class InSilicoPerturber:
818
 
819
  else:
820
  logger.error(
821
- "states_to_model must only have the following four keys: 'state_key', 'start_state', 'goal_state', 'alt_states'." \
822
- "For example, cell_states_to_model={ \
823
- 'state_key': 'disease', \
824
- 'start_state': 'dcm', \
825
- 'goal_state': 'nf'', \
826
- 'alt_states': ['hcm', 'other1', 'other2'] \
827
- }"
828
  )
829
  raise
830
 
@@ -877,12 +923,13 @@ class InSilicoPerturber:
877
  if self.cell_states_to_model is None:
878
  state_embs_dict = None
879
  else:
880
- # make sure that all states are valid; save time on filtering
881
  state_name = self.cell_states_to_model["state_key"]
 
882
  for value in get_possible_states(self.cell_states_to_model):
883
- if value not in filtered_input_data[state_name]:
884
  logger.error(
885
- f"{value} is not a valid value in {state_name}.")
886
  raise
887
  # get dictionary of average cell state embeddings for comparison
888
  downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells)
@@ -1019,7 +1066,7 @@ class InSilicoPerturber:
1019
  data_list = []
1020
  for data in list(cos_sims_data.values()):
1021
  data_item = data.to("cuda")
1022
- data_list += [data_item]
1023
  cos_sims_dict[(perturbed_genes, "cell_emb")] += [tuple(data_list)]
1024
 
1025
  with open(f"{output_path_prefix}_raw.pickle", "wb") as fp:
@@ -1213,4 +1260,4 @@ class InSilicoPerturber:
1213
 
1214
  # save remainder cells
1215
  with open(f"{output_path_prefix}{pickle_batch}_raw.pickle", "wb") as fp:
1216
- pickle.dump(cos_sims_dict, fp)
 
13
  emb_mode="cell",
14
  cell_emb_style="mean_pool",
15
  filter_data={"cell_type":["cardiomyocyte"]},
16
+ cell_states_to_model={"state_key": "disease", "start_state": "dcm", "goal_state": "nf", "alt_states": ["hcm", "other1", "other2"]},
17
  max_ncells=None,
18
  emb_layer=-1,
19
  forward_batch_size=100,
 
106
  return data_sorted
107
 
108
  def get_possible_states(cell_states_to_model):
109
+ possible_states = []
110
+ for key in ["start_state","goal_state"]:
111
+ possible_states += [cell_states_to_model[key]]
112
+ possible_states += cell_states_to_model.get("alt_states",[])
113
+ return possible_states
114
 
115
  def forward_pass_single_cell(model, example_cell, layer_to_quant):
116
  example_cell.set_format(type="torch")
 
153
  [example["input_ids"].insert(0, token) for token in example["tokens_to_perturb"][::-1]]
154
  return example
155
 
156
+ def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
157
+ # indices_to_remove is list of indices to remove
158
+ indices_to_keep = [i for i in range(emb.size()[gene_dim]) if i not in indices_to_remove]
159
+ num_dims = emb.dim()
160
+ emb_slice = [slice(None) if dim != gene_dim else indices_to_keep for dim in range(num_dims)]
161
+ sliced_emb = emb[emb_slice]
162
+ return sliced_emb
163
+
164
+ def remove_indices_from_emb_batch(emb_batch, list_of_indices_to_remove, gene_dim):
165
+ output_batch = torch.stack([
166
+ remove_indices_from_emb(emb_batch[i, :, :], idx, gene_dim-1) for
167
+ i, idx in enumerate(list_of_indices_to_remove)
168
+ ])
169
+ return output_batch
170
+
171
  def make_perturbation_batch(example_cell,
172
  perturb_type,
173
  tokens_to_perturb,
 
265
 
266
  def filter_states(example):
267
  state_key = cell_states_to_model["state_key"]
268
+ return example[state_key] in [possible_state]
269
  filtered_input_data_state = filtered_input_data.filter(filter_states, num_proc=num_proc)
270
  total_batch_length = len(filtered_input_data_state)
271
  if ((total_batch_length-1)/forward_batch_size).is_integer():
 
278
  state_minibatch.set_format(type="torch")
279
 
280
  input_data_minibatch = state_minibatch["input_ids"]
281
+ original_lens += state_minibatch["length"]
282
  input_data_minibatch = pad_tensor_list(input_data_minibatch,
283
  max_len,
284
  pad_token_id,
285
  model_input_size)
286
+ attention_mask = gen_attention_mask(state_minibatch, max_len)
287
 
288
  with torch.no_grad():
289
  outputs = model(
290
+ input_ids = input_data_minibatch.to("cuda"),
291
+ attention_mask = attention_mask
292
  )
293
 
294
  state_embs_i = outputs.hidden_states[layer_to_quant]
 
296
  del outputs
297
  del state_minibatch
298
  del input_data_minibatch
299
+ del attention_mask
300
  del state_embs_i
301
  torch.cuda.empty_cache()
302
 
 
 
303
  state_embs = torch.cat(state_embs_list)
304
  avg_state_emb = mean_nonpadding_embs(state_embs, torch.Tensor(original_lens).to("cuda"))
305
  avg_state_emb = torch.mean(avg_state_emb, dim=0, keepdim=True)
 
350
  needs_pad_or_trunc = True
351
  else:
352
  needs_pad_or_trunc = False
353
+ max_len = max(minibatch_length_set)
354
 
355
  if needs_pad_or_trunc == True:
356
  max_len = min(max(minibatch_length_set),model_input_size)
 
363
  perturbation_minibatch.set_format(type="torch")
364
 
365
  input_data_minibatch = perturbation_minibatch["input_ids"]
366
+ attention_mask = gen_attention_mask(perturbation_minibatch, max_len)
367
 
368
  # extract embeddings for perturbation minibatch
369
  with torch.no_grad():
370
  outputs = model(
371
+ input_ids = input_data_minibatch.to("cuda"),
372
+ attention_mask = attention_mask
373
  )
374
  del input_data_minibatch
375
  del perturbation_minibatch
376
+ del attention_mask
377
 
378
  if len(indices_to_perturb)>1:
379
  minibatch_emb = torch.squeeze(outputs.hidden_states[layer_to_quant])
 
408
  original_minibatch = original_minibatch.map(pad_or_trunc_example, num_proc=nproc)
409
  original_minibatch.set_format(type="torch")
410
  original_input_data_minibatch = original_minibatch["input_ids"]
411
+ attention_mask = gen_attention_mask(original_minibatch, original_max_len)
412
  # extract embeddings for original minibatch
413
  with torch.no_grad():
414
  original_outputs = model(
415
+ input_ids = original_input_data_minibatch.to("cuda"),
416
+ attention_mask = attention_mask
417
  )
418
  del original_input_data_minibatch
419
  del original_minibatch
420
+ del attention_mask
421
 
422
  if len(indices_to_perturb)>1:
423
  original_minibatch_emb = torch.squeeze(original_outputs.hidden_states[layer_to_quant])
424
  else:
425
  original_minibatch_emb = original_outputs.hidden_states[layer_to_quant]
426
 
427
+ # embedding dimension of the genes
 
 
 
 
 
 
 
 
 
 
428
  gene_dim = 1
429
+ # exclude overexpression due to case when genes are not expressed but being overexpressed
 
 
 
430
  if perturb_type != "overexpress":
431
+ original_minibatch_emb = remove_indices_from_emb_batch(original_minibatch_emb,
432
+ indices_to_perturb,
433
+ gene_dim)
 
 
 
 
434
 
435
  # cosine similarity between original emb and batch items
436
  if cell_states_to_model is None:
 
440
  minibatch_comparison = make_comparison_batch(original_minibatch_emb,
441
  indices_to_perturb,
442
  perturb_group)
443
+
444
  cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
445
  elif cell_states_to_model is not None:
446
  for state in possible_states:
 
470
  return cos_sims_vs_alt_dict
471
 
472
  # calculate cos sim shift of perturbation with respect to origin and alternative cell
473
+ def cos_sim_shift(original_emb,
474
+ minibatch_emb,
475
+ end_emb,
476
+ perturb_group,
477
+ original_minibatch_lengths = None,
478
+ minibatch_lengths = None):
479
  cos = torch.nn.CosineSimilarity(dim=2)
480
  if not perturb_group:
481
  original_emb = torch.mean(original_emb,dim=0,keepdim=True)
482
  original_emb = original_emb[None, :]
483
+ origin_v_end = torch.squeeze(cos(original_emb, end_emb)) #test
484
  else:
485
  if original_emb.size() != minibatch_emb.size():
486
  logger.error(
 
489
  f"minibatch_emb is {minibatch_emb.size()}. "
490
  )
491
  raise
 
492
 
493
  if original_minibatch_lengths is not None:
494
  original_emb = mean_nonpadding_embs(original_emb, original_minibatch_lengths)
495
+ # else:
496
+ # original_emb = torch.mean(original_emb,dim=1,keepdim=True)
 
497
 
498
+ end_emb = torch.unsqueeze(end_emb, 1)
499
+ origin_v_end = cos(original_emb, end_emb)
500
  origin_v_end = torch.squeeze(origin_v_end)
 
501
  if minibatch_lengths is not None:
502
  perturb_emb = mean_nonpadding_embs(minibatch_emb, minibatch_lengths)
503
  else:
504
  perturb_emb = torch.mean(minibatch_emb,dim=1,keepdim=True)
505
 
506
+ perturb_v_end = cos(perturb_emb, end_emb)
507
  perturb_v_end = torch.squeeze(perturb_v_end)
 
508
  return [(perturb_v_end-origin_v_end).to("cpu")]
509
 
510
  def pad_list(input_ids, pad_token_id, max_len):
 
564
  # return stacked tensors
565
  return torch.stack(tensor_list)
566
 
567
+ def gen_attention_mask(minibatch_encoding, max_len = None):
568
+ if max_len == None:
569
+ max_len = max(minibatch_encoding["length"])
570
+ original_lens = minibatch_encoding["length"]
571
+ attention_mask = [[1]*original_len
572
+ +[0]*(max_len - original_len)
573
+ for original_len in original_lens]
574
+ return torch.tensor(attention_mask).to("cuda")
575
+
576
+ # get cell embeddings excluding padding
577
+ def mean_nonpadding_embs(embs, original_lens):
578
+ # mask based on padding lengths
579
+ mask = torch.arange(embs.size(1)).unsqueeze(0).to("cuda") < original_lens.unsqueeze(1)
580
+
581
+ # extend mask dimensions to match the embeddings tensor
582
+ mask = mask.unsqueeze(2).expand_as(embs)
583
+
584
+ # use the mask to zero out the embeddings in padded areas
585
+ masked_embs = embs * mask.float()
586
+
587
+ # sum and divide by the lengths to get the mean of non-padding embs
588
+ mean_embs = masked_embs.sum(1) / original_lens.view(-1, 1).float()
589
+ return mean_embs
590
+
591
  class InSilicoPerturber:
592
  valid_option_dict = {
593
  "perturb_type": {"delete","overexpress","inhibit","activate"},
 
673
  Otherwise, dictionary specifying .dataset column name and list of values to filter by.
674
  cell_states_to_model: None, dict
675
  Cell states to model if testing perturbations that achieve goal state change.
676
+ Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
677
+ state_key: key specifying name of column in .dataset that defines the start/goal states
678
+ start_state: value in the state_key column that specifies the start state
679
+ goal_state: value in the state_key column taht specifies the goal end state
680
+ alt_states: list of values in the state_key column that specify the alternate end states
681
+ For example: {"state_key": "disease",
682
+ "start_state": "dcm",
683
+ "goal_state": "nf",
684
+ "alt_states": ["hcm", "other1", "other2"]}
685
  max_ncells : None, int
686
  Maximum number of cells to test.
687
  If None, will test all cells.
 
814
  if len(self.cell_states_to_model.items()) == 1:
815
  logger.warning(
816
  "The single value dictionary for cell_states_to_model will be " \
817
+ "replaced with a dictionary with named keys for start, goal, and alternate states. " \
818
+ "Please specify state_key, start_state, goal_state, and alt_states " \
819
+ "in the cell_states_to_model dictionary for future use. " \
820
+ "For example, cell_states_to_model={" \
821
+ "'state_key': 'disease', " \
822
+ "'start_state': 'dcm', " \
823
+ "'goal_state': 'nf', " \
824
+ "'alt_states': ['hcm', 'other1', 'other2']}"
825
  )
826
  for key,value in self.cell_states_to_model.items():
827
  if (len(value) == 3) and isinstance(value, tuple):
 
830
  all_values = value[0]+value[1]+value[2]
831
  if len(all_values) == len(set(all_values)):
832
  continue
833
+ # reformat to the new named key format
834
  state_values = flatten_list(list(self.cell_states_to_model.values()))
835
  self.cell_states_to_model = {
836
  "state_key": list(self.cell_states_to_model.keys())[0],
 
839
  "alt_states": state_values[2:][0]
840
  }
841
  elif set(self.cell_states_to_model.keys()) == {"state_key", "start_state", "goal_state", "alt_states"}:
842
+ if (self.cell_states_to_model["state_key"] is None) \
843
+ or (self.cell_states_to_model["start_state"] is None) \
844
+ or (self.cell_states_to_model["goal_state"] is None):
845
  logger.error(
846
+ "Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model.")
847
  raise
848
+
849
  if self.cell_states_to_model["start_state"] == self.cell_states_to_model["goal_state"]:
850
  logger.error(
851
  "All states must be unique.")
 
864
 
865
  else:
866
  logger.error(
867
+ "cell_states_to_model must only have the following four keys: " \
868
+ "'state_key', 'start_state', 'goal_state', 'alt_states'." \
869
+ "For example, cell_states_to_model={" \
870
+ "'state_key': 'disease', " \
871
+ "'start_state': 'dcm', " \
872
+ "'goal_state': 'nf', " \
873
+ "'alt_states': ['hcm', 'other1', 'other2']}"
874
  )
875
  raise
876
 
 
923
  if self.cell_states_to_model is None:
924
  state_embs_dict = None
925
  else:
926
+ # confirm that all states are valid to prevent futile filtering
927
  state_name = self.cell_states_to_model["state_key"]
928
+ state_values = filtered_input_data[state_name]
929
  for value in get_possible_states(self.cell_states_to_model):
930
+ if value not in state_values:
931
  logger.error(
932
+ f"{value} is not present in the dataset's {state_name} attribute.")
933
  raise
934
  # get dictionary of average cell state embeddings for comparison
935
  downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells)
 
1066
  data_list = []
1067
  for data in list(cos_sims_data.values()):
1068
  data_item = data.to("cuda")
1069
+ data_list += [data_item[j].item()]
1070
  cos_sims_dict[(perturbed_genes, "cell_emb")] += [tuple(data_list)]
1071
 
1072
  with open(f"{output_path_prefix}_raw.pickle", "wb") as fp:
 
1260
 
1261
  # save remainder cells
1262
  with open(f"{output_path_prefix}{pickle_batch}_raw.pickle", "wb") as fp:
1263
+ pickle.dump(cos_sims_dict, fp)
geneformer/in_silico_perturber_stats.py CHANGED
@@ -6,7 +6,10 @@ Usage:
6
  ispstats = InSilicoPerturberStats(mode="goal_state_shift",
7
  combos=0,
8
  anchor_gene=None,
9
- cell_states_to_model={"disease":(["dcm"],["ctrl"],["hcm"])})
 
 
 
10
  ispstats.get_stats("path/to/input_data",
11
  None,
12
  "path/to/output_directory",
@@ -26,6 +29,8 @@ from scipy.stats import ranksums
26
  from sklearn.mixture import GaussianMixture
27
  from tqdm.notebook import trange, tqdm
28
 
 
 
29
  from .tokenizer import TOKEN_DICTIONARY_FILE
30
 
31
  GENE_NAME_ID_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
@@ -123,10 +128,10 @@ def isp_aggregate_grouped_perturb(cos_sims_df, dict_list):
123
 
124
  # stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
125
  def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model, genes_perturbed):
126
- cell_state_key = list(cell_states_to_model.keys())[0]
127
- if cell_states_to_model[cell_state_key][2] == []:
128
  alt_end_state_exists = False
129
- elif (len(cell_states_to_model[cell_state_key][2]) > 0) and (cell_states_to_model[cell_state_key][2] != [None]):
130
  alt_end_state_exists = True
131
 
132
  # for single perturbation in multiple cells, there are no random perturbations to compare to
@@ -231,10 +236,12 @@ def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model, genes_
231
  # quantify number of detections of each gene
232
  cos_sims_full_df["N_Detections"] = [n_detections(i, dict_list, "cell", None) for i in cos_sims_full_df["Gene"]]
233
 
234
- # sort by shift to desired state
235
- cos_sims_full_df = cos_sims_full_df.sort_values(by=["Shift_to_goal_end",
 
 
236
  "Goal_end_FDR"],
237
- ascending=[False,True])
238
 
239
  return cos_sims_full_df
240
 
@@ -272,9 +279,11 @@ def isp_stats_vs_null(cos_sims_df, dict_list, null_dict_list):
272
 
273
  cos_sims_full_df["Test_vs_null_FDR"] = get_fdr(cos_sims_full_df["Test_vs_null_pval"])
274
 
275
- cos_sims_full_df = cos_sims_full_df.sort_values(by=["Test_vs_null_avg_shift",
 
 
276
  "Test_vs_null_FDR"],
277
- ascending=[False,True])
278
  return cos_sims_full_df
279
 
280
  # stats for identifying perturbations with largest effect within a given set of cells
@@ -441,9 +450,15 @@ class InSilicoPerturberStats:
441
  analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene.
442
  cell_states_to_model: None, dict
443
  Cell states to model if testing perturbations that achieve goal state change.
444
- Single-item dictionary with key being cell attribute (e.g. "disease").
445
- Value is tuple of three lists indicating start state, goal end state, and alternate possible end states.
446
- If no alternate possible end states, third list should be empty (i.e. the third list should be []).
 
 
 
 
 
 
447
  token_dictionary_file : Path
448
  Path to pickle file containing token dictionary (Ensembl ID:token).
449
  gene_name_id_dictionary_file : Path
@@ -494,6 +509,17 @@ class InSilicoPerturberStats:
494
 
495
  if self.cell_states_to_model is not None:
496
  if len(self.cell_states_to_model.items()) == 1:
 
 
 
 
 
 
 
 
 
 
 
497
  for key,value in self.cell_states_to_model.items():
498
  if (len(value) == 3) and isinstance(value, tuple):
499
  if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
@@ -501,14 +527,50 @@ class InSilicoPerturberStats:
501
  all_values = value[0]+value[1]+value[2]
502
  if len(all_values) == len(set(all_values)):
503
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
504
  else:
505
  logger.error(
506
- "Cell states to model must be a single-item dictionary with " \
507
- "key being cell attribute (e.g. 'disease') and value being " \
508
- "tuple of three lists indicating start state, goal end state, and alternate possible end states. " \
509
- "Values should all be unique. " \
510
- "For example: {'disease':(['start_state'],['ctrl'],['alt_end'])}")
 
 
 
511
  raise
 
512
  if self.anchor_gene is not None:
513
  self.anchor_gene = None
514
  logger.warning(
@@ -565,6 +627,7 @@ class InSilicoPerturberStats:
565
  "Gene_name": gene name
566
  "Ensembl_ID": gene Ensembl ID
567
  "N_Detections": number of cells in which each gene or gene combination was detected in the input dataset
 
568
 
569
  "Shift_to_goal_end": cosine shift from start state towards goal end state in response to given perturbation
570
  "Shift_to_alt_end": cosine shift from start state towards alternate end state in response to given perturbation
 
6
  ispstats = InSilicoPerturberStats(mode="goal_state_shift",
7
  combos=0,
8
  anchor_gene=None,
9
+ cell_states_to_model={"state_key": "disease",
10
+ "start_state": "dcm",
11
+ "goal_state": "nf",
12
+ "alt_states": ["hcm", "other1", "other2"]})
13
  ispstats.get_stats("path/to/input_data",
14
  None,
15
  "path/to/output_directory",
 
29
  from sklearn.mixture import GaussianMixture
30
  from tqdm.notebook import trange, tqdm
31
 
32
+ from .in_silico_perturber import flatten_list
33
+
34
  from .tokenizer import TOKEN_DICTIONARY_FILE
35
 
36
  GENE_NAME_ID_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
 
128
 
129
  # stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
130
  def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model, genes_perturbed):
131
+ cell_state_key = cell_states_to_model["start_state"]
132
+ if "alt_states" not in cell_states_to_model.keys():
133
  alt_end_state_exists = False
134
+ elif (len(cell_states_to_model["alt_states"]) > 0) and (cell_states_to_model["alt_states"] != [None]):
135
  alt_end_state_exists = True
136
 
137
  # for single perturbation in multiple cells, there are no random perturbations to compare to
 
236
  # quantify number of detections of each gene
237
  cos_sims_full_df["N_Detections"] = [n_detections(i, dict_list, "cell", None) for i in cos_sims_full_df["Gene"]]
238
 
239
+ # sort by shift to desired state\
240
+ cos_sims_full_df["Sig"] = [1 if fdr<0.05 else 0 for fdr in cos_sims_full_df["Goal_end_FDR"]]
241
+ cos_sims_full_df = cos_sims_full_df.sort_values(by=["Sig",
242
+ "Shift_to_goal_end",
243
  "Goal_end_FDR"],
244
+ ascending=[False,False,True])
245
 
246
  return cos_sims_full_df
247
 
 
279
 
280
  cos_sims_full_df["Test_vs_null_FDR"] = get_fdr(cos_sims_full_df["Test_vs_null_pval"])
281
 
282
+ cos_sims_full_df["Sig"] = [1 if fdr<0.05 else 0 for fdr in cos_sims_full_df["Test_vs_null_FDR"]]
283
+ cos_sims_full_df = cos_sims_full_df.sort_values(by=["Sig",
284
+ "Test_vs_null_avg_shift",
285
  "Test_vs_null_FDR"],
286
+ ascending=[False,False,True])
287
  return cos_sims_full_df
288
 
289
  # stats for identifying perturbations with largest effect within a given set of cells
 
450
  analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene.
451
  cell_states_to_model: None, dict
452
  Cell states to model if testing perturbations that achieve goal state change.
453
+ Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
454
+ state_key: key specifying name of column in .dataset that defines the start/goal states
455
+ start_state: value in the state_key column that specifies the start state
456
+ goal_state: value in the state_key column taht specifies the goal end state
457
+ alt_states: list of values in the state_key column that specify the alternate end states
458
+ For example: {"state_key": "disease",
459
+ "start_state": "dcm",
460
+ "goal_state": "nf",
461
+ "alt_states": ["hcm", "other1", "other2"]}
462
  token_dictionary_file : Path
463
  Path to pickle file containing token dictionary (Ensembl ID:token).
464
  gene_name_id_dictionary_file : Path
 
509
 
510
  if self.cell_states_to_model is not None:
511
  if len(self.cell_states_to_model.items()) == 1:
512
+ logger.warning(
513
+ "The single value dictionary for cell_states_to_model will be " \
514
+ "replaced with a dictionary with named keys for start, goal, and alternate states. " \
515
+ "Please specify state_key, start_state, goal_state, and alt_states " \
516
+ "in the cell_states_to_model dictionary for future use. " \
517
+ "For example, cell_states_to_model={" \
518
+ "'state_key': 'disease', " \
519
+ "'start_state': 'dcm', " \
520
+ "'goal_state': 'nf', " \
521
+ "'alt_states': ['hcm', 'other1', 'other2']}"
522
+ )
523
  for key,value in self.cell_states_to_model.items():
524
  if (len(value) == 3) and isinstance(value, tuple):
525
  if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
 
527
  all_values = value[0]+value[1]+value[2]
528
  if len(all_values) == len(set(all_values)):
529
  continue
530
+ # reformat to the new named key format
531
+ state_values = flatten_list(list(self.cell_states_to_model.values()))
532
+ self.cell_states_to_model = {
533
+ "state_key": list(self.cell_states_to_model.keys())[0],
534
+ "start_state": state_values[0][0],
535
+ "goal_state": state_values[1][0],
536
+ "alt_states": state_values[2:][0]
537
+ }
538
+ elif set(self.cell_states_to_model.keys()) == {"state_key", "start_state", "goal_state", "alt_states"}:
539
+ if (self.cell_states_to_model["state_key"] is None) \
540
+ or (self.cell_states_to_model["start_state"] is None) \
541
+ or (self.cell_states_to_model["goal_state"] is None):
542
+ logger.error(
543
+ "Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model.")
544
+ raise
545
+
546
+ if self.cell_states_to_model["start_state"] == self.cell_states_to_model["goal_state"]:
547
+ logger.error(
548
+ "All states must be unique.")
549
+ raise
550
+
551
+ if self.cell_states_to_model["alt_states"] is not None:
552
+ if type(self.cell_states_to_model["alt_states"]) is not list:
553
+ logger.error(
554
+ "self.cell_states_to_model['alt_states'] must be a list (even if it is one element)."
555
+ )
556
+ raise
557
+ if len(self.cell_states_to_model["alt_states"])!= len(set(self.cell_states_to_model["alt_states"])):
558
+ logger.error(
559
+ "All states must be unique.")
560
+ raise
561
+
562
  else:
563
  logger.error(
564
+ "cell_states_to_model must only have the following four keys: " \
565
+ "'state_key', 'start_state', 'goal_state', 'alt_states'." \
566
+ "For example, cell_states_to_model={" \
567
+ "'state_key': 'disease', " \
568
+ "'start_state': 'dcm', " \
569
+ "'goal_state': 'nf', " \
570
+ "'alt_states': ['hcm', 'other1', 'other2']}"
571
+ )
572
  raise
573
+
574
  if self.anchor_gene is not None:
575
  self.anchor_gene = None
576
  logger.warning(
 
627
  "Gene_name": gene name
628
  "Ensembl_ID": gene Ensembl ID
629
  "N_Detections": number of cells in which each gene or gene combination was detected in the input dataset
630
+ "Sig": 1 if FDR<0.05, otherwise 0
631
 
632
  "Shift_to_goal_end": cosine shift from start state towards goal end state in response to given perturbation
633
  "Shift_to_alt_end": cosine shift from start state towards alternate end state in response to given perturbation