hchen725 commited on
Commit
c6c8c88
·
verified ·
1 Parent(s): 471eefc

Update geneformer/emb_extractor.py

Browse files

add custom token dictionary, exclude CLS and EOS from cell mean if present in token dictionary, add extracting just CLS

Files changed (1) hide show
  1. geneformer/emb_extractor.py +49 -12
geneformer/emb_extractor.py CHANGED
@@ -38,12 +38,13 @@ def get_embs(
38
  layer_to_quant,
39
  pad_token_id,
40
  forward_batch_size,
 
41
  summary_stat=None,
42
  silent=False,
43
  ):
44
  model_input_size = pu.get_model_input_size(model)
45
  total_batch_length = len(filtered_input_data)
46
-
47
  if summary_stat is None:
48
  embs_list = []
49
  elif summary_stat is not None:
@@ -67,9 +68,21 @@ def get_embs(
67
  k: [TDigest() for _ in range(emb_dims)] for k in gene_set
68
  }
69
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  overall_max_len = 0
71
-
72
- for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
73
  max_range = min(i + forward_batch_size, total_batch_length)
74
 
75
  minibatch = filtered_input_data.select([i for i in range(i, max_range)])
@@ -90,9 +103,16 @@ def get_embs(
90
  )
91
 
92
  embs_i = outputs.hidden_states[layer_to_quant]
93
-
94
  if emb_mode == "cell":
95
- mean_embs = pu.mean_nonpadding_embs(embs_i, original_lens)
 
 
 
 
 
 
 
96
  if summary_stat is None:
97
  embs_list.append(mean_embs)
98
  elif summary_stat is not None:
@@ -121,7 +141,13 @@ def get_embs(
121
  accumulate_tdigests(
122
  embs_tdigests_dict[int(k)], dict_h[k], emb_dims
123
  )
124
-
 
 
 
 
 
 
125
  overall_max_len = max(overall_max_len, max_len)
126
  del outputs
127
  del minibatch
@@ -129,7 +155,8 @@ def get_embs(
129
  del embs_i
130
 
131
  torch.cuda.empty_cache()
132
-
 
133
  if summary_stat is None:
134
  if emb_mode == "cell":
135
  embs_stack = torch.cat(embs_list, dim=0)
@@ -142,6 +169,8 @@ def get_embs(
142
  1,
143
  pu.pad_3d_tensor,
144
  )
 
 
145
 
146
  # calculate summary stat embs from approximated tdigests
147
  elif summary_stat is not None:
@@ -348,7 +377,7 @@ def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
348
  bbox_to_anchor=(0.5, 1),
349
  facecolor="white",
350
  )
351
-
352
  plt.savefig(output_file, bbox_inches="tight")
353
 
354
 
@@ -356,7 +385,7 @@ class EmbExtractor:
356
  valid_option_dict = {
357
  "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
358
  "num_classes": {int},
359
- "emb_mode": {"cell", "gene"},
360
  "cell_emb_style": {"mean_pool"},
361
  "gene_emb_style": {"mean_pool"},
362
  "filter_data": {None, dict},
@@ -365,6 +394,7 @@ class EmbExtractor:
365
  "emb_label": {None, list},
366
  "labels_to_plot": {None, list},
367
  "forward_batch_size": {int},
 
368
  "nproc": {int},
369
  "summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
370
  }
@@ -384,7 +414,7 @@ class EmbExtractor:
384
  forward_batch_size=100,
385
  nproc=4,
386
  summary_stat=None,
387
- token_dictionary_file=TOKEN_DICTIONARY_FILE,
388
  ):
389
  """
390
  Initialize embedding extractor.
@@ -434,6 +464,7 @@ class EmbExtractor:
434
  | Non-exact recommended if encountering memory constraints while generating goal embedding positions.
435
  | Non-exact is slower but more memory-efficient.
436
  token_dictionary_file : Path
 
437
  | Path to pickle file containing token dictionary (Ensembl ID:token).
438
 
439
  **Examples:**
@@ -463,6 +494,7 @@ class EmbExtractor:
463
  self.emb_layer = emb_layer
464
  self.emb_label = emb_label
465
  self.labels_to_plot = labels_to_plot
 
466
  self.forward_batch_size = forward_batch_size
467
  self.nproc = nproc
468
  if (summary_stat is not None) and ("exact" in summary_stat):
@@ -475,6 +507,8 @@ class EmbExtractor:
475
  self.validate_options()
476
 
477
  # load token dictionary (Ensembl IDs:token)
 
 
478
  with open(token_dictionary_file, "rb") as f:
479
  self.gene_token_dict = pickle.load(f)
480
 
@@ -490,7 +524,7 @@ class EmbExtractor:
490
  continue
491
  valid_type = False
492
  for option in valid_options:
493
- if (option in [int, list, dict, bool]) and isinstance(
494
  attr_value, option
495
  ):
496
  valid_type = True
@@ -570,6 +604,7 @@ class EmbExtractor:
570
  layer_to_quant,
571
  self.pad_token_id,
572
  self.forward_batch_size,
 
573
  self.summary_stat,
574
  )
575
 
@@ -584,6 +619,8 @@ class EmbExtractor:
584
  elif self.summary_stat is not None:
585
  embs_df = pd.DataFrame(embs).T
586
  embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
 
 
587
 
588
  # save embeddings to output_path
589
  if cell_state is None:
@@ -781,7 +818,7 @@ class EmbExtractor:
781
  f"not present in provided embeddings dataframe."
782
  )
783
  continue
784
- output_prefix_label = "_" + output_prefix + f"_umap_{label}"
785
  output_file = (
786
  Path(output_directory) / output_prefix_label
787
  ).with_suffix(".pdf")
 
38
  layer_to_quant,
39
  pad_token_id,
40
  forward_batch_size,
41
+ token_gene_dict,
42
  summary_stat=None,
43
  silent=False,
44
  ):
45
  model_input_size = pu.get_model_input_size(model)
46
  total_batch_length = len(filtered_input_data)
47
+
48
  if summary_stat is None:
49
  embs_list = []
50
  elif summary_stat is not None:
 
68
  k: [TDigest() for _ in range(emb_dims)] for k in gene_set
69
  }
70
 
71
+ # Check if CLS and SEP token is present in the token dictionary
72
+ lowercase_token_gene_dict = {k: v.lower() for k, v in token_gene_dict.items()}
73
+ cls_present = any("cls" in value for value in lowercase_token_gene_dict.values())
74
+ sep_present = any("sep" in value for value in lowercase_token_gene_dict.values())
75
+ if emb_mode == "cls":
76
+ assert cls_present, "CLS token missing in token dictionary"
77
+ else:
78
+ if cls_present:
79
+ logger.warning("CLS token present in token dictionary, excluding from average")
80
+ if sep_present:
81
+ logger.warning("SEP token present in token dictionary, excluding from average")
82
+
83
  overall_max_len = 0
84
+
85
+ for i in trange(0, total_batch_length, forward_batch_size, leave = (not silent)):
86
  max_range = min(i + forward_batch_size, total_batch_length)
87
 
88
  minibatch = filtered_input_data.select([i for i in range(i, max_range)])
 
103
  )
104
 
105
  embs_i = outputs.hidden_states[layer_to_quant]
106
+
107
  if emb_mode == "cell":
108
+ if cls_present:
109
+ non_cls_embs = embs_i[:, 1:, :] # Get all layers except the embs
110
+ if sep_present:
111
+ mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 2)
112
+ else:
113
+ mean_embs = pu.mean_nonpadding_embs(non_cls_embs, origina_lens - 1)
114
+ else:
115
+ mean_embs = pu.mean_nonpadding_embs(embs_i, original_lens)
116
  if summary_stat is None:
117
  embs_list.append(mean_embs)
118
  elif summary_stat is not None:
 
141
  accumulate_tdigests(
142
  embs_tdigests_dict[int(k)], dict_h[k], emb_dims
143
  )
144
+ del embs_h
145
+ del dict_h
146
+ elif emb_mode == "cls":
147
+ cls_embs = embs_i[:,0,:].cpu() # CLS token layer
148
+ embs_list.append(cls_embs)
149
+ del cls_embs
150
+
151
  overall_max_len = max(overall_max_len, max_len)
152
  del outputs
153
  del minibatch
 
155
  del embs_i
156
 
157
  torch.cuda.empty_cache()
158
+
159
+
160
  if summary_stat is None:
161
  if emb_mode == "cell":
162
  embs_stack = torch.cat(embs_list, dim=0)
 
169
  1,
170
  pu.pad_3d_tensor,
171
  )
172
+ elif emb_mode == "cls":
173
+ embs_stack = torch.cat(embs_list, dim=0)
174
 
175
  # calculate summary stat embs from approximated tdigests
176
  elif summary_stat is not None:
 
377
  bbox_to_anchor=(0.5, 1),
378
  facecolor="white",
379
  )
380
+ print(f"Output file: {output_file}")
381
  plt.savefig(output_file, bbox_inches="tight")
382
 
383
 
 
385
  valid_option_dict = {
386
  "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
387
  "num_classes": {int},
388
+ "emb_mode": {"cell", "gene", "cls"},
389
  "cell_emb_style": {"mean_pool"},
390
  "gene_emb_style": {"mean_pool"},
391
  "filter_data": {None, dict},
 
394
  "emb_label": {None, list},
395
  "labels_to_plot": {None, list},
396
  "forward_batch_size": {int},
397
+ "token_dictionary_file" : {None, str},
398
  "nproc": {int},
399
  "summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
400
  }
 
414
  forward_batch_size=100,
415
  nproc=4,
416
  summary_stat=None,
417
+ token_dictionary_file=None,
418
  ):
419
  """
420
  Initialize embedding extractor.
 
464
  | Non-exact recommended if encountering memory constraints while generating goal embedding positions.
465
  | Non-exact is slower but more memory-efficient.
466
  token_dictionary_file : Path
467
+ | Default is to the geneformer token dictionary
468
  | Path to pickle file containing token dictionary (Ensembl ID:token).
469
 
470
  **Examples:**
 
494
  self.emb_layer = emb_layer
495
  self.emb_label = emb_label
496
  self.labels_to_plot = labels_to_plot
497
+ self.token_dictionary_file = token_dictionary_file
498
  self.forward_batch_size = forward_batch_size
499
  self.nproc = nproc
500
  if (summary_stat is not None) and ("exact" in summary_stat):
 
507
  self.validate_options()
508
 
509
  # load token dictionary (Ensembl IDs:token)
510
+ if self.token_dictionary_file is None:
511
+ token_dictionary_file = TOKEN_DICTIONARY_FILE
512
  with open(token_dictionary_file, "rb") as f:
513
  self.gene_token_dict = pickle.load(f)
514
 
 
524
  continue
525
  valid_type = False
526
  for option in valid_options:
527
+ if (option in [int, list, dict, bool, str]) and isinstance(
528
  attr_value, option
529
  ):
530
  valid_type = True
 
604
  layer_to_quant,
605
  self.pad_token_id,
606
  self.forward_batch_size,
607
+ self.token_gene_dict,
608
  self.summary_stat,
609
  )
610
 
 
619
  elif self.summary_stat is not None:
620
  embs_df = pd.DataFrame(embs).T
621
  embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
622
+ elif self.emb_mode == "cls":
623
+ embs_df = label_cell_embs(embs, downsampled_data, self.emb_label)
624
 
625
  # save embeddings to output_path
626
  if cell_state is None:
 
818
  f"not present in provided embeddings dataframe."
819
  )
820
  continue
821
+ output_prefix_label = output_prefix + f"_umap_{label}"
822
  output_file = (
823
  Path(output_directory) / output_prefix_label
824
  ).with_suffix(".pdf")