Update geneformer/emb_extractor.py
Browse filesadd custom token dictionary, exclude CLS and EOS from cell mean if present in token dictionary, add extracting just CLS
- geneformer/emb_extractor.py +49 -12
geneformer/emb_extractor.py
CHANGED
@@ -38,12 +38,13 @@ def get_embs(
|
|
38 |
layer_to_quant,
|
39 |
pad_token_id,
|
40 |
forward_batch_size,
|
|
|
41 |
summary_stat=None,
|
42 |
silent=False,
|
43 |
):
|
44 |
model_input_size = pu.get_model_input_size(model)
|
45 |
total_batch_length = len(filtered_input_data)
|
46 |
-
|
47 |
if summary_stat is None:
|
48 |
embs_list = []
|
49 |
elif summary_stat is not None:
|
@@ -67,9 +68,21 @@ def get_embs(
|
|
67 |
k: [TDigest() for _ in range(emb_dims)] for k in gene_set
|
68 |
}
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
overall_max_len = 0
|
71 |
-
|
72 |
-
for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
|
73 |
max_range = min(i + forward_batch_size, total_batch_length)
|
74 |
|
75 |
minibatch = filtered_input_data.select([i for i in range(i, max_range)])
|
@@ -90,9 +103,16 @@ def get_embs(
|
|
90 |
)
|
91 |
|
92 |
embs_i = outputs.hidden_states[layer_to_quant]
|
93 |
-
|
94 |
if emb_mode == "cell":
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
if summary_stat is None:
|
97 |
embs_list.append(mean_embs)
|
98 |
elif summary_stat is not None:
|
@@ -121,7 +141,13 @@ def get_embs(
|
|
121 |
accumulate_tdigests(
|
122 |
embs_tdigests_dict[int(k)], dict_h[k], emb_dims
|
123 |
)
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
overall_max_len = max(overall_max_len, max_len)
|
126 |
del outputs
|
127 |
del minibatch
|
@@ -129,7 +155,8 @@ def get_embs(
|
|
129 |
del embs_i
|
130 |
|
131 |
torch.cuda.empty_cache()
|
132 |
-
|
|
|
133 |
if summary_stat is None:
|
134 |
if emb_mode == "cell":
|
135 |
embs_stack = torch.cat(embs_list, dim=0)
|
@@ -142,6 +169,8 @@ def get_embs(
|
|
142 |
1,
|
143 |
pu.pad_3d_tensor,
|
144 |
)
|
|
|
|
|
145 |
|
146 |
# calculate summary stat embs from approximated tdigests
|
147 |
elif summary_stat is not None:
|
@@ -348,7 +377,7 @@ def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
|
|
348 |
bbox_to_anchor=(0.5, 1),
|
349 |
facecolor="white",
|
350 |
)
|
351 |
-
|
352 |
plt.savefig(output_file, bbox_inches="tight")
|
353 |
|
354 |
|
@@ -356,7 +385,7 @@ class EmbExtractor:
|
|
356 |
valid_option_dict = {
|
357 |
"model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
|
358 |
"num_classes": {int},
|
359 |
-
"emb_mode": {"cell", "gene"},
|
360 |
"cell_emb_style": {"mean_pool"},
|
361 |
"gene_emb_style": {"mean_pool"},
|
362 |
"filter_data": {None, dict},
|
@@ -365,6 +394,7 @@ class EmbExtractor:
|
|
365 |
"emb_label": {None, list},
|
366 |
"labels_to_plot": {None, list},
|
367 |
"forward_batch_size": {int},
|
|
|
368 |
"nproc": {int},
|
369 |
"summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
|
370 |
}
|
@@ -384,7 +414,7 @@ class EmbExtractor:
|
|
384 |
forward_batch_size=100,
|
385 |
nproc=4,
|
386 |
summary_stat=None,
|
387 |
-
token_dictionary_file=
|
388 |
):
|
389 |
"""
|
390 |
Initialize embedding extractor.
|
@@ -434,6 +464,7 @@ class EmbExtractor:
|
|
434 |
| Non-exact recommended if encountering memory constraints while generating goal embedding positions.
|
435 |
| Non-exact is slower but more memory-efficient.
|
436 |
token_dictionary_file : Path
|
|
|
437 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
438 |
|
439 |
**Examples:**
|
@@ -463,6 +494,7 @@ class EmbExtractor:
|
|
463 |
self.emb_layer = emb_layer
|
464 |
self.emb_label = emb_label
|
465 |
self.labels_to_plot = labels_to_plot
|
|
|
466 |
self.forward_batch_size = forward_batch_size
|
467 |
self.nproc = nproc
|
468 |
if (summary_stat is not None) and ("exact" in summary_stat):
|
@@ -475,6 +507,8 @@ class EmbExtractor:
|
|
475 |
self.validate_options()
|
476 |
|
477 |
# load token dictionary (Ensembl IDs:token)
|
|
|
|
|
478 |
with open(token_dictionary_file, "rb") as f:
|
479 |
self.gene_token_dict = pickle.load(f)
|
480 |
|
@@ -490,7 +524,7 @@ class EmbExtractor:
|
|
490 |
continue
|
491 |
valid_type = False
|
492 |
for option in valid_options:
|
493 |
-
if (option in [int, list, dict, bool]) and isinstance(
|
494 |
attr_value, option
|
495 |
):
|
496 |
valid_type = True
|
@@ -570,6 +604,7 @@ class EmbExtractor:
|
|
570 |
layer_to_quant,
|
571 |
self.pad_token_id,
|
572 |
self.forward_batch_size,
|
|
|
573 |
self.summary_stat,
|
574 |
)
|
575 |
|
@@ -584,6 +619,8 @@ class EmbExtractor:
|
|
584 |
elif self.summary_stat is not None:
|
585 |
embs_df = pd.DataFrame(embs).T
|
586 |
embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
|
|
|
|
|
587 |
|
588 |
# save embeddings to output_path
|
589 |
if cell_state is None:
|
@@ -781,7 +818,7 @@ class EmbExtractor:
|
|
781 |
f"not present in provided embeddings dataframe."
|
782 |
)
|
783 |
continue
|
784 |
-
output_prefix_label =
|
785 |
output_file = (
|
786 |
Path(output_directory) / output_prefix_label
|
787 |
).with_suffix(".pdf")
|
|
|
38 |
layer_to_quant,
|
39 |
pad_token_id,
|
40 |
forward_batch_size,
|
41 |
+
token_gene_dict,
|
42 |
summary_stat=None,
|
43 |
silent=False,
|
44 |
):
|
45 |
model_input_size = pu.get_model_input_size(model)
|
46 |
total_batch_length = len(filtered_input_data)
|
47 |
+
|
48 |
if summary_stat is None:
|
49 |
embs_list = []
|
50 |
elif summary_stat is not None:
|
|
|
68 |
k: [TDigest() for _ in range(emb_dims)] for k in gene_set
|
69 |
}
|
70 |
|
71 |
+
# Check if CLS and SEP token is present in the token dictionary
|
72 |
+
lowercase_token_gene_dict = {k: v.lower() for k, v in token_gene_dict.items()}
|
73 |
+
cls_present = any("cls" in value for value in lowercase_token_gene_dict.values())
|
74 |
+
sep_present = any("sep" in value for value in lowercase_token_gene_dict.values())
|
75 |
+
if emb_mode == "cls":
|
76 |
+
assert cls_present, "CLS token missing in token dictionary"
|
77 |
+
else:
|
78 |
+
if cls_present:
|
79 |
+
logger.warning("CLS token present in token dictionary, excluding from average")
|
80 |
+
if sep_present:
|
81 |
+
logger.warning("SEP token present in token dictionary, excluding from average")
|
82 |
+
|
83 |
overall_max_len = 0
|
84 |
+
|
85 |
+
for i in trange(0, total_batch_length, forward_batch_size, leave = (not silent)):
|
86 |
max_range = min(i + forward_batch_size, total_batch_length)
|
87 |
|
88 |
minibatch = filtered_input_data.select([i for i in range(i, max_range)])
|
|
|
103 |
)
|
104 |
|
105 |
embs_i = outputs.hidden_states[layer_to_quant]
|
106 |
+
|
107 |
if emb_mode == "cell":
|
108 |
+
if cls_present:
|
109 |
+
non_cls_embs = embs_i[:, 1:, :] # Get all layers except the embs
|
110 |
+
if sep_present:
|
111 |
+
mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 2)
|
112 |
+
else:
|
113 |
+
mean_embs = pu.mean_nonpadding_embs(non_cls_embs, origina_lens - 1)
|
114 |
+
else:
|
115 |
+
mean_embs = pu.mean_nonpadding_embs(embs_i, original_lens)
|
116 |
if summary_stat is None:
|
117 |
embs_list.append(mean_embs)
|
118 |
elif summary_stat is not None:
|
|
|
141 |
accumulate_tdigests(
|
142 |
embs_tdigests_dict[int(k)], dict_h[k], emb_dims
|
143 |
)
|
144 |
+
del embs_h
|
145 |
+
del dict_h
|
146 |
+
elif emb_mode == "cls":
|
147 |
+
cls_embs = embs_i[:,0,:].cpu() # CLS token layer
|
148 |
+
embs_list.append(cls_embs)
|
149 |
+
del cls_embs
|
150 |
+
|
151 |
overall_max_len = max(overall_max_len, max_len)
|
152 |
del outputs
|
153 |
del minibatch
|
|
|
155 |
del embs_i
|
156 |
|
157 |
torch.cuda.empty_cache()
|
158 |
+
|
159 |
+
|
160 |
if summary_stat is None:
|
161 |
if emb_mode == "cell":
|
162 |
embs_stack = torch.cat(embs_list, dim=0)
|
|
|
169 |
1,
|
170 |
pu.pad_3d_tensor,
|
171 |
)
|
172 |
+
elif emb_mode == "cls":
|
173 |
+
embs_stack = torch.cat(embs_list, dim=0)
|
174 |
|
175 |
# calculate summary stat embs from approximated tdigests
|
176 |
elif summary_stat is not None:
|
|
|
377 |
bbox_to_anchor=(0.5, 1),
|
378 |
facecolor="white",
|
379 |
)
|
380 |
+
print(f"Output file: {output_file}")
|
381 |
plt.savefig(output_file, bbox_inches="tight")
|
382 |
|
383 |
|
|
|
385 |
valid_option_dict = {
|
386 |
"model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
|
387 |
"num_classes": {int},
|
388 |
+
"emb_mode": {"cell", "gene", "cls"},
|
389 |
"cell_emb_style": {"mean_pool"},
|
390 |
"gene_emb_style": {"mean_pool"},
|
391 |
"filter_data": {None, dict},
|
|
|
394 |
"emb_label": {None, list},
|
395 |
"labels_to_plot": {None, list},
|
396 |
"forward_batch_size": {int},
|
397 |
+
"token_dictionary_file" : {None, str},
|
398 |
"nproc": {int},
|
399 |
"summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
|
400 |
}
|
|
|
414 |
forward_batch_size=100,
|
415 |
nproc=4,
|
416 |
summary_stat=None,
|
417 |
+
token_dictionary_file=None,
|
418 |
):
|
419 |
"""
|
420 |
Initialize embedding extractor.
|
|
|
464 |
| Non-exact recommended if encountering memory constraints while generating goal embedding positions.
|
465 |
| Non-exact is slower but more memory-efficient.
|
466 |
token_dictionary_file : Path
|
467 |
+
| Default is to the geneformer token dictionary
|
468 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
469 |
|
470 |
**Examples:**
|
|
|
494 |
self.emb_layer = emb_layer
|
495 |
self.emb_label = emb_label
|
496 |
self.labels_to_plot = labels_to_plot
|
497 |
+
self.token_dictionary_file = token_dictionary_file
|
498 |
self.forward_batch_size = forward_batch_size
|
499 |
self.nproc = nproc
|
500 |
if (summary_stat is not None) and ("exact" in summary_stat):
|
|
|
507 |
self.validate_options()
|
508 |
|
509 |
# load token dictionary (Ensembl IDs:token)
|
510 |
+
if self.token_dictionary_file is None:
|
511 |
+
token_dictionary_file = TOKEN_DICTIONARY_FILE
|
512 |
with open(token_dictionary_file, "rb") as f:
|
513 |
self.gene_token_dict = pickle.load(f)
|
514 |
|
|
|
524 |
continue
|
525 |
valid_type = False
|
526 |
for option in valid_options:
|
527 |
+
if (option in [int, list, dict, bool, str]) and isinstance(
|
528 |
attr_value, option
|
529 |
):
|
530 |
valid_type = True
|
|
|
604 |
layer_to_quant,
|
605 |
self.pad_token_id,
|
606 |
self.forward_batch_size,
|
607 |
+
self.token_gene_dict,
|
608 |
self.summary_stat,
|
609 |
)
|
610 |
|
|
|
619 |
elif self.summary_stat is not None:
|
620 |
embs_df = pd.DataFrame(embs).T
|
621 |
embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
|
622 |
+
elif self.emb_mode == "cls":
|
623 |
+
embs_df = label_cell_embs(embs, downsampled_data, self.emb_label)
|
624 |
|
625 |
# save embeddings to output_path
|
626 |
if cell_state is None:
|
|
|
818 |
f"not present in provided embeddings dataframe."
|
819 |
)
|
820 |
continue
|
821 |
+
output_prefix_label = output_prefix + f"_umap_{label}"
|
822 |
output_file = (
|
823 |
Path(output_directory) / output_prefix_label
|
824 |
).with_suffix(".pdf")
|