hchen725 commited on
Commit
2c8d3f5
·
verified ·
1 Parent(s): 664f71e

Update geneformer/emb_extractor.py

Browse files

Check to make sure that all the emb_labels exist in the tokenized data before extracting embedding

Files changed (1) hide show
  1. geneformer/emb_extractor.py +7 -7
geneformer/emb_extractor.py CHANGED
@@ -411,7 +411,7 @@ class EmbExtractor:
411
  self,
412
  model_type="Pretrained",
413
  num_classes=0,
414
- emb_mode="cls",
415
  cell_emb_style="mean_pool",
416
  gene_emb_style="mean_pool",
417
  filter_data=None,
@@ -596,6 +596,12 @@ class EmbExtractor:
596
  filtered_input_data = pu.load_and_filter(
597
  self.filter_data, self.nproc, input_data_file
598
  )
 
 
 
 
 
 
599
  if cell_state is not None:
600
  filtered_input_data = pu.filter_by_dict(
601
  filtered_input_data, cell_state, self.nproc
@@ -719,12 +725,6 @@ class EmbExtractor:
719
  )
720
  raise
721
 
722
- if self.emb_label is not None:
723
- logger.error(
724
- "For extracting state embs, emb_label should be None since labels are based on state embs dict keys."
725
- )
726
- raise
727
-
728
  state_embs_dict = dict()
729
  state_key = cell_states_to_model["state_key"]
730
  for k, v in cell_states_to_model.items():
 
411
  self,
412
  model_type="Pretrained",
413
  num_classes=0,
414
+ emb_mode="cell",
415
  cell_emb_style="mean_pool",
416
  gene_emb_style="mean_pool",
417
  filter_data=None,
 
596
  filtered_input_data = pu.load_and_filter(
597
  self.filter_data, self.nproc, input_data_file
598
  )
599
+
600
+ # Check to make sure that all the labels exist in the tokenized data:
601
+ if self.emb_label is not None:
602
+ for label in self.emb_label:
603
+ assert label in list(filtered_input_data.features), f"Attribute `{label}` not present in dataset features"
604
+
605
  if cell_state is not None:
606
  filtered_input_data = pu.filter_by_dict(
607
  filtered_input_data, cell_state, self.nproc
 
725
  )
726
  raise
727
 
 
 
 
 
 
 
728
  state_embs_dict = dict()
729
  state_key = cell_states_to_model["state_key"]
730
  for k, v in cell_states_to_model.items():