Christina Theodoris commited on
Commit
624349c
·
1 Parent(s): 4302f48

Add option to output embs as tensor

Browse files
examples/extract_and_plot_cell_embeddings.ipynb CHANGED
@@ -29,6 +29,7 @@
29
  " nproc=16)\n",
30
  "\n",
31
  "# extracts embedding from input data\n",
 
32
  "# example dataset: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n",
33
  "embs = embex.extract_embs(\"../fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224\",\n",
34
  " \"path/to/input_data/\",\n",
 
29
  " nproc=16)\n",
30
  "\n",
31
  "# extracts embedding from input data\n",
32
+ "# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
33
  "# example dataset: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n",
34
  "embs = embex.extract_embs(\"../fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224\",\n",
35
  " \"path/to/input_data/\",\n",
geneformer/emb_extractor.py CHANGED
@@ -40,7 +40,7 @@ import seaborn as sns
40
  import torch
41
  from collections import Counter
42
  from pathlib import Path
43
- from tqdm.notebook import trange
44
  from transformers import BertForMaskedLM, BertForTokenClassification, BertForSequenceClassification
45
 
46
  from .tokenizer import TOKEN_DICTIONARY_FILE
@@ -64,7 +64,6 @@ def get_embs(model,
64
  pad_token_id,
65
  forward_batch_size,
66
  summary_stat):
67
-
68
  model_input_size = get_model_input_size(model)
69
  total_batch_length = len(filtered_input_data)
70
 
@@ -138,7 +137,7 @@ def test_emb(model, example, layer_to_quant):
138
  return embs_test.size()[2]
139
 
140
  def label_embs(embs, downsampled_data, emb_labels):
141
- embs_df = pd.DataFrame(embs.cpu())
142
  if emb_labels is not None:
143
  for label in emb_labels:
144
  emb_label = downsampled_data[label]
@@ -367,7 +366,8 @@ class EmbExtractor:
367
  model_directory,
368
  input_data_file,
369
  output_directory,
370
- output_prefix):
 
371
  """
372
  Extract embeddings from input data and save as results in output_directory.
373
 
@@ -381,6 +381,9 @@ class EmbExtractor:
381
  Path to directory where embedding data will be saved as csv
382
  output_prefix : str
383
  Prefix for output file
 
 
 
384
  """
385
 
386
  filtered_input_data = load_and_filter(self.filter_data, self.nproc, input_data_file)
@@ -398,13 +401,16 @@ class EmbExtractor:
398
  if self.summary_stat is None:
399
  embs_df = label_embs(embs, downsampled_data, self.emb_label)
400
  elif self.summary_stat is not None:
401
- embs_df = pd.DataFrame(embs.cpu()).T
402
 
403
  # save embeddings to output_path
404
  output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
405
  embs_df.to_csv(output_path)
406
-
407
- return embs_df
 
 
 
408
 
409
  def plot_embs(self,
410
  embs,
 
40
  import torch
41
  from collections import Counter
42
  from pathlib import Path
43
+ from tqdm.auto import trange
44
  from transformers import BertForMaskedLM, BertForTokenClassification, BertForSequenceClassification
45
 
46
  from .tokenizer import TOKEN_DICTIONARY_FILE
 
64
  pad_token_id,
65
  forward_batch_size,
66
  summary_stat):
 
67
  model_input_size = get_model_input_size(model)
68
  total_batch_length = len(filtered_input_data)
69
 
 
137
  return embs_test.size()[2]
138
 
139
  def label_embs(embs, downsampled_data, emb_labels):
140
+ embs_df = pd.DataFrame(embs.cpu().numpy())
141
  if emb_labels is not None:
142
  for label in emb_labels:
143
  emb_label = downsampled_data[label]
 
366
  model_directory,
367
  input_data_file,
368
  output_directory,
369
+ output_prefix,
370
+ output_torch_embs=False):
371
  """
372
  Extract embeddings from input data and save as results in output_directory.
373
 
 
381
  Path to directory where embedding data will be saved as csv
382
  output_prefix : str
383
  Prefix for output file
384
+ output_torch_embs : bool
385
+ Whether or not to also output the embeddings as a tensor.
386
+ Note, if true, will output embeddings as both dataframe and tensor.
387
  """
388
 
389
  filtered_input_data = load_and_filter(self.filter_data, self.nproc, input_data_file)
 
401
  if self.summary_stat is None:
402
  embs_df = label_embs(embs, downsampled_data, self.emb_label)
403
  elif self.summary_stat is not None:
404
+ embs_df = pd.DataFrame(embs.cpu().numpy()).T
405
 
406
  # save embeddings to output_path
407
  output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
408
  embs_df.to_csv(output_path)
409
+
410
+ if output_torch_embs == True:
411
+ return embs_df, embs
412
+ else:
413
+ return embs_df
414
 
415
  def plot_embs(self,
416
  embs,
geneformer/in_silico_perturber.py CHANGED
@@ -34,7 +34,7 @@ import seaborn as sns; sns.set()
34
  import torch
35
  from collections import defaultdict
36
  from datasets import Dataset, load_from_disk
37
- from tqdm.notebook import trange
38
  from transformers import BertForMaskedLM, BertForTokenClassification, BertForSequenceClassification
39
 
40
  from .tokenizer import TOKEN_DICTIONARY_FILE
 
34
  import torch
35
  from collections import defaultdict
36
  from datasets import Dataset, load_from_disk
37
+ from tqdm.auto import trange
38
  from transformers import BertForMaskedLM, BertForTokenClassification, BertForSequenceClassification
39
 
40
  from .tokenizer import TOKEN_DICTIONARY_FILE
geneformer/in_silico_perturber_stats.py CHANGED
@@ -27,7 +27,7 @@ import statsmodels.stats.multitest as smt
27
  from pathlib import Path
28
  from scipy.stats import ranksums
29
  from sklearn.mixture import GaussianMixture
30
- from tqdm.notebook import trange, tqdm
31
 
32
  from .in_silico_perturber import flatten_list
33
 
 
27
  from pathlib import Path
28
  from scipy.stats import ranksums
29
  from sklearn.mixture import GaussianMixture
30
+ from tqdm.auto import trange, tqdm
31
 
32
  from .in_silico_perturber import flatten_list
33
 
setup.py CHANGED
@@ -16,6 +16,7 @@ setup(
16
  "datasets",
17
  "loompy",
18
  "numpy",
 
19
  "transformers",
20
  ],
21
  )
 
16
  "datasets",
17
  "loompy",
18
  "numpy",
19
+ "tdigest",
20
  "transformers",
21
  ],
22
  )