Christina Theodoris commited on
Commit
5426788
·
1 Parent(s): b73028f

Add Geneformer tokenizer and updated model card

Browse files
README.md CHANGED
@@ -1,17 +1,14 @@
1
  # Geneformer
2
  Geneformer is a transformer model pretrained on a large-scale corpus of ~30 million single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
3
 
4
- <!---
5
  See [our manuscript](manuscript_link) for details.
6
- -->
7
 
8
  # Model Description
9
  Geneformer is transformer model pretrained on a [Genecorpus-30M](dataset_link), a pretraining corpus comprised of ~30 million single cell transcriptomes from a broad range of human tissues. Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell normalized by their expression across the entire Genecorpus-30M. The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across Genecorpus-30M to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by normalizing them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable.
10
 
11
  The rank value encoding of each single cell’s transcriptome then proceeds through six transformer encoder units. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels.
12
 
13
- <!--- We detail applications and results in [our manuscript](manuscript_link). -->
14
- During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. Fine-tuning Geneformer towards a diverse panel of downstream tasks relevant to chromatin and network dynamics using limited task-specific data demonstrated that Geneformer consistently boosted predictive accuracy. Applied to disease modeling with limited patient data, Geneformer identified candidate therapeutic targets. Overall, Geneformer represents an invaluable pretrained model from which fine-tuning towards a broad range of downstream applications can be pursued to accelerate discovery of key network regulators and candidate therapeutic targets.
15
 
16
  # Application
17
- The pretrained Geneformer model can be used directly, for example for in silico deletion analysis, but is best used by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
 
1
  # Geneformer
2
  Geneformer is a transformer model pretrained on a large-scale corpus of ~30 million single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
3
 
 
4
  See [our manuscript](manuscript_link) for details.
 
5
 
6
  # Model Description
7
  Geneformer is transformer model pretrained on a [Genecorpus-30M](dataset_link), a pretraining corpus comprised of ~30 million single cell transcriptomes from a broad range of human tissues. Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell normalized by their expression across the entire Genecorpus-30M. The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across Genecorpus-30M to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by normalizing them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable.
8
 
9
  The rank value encoding of each single cell’s transcriptome then proceeds through six transformer encoder units. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels.
10
 
11
+ We detail applications and results in [our manuscript](manuscript_link). During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. Fine-tuning Geneformer towards a diverse panel of downstream tasks relevant to chromatin and network dynamics using limited task-specific data demonstrated that Geneformer consistently boosted predictive accuracy. Applied to disease modeling with limited patient data, Geneformer identified candidate therapeutic targets. Overall, Geneformer represents an invaluable pretrained model from which fine-tuning towards a broad range of downstream applications can be pursued to accelerate discovery of key network regulators and candidate therapeutic targets.
 
12
 
13
  # Application
14
+ The pretrained Geneformer model can be used directly, for example for in silico deletion analysis, but is best used by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
geneformer/__init__.py ADDED
File without changes
geneformer/gene_median_dictionary.pkl ADDED
Binary file (941 kB). View file
 
geneformer/token_dictionary.pkl ADDED
Binary file (788 kB). View file
 
geneformer/tokenizer.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer tokenizer.
3
+
4
+ Usage:
5
+ from geneformer.tokenizer import Tokenizer
6
+ tk = Tokenizer({"cell_type": "cell_type", "organ_major": "organ_major"}, nproc=4)
7
+ tk.tokenize_data("loom_data_directory", "output_directory", "output_prefix")
8
+ """
9
+
10
+ import pickle
11
+ from pathlib import Path
12
+
13
+ import loompy as lp
14
+ import numpy as np
15
+ from datasets import Dataset
16
+
17
+ GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
18
+ TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
19
+
20
+
21
+ def tokenize_cell(gene_vector, gene_tokens):
22
+ """
23
+ Convert normalized gene expression vector to tokenized rank value encoding.
24
+ """
25
+ # create array of gene vector with token indices
26
+ # mask undetected genes
27
+ nonzero_mask = np.nonzero(gene_vector)[0]
28
+ # sort by median-scaled gene values
29
+ sorted_indices = np.argsort(-gene_vector[nonzero_mask])
30
+ # tokenize
31
+ sentence_tokens = gene_tokens[nonzero_mask][sorted_indices]
32
+ return sentence_tokens
33
+
34
+
35
+ class Tokenizer:
36
+ def __init__(
37
+ self,
38
+ custom_attr_name_dict,
39
+ nproc=1,
40
+ gene_median_file=GENE_MEDIAN_FILE,
41
+ token_dictionary_file=TOKEN_DICTIONARY_FILE,
42
+ ):
43
+ """
44
+ Initialize tokenizer.
45
+
46
+ Parameters
47
+ ----------
48
+ custom_attr_name_dict : dict
49
+ Dictionary of custom attributes to be added to the dataset.
50
+ Keys are the names of the attributes in the loom file.
51
+ Values are the names of the attributes in the dataset.
52
+ nproc : int
53
+ Number of processes to use for dataset mapping.
54
+ gene_median_file : Path
55
+ Path to pickle file containing dictionary of non-zero median
56
+ gene expression values across Genecorpus-30M.
57
+ token_dictionary_file : Path
58
+ Path to pickle file containing token dictionary (Ensembl IDs:token).
59
+ """
60
+ # dictionary of custom attributes {output dataset column name: input .loom column name}
61
+ self.custom_attr_name_dict = custom_attr_name_dict
62
+
63
+ # number of processes for dataset mapping
64
+ self.nproc = nproc
65
+
66
+ # load dictionary of gene normalization factors
67
+ # (non-zero median value of expression across Genecorpus-30M)
68
+ with open(gene_median_file, "rb") as f:
69
+ self.gene_median_dict = pickle.load(f)
70
+
71
+ # load token dictionary (Ensembl IDs:token)
72
+ with open(token_dictionary_file, "rb") as f:
73
+ self.gene_token_dict = pickle.load(f)
74
+
75
+ # gene keys for full vocabulary
76
+ self.gene_keys = list(self.gene_median_dict.keys())
77
+
78
+ # protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
79
+ self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))
80
+
81
+ def tokenize_data(self, loom_data_directory, output_directory, output_prefix):
82
+ """
83
+ Tokenize .loom files in loom_data_directory and save as tokenized .dataset in output_directory.
84
+
85
+ Parameters
86
+ ----------
87
+ loom_data_directory : Path
88
+ Path to directory containing loom files
89
+ output_directory : Path
90
+ Path to directory where tokenized data will be saved as .dataset
91
+ output_prefix : str
92
+ Prefix for output .dataset
93
+ """
94
+ tokenized_cells, cell_metadata = self.tokenize_files(loom_data_directory)
95
+ tokenized_dataset = self.create_dataset(tokenized_cells, cell_metadata)
96
+
97
+ output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
98
+ tokenized_dataset.save_to_disk(output_path)
99
+
100
+ def tokenize_files(self, loom_data_directory):
101
+ tokenized_cells = []
102
+ cell_metadata = {attr_key: [] for attr_key in self.custom_attr_name_dict.keys()}
103
+
104
+ # loops through directories to tokenize .loom files
105
+ for loom_file_path in loom_data_directory.glob("*.loom"):
106
+ print(f"Tokenizing {loom_file_path}")
107
+ file_tokenized_cells, file_cell_metadata = self.tokenize_file(
108
+ loom_file_path
109
+ )
110
+ tokenized_cells += file_tokenized_cells
111
+ cell_metadata.update(file_cell_metadata)
112
+
113
+ return tokenized_cells, cell_metadata
114
+
115
+ def tokenize_file(self, loom_file_path):
116
+ file_cell_metadata = {
117
+ attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
118
+ }
119
+
120
+ with lp.connect(str(loom_file_path)) as data:
121
+ # define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
122
+ coding_miRNA_loc = np.where(
123
+ [self.genelist_dict.get(i, False) for i in data.ra["ensembl_id"]]
124
+ )[0]
125
+ norm_factor_vector = np.array(
126
+ [
127
+ self.gene_median_dict[i]
128
+ for i in data.ra["ensembl_id"][coding_miRNA_loc]
129
+ ]
130
+ )
131
+ coding_miRNA_ids = data.ra["ensembl_id"][coding_miRNA_loc]
132
+ coding_miRNA_tokens = np.array(
133
+ [self.gene_token_dict[i] for i in coding_miRNA_ids]
134
+ )
135
+
136
+ # define coordinates of cells passing filters for inclusion (e.g. QC)
137
+ try:
138
+ data.ca["filter_pass"]
139
+ except NameError:
140
+ var_exists = False
141
+ else:
142
+ var_exists = True
143
+
144
+ if var_exists is True:
145
+ filter_pass_loc = np.where(
146
+ [True if i == 1 else False for i in data.ca["filter_pass"]]
147
+ )[0]
148
+ elif var_exists is False:
149
+ print(
150
+ f"{loom_file_path} has no column attribute 'filter_pass'; tokenizing all cells."
151
+ )
152
+ filter_pass_loc = np.array([i for i in range(data.shape[1])])
153
+
154
+ # scan through .loom files and tokenize cells
155
+ tokenized_cells = []
156
+ for (_ix, _selection, view) in data.scan(items=filter_pass_loc, axis=1):
157
+ # select subview with protein-coding and miRNA genes
158
+ subview = view.view[coding_miRNA_loc, :]
159
+
160
+ # normalize by total counts per cell and multiply by 10,000 to allocate bits to precision
161
+ # and normalize by gene normalization factors
162
+ subview_norm_array = (
163
+ subview[:, :]
164
+ / subview.ca.n_counts
165
+ * 10_000
166
+ / norm_factor_vector[:, None]
167
+ )
168
+ # tokenize subview gene vectors
169
+ tokenized_cells += [
170
+ tokenize_cell(subview_norm_array[:, i], coding_miRNA_tokens)
171
+ for i in range(subview_norm_array.shape[1])
172
+ ]
173
+
174
+ # add custom attributes for subview to dict
175
+ for k in file_cell_metadata.keys():
176
+ file_cell_metadata[k] += subview.ca[k].tolist()
177
+
178
+ return tokenized_cells, file_cell_metadata
179
+
180
+ def create_dataset(self, tokenized_cells, cell_metadata):
181
+ # create dict for dataset creation
182
+ dataset_dict = {"input_ids": tokenized_cells}
183
+ dataset_dict.update(cell_metadata)
184
+
185
+ # create dataset
186
+ output_dataset = Dataset.from_dict(dataset_dict)
187
+
188
+ # truncate dataset
189
+ def truncate(example):
190
+ example["input_ids"] = example["input_ids"][0:2048]
191
+ return example
192
+
193
+ output_dataset_truncated = output_dataset.map(truncate, num_proc=self.nproc)
194
+
195
+ # measure lengths of dataset
196
+ def measure_length(example):
197
+ example["length"] = len(example["input_ids"])
198
+ return example
199
+
200
+ output_dataset_truncated_w_length = output_dataset_truncated.map(
201
+ measure_length, num_proc=self.nproc
202
+ )
203
+
204
+ return output_dataset_truncated_w_length