Add option for variable input_size and to add CLS/SEP Tokens (#299)
Browse files- Add option for variable input_size and to add CLS/SEP Tokens (944c98c4ed4fa9bd539c2a7d8db5351161d85012)
Co-authored-by: Han Chen <[email protected]>
- geneformer/tokenizer.py +22 -8
geneformer/tokenizer.py
CHANGED
@@ -81,14 +81,14 @@ class TranscriptomeTokenizer:
|
|
81 |
custom_attr_name_dict=None,
|
82 |
nproc=1,
|
83 |
chunk_size=512,
|
|
|
|
|
84 |
gene_median_file=GENE_MEDIAN_FILE,
|
85 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
86 |
):
|
87 |
"""
|
88 |
Initialize tokenizer.
|
89 |
-
|
90 |
**Parameters:**
|
91 |
-
|
92 |
custom_attr_name_dict : None, dict
|
93 |
| Dictionary of custom attributes to be added to the dataset.
|
94 |
| Keys are the names of the attributes in the loom file.
|
@@ -97,6 +97,10 @@ class TranscriptomeTokenizer:
|
|
97 |
| Number of processes to use for dataset mapping.
|
98 |
chunk_size: int = 512
|
99 |
| Chunk size for anndata tokenizer.
|
|
|
|
|
|
|
|
|
100 |
gene_median_file : Path
|
101 |
| Path to pickle file containing dictionary of non-zero median
|
102 |
| gene expression values across Genecorpus-30M.
|
@@ -112,6 +116,12 @@ class TranscriptomeTokenizer:
|
|
112 |
# chunk size for anndata tokenizer
|
113 |
self.chunk_size = chunk_size
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
# load dictionary of gene normalization factors
|
116 |
# (non-zero median value of expression across Genecorpus-30M)
|
117 |
with open(gene_median_file, "rb") as f:
|
@@ -137,9 +147,7 @@ class TranscriptomeTokenizer:
|
|
137 |
):
|
138 |
"""
|
139 |
Tokenize .loom files in data_directory and save as tokenized .dataset in output_directory.
|
140 |
-
|
141 |
**Parameters:**
|
142 |
-
|
143 |
data_directory : Path
|
144 |
| Path to directory containing loom files or anndata files
|
145 |
output_directory : Path
|
@@ -324,7 +332,7 @@ class TranscriptomeTokenizer:
|
|
324 |
file_cell_metadata[k] += subview.ca[k].tolist()
|
325 |
else:
|
326 |
file_cell_metadata = None
|
327 |
-
|
328 |
return tokenized_cells, file_cell_metadata
|
329 |
|
330 |
def create_dataset(
|
@@ -357,8 +365,14 @@ class TranscriptomeTokenizer:
|
|
357 |
example["input_ids_uncropped"] = example["input_ids"]
|
358 |
example["length_uncropped"] = len(example["input_ids"])
|
359 |
|
360 |
-
# Truncate/Crop input_ids to size
|
361 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
362 |
example["length"] = len(example["input_ids"])
|
363 |
|
364 |
return example
|
@@ -366,4 +380,4 @@ class TranscriptomeTokenizer:
|
|
366 |
output_dataset_truncated = output_dataset.map(
|
367 |
format_cell_features, num_proc=self.nproc
|
368 |
)
|
369 |
-
return output_dataset_truncated
|
|
|
81 |
custom_attr_name_dict=None,
|
82 |
nproc=1,
|
83 |
chunk_size=512,
|
84 |
+
input_size=2048,
|
85 |
+
special_token=False,
|
86 |
gene_median_file=GENE_MEDIAN_FILE,
|
87 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
88 |
):
|
89 |
"""
|
90 |
Initialize tokenizer.
|
|
|
91 |
**Parameters:**
|
|
|
92 |
custom_attr_name_dict : None, dict
|
93 |
| Dictionary of custom attributes to be added to the dataset.
|
94 |
| Keys are the names of the attributes in the loom file.
|
|
|
97 |
| Number of processes to use for dataset mapping.
|
98 |
chunk_size: int = 512
|
99 |
| Chunk size for anndata tokenizer.
|
100 |
+
input_size: int = 2048
|
101 |
+
| Input size for tokenization
|
102 |
+
special_token: bool = False
|
103 |
+
| Option to add CLS and SEP tokens
|
104 |
gene_median_file : Path
|
105 |
| Path to pickle file containing dictionary of non-zero median
|
106 |
| gene expression values across Genecorpus-30M.
|
|
|
116 |
# chunk size for anndata tokenizer
|
117 |
self.chunk_size = chunk_size
|
118 |
|
119 |
+
# input size for tokenization
|
120 |
+
self.input_size = input_size
|
121 |
+
|
122 |
+
# add CLS and SEP tokens
|
123 |
+
self.special_token = special_token
|
124 |
+
|
125 |
# load dictionary of gene normalization factors
|
126 |
# (non-zero median value of expression across Genecorpus-30M)
|
127 |
with open(gene_median_file, "rb") as f:
|
|
|
147 |
):
|
148 |
"""
|
149 |
Tokenize .loom files in data_directory and save as tokenized .dataset in output_directory.
|
|
|
150 |
**Parameters:**
|
|
|
151 |
data_directory : Path
|
152 |
| Path to directory containing loom files or anndata files
|
153 |
output_directory : Path
|
|
|
332 |
file_cell_metadata[k] += subview.ca[k].tolist()
|
333 |
else:
|
334 |
file_cell_metadata = None
|
335 |
+
|
336 |
return tokenized_cells, file_cell_metadata
|
337 |
|
338 |
def create_dataset(
|
|
|
365 |
example["input_ids_uncropped"] = example["input_ids"]
|
366 |
example["length_uncropped"] = len(example["input_ids"])
|
367 |
|
368 |
+
# Truncate/Crop input_ids to input size
|
369 |
+
if tk.special_token:
|
370 |
+
example["input_ids"] = example["input_ids"][0:self.input_size-2] # truncate to leave space for CLS and SEP token
|
371 |
+
example["input_ids"] = np.insert(example["input_ids"], 0, self.gene_token_dict.get("<cls>"))
|
372 |
+
example["input_ids"] = np.insert(example["input_ids"], len(example["input_ids"]), self.gene_token_dict.get("<sep>"))
|
373 |
+
else:
|
374 |
+
# Truncate/Crop input_ids to input size
|
375 |
+
example["input_ids"] = example["input_ids"][0:self.input_size]
|
376 |
example["length"] = len(example["input_ids"])
|
377 |
|
378 |
return example
|
|
|
380 |
output_dataset_truncated = output_dataset.map(
|
381 |
format_cell_features, num_proc=self.nproc
|
382 |
)
|
383 |
+
return output_dataset_truncated
|