Add function for summing of Ensembl IDs

#377
by hchen725 - opened
.gitattributes CHANGED
@@ -14,6 +14,7 @@
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
@@ -25,5 +26,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
- geneformer/gene_name_id_dict.pkl filter=lfs diff=lfs merge=lfs -text
29
- model.safetensors filter=lfs diff=lfs merge=lfs -text
 
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pkl filter=lfs diff=lfs merge=lfs -text
18
  *.pt filter=lfs diff=lfs merge=lfs -text
19
  *.pth filter=lfs diff=lfs merge=lfs -text
20
  *.rar filter=lfs diff=lfs merge=lfs -text
 
26
  *.zip filter=lfs diff=lfs merge=lfs -text
27
  *.zstandard filter=lfs diff=lfs merge=lfs -text
28
  *tfevents* filter=lfs diff=lfs merge=lfs -text
29
+ model.safetensors filter=lfs diff=lfs merge=lfs -text
 
geneformer/__init__.py CHANGED
@@ -4,6 +4,7 @@ from pathlib import Path
4
  GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
5
  TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
6
  ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
 
7
 
8
  from . import (
9
  collator_for_classification,
 
4
  GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
5
  TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
6
  ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
7
+ ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict.pkl"
8
 
9
  from . import (
10
  collator_for_classification,
geneformer/ensembl_mapping_dict_gc95M.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0819bcbd869cfa14279449b037eb9ed1d09a91310e77bd1a19d927465030e95c
3
+ size 3957652
geneformer/gene_median_dictionary_gc95M.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a51c53f6a771d64508dfaf61529df70e394c53bd20856926117ae5d641a24bf5
3
+ size 1512661
geneformer/gene_name_id_dict_gc95M.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fabfa0c2f49c598c59ae432a32c3499a5908c033756c663b5e0cddf58deea8e1
3
+ size 1660882
geneformer/token_dictionary_gc95M.pkl ADDED
Binary file (426 kB). View file
 
geneformer/tokenizer.py CHANGED
@@ -1,49 +1,40 @@
1
  """
2
  Geneformer tokenizer.
3
-
4
  **Input data:**
5
-
6
  | *Required format:* raw counts scRNAseq data without feature selection as .loom or anndata file.
7
  | *Required row (gene) attribute:* "ensembl_id"; Ensembl ID for each gene.
8
  | *Required col (cell) attribute:* "n_counts"; total read counts in that cell.
9
-
10
  | *Optional col (cell) attribute:* "filter_pass"; binary indicator of whether cell should be tokenized based on user-defined filtering criteria.
11
  | *Optional col (cell) attributes:* any other cell metadata can be passed on to the tokenized dataset as a custom attribute dictionary as shown below.
12
-
13
  **Usage:**
14
-
15
  .. code-block :: python
16
-
17
  >>> from geneformer import TranscriptomeTokenizer
18
  >>> tk = TranscriptomeTokenizer({"cell_type": "cell_type", "organ_major": "organ"}, nproc=4)
19
  >>> tk.tokenize_data("data_directory", "output_directory", "output_prefix")
20
-
21
  **Description:**
22
-
23
  | Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.
24
-
25
  | The discussion below references the .loom file format, but the analagous labels are required for .h5ad files, just that they will be column instead of row attributes and vice versa due to the transposed format of the two file types.
26
-
27
  | Genes should be labeled with Ensembl IDs (loom row attribute "ensembl_id"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (loom column attribute "n_counts") to be used for normalization.
28
-
29
  | No cell metadata is required, but custom cell attributes may be passed onto the tokenized dataset by providing a dictionary of custom attributes to be added, which is formatted as loom_col_attr_name : desired_dataset_col_attr_name. For example, if the original .loom dataset has column attributes "cell_type" and "organ_major" and one would like to retain these attributes as labels in the tokenized dataset with the new names "cell_type" and "organ", respectively, the following custom attribute dictionary should be provided: {"cell_type": "cell_type", "organ_major": "organ"}.
30
-
31
  | Additionally, if the original .loom file contains a cell column attribute called "filter_pass", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with "1" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.
32
-
33
  | If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer.
34
-
35
  """
36
 
37
  from __future__ import annotations
38
 
 
39
  import logging
40
  import pickle
41
  import warnings
42
  from pathlib import Path
43
  from typing import Literal
 
 
44
 
45
- import anndata as ad
46
  import numpy as np
 
 
 
47
  import scipy.sparse as sp
48
  from datasets import Dataset
49
 
@@ -52,7 +43,7 @@ import loompy as lp # noqa
52
 
53
  logger = logging.getLogger(__name__)
54
 
55
- from . import GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE
56
 
57
 
58
  def rank_genes(gene_vector, gene_tokens):
@@ -74,6 +65,134 @@ def tokenize_cell(gene_vector, gene_tokens):
74
  # rank by median-scaled gene values
75
  return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  class TranscriptomeTokenizer:
79
  def __init__(
@@ -83,14 +202,14 @@ class TranscriptomeTokenizer:
83
  chunk_size=512,
84
  model_input_size=2048,
85
  special_token=False,
 
86
  gene_median_file=GENE_MEDIAN_FILE,
87
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
 
88
  ):
89
  """
90
  Initialize tokenizer.
91
-
92
  **Parameters:**
93
-
94
  custom_attr_name_dict : None, dict
95
  | Dictionary of custom attributes to be added to the dataset.
96
  | Keys are the names of the attributes in the loom file.
@@ -103,12 +222,15 @@ class TranscriptomeTokenizer:
103
  | Max input size of model to truncate input to.
104
  special_token : bool = False
105
  | Adds CLS token before and EOS token after rank value encoding.
 
 
106
  gene_median_file : Path
107
  | Path to pickle file containing dictionary of non-zero median
108
  | gene expression values across Genecorpus-30M.
109
  token_dictionary_file : Path
110
  | Path to pickle file containing token dictionary (Ensembl IDs:token).
111
-
 
112
  """
113
  # dictionary of custom attributes {output dataset column name: input .loom column name}
114
  self.custom_attr_name_dict = custom_attr_name_dict
@@ -134,9 +256,31 @@ class TranscriptomeTokenizer:
134
  with open(token_dictionary_file, "rb") as f:
135
  self.gene_token_dict = pickle.load(f)
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  # gene keys for full vocabulary
138
  self.gene_keys = list(self.gene_token_dict.keys())
139
 
 
 
 
 
140
  # protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
141
  self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))
142
 
@@ -150,9 +294,7 @@ class TranscriptomeTokenizer:
150
  ):
151
  """
152
  Tokenize .loom files in data_directory and save as tokenized .dataset in output_directory.
153
-
154
  **Parameters:**
155
-
156
  data_directory : Path
157
  | Path to directory containing loom files or anndata files
158
  output_directory : Path
@@ -163,7 +305,6 @@ class TranscriptomeTokenizer:
163
  | Format of input files. Can be "loom" or "h5ad".
164
  use_generator : bool
165
  | Whether to use generator or dict for tokenization.
166
-
167
  """
168
  tokenized_cells, cell_metadata = self.tokenize_files(
169
  Path(data_directory), file_format
@@ -214,7 +355,7 @@ class TranscriptomeTokenizer:
214
  return tokenized_cells, cell_metadata
215
 
216
  def tokenize_anndata(self, adata_file_path, target_sum=10_000):
217
- adata = ad.read(adata_file_path, backed="r")
218
 
219
  if self.custom_attr_name_dict is not None:
220
  file_cell_metadata = {
@@ -256,7 +397,8 @@ class TranscriptomeTokenizer:
256
  idx = filter_pass_loc[i : i + self.chunk_size]
257
 
258
  n_counts = adata[idx].obs["n_counts"].values[:, None]
259
- X_view = adata[idx, coding_miRNA_loc].X
 
260
  X_norm = X_view / n_counts * target_sum / norm_factor_vector
261
  X_norm = sp.csr_matrix(X_norm)
262
 
@@ -280,6 +422,9 @@ class TranscriptomeTokenizer:
280
  attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
281
  }
282
 
 
 
 
283
  with lp.connect(str(loom_file_path)) as data:
284
  # define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
285
  coding_miRNA_loc = np.where(
@@ -341,6 +486,9 @@ class TranscriptomeTokenizer:
341
  else:
342
  file_cell_metadata = None
343
 
 
 
 
344
  return tokenized_cells, file_cell_metadata
345
 
346
  def create_dataset(
@@ -396,4 +544,4 @@ class TranscriptomeTokenizer:
396
  output_dataset_truncated = output_dataset.map(
397
  format_cell_features, num_proc=self.nproc
398
  )
399
- return output_dataset_truncated
 
1
  """
2
  Geneformer tokenizer.
 
3
  **Input data:**
 
4
  | *Required format:* raw counts scRNAseq data without feature selection as .loom or anndata file.
5
  | *Required row (gene) attribute:* "ensembl_id"; Ensembl ID for each gene.
6
  | *Required col (cell) attribute:* "n_counts"; total read counts in that cell.
 
7
  | *Optional col (cell) attribute:* "filter_pass"; binary indicator of whether cell should be tokenized based on user-defined filtering criteria.
8
  | *Optional col (cell) attributes:* any other cell metadata can be passed on to the tokenized dataset as a custom attribute dictionary as shown below.
 
9
  **Usage:**
 
10
  .. code-block :: python
 
11
  >>> from geneformer import TranscriptomeTokenizer
12
  >>> tk = TranscriptomeTokenizer({"cell_type": "cell_type", "organ_major": "organ"}, nproc=4)
13
  >>> tk.tokenize_data("data_directory", "output_directory", "output_prefix")
 
14
  **Description:**
 
15
  | Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.
 
16
  | The discussion below references the .loom file format, but the analagous labels are required for .h5ad files, just that they will be column instead of row attributes and vice versa due to the transposed format of the two file types.
 
17
  | Genes should be labeled with Ensembl IDs (loom row attribute "ensembl_id"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (loom column attribute "n_counts") to be used for normalization.
 
18
  | No cell metadata is required, but custom cell attributes may be passed onto the tokenized dataset by providing a dictionary of custom attributes to be added, which is formatted as loom_col_attr_name : desired_dataset_col_attr_name. For example, if the original .loom dataset has column attributes "cell_type" and "organ_major" and one would like to retain these attributes as labels in the tokenized dataset with the new names "cell_type" and "organ", respectively, the following custom attribute dictionary should be provided: {"cell_type": "cell_type", "organ_major": "organ"}.
 
19
  | Additionally, if the original .loom file contains a cell column attribute called "filter_pass", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with "1" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.
 
20
  | If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer.
 
21
  """
22
 
23
  from __future__ import annotations
24
 
25
+ import os
26
  import logging
27
  import pickle
28
  import warnings
29
  from pathlib import Path
30
  from typing import Literal
31
+ from tqdm import tqdm
32
+ from collections import Counter
33
 
 
34
  import numpy as np
35
+ import scanpy as sc
36
+ import loompy as lp
37
+ import pandas as pd
38
  import scipy.sparse as sp
39
  from datasets import Dataset
40
 
 
43
 
44
  logger = logging.getLogger(__name__)
45
 
46
+ from . import GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE, ENSEMBL_MAPPING_FILE
47
 
48
 
49
  def rank_genes(gene_vector, gene_tokens):
 
65
  # rank by median-scaled gene values
66
  return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])
67
 
68
+ def sum_ensembl_ids(data_directory,
69
+ collapse_gene_ids,
70
+ gene_mapping_dict,
71
+ gene_token_dict,
72
+ file_format = "loom",
73
+ chunk_size = 512):
74
+
75
+ if file_format == "loom":
76
+ """
77
+ Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
78
+ """
79
+ with lp.connect(data_directory) as data:
80
+ assert "ensembl_id" in data.ra.keys(), "'ensembl_id' column missing from data.ra.keys()"
81
+ gene_ids_in_dict = [gene for gene in data.ra.ensembl_id if gene in gene_token_dict.keys()]
82
+ if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
83
+ token_genes_unique = True
84
+ else:
85
+ token_genes_unique = False
86
+ if collapse_gene_ids is False:
87
+ if token_genes_unique:
88
+ return data_directory
89
+ else:
90
+ raise ValueError("Error: data Ensembl IDs non-unique.")
91
+
92
+ gene_ids_collapsed = [gene_mapping_dict.get(gene_id.upper()) for gene_id in data.ra.ensembl_id]
93
+ gene_ids_collapsed_in_dict = [gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()]
94
+
95
+ if (len(set(gene_ids_collapsed_in_dict)) == len(set(gene_ids_in_dict))) and token_genes_unique:
96
+ return data_directory
97
+ else:
98
+ dedup_filename = data_directory.with_name(data_directory.stem + "__dedup.loom")
99
+ data.ra["gene_ids_collapsed"] = gene_ids_collapsed
100
+ dup_genes = [idx for idx, count in Counter(data.ra["gene_ids_collapsed"]).items() if count > 1]
101
+ num_chunks = int(np.ceil(data.shape[1] / chunk_size))
102
+ first_chunk = True
103
+ for _, _, view in tqdm(data.scan(axis = 1, batch_size = chunk_size), total = num_chunks):
104
+ def process_chunk(view, duplic_genes):
105
+ data_count_view = pd.DataFrame(view, index=data.ra["gene_ids_collapsed"])
106
+ unique_data_df = data_count_view.loc[~data_count_view.index.isin(duplic_genes)]
107
+ dup_data_df = data_count_view.loc[data_count_view.index.isin([i for i in duplic_genes if "None" not in i])]
108
+ summed_data = dup_data_df.groupby(dup_data_df.index).sum()
109
+ if not summed_data.index.is_unique:
110
+ raise ValueError("Error: Ensembl IDs in summed data frame non-unique.")
111
+ data_count_view = pd.concat([unique_data_df, summed_data], axis=0)
112
+ if not data_count_view.index.is_unique:
113
+ raise ValueError("Error: Ensembl IDs in final data frame non-unique.")
114
+ return data_count_view
115
+ processed_chunk = process_chunk(view[:, :], dup_genes)
116
+ processed_array = processed_chunk.to_numpy()
117
+ new_row_attrs = {"ensembl_id": processed_chunk.index.to_numpy()}
118
+
119
+ if "n_counts" not in view.ca.keys():
120
+ total_count_view = np.sum(view[:,:], axis=0).astype(int)
121
+ view.ca["n_counts"] = total_count_view
122
+
123
+ if first_chunk: # Create the Loom file with the first chunk
124
+ lp.create(f"{dedup_filename}", processed_array, row_attrs=new_row_attrs, col_attrs=view.ca)
125
+ first_chunk = False
126
+ else: # Append subsequent chunks
127
+ with lp.connect(dedup_filename, mode='r+') as dsout:
128
+ dsout.add_columns(processed_array, col_attrs=view.ca)
129
+ return dedup_filename
130
+
131
+ elif file_format == "h5ad":
132
+ """
133
+ Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
134
+ Returns adata object with deduplicated Ensembl IDs.
135
+ """
136
+
137
+ data = sc.read_h5ad(str(data_directory))
138
+
139
+ assert "ensembl_id" in data.var.columns, "'ensembl_id' column missing from data.var"
140
+ gene_ids_in_dict = [gene for gene in data.var.ensembl_id if gene in gene_token_dict.keys()]
141
+ if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
142
+ token_genes_unique = True
143
+ else:
144
+ token_genes_unique = False
145
+ if collapse_gene_ids is False:
146
+ if token_genes_unique:
147
+ return data
148
+ else:
149
+ raise ValueError("Error: data Ensembl IDs non-unique.")
150
+
151
+ gene_ids_collapsed = [gene_mapping_dict.get(gene_id.upper()) for gene_id in data.var.ensembl_id]
152
+ gene_ids_collapsed_in_dict = [gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()]
153
+ if (len(set(gene_ids_collapsed_in_dict)) == len(set(gene_ids_in_dict))) and token_genes_unique:
154
+ return data
155
+
156
+ else:
157
+ data.var["gene_ids_collapsed"] = gene_ids_collapsed
158
+ data.var_names = gene_ids_collapsed
159
+ data = data[:, ~data.var.index.isna()]
160
+ dup_genes = [idx for idx, count in Counter(data.var_names).items() if count > 1]
161
+
162
+ num_chunks = int(np.ceil(data.shape[0] / chunk_size))
163
+
164
+ processed_genes = []
165
+ for i in tqdm(range(num_chunks)):
166
+
167
+ start_idx = i * chunk_size
168
+ end_idx = min((i + 1) * chunk_size, data.shape[0])
169
+ data_chunk = data[start_idx:end_idx, :]
170
+
171
+ processed_chunks = []
172
+ for dup_gene in dup_genes:
173
+ data_dup_gene = data_chunk[:, data_chunk.var_names == dup_gene]
174
+ df = pd.DataFrame.sparse.from_spmatrix(data_dup_gene.X,
175
+ index=data_dup_gene.obs_names,
176
+ columns=data_dup_gene.var_names)
177
+ df_sum = pd.DataFrame(df.sum(axis=1))
178
+ df_sum.columns = [dup_gene]
179
+ df_sum.index = data_dup_gene.obs.index
180
+ processed_chunks.append(df_sum)
181
+
182
+ processed_chunks = pd.concat(processed_chunks, axis=1)
183
+ processed_genes.append(processed_chunks)
184
+ processed_genes = pd.concat(processed_genes, axis = 0)
185
+ var_df = pd.DataFrame({"gene_ids_collapsed" : processed_genes.columns})
186
+ var_df.index = processed_genes.columns
187
+ processed_genes = sc.AnnData(X = processed_genes,
188
+ obs = data.obs,
189
+ var = var_df)
190
+
191
+ data_dedup = data[:, ~data.var.index.isin(dup_genes)] # Deduplicated data
192
+ data_dedup = sc.concat([data_dedup, processed_genes], axis = 1)
193
+ data_dedup.obs = data.obs
194
+ data_dedup.var = data_dedup.var.rename(columns = {"gene_ids_collapsed" : "ensembl_id"})
195
+ return data_dedup
196
 
197
  class TranscriptomeTokenizer:
198
  def __init__(
 
202
  chunk_size=512,
203
  model_input_size=2048,
204
  special_token=False,
205
+ collapse_gene_ids=True,
206
  gene_median_file=GENE_MEDIAN_FILE,
207
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
208
+ gene_mapping_file=ENSEMBL_MAPPING_FILE,
209
  ):
210
  """
211
  Initialize tokenizer.
 
212
  **Parameters:**
 
213
  custom_attr_name_dict : None, dict
214
  | Dictionary of custom attributes to be added to the dataset.
215
  | Keys are the names of the attributes in the loom file.
 
222
  | Max input size of model to truncate input to.
223
  special_token : bool = False
224
  | Adds CLS token before and EOS token after rank value encoding.
225
+ collapse_gene_ids : bool = True
226
+ | Whether to collapse gene IDs based on gene mapping dictionary.
227
  gene_median_file : Path
228
  | Path to pickle file containing dictionary of non-zero median
229
  | gene expression values across Genecorpus-30M.
230
  token_dictionary_file : Path
231
  | Path to pickle file containing token dictionary (Ensembl IDs:token).
232
+ gene_mapping_file : None, Path
233
+ | Path to pickle file containing dictionary for collapsing gene IDs.
234
  """
235
  # dictionary of custom attributes {output dataset column name: input .loom column name}
236
  self.custom_attr_name_dict = custom_attr_name_dict
 
256
  with open(token_dictionary_file, "rb") as f:
257
  self.gene_token_dict = pickle.load(f)
258
 
259
+ # check for special token in gene_token_dict
260
+ if self.special_token:
261
+ if ("<cls>" not in self.gene_token_dict.keys()) and ("<eos>" not in self.gene_token_dict.keys()):
262
+ logger.error(
263
+ "<cls> and <eos> required in gene_token_dict when special_token = True."
264
+ )
265
+ raise
266
+
267
+ # if collapsing duplicate gene IDs
268
+ self.collapse_gene_ids = collapse_gene_ids
269
+
270
+ # load gene mappings dictionary (Ensembl IDs:Ensembl ID)
271
+ if gene_mapping_file is not None:
272
+ with open(gene_mapping_file, "rb") as f:
273
+ self.gene_mapping_dict = pickle.load(f)
274
+ else:
275
+ self.gene_mapping_dict = {k:k for k,_ in self.gene_token_dict.items()}
276
+
277
  # gene keys for full vocabulary
278
  self.gene_keys = list(self.gene_token_dict.keys())
279
 
280
+ # Filter gene mapping dict for items that exist in gene_token_dict
281
+ gene_keys_set = set(self.gene_token_dict.keys())
282
+ self.gene_mapping_dict = {k: v for k, v in self.gene_mapping_dict.items() if v in gene_keys_set}
283
+
284
  # protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
285
  self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))
286
 
 
294
  ):
295
  """
296
  Tokenize .loom files in data_directory and save as tokenized .dataset in output_directory.
 
297
  **Parameters:**
 
298
  data_directory : Path
299
  | Path to directory containing loom files or anndata files
300
  output_directory : Path
 
305
  | Format of input files. Can be "loom" or "h5ad".
306
  use_generator : bool
307
  | Whether to use generator or dict for tokenization.
 
308
  """
309
  tokenized_cells, cell_metadata = self.tokenize_files(
310
  Path(data_directory), file_format
 
355
  return tokenized_cells, cell_metadata
356
 
357
  def tokenize_anndata(self, adata_file_path, target_sum=10_000):
358
+ adata = sum_ensembl_ids(adata_file_path, self.collapse_gene_ids, self.gene_mapping_dict, self.gene_token_dict, file_format = "h5ad", chunk_size = self.chunk_size)
359
 
360
  if self.custom_attr_name_dict is not None:
361
  file_cell_metadata = {
 
397
  idx = filter_pass_loc[i : i + self.chunk_size]
398
 
399
  n_counts = adata[idx].obs["n_counts"].values[:, None]
400
+ X_view0 = adata[idx,:].X
401
+ X_view = X_view0[:, coding_miRNA_loc]
402
  X_norm = X_view / n_counts * target_sum / norm_factor_vector
403
  X_norm = sp.csr_matrix(X_norm)
404
 
 
422
  attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
423
  }
424
 
425
+ dedup_filename = loom_file_path.with_name(loom_file_path.stem + "__dedup.loom")
426
+ loom_file_path = sum_ensembl_ids(loom_file_path, self.collapse_gene_ids, self.gene_mapping_dict, self.gene_token_dict, file_format = "loom", chunk_size = self.chunk_size)
427
+
428
  with lp.connect(str(loom_file_path)) as data:
429
  # define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
430
  coding_miRNA_loc = np.where(
 
486
  else:
487
  file_cell_metadata = None
488
 
489
+ if str(dedup_filename) == str(loom_file_path):
490
+ os.remove(str(dedup_filename))
491
+
492
  return tokenized_cells, file_cell_metadata
493
 
494
  def create_dataset(
 
544
  output_dataset_truncated = output_dataset.map(
545
  format_cell_features, num_proc=self.nproc
546
  )
547
+ return output_dataset_truncated