Update geneformer/tokenizer.py

#415
by hchen725 - opened
geneformer/gene_dictionaries_30m/ensembl_mapping_dict_gc30M.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eac0fb0b3007267871b6305ac0003ceba19d4f28d85686cb9067ecf142787869
3
+ size 584125
geneformer/tokenizer.py CHANGED
@@ -63,7 +63,6 @@ logger = logging.getLogger(__name__)
63
 
64
  from . import ENSEMBL_MAPPING_FILE, GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE
65
 
66
-
67
  def rank_genes(gene_vector, gene_tokens):
68
  """
69
  Rank gene expression vector.
@@ -100,15 +99,18 @@ def sum_ensembl_ids(
100
  assert (
101
  "ensembl_id" in data.ra.keys()
102
  ), "'ensembl_id' column missing from data.ra.keys()"
 
 
 
 
 
 
103
  gene_ids_in_dict = [
104
  gene for gene in data.ra.ensembl_id if gene in gene_token_dict.keys()
105
  ]
106
- if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
107
- token_genes_unique = True
108
- else:
109
- token_genes_unique = False
110
  if collapse_gene_ids is False:
111
- if token_genes_unique:
 
112
  return data_directory
113
  else:
114
  raise ValueError("Error: data Ensembl IDs non-unique.")
@@ -120,18 +122,17 @@ def sum_ensembl_ids(
120
  gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()
121
  ]
122
 
123
- if (
124
- len(set(gene_ids_collapsed_in_dict)) == len(set(gene_ids_in_dict))
125
- ) and token_genes_unique:
126
  return data_directory
127
  else:
128
  dedup_filename = data_directory.with_name(
129
  data_directory.stem + "__dedup.loom"
130
  )
131
- data.ra["gene_ids_collapsed"] = gene_ids_collapsed
132
  dup_genes = [
133
  idx
134
- for idx, count in Counter(data.ra["gene_ids_collapsed"]).items()
135
  if count > 1
136
  ]
137
  num_chunks = int(np.ceil(data.shape[1] / chunk_size))
@@ -142,7 +143,7 @@ def sum_ensembl_ids(
142
 
143
  def process_chunk(view, duplic_genes):
144
  data_count_view = pd.DataFrame(
145
- view, index=data.ra["gene_ids_collapsed"]
146
  )
147
  unique_data_df = data_count_view.loc[
148
  ~data_count_view.index.isin(duplic_genes)
@@ -168,7 +169,7 @@ def sum_ensembl_ids(
168
 
169
  processed_chunk = process_chunk(view[:, :], dup_genes)
170
  processed_array = processed_chunk.to_numpy()
171
- new_row_attrs = {"ensembl_id": processed_chunk.index.to_numpy()}
172
 
173
  if "n_counts" not in view.ca.keys():
174
  total_count_view = np.sum(view[:, :], axis=0).astype(int)
@@ -198,32 +199,36 @@ def sum_ensembl_ids(
198
  assert (
199
  "ensembl_id" in data.var.columns
200
  ), "'ensembl_id' column missing from data.var"
 
 
 
 
 
 
 
201
  gene_ids_in_dict = [
202
  gene for gene in data.var.ensembl_id if gene in gene_token_dict.keys()
203
  ]
204
- if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
205
- token_genes_unique = True
206
- else:
207
- token_genes_unique = False
208
  if collapse_gene_ids is False:
209
- if token_genes_unique:
 
210
  return data
211
  else:
212
  raise ValueError("Error: data Ensembl IDs non-unique.")
213
 
 
214
  gene_ids_collapsed = [
215
  gene_mapping_dict.get(gene_id.upper()) for gene_id in data.var.ensembl_id
216
  ]
217
  gene_ids_collapsed_in_dict = [
218
  gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()
219
  ]
220
- if (
221
- len(set(gene_ids_collapsed_in_dict)) == len(set(gene_ids_in_dict))
222
- ) and token_genes_unique:
223
  return data
224
 
225
  else:
226
- data.var["gene_ids_collapsed"] = gene_ids_collapsed
227
  data.var_names = gene_ids_collapsed
228
  data = data[:, ~data.var.index.isna()]
229
  dup_genes = [
@@ -254,16 +259,13 @@ def sum_ensembl_ids(
254
  processed_chunks = pd.concat(processed_chunks, axis=1)
255
  processed_genes.append(processed_chunks)
256
  processed_genes = pd.concat(processed_genes, axis=0)
257
- var_df = pd.DataFrame({"gene_ids_collapsed": processed_genes.columns})
258
  var_df.index = processed_genes.columns
259
  processed_genes = sc.AnnData(X=processed_genes, obs=data.obs, var=var_df)
260
 
261
  data_dedup = data[:, ~data.var.index.isin(dup_genes)] # Deduplicated data
262
  data_dedup = sc.concat([data_dedup, processed_genes], axis=1)
263
  data_dedup.obs = data.obs
264
- data_dedup.var = data_dedup.var.rename(
265
- columns={"gene_ids_collapsed": "ensembl_id"}
266
- )
267
  return data_dedup
268
 
269
 
@@ -463,15 +465,15 @@ class TranscriptomeTokenizer:
463
  }
464
 
465
  coding_miRNA_loc = np.where(
466
- [self.genelist_dict.get(i, False) for i in adata.var["ensembl_id"]]
467
  )[0]
468
  norm_factor_vector = np.array(
469
  [
470
  self.gene_median_dict[i]
471
- for i in adata.var["ensembl_id"][coding_miRNA_loc]
472
  ]
473
  )
474
- coding_miRNA_ids = adata.var["ensembl_id"][coding_miRNA_loc]
475
  coding_miRNA_tokens = np.array(
476
  [self.gene_token_dict[i] for i in coding_miRNA_ids]
477
  )
@@ -521,6 +523,7 @@ class TranscriptomeTokenizer:
521
  file_cell_metadata = {
522
  attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
523
  }
 
524
 
525
  dedup_filename = loom_file_path.with_name(loom_file_path.stem + "__dedup.loom")
526
  loom_file_path = sum_ensembl_ids(
@@ -535,15 +538,15 @@ class TranscriptomeTokenizer:
535
  with lp.connect(str(loom_file_path)) as data:
536
  # define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
537
  coding_miRNA_loc = np.where(
538
- [self.genelist_dict.get(i, False) for i in data.ra["ensembl_id"]]
539
  )[0]
540
  norm_factor_vector = np.array(
541
  [
542
  self.gene_median_dict[i]
543
- for i in data.ra["ensembl_id"][coding_miRNA_loc]
544
  ]
545
  )
546
- coding_miRNA_ids = data.ra["ensembl_id"][coding_miRNA_loc]
547
  coding_miRNA_tokens = np.array(
548
  [self.gene_token_dict[i] for i in coding_miRNA_ids]
549
  )
@@ -596,6 +599,11 @@ class TranscriptomeTokenizer:
596
  if str(dedup_filename) == str(loom_file_path):
597
  os.remove(str(dedup_filename))
598
 
 
 
 
 
 
599
  return tokenized_cells, file_cell_metadata
600
 
601
  def create_dataset(
 
63
 
64
  from . import ENSEMBL_MAPPING_FILE, GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE
65
 
 
66
  def rank_genes(gene_vector, gene_tokens):
67
  """
68
  Rank gene expression vector.
 
99
  assert (
100
  "ensembl_id" in data.ra.keys()
101
  ), "'ensembl_id' column missing from data.ra.keys()"
102
+
103
+ assert (
104
+ "ensembl_id_collapsed" not in data.ra.keys()
105
+ ), "'ensembl_id_collapsed' column already exists in data.ra.keys()"
106
+ # Check for duplicate Ensembl IDs if collapse_gene_ids is False.
107
+ # Comparing to gene_token_dict here, would not perform any mapping steps
108
  gene_ids_in_dict = [
109
  gene for gene in data.ra.ensembl_id if gene in gene_token_dict.keys()
110
  ]
 
 
 
 
111
  if collapse_gene_ids is False:
112
+
113
+ if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
114
  return data_directory
115
  else:
116
  raise ValueError("Error: data Ensembl IDs non-unique.")
 
122
  gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()
123
  ]
124
 
125
+ if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed_in_dict)):
126
+ data.ra["ensembl_id_collapsed"] = gene_ids_collapsed
 
127
  return data_directory
128
  else:
129
  dedup_filename = data_directory.with_name(
130
  data_directory.stem + "__dedup.loom"
131
  )
132
+ data.ra["ensembl_id_collapsed"] = gene_ids_collapsed
133
  dup_genes = [
134
  idx
135
+ for idx, count in Counter(data.ra["ensembl_id_collapsed"]).items()
136
  if count > 1
137
  ]
138
  num_chunks = int(np.ceil(data.shape[1] / chunk_size))
 
143
 
144
  def process_chunk(view, duplic_genes):
145
  data_count_view = pd.DataFrame(
146
+ view, index=data.ra["ensembl_id_collapsed"]
147
  )
148
  unique_data_df = data_count_view.loc[
149
  ~data_count_view.index.isin(duplic_genes)
 
169
 
170
  processed_chunk = process_chunk(view[:, :], dup_genes)
171
  processed_array = processed_chunk.to_numpy()
172
+ new_row_attrs = {"ensembl_id_collapsed": processed_chunk.index.to_numpy()}
173
 
174
  if "n_counts" not in view.ca.keys():
175
  total_count_view = np.sum(view[:, :], axis=0).astype(int)
 
199
  assert (
200
  "ensembl_id" in data.var.columns
201
  ), "'ensembl_id' column missing from data.var"
202
+
203
+ assert (
204
+ "ensembl_id_collapsed" not in data.var.columns
205
+ ), "'ensembl_id_collapsed' column already exists in data.var"
206
+
207
+ # Check for duplicate Ensembl IDs if collapse_gene_ids is False.
208
+ # Comparing to gene_token_dict here, would not perform any mapping steps
209
  gene_ids_in_dict = [
210
  gene for gene in data.var.ensembl_id if gene in gene_token_dict.keys()
211
  ]
 
 
 
 
212
  if collapse_gene_ids is False:
213
+
214
+ if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
215
  return data
216
  else:
217
  raise ValueError("Error: data Ensembl IDs non-unique.")
218
 
219
+ # Check for when if collapse_gene_ids is True
220
  gene_ids_collapsed = [
221
  gene_mapping_dict.get(gene_id.upper()) for gene_id in data.var.ensembl_id
222
  ]
223
  gene_ids_collapsed_in_dict = [
224
  gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()
225
  ]
226
+ if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed_in_dict)):
227
+ data.var["ensembl_id_collapsed"] = data.var.ensembl_id.map(gene_mapping_dict)
 
228
  return data
229
 
230
  else:
231
+ data.var["ensembl_id_collapsed"] = gene_ids_collapsed
232
  data.var_names = gene_ids_collapsed
233
  data = data[:, ~data.var.index.isna()]
234
  dup_genes = [
 
259
  processed_chunks = pd.concat(processed_chunks, axis=1)
260
  processed_genes.append(processed_chunks)
261
  processed_genes = pd.concat(processed_genes, axis=0)
262
+ var_df = pd.DataFrame({"ensembl_id_collapsed": processed_genes.columns})
263
  var_df.index = processed_genes.columns
264
  processed_genes = sc.AnnData(X=processed_genes, obs=data.obs, var=var_df)
265
 
266
  data_dedup = data[:, ~data.var.index.isin(dup_genes)] # Deduplicated data
267
  data_dedup = sc.concat([data_dedup, processed_genes], axis=1)
268
  data_dedup.obs = data.obs
 
 
 
269
  return data_dedup
270
 
271
 
 
465
  }
466
 
467
  coding_miRNA_loc = np.where(
468
+ [self.genelist_dict.get(i, False) for i in adata.var["ensembl_id_collapsed"]]
469
  )[0]
470
  norm_factor_vector = np.array(
471
  [
472
  self.gene_median_dict[i]
473
+ for i in adata.var["ensembl_id_collapsed"][coding_miRNA_loc]
474
  ]
475
  )
476
+ coding_miRNA_ids = adata.var["ensembl_id_collapsed"][coding_miRNA_loc]
477
  coding_miRNA_tokens = np.array(
478
  [self.gene_token_dict[i] for i in coding_miRNA_ids]
479
  )
 
523
  file_cell_metadata = {
524
  attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
525
  }
526
+ loom_file_path_original = loom_file_path
527
 
528
  dedup_filename = loom_file_path.with_name(loom_file_path.stem + "__dedup.loom")
529
  loom_file_path = sum_ensembl_ids(
 
538
  with lp.connect(str(loom_file_path)) as data:
539
  # define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
540
  coding_miRNA_loc = np.where(
541
+ [self.genelist_dict.get(i, False) for i in data.ra["ensembl_id_collapsed"]]
542
  )[0]
543
  norm_factor_vector = np.array(
544
  [
545
  self.gene_median_dict[i]
546
+ for i in data.ra["ensembl_id_collapsed"][coding_miRNA_loc]
547
  ]
548
  )
549
+ coding_miRNA_ids = data.ra["ensembl_id_collapsed"][coding_miRNA_loc]
550
  coding_miRNA_tokens = np.array(
551
  [self.gene_token_dict[i] for i in coding_miRNA_ids]
552
  )
 
599
  if str(dedup_filename) == str(loom_file_path):
600
  os.remove(str(dedup_filename))
601
 
602
+ with lp.connect(str(loom_file_path_original)) as data:
603
+ if "ensembl_id_collapsed" in data.ra.keys():
604
+ del data.ra["ensembl_id_collapsed"]
605
+
606
+
607
  return tokenized_cells, file_cell_metadata
608
 
609
  def create_dataset(