hchen725 commited on
Commit
704ef0d
·
verified ·
1 Parent(s): 4cdb505

Update geneformer/tokenizer.py

Browse files

Update to use Ensembl ID mapped throughout

Files changed (1) hide show
  1. geneformer/tokenizer.py +14 -30
geneformer/tokenizer.py CHANGED
@@ -63,17 +63,6 @@ logger = logging.getLogger(__name__)
63
 
64
  from . import ENSEMBL_MAPPING_FILE, GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE
65
 
66
- def rename_attr(data_ra_or_ca, old_name, new_name):
67
- """ Rename attributes
68
- Args:
69
- data_ra_or_ca: data as a record array or column attribute
70
- old_name (str): old name of attribute
71
- new_name (str): new name of attribute
72
- """
73
- data_ra_or_ca[new_name] = data_ra_or_ca[old_name]
74
- if new_name != old_name:
75
- del data_ra_or_ca[old_name]
76
-
77
  def rank_genes(gene_vector, gene_tokens):
78
  """
79
  Rank gene expression vector.
@@ -131,18 +120,16 @@ def sum_ensembl_ids(
131
  ]
132
 
133
  if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed_in_dict)):
134
- # Keep original Ensembl IDs as `ensembl_id_original`
135
- rename_attr(data.ra, "ensembl_id", "ensembl_id_original")
136
- data.ra["ensembl_id"] = gene_ids_collapsed
137
  return data_directory
138
  else:
139
  dedup_filename = data_directory.with_name(
140
  data_directory.stem + "__dedup.loom"
141
  )
142
- data.ra["gene_ids_collapsed"] = gene_ids_collapsed
143
  dup_genes = [
144
  idx
145
- for idx, count in Counter(data.ra["gene_ids_collapsed"]).items()
146
  if count > 1
147
  ]
148
  num_chunks = int(np.ceil(data.shape[1] / chunk_size))
@@ -153,7 +140,7 @@ def sum_ensembl_ids(
153
 
154
  def process_chunk(view, duplic_genes):
155
  data_count_view = pd.DataFrame(
156
- view, index=data.ra["gene_ids_collapsed"]
157
  )
158
  unique_data_df = data_count_view.loc[
159
  ~data_count_view.index.isin(duplic_genes)
@@ -179,7 +166,7 @@ def sum_ensembl_ids(
179
 
180
  processed_chunk = process_chunk(view[:, :], dup_genes)
181
  processed_array = processed_chunk.to_numpy()
182
- new_row_attrs = {"ensembl_id": processed_chunk.index.to_numpy()}
183
 
184
  if "n_counts" not in view.ca.keys():
185
  total_count_view = np.sum(view[:, :], axis=0).astype(int)
@@ -230,11 +217,11 @@ def sum_ensembl_ids(
230
  gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()
231
  ]
232
  if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed_in_dict)):
233
- data.var.ensembl_id = data.var.ensembl_id.map(gene_mapping_dict)
234
  return data
235
 
236
  else:
237
- data.var["gene_ids_collapsed"] = gene_ids_collapsed
238
  data.var_names = gene_ids_collapsed
239
  data = data[:, ~data.var.index.isna()]
240
  dup_genes = [
@@ -265,16 +252,13 @@ def sum_ensembl_ids(
265
  processed_chunks = pd.concat(processed_chunks, axis=1)
266
  processed_genes.append(processed_chunks)
267
  processed_genes = pd.concat(processed_genes, axis=0)
268
- var_df = pd.DataFrame({"gene_ids_collapsed": processed_genes.columns})
269
  var_df.index = processed_genes.columns
270
  processed_genes = sc.AnnData(X=processed_genes, obs=data.obs, var=var_df)
271
 
272
  data_dedup = data[:, ~data.var.index.isin(dup_genes)] # Deduplicated data
273
  data_dedup = sc.concat([data_dedup, processed_genes], axis=1)
274
  data_dedup.obs = data.obs
275
- data_dedup.var = data_dedup.var.rename(
276
- columns={"gene_ids_collapsed": "ensembl_id"}
277
- )
278
  return data_dedup
279
 
280
 
@@ -474,15 +458,15 @@ class TranscriptomeTokenizer:
474
  }
475
 
476
  coding_miRNA_loc = np.where(
477
- [self.genelist_dict.get(i, False) for i in adata.var["ensembl_id"]]
478
  )[0]
479
  norm_factor_vector = np.array(
480
  [
481
  self.gene_median_dict[i]
482
- for i in adata.var["ensembl_id"][coding_miRNA_loc]
483
  ]
484
  )
485
- coding_miRNA_ids = adata.var["ensembl_id"][coding_miRNA_loc]
486
  coding_miRNA_tokens = np.array(
487
  [self.gene_token_dict[i] for i in coding_miRNA_ids]
488
  )
@@ -546,15 +530,15 @@ class TranscriptomeTokenizer:
546
  with lp.connect(str(loom_file_path)) as data:
547
  # define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
548
  coding_miRNA_loc = np.where(
549
- [self.genelist_dict.get(i, False) for i in data.ra["ensembl_id"]]
550
  )[0]
551
  norm_factor_vector = np.array(
552
  [
553
  self.gene_median_dict[i]
554
- for i in data.ra["ensembl_id"][coding_miRNA_loc]
555
  ]
556
  )
557
- coding_miRNA_ids = data.ra["ensembl_id"][coding_miRNA_loc]
558
  coding_miRNA_tokens = np.array(
559
  [self.gene_token_dict[i] for i in coding_miRNA_ids]
560
  )
 
63
 
64
  from . import ENSEMBL_MAPPING_FILE, GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE
65
 
 
 
 
 
 
 
 
 
 
 
 
66
  def rank_genes(gene_vector, gene_tokens):
67
  """
68
  Rank gene expression vector.
 
120
  ]
121
 
122
  if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed_in_dict)):
123
+ data.ra["ensembl_id_collapsed"] = gene_ids_collapsed
 
 
124
  return data_directory
125
  else:
126
  dedup_filename = data_directory.with_name(
127
  data_directory.stem + "__dedup.loom"
128
  )
129
+ data.ra["ensembl_id_collapsed"] = gene_ids_collapsed
130
  dup_genes = [
131
  idx
132
+ for idx, count in Counter(data.ra["ensembl_id_collapsed"]).items()
133
  if count > 1
134
  ]
135
  num_chunks = int(np.ceil(data.shape[1] / chunk_size))
 
140
 
141
  def process_chunk(view, duplic_genes):
142
  data_count_view = pd.DataFrame(
143
+ view, index=data.ra["ensembl_id_collapsed"]
144
  )
145
  unique_data_df = data_count_view.loc[
146
  ~data_count_view.index.isin(duplic_genes)
 
166
 
167
  processed_chunk = process_chunk(view[:, :], dup_genes)
168
  processed_array = processed_chunk.to_numpy()
169
+ new_row_attrs = {"ensembl_id_collapsed": processed_chunk.index.to_numpy()}
170
 
171
  if "n_counts" not in view.ca.keys():
172
  total_count_view = np.sum(view[:, :], axis=0).astype(int)
 
217
  gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()
218
  ]
219
  if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed_in_dict)):
220
+ data.var["ensembl_id_collapsed"] = data.var.ensembl_id.map(gene_mapping_dict)
221
  return data
222
 
223
  else:
224
+ data.var["ensembl_id_collapsed"] = gene_ids_collapsed
225
  data.var_names = gene_ids_collapsed
226
  data = data[:, ~data.var.index.isna()]
227
  dup_genes = [
 
252
  processed_chunks = pd.concat(processed_chunks, axis=1)
253
  processed_genes.append(processed_chunks)
254
  processed_genes = pd.concat(processed_genes, axis=0)
255
+ var_df = pd.DataFrame({"ensembl_id_collapsed": processed_genes.columns})
256
  var_df.index = processed_genes.columns
257
  processed_genes = sc.AnnData(X=processed_genes, obs=data.obs, var=var_df)
258
 
259
  data_dedup = data[:, ~data.var.index.isin(dup_genes)] # Deduplicated data
260
  data_dedup = sc.concat([data_dedup, processed_genes], axis=1)
261
  data_dedup.obs = data.obs
 
 
 
262
  return data_dedup
263
 
264
 
 
458
  }
459
 
460
  coding_miRNA_loc = np.where(
461
+ [self.genelist_dict.get(i, False) for i in adata.var["ensembl_id_collapsed"]]
462
  )[0]
463
  norm_factor_vector = np.array(
464
  [
465
  self.gene_median_dict[i]
466
+ for i in adata.var["ensembl_id_collapsed"][coding_miRNA_loc]
467
  ]
468
  )
469
+ coding_miRNA_ids = adata.var["ensembl_id_collapsed"][coding_miRNA_loc]
470
  coding_miRNA_tokens = np.array(
471
  [self.gene_token_dict[i] for i in coding_miRNA_ids]
472
  )
 
530
  with lp.connect(str(loom_file_path)) as data:
531
  # define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
532
  coding_miRNA_loc = np.where(
533
+ [self.genelist_dict.get(i, False) for i in data.ra["ensembl_id_collapsed"]]
534
  )[0]
535
  norm_factor_vector = np.array(
536
  [
537
  self.gene_median_dict[i]
538
+ for i in data.ra["ensembl_id_collapsed"][coding_miRNA_loc]
539
  ]
540
  )
541
+ coding_miRNA_ids = data.ra["ensembl_id_collapsed"][coding_miRNA_loc]
542
  coding_miRNA_tokens = np.array(
543
  [self.gene_token_dict[i] for i in coding_miRNA_ids]
544
  )