Christina Theodoris commited on
Commit
6caf480
·
1 Parent(s): 61f15d2

Add memory-efficient method for computing emb summary statistics

Browse files
Files changed (1) hide show
  1. geneformer/emb_extractor.py +62 -19
geneformer/emb_extractor.py CHANGED
@@ -14,7 +14,8 @@ Usage:
14
  emb_label=["disease","cell_type"],
15
  labels_to_plot=["disease","cell_type"],
16
  forward_batch_size=100,
17
- nproc=16)
 
18
  embs = embex.extract_embs("path/to/model",
19
  "path/to/input_data",
20
  "path/to/output_directory",
@@ -33,6 +34,7 @@ import matplotlib.pyplot as plt
33
  import numpy as np
34
  import pandas as pd
35
  import pickle
 
36
  import scanpy as sc
37
  import seaborn as sns
38
  import torch
@@ -54,20 +56,28 @@ from .in_silico_perturber import downsample_and_sort, \
54
 
55
  logger = logging.getLogger(__name__)
56
 
57
- # average embedding position of goal cell states
58
  def get_embs(model,
59
  filtered_input_data,
60
  emb_mode,
61
  layer_to_quant,
62
  pad_token_id,
63
- forward_batch_size):
 
64
 
65
  model_input_size = get_model_input_size(model)
66
  total_batch_length = len(filtered_input_data)
67
- if ((total_batch_length-1)/forward_batch_size).is_integer():
68
- forward_batch_size = forward_batch_size-1
69
 
70
- embs_list = []
 
 
 
 
 
 
 
 
 
71
  for i in trange(0, total_batch_length, forward_batch_size):
72
  max_range = min(i+forward_batch_size, total_batch_length)
73
 
@@ -81,29 +91,52 @@ def get_embs(model,
81
  max_len,
82
  pad_token_id,
83
  model_input_size)
84
-
85
  with torch.no_grad():
86
  outputs = model(
87
  input_ids = input_data_minibatch.to("cuda"),
88
  attention_mask = gen_attention_mask(minibatch)
89
  )
90
-
91
  embs_i = outputs.hidden_states[layer_to_quant]
92
 
93
  if emb_mode == "cell":
94
  mean_embs = mean_nonpadding_embs(embs_i, original_lens)
95
- embs_list += [mean_embs]
 
 
 
 
 
96
 
97
  del outputs
98
  del minibatch
99
  del input_data_minibatch
100
  del embs_i
101
  del mean_embs
102
- torch.cuda.empty_cache()
103
-
104
- embs_stack = torch.cat(embs_list)
 
 
 
 
 
 
 
 
 
105
  return embs_stack
106
 
 
 
 
 
 
 
 
 
 
107
  def label_embs(embs, downsampled_data, emb_labels):
108
  embs_df = pd.DataFrame(embs.cpu())
109
  if emb_labels is not None:
@@ -131,7 +164,6 @@ def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict):
131
 
132
  sc.pl.umap(adata, color=label, save=output_file, **default_kwargs_dict)
133
 
134
-
135
  def gen_heatmap_class_colors(labels, df):
136
  pal = sns.cubehelix_palette(len(Counter(labels).keys()), light=0.9, dark=0.1, hue=1, reverse=True, start=1, rot=-2)
137
  lut = dict(zip(map(str, Counter(labels).keys()), pal))
@@ -208,6 +240,7 @@ class EmbExtractor:
208
  "labels_to_plot": {None, list},
209
  "forward_batch_size": {int},
210
  "nproc": {int},
 
211
  }
212
  def __init__(
213
  self,
@@ -222,6 +255,7 @@ class EmbExtractor:
222
  labels_to_plot=None,
223
  forward_batch_size=100,
224
  nproc=4,
 
225
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
226
  ):
227
  """
@@ -263,6 +297,10 @@ class EmbExtractor:
263
  Batch size for forward pass.
264
  nproc : int
265
  Number of CPU processes to use.
 
 
 
 
266
  token_dictionary_file : Path
267
  Path to pickle file containing token dictionary (Ensembl ID:token).
268
  """
@@ -278,6 +316,7 @@ class EmbExtractor:
278
  self.labels_to_plot = labels_to_plot
279
  self.forward_batch_size = forward_batch_size
280
  self.nproc = nproc
 
281
 
282
  self.validate_options()
283
 
@@ -353,14 +392,19 @@ class EmbExtractor:
353
  self.emb_mode,
354
  layer_to_quant,
355
  self.pad_token_id,
356
- self.forward_batch_size)
357
- embs_df = label_embs(embs, downsampled_data, self.emb_label)
358
 
 
 
 
 
 
359
  # save embeddings to output_path
360
  output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
361
  embs_df.to_csv(output_path)
362
-
363
- return embs_df
364
 
365
  def plot_embs(self,
366
  embs,
@@ -446,5 +490,4 @@ class EmbExtractor:
446
  continue
447
  output_prefix_label = output_prefix + f"_heatmap_{label}"
448
  output_file = (Path(output_directory) / output_prefix_label).with_suffix(".pdf")
449
- plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
450
-
 
14
  emb_label=["disease","cell_type"],
15
  labels_to_plot=["disease","cell_type"],
16
  forward_batch_size=100,
17
+ nproc=16,
18
+ summary_stat=None)
19
  embs = embex.extract_embs("path/to/model",
20
  "path/to/input_data",
21
  "path/to/output_directory",
 
34
  import numpy as np
35
  import pandas as pd
36
  import pickle
37
+ from tdigest import TDigest
38
  import scanpy as sc
39
  import seaborn as sns
40
  import torch
 
56
 
57
  logger = logging.getLogger(__name__)
58
 
59
+ # extract embeddings
60
  def get_embs(model,
61
  filtered_input_data,
62
  emb_mode,
63
  layer_to_quant,
64
  pad_token_id,
65
+ forward_batch_size,
66
+ summary_stat):
67
 
68
  model_input_size = get_model_input_size(model)
69
  total_batch_length = len(filtered_input_data)
 
 
70
 
71
+ if summary_stat is None:
72
+ embs_list = []
73
+ elif summary_stat is not None:
74
+ # test embedding extraction for example cell and extract # emb dims
75
+ example = filtered_input_data.select([i for i in range(1)])
76
+ example.set_format(type="torch")
77
+ emb_dims = test_emb(model, example["input_ids"], layer_to_quant)
78
+ # initiate tdigests for # of emb dims
79
+ embs_tdigests = [TDigest() for _ in range(emb_dims)]
80
+
81
  for i in trange(0, total_batch_length, forward_batch_size):
82
  max_range = min(i+forward_batch_size, total_batch_length)
83
 
 
91
  max_len,
92
  pad_token_id,
93
  model_input_size)
94
+
95
  with torch.no_grad():
96
  outputs = model(
97
  input_ids = input_data_minibatch.to("cuda"),
98
  attention_mask = gen_attention_mask(minibatch)
99
  )
100
+
101
  embs_i = outputs.hidden_states[layer_to_quant]
102
 
103
  if emb_mode == "cell":
104
  mean_embs = mean_nonpadding_embs(embs_i, original_lens)
105
+ if summary_stat is None:
106
+ embs_list += [mean_embs]
107
+ elif summary_stat is not None:
108
+ # update tdigests with current batch for each emb dim
109
+ # note: tdigest batch update known to be slow so updating serially
110
+ [embs_tdigests[j].update(mean_embs[i,j].item()) for i in range(mean_embs.size(0)) for j in range(emb_dims)]
111
 
112
  del outputs
113
  del minibatch
114
  del input_data_minibatch
115
  del embs_i
116
  del mean_embs
117
+ torch.cuda.empty_cache()
118
+
119
+ if summary_stat is None:
120
+ embs_stack = torch.cat(embs_list)
121
+ # calculate summary stat embs from approximated tdigests
122
+ elif summary_stat is not None:
123
+ if summary_stat == "mean":
124
+ summary_emb_list = [embs_tdigests[i].trimmed_mean(0,100) for i in range(emb_dims)]
125
+ elif summary_stat == "median":
126
+ summary_emb_list = [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
127
+ embs_stack = torch.tensor(summary_emb_list)
128
+
129
  return embs_stack
130
 
131
+ def test_emb(model, example, layer_to_quant):
132
+ with torch.no_grad():
133
+ outputs = model(
134
+ input_ids = example.to("cuda")
135
+ )
136
+
137
+ embs_test = outputs.hidden_states[layer_to_quant]
138
+ return embs_test.size()[2]
139
+
140
  def label_embs(embs, downsampled_data, emb_labels):
141
  embs_df = pd.DataFrame(embs.cpu())
142
  if emb_labels is not None:
 
164
 
165
  sc.pl.umap(adata, color=label, save=output_file, **default_kwargs_dict)
166
 
 
167
  def gen_heatmap_class_colors(labels, df):
168
  pal = sns.cubehelix_palette(len(Counter(labels).keys()), light=0.9, dark=0.1, hue=1, reverse=True, start=1, rot=-2)
169
  lut = dict(zip(map(str, Counter(labels).keys()), pal))
 
240
  "labels_to_plot": {None, list},
241
  "forward_batch_size": {int},
242
  "nproc": {int},
243
+ "summary_stat": {None, "mean", "median"},
244
  }
245
  def __init__(
246
  self,
 
255
  labels_to_plot=None,
256
  forward_batch_size=100,
257
  nproc=4,
258
+ summary_stat=None,
259
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
260
  ):
261
  """
 
297
  Batch size for forward pass.
298
  nproc : int
299
  Number of CPU processes to use.
300
+ summary_stat : {None, "mean", "median"}
301
+ If not None, outputs only approximated mean or median embedding of input data.
302
+ Recommended if encountering memory constraints while generating goal embedding positions.
303
+ Slower but more memory-efficient.
304
  token_dictionary_file : Path
305
  Path to pickle file containing token dictionary (Ensembl ID:token).
306
  """
 
316
  self.labels_to_plot = labels_to_plot
317
  self.forward_batch_size = forward_batch_size
318
  self.nproc = nproc
319
+ self.summary_stat = summary_stat
320
 
321
  self.validate_options()
322
 
 
392
  self.emb_mode,
393
  layer_to_quant,
394
  self.pad_token_id,
395
+ self.forward_batch_size,
396
+ self.summary_stat)
397
 
398
+ if self.summary_stat is None:
399
+ embs_df = label_embs(embs, downsampled_data, self.emb_label)
400
+ elif self.summary_stat is not None:
401
+ embs_df = pd.DataFrame(embs.cpu()).T
402
+
403
  # save embeddings to output_path
404
  output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
405
  embs_df.to_csv(output_path)
406
+
407
+ return embs_df
408
 
409
  def plot_embs(self,
410
  embs,
 
490
  continue
491
  output_prefix_label = output_prefix + f"_heatmap_{label}"
492
  output_file = (Path(output_directory) / output_prefix_label).with_suffix(".pdf")
493
+ plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)