In silico Perturbation Out of Memory
I realise this topic has been discussed before, but I have gone through previous responses and none of those solutions worked for me. I would like to set max_ncells to 'None' but I can only go as far as 6500. I tried reducing cell_inds_to_perturb and forward_batch_size, but neither was sufficient. The OOM occurs before saving disk so clearing memory more often has no effect. I am using an AWS instance with 24GB GPU RAM. Thanks in advance!
isp = InSilicoPerturber(
perturb_type="delete",
perturb_rank_shift=None,
genes_to_perturb=genes,
combos=0,
anchor_gene=None,
model_type="CellClassifier",
num_classes=3,
emb_mode="cell",
cell_emb_style="mean_pool",
filter_data={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]},
cell_states_to_model={
'state_key': 'disease',
'start_state': 'dcm',
'goal_state': 'nf',
'alt_states': ['hcm']
},
cell_inds_to_perturb=[{'start': 1, 'end': 1000}],
max_ncells=6500,
emb_layer=0,
forward_batch_size=3,
nproc=16
)
isp.perturb_data(
"./CellClassifier_cardiomyopathies/",
"./human_dcm_hcm_nf.dataset",
"./perturb_out/",
"prefix_"
)
Error:
โญโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ Traceback (most recent call last) โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฎ
โ /home/ubuntu/Geneformer/examples/in_silico_perturbation.py:44 in <module> โ
โ โ
โ 41 โ
โ 42 cell_ind_ranges_list=[{'start': 1, 'end': 1000}] โ
โ 43 for x in cell_ind_ranges_list: โ
โ โฑ 44 โ perform_perturb(x) โ
โ 45 โ
โ 46 #with ProcessPoolExecutor(max_workers=len(cell_ind_ranges_list)) as executor: โ
โ 47 # executor.map(perform_perturb, cell_ind_ranges_list) โ
โ โ
โ /home/ubuntu/Geneformer/examples/in_silico_perturbation.py:34 in perform_perturb โ
โ โ
โ 31 โ ) โ
โ 32 โ โ
โ 33 โ # outputs intermediate files from in silico perturbation โ
โ โฑ 34 โ isp.perturb_data( โ
โ 35 โ โ "./CellClassifier_cardiomyopathies/", โ
โ 36 โ โ "/home/ubuntu/Geneformer/example_input_files/cell_classification/disease_classif โ
โ 37 โ โ "./perturb_out/", โ
โ โ
โ /home/ubuntu/miniconda3/lib/python3.10/site-packages/geneformer/in_silico_perturber.py:958 in โ
โ perturb_data โ
โ โ
โ 955 โ โ โ โ โ raise โ
โ 956 โ โ โ # get dictionary of average cell state embeddings for comparison โ
โ 957 โ โ โ downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells) โ
โ โฑ 958 โ โ โ state_embs_dict = get_cell_state_avg_embs(model, โ
โ 959 โ โ โ โ โ โ โ โ โ โ โ โ โ downsampled_data, โ
โ 960 โ โ โ โ โ โ โ โ โ โ โ โ โ self.cell_states_to_model, โ
โ 961 โ โ โ โ โ โ โ โ โ โ โ โ โ layer_to_quant, โ
โ โ
โ /home/ubuntu/miniconda3/lib/python3.10/site-packages/geneformer/in_silico_perturber.py:303 in โ
โ get_cell_state_avg_embs โ
โ โ
โ 300 โ โ โ torch.cuda.empty_cache() โ
โ 301 โ โ โ
โ 302 โ โ state_embs = torch.cat(state_embs_list) โ
โ โฑ 303 โ โ avg_state_emb = mean_nonpadding_embs(state_embs, torch.Tensor(original_lens).to( โ
โ 304 โ โ avg_state_emb = torch.mean(avg_state_emb, dim=0, keepdim=True) โ
โ 305 โ โ state_embs_dict[possible_state] = avg_state_emb โ
โ 306 โ return state_embs_dict โ
โ โ
โ /home/ubuntu/miniconda3/lib/python3.10/site-packages/geneformer/in_silico_perturber.py:587 in โ
โ mean_nonpadding_embs โ
โ โ
โ 584 โ mask = mask.unsqueeze(2).expand_as(embs) โ
โ 585 โ โ
โ 586 โ # use the mask to zero out the embeddings in padded areas โ
โ โฑ 587 โ masked_embs = embs * mask.float() โ
โ 588 โ โ
โ 589 โ # sum and divide by the lengths to get the mean of non-padding embs โ
โ 590 โ mean_embs = masked_embs.sum(1) / original_lens.view(-1, 1).float() โ
โฐโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฏ
Thank you for your interest in Geneformer and for your question! In this case the OOM is from calculating the target embedding positions for the cell states being modeled. When setting max_ncells to None, all of the cells in the dataset are used to calculate these initial positions even though the cell_inds_to_perturb subsets the cells used for the perturbation tests. We will look into this memory error to see if itโs possible to obtain these initial embedding positions in a more memory preserving manner, but in the meantime, you could obtain the goal embedding positions separately and modify the code to supply them to each of your runs with subsets of cells. At the end, you can put all the results into the same directory for the stats module to analyze them together.
As a quick fix:
downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells)
num_shards = 40 # data will be split into 40
shard_num= 0 # which shard will be analysed
shards = []
for i in range(num_shards):
shard = downsampled_data.shard(num_shards=num_shards, index=i)
shards.append(shard)
print(f'shard num is {shard_num}')
downsampled_data = shards[shard_num]
Here I'm splitting the data into 40 'shards' and will analyse them separately. With my current set up (g4dn.metal AWS) this will take an estimated 24 days though, so would greatly appreciate any pointers to make this more efficient!
Edit: Running the entire dataset seems to be a little overkill. FDR and cosine shift values dont change much after ~20% of the cells are analysed, which reduces the analysis time significantly.