Christina Theodoris
Add data collator for cell classification and example for cell classification
088ea6e
#!/usr/bin/env python | |
# coding: utf-8 | |
# run with: | |
# deepspeed --num_gpus=12 --num_nodes=3 pretrain_geneformer_w_deepspeed.py --deepspeed ds_config.json | |
import datetime | |
# imports | |
import os | |
os.environ["NCCL_DEBUG"] = "INFO" | |
os.environ["OMPI_MCA_opal_cuda_support"] = "true" | |
os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56" | |
import pickle | |
import random | |
import subprocess | |
import numpy as np | |
import pytz | |
import torch | |
from datasets import load_from_disk | |
from transformers import BertConfig, BertForMaskedLM, TrainingArguments | |
from geneformer import GeneformerPretrainer | |
seed_num = 0 | |
random.seed(seed_num) | |
np.random.seed(seed_num) | |
seed_val = 42 | |
torch.manual_seed(seed_val) | |
torch.cuda.manual_seed_all(seed_val) | |
# set local time/directories | |
timezone = pytz.timezone("US/Eastern") | |
rootdir = "/parent_ouput_directory" | |
# set model parameters | |
# model type | |
model_type = "bert" | |
# max input size | |
max_input_size = 2**11 # 2048 | |
# number of layers | |
num_layers = 6 | |
# number of attention heads | |
num_attn_heads = 4 | |
# number of embedding dimensions | |
num_embed_dim = 256 | |
# intermediate size | |
intermed_size = num_embed_dim * 2 | |
# activation function | |
activ_fn = "relu" | |
# initializer range, layer norm, dropout | |
initializer_range = 0.02 | |
layer_norm_eps = 1e-12 | |
attention_probs_dropout_prob = 0.02 | |
hidden_dropout_prob = 0.02 | |
# set training parameters | |
# total number of examples in Genecorpus-30M after QC filtering: | |
num_examples = 27_406_208 | |
# number gpus | |
num_gpus = 12 | |
# batch size for training and eval | |
geneformer_batch_size = 12 | |
# max learning rate | |
max_lr = 1e-3 | |
# learning schedule | |
lr_schedule_fn = "linear" | |
# warmup steps | |
warmup_steps = 10_000 | |
# number of epochs | |
epochs = 3 | |
# optimizer | |
optimizer = "adamw" | |
# weight_decay | |
weight_decay = 0.001 | |
# output directories | |
current_date = datetime.datetime.now(tz=timezone) | |
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}_{current_date.strftime('%X').replace(':','')}" | |
run_name = f"{datestamp}_geneformer_30M_L{num_layers}_emb{num_embed_dim}_SL{max_input_size}_E{epochs}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_O{optimizer}_DS{num_gpus}" | |
training_output_dir = f"{rootdir}/models/{run_name}/" | |
logging_dir = f"{rootdir}/runs/{run_name}/" | |
model_output_dir = os.path.join(training_output_dir, "models/") | |
# ensure not overwriting previously saved model | |
model_output_file = os.path.join(model_output_dir, "pytorch_model.bin") | |
if os.path.isfile(model_output_file) is True: | |
raise Exception("Model already saved to this directory.") | |
# make training and model output directories | |
subprocess.call(f"mkdir {training_output_dir}", shell=True) | |
subprocess.call(f"mkdir {model_output_dir}", shell=True) | |
# load gene_ensembl_id:token dictionary (e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/datasets/token_dictionary.pkl) | |
with open("token_dictionary.pkl", "rb") as fp: | |
token_dictionary = pickle.load(fp) | |
# model configuration | |
config = { | |
"hidden_size": num_embed_dim, | |
"num_hidden_layers": num_layers, | |
"initializer_range": initializer_range, | |
"layer_norm_eps": layer_norm_eps, | |
"attention_probs_dropout_prob": attention_probs_dropout_prob, | |
"hidden_dropout_prob": hidden_dropout_prob, | |
"intermediate_size": intermed_size, | |
"hidden_act": activ_fn, | |
"max_position_embeddings": max_input_size, | |
"model_type": model_type, | |
"num_attention_heads": num_attn_heads, | |
"pad_token_id": token_dictionary.get("<pad>"), | |
"vocab_size": len(token_dictionary), # genes+2 for <mask> and <pad> tokens | |
} | |
config = BertConfig(**config) | |
model = BertForMaskedLM(config) | |
model = model.train() | |
# define the training arguments | |
training_args = { | |
"learning_rate": max_lr, | |
"do_train": True, | |
"do_eval": False, | |
"group_by_length": True, | |
"length_column_name": "length", | |
"disable_tqdm": False, | |
"lr_scheduler_type": lr_schedule_fn, | |
"warmup_steps": warmup_steps, | |
"weight_decay": weight_decay, | |
"per_device_train_batch_size": geneformer_batch_size, | |
"num_train_epochs": epochs, | |
"load_best_model_at_end": True, | |
"save_strategy": "steps", | |
"save_steps": num_examples / geneformer_batch_size / 8, # 8 saves per epoch | |
"logging_steps": 1000, | |
"output_dir": training_output_dir, | |
"logging_dir": logging_dir, | |
} | |
training_args = TrainingArguments(**training_args) | |
print("Starting training.") | |
# define the trainer | |
trainer = GeneformerPretrainer( | |
model=model, | |
args=training_args, | |
# pretraining corpus (e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048.dataset) | |
train_dataset=load_from_disk("genecorpus_30M_2048.dataset"), | |
# file of lengths of each example cell (e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048_sorted_lengths.pkl) | |
example_lengths_file="genecorpus_30M_2048_sorted_lengths.pkl", | |
token_dictionary=token_dictionary, | |
) | |
# train | |
trainer.train() | |
# save model | |
trainer.save_model(model_output_dir) | |