Geneformer / geneformer /mtl_classifier.py
ctheodoris's picture
edit docs formatting
ef094b2
"""
Geneformer multi-task cell classifier.
**Input data:**
| Single-cell transcriptomes as Geneformer rank value encodings with cell state labels for each task in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py). Must contain "unique_cell_id" column for logging.
**Usage:**
.. code-block :: python
>>> from geneformer import MTLClassifier
>>> mc = MTLClassifier(task_columns = ["task1", "task2"],
... study_name = "mtl",
... pretrained_path = "/path/pretrained/model",
... train_path = "/path/train/set",
... val_path = "/path/eval/set",
... test_path = "/path/test/set",
... model_save_path = "/results/directory/save_path",
... trials_result_path = "/results/directory/results.txt",
... results_dir = "/results/directory",
... tensorboard_log_dir = "/results/tblogdir",
... hyperparameters = hyperparameters)
>>> mc.run_optuna_study()
>>> mc.load_and_evaluate_test_model()
>>> mc.save_model_without_heads()
"""
import logging
import os
from .mtl import eval_utils, train_utils, utils
logger = logging.getLogger(__name__)
class MTLClassifier:
valid_option_dict = {
"task_columns": {list},
"train_path": {None, str},
"val_path": {None, str},
"test_path": {None, str},
"pretrained_path": {None, str},
"model_save_path": {None, str},
"results_dir": {None, str},
"batch_size": {None, int},
"n_trials": {None, int},
"study_name": {None, str},
"max_layers_to_freeze": {None, dict},
"epochs": {None, int},
"tensorboard_log_dir": {None, str},
"use_data_parallel": {None, bool},
"use_attention_pooling": {None, bool},
"use_task_weights": {None, bool},
"hyperparameters": {None, dict},
"manual_hyperparameters": {None, dict},
"use_manual_hyperparameters": {None, bool},
"use_wandb": {None, bool},
"wandb_project": {None, str},
"gradient_clipping": {None, bool},
"max_grad_norm": {None, int, float},
"seed": {None, int},
"trials_result_path": {None, str},
}
def __init__(
self,
task_columns=None,
train_path=None,
val_path=None,
test_path=None,
pretrained_path=None,
model_save_path=None,
results_dir=None,
trials_result_path=None,
batch_size=4,
n_trials=15,
study_name="mtl",
max_layers_to_freeze=None,
epochs=1,
tensorboard_log_dir="/results/tblogdir",
use_data_parallel=False,
use_attention_pooling=True,
use_task_weights=True,
hyperparameters=None, # Default is None
manual_hyperparameters=None, # Default is None
use_manual_hyperparameters=False, # Default is False
use_wandb=False,
wandb_project=None,
gradient_clipping=False,
max_grad_norm=None,
seed=42, # Default seed value
):
"""
Initialize Geneformer multi-task classifier.
**Parameters:**
task_columns : list
| List of tasks for cell state classification
| Input data columns are labeled with corresponding task names
study_name : None, str
| Study name for labeling output files
pretrained_path : None, str
| Path to pretrained model
train_path : None, str
| Path to training dataset with task columns and "unique_cell_id" column
val_path : None, str
| Path to validation dataset with task columns and "unique_cell_id" column
test_path : None, str
| Path to test dataset with task columns and "unique_cell_id" column
model_save_path : None, str
| Path to directory to save output model (either full model or model without heads)
trials_result_path : None, str
| Path to directory to save hyperparameter tuning trial results
results_dir : None, str
| Path to directory to save results
tensorboard_log_dir : None, str
| Path to directory for Tensorboard logging results
use_data_parallel : None, bool
| Whether to use data parallelization
use_attention_pooling : None, bool
| Whether to use attention pooling
use_task_weights : None, bool
| Whether to use task weights
batch_size : None, int
| Batch size to use
n_trials : None, int
| Number of trials for hyperparameter tuning
epochs : None, int
| Number of epochs for training
max_layers_to_freeze : None, dict
| Dictionary with keys "min" and "max" indicating the min and max layers to freeze from fine-tuning (int)
| 0: no layers will be frozen; 2: first two layers will be frozen; etc.
hyperparameters : None, dict
| Dictionary of categorical max and min for each hyperparameter for tuning
| For example:
| {"learning_rate": {"type":"float", "low":"1e-5", "high":"1e-3", "log":True}, "task_weights": {...}, ...}
manual_hyperparameters : None, dict
| Dictionary of manually set value for each hyperparameter
| For example:
| {"learning_rate": 0.001, "task_weights": [1, 1], ...}
use_manual_hyperparameters : None, bool
| Whether to use manually set hyperparameters
use_wandb : None, bool
| Whether to use Weights & Biases for logging
wandb_project : None, str
| Weights & Biases project name
gradient_clipping : None, bool
| Whether to use gradient clipping
max_grad_norm : None, int, float
| Maximum norm for gradient clipping
seed : None, int
| Random seed
"""
self.task_columns = task_columns
self.train_path = train_path
self.val_path = val_path
self.test_path = test_path
self.pretrained_path = pretrained_path
self.model_save_path = model_save_path
self.results_dir = results_dir
self.trials_result_path = trials_result_path
self.batch_size = batch_size
self.n_trials = n_trials
self.study_name = study_name
if max_layers_to_freeze is None:
# Dynamically determine the range of layers to freeze
layer_freeze_range = utils.get_layer_freeze_range(pretrained_path)
self.max_layers_to_freeze = {"min": 1, "max": layer_freeze_range["max"]}
else:
self.max_layers_to_freeze = max_layers_to_freeze
self.epochs = epochs
self.tensorboard_log_dir = tensorboard_log_dir
self.use_data_parallel = use_data_parallel
self.use_attention_pooling = use_attention_pooling
self.use_task_weights = use_task_weights
self.hyperparameters = (
hyperparameters
if hyperparameters is not None
else {
"learning_rate": {
"type": "float",
"low": 1e-5,
"high": 1e-3,
"log": True,
},
"warmup_ratio": {"type": "float", "low": 0.005, "high": 0.01},
"weight_decay": {"type": "float", "low": 0.01, "high": 0.1},
"dropout_rate": {"type": "float", "low": 0.0, "high": 0.7},
"lr_scheduler_type": {"type": "categorical", "choices": ["cosine"]},
"task_weights": {"type": "float", "low": 0.1, "high": 2.0},
}
)
self.manual_hyperparameters = (
manual_hyperparameters
if manual_hyperparameters is not None
else {
"learning_rate": 0.001,
"warmup_ratio": 0.01,
"weight_decay": 0.1,
"dropout_rate": 0.1,
"lr_scheduler_type": "cosine",
"use_attention_pooling": False,
"task_weights": [1, 1],
"max_layers_to_freeze": 2,
}
)
self.use_manual_hyperparameters = use_manual_hyperparameters
self.use_wandb = use_wandb
self.wandb_project = wandb_project
self.gradient_clipping = gradient_clipping
self.max_grad_norm = max_grad_norm
self.seed = seed
if self.use_manual_hyperparameters:
logger.warning(
"Hyperparameter tuning is highly recommended for optimal results."
)
self.validate_options()
# set up output directories
if self.results_dir is not None:
self.trials_results_path = f"{self.results_dir}/results.txt".replace(
"//", "/"
)
for output_dir in [self.model_save_path, self.results_dir]:
if not os.path.exists(output_dir):
os.makedirs(output_dir)
self.config = {
key: value
for key, value in self.__dict__.items()
if key in self.valid_option_dict
}
def validate_options(self):
# confirm arguments are within valid options and compatible with each other
for attr_name, valid_options in self.valid_option_dict.items():
attr_value = self.__dict__[attr_name]
if not isinstance(attr_value, (list, dict)):
if attr_value in valid_options:
continue
valid_type = False
for option in valid_options:
if (option in [int, float, list, dict, bool, str]) and isinstance(
attr_value, option
):
valid_type = True
break
if valid_type:
continue
logger.error(
f"Invalid option for {attr_name}. "
f"Valid options for {attr_name}: {valid_options}"
)
raise ValueError(
f"Invalid option for {attr_name}. Valid options for {attr_name}: {valid_options}"
)
def run_manual_tuning(self):
"""
Manual hyperparameter tuning and multi-task fine-tuning of pretrained model.
"""
required_variable_names = [
"train_path",
"val_path",
"pretrained_path",
"model_save_path",
"results_dir",
]
required_variables = [
self.train_path,
self.val_path,
self.pretrained_path,
self.model_save_path,
self.results_dir,
]
req_var_dict = dict(zip(required_variable_names, required_variables))
self.validate_additional_options(req_var_dict)
if not self.use_manual_hyperparameters:
raise ValueError(
"Manual hyperparameters are not enabled. Set use_manual_hyperparameters to True."
)
# Ensure manual_hyperparameters are set in the config
self.config["manual_hyperparameters"] = self.manual_hyperparameters
self.config["use_manual_hyperparameters"] = True
train_utils.run_manual_tuning(self.config)
def validate_additional_options(self, req_var_dict):
missing_variable = False
for variable_name, variable in req_var_dict.items():
if variable is None:
logger.warning(
f"Please provide value to MTLClassifier for required variable {variable_name}"
)
missing_variable = True
if missing_variable is True:
raise ValueError("Missing required variables for MTLClassifier")
def run_optuna_study(
self,
):
"""
Hyperparameter optimization and/or multi-task fine-tuning of pretrained model.
"""
required_variable_names = [
"train_path",
"val_path",
"pretrained_path",
"model_save_path",
"results_dir",
]
required_variables = [
self.train_path,
self.val_path,
self.pretrained_path,
self.model_save_path,
self.results_dir,
]
req_var_dict = dict(zip(required_variable_names, required_variables))
self.validate_additional_options(req_var_dict)
train_utils.run_optuna_study(self.config)
def load_and_evaluate_test_model(
self,
):
"""
Loads previously fine-tuned multi-task model and evaluates on test data.
"""
required_variable_names = ["test_path", "model_save_path", "results_dir"]
required_variables = [self.test_path, self.model_save_path, self.results_dir]
req_var_dict = dict(zip(required_variable_names, required_variables))
self.validate_additional_options(req_var_dict)
eval_utils.load_and_evaluate_test_model(self.config)
# def save_model_without_heads(
# self,
# ):
# """
# Save previously fine-tuned multi-task model without classification heads.
# """
# required_variable_names = ["model_save_path"]
# required_variables = [self.model_save_path]
# req_var_dict = dict(zip(required_variable_names, required_variables))
# self.validate_additional_options(req_var_dict)
# utils.save_model_without_heads(
# os.path.join(self.model_save_path, "GeneformerMultiTask")
# )