|
""" |
|
Geneformer multi-task cell classifier. |
|
|
|
**Input data:** |
|
|
|
| Single-cell transcriptomes as Geneformer rank value encodings with cell state labels for each task in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py). Must contain "unique_cell_id" column for logging. |
|
|
|
**Usage:** |
|
|
|
.. code-block :: python |
|
|
|
>>> from geneformer import MTLClassifier |
|
>>> mc = MTLClassifier(task_columns = ["task1", "task2"], |
|
... study_name = "mtl", |
|
... pretrained_path = "/path/pretrained/model", |
|
... train_path = "/path/train/set", |
|
... val_path = "/path/eval/set", |
|
... test_path = "/path/test/set", |
|
... model_save_path = "/results/directory/save_path", |
|
... trials_result_path = "/results/directory/results.txt", |
|
... results_dir = "/results/directory", |
|
... tensorboard_log_dir = "/results/tblogdir", |
|
... hyperparameters = hyperparameters) |
|
>>> mc.run_optuna_study() |
|
>>> mc.load_and_evaluate_test_model() |
|
>>> mc.save_model_without_heads() |
|
""" |
|
|
|
import logging |
|
import os |
|
from .mtl import train_utils |
|
from .mtl import utils |
|
from .mtl import eval_utils |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class MTLClassifier: |
|
valid_option_dict = { |
|
"task_columns": {list}, |
|
"train_path": {None, str}, |
|
"val_path": {None, str}, |
|
"test_path": {None, str}, |
|
"pretrained_path": {None, str}, |
|
"model_save_path": {None, str}, |
|
"results_dir": {None, str}, |
|
"batch_size": {None, int}, |
|
"n_trials": {None, int}, |
|
"study_name": {None, str}, |
|
"max_layers_to_freeze": {None, dict}, |
|
"epochs": {None, int}, |
|
"tensorboard_log_dir": {None, str}, |
|
"use_data_parallel": {None, bool}, |
|
"use_attention_pooling": {None, bool}, |
|
"use_task_weights": {None, bool}, |
|
"hyperparameters": {None, dict}, |
|
"manual_hyperparameters": {None, dict}, |
|
"use_manual_hyperparameters": {None, bool}, |
|
"use_wandb": {None, bool}, |
|
"wandb_project": {None, str}, |
|
"gradient_clipping": {None, bool}, |
|
"max_grad_norm": {None, int, float}, |
|
"seed": {None, int}, |
|
"trials_result_path": {None, str}, |
|
} |
|
|
|
def __init__( |
|
self, |
|
task_columns=None, |
|
train_path=None, |
|
val_path=None, |
|
test_path=None, |
|
pretrained_path=None, |
|
model_save_path=None, |
|
results_dir=None, |
|
trials_result_path=None, |
|
batch_size=4, |
|
n_trials=15, |
|
study_name="mtl", |
|
max_layers_to_freeze=None, |
|
epochs=1, |
|
tensorboard_log_dir="/results/tblogdir", |
|
use_data_parallel=False, |
|
use_attention_pooling=True, |
|
use_task_weights=True, |
|
hyperparameters=None, |
|
manual_hyperparameters=None, |
|
use_manual_hyperparameters=False, |
|
use_wandb=False, |
|
wandb_project=None, |
|
gradient_clipping=False, |
|
max_grad_norm=None, |
|
seed=42 |
|
): |
|
|
|
""" |
|
Initialize Geneformer multi-task classifier. |
|
**Parameters:** |
|
task_columns : list |
|
| List of tasks for cell state classification |
|
| Input data columns are labeled with corresponding task names |
|
study_name : None, str |
|
| Study name for labeling output files |
|
pretrained_path : None, str |
|
| Path to pretrained model |
|
train_path : None, str |
|
| Path to training dataset with task columns and "unique_cell_id" column |
|
val_path : None, str |
|
| Path to validation dataset with task columns and "unique_cell_id" column |
|
test_path : None, str |
|
| Path to test dataset with task columns and "unique_cell_id" column |
|
model_save_path : None, str |
|
| Path to directory to save output model (either full model or model without heads) |
|
trials_result_path : None, str |
|
| Path to directory to save hyperparameter tuning trial results |
|
results_dir : None, str |
|
| Path to directory to save results |
|
tensorboard_log_dir : None, str |
|
| Path to directory for Tensorboard logging results |
|
use_data_parallel : None, bool |
|
| Whether to use data parallelization |
|
use_attention_pooling : None, bool |
|
| Whether to use attention pooling |
|
use_task_weights : None, bool |
|
| Whether to use task weights |
|
batch_size : None, int |
|
| Batch size to use |
|
n_trials : None, int |
|
| Number of trials for hyperparameter tuning |
|
epochs : None, int |
|
| Number of epochs for training |
|
max_layers_to_freeze : None, dict |
|
| Dictionary with keys "min" and "max" indicating the min and max layers to freeze from fine-tuning (int) |
|
| 0: no layers will be frozen; 2: first two layers will be frozen; etc. |
|
hyperparameters : None, dict |
|
| Dictionary of categorical max and min for each hyperparameter for tuning |
|
| For example: |
|
| {"learning_rate": {"type":"float", "low":"1e-5", "high":"1e-3", "log":True}, "task_weights": {...}, ...} |
|
manual_hyperparameters : None, dict |
|
| Dictionary of manually set value for each hyperparameter |
|
| For example: |
|
| {"learning_rate": 0.001, "task_weights": [1, 1], ...} |
|
use_manual_hyperparameters : None, bool |
|
| Whether to use manually set hyperparameters |
|
use_wandb : None, bool |
|
| Whether to use Weights & Biases for logging |
|
wandb_project : None, str |
|
| Weights & Biases project name |
|
gradient_clipping : None, bool |
|
| Whether to use gradient clipping |
|
max_grad_norm : None, int, float |
|
| Maximum norm for gradient clipping |
|
seed : None, int |
|
| Random seed |
|
""" |
|
|
|
self.task_columns = task_columns |
|
self.train_path = train_path |
|
self.val_path = val_path |
|
self.test_path = test_path |
|
self.pretrained_path = pretrained_path |
|
self.model_save_path = model_save_path |
|
self.results_dir = results_dir |
|
self.trials_result_path = trials_result_path |
|
self.batch_size = batch_size |
|
self.n_trials = n_trials |
|
self.study_name = study_name |
|
|
|
if max_layers_to_freeze is None: |
|
|
|
layer_freeze_range = utils.get_layer_freeze_range(pretrained_path) |
|
self.max_layers_to_freeze = {"min": 1, "max": layer_freeze_range['max']} |
|
else: |
|
self.max_layers_to_freeze = max_layers_to_freeze |
|
|
|
self.epochs = epochs |
|
self.tensorboard_log_dir = tensorboard_log_dir |
|
self.use_data_parallel = use_data_parallel |
|
self.use_attention_pooling = use_attention_pooling |
|
self.use_task_weights = use_task_weights |
|
self.hyperparameters = hyperparameters if hyperparameters is not None else { |
|
"learning_rate": { |
|
"type": "float", |
|
"low": 1e-5, |
|
"high": 1e-3, |
|
"log": True |
|
}, |
|
"warmup_ratio": { |
|
"type": "float", |
|
"low": 0.005, |
|
"high": 0.01 |
|
}, |
|
"weight_decay": { |
|
"type": "float", |
|
"low": 0.01, |
|
"high": 0.1 |
|
}, |
|
"dropout_rate": { |
|
"type": "float", |
|
"low": 0.0, |
|
"high": 0.7 |
|
}, |
|
"lr_scheduler_type": { |
|
"type": "categorical", |
|
"choices": ["cosine"] |
|
}, |
|
"task_weights": { |
|
"type": "float", |
|
"low": 0.1, |
|
"high": 2.0 |
|
} |
|
} |
|
self.manual_hyperparameters = manual_hyperparameters if manual_hyperparameters is not None else { |
|
"learning_rate": 0.001, |
|
"warmup_ratio": 0.01, |
|
"weight_decay": 0.1, |
|
"dropout_rate": 0.1, |
|
"lr_scheduler_type": "cosine", |
|
"use_attention_pooling": False, |
|
"task_weights": [1, 1], |
|
"max_layers_to_freeze": 2 |
|
} |
|
self.use_manual_hyperparameters = use_manual_hyperparameters |
|
self.use_wandb = use_wandb |
|
self.wandb_project = wandb_project |
|
self.gradient_clipping = gradient_clipping |
|
self.max_grad_norm = max_grad_norm |
|
self.seed = seed |
|
|
|
if self.use_manual_hyperparameters: |
|
logger.warning( |
|
"Hyperparameter tuning is highly recommended for optimal results." |
|
) |
|
|
|
self.validate_options() |
|
|
|
|
|
if self.results_dir is not None: |
|
self.trials_results_path = f"{self.results_dir}/results.txt".replace("//","/") |
|
|
|
for output_dir in [self.model_save_path, self.results_dir]: |
|
if not os.path.exists(output_dir): |
|
os.makedirs(output_dir) |
|
|
|
self.config = {key: value for key, value in self.__dict__.items() if key in self.valid_option_dict} |
|
|
|
def validate_options(self): |
|
|
|
for attr_name, valid_options in self.valid_option_dict.items(): |
|
attr_value = self.__dict__[attr_name] |
|
if not isinstance(attr_value, (list, dict)): |
|
if attr_value in valid_options: |
|
continue |
|
valid_type = False |
|
for option in valid_options: |
|
if (option in [int, float, list, dict, bool, str]) and isinstance( |
|
attr_value, option |
|
): |
|
valid_type = True |
|
break |
|
if valid_type: |
|
continue |
|
logger.error( |
|
f"Invalid option for {attr_name}. " |
|
f"Valid options for {attr_name}: {valid_options}" |
|
) |
|
raise ValueError(f"Invalid option for {attr_name}. Valid options for {attr_name}: {valid_options}") |
|
|
|
def run_manual_tuning(self): |
|
""" |
|
Manual hyperparameter tuning and multi-task fine-tuning of pretrained model. |
|
""" |
|
required_variable_names = ["train_path", "val_path", "pretrained_path", "model_save_path", "results_dir"] |
|
required_variables = [self.train_path, self.val_path, self.pretrained_path, self.model_save_path, self.results_dir] |
|
req_var_dict = dict(zip(required_variable_names, required_variables)) |
|
self.validate_additional_options(req_var_dict) |
|
|
|
if not self.use_manual_hyperparameters: |
|
raise ValueError("Manual hyperparameters are not enabled. Set use_manual_hyperparameters to True.") |
|
|
|
|
|
self.config["manual_hyperparameters"] = self.manual_hyperparameters |
|
self.config["use_manual_hyperparameters"] = True |
|
|
|
train_utils.run_manual_tuning(self.config) |
|
|
|
def validate_additional_options(self, req_var_dict): |
|
missing_variable = False |
|
for variable_name, variable in req_var_dict.items(): |
|
if variable is None: |
|
logger.warning( |
|
f"Please provide value to MTLClassifier for required variable {variable_name}" |
|
) |
|
missing_variable = True |
|
if missing_variable is True: |
|
raise ValueError("Missing required variables for MTLClassifier") |
|
|
|
def run_optuna_study( |
|
self, |
|
): |
|
""" |
|
Hyperparameter optimization and/or multi-task fine-tuning of pretrained model. |
|
""" |
|
|
|
required_variable_names = ["train_path", "val_path", "pretrained_path", "model_save_path", "results_dir"] |
|
required_variables = [self.train_path, self.val_path, self.pretrained_path, self.model_save_path, self.results_dir] |
|
req_var_dict = dict(zip(required_variable_names, required_variables)) |
|
self.validate_additional_options(req_var_dict) |
|
|
|
train_utils.run_optuna_study(self.config) |
|
|
|
def load_and_evaluate_test_model( |
|
self, |
|
): |
|
""" |
|
Loads previously fine-tuned multi-task model and evaluates on test data. |
|
""" |
|
|
|
required_variable_names = ["test_path", "model_save_path", "results_dir"] |
|
required_variables = [self.test_path, self.model_save_path, self.results_dir] |
|
req_var_dict = dict(zip(required_variable_names, required_variables)) |
|
self.validate_additional_options(req_var_dict) |
|
|
|
eval_utils.load_and_evaluate_test_model(self.config) |
|
|
|
def save_model_without_heads( |
|
self, |
|
): |
|
""" |
|
Save previously fine-tuned multi-task model without classification heads. |
|
""" |
|
|
|
required_variable_names = ["model_save_path"] |
|
required_variables = [self.model_save_path] |
|
req_var_dict = dict(zip(required_variable_names, required_variables)) |
|
self.validate_additional_options(req_var_dict) |
|
|
|
utils.save_model_without_heads(os.path.join(self.model_save_path, "GeneformerMultiTask")) |
|
|