Geneformer / geneformer /classifier.py
Christina Theodoris
patch datasets save_to_disk
75c67a1
raw
history blame
60.7 kB
"""
Geneformer classifier.
**Input data:**
| Cell state classifier:
| Single-cell transcriptomes as Geneformer rank value encodings with cell state labels in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py)
| Gene classifier:
| Dictionary in format {Gene_label: list(genes)} for gene labels and single-cell transcriptomes as Geneformer rank value encodings in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py)
**Usage:**
.. code-block :: python
>>> from geneformer import Classifier
>>> cc = Classifier(classifier="cell", # example of cell state classifier
... cell_state_dict={"state_key": "disease", "states": "all"},
... filter_data={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]},
... training_args=training_args,
... freeze_layers = 2,
... num_crossval_splits = 1,
... forward_batch_size=200,
... nproc=16)
>>> cc.prepare_data(input_data_file="path/to/input_data",
... output_directory="path/to/output_directory",
... output_prefix="output_prefix")
>>> all_metrics = cc.validate(model_directory="path/to/model",
... prepared_input_data_file=f"path/to/output_directory/{output_prefix}_labeled.dataset",
... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl",
... output_directory="path/to/output_directory",
... output_prefix="output_prefix",
... predict_eval=True)
>>> cc.plot_conf_mat(conf_mat_dict={"Geneformer": all_metrics["conf_matrix"]},
... output_directory="path/to/output_directory",
... output_prefix="output_prefix",
... custom_class_order=["healthy","disease1","disease2"])
>>> cc.plot_predictions(predictions_file=f"path/to/output_directory/datestamp_geneformer_cellClassifier_{output_prefix}/ksplit1/predictions.pkl",
... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl",
... title="disease",
... output_directory="path/to/output_directory",
... output_prefix="output_prefix",
... custom_class_order=["healthy","disease1","disease2"])
"""
import datetime
import logging
import os
import pickle
import subprocess
from pathlib import Path
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm.auto import tqdm, trange
from transformers import Trainer
from transformers.training_args import TrainingArguments
from . import DataCollatorForCellClassification, DataCollatorForGeneClassification
from . import classifier_utils as cu
from . import evaluation_utils as eu
from . import perturber_utils as pu
from .tokenizer import TOKEN_DICTIONARY_FILE
sns.set()
logger = logging.getLogger(__name__)
class Classifier:
valid_option_dict = {
"classifier": {"cell", "gene"},
"cell_state_dict": {None, dict},
"gene_class_dict": {None, dict},
"filter_data": {None, dict},
"rare_threshold": {int, float},
"max_ncells": {None, int},
"max_ncells_per_class": {None, int},
"training_args": {None, dict},
"freeze_layers": {int},
"num_crossval_splits": {0, 1, 5},
"split_sizes": {None, dict},
"no_eval": {bool},
"stratify_splits_col": {None, str},
"forward_batch_size": {int},
"token_dictionary_file": {None, str},
"nproc": {int},
"ngpu": {int},
}
def __init__(
self,
classifier=None,
cell_state_dict=None,
gene_class_dict=None,
filter_data=None,
rare_threshold=0,
max_ncells=None,
max_ncells_per_class=None,
training_args=None,
ray_config=None,
freeze_layers=0,
num_crossval_splits=1,
split_sizes={"train": 0.8, "valid": 0.1, "test": 0.1},
stratify_splits_col=None,
no_eval=False,
forward_batch_size=100,
token_dictionary_file=None,
nproc=4,
ngpu=1,
):
"""
Initialize Geneformer classifier.
**Parameters:**
classifier : {"cell", "gene"}
| Whether to fine-tune a cell state or gene classifier.
cell_state_dict : None, dict
| Cell states to fine-tune model to distinguish.
| Two-item dictionary with keys: state_key and states
| state_key: key specifying name of column in .dataset that defines the states to model
| states: list of values in the state_key column that specifies the states to model
| Alternatively, instead of a list of states, can specify "all" to use all states in that state key from input data.
| Of note, if using "all", states will be defined after data is filtered.
| Must have at least 2 states to model.
| For example: {"state_key": "disease",
| "states": ["nf", "hcm", "dcm"]}
| or
| {"state_key": "disease",
| "states": "all"}
gene_class_dict : None, dict
| Gene classes to fine-tune model to distinguish.
| Dictionary in format: {Gene_label_A: list(geneA1, geneA2, ...),
| Gene_label_B: list(geneB1, geneB2, ...)}
| Gene values should be Ensembl IDs.
filter_data : None, dict
| Default is to fine-tune with all input data.
| Otherwise, dictionary specifying .dataset column name and list of values to filter by.
rare_threshold : float
| Threshold below which rare cell states should be removed.
| For example, setting to 0.05 will remove cell states representing
| < 5% of the total cells from the cell state classifier's possible classes.
max_ncells : None, int
| Maximum number of cells to use for fine-tuning.
| Default is to fine-tune with all input data.
max_ncells_per_class : None, int
| Maximum number of cells per cell class to use for fine-tuning.
| Of note, will be applied after max_ncells above.
| (Only valid for cell classification.)
training_args : None, dict
| Training arguments for fine-tuning.
| If None, defaults will be inferred for 6 layer Geneformer.
| Otherwise, will use the Hugging Face defaults:
| https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
| Note: Hyperparameter tuning is highly recommended, rather than using defaults.
ray_config : None, dict
| Training argument ranges for tuning hyperparameters with Ray.
freeze_layers : int
| Number of layers to freeze from fine-tuning.
| 0: no layers will be frozen; 2: first two layers will be frozen; etc.
num_crossval_splits : {0, 1, 5}
| 0: train on all data without splitting
| 1: split data into train and eval sets by designated split_sizes["valid"]
| 5: split data into 5 folds of train and eval sets by designated split_sizes["valid"]
split_sizes : None, dict
| Dictionary of proportion of data to hold out for train, validation, and test sets
| {"train": 0.8, "valid": 0.1, "test": 0.1} if intending 80/10/10 train/valid/test split
stratify_splits_col : None, str
| Name of column in .dataset to be used for stratified splitting.
| Proportion of each class in this column will be the same in the splits as in the original dataset.
no_eval : bool
| If True, will skip eval step and use all data for training.
| Otherwise, will perform eval during training.
forward_batch_size : int
| Batch size for forward pass (for evaluation, not training).
token_dictionary_file : None, str
| Default is to use token dictionary file from Geneformer
| Otherwise, will load custom gene token dictionary.
nproc : int
| Number of CPU processes to use.
ngpu : int
| Number of GPUs available.
"""
self.classifier = classifier
if self.classifier == "cell":
self.model_type = "CellClassifier"
elif self.classifier == "gene":
self.model_type = "GeneClassifier"
self.cell_state_dict = cell_state_dict
self.gene_class_dict = gene_class_dict
self.filter_data = filter_data
self.rare_threshold = rare_threshold
self.max_ncells = max_ncells
self.max_ncells_per_class = max_ncells_per_class
self.training_args = training_args
self.ray_config = ray_config
self.freeze_layers = freeze_layers
self.num_crossval_splits = num_crossval_splits
self.split_sizes = split_sizes
self.train_size = self.split_sizes["train"]
self.valid_size = self.split_sizes["valid"]
self.oos_test_size = self.split_sizes["test"]
self.eval_size = self.valid_size / (self.train_size + self.valid_size)
self.stratify_splits_col = stratify_splits_col
self.no_eval = no_eval
self.forward_batch_size = forward_batch_size
self.token_dictionary_file = token_dictionary_file
self.nproc = nproc
self.ngpu = ngpu
if self.training_args is None:
logger.warning(
"Hyperparameter tuning is highly recommended for optimal results. "
"No training_args provided; using default hyperparameters."
)
self.validate_options()
if self.filter_data is None:
self.filter_data = dict()
if self.classifier == "cell":
if self.cell_state_dict["states"] != "all":
self.filter_data[
self.cell_state_dict["state_key"]
] = self.cell_state_dict["states"]
# load token dictionary (Ensembl IDs:token)
if self.token_dictionary_file is None:
self.token_dictionary_file = TOKEN_DICTIONARY_FILE
with open(self.token_dictionary_file, "rb") as f:
self.gene_token_dict = pickle.load(f)
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
# filter genes for gene classification for those in token dictionary
if self.classifier == "gene":
all_gene_class_values = set(pu.flatten_list(self.gene_class_dict.values()))
missing_genes = [
gene
for gene in all_gene_class_values
if gene not in self.gene_token_dict.keys()
]
if len(missing_genes) == len(all_gene_class_values):
logger.error(
"None of the provided genes to classify are in token dictionary."
)
raise
elif len(missing_genes) > 0:
logger.warning(
f"Genes to classify {missing_genes} are not in token dictionary."
)
self.gene_class_dict = {
k: set([self.gene_token_dict.get(gene) for gene in v])
for k, v in self.gene_class_dict.items()
}
empty_classes = []
for k, v in self.gene_class_dict.items():
if len(v) == 0:
empty_classes += [k]
if len(empty_classes) > 0:
logger.error(
f"Class(es) {empty_classes} did not contain any genes in the token dictionary."
)
raise
def validate_options(self):
# confirm arguments are within valid options and compatible with each other
for attr_name, valid_options in self.valid_option_dict.items():
attr_value = self.__dict__[attr_name]
if not isinstance(attr_value, (list, dict)):
if attr_value in valid_options:
continue
valid_type = False
for option in valid_options:
if (option in [int, float, list, dict, bool, str]) and isinstance(
attr_value, option
):
valid_type = True
break
if valid_type:
continue
logger.error(
f"Invalid option for {attr_name}. "
f"Valid options for {attr_name}: {valid_options}"
)
raise
if self.filter_data is not None:
for key, value in self.filter_data.items():
if not isinstance(value, list):
self.filter_data[key] = [value]
logger.warning(
"Values in filter_data dict must be lists. "
f"Changing {key} value to list ([{value}])."
)
if self.classifier == "cell":
if set(self.cell_state_dict.keys()) != set(["state_key", "states"]):
logger.error(
"Invalid keys for cell_state_dict. "
"The cell_state_dict should have only 2 keys: state_key and states"
)
raise
if self.cell_state_dict["states"] != "all":
if not isinstance(self.cell_state_dict["states"], list):
logger.error(
"States in cell_state_dict should be list of states to model."
)
raise
if len(self.cell_state_dict["states"]) < 2:
logger.error(
"States in cell_state_dict should contain at least 2 states to classify."
)
raise
if self.classifier == "gene":
if len(self.gene_class_dict.keys()) < 2:
logger.error(
"Gene_class_dict should contain at least 2 gene classes to classify."
)
raise
if sum(self.split_sizes.values()) != 1:
logger.error("Train, validation, and test proportions should sum to 1.")
raise
def prepare_data(
self,
input_data_file,
output_directory,
output_prefix,
split_id_dict=None,
test_size=None,
attr_to_split=None,
attr_to_balance=None,
max_trials=100,
pval_threshold=0.1,
):
"""
Prepare data for cell state or gene classification.
**Parameters**
input_data_file : Path
| Path to directory containing .dataset input
output_directory : Path
| Path to directory where prepared data will be saved
output_prefix : str
| Prefix for output file
split_id_dict : None, dict
| Dictionary of IDs for train and test splits
| Three-item dictionary with keys: attr_key, train, test
| attr_key: key specifying name of column in .dataset that contains the IDs for the data splits
| train: list of IDs in the attr_key column to include in the train split
| test: list of IDs in the attr_key column to include in the test split
| For example: {"attr_key": "individual",
| "train": ["patient1", "patient2", "patient3", "patient4"],
| "test": ["patient5", "patient6"]}
test_size : None, float
| Proportion of data to be saved separately and held out for test set
| (e.g. 0.2 if intending hold out 20%)
| If None, will inherit from split_sizes["test"] from Classifier
| The training set will be further split to train / validation in self.validate
| Note: only available for CellClassifiers
attr_to_split : None, str
| Key for attribute on which to split data while balancing potential confounders
| e.g. "patient_id" for splitting by patient while balancing other characteristics
| Note: only available for CellClassifiers
attr_to_balance : None, list
| List of attribute keys on which to balance data while splitting on attr_to_split
| e.g. ["age", "sex"] for balancing these characteristics while splitting by patient
| Note: only available for CellClassifiers
max_trials : None, int
| Maximum number of trials of random splitting to try to achieve balanced other attributes
| If no split is found without significant (p<0.05) differences in other attributes, will select best
| Note: only available for CellClassifiers
pval_threshold : None, float
| P-value threshold to use for attribute balancing across splits
| E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
"""
if test_size is None:
test_size = self.oos_test_size
# prepare data and labels for classification
data = pu.load_and_filter(self.filter_data, self.nproc, input_data_file)
if self.classifier == "cell":
if "label" in data.features:
logger.error(
"Column name 'label' must be reserved for class IDs. Please rename column."
)
raise
elif self.classifier == "gene":
if "labels" in data.features:
logger.error(
"Column name 'labels' must be reserved for class IDs. Please rename column."
)
raise
if self.classifier == "cell":
# remove cell states representing < rare_threshold of cells
data = cu.remove_rare(
data, self.rare_threshold, self.cell_state_dict["state_key"], self.nproc
)
# downsample max cells and max per class
data = cu.downsample_and_shuffle(
data, self.max_ncells, self.max_ncells_per_class, self.cell_state_dict
)
# rename cell state column to "label"
data = cu.rename_cols(data, self.cell_state_dict["state_key"])
# convert classes to numerical labels and save as id_class_dict
# of note, will label all genes in gene_class_dict
# if (cross-)validating, genes will be relabeled in column "labels" for each split
# at the time of training with Classifier.validate
data, id_class_dict = cu.label_classes(
self.classifier, data, self.gene_class_dict, self.nproc
)
# save id_class_dict for future reference
id_class_output_path = (
Path(output_directory) / f"{output_prefix}_id_class_dict"
).with_suffix(".pkl")
with open(id_class_output_path, "wb") as f:
pickle.dump(id_class_dict, f)
if split_id_dict is not None:
data_dict = dict()
data_dict["train"] = pu.filter_by_dict(
data, {split_id_dict["attr_key"]: split_id_dict["train"]}, self.nproc
)
data_dict["test"] = pu.filter_by_dict(
data, {split_id_dict["attr_key"]: split_id_dict["test"]}, self.nproc
)
train_data_output_path = (
Path(output_directory) / f"{output_prefix}_labeled_train"
).with_suffix(".dataset")
test_data_output_path = (
Path(output_directory) / f"{output_prefix}_labeled_test"
).with_suffix(".dataset")
data_dict["train"].save_to_disk(str(train_data_output_path))
data_dict["test"].save_to_disk(str(test_data_output_path))
elif (test_size is not None) and (self.classifier == "cell"):
if 1 > test_size > 0:
if attr_to_split is None:
data_dict = data.train_test_split(
test_size=test_size,
stratify_by_column=self.stratify_splits_col,
seed=42,
)
train_data_output_path = (
Path(output_directory) / f"{output_prefix}_labeled_train"
).with_suffix(".dataset")
test_data_output_path = (
Path(output_directory) / f"{output_prefix}_labeled_test"
).with_suffix(".dataset")
data_dict["train"].save_to_disk(str(train_data_output_path))
data_dict["test"].save_to_disk(str(test_data_output_path))
else:
data_dict, balance_df = cu.balance_attr_splits(
data,
attr_to_split,
attr_to_balance,
test_size,
max_trials,
pval_threshold,
self.cell_state_dict["state_key"],
self.nproc,
)
balance_df.to_csv(
f"{output_directory}/{output_prefix}_train_test_balance_df.csv"
)
train_data_output_path = (
Path(output_directory) / f"{output_prefix}_labeled_train"
).with_suffix(".dataset")
test_data_output_path = (
Path(output_directory) / f"{output_prefix}_labeled_test"
).with_suffix(".dataset")
data_dict["train"].save_to_disk(str(train_data_output_path))
data_dict["test"].save_to_disk(str(test_data_output_path))
else:
data_output_path = (
Path(output_directory) / f"{output_prefix}_labeled"
).with_suffix(".dataset")
data.save_to_disk(str(data_output_path))
print(data_output_path)
else:
data_output_path = (
Path(output_directory) / f"{output_prefix}_labeled"
).with_suffix(".dataset")
data.save_to_disk(str(data_output_path))
def train_all_data(
self,
model_directory,
prepared_input_data_file,
id_class_dict_file,
output_directory,
output_prefix,
save_eval_output=True,
):
"""
Train cell state or gene classifier using all data.
**Parameters**
model_directory : Path
| Path to directory containing model
prepared_input_data_file : Path
| Path to directory containing _labeled.dataset previously prepared by Classifier.prepare_data
id_class_dict_file : Path
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
| (dictionary of format: numerical IDs: class_labels)
output_directory : Path
| Path to directory where model and eval data will be saved
output_prefix : str
| Prefix for output files
save_eval_output : bool
| Whether to save cross-fold eval output
| Saves as pickle file of dictionary of eval metrics
**Output**
Returns trainer after fine-tuning with all data.
"""
##### Load data and prepare output directory #####
# load numerical id to class dictionary (id:class)
with open(id_class_dict_file, "rb") as f:
id_class_dict = pickle.load(f)
class_id_dict = {v: k for k, v in id_class_dict.items()}
# load previously filtered and prepared data
data = pu.load_and_filter(None, self.nproc, prepared_input_data_file)
data = data.shuffle(seed=42) # reshuffle in case users provide unshuffled data
# define output directory path
current_date = datetime.datetime.now()
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
if output_directory[-1:] != "/": # add slash for dir if not present
output_directory = output_directory + "/"
output_dir = f"{output_directory}{datestamp}_geneformer_{self.classifier}Classifier_{output_prefix}/"
subprocess.call(f"mkdir {output_dir}", shell=True)
# get number of classes for classifier
num_classes = cu.get_num_classes(id_class_dict)
if self.classifier == "gene":
targets = pu.flatten_list(self.gene_class_dict.values())
labels = pu.flatten_list(
[
[class_id_dict[label]] * len(targets)
for label, targets in self.gene_class_dict.items()
]
)
assert len(targets) == len(labels)
data = cu.prep_gene_classifier_all_data(
data, targets, labels, self.max_ncells, self.nproc
)
trainer = self.train_classifier(
model_directory, num_classes, data, None, output_dir
)
return trainer
def validate(
self,
model_directory,
prepared_input_data_file,
id_class_dict_file,
output_directory,
output_prefix,
split_id_dict=None,
attr_to_split=None,
attr_to_balance=None,
max_trials=100,
pval_threshold=0.1,
save_eval_output=True,
predict_eval=True,
predict_trainer=False,
n_hyperopt_trials=0,
):
"""
(Cross-)validate cell state or gene classifier.
**Parameters**
model_directory : Path
| Path to directory containing model
prepared_input_data_file : Path
| Path to directory containing _labeled.dataset previously prepared by Classifier.prepare_data
id_class_dict_file : Path
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
| (dictionary of format: numerical IDs: class_labels)
output_directory : Path
| Path to directory where model and eval data will be saved
output_prefix : str
| Prefix for output files
split_id_dict : None, dict
| Dictionary of IDs for train and eval splits
| Three-item dictionary with keys: attr_key, train, eval
| attr_key: key specifying name of column in .dataset that contains the IDs for the data splits
| train: list of IDs in the attr_key column to include in the train split
| eval: list of IDs in the attr_key column to include in the eval split
| For example: {"attr_key": "individual",
| "train": ["patient1", "patient2", "patient3", "patient4"],
| "eval": ["patient5", "patient6"]}
| Note: only available for CellClassifiers with 1-fold split (self.classifier="cell"; self.num_crossval_splits=1)
attr_to_split : None, str
| Key for attribute on which to split data while balancing potential confounders
| e.g. "patient_id" for splitting by patient while balancing other characteristics
| Note: only available for CellClassifiers with 1-fold split (self.classifier="cell"; self.num_crossval_splits=1)
attr_to_balance : None, list
| List of attribute keys on which to balance data while splitting on attr_to_split
| e.g. ["age", "sex"] for balancing these characteristics while splitting by patient
max_trials : None, int
| Maximum number of trials of random splitting to try to achieve balanced other attribute
| If no split is found without significant (p < pval_threshold) differences in other attributes, will select best
pval_threshold : None, float
| P-value threshold to use for attribute balancing across splits
| E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
save_eval_output : bool
| Whether to save cross-fold eval output
| Saves as pickle file of dictionary of eval metrics
predict_eval : bool
| Whether or not to save eval predictions
| Saves as a pickle file of self.evaluate predictions
predict_trainer : bool
| Whether or not to save eval predictions from trainer
| Saves as a pickle file of trainer predictions
n_hyperopt_trials : int
| Number of trials to run for hyperparameter optimization
| If 0, will not optimize hyperparameters
"""
if self.num_crossval_splits == 0:
logger.error("num_crossval_splits must be 1 or 5 to validate.")
raise
# ensure number of genes in each class is > 5 if validating model
if self.classifier == "gene":
insuff_classes = [k for k, v in self.gene_class_dict.items() if len(v) < 5]
if (self.num_crossval_splits > 0) and (len(insuff_classes) > 0):
logger.error(
f"Insufficient # of members in class(es) {insuff_classes} to (cross-)validate."
)
raise
##### Load data and prepare output directory #####
# load numerical id to class dictionary (id:class)
with open(id_class_dict_file, "rb") as f:
id_class_dict = pickle.load(f)
class_id_dict = {v: k for k, v in id_class_dict.items()}
# load previously filtered and prepared data
data = pu.load_and_filter(None, self.nproc, prepared_input_data_file)
data = data.shuffle(seed=42) # reshuffle in case users provide unshuffled data
# define output directory path
current_date = datetime.datetime.now()
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
if output_directory[-1:] != "/": # add slash for dir if not present
output_directory = output_directory + "/"
output_dir = f"{output_directory}{datestamp}_geneformer_{self.classifier}Classifier_{output_prefix}/"
subprocess.call(f"mkdir {output_dir}", shell=True)
# get number of classes for classifier
num_classes = cu.get_num_classes(id_class_dict)
##### (Cross-)validate the model #####
results = []
all_conf_mat = np.zeros((num_classes, num_classes))
iteration_num = 1
if self.classifier == "cell":
for i in trange(self.num_crossval_splits):
print(
f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
)
ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
if self.num_crossval_splits == 1:
# single 1-eval_size:eval_size split
if split_id_dict is not None:
data_dict = dict()
data_dict["train"] = pu.filter_by_dict(
data,
{split_id_dict["attr_key"]: split_id_dict["train"]},
self.nproc,
)
data_dict["test"] = pu.filter_by_dict(
data,
{split_id_dict["attr_key"]: split_id_dict["eval"]},
self.nproc,
)
elif attr_to_split is not None:
data_dict, balance_df = cu.balance_attr_splits(
data,
attr_to_split,
attr_to_balance,
self.eval_size,
max_trials,
pval_threshold,
self.cell_state_dict["state_key"],
self.nproc,
)
balance_df.to_csv(
f"{output_dir}/{output_prefix}_train_valid_balance_df.csv"
)
else:
data_dict = data.train_test_split(
test_size=self.eval_size,
stratify_by_column=self.stratify_splits_col,
seed=42,
)
train_data = data_dict["train"]
eval_data = data_dict["test"]
else:
# 5-fold cross-validate
num_cells = len(data)
fifth_cells = num_cells * 0.2
num_eval = min((self.eval_size * num_cells), fifth_cells)
start = i * fifth_cells
end = start + num_eval
eval_indices = [j for j in range(start, end)]
train_indices = [
j for j in range(num_cells) if j not in eval_indices
]
eval_data = data.select(eval_indices)
train_data = data.select(train_indices)
if n_hyperopt_trials == 0:
trainer = self.train_classifier(
model_directory,
num_classes,
train_data,
eval_data,
ksplit_output_dir,
predict_trainer,
)
else:
trainer = self.hyperopt_classifier(
model_directory,
num_classes,
train_data,
eval_data,
ksplit_output_dir,
n_trials=n_hyperopt_trials,
)
if iteration_num == self.num_crossval_splits:
return
else:
iteration_num = iteration_num + 1
continue
result = self.evaluate_model(
trainer.model,
num_classes,
id_class_dict,
eval_data,
predict_eval,
ksplit_output_dir,
output_prefix,
)
results += [result]
all_conf_mat = all_conf_mat + result["conf_mat"]
iteration_num = iteration_num + 1
elif self.classifier == "gene":
# set up (cross-)validation splits
targets = pu.flatten_list(self.gene_class_dict.values())
labels = pu.flatten_list(
[
[class_id_dict[label]] * len(targets)
for label, targets in self.gene_class_dict.items()
]
)
assert len(targets) == len(labels)
n_splits = int(1 / (1 - self.train_size))
skf = cu.StratifiedKFold3(n_splits=n_splits, random_state=0, shuffle=True)
# (Cross-)validate
test_ratio = self.oos_test_size / (self.eval_size + self.oos_test_size)
for train_index, eval_index, test_index in tqdm(
skf.split(targets, labels, test_ratio)
):
print(
f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
)
ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
# filter data for examples containing classes for this split
# subsample to max_ncells and relabel data in column "labels"
train_data, eval_data = cu.prep_gene_classifier_train_eval_split(
data,
targets,
labels,
train_index,
eval_index,
self.max_ncells,
iteration_num,
self.nproc,
)
if self.oos_test_size > 0:
test_data = cu.prep_gene_classifier_split(
data,
targets,
labels,
test_index,
"test",
self.max_ncells,
iteration_num,
self.nproc,
)
if n_hyperopt_trials == 0:
trainer = self.train_classifier(
model_directory,
num_classes,
train_data,
eval_data,
ksplit_output_dir,
predict_trainer,
)
result = self.evaluate_model(
trainer.model,
num_classes,
id_class_dict,
eval_data,
predict_eval,
ksplit_output_dir,
output_prefix,
)
else:
trainer = self.hyperopt_classifier(
model_directory,
num_classes,
train_data,
eval_data,
ksplit_output_dir,
n_trials=n_hyperopt_trials,
)
model = cu.load_best_model(
ksplit_output_dir, self.model_type, num_classes
)
if self.oos_test_size > 0:
result = self.evaluate_model(
model,
num_classes,
id_class_dict,
test_data,
predict_eval,
ksplit_output_dir,
output_prefix,
)
else:
if iteration_num == self.num_crossval_splits:
return
else:
iteration_num = iteration_num + 1
continue
results += [result]
all_conf_mat = all_conf_mat + result["conf_mat"]
# break after 1 or 5 splits, each with train/eval proportions dictated by eval_size
if iteration_num == self.num_crossval_splits:
break
iteration_num = iteration_num + 1
all_conf_mat_df = pd.DataFrame(
all_conf_mat, columns=id_class_dict.values(), index=id_class_dict.values()
)
all_metrics = {
"conf_matrix": all_conf_mat_df,
"macro_f1": [result["macro_f1"] for result in results],
"acc": [result["acc"] for result in results],
}
all_roc_metrics = None # roc metrics not reported for multiclass
if num_classes == 2:
mean_fpr = np.linspace(0, 1, 100)
all_tpr = [result["roc_metrics"]["interp_tpr"] for result in results]
all_roc_auc = [result["roc_metrics"]["auc"] for result in results]
all_tpr_wt = [result["roc_metrics"]["tpr_wt"] for result in results]
mean_tpr, roc_auc, roc_auc_sd = eu.get_cross_valid_roc_metrics(
all_tpr, all_roc_auc, all_tpr_wt
)
all_roc_metrics = {
"mean_tpr": mean_tpr,
"mean_fpr": mean_fpr,
"all_roc_auc": all_roc_auc,
"roc_auc": roc_auc,
"roc_auc_sd": roc_auc_sd,
}
all_metrics["all_roc_metrics"] = all_roc_metrics
if save_eval_output is True:
eval_metrics_output_path = (
Path(output_dir) / f"{output_prefix}_eval_metrics_dict"
).with_suffix(".pkl")
with open(eval_metrics_output_path, "wb") as f:
pickle.dump(all_metrics, f)
return all_metrics
def hyperopt_classifier(
self,
model_directory,
num_classes,
train_data,
eval_data,
output_directory,
n_trials=100,
):
"""
Fine-tune model for cell state or gene classification.
**Parameters**
model_directory : Path
| Path to directory containing model
num_classes : int
| Number of classes for classifier
train_data : Dataset
| Loaded training .dataset input
| For cell classifier, labels in column "label".
| For gene classifier, labels in column "labels".
eval_data : None, Dataset
| (Optional) Loaded evaluation .dataset input
| For cell classifier, labels in column "label".
| For gene classifier, labels in column "labels".
output_directory : Path
| Path to directory where fine-tuned model will be saved
n_trials : int
| Number of trials to run for hyperparameter optimization
"""
# initiate runtime environment for raytune
import ray
from ray import tune
from ray.tune.search.hyperopt import HyperOptSearch
ray.shutdown() # engage new ray session
ray.init()
##### Validate and prepare data #####
train_data, eval_data = cu.validate_and_clean_cols(
train_data, eval_data, self.classifier
)
if (self.no_eval is True) and (eval_data is not None):
logger.warning(
"no_eval set to True; hyperparameter optimization requires eval, proceeding with eval"
)
# ensure not overwriting previously saved model
saved_model_test = os.path.join(output_directory, "pytorch_model.bin")
if os.path.isfile(saved_model_test) is True:
logger.error("Model already saved to this designated output directory.")
raise
# make output directory
subprocess.call(f"mkdir {output_directory}", shell=True)
##### Load model and training args #####
model = pu.load_model(self.model_type, num_classes, model_directory, "train")
def_training_args, def_freeze_layers = cu.get_default_train_args(
model, self.classifier, train_data, output_directory
)
del model
if self.training_args is not None:
def_training_args.update(self.training_args)
logging_steps = round(
len(train_data) / def_training_args["per_device_train_batch_size"] / 10
)
def_training_args["logging_steps"] = logging_steps
def_training_args["output_dir"] = output_directory
if eval_data is None:
def_training_args["evaluation_strategy"] = "no"
def_training_args["load_best_model_at_end"] = False
def_training_args.update(
{"save_strategy": "epoch", "save_total_limit": 1}
) # only save last model for each run
training_args_init = TrainingArguments(**def_training_args)
##### Fine-tune the model #####
# define the data collator
if self.classifier == "cell":
data_collator = DataCollatorForCellClassification()
elif self.classifier == "gene":
data_collator = DataCollatorForGeneClassification()
# define function to initiate model
def model_init():
model = pu.load_model(
self.model_type, num_classes, model_directory, "train"
)
if self.freeze_layers is not None:
def_freeze_layers = self.freeze_layers
if def_freeze_layers > 0:
modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers]
for module in modules_to_freeze:
for param in module.parameters():
param.requires_grad = False
model = model.to("cuda:0")
return model
# create the trainer
trainer = Trainer(
model_init=model_init,
args=training_args_init,
data_collator=data_collator,
train_dataset=train_data,
eval_dataset=eval_data,
compute_metrics=cu.compute_metrics,
)
# specify raytune hyperparameter search space
if self.ray_config is None:
logger.warning(
"No ray_config provided. Proceeding with default, but ranges may need adjustment depending on model."
)
def_ray_config = {
"num_train_epochs": tune.choice([1]),
"learning_rate": tune.loguniform(1e-6, 1e-3),
"weight_decay": tune.uniform(0.0, 0.3),
"lr_scheduler_type": tune.choice(["linear", "cosine", "polynomial"]),
"warmup_steps": tune.uniform(100, 2000),
"seed": tune.uniform(0, 100),
"per_device_train_batch_size": tune.choice(
[def_training_args["per_device_train_batch_size"]]
),
}
hyperopt_search = HyperOptSearch(metric="eval_macro_f1", mode="max")
# optimize hyperparameters
trainer.hyperparameter_search(
direction="maximize",
backend="ray",
resources_per_trial={"cpu": int(self.nproc / self.ngpu), "gpu": 1},
hp_space=lambda _: def_ray_config
if self.ray_config is None
else self.ray_config,
search_alg=hyperopt_search,
n_trials=n_trials, # number of trials
progress_reporter=tune.CLIReporter(
max_report_frequency=600,
sort_by_metric=True,
max_progress_rows=n_trials,
mode="max",
metric="eval_macro_f1",
metric_columns=["loss", "eval_loss", "eval_accuracy", "eval_macro_f1"],
),
local_dir=output_directory,
)
return trainer
def train_classifier(
self,
model_directory,
num_classes,
train_data,
eval_data,
output_directory,
predict=False,
):
"""
Fine-tune model for cell state or gene classification.
**Parameters**
model_directory : Path
| Path to directory containing model
num_classes : int
| Number of classes for classifier
train_data : Dataset
| Loaded training .dataset input
| For cell classifier, labels in column "label".
| For gene classifier, labels in column "labels".
eval_data : None, Dataset
| (Optional) Loaded evaluation .dataset input
| For cell classifier, labels in column "label".
| For gene classifier, labels in column "labels".
output_directory : Path
| Path to directory where fine-tuned model will be saved
predict : bool
| Whether or not to save eval predictions from trainer
"""
##### Validate and prepare data #####
train_data, eval_data = cu.validate_and_clean_cols(
train_data, eval_data, self.classifier
)
if (self.no_eval is True) and (eval_data is not None):
logger.warning(
"no_eval set to True; model will be trained without evaluation."
)
eval_data = None
if (self.classifier == "gene") and (predict is True):
logger.warning(
"Predictions during training not currently available for gene classifiers; setting predict to False."
)
predict = False
# ensure not overwriting previously saved model
saved_model_test = os.path.join(output_directory, "pytorch_model.bin")
if os.path.isfile(saved_model_test) is True:
logger.error("Model already saved to this designated output directory.")
raise
# make output directory
subprocess.call(f"mkdir {output_directory}", shell=True)
##### Load model and training args #####
model = pu.load_model(self.model_type, num_classes, model_directory, "train")
def_training_args, def_freeze_layers = cu.get_default_train_args(
model, self.classifier, train_data, output_directory
)
if self.training_args is not None:
def_training_args.update(self.training_args)
logging_steps = round(
len(train_data) / def_training_args["per_device_train_batch_size"] / 10
)
def_training_args["logging_steps"] = logging_steps
def_training_args["output_dir"] = output_directory
if eval_data is None:
def_training_args["evaluation_strategy"] = "no"
def_training_args["load_best_model_at_end"] = False
training_args_init = TrainingArguments(**def_training_args)
if self.freeze_layers is not None:
def_freeze_layers = self.freeze_layers
if def_freeze_layers > 0:
modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers]
for module in modules_to_freeze:
for param in module.parameters():
param.requires_grad = False
##### Fine-tune the model #####
# define the data collator
if self.classifier == "cell":
data_collator = DataCollatorForCellClassification()
elif self.classifier == "gene":
data_collator = DataCollatorForGeneClassification()
# create the trainer
trainer = Trainer(
model=model,
args=training_args_init,
data_collator=data_collator,
train_dataset=train_data,
eval_dataset=eval_data,
compute_metrics=cu.compute_metrics,
)
# train the classifier
trainer.train()
trainer.save_model(output_directory)
if predict is True:
# make eval predictions and save predictions and metrics
predictions = trainer.predict(eval_data)
prediction_output_path = f"{output_directory}/predictions.pkl"
with open(prediction_output_path, "wb") as f:
pickle.dump(predictions, f)
trainer.save_metrics("eval", predictions.metrics)
return trainer
def evaluate_model(
self,
model,
num_classes,
id_class_dict,
eval_data,
predict=False,
output_directory=None,
output_prefix=None,
):
"""
Evaluate the fine-tuned model.
**Parameters**
model : nn.Module
| Loaded fine-tuned model (e.g. trainer.model)
num_classes : int
| Number of classes for classifier
id_class_dict : dict
| Loaded _id_class_dict.pkl previously prepared by Classifier.prepare_data
| (dictionary of format: numerical IDs: class_labels)
eval_data : Dataset
| Loaded evaluation .dataset input
predict : bool
| Whether or not to save eval predictions
output_directory : Path
| Path to directory where eval data will be saved
output_prefix : str
| Prefix for output files
"""
##### Evaluate the model #####
labels = id_class_dict.keys()
y_pred, y_true, logits_list = eu.classifier_predict(
model, self.classifier, eval_data, self.forward_batch_size
)
conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics(
y_pred, y_true, logits_list, num_classes, labels
)
if predict is True:
pred_dict = {
"pred_ids": y_pred,
"label_ids": y_true,
"predictions": logits_list,
}
pred_dict_output_path = (
Path(output_directory) / f"{output_prefix}_pred_dict"
).with_suffix(".pkl")
with open(pred_dict_output_path, "wb") as f:
pickle.dump(pred_dict, f)
return {
"conf_mat": conf_mat,
"macro_f1": macro_f1,
"acc": acc,
"roc_metrics": roc_metrics,
}
def evaluate_saved_model(
self,
model_directory,
id_class_dict_file,
test_data_file,
output_directory,
output_prefix,
predict=True,
):
"""
Evaluate the fine-tuned model.
**Parameters**
model_directory : Path
| Path to directory containing model
id_class_dict_file : Path
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
| (dictionary of format: numerical IDs: class_labels)
test_data_file : Path
| Path to directory containing test .dataset
output_directory : Path
| Path to directory where eval data will be saved
output_prefix : str
| Prefix for output files
predict : bool
| Whether or not to save eval predictions
"""
# load numerical id to class dictionary (id:class)
with open(id_class_dict_file, "rb") as f:
id_class_dict = pickle.load(f)
# get number of classes for classifier
num_classes = cu.get_num_classes(id_class_dict)
# load previously filtered and prepared data
test_data = pu.load_and_filter(None, self.nproc, test_data_file)
# load previously fine-tuned model
model = pu.load_model(self.model_type, num_classes, model_directory, "eval")
# evaluate the model
result = self.evaluate_model(
model,
num_classes,
id_class_dict,
test_data,
predict=predict,
output_directory=output_directory,
output_prefix=output_prefix,
)
all_conf_mat_df = pd.DataFrame(
result["conf_mat"],
columns=id_class_dict.values(),
index=id_class_dict.values(),
)
all_metrics = {
"conf_matrix": all_conf_mat_df,
"macro_f1": result["macro_f1"],
"acc": result["acc"],
}
all_roc_metrics = None # roc metrics not reported for multiclass
if num_classes == 2:
mean_fpr = np.linspace(0, 1, 100)
mean_tpr = result["roc_metrics"]["interp_tpr"]
all_roc_auc = result["roc_metrics"]["auc"]
all_roc_metrics = {
"mean_tpr": mean_tpr,
"mean_fpr": mean_fpr,
"all_roc_auc": all_roc_auc,
}
all_metrics["all_roc_metrics"] = all_roc_metrics
test_metrics_output_path = (
Path(output_directory) / f"{output_prefix}_test_metrics_dict"
).with_suffix(".pkl")
with open(test_metrics_output_path, "wb") as f:
pickle.dump(all_metrics, f)
return all_metrics
def plot_conf_mat(
self,
conf_mat_dict,
output_directory,
output_prefix,
custom_class_order=None,
):
"""
Plot confusion matrix results of evaluating the fine-tuned model.
**Parameters**
conf_mat_dict : dict
| Dictionary of model_name : confusion_matrix_DataFrame
| (all_metrics["conf_matrix"] from self.validate)
output_directory : Path
| Path to directory where plots will be saved
output_prefix : str
| Prefix for output file
custom_class_order : None, list
| List of classes in custom order for plots.
| Same order will be used for all models.
"""
for model_name in conf_mat_dict.keys():
eu.plot_confusion_matrix(
conf_mat_dict[model_name],
model_name,
output_directory,
output_prefix,
custom_class_order,
)
def plot_roc(
self,
roc_metric_dict,
model_style_dict,
title,
output_directory,
output_prefix,
):
"""
Plot ROC curve results of evaluating the fine-tuned model.
**Parameters**
roc_metric_dict : dict
| Dictionary of model_name : roc_metrics
| (all_metrics["all_roc_metrics"] from self.validate)
model_style_dict : dict[dict]
| Dictionary of model_name : dictionary of style_attribute : style
| where style includes color and linestyle
| e.g. {'Model_A': {'color': 'black', 'linestyle': '-'}, 'Model_B': ...}
title : str
| Title of plot (e.g. 'Dosage-sensitive vs -insensitive factors')
output_directory : Path
| Path to directory where plots will be saved
output_prefix : str
| Prefix for output file
"""
eu.plot_ROC(
roc_metric_dict, model_style_dict, title, output_directory, output_prefix
)
def plot_predictions(
self,
predictions_file,
id_class_dict_file,
title,
output_directory,
output_prefix,
custom_class_order=None,
kwargs_dict=None,
):
"""
Plot prediction results of evaluating the fine-tuned model.
**Parameters**
predictions_file : path
| Path of model predictions output to plot
| (saved output from self.validate if predict_eval=True)
| (or saved output from self.evaluate_saved_model)
id_class_dict_file : Path
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
| (dictionary of format: numerical IDs: class_labels)
title : str
| Title for legend containing class labels.
output_directory : Path
| Path to directory where plots will be saved
output_prefix : str
| Prefix for output file
custom_class_order : None, list
| List of classes in custom order for plots.
| Same order will be used for all models.
kwargs_dict : None, dict
| Dictionary of kwargs to pass to plotting function.
"""
# load predictions
with open(predictions_file, "rb") as f:
predictions = pickle.load(f)
# load numerical id to class dictionary (id:class)
with open(id_class_dict_file, "rb") as f:
id_class_dict = pickle.load(f)
if isinstance(predictions, dict):
if all(
[
key in predictions.keys()
for key in ["pred_ids", "label_ids", "predictions"]
]
):
# format is output from self.evaluate_saved_model
predictions_logits = np.array(predictions["predictions"])
true_ids = predictions["label_ids"]
else:
# format is output from self.validate if predict_eval=True
predictions_logits = predictions.predictions
true_ids = predictions.label_ids
num_classes = len(id_class_dict.keys())
num_predict_classes = predictions_logits.shape[1]
assert num_classes == num_predict_classes
classes = id_class_dict.values()
true_labels = [id_class_dict[idx] for idx in true_ids]
predictions_df = pd.DataFrame(predictions_logits, columns=classes)
if custom_class_order is not None:
predictions_df = predictions_df.reindex(columns=custom_class_order)
predictions_df["true"] = true_labels
custom_dict = dict(zip(classes, [i for i in range(len(classes))]))
if custom_class_order is not None:
custom_dict = dict(
zip(custom_class_order, [i for i in range(len(custom_class_order))])
)
predictions_df = predictions_df.sort_values(
by=["true"], key=lambda x: x.map(custom_dict)
)
eu.plot_predictions(
predictions_df, title, output_directory, output_prefix, kwargs_dict
)