{ "cells": [ { "cell_type": "markdown", "id": "65a2b29a-c678-4874-a1bf-5af3a7d00ed9", "metadata": {}, "source": [ "## Geneformer Fine-Tuning for Classification of Cardiomyopathy Disease States" ] }, { "cell_type": "markdown", "id": "1792e51c-86c3-406f-be5a-273c4e4aec20", "metadata": {}, "source": [ "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses previously optimized hyperparameters, but please see the \"hyperparam_optimiz_for_disease_classifier\" script for an example of how to tune hyperparameters for downstream applications." ] }, { "cell_type": "markdown", "id": "3dad7564-b464-4d37-9188-17c0ae4ae59f", "metadata": {}, "source": [ "### Train cell classifier with 70% of data (with hyperparameters previously optimized based on 15% of data as validation set) and evaluate on held-out test set of 15% of data" ] }, { "cell_type": "markdown", "id": "9027e51e-7830-4ab8-aebf-b9779b3ea2c1", "metadata": {}, "source": [ "### Fine-tune the model for cell state classification" ] }, { "cell_type": "code", "execution_count": 2, "id": "efe3b79b-aa8f-416c-9755-7f9299d6a81e", "metadata": {}, "outputs": [], "source": [ "import datetime\n", "from geneformer import Classifier\n", "\n", "current_date = datetime.datetime.now()\n", "datestamp = f\"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}{current_date.hour:02d}{current_date.minute:02d}{current_date.second:02d}\"\n", "datestamp_min = f\"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}\"\n", "\n", "output_prefix = \"cm_classifier_test\"\n", "output_dir = f\"/path/to/output_dir/{datestamp}\"\n", "!mkdir $output_dir" ] }, { "cell_type": "code", "execution_count": 3, "id": "f070ab20-1b18-4941-a5c7-89e23b519261", "metadata": {}, "outputs": [], "source": [ "filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n", "training_args = {\n", " \"num_train_epochs\": 0.9,\n", " \"learning_rate\": 0.000804,\n", " \"lr_scheduler_type\": \"polynomial\",\n", " \"warmup_steps\": 1812,\n", " \"weight_decay\":0.258828,\n", " \"per_device_train_batch_size\": 12,\n", " \"seed\": 73,\n", "}\n", "cc = Classifier(classifier=\"cell\",\n", " cell_state_dict = {\"state_key\": \"disease\", \"states\": \"all\"},\n", " filter_data=filter_data_dict,\n", " training_args=training_args,\n", " max_ncells=None,\n", " freeze_layers = 2,\n", " num_crossval_splits = 1,\n", " forward_batch_size=200,\n", " nproc=16)" ] }, { "cell_type": "code", "execution_count": 4, "id": "0bced2e8-0a49-418e-a7f9-3981be256bd6", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9c409ca656ed4cb0b280d95e326c1bc7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Saving the dataset (0/3 shards): 0%| | 0/115367 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "facb7207b57948aebb3f8681346e17d4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Saving the dataset (0/1 shards): 0%| | 0/17228 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# previously balanced splits with prepare_data and validate functions\n", "# argument attr_to_split set to \"individual\" and attr_to_balance set to [\"disease\",\"lvef\",\"age\",\"sex\",\"length\"]\n", "train_ids = [\"1447\", \"1600\", \"1462\", \"1558\", \"1300\", \"1508\", \"1358\", \"1678\", \"1561\", \"1304\", \"1610\", \"1430\", \"1472\", \"1707\", \"1726\", \"1504\", \"1425\", \"1617\", \"1631\", \"1735\", \"1582\", \"1722\", \"1622\", \"1630\", \"1290\", \"1479\", \"1371\", \"1549\", \"1515\"]\n", "eval_ids = [\"1422\", \"1510\", \"1539\", \"1606\", \"1702\"]\n", "test_ids = [\"1437\", \"1516\", \"1602\", \"1685\", \"1718\"]\n", "\n", "train_test_id_split_dict = {\"attr_key\": \"individual\",\n", " \"train\": train_ids+eval_ids,\n", " \"test\": test_ids}\n", "\n", "# Example input_data_file: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n", "cc.prepare_data(input_data_file=\"/path/to/human_dcm_hcm_nf_2048_w_length.dataset\",\n", " output_directory=output_dir,\n", " output_prefix=output_prefix,\n", " split_id_dict=train_test_id_split_dict)" ] }, { "cell_type": "code", "execution_count": 5, "id": "73fe8b29-dd8f-4bf8-82c1-53196d73ed49", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "691e875524e441bca22b790a0f4a2a35", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/1 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "****** Validation split: 1/1 ******\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c2c4f53aa71a49b89c32c8ba573b0b0c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Filter (num_proc=16): 0%| | 0/115367 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "adf76144219747558bf39b7e776a68b3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Filter (num_proc=16): 0%| | 0/115367 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /gladstone/theodoris/home/ctheodoris/Geneformer and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n", "/gladstone/theodoris/home/ctheodoris/Geneformer/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, { "data": { "text/html": [ "\n", "
Epoch | \n", "Training Loss | \n", "Validation Loss | \n", "Accuracy | \n", "Macro F1 | \n", "
---|---|---|---|---|
0 | \n", "0.142400 | \n", "0.389166 | \n", "0.889797 | \n", "0.693074 | \n", "
"
],
"text/plain": [
"