Questions about the classification model
Thank you for your great work!
I am following https://huggingface.co/ctheodoris/Geneformer/blob/main/examples/cell_classification.ipynb for cell classification task, and load the pretrained model according to the following code, where the path is replaced by "Geneformer/"
(the main directory of this repository) :
# reload pretrained model
model = BertForSequenceClassification.from_pretrained("/path/to/pretrained_model/",
num_labels=len(organ_label_dict.keys()),
output_attentions = False,
output_hidden_states = False).to("cuda")
As far as I know, BertForSequenceClassification
uses a linear layer on top of the pooling layer. And the pooling layer just takes the first token which is the highest ranked gene in this scene. So, the classification model defined in this way doesn't use mean pooling of the gene embeddings. Is there something wrong with my understanding?
Thanks!
Thank you for your interest in Geneformer! Please refer to the Huggingface transformers code for how the sequence (cell) classification works with their function in terms of training the classifier and using it for predicting the classes in new data. To extract/plot cell embeddings and to perform in silico perturbation with the tools in this repository, we use mean pooling of the embedding layer specified by the user, averaging the embeddings for all the genes in the cell to generate the cell embedding. We use the same process whether the model is the pretrained one or a fine-tuned model for gene or cell classification. If you prefer, you can use a different method for training the cell classifier and then load the fine-tuned model for extracting/plotting embeddings or performing in silico perturbation with the tools provided in this repository.