Questions about the classification model

#215
by nwt - opened

Thank you for your great work!
I am following https://huggingface.co/ctheodoris/Geneformer/blob/main/examples/cell_classification.ipynb for cell classification task, and load the pretrained model according to the following code, where the path is replaced by "Geneformer/" (the main directory of this repository) :

# reload pretrained model
model = BertForSequenceClassification.from_pretrained("/path/to/pretrained_model/", 
                         num_labels=len(organ_label_dict.keys()),
                         output_attentions = False,
                         output_hidden_states = False).to("cuda")

As far as I know, BertForSequenceClassification uses a linear layer on top of the pooling layer. And the pooling layer just takes the first token which is the highest ranked gene in this scene. So, the classification model defined in this way doesn't use mean pooling of the gene embeddings. Is there something wrong with my understanding?
Thanks!

Thank you for your interest in Geneformer! Please refer to the Huggingface transformers code for how the sequence (cell) classification works with their function in terms of training the classifier and using it for predicting the classes in new data. To extract/plot cell embeddings and to perform in silico perturbation with the tools in this repository, we use mean pooling of the embedding layer specified by the user, averaging the embeddings for all the genes in the cell to generate the cell embedding. We use the same process whether the model is the pretrained one or a fine-tuned model for gene or cell classification. If you prefer, you can use a different method for training the cell classifier and then load the fine-tuned model for extracting/plotting embeddings or performing in silico perturbation with the tools provided in this repository.

ctheodoris changed discussion status to closed

Sign up or log in to comment