Question about the CLS token in cell type annotation task
Congrats for the great work! I have a question reagarding the cell type annotation task. I see from your code that the model actually has only PAD and MASK two special tokens for MLM pretraining, but when it comes to the downstream cell type annotation task, the used classifier in BertForSequenceClassification is actually get the first token of the hidden states, since you have no CLS token, so you actually use the first gene to represent the cell embedding for cell type annotation task?
Thank you for your interest in Geneformer! If you use the Huggingface code for sequence (cell) classification, the first token's embedding is used as a representation of the cell. Because these are contextual embeddings, the embedding of the first token varies with the context of the remaining genes for each cell. (This is also the concept of the CLS token.) The first token in this case has additional information of being a particular gene. If you'd like to use a CLS token instead, you can add it to the model dictionary, tokenize the data with adding the CLS token in the front of the encoding, and fine-tune the model with those encodings. Alternatively, you can modify the Huggingface code to perform mean-pooling as it's pooling layer if you prefer.
Of note, extracting/plotting cell embeddings with our emb_extractor function uses mean-pooling by default, as discussed in closed discussion #215. Additionally it is important to note that, as discussed in closed discussion #69, while the CLS token could have been used for cell embeddings for cell classification, summarizing the embedding in a single token presents an issue for the in silico perturbation strategy. In this application, we derive the cell embedding shift in response to perturbation by comparing the embedding of all genes aside from the perturbed gene so that we are quantifying the perturbation’s effect on context. Therefore, we would not be able to use a CLS token to accomplish this since the genes would be inseparable.
Thank you for your kind reply, I understand now and it makes sense! Another question is that I saw from your config file that you used the absolute position embedding, but in your paper, there is no position-related description, I wonder if you used a learnable positional embedding for the top 2048 genes obtained by ranking value encoding?
Thank you for your question! In this case the rank value encodings define the position based on the relative rank of the genes. We used the Huggingface implementation as in this repository with regard to presenting the encodings to the model.