license: cc-by-nc-nd-4.0
datasets:
- vandijklab/immune-c2s
language:
- en
tags:
- pytorch
- causal-lm
- scRNA-seq
Overview
This is the the Pythia-160m model developed by EleutherAI fine-tuned using Cell2Sentence on full scRNA-seq cells. Cell2Sentence is a novel method for adapting large language models to single-cell transcriptomics. We transform single-cell RNA sequencing data into sequences of gene names ordered by expression level, termed "cell sentences". For more details, we refer to the paper linked below. This model was trained on the immune tissue dataset from Domínguez et al. using 8 A100 40GB GPUs for approximately 20 hours on the following tasks:
- conditional cell generation
- unconditional cell generation
- cell type prediction
Cell2Sentence Links:
GitHub: https://github.com/vandijklab/cell2sentence-ft
Paper: https://www.biorxiv.org/content/10.1101/2023.09.11.557287v3
Pythia Links:
GitHub: https://github.com/EleutherAI/pythia
Paper: https://arxiv.org/abs/2304.01373
Hugging Face: https://huggingface.co/EleutherAI/pythia-160m
Evaluation
This model was evaluated on KNN classification and Gromov-Wasserstein (GW) distance. The label for a generated cell is the corresponding cell type used in its corresponding prompt for generation. Ground truth cells were sampled with replacement from a held out test dataset. The generated cells are converted to expression vectors using the method described in the paper. For complete details on the experiments, we refer to the paper.
Model | k=3 NN (↑) | k=5 NN (↑) | k=10 NN (↑) | k=25 NN (↑) | GW (↓) |
---|---|---|---|---|---|
scGEN | 0.2376 | 0.2330 | 0.2377 | 0.2335 | 315.9505 |
scVI | 0.2436 | 0.2400 | 0.2425 | 0.2348 | 302.1285 |
scDiffusion | 0.2335 | 0.2288 | 0.2368 | 0.2306 | 72.0208 |
scGPT | 0.1838 | 0.1788 | 0.1811 | 0.1882 | 2989.8066 |
C2S (Pythia-160m) | 0.2588 | 0.2565 | 0.2746 | 0.2715 | 54.3040 |
Sample Code
We provide an example of how to use the model to conditionally generate a cell equipped with a post-processing function to remove duplicate and invalid genes.
In order to generate full cells, the max_length
generation parameter should be changed to 9200.
However, we recommend using an A100 GPU for inference speed and memory capacity if full cell generation is required.
Unconditional cell generation and cell type prediction prompts are included as well, but we do not include an example cell sentence to format the prompt.
We refer to the paper and GitHub repository for instructions on how to transform expression vectors into cell sentences.
import json
import re
from collections import Counter
from typing import List
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
def post_process_generated_cell_sentences(
cell_sentence: str,
gene_dictionary: List
):
"""
Post-processing function for generated cell sentences.
Invalid genes are removed and ranks of duplicated genes are averaged.
Arguments:
cell_sentence: generated cell sentence string
gene_dictionary: list of gene vocabulary (all uppercase)
Returns:
post_processed_sentence: generated cell sentence after post processing steps
"""
generated_gene_names = cell_sentence.split(" ")
generated_gene_names = [generated_gene.upper() for generated_gene in generated_gene_names]
#--- Remove nonsense genes ---#
generated_gene_names = [gene_name for gene_name in generated_gene_names if gene_name in gene_dictionary]
#--- Average ranks ---#
gene_name_to_occurrences = Counter(generated_gene_names) # get mapping of gene name --> number of occurrences
post_processed_sentence = generated_gene_names.copy() # copy of generated gene list
for gene_name in gene_name_to_occurrences:
if gene_name_to_occurrences[gene_name] > 1 and gene_name != replace_nonsense_string:
# Find positions of all occurrences of duplicated generated gene in list
# Note: using post_processed_sentence here; since duplicates are being removed, list will be
# getting shorter. Getting indices in original list will no longer be accurate positions
occurrence_positions = [idx for idx, elem in enumerate(post_processed_sentence) if elem == gene_name]
average_position = int(sum(occurrence_positions) / len(occurrence_positions))
# Remove occurrences
post_processed_sentence = [elem for elem in post_processed_sentence if elem != gene_name]
# Reinsert gene_name at average position
post_processed_sentence.insert(average_position, gene_name)
return post_processed_sentence
genes_path = "pbmc_vocab.json"
with open(vocab_path, "r") as f:
gene_dictionary = json.load(f)
model_name = "vandijklab/pythia-160m-c2s"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2"
).to(torch.device("cuda"))
tokenizer = AutoTokenizer.from_pretrained(model_name)
cell_type = "T Cell"
ccg = f"Enumerate the genes in a {cell_type} cell with nonzero expression, from highest to lowest."
# Prompts for other forms a generation.
# ucg = "Display a cell's genes by expression level, in descending order."
# cellsentence = "CELL_SENTENCE"
# ctp = "Identify the cell type most likely associated with these highly expressed genes listed in descending order. "
# + cellsentence +
# "Name the cell type connected to these genes, ranked from highest to lowest expression."
tokens = tokenizer(ccg, return_tensors='pt')
input_ids = tokens['input_ids'].to(torch.device("cuda"))
attention_mask = tokens['attention_mask'].to(torch.device("cuda"))
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
do_sample=True,
max_length=1024,
top_k=50,
top_p=0.95,
)
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
cell_sentence = "".join(re.split(r"\?|\.|:", output_text)[1:]).strip()
processed_genes = post_process_generated_cell_sentences(cell_sentence, gene_dictionary)