|
--- |
|
license: cc-by-nc-nd-4.0 |
|
datasets: |
|
- vandijklab/immune-c2s |
|
language: |
|
- en |
|
tags: |
|
- pytorch |
|
- causal-lm |
|
- scRNA-seq |
|
--- |
|
# Overview |
|
|
|
This is the the Pythia-160m model developed by EleutherAI fine-tuned using Cell2Sentence on *full* scRNA-seq cells. |
|
Cell2Sentence is a novel method for adapting large language models to single-cell transcriptomics. |
|
We transform single-cell RNA sequencing data into sequences of gene names ordered by expression level, termed "cell sentences". |
|
For more details, we refer to the paper linked below. |
|
This model was trained on the immune tissue dataset from [Domínguez et al.](https://www.science.org/doi/10.1126/science.abl5197) |
|
using 8 A100 40GB GPUs for approximately 20 hours on the following tasks: |
|
1. conditional cell generation |
|
2. unconditional cell generation |
|
3. cell type prediction |
|
|
|
## Cell2Sentence Links: |
|
GitHub: <https://github.com/vandijklab/cell2sentence-ft> |
|
Paper: <https://www.biorxiv.org/content/10.1101/2023.09.11.557287v3> |
|
|
|
## Pythia Links: |
|
GitHub: <https://github.com/EleutherAI/pythia> |
|
Paper: <https://arxiv.org/abs/2304.01373> |
|
Hugging Face: <https://huggingface.co/EleutherAI/pythia-160m> |
|
|
|
# Evaluation |
|
|
|
This model was evaluated on KNN classification and Gromov-Wasserstein (GW) distance. |
|
The label for a generated cell is the corresponding cell type used in its corresponding prompt for generation. |
|
Ground truth cells were sampled with replacement from a held out test dataset. |
|
The generated cells are converted to expression vectors using the method described in the paper. |
|
For complete details on the experiments, we refer to the paper. |
|
|
|
| Model | k=3 NN (↑) | k=5 NN (↑) | k=10 NN (↑) | k=25 NN (↑) | GW (↓) | |
|
| :---- | :---: | :---: | :---: | :---: | :----: | |
|
| scGEN | 0.2376 | 0.2330 | 0.2377 | 0.2335 | 315.9505 | |
|
| scVI | 0.2436 | 0.2400 | 0.2425 | 0.2348 | 302.1285 | |
|
| scDiffusion | 0.2335 | 0.2288 | 0.2368 | 0.2306 | 72.0208 | |
|
| scGPT | 0.1838 | 0.1788 | 0.1811 | 0.1882 | 2989.8066 | |
|
| **C2S (Pythia-160m)** | **0.2588** | **0.2565** | **0.2746** | **0.2715** | **54.3040** | |
|
|
|
# Sample Code |
|
|
|
We provide an example of how to use the model to conditionally generate a cell equipped with a post-processing function to remove duplicate and invalid genes. |
|
In order to generate full cells, the `max_length` generation parameter should be changed to 9200. |
|
However, we recommend using an A100 GPU for inference speed and memory capacity if full cell generation is required. |
|
Unconditional cell generation and cell type prediction prompts are included as well, but we do not include an example cell sentence to format the prompt. |
|
We refer to the paper and GitHub repository for instructions on how to transform expression vectors into cell sentences. |
|
|
|
``` |
|
import json |
|
import re |
|
from collections import Counter |
|
from typing import List |
|
|
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
def post_process_generated_cell_sentences( |
|
cell_sentence: str, |
|
gene_dictionary: List |
|
): |
|
""" |
|
Post-processing function for generated cell sentences. |
|
Invalid genes are removed and ranks of duplicated genes are averaged. |
|
|
|
Arguments: |
|
cell_sentence: generated cell sentence string |
|
gene_dictionary: list of gene vocabulary (all uppercase) |
|
|
|
Returns: |
|
post_processed_sentence: generated cell sentence after post processing steps |
|
""" |
|
generated_gene_names = cell_sentence.split(" ") |
|
generated_gene_names = [generated_gene.upper() for generated_gene in generated_gene_names] |
|
|
|
#--- Remove nonsense genes ---# |
|
generated_gene_names = [gene_name for gene_name in generated_gene_names if gene_name in gene_dictionary] |
|
|
|
#--- Average ranks ---# |
|
gene_name_to_occurrences = Counter(generated_gene_names) # get mapping of gene name --> number of occurrences |
|
post_processed_sentence = generated_gene_names.copy() # copy of generated gene list |
|
|
|
for gene_name in gene_name_to_occurrences: |
|
if gene_name_to_occurrences[gene_name] > 1 and gene_name != replace_nonsense_string: |
|
# Find positions of all occurrences of duplicated generated gene in list |
|
# Note: using post_processed_sentence here; since duplicates are being removed, list will be |
|
# getting shorter. Getting indices in original list will no longer be accurate positions |
|
occurrence_positions = [idx for idx, elem in enumerate(post_processed_sentence) if elem == gene_name] |
|
average_position = int(sum(occurrence_positions) / len(occurrence_positions)) |
|
|
|
# Remove occurrences |
|
post_processed_sentence = [elem for elem in post_processed_sentence if elem != gene_name] |
|
|
|
# Reinsert gene_name at average position |
|
post_processed_sentence.insert(average_position, gene_name) |
|
|
|
return post_processed_sentence |
|
|
|
genes_path = "pbmc_vocab.json" |
|
|
|
with open(vocab_path, "r") as f: |
|
gene_dictionary = json.load(f) |
|
|
|
model_name = "vandijklab/pythia-160m-c2s" |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16, |
|
attn_implementation="flash_attention_2" |
|
).to(torch.device("cuda")) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
cell_type = "T Cell" |
|
ccg = f"Enumerate the genes in a {cell_type} cell with nonzero expression, from highest to lowest." |
|
|
|
# Prompts for other forms a generation. |
|
# ucg = "Display a cell's genes by expression level, in descending order." |
|
# cellsentence = "CELL_SENTENCE" |
|
# ctp = "Identify the cell type most likely associated with these highly expressed genes listed in descending order. " |
|
# + cellsentence + |
|
# "Name the cell type connected to these genes, ranked from highest to lowest expression." |
|
|
|
tokens = tokenizer(ccg, return_tensors='pt') |
|
input_ids = tokens['input_ids'].to(torch.device("cuda")) |
|
attention_mask = tokens['attention_mask'].to(torch.device("cuda")) |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
do_sample=True, |
|
max_length=1024, |
|
top_k=50, |
|
top_p=0.95, |
|
) |
|
|
|
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
cell_sentence = "".join(re.split(r"\?|\.|:", output_text)[1:]).strip() |
|
processed_genes = post_process_generated_cell_sentences(cell_sentence, gene_dictionary) |
|
``` |