File size: 6,485 Bytes
a3a0e3e a09ef8b bfb8e56 a3a0e3e a09ef8b 2e9ef7f c2ab9c5 2e9ef7f a09ef8b 2e9ef7f a09ef8b 2e9ef7f bfb8e56 a09ef8b 66090a5 5fb0b4c 66090a5 5fb0b4c 66090a5 bfb8e56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
---
license: cc-by-nc-nd-4.0
datasets:
- vandijklab/immune-c2s
language:
- en
tags:
- pytorch
- causal-lm
- scRNA-seq
---
# Overview
This is the the Pythia-160m model developed by EleutherAI fine-tuned using Cell2Sentence on *full* scRNA-seq cells.
Cell2Sentence is a novel method for adapting large language models to single-cell transcriptomics.
We transform single-cell RNA sequencing data into sequences of gene names ordered by expression level, termed "cell sentences".
For more details, we refer to the paper linked below.
This model was trained on the immune tissue dataset from [Domínguez et al.](https://www.science.org/doi/10.1126/science.abl5197)
using 8 A100 40GB GPUs for approximately 20 hours on the following tasks:
1. conditional cell generation
2. unconditional cell generation
3. cell type prediction
## Cell2Sentence Links:
GitHub: <https://github.com/vandijklab/cell2sentence-ft>
Paper: <https://www.biorxiv.org/content/10.1101/2023.09.11.557287v3>
## Pythia Links:
GitHub: <https://github.com/EleutherAI/pythia>
Paper: <https://arxiv.org/abs/2304.01373>
Hugging Face: <https://huggingface.co/EleutherAI/pythia-160m>
# Evaluation
This model was evaluated on KNN classification and Gromov-Wasserstein (GW) distance.
The label for a generated cell is the corresponding cell type used in its corresponding prompt for generation.
Ground truth cells were sampled with replacement from a held out test dataset.
The generated cells are converted to expression vectors using the method described in the paper.
For complete details on the experiments, we refer to the paper.
| Model | k=3 NN (↑) | k=5 NN (↑) | k=10 NN (↑) | k=25 NN (↑) | GW (↓) |
| :---- | :---: | :---: | :---: | :---: | :----: |
| scGEN | 0.2376 | 0.2330 | 0.2377 | 0.2335 | 315.9505 |
| scVI | 0.2436 | 0.2400 | 0.2425 | 0.2348 | 302.1285 |
| scDiffusion | 0.2335 | 0.2288 | 0.2368 | 0.2306 | 72.0208 |
| scGPT | 0.1838 | 0.1788 | 0.1811 | 0.1882 | 2989.8066 |
| **C2S (Pythia-160m)** | **0.2588** | **0.2565** | **0.2746** | **0.2715** | **54.3040** |
# Sample Code
We provide an example of how to use the model to conditionally generate a cell equipped with a post-processing function to remove duplicate and invalid genes.
In order to generate full cells, the `max_length` generation parameter should be changed to 9200.
However, we recommend using an A100 GPU for inference speed and memory capacity if full cell generation is required.
Unconditional cell generation and cell type prediction prompts are included as well, but we do not include an example cell sentence to format the prompt.
We refer to the paper and GitHub repository for instructions on how to transform expression vectors into cell sentences.
```
import json
import re
from collections import Counter
from typing import List
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
def post_process_generated_cell_sentences(
cell_sentence: str,
gene_dictionary: List
):
"""
Post-processing function for generated cell sentences.
Invalid genes are removed and ranks of duplicated genes are averaged.
Arguments:
cell_sentence: generated cell sentence string
gene_dictionary: list of gene vocabulary (all uppercase)
Returns:
post_processed_sentence: generated cell sentence after post processing steps
"""
generated_gene_names = cell_sentence.split(" ")
generated_gene_names = [generated_gene.upper() for generated_gene in generated_gene_names]
#--- Remove nonsense genes ---#
generated_gene_names = [gene_name for gene_name in generated_gene_names if gene_name in gene_dictionary]
#--- Average ranks ---#
gene_name_to_occurrences = Counter(generated_gene_names) # get mapping of gene name --> number of occurrences
post_processed_sentence = generated_gene_names.copy() # copy of generated gene list
for gene_name in gene_name_to_occurrences:
if gene_name_to_occurrences[gene_name] > 1 and gene_name != replace_nonsense_string:
# Find positions of all occurrences of duplicated generated gene in list
# Note: using post_processed_sentence here; since duplicates are being removed, list will be
# getting shorter. Getting indices in original list will no longer be accurate positions
occurrence_positions = [idx for idx, elem in enumerate(post_processed_sentence) if elem == gene_name]
average_position = int(sum(occurrence_positions) / len(occurrence_positions))
# Remove occurrences
post_processed_sentence = [elem for elem in post_processed_sentence if elem != gene_name]
# Reinsert gene_name at average position
post_processed_sentence.insert(average_position, gene_name)
return post_processed_sentence
genes_path = "pbmc_vocab.json"
with open(vocab_path, "r") as f:
gene_dictionary = json.load(f)
model_name = "vandijklab/pythia-160m-c2s"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2"
).to(torch.device("cuda"))
tokenizer = AutoTokenizer.from_pretrained(model_name)
cell_type = "T Cell"
ccg = f"Enumerate the genes in a {cell_type} cell with nonzero expression, from highest to lowest."
# Prompts for other forms a generation.
# ucg = "Display a cell's genes by expression level, in descending order."
# cellsentence = "CELL_SENTENCE"
# ctp = "Identify the cell type most likely associated with these highly expressed genes listed in descending order. "
# + cellsentence +
# "Name the cell type connected to these genes, ranked from highest to lowest expression."
tokens = tokenizer(ccg, return_tensors='pt')
input_ids = tokens['input_ids'].to(torch.device("cuda"))
attention_mask = tokens['attention_mask'].to(torch.device("cuda"))
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
do_sample=True,
max_length=1024,
top_k=50,
top_p=0.95,
)
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
cell_sentence = "".join(re.split(r"\?|\.|:", output_text)[1:]).strip()
processed_genes = post_process_generated_cell_sentences(cell_sentence, gene_dictionary)
``` |