dflevine13 commited on
Commit
66090a5
·
verified ·
1 Parent(s): 2576b96

added example to readme

Browse files
Files changed (1) hide show
  1. README.md +103 -0
README.md CHANGED
@@ -6,12 +6,17 @@ license: apache-2.0
6
  This model uses Cell2Sentence fine-tuning on the Pythia-160m model developed by EleutherAI.
7
 
8
  Cell2Sentence Links:
 
9
  GitHub: <https://github.com/vandijklab/cell2sentence-ft>
 
10
  Paper: <https://www.biorxiv.org/content/10.1101/2023.09.11.557287v3>
11
 
12
  Pythia Links
 
13
  GitHub: <https://github.com/EleutherAI/pythia>
 
14
  Paper: <https://arxiv.org/abs/2304.01373>
 
15
  Hugging Face: <https://huggingface.co/EleutherAI/pythia-160m>
16
 
17
  ##Model Details
@@ -23,3 +28,101 @@ This model is trained on the immune tissue dataset from [Domínguez et al.](http
23
  1. conditional cell generation
24
  2. unconditional cell generation
25
  3. cell type prediction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  This model uses Cell2Sentence fine-tuning on the Pythia-160m model developed by EleutherAI.
7
 
8
  Cell2Sentence Links:
9
+
10
  GitHub: <https://github.com/vandijklab/cell2sentence-ft>
11
+
12
  Paper: <https://www.biorxiv.org/content/10.1101/2023.09.11.557287v3>
13
 
14
  Pythia Links
15
+
16
  GitHub: <https://github.com/EleutherAI/pythia>
17
+
18
  Paper: <https://arxiv.org/abs/2304.01373>
19
+
20
  Hugging Face: <https://huggingface.co/EleutherAI/pythia-160m>
21
 
22
  ##Model Details
 
28
  1. conditional cell generation
29
  2. unconditional cell generation
30
  3. cell type prediction
31
+
32
+ ##Sample Code
33
+
34
+ We provide an example of how to use the model to conditionally generate a cell equipped with a post-processing function to remove duplicate and invalid genes.
35
+ Unconditional cell generation and cell type prediction prompts are included as well, but we do not include an example cell sentence to format the prompt.
36
+ We refer to the paper and GitHub repository for instructions on how to transform expression vectors into cell sentences.
37
+
38
+ ```
39
+ import json
40
+ import re
41
+ from collections import Counter
42
+ from typing import List
43
+
44
+ import torch
45
+ from transformers import AutoTokenizer, AutoModelForCausalLM
46
+
47
+
48
+ def post_process_generated_cell_sentences(
49
+ cell_sentence: str,
50
+ gene_dictionary: List
51
+ ):
52
+ """
53
+ Post-processing function for generated cell sentences.
54
+ Invalid genes are removed and ranks of duplicated genes are averaged.
55
+
56
+ Arguments:
57
+ cell_sentence: generated cell sentence string
58
+ gene_dictionary: list of gene vocabulary (all uppercase)
59
+
60
+ Returns:
61
+ post_processed_sentence: generated cell sentence after post processing steps
62
+ """
63
+ generated_gene_names = cell_sentence.split(" ")
64
+ generated_gene_names = [generated_gene.upper() for generated_gene in generated_gene_names]
65
+
66
+ #--- Remove nonsense genes ---#
67
+ generated_gene_names = [gene_name for gene_name in generated_gene_names if gene_name in gene_dictionary]
68
+
69
+ #--- Average ranks ---#
70
+ gene_name_to_occurrences = Counter(generated_gene_names) # get mapping of gene name --> number of occurrences
71
+ post_processed_sentence = generated_gene_names.copy() # copy of generated gene list
72
+
73
+ for gene_name in gene_name_to_occurrences:
74
+ if gene_name_to_occurrences[gene_name] > 1 and gene_name != replace_nonsense_string:
75
+ # Find positions of all occurrences of duplicated generated gene in list
76
+ # Note: using post_processed_sentence here; since duplicates are being removed, list will be
77
+ # getting shorter. Getting indices in original list will no longer be accurate positions
78
+ occurrence_positions = [idx for idx, elem in enumerate(post_processed_sentence) if elem == gene_name]
79
+ average_position = int(sum(occurrence_positions) / len(occurrence_positions))
80
+
81
+ # Remove occurrences
82
+ post_processed_sentence = [elem for elem in post_processed_sentence if elem != gene_name]
83
+
84
+ # Reinsert gene_name at average position
85
+ post_processed_sentence.insert(average_position, gene_name)
86
+
87
+ return post_processed_sentence
88
+
89
+ genes_path = "pbmc_vocab.json"
90
+
91
+ with open(vocab_path, "r") as f:
92
+ gene_dictionary = json.load(f)
93
+
94
+ model_name = "vandijklab/pythia-160m-c2s"
95
+
96
+ model = AutoModelForCausalLM.from_pretrained(
97
+ model_name,
98
+ torch_dtype=torch.float16,
99
+ attn_implementation="flash_attention_2"
100
+ ).to(torch.device("cuda"))
101
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
102
+
103
+ cell_type = "T Cell"
104
+ ccg = f"Enumerate the genes in a {cell_type} cell with nonzero expression, from highest to lowest."
105
+
106
+ # Prompts for other forms a generation.
107
+ # ucg = "Display a cell's genes by expression level, in descending order."
108
+ # cellsentence = "CELL_SENTENCE"
109
+ # ctp = f"Identify the cell type most likely associated with these highly expressed genes listed in descending order. {cellsentence} Name the cell type connected to these genes, ranked from highest to lowest expression."
110
+
111
+ tokens = tokenizer(ccg, return_tensors='pt')
112
+ input_ids = tokens['input_ids'].to(torch.device("cuda"))
113
+ attention_mask = tokens['attention_mask'].to(torch.device("cuda"))
114
+
115
+ with torch.no_grad():
116
+ outputs = model.generate(
117
+ input_ids=input_ids,
118
+ attention_mask=attention_mask,
119
+ do_sample=True,
120
+ max_length=1024,
121
+ top_k=50,
122
+ top_p=0.95,
123
+ )
124
+
125
+ output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
126
+ cell_sentence = "".join(re.split(r"\?|\.|:", output_text)[1:]).strip()
127
+ processed_genes = post_process_generated_cell_sentences(cell_sentence, gene_dictionary)
128
+ ```