wenkai commited on
Commit
4b441dc
·
verified ·
1 Parent(s): ab60ac5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +209 -30
app.py CHANGED
@@ -7,23 +7,135 @@ from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral
7
  from lavis.models.base_model import FAPMConfig
8
  import spaces
9
  import gradio as gr
10
- from esm_scripts.extract import run_demo
11
  from esm import pretrained, FastaBatchedDataset
12
-
13
- # from transformers import EsmTokenizer, EsmModel
14
-
15
-
16
- # Load the model
17
- model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
18
- model.load_checkpoint("model/checkpoint_mf2.pth")
19
- model.to('cuda')
20
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  @spaces.GPU
23
  def generate_caption(protein, prompt):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- esm_emb = torch.load('data/emb_esm2_3b/P18281.pt')['representations'][36]
26
- torch.save(esm_emb, 'data/emb_esm2_3b/example.pt')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  '''
28
  inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda')
29
  with torch.no_grad():
@@ -32,17 +144,50 @@ def generate_caption(protein, prompt):
32
  '''
33
  print("esm embedding generated")
34
  esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
35
- print("esm embedding processed")
 
 
 
36
  samples = {'name': ['protein_name'],
37
  'image': torch.unsqueeze(esm_emb, dim=0),
38
  'text_input': ['none'],
39
  'prompt': [prompt]}
40
 
41
- # Generate the output
42
- prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
43
- repetition_penalty=1.0)
44
-
45
- return prediction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  # return "test"
47
 
48
 
@@ -51,16 +196,50 @@ description = """Quick demonstration of the FAPM model for protein function pred
51
 
52
  The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main)."""
53
 
54
- iface = gr.Interface(
55
- fn=generate_caption,
56
- inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")],
57
- outputs=gr.Textbox(label="Generated description"),
58
- description=description
59
- )
60
-
61
- # Launch the interface
62
- iface.launch()
63
-
64
-
65
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
 
7
  from lavis.models.base_model import FAPMConfig
8
  import spaces
9
  import gradio as gr
10
+ # from esm_scripts.extract import run_demo
11
  from esm import pretrained, FastaBatchedDataset
12
+ from data.evaluate_data.utils import Ontology
13
+ import difflib
14
+ import re
15
+ from transformers import MistralForCausalLM
16
+
17
+ # Load the trained model
18
+ def get_model(type='Molecule Function'):
19
+ model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
20
+ if type == 'Molecule Function':
21
+ model.load_checkpoint("model/checkpoint_mf2.pth")
22
+ model.to('cuda')
23
+ elif type == 'Biological Process':
24
+ model.load_checkpoint("model/checkpoint_bp1.pth")
25
+ model.to('cuda')
26
+ elif type == 'Cellar Component':
27
+ model.load_checkpoint("model/checkpoint_cc2.pth")
28
+ model.to('cuda')
29
+ return model
30
+
31
+
32
+ models = {
33
+ 'Molecule Function': get_model('Molecule Function'),
34
+ 'Biological Process': get_model('Biological Process'),
35
+ 'Cellular Component': get_model('Cellar Component'),
36
+ }
37
+
38
+
39
+ # Load the mistral model
40
+ mistral_model = MistralForCausalLM.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B", torch_dtype=torch.float16).to('cuda')
41
+
42
+ # Load ESM2 model
43
+ model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
44
+ model_esm.to('cuda')
45
+ model_esm.eval()
46
+
47
+ godb = Ontology(f'data/go1.4-basic.obo', with_rels=True)
48
+ go_des = pd.read_csv('data/go_descriptions1.4.txt', sep='|', header=None)
49
+ go_des.columns = ['id', 'text']
50
+ go_des = go_des.dropna()
51
+ go_des['id'] = go_des['id'].apply(lambda x: re.sub('_', ':', x))
52
+ go_obo_set = set(go_des['id'].tolist())
53
+ go_des['text'] = go_des['text'].apply(lambda x: x.lower())
54
+ GO_dict = dict(zip(go_des['text'], go_des['id']))
55
+ Func_dict = dict(zip(go_des['id'], go_des['text']))
56
+
57
+ terms_mf = pd.read_pickle('data/terms/mf_terms.pkl')
58
+ choices_mf = [Func_dict[i] for i in list(set(terms_mf['gos']))]
59
+ choices_mf = {x.lower(): x for x in choices_mf}
60
+ terms_bp = pd.read_pickle('data/terms/bp_terms.pkl')
61
+ choices_bp = [Func_dict[i] for i in list(set(terms_bp['gos']))]
62
+ choices_bp = {x.lower(): x for x in choices_bp}
63
+ terms_cc = pd.read_pickle('data/terms/cc_terms.pkl')
64
+ choices_cc = [Func_dict[i] for i in list(set(terms_cc['gos']))]
65
+ choices_cc = {x.lower(): x for x in choices_cc}
66
+ choices = {
67
+ 'Molecule Function': choices_mf,
68
+ 'Biological Process': choices_bp,
69
+ 'Cellular Component': choices_cc,
70
+ }
71
 
72
  @spaces.GPU
73
  def generate_caption(protein, prompt):
74
+ # Process the image and the prompt
75
+ # with open('/home/user/app/example.fasta', 'w') as f:
76
+ # f.write('>{}\n'.format("protein_name"))
77
+ # f.write('{}\n'.format(protein.strip()))
78
+ # os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D /home/user/app/example.fasta /home/user/app --repr_layers 36 --truncation_seq_length 1024 --include per_tok")
79
+ # esm_emb = run_demo(protein_name='protein_name', protein_seq=protein,
80
+ # model=model_esm, alphabet=alphabet,
81
+ # include='per_tok', repr_layers=[36], truncation_seq_length=1024)
82
+
83
+ protein_name = 'protein_name'
84
+ protein_seq = protein
85
+ include = 'per_tok'
86
+ repr_layers = [36]
87
+ truncation_seq_length = 1024
88
+ toks_per_batch = 4096
89
+ print("start")
90
+ dataset = FastaBatchedDataset([protein_name], [protein_seq])
91
+ print("dataset prepared")
92
+ batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
93
+ print("batches prepared")
94
+
95
+ data_loader = torch.utils.data.DataLoader(
96
+ dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
97
+ )
98
+ print(f"Read sequences")
99
+ return_contacts = "contacts" in include
100
+
101
+ assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
102
+ repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
103
 
104
+ with torch.no_grad():
105
+ for batch_idx, (labels, strs, toks) in enumerate(data_loader):
106
+ print(
107
+ f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
108
+ )
109
+ if torch.cuda.is_available():
110
+ toks = toks.to(device="cuda", non_blocking=True)
111
+ out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
112
+ representations = {
113
+ layer: t.to(device="cpu") for layer, t in out["representations"].items()
114
+ }
115
+ if return_contacts:
116
+ contacts = out["contacts"].to(device="cpu")
117
+ for i, label in enumerate(labels):
118
+ result = {"label": label}
119
+ truncate_len = min(truncation_seq_length, len(strs[i]))
120
+ # Call clone on tensors to ensure tensors are not views into a larger representation
121
+ # See https://github.com/pytorch/pytorch/issues/1995
122
+ if "per_tok" in include:
123
+ result["representations"] = {
124
+ layer: t[i, 1: truncate_len + 1].clone()
125
+ for layer, t in representations.items()
126
+ }
127
+ if "mean" in include:
128
+ result["mean_representations"] = {
129
+ layer: t[i, 1: truncate_len + 1].mean(0).clone()
130
+ for layer, t in representations.items()
131
+ }
132
+ if "bos" in include:
133
+ result["bos_representations"] = {
134
+ layer: t[i, 0].clone() for layer, t in representations.items()
135
+ }
136
+ if return_contacts:
137
+ result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
138
+ esm_emb = result['representations'][36]
139
  '''
140
  inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda')
141
  with torch.no_grad():
 
144
  '''
145
  print("esm embedding generated")
146
  esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
147
+ if prompt is None:
148
+ prompt = 'none'
149
+ else:
150
+ prompt = prompt.lower()
151
  samples = {'name': ['protein_name'],
152
  'image': torch.unsqueeze(esm_emb, dim=0),
153
  'text_input': ['none'],
154
  'prompt': [prompt]}
155
 
156
+ union_pred_terms = []
157
+ for model_id in models.keys():
158
+ model = models[model_id]
159
+ # Generate the output
160
+ prediction = model.generate(mistral_model, samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
161
+ repetition_penalty=1.0)
162
+ x = prediction[0]
163
+ x = [eval(i) for i in x.split('; ')]
164
+ pred_terms = []
165
+ temp = []
166
+ for i in x:
167
+ txt = i[0]
168
+ prob = i[1]
169
+ sim_list = difflib.get_close_matches(txt.lower(), choices[model_id], n=1, cutoff=0.9)
170
+ if len(sim_list) > 0:
171
+ t_standard = sim_list[0]
172
+ if t_standard not in temp:
173
+ pred_terms.append(t_standard+f'({prob})')
174
+ temp.append(t_standard)
175
+ union_pred_terms.append(pred_terms)
176
+
177
+ if prompt == 'none':
178
+ res_str = "No available predictions for this protein, you can use other two types of model, remove prompt or try another sequence!"
179
+ else:
180
+ res_str = "No available predictions for this protein, you can use other two types of model or try another sequence!"
181
+ if len(union_pred_terms[0]) == 0 and len(union_pred_terms[1]) == 0 and len(union_pred_terms[2]) == 0:
182
+ return res_str
183
+ res_str = ''
184
+ if len(union_pred_terms[0]) != 0:
185
+ res_str += f"Based on the given amino acid sequence, the protein appears to have a primary function of {', '.join(pred_terms)}. "
186
+ if len(union_pred_terms[1]) != 0:
187
+ res_str += f"It is likely involved in the {', '.join(pred_terms)}. "
188
+ if len(union_pred_terms[2]) != 0:
189
+ res_str += f"It's subcellular localization is within the {', '.join(pred_terms)}."
190
+ return res_str
191
  # return "test"
192
 
193
 
 
196
 
197
  The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main)."""
198
 
199
+ # iface = gr.Interface(
200
+ # fn=generate_caption,
201
+ # inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")],
202
+ # outputs=gr.Textbox(label="Generated description"),
203
+ # description=description
204
+ # )
205
+ # # Launch the interface
206
+ # iface.launch()
207
+
208
+ css = """
209
+ #output {
210
+ height: 500px;
211
+ overflow: auto;
212
+ border: 1px solid #ccc;
213
+ }
214
+ """
215
+
216
+ with gr.Blocks(css=css) as demo:
217
+ gr.Markdown(description)
218
+ with gr.Tab(label="Protein caption"):
219
+ with gr.Row():
220
+ with gr.Column():
221
+ input_protein = gr.Textbox(type="text", label="Upload sequence")
222
+ prompt = gr.Textbox(type="text", label="Taxonomy Prompt (Optional)")
223
+ submit_btn = gr.Button(value="Submit")
224
+ with gr.Column():
225
+ output_text = gr.Textbox(label="Output Text")
226
+ # O14813 train index 127, 266, 738, 1060 test index 4
227
+ gr.Examples(
228
+ examples=[
229
+ ["MDYSYLNSYDSCVAAMEASAYGDFGACSQPGGFQYSPLRPAFPAAGPPCPALGSSNCALGALRDHQPAPYSAVPYKFFPEPSGLHEKRKQRRIRTTFTSAQLKELERVFAETHYPDIYTREELALKIDLTEARVQVWFQNRRAKFRKQERAASAKGAAGAAGAKKGEARCSSEDDDSKESTCSPTPDSTASLPPPPAPGLASPRLSPSPLPVALGSGPGPGPGPQPLKGALWAGVAGGGGGGPGAGAAELLKAWQPAESGPGPFSGVLSSFHRKPGPALKTNLF", ''],
230
+ ["MKTLALFLVLVCVLGLVQSWEWPWNRKPTKFPIPSPNPRDKWCRLNLGPAWGGRC", ''],
231
+ ["MAAAGGARLLRAASAVLGGPAGRWLHHAGSRAGSSGLLRNRGPGGSAEASRSLSVSARARSSSEDKITVHFINRDGETLTTKGKVGDSLLDVVVENNLDIDGFGACEGTLACSTCHLIFEDHIYEKLDAITDEENDMLDLAYGLTDRSRLGCQICLTKSMDNMTVRVPETVADARQSIDVGKTS", 'Homo'],
232
+ ['MASAELSREENVYMAKLAEQAERYEEMVEFMEKVAKTVDSEELTVEERNLLSVAYKNVIGARRASWRIISSIEQKEEGRGNEDRVTLIKDYRGKIETELTKICDGILKLLETHLVPSSTAPESKVFYLKMKGDYYRYLAEFKTGAERKDAAENTMVAYKAAQDIALAELAPTHPIRLGLALNFSVFYYEILNSPDRACSLAKQAFDEAISELDTLSEESYKDSTLIMQLLRDNLTLWTSDISEDPAEEIREAPKRDSSEGQ', 'Zea'],
233
+ ['MIKAAVTKESLYRMNTLMEAFQGFLGLDLGEFTFKVKPGVFLLTDVKSYLIGDKYDDAFNALIDFVLRNDRDAVEGTETDVSIRLGLSPSDMVVKRQDKTFTFTHGDLEFEVHWINL', 'Bacteriophage'],
234
+ ['MNDLMIQLLDQFEMGLRERAIKVMATINDEKHRFPMELNKKQCSLMLLGTTDTTTFDMRFNSKKDFPRIKGAREKYPRDAVIEWYHQNWMRTEVKQ', 'Bacteriophage'],
235
+ ],
236
+ inputs=[input_protein, prompt],
237
+ outputs=[output_text],
238
+ fn=generate_caption,
239
+ cache_examples=True,
240
+ label='Try examples'
241
+ )
242
+ submit_btn.click(generate_caption, [input_protein, prompt], [output_text])
243
+
244
+ demo.launch(debug=True)
245