File size: 15,782 Bytes
2b26389 72b0e49 2b26389 181fc76 2b26389 aad9fe1 9b993cf f3ed046 aad9fe1 d0dd902 aad9fe1 d0dd902 aad9fe1 d0dd902 07ac117 aad9fe1 d0dd902 0b30831 aad9fe1 d0dd902 07ac117 aad9fe1 eb615db aad9fe1 c34048a aad9fe1 3daa625 f3ed046 eb615db 1a0324b 72b0e49 eb615db 3daa625 507babf 3daa625 507babf 3daa625 507babf c8e59d5 3daa625 507babf 3daa625 3705c34 77b966b 3705c34 77b966b 3705c34 cdf31f1 507babf 2b26389 2bc812b 2b26389 eb615db 6b16660 aad9fe1 6b16660 aad9fe1 eb615db e95deab eb615db 507babf ab4d8af eb615db 507babf ab4d8af eb615db 507babf ab4d8af 1167137 2b26389 8df133a 491d478 8df133a 491d478 48629ab 8df133a d638ffc 3e43b5a d638ffc 4ce5cb4 3e43b5a d638ffc 4ce5cb4 3e43b5a 48629ab 2b26389 222ff4a d87c732 2b26389 c3846ee a8642e4 da19b49 106f1cd 6789f7b 106f1cd 6789f7b 2592450 fb28f39 2592450 106f1cd a8642e4 c3846ee 8e8d15c c3846ee 892748f da19b49 c3846ee 507babf ab4d8af 2ad7e51 da19b49 106f1cd da19b49 106f1cd da19b49 7a5fd5c 2b1fc35 848fbff aa8d8d9 c3846ee bf58ca4 c3846ee eb615db 507babf c3846ee 507babf 8e8d15c d638ffc bece41e 48629ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 |
import os
import torch
import torch.nn as nn
import pandas as pd
import torch.nn.functional as F
from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral
from lavis.models.base_model import FAPMConfig
import spaces
import gradio as gr
# from esm_scripts.extract import run_demo
from esm import pretrained, FastaBatchedDataset
from data.evaluate_data.utils import Ontology
import difflib
import re
from transformers import MistralForCausalLM
from huggingface_hub import hf_hub_download
bp_param = hf_hub_download(repo_id="wenkai/FAPM", filename="model/checkpoint_bp1.pth")
mf_param = hf_hub_download(repo_id="wenkai/FAPM", filename="model/checkpoint_mf2.pth")
cc_param = hf_hub_download(repo_id="wenkai/FAPM", filename="model/checkpoint_cc2.pth")
# hf_hub_download(repo_id="wenkai/FAPM", filename="model/mf2_bert.pth")
# hf_hub_download(repo_id="wenkai/FAPM", filename="model/bp1_bert.pth")
# hf_hub_download(repo_id="wenkai/FAPM", filename="model/cc2_bert.pth")
# bert_param = BertModel.from_pretrained("bert-base-uncased").state_dict()
# Load the trained model
def get_model(type='Molecule Function'):
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
if type == 'Molecule Function':
# model.load_checkpoint("model/checkpoint_mf2.pth")
model.load_checkpoint(mf_param)
model.Qformer.bert = torch.load('model/mf2_bert.pth', map_location=torch.device('cpu'))
model.to('cuda')
elif type == 'Biological Process':
# model.load_checkpoint("model/checkpoint_bp1.pth")
model.load_checkpoint(bp_param)
model.Qformer.bert = torch.load('model/bp1_bert.pth', map_location=torch.device('cpu'))
model.to('cuda')
elif type == 'Cellar Component':
# model.load_checkpoint("model/checkpoint_cc2.pth")
model.load_checkpoint(cc_param)
model.Qformer.bert = torch.load('model/cc2_bert.pth', map_location=torch.device('cpu'))
model.to('cuda')
return model
models = {
'Molecule Function': get_model('Molecule Function'),
'Biological Process': get_model('Biological Process'),
'Cellular Component': get_model('Cellar Component'),
}
# Load the mistral model
mistral_model = MistralForCausalLM.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B", torch_dtype=torch.float16)
mistral_model.to('cuda')
# Load ESM2 model
model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
model_esm.to('cuda')
model_esm.eval()
godb = Ontology(f'data/go1.4-basic.obo', with_rels=True)
go_des = pd.read_csv('data/go_descriptions1.4.txt', sep='|', header=None)
go_des.columns = ['id', 'text']
go_des = go_des.dropna()
go_des['id'] = go_des['id'].apply(lambda x: re.sub('_', ':', x))
go_obo_set = set(go_des['id'].tolist())
go_des['text'] = go_des['text'].apply(lambda x: x.lower())
GO_dict = dict(zip(go_des['text'], go_des['id']))
Func_dict = dict(zip(go_des['id'], go_des['text']))
terms_mf = pd.read_pickle('data/terms/mf_terms.pkl')
choices_mf = [Func_dict[i] for i in list(set(terms_mf['gos']))]
choices_mf = {x.lower(): x for x in choices_mf}
terms_bp = pd.read_pickle('data/terms/bp_terms.pkl')
choices_bp = [Func_dict[i] for i in list(set(terms_bp['gos']))]
choices_bp = {x.lower(): x for x in choices_bp}
terms_cc = pd.read_pickle('data/terms/cc_terms.pkl')
choices_cc = [Func_dict[i] for i in list(set(terms_cc['gos']))]
choices_cc = {x.lower(): x for x in choices_cc}
choices = {
'Molecule Function': choices_mf,
'Biological Process': choices_bp,
'Cellular Component': choices_cc,
}
@spaces.GPU
def generate_caption(protein, prompt):
# Process the image and the prompt
# with open('/home/user/app/example.fasta', 'w') as f:
# f.write('>{}\n'.format("protein_name"))
# f.write('{}\n'.format(protein.strip()))
# os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D /home/user/app/example.fasta /home/user/app --repr_layers 36 --truncation_seq_length 1024 --include per_tok")
# esm_emb = run_demo(protein_name='protein_name', protein_seq=protein,
# model=model_esm, alphabet=alphabet,
# include='per_tok', repr_layers=[36], truncation_seq_length=1024)
protein_name = 'protein_name'
protein_seq = protein
include = 'per_tok'
repr_layers = [36]
truncation_seq_length = 1024
toks_per_batch = 4096
# print("start")
dataset = FastaBatchedDataset([protein_name], [protein_seq])
# print("dataset prepared")
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
# print("batches prepared")
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
)
# print(f"Read sequences")
return_contacts = "contacts" in include
assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
with torch.no_grad():
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
print(
f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
)
if torch.cuda.is_available():
toks = toks.to(device="cuda", non_blocking=True)
out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
representations = {
layer: t.to(device="cpu") for layer, t in out["representations"].items()
}
if return_contacts:
contacts = out["contacts"].to(device="cpu")
for i, label in enumerate(labels):
result = {"label": label}
truncate_len = min(truncation_seq_length, len(strs[i]))
# Call clone on tensors to ensure tensors are not views into a larger representation
# See https://github.com/pytorch/pytorch/issues/1995
if "per_tok" in include:
result["representations"] = {
layer: t[i, 1: truncate_len + 1].clone()
for layer, t in representations.items()
}
if "mean" in include:
result["mean_representations"] = {
layer: t[i, 1: truncate_len + 1].mean(0).clone()
for layer, t in representations.items()
}
if "bos" in include:
result["bos_representations"] = {
layer: t[i, 0].clone() for layer, t in representations.items()
}
if return_contacts:
result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
esm_emb = result['representations'][36]
'''
inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda')
with torch.no_grad():
outputs = model_esm(**inputs)
esm_emb = outputs.last_hidden_state.detach()[0]
'''
# print("esm embedding generated")
esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
if prompt is None:
prompt = 'none'
else:
prompt = prompt.lower()
samples = {'name': ['protein_name'],
'image': torch.unsqueeze(esm_emb, dim=0),
'text_input': ['none'],
'prompt': [prompt]}
union_pred_terms = []
for model_id in models.keys():
model = models[model_id]
# Generate the output
prediction = model.generate(mistral_model, samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
repetition_penalty=1.0)
x = prediction[0]
x = [eval(i) for i in x.split('; ')]
pred_terms = []
temp = []
for i in x:
txt = i[0]
prob = i[1]
sim_list = difflib.get_close_matches(txt.lower(), choices[model_id], n=1, cutoff=0.9)
if len(sim_list) > 0:
t_standard = sim_list[0]
if t_standard not in temp:
pred_terms.append(t_standard+f'({prob})')
temp.append(t_standard)
union_pred_terms.append(pred_terms)
if prompt == 'none':
res_str = "No available predictions for this protein, you can use other two types of model, remove prompt or try another sequence!"
else:
res_str = "No available predictions for this protein, you can use other two types of model or try another sequence!"
if len(union_pred_terms[0]) == 0 and len(union_pred_terms[1]) == 0 and len(union_pred_terms[2]) == 0:
return res_str
res_str = ''
if len(union_pred_terms[0]) != 0:
temp = ['- '+i+'\n' for i in union_pred_terms[0]]
res_str += f"Based on the given amino acid sequence, the protein appears to have a primary function of \n{''.join(temp)} \n"
if len(union_pred_terms[1]) != 0:
temp = ['- ' + i + '\n' for i in union_pred_terms[1]]
res_str += f"It is likely involved in the following process: \n{''.join(temp)} \n"
if len(union_pred_terms[2]) != 0:
temp = ['- ' + i + '\n' for i in union_pred_terms[2]]
res_str += f"It's subcellular localization is within the: \n{''.join(temp)}"
return res_str
def save_feedback(inputs):
print(inputs)
with open('feedback.txt', 'a+') as f:
f.write(inputs+'\n')
return "Thanks your advice!"
feedback_data = []
def chatbot_respond(message, history=[]):
response = "yes"
return response, history + [(message, response)]
# Functions to handle like/dislike
def upvote(vote_id):
feedback_data.append((vote_id, "upvote"))
print(f"Current feedback data: {feedback_data}")
return "You liked this prediction"
def downvote(vote_id):
feedback_data.append((vote_id, "downvote"))
print(f"Current feedback data: {feedback_data}")
return "You disliked this prediction"
# Define the FAPM interface
description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information.
Our paper is available at [BioRxiv](https://www.biorxiv.org/content/10.1101/2024.05.07.593067v1)
The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main).
Thanks for the support from ProtonUnfold Tech. Co., Ltd (https://www.protonunfold.com/)."""
# iface = gr.Interface(
# fn=generate_caption,
# inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")],
# outputs=gr.Textbox(label="Generated description"),
# description=description
# )
# # Launch the interface
# iface.launch()
css = """
#output {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
.submit-btn {
display: flex;
width: 100%;
gap: 10px;
}
.vote-container {
display: flex;
width: 100%;
gap: 10px;
}
.vote-buttons {
display: flex;
width: 100%;
gap: 10px;
}
.vote-btn {
display: flex;
width: 100%;
gap: 10px;
}
.vote-content {
flex: 3;
}
/* Style for the upvote button */
.upvote-button {
width: 500px; /* Set button width */
height: 50px; /* Set button height */
font-size: 20px; /* Set font size */
background-color: #d4edda; /* Set background color */
border-radius: 5px; /* Rounded corners */
}
/* Style for the downvote button */
.downvote-button {
width: 50px; /* Set button width */
height: 50px; /* Set button height */
font-size: 20px; /* Set font size */
background-color: #f8d7da; /* Set background color */
border-radius: 5px; /* Rounded corners */
}
.feedback {
width: 40px; /* Set button width */
height: 40px; /* Set button height */
font-size: 16px; /* Set font size */
background-color: #f8d7da; /* Set background color */
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(description)
# vote_id = gr.State(0)
with gr.Tab(label="Protein caption"):
with gr.Row():
with gr.Column():
input_protein = gr.Textbox(type="text", label="Upload sequence")
prompt = gr.Textbox(type="text", label="Taxonomy Prompt (Optional)")
submit_btn = gr.Button(value="Submit", elem_classes=["submit-btn"])
with gr.Column():
# output_text = gr.Textbox(label="Output Text")
with gr.Accordion('Prediction:', open=True):
output_markdown = gr.Markdown(label="Output")
# ๆ็ฅจๆ้ฎๅๅ
ๅฎน็ๅฎนๅจ
with gr.Row(elem_classes=["vote-container"]):
# ๆ็ฅจๆ้ฎ็ป
with gr.Column():
with gr.Row():
with gr.Column():
upvote_button = gr.Button("๐", elem_classes=["vote-btn"])
with gr.Column():
downvote_button = gr.Button("๐", elem_classes=["vote-btn"])
with gr.Column():
vote_markdown = gr.Markdown(label="Output")
with gr.Row():
inputs = gr.Textbox(type="text", label="Your feedback")
feedback_markdown = gr.Markdown(label="Output")
with gr.Row():
with gr.Column():
feedback_btn = gr.Button(value="Feedback")
# feedback_temp1 = gr.Markdown(label="Output")
with gr.Column():
feedback_temp2 = gr.Markdown(label="Output")
gr.Examples(
examples=[
["MDYSYLNSYDSCVAAMEASAYGDFGACSQPGGFQYSPLRPAFPAAGPPCPALGSSNCALGALRDHQPAPYSAVPYKFFPEPSGLHEKRKQRRIRTTFTSAQLKELERVFAETHYPDIYTREELALKIDLTEARVQVWFQNRRAKFRKQERAASAKGAAGAAGAKKGEARCSSEDDDSKESTCSPTPDSTASLPPPPAPGLASPRLSPSPLPVALGSGPGPGPGPQPLKGALWAGVAGGGGGGPGAGAAELLKAWQPAESGPGPFSGVLSSFHRKPGPALKTNLF", ''],
["MKTLALFLVLVCVLGLVQSWEWPWNRKPTKFPIPSPNPRDKWCRLNLGPAWGGRC", ''],
["MAAAGGARLLRAASAVLGGPAGRWLHHAGSRAGSSGLLRNRGPGGSAEASRSLSVSARARSSSEDKITVHFINRDGETLTTKGKVGDSLLDVVVENNLDIDGFGACEGTLACSTCHLIFEDHIYEKLDAITDEENDMLDLAYGLTDRSRLGCQICLTKSMDNMTVRVPETVADARQSIDVGKTS", 'Homo'],
['MASAELSREENVYMAKLAEQAERYEEMVEFMEKVAKTVDSEELTVEERNLLSVAYKNVIGARRASWRIISSIEQKEEGRGNEDRVTLIKDYRGKIETELTKICDGILKLLETHLVPSSTAPESKVFYLKMKGDYYRYLAEFKTGAERKDAAENTMVAYKAAQDIALAELAPTHPIRLGLALNFSVFYYEILNSPDRACSLAKQAFDEAISELDTLSEESYKDSTLIMQLLRDNLTLWTSDISEDPAEEIREAPKRDSSEGQ', 'Zea'],
['MIKAAVTKESLYRMNTLMEAFQGFLGLDLGEFTFKVKPGVFLLTDVKSYLIGDKYDDAFNALIDFVLRNDRDAVEGTETDVSIRLGLSPSDMVVKRQDKTFTFTHGDLEFEVHWINL', 'Bacteriophage'],
['MNDLMIQLLDQFEMGLRERAIKVMATINDEKHRFPMELNKKQCSLMLLGTTDTTTFDMRFNSKKDFPRIKGAREKYPRDAVIEWYHQNWMRTEVKQ', 'Bacteriophage'],
],
inputs=[input_protein, prompt],
outputs=[output_markdown],
fn=generate_caption,
cache_examples=True,
label='Try examples'
)
submit_btn.click(generate_caption, [input_protein, prompt], [output_markdown])
upvote_button.click(upvote, input_protein, vote_markdown)
downvote_button.click(downvote, input_protein, vote_markdown)
feedback_btn.click(save_feedback, [inputs], [feedback_markdown])
demo.launch(debug=True) |