FAPM_demo / app.py
wenkai's picture
Update app.py
da19b49 verified
import os
import torch
import torch.nn as nn
import pandas as pd
import torch.nn.functional as F
from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral
from lavis.models.base_model import FAPMConfig
import spaces
import gradio as gr
# from esm_scripts.extract import run_demo
from esm import pretrained, FastaBatchedDataset
from data.evaluate_data.utils import Ontology
import difflib
import re
from transformers import MistralForCausalLM
from huggingface_hub import hf_hub_download
bp_param = hf_hub_download(repo_id="wenkai/FAPM", filename="model/checkpoint_bp1.pth")
mf_param = hf_hub_download(repo_id="wenkai/FAPM", filename="model/checkpoint_mf2.pth")
cc_param = hf_hub_download(repo_id="wenkai/FAPM", filename="model/checkpoint_cc2.pth")
# hf_hub_download(repo_id="wenkai/FAPM", filename="model/mf2_bert.pth")
# hf_hub_download(repo_id="wenkai/FAPM", filename="model/bp1_bert.pth")
# hf_hub_download(repo_id="wenkai/FAPM", filename="model/cc2_bert.pth")
# bert_param = BertModel.from_pretrained("bert-base-uncased").state_dict()
# Load the trained model
def get_model(type='Molecule Function'):
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
if type == 'Molecule Function':
# model.load_checkpoint("model/checkpoint_mf2.pth")
model.load_checkpoint(mf_param)
model.Qformer.bert = torch.load('model/mf2_bert.pth', map_location=torch.device('cpu'))
model.to('cuda')
elif type == 'Biological Process':
# model.load_checkpoint("model/checkpoint_bp1.pth")
model.load_checkpoint(bp_param)
model.Qformer.bert = torch.load('model/bp1_bert.pth', map_location=torch.device('cpu'))
model.to('cuda')
elif type == 'Cellar Component':
# model.load_checkpoint("model/checkpoint_cc2.pth")
model.load_checkpoint(cc_param)
model.Qformer.bert = torch.load('model/cc2_bert.pth', map_location=torch.device('cpu'))
model.to('cuda')
return model
models = {
'Molecule Function': get_model('Molecule Function'),
'Biological Process': get_model('Biological Process'),
'Cellular Component': get_model('Cellar Component'),
}
# Load the mistral model
mistral_model = MistralForCausalLM.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B", torch_dtype=torch.float16)
mistral_model.to('cuda')
# Load ESM2 model
model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
model_esm.to('cuda')
model_esm.eval()
godb = Ontology(f'data/go1.4-basic.obo', with_rels=True)
go_des = pd.read_csv('data/go_descriptions1.4.txt', sep='|', header=None)
go_des.columns = ['id', 'text']
go_des = go_des.dropna()
go_des['id'] = go_des['id'].apply(lambda x: re.sub('_', ':', x))
go_obo_set = set(go_des['id'].tolist())
go_des['text'] = go_des['text'].apply(lambda x: x.lower())
GO_dict = dict(zip(go_des['text'], go_des['id']))
Func_dict = dict(zip(go_des['id'], go_des['text']))
terms_mf = pd.read_pickle('data/terms/mf_terms.pkl')
choices_mf = [Func_dict[i] for i in list(set(terms_mf['gos']))]
choices_mf = {x.lower(): x for x in choices_mf}
terms_bp = pd.read_pickle('data/terms/bp_terms.pkl')
choices_bp = [Func_dict[i] for i in list(set(terms_bp['gos']))]
choices_bp = {x.lower(): x for x in choices_bp}
terms_cc = pd.read_pickle('data/terms/cc_terms.pkl')
choices_cc = [Func_dict[i] for i in list(set(terms_cc['gos']))]
choices_cc = {x.lower(): x for x in choices_cc}
choices = {
'Molecule Function': choices_mf,
'Biological Process': choices_bp,
'Cellular Component': choices_cc,
}
@spaces.GPU
def generate_caption(protein, prompt):
# Process the image and the prompt
# with open('/home/user/app/example.fasta', 'w') as f:
# f.write('>{}\n'.format("protein_name"))
# f.write('{}\n'.format(protein.strip()))
# os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D /home/user/app/example.fasta /home/user/app --repr_layers 36 --truncation_seq_length 1024 --include per_tok")
# esm_emb = run_demo(protein_name='protein_name', protein_seq=protein,
# model=model_esm, alphabet=alphabet,
# include='per_tok', repr_layers=[36], truncation_seq_length=1024)
protein_name = 'protein_name'
protein_seq = protein
include = 'per_tok'
repr_layers = [36]
truncation_seq_length = 1024
toks_per_batch = 4096
# print("start")
dataset = FastaBatchedDataset([protein_name], [protein_seq])
# print("dataset prepared")
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
# print("batches prepared")
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
)
# print(f"Read sequences")
return_contacts = "contacts" in include
assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
with torch.no_grad():
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
print(
f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
)
if torch.cuda.is_available():
toks = toks.to(device="cuda", non_blocking=True)
out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
representations = {
layer: t.to(device="cpu") for layer, t in out["representations"].items()
}
if return_contacts:
contacts = out["contacts"].to(device="cpu")
for i, label in enumerate(labels):
result = {"label": label}
truncate_len = min(truncation_seq_length, len(strs[i]))
# Call clone on tensors to ensure tensors are not views into a larger representation
# See https://github.com/pytorch/pytorch/issues/1995
if "per_tok" in include:
result["representations"] = {
layer: t[i, 1: truncate_len + 1].clone()
for layer, t in representations.items()
}
if "mean" in include:
result["mean_representations"] = {
layer: t[i, 1: truncate_len + 1].mean(0).clone()
for layer, t in representations.items()
}
if "bos" in include:
result["bos_representations"] = {
layer: t[i, 0].clone() for layer, t in representations.items()
}
if return_contacts:
result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
esm_emb = result['representations'][36]
'''
inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda')
with torch.no_grad():
outputs = model_esm(**inputs)
esm_emb = outputs.last_hidden_state.detach()[0]
'''
# print("esm embedding generated")
esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
if prompt is None:
prompt = 'none'
else:
prompt = prompt.lower()
samples = {'name': ['protein_name'],
'image': torch.unsqueeze(esm_emb, dim=0),
'text_input': ['none'],
'prompt': [prompt]}
union_pred_terms = []
for model_id in models.keys():
model = models[model_id]
# Generate the output
prediction = model.generate(mistral_model, samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
repetition_penalty=1.0)
x = prediction[0]
x = [eval(i) for i in x.split('; ')]
pred_terms = []
temp = []
for i in x:
txt = i[0]
prob = i[1]
sim_list = difflib.get_close_matches(txt.lower(), choices[model_id], n=1, cutoff=0.9)
if len(sim_list) > 0:
t_standard = sim_list[0]
if t_standard not in temp:
pred_terms.append(t_standard+f'({prob})')
temp.append(t_standard)
union_pred_terms.append(pred_terms)
if prompt == 'none':
res_str = "No available predictions for this protein, you can use other two types of model, remove prompt or try another sequence!"
else:
res_str = "No available predictions for this protein, you can use other two types of model or try another sequence!"
if len(union_pred_terms[0]) == 0 and len(union_pred_terms[1]) == 0 and len(union_pred_terms[2]) == 0:
return res_str
res_str = ''
if len(union_pred_terms[0]) != 0:
temp = ['- '+i+'\n' for i in union_pred_terms[0]]
res_str += f"Based on the given amino acid sequence, the protein appears to have a primary function of \n{''.join(temp)} \n"
if len(union_pred_terms[1]) != 0:
temp = ['- ' + i + '\n' for i in union_pred_terms[1]]
res_str += f"It is likely involved in the following process: \n{''.join(temp)} \n"
if len(union_pred_terms[2]) != 0:
temp = ['- ' + i + '\n' for i in union_pred_terms[2]]
res_str += f"It's subcellular localization is within the: \n{''.join(temp)}"
return res_str
def save_feedback(inputs):
print(inputs)
with open('feedback.txt', 'a+') as f:
f.write(inputs+'\n')
return "Thanks your advice!"
feedback_data = []
def chatbot_respond(message, history=[]):
response = "yes"
return response, history + [(message, response)]
# Functions to handle like/dislike
def upvote(vote_id):
feedback_data.append((vote_id, "upvote"))
print(f"Current feedback data: {feedback_data}")
return "You liked this prediction"
def downvote(vote_id):
feedback_data.append((vote_id, "downvote"))
print(f"Current feedback data: {feedback_data}")
return "You disliked this prediction"
# Define the FAPM interface
description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information.
Our paper is available at [BioRxiv](https://www.biorxiv.org/content/10.1101/2024.05.07.593067v1)
The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main).
Thanks for the support from ProtonUnfold Tech. Co., Ltd (https://www.protonunfold.com/)."""
# iface = gr.Interface(
# fn=generate_caption,
# inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")],
# outputs=gr.Textbox(label="Generated description"),
# description=description
# )
# # Launch the interface
# iface.launch()
css = """
#output {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
.submit-btn {
display: flex;
width: 100%;
gap: 10px;
}
.vote-container {
display: flex;
width: 100%;
gap: 10px;
}
.vote-buttons {
display: flex;
width: 100%;
gap: 10px;
}
.vote-btn {
display: flex;
width: 100%;
gap: 10px;
}
.vote-content {
flex: 3;
}
/* Style for the upvote button */
.upvote-button {
width: 500px; /* Set button width */
height: 50px; /* Set button height */
font-size: 20px; /* Set font size */
background-color: #d4edda; /* Set background color */
border-radius: 5px; /* Rounded corners */
}
/* Style for the downvote button */
.downvote-button {
width: 50px; /* Set button width */
height: 50px; /* Set button height */
font-size: 20px; /* Set font size */
background-color: #f8d7da; /* Set background color */
border-radius: 5px; /* Rounded corners */
}
.feedback {
width: 40px; /* Set button width */
height: 40px; /* Set button height */
font-size: 16px; /* Set font size */
background-color: #f8d7da; /* Set background color */
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(description)
# vote_id = gr.State(0)
with gr.Tab(label="Protein caption"):
with gr.Row():
with gr.Column():
input_protein = gr.Textbox(type="text", label="Upload sequence")
prompt = gr.Textbox(type="text", label="Taxonomy Prompt (Optional)")
submit_btn = gr.Button(value="Submit", elem_classes=["submit-btn"])
with gr.Column():
# output_text = gr.Textbox(label="Output Text")
with gr.Accordion('Prediction:', open=True):
output_markdown = gr.Markdown(label="Output")
# ๆŠ•็ฅจๆŒ‰้’ฎๅ’Œๅ†…ๅฎน็š„ๅฎนๅ™จ
with gr.Row(elem_classes=["vote-container"]):
# ๆŠ•็ฅจๆŒ‰้’ฎ็ป„
with gr.Column():
with gr.Row():
with gr.Column():
upvote_button = gr.Button("๐Ÿ‘", elem_classes=["vote-btn"])
with gr.Column():
downvote_button = gr.Button("๐Ÿ‘Ž", elem_classes=["vote-btn"])
with gr.Column():
vote_markdown = gr.Markdown(label="Output")
with gr.Row():
inputs = gr.Textbox(type="text", label="Your feedback")
feedback_markdown = gr.Markdown(label="Output")
with gr.Row():
with gr.Column():
feedback_btn = gr.Button(value="Feedback")
# feedback_temp1 = gr.Markdown(label="Output")
with gr.Column():
feedback_temp2 = gr.Markdown(label="Output")
gr.Examples(
examples=[
["MDYSYLNSYDSCVAAMEASAYGDFGACSQPGGFQYSPLRPAFPAAGPPCPALGSSNCALGALRDHQPAPYSAVPYKFFPEPSGLHEKRKQRRIRTTFTSAQLKELERVFAETHYPDIYTREELALKIDLTEARVQVWFQNRRAKFRKQERAASAKGAAGAAGAKKGEARCSSEDDDSKESTCSPTPDSTASLPPPPAPGLASPRLSPSPLPVALGSGPGPGPGPQPLKGALWAGVAGGGGGGPGAGAAELLKAWQPAESGPGPFSGVLSSFHRKPGPALKTNLF", ''],
["MKTLALFLVLVCVLGLVQSWEWPWNRKPTKFPIPSPNPRDKWCRLNLGPAWGGRC", ''],
["MAAAGGARLLRAASAVLGGPAGRWLHHAGSRAGSSGLLRNRGPGGSAEASRSLSVSARARSSSEDKITVHFINRDGETLTTKGKVGDSLLDVVVENNLDIDGFGACEGTLACSTCHLIFEDHIYEKLDAITDEENDMLDLAYGLTDRSRLGCQICLTKSMDNMTVRVPETVADARQSIDVGKTS", 'Homo'],
['MASAELSREENVYMAKLAEQAERYEEMVEFMEKVAKTVDSEELTVEERNLLSVAYKNVIGARRASWRIISSIEQKEEGRGNEDRVTLIKDYRGKIETELTKICDGILKLLETHLVPSSTAPESKVFYLKMKGDYYRYLAEFKTGAERKDAAENTMVAYKAAQDIALAELAPTHPIRLGLALNFSVFYYEILNSPDRACSLAKQAFDEAISELDTLSEESYKDSTLIMQLLRDNLTLWTSDISEDPAEEIREAPKRDSSEGQ', 'Zea'],
['MIKAAVTKESLYRMNTLMEAFQGFLGLDLGEFTFKVKPGVFLLTDVKSYLIGDKYDDAFNALIDFVLRNDRDAVEGTETDVSIRLGLSPSDMVVKRQDKTFTFTHGDLEFEVHWINL', 'Bacteriophage'],
['MNDLMIQLLDQFEMGLRERAIKVMATINDEKHRFPMELNKKQCSLMLLGTTDTTTFDMRFNSKKDFPRIKGAREKYPRDAVIEWYHQNWMRTEVKQ', 'Bacteriophage'],
],
inputs=[input_protein, prompt],
outputs=[output_markdown],
fn=generate_caption,
cache_examples=True,
label='Try examples'
)
submit_btn.click(generate_caption, [input_protein, prompt], [output_markdown])
upvote_button.click(upvote, input_protein, vote_markdown)
downvote_button.click(downvote, input_protein, vote_markdown)
feedback_btn.click(save_feedback, [inputs], [feedback_markdown])
demo.launch(debug=True)