|
import os |
|
import torch |
|
import torch.nn as nn |
|
import pandas as pd |
|
import torch.nn.functional as F |
|
from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral |
|
from lavis.models.base_model import FAPMConfig |
|
import spaces |
|
import gradio as gr |
|
from esm_scripts.extract import run_demo |
|
from esm import pretrained, FastaBatchedDataset |
|
|
|
|
|
|
|
|
|
|
|
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b') |
|
model.load_checkpoint("model/checkpoint_mf2.pth") |
|
model.to('cuda') |
|
|
|
|
|
@spaces.GPU |
|
def generate_caption(protein, prompt): |
|
|
|
esm_emb = torch.load('data/emb_esm2_3b/P18281.pt')['representations'][36] |
|
torch.save(esm_emb, 'data/emb_esm2_3b/example.pt') |
|
''' |
|
inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda') |
|
with torch.no_grad(): |
|
outputs = model_esm(**inputs) |
|
esm_emb = outputs.last_hidden_state.detach()[0] |
|
''' |
|
print("esm embedding generated") |
|
esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda') |
|
print("esm embedding processed") |
|
samples = {'name': ['protein_name'], |
|
'image': torch.unsqueeze(esm_emb, dim=0), |
|
'text_input': ['none'], |
|
'prompt': [prompt]} |
|
|
|
|
|
prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1., |
|
repetition_penalty=1.0) |
|
|
|
return prediction |
|
|
|
|
|
|
|
|
|
description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information. |
|
|
|
The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main).""" |
|
|
|
iface = gr.Interface( |
|
fn=generate_caption, |
|
inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")], |
|
outputs=gr.Textbox(label="Generated description"), |
|
description=description |
|
) |
|
|
|
|
|
iface.launch() |
|
|
|
|
|
|
|
|
|
|