Update app.py
Browse files
app.py
CHANGED
@@ -1,91 +1,28 @@
|
|
1 |
-
import
|
2 |
-
from transformers import AutoProcessor, AutoModelForCausalLM
|
3 |
-
import spaces
|
4 |
-
import torch.nn.functional as F
|
5 |
-
import copy
|
6 |
import torch
|
7 |
-
|
8 |
-
import
|
9 |
-
import
|
|
|
|
|
|
|
|
|
|
|
10 |
from esm import pretrained, FastaBatchedDataset
|
11 |
|
12 |
-
|
13 |
-
def get_model(model_id):
|
14 |
-
a, b = pretrained.load_model_and_alphabet(model_id.split('/')[1])
|
15 |
-
a.to('cuda').eval()
|
16 |
-
return (a, b)
|
17 |
-
|
18 |
-
models = {
|
19 |
-
'facebook/esm2_t36_3B_UR50D': get_model('facebook/esm2_t36_3B_UR50D'),
|
20 |
-
}
|
21 |
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
|
28 |
|
29 |
|
30 |
@spaces.GPU
|
31 |
-
def
|
32 |
-
model_esm, alphabet = models[model_id]
|
33 |
-
protein_name = 'protein_name'
|
34 |
-
protein_seq = protein
|
35 |
-
include = 'per_tok'
|
36 |
-
repr_layers = [36]
|
37 |
-
truncation_seq_length = 1024
|
38 |
-
toks_per_batch = 4096
|
39 |
-
print("start")
|
40 |
-
dataset = FastaBatchedDataset([protein_name], [protein_seq])
|
41 |
-
print("dataset prepared")
|
42 |
-
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
|
43 |
-
print("batches prepared")
|
44 |
-
|
45 |
-
data_loader = torch.utils.data.DataLoader(
|
46 |
-
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
|
47 |
-
)
|
48 |
-
print(f"Read sequences")
|
49 |
-
return_contacts = "contacts" in include
|
50 |
-
|
51 |
-
assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
|
52 |
-
repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
|
53 |
|
54 |
-
|
55 |
-
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
|
56 |
-
print(
|
57 |
-
f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
|
58 |
-
)
|
59 |
-
if torch.cuda.is_available():
|
60 |
-
toks = toks.to(device="cuda", non_blocking=True)
|
61 |
-
out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
|
62 |
-
representations = {
|
63 |
-
layer: t.to(device="cpu") for layer, t in out["representations"].items()
|
64 |
-
}
|
65 |
-
if return_contacts:
|
66 |
-
contacts = out["contacts"].to(device="cpu")
|
67 |
-
for i, label in enumerate(labels):
|
68 |
-
result = {"label": label}
|
69 |
-
truncate_len = min(truncation_seq_length, len(strs[i]))
|
70 |
-
# Call clone on tensors to ensure tensors are not views into a larger representation
|
71 |
-
# See https://github.com/pytorch/pytorch/issues/1995
|
72 |
-
if "per_tok" in include:
|
73 |
-
result["representations"] = {
|
74 |
-
layer: t[i, 1: truncate_len + 1].clone()
|
75 |
-
for layer, t in representations.items()
|
76 |
-
}
|
77 |
-
if "mean" in include:
|
78 |
-
result["mean_representations"] = {
|
79 |
-
layer: t[i, 1: truncate_len + 1].mean(0).clone()
|
80 |
-
for layer, t in representations.items()
|
81 |
-
}
|
82 |
-
if "bos" in include:
|
83 |
-
result["bos_representations"] = {
|
84 |
-
layer: t[i, 0].clone() for layer, t in representations.items()
|
85 |
-
}
|
86 |
-
if return_contacts:
|
87 |
-
result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
|
88 |
-
esm_emb = result['representations'][36]
|
89 |
'''
|
90 |
inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda')
|
91 |
with torch.no_grad():
|
@@ -93,40 +30,36 @@ def run_example(protein, model_id='facebook/esm2_t36_3B_UR50D'):
|
|
93 |
esm_emb = outputs.last_hidden_state.detach()[0]
|
94 |
'''
|
95 |
print("esm embedding generated")
|
96 |
-
esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t()
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
button.click(run_example, [input_protein, model_selector], pt)
|
131 |
-
|
132 |
-
demo.launch(debug=True)
|
|
|
1 |
+
import os
|
|
|
|
|
|
|
|
|
2 |
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import pandas as pd
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral
|
7 |
+
from lavis.models.base_model import FAPMConfig
|
8 |
+
import spaces
|
9 |
+
import gradio as gr
|
10 |
+
from esm_scripts.extract import run_demo
|
11 |
from esm import pretrained, FastaBatchedDataset
|
12 |
|
13 |
+
# from transformers import EsmTokenizer, EsmModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
+
# Load the model
|
17 |
+
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
|
18 |
+
model.load_checkpoint("model/checkpoint_mf2.pth")
|
19 |
+
model.to('cuda')
|
|
|
20 |
|
21 |
|
22 |
@spaces.GPU
|
23 |
+
def generate_caption(protein, prompt):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
+
esm_emb = torch.load('data/emb_esm2_3b/P18281.pt')['representations'][36]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
'''
|
27 |
inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda')
|
28 |
with torch.no_grad():
|
|
|
30 |
esm_emb = outputs.last_hidden_state.detach()[0]
|
31 |
'''
|
32 |
print("esm embedding generated")
|
33 |
+
esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
|
34 |
+
print("esm embedding processed")
|
35 |
+
samples = {'name': ['protein_name'],
|
36 |
+
'image': torch.unsqueeze(esm_emb, dim=0),
|
37 |
+
'text_input': ['none'],
|
38 |
+
'prompt': [prompt]}
|
39 |
+
|
40 |
+
# Generate the output
|
41 |
+
prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
|
42 |
+
repetition_penalty=1.0)
|
43 |
+
|
44 |
+
return prediction
|
45 |
+
# return "test"
|
46 |
+
|
47 |
+
|
48 |
+
# Define the FAPM interface
|
49 |
+
description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information.
|
50 |
+
|
51 |
+
The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main)."""
|
52 |
+
|
53 |
+
iface = gr.Interface(
|
54 |
+
fn=generate_caption,
|
55 |
+
inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")],
|
56 |
+
outputs=gr.Textbox(label="Generated description"),
|
57 |
+
description=description
|
58 |
+
)
|
59 |
+
|
60 |
+
# Launch the interface
|
61 |
+
iface.launch()
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
|
|
|
|
|
|
|
|