Update app.py
Browse files
app.py
CHANGED
@@ -1,44 +1,41 @@
|
|
1 |
-
import os
|
2 |
-
import torch
|
3 |
-
import torch.nn as nn
|
4 |
-
import pandas as pd
|
5 |
-
import torch.nn.functional as F
|
6 |
-
from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral
|
7 |
-
from lavis.models.base_model import FAPMConfig
|
8 |
-
import spaces
|
9 |
import gradio as gr
|
10 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
from esm import pretrained, FastaBatchedDataset
|
12 |
|
13 |
-
# from transformers import EsmTokenizer, EsmModel
|
14 |
|
|
|
|
|
|
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
20 |
|
21 |
-
model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
|
22 |
-
# model_esm.to('cuda')
|
23 |
-
model_esm.eval()
|
24 |
|
|
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
# model_esm.to('cuda')
|
29 |
-
# model_esm.eval()
|
30 |
|
31 |
-
@spaces.GPU
|
32 |
-
def generate_caption(protein, prompt):
|
33 |
-
# Process the image and the prompt
|
34 |
-
# with open('/home/user/app/example.fasta', 'w') as f:
|
35 |
-
# f.write('>{}\n'.format("protein_name"))
|
36 |
-
# f.write('{}\n'.format(protein.strip()))
|
37 |
-
# os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D /home/user/app/example.fasta /home/user/app --repr_layers 36 --truncation_seq_length 1024 --include per_tok")
|
38 |
-
# esm_emb = run_demo(protein_name='protein_name', protein_seq=protein,
|
39 |
-
# model=model_esm, alphabet=alphabet,
|
40 |
-
# include='per_tok', repr_layers=[36], truncation_seq_length=1024)
|
41 |
|
|
|
|
|
|
|
42 |
protein_name = 'protein_name'
|
43 |
protein_seq = protein
|
44 |
include = 'per_tok'
|
@@ -51,8 +48,6 @@ def generate_caption(protein, prompt):
|
|
51 |
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
|
52 |
print("batches prepared")
|
53 |
|
54 |
-
model_esm.to('cuda')
|
55 |
-
|
56 |
data_loader = torch.utils.data.DataLoader(
|
57 |
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
|
58 |
)
|
@@ -70,7 +65,6 @@ def generate_caption(protein, prompt):
|
|
70 |
if torch.cuda.is_available():
|
71 |
toks = toks.to(device="cuda", non_blocking=True)
|
72 |
out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
|
73 |
-
logits = out["logits"].to(device="cpu")
|
74 |
representations = {
|
75 |
layer: t.to(device="cpu") for layer, t in out["representations"].items()
|
76 |
}
|
@@ -105,39 +99,40 @@ def generate_caption(protein, prompt):
|
|
105 |
esm_emb = outputs.last_hidden_state.detach()[0]
|
106 |
'''
|
107 |
print("esm embedding generated")
|
108 |
-
esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t()
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
#
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import AutoProcessor, AutoModelForCausalLM
|
3 |
+
import spaces
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import requests
|
6 |
+
import copy
|
7 |
+
import torch
|
8 |
+
from PIL import Image, ImageDraw, ImageFont
|
9 |
+
import io
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import matplotlib.patches as patches
|
12 |
+
|
13 |
+
import random
|
14 |
+
import numpy as np
|
15 |
from esm import pretrained, FastaBatchedDataset
|
16 |
|
|
|
17 |
|
18 |
+
models = {
|
19 |
+
'facebook/esm2_t36_3B_UR50D': pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D').to("cuda").eval(),
|
20 |
+
}
|
21 |
|
22 |
+
processors = {
|
23 |
+
'microsoft/Florence-2-large-ft': AutoProcessor.from_pretrained('microsoft/Florence-2-large-ft', trust_remote_code=True),
|
24 |
+
'microsoft/Florence-2-large': AutoProcessor.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True),
|
25 |
+
'microsoft/Florence-2-base-ft': AutoProcessor.from_pretrained('microsoft/Florence-2-base-ft', trust_remote_code=True),
|
26 |
+
'microsoft/Florence-2-base': AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True),
|
27 |
+
}
|
28 |
|
|
|
|
|
|
|
29 |
|
30 |
+
DESCRIPTION = "Esm2 embedding"
|
31 |
|
32 |
+
colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red',
|
33 |
+
'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
|
|
|
|
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
+
@spaces.GPU
|
37 |
+
def run_example(protein, model_id='facebook/esm2_t36_3B_UR50D'):
|
38 |
+
model_esm, alphabet = models[model_id]
|
39 |
protein_name = 'protein_name'
|
40 |
protein_seq = protein
|
41 |
include = 'per_tok'
|
|
|
48 |
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
|
49 |
print("batches prepared")
|
50 |
|
|
|
|
|
51 |
data_loader = torch.utils.data.DataLoader(
|
52 |
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
|
53 |
)
|
|
|
65 |
if torch.cuda.is_available():
|
66 |
toks = toks.to(device="cuda", non_blocking=True)
|
67 |
out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
|
|
|
68 |
representations = {
|
69 |
layer: t.to(device="cpu") for layer, t in out["representations"].items()
|
70 |
}
|
|
|
99 |
esm_emb = outputs.last_hidden_state.detach()[0]
|
100 |
'''
|
101 |
print("esm embedding generated")
|
102 |
+
esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t()
|
103 |
+
torch.save(esm_emb, 'example.pt')
|
104 |
+
return gr.File.update(value="example.pt", visible=True)
|
105 |
+
|
106 |
+
css = """
|
107 |
+
#output {
|
108 |
+
height: 500px;
|
109 |
+
overflow: auto;
|
110 |
+
border: 1px solid #ccc;
|
111 |
+
}
|
112 |
+
"""
|
113 |
+
|
114 |
+
with gr.Blocks(css=css) as demo:
|
115 |
+
gr.Markdown(DESCRIPTION)
|
116 |
+
with gr.Tab(label="Esm2 embedding generation"):
|
117 |
+
with gr.Row():
|
118 |
+
with gr.Column():
|
119 |
+
input_protein = gr.Textbox(type="text", label="Upload sequence")
|
120 |
+
model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='microsoft/Florence-2-large')
|
121 |
+
submit_btn = gr.Button(value="Submit")
|
122 |
+
with gr.Column():
|
123 |
+
button = gr.Button("Export")
|
124 |
+
pt = gr.File(interactive=False, visible=False)
|
125 |
+
# gr.Examples(
|
126 |
+
# examples=[
|
127 |
+
# ["image1.jpg", 'Object Detection'],
|
128 |
+
# ],
|
129 |
+
# inputs=[input_img, task_prompt],
|
130 |
+
# outputs=[output_text, output_img],
|
131 |
+
# fn=process_image,
|
132 |
+
# cache_examples=True,
|
133 |
+
# label='Try examples'
|
134 |
+
# )
|
135 |
+
|
136 |
+
button.click(run_example, [input_protein, model_selector], pt)
|
137 |
+
|
138 |
+
demo.launch(debug=True)
|