Update app.py
Browse files
app.py
CHANGED
@@ -9,6 +9,9 @@ import spaces
|
|
9 |
import gradio as gr
|
10 |
from esm_scripts.extract import run_demo
|
11 |
from esm import pretrained, FastaBatchedDataset
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
# Load the model
|
@@ -20,6 +23,21 @@ model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
|
|
20 |
model_esm.to('cuda')
|
21 |
model_esm.eval()
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
@spaces.GPU
|
25 |
def generate_caption(protein, prompt):
|
@@ -106,7 +124,16 @@ def generate_caption(protein, prompt):
|
|
106 |
prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
|
107 |
repetition_penalty=1.0)
|
108 |
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
# return "test"
|
111 |
|
112 |
|
|
|
9 |
import gradio as gr
|
10 |
from esm_scripts.extract import run_demo
|
11 |
from esm import pretrained, FastaBatchedDataset
|
12 |
+
from data.evaluate_data.utils import Ontology
|
13 |
+
import difflib
|
14 |
+
import re
|
15 |
|
16 |
|
17 |
# Load the model
|
|
|
23 |
model_esm.to('cuda')
|
24 |
model_esm.eval()
|
25 |
|
26 |
+
godb = Ontology(f'data/go1.4-basic.obo', with_rels=True)
|
27 |
+
go_des = pd.read_csv('data/go_descriptions1.4.txt', sep='|', header=None)
|
28 |
+
go_des.columns = ['id', 'text']
|
29 |
+
go_des = go_des.dropna()
|
30 |
+
go_des['id'] = go_des['id'].apply(lambda x: re.sub('_', ':', x))
|
31 |
+
go_obo_set = set(go_des['id'].tolist())
|
32 |
+
go_des['text'] = go_des['text'].apply(lambda x: x.lower())
|
33 |
+
GO_dict = dict(zip(go_des['text'], go_des['id']))
|
34 |
+
Func_dict = dict(zip(go_des['id'], go_des['text']))
|
35 |
+
|
36 |
+
# terms_mf = pd.read_pickle('/cluster/home/wenkai/deepgo2/data/mf/terms.pkl')
|
37 |
+
terms_mf = pd.read_pickle('data/terms/mf_terms.pkl')
|
38 |
+
choices_mf = [Func_dict[i] for i in list(set(terms_mf['gos']))]
|
39 |
+
choices = {x.lower(): x for x in choices_mf}
|
40 |
+
|
41 |
|
42 |
@spaces.GPU
|
43 |
def generate_caption(protein, prompt):
|
|
|
124 |
prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
|
125 |
repetition_penalty=1.0)
|
126 |
|
127 |
+
x = prediction[0]
|
128 |
+
x = [eval(i) for i in x.split('; ')]
|
129 |
+
pred_terms = []
|
130 |
+
for i in x:
|
131 |
+
txt = i[0]
|
132 |
+
prob = i[1]
|
133 |
+
sim_list = difflib.get_close_matches(txt.lower(), choices, n=1, cutoff=0.9)
|
134 |
+
if len(sim_list) > 0:
|
135 |
+
pred_terms.append((sim_list[0], prob))
|
136 |
+
return str(pred_terms)
|
137 |
# return "test"
|
138 |
|
139 |
|