Update app.py
Browse files
app.py
CHANGED
@@ -32,12 +32,12 @@ def get_model(type='Molecule Function'):
|
|
32 |
models = {
|
33 |
'Molecule Function': get_model('Molecule Function'),
|
34 |
'Biological Process': get_model('Biological Process'),
|
35 |
-
'
|
36 |
}
|
37 |
|
38 |
|
39 |
# Load the mistral model
|
40 |
-
mistral_model = MistralForCausalLM.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B", torch_dtype=torch.float16)
|
41 |
|
42 |
# Load ESM2 model
|
43 |
model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
|
@@ -54,14 +54,23 @@ go_des['text'] = go_des['text'].apply(lambda x: x.lower())
|
|
54 |
GO_dict = dict(zip(go_des['text'], go_des['id']))
|
55 |
Func_dict = dict(zip(go_des['id'], go_des['text']))
|
56 |
|
57 |
-
# terms_mf = pd.read_pickle('/cluster/home/wenkai/deepgo2/data/mf/terms.pkl')
|
58 |
terms_mf = pd.read_pickle('data/terms/mf_terms.pkl')
|
59 |
choices_mf = [Func_dict[i] for i in list(set(terms_mf['gos']))]
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
@spaces.GPU
|
64 |
-
def generate_caption(
|
65 |
# Process the image and the prompt
|
66 |
# with open('/home/user/app/example.fasta', 'w') as f:
|
67 |
# f.write('>{}\n'.format("protein_name"))
|
@@ -144,36 +153,40 @@ def generate_caption(model_id, protein, prompt):
|
|
144 |
'text_input': ['none'],
|
145 |
'prompt': [prompt]}
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
|
|
|
|
|
|
165 |
if prompt == 'none':
|
166 |
res_str = "No available predictions for this protein, you can use other two types of model, remove prompt or try another sequence!"
|
167 |
else:
|
168 |
res_str = "No available predictions for this protein, you can use other two types of model or try another sequence!"
|
169 |
-
if len(
|
170 |
return res_str
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
|
|
177 |
return res_str
|
178 |
# return "test"
|
179 |
|
@@ -205,7 +218,6 @@ with gr.Blocks(css=css) as demo:
|
|
205 |
with gr.Tab(label="Protein caption"):
|
206 |
with gr.Row():
|
207 |
with gr.Column():
|
208 |
-
model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='Molecule Function')
|
209 |
input_protein = gr.Textbox(type="text", label="Upload sequence")
|
210 |
prompt = gr.Textbox(type="text", label="Taxonomy Prompt (Optional)")
|
211 |
submit_btn = gr.Button(value="Submit")
|
@@ -214,20 +226,20 @@ with gr.Blocks(css=css) as demo:
|
|
214 |
# O14813 train index 127, 266, 738, 1060 test index 4
|
215 |
gr.Examples(
|
216 |
examples=[
|
217 |
-
["
|
218 |
-
["
|
219 |
-
["
|
220 |
-
[
|
221 |
-
[
|
222 |
-
[
|
223 |
],
|
224 |
-
inputs=[
|
225 |
outputs=[output_text],
|
226 |
fn=generate_caption,
|
227 |
cache_examples=True,
|
228 |
label='Try examples'
|
229 |
)
|
230 |
-
submit_btn.click(generate_caption, [
|
231 |
|
232 |
demo.launch(debug=True)
|
233 |
|
|
|
32 |
models = {
|
33 |
'Molecule Function': get_model('Molecule Function'),
|
34 |
'Biological Process': get_model('Biological Process'),
|
35 |
+
'Cellular Component': get_model('Cellar Component'),
|
36 |
}
|
37 |
|
38 |
|
39 |
# Load the mistral model
|
40 |
+
mistral_model = MistralForCausalLM.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B", torch_dtype=torch.float16).to('cuda')
|
41 |
|
42 |
# Load ESM2 model
|
43 |
model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
|
|
|
54 |
GO_dict = dict(zip(go_des['text'], go_des['id']))
|
55 |
Func_dict = dict(zip(go_des['id'], go_des['text']))
|
56 |
|
|
|
57 |
terms_mf = pd.read_pickle('data/terms/mf_terms.pkl')
|
58 |
choices_mf = [Func_dict[i] for i in list(set(terms_mf['gos']))]
|
59 |
+
choices_mf = {x.lower(): x for x in choices_mf}
|
60 |
+
terms_bp = pd.read_pickle('data/terms/bp_terms.pkl')
|
61 |
+
choices_bp = [Func_dict[i] for i in list(set(terms_bp['gos']))]
|
62 |
+
choices_bp = {x.lower(): x for x in choices_bp}
|
63 |
+
terms_cc = pd.read_pickle('data/terms/cc_terms.pkl')
|
64 |
+
choices_cc = [Func_dict[i] for i in list(set(terms_cc['gos']))]
|
65 |
+
choices_cc = {x.lower(): x for x in choices_cc}
|
66 |
+
choices = {
|
67 |
+
'Molecule Function': choices_mf,
|
68 |
+
'Biological Process': choices_bp,
|
69 |
+
'Cellular Component': choices_cc,
|
70 |
+
}
|
71 |
|
72 |
@spaces.GPU
|
73 |
+
def generate_caption(protein, prompt):
|
74 |
# Process the image and the prompt
|
75 |
# with open('/home/user/app/example.fasta', 'w') as f:
|
76 |
# f.write('>{}\n'.format("protein_name"))
|
|
|
153 |
'text_input': ['none'],
|
154 |
'prompt': [prompt]}
|
155 |
|
156 |
+
union_pred_terms = []
|
157 |
+
for model_id in models.keys():
|
158 |
+
model = models[model_id]
|
159 |
+
# Generate the output
|
160 |
+
prediction = model.generate(mistral_model, samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
|
161 |
+
repetition_penalty=1.0)
|
162 |
+
x = prediction[0]
|
163 |
+
x = [eval(i) for i in x.split('; ')]
|
164 |
+
pred_terms = []
|
165 |
+
temp = []
|
166 |
+
for i in x:
|
167 |
+
txt = i[0]
|
168 |
+
prob = i[1]
|
169 |
+
sim_list = difflib.get_close_matches(txt.lower(), choices[model_id], n=1, cutoff=0.9)
|
170 |
+
if len(sim_list) > 0:
|
171 |
+
t_standard = sim_list[0]
|
172 |
+
if t_standard not in temp:
|
173 |
+
pred_terms.append(t_standard+f'({prob})')
|
174 |
+
temp.append(t_standard)
|
175 |
+
union_pred_terms.append(pred_terms)
|
176 |
+
|
177 |
if prompt == 'none':
|
178 |
res_str = "No available predictions for this protein, you can use other two types of model, remove prompt or try another sequence!"
|
179 |
else:
|
180 |
res_str = "No available predictions for this protein, you can use other two types of model or try another sequence!"
|
181 |
+
if len(union_pred_terms[0]) == 0 and len(union_pred_terms[1]) == 0 and len(union_pred_terms[2]) == 0:
|
182 |
return res_str
|
183 |
+
res_str = ''
|
184 |
+
if len(union_pred_terms[0]) != 0:
|
185 |
+
res_str += f"Based on the given amino acid sequence, the protein appears to have a primary function of {', '.join(pred_terms)}. "
|
186 |
+
if len(union_pred_terms[1]) != 0:
|
187 |
+
res_str += f"It is likely involved in the {', '.join(pred_terms)}. "
|
188 |
+
if len(union_pred_terms[2]) != 0:
|
189 |
+
res_str += f"It's subcellular localization is within the {', '.join(pred_terms)}."
|
190 |
return res_str
|
191 |
# return "test"
|
192 |
|
|
|
218 |
with gr.Tab(label="Protein caption"):
|
219 |
with gr.Row():
|
220 |
with gr.Column():
|
|
|
221 |
input_protein = gr.Textbox(type="text", label="Upload sequence")
|
222 |
prompt = gr.Textbox(type="text", label="Taxonomy Prompt (Optional)")
|
223 |
submit_btn = gr.Button(value="Submit")
|
|
|
226 |
# O14813 train index 127, 266, 738, 1060 test index 4
|
227 |
gr.Examples(
|
228 |
examples=[
|
229 |
+
["MDYSYLNSYDSCVAAMEASAYGDFGACSQPGGFQYSPLRPAFPAAGPPCPALGSSNCALGALRDHQPAPYSAVPYKFFPEPSGLHEKRKQRRIRTTFTSAQLKELERVFAETHYPDIYTREELALKIDLTEARVQVWFQNRRAKFRKQERAASAKGAAGAAGAKKGEARCSSEDDDSKESTCSPTPDSTASLPPPPAPGLASPRLSPSPLPVALGSGPGPGPGPQPLKGALWAGVAGGGGGGPGAGAAELLKAWQPAESGPGPFSGVLSSFHRKPGPALKTNLF", ''],
|
230 |
+
["MKTLALFLVLVCVLGLVQSWEWPWNRKPTKFPIPSPNPRDKWCRLNLGPAWGGRC", ''],
|
231 |
+
["MAAAGGARLLRAASAVLGGPAGRWLHHAGSRAGSSGLLRNRGPGGSAEASRSLSVSARARSSSEDKITVHFINRDGETLTTKGKVGDSLLDVVVENNLDIDGFGACEGTLACSTCHLIFEDHIYEKLDAITDEENDMLDLAYGLTDRSRLGCQICLTKSMDNMTVRVPETVADARQSIDVGKTS", 'Homo'],
|
232 |
+
['MASAELSREENVYMAKLAEQAERYEEMVEFMEKVAKTVDSEELTVEERNLLSVAYKNVIGARRASWRIISSIEQKEEGRGNEDRVTLIKDYRGKIETELTKICDGILKLLETHLVPSSTAPESKVFYLKMKGDYYRYLAEFKTGAERKDAAENTMVAYKAAQDIALAELAPTHPIRLGLALNFSVFYYEILNSPDRACSLAKQAFDEAISELDTLSEESYKDSTLIMQLLRDNLTLWTSDISEDPAEEIREAPKRDSSEGQ', 'Zea'],
|
233 |
+
['MIKAAVTKESLYRMNTLMEAFQGFLGLDLGEFTFKVKPGVFLLTDVKSYLIGDKYDDAFNALIDFVLRNDRDAVEGTETDVSIRLGLSPSDMVVKRQDKTFTFTHGDLEFEVHWINL', 'Bacteriophage'],
|
234 |
+
['MNDLMIQLLDQFEMGLRERAIKVMATINDEKHRFPMELNKKQCSLMLLGTTDTTTFDMRFNSKKDFPRIKGAREKYPRDAVIEWYHQNWMRTEVKQ', 'Bacteriophage'],
|
235 |
],
|
236 |
+
inputs=[input_protein, prompt],
|
237 |
outputs=[output_text],
|
238 |
fn=generate_caption,
|
239 |
cache_examples=True,
|
240 |
label='Try examples'
|
241 |
)
|
242 |
+
submit_btn.click(generate_caption, [input_protein, prompt], [output_text])
|
243 |
|
244 |
demo.launch(debug=True)
|
245 |
|