Update app.py
Browse files
app.py
CHANGED
@@ -86,16 +86,16 @@ def generate_caption(protein, prompt):
|
|
86 |
repr_layers = [36]
|
87 |
truncation_seq_length = 1024
|
88 |
toks_per_batch = 4096
|
89 |
-
print("start")
|
90 |
dataset = FastaBatchedDataset([protein_name], [protein_seq])
|
91 |
-
print("dataset prepared")
|
92 |
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
|
93 |
-
print("batches prepared")
|
94 |
|
95 |
data_loader = torch.utils.data.DataLoader(
|
96 |
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
|
97 |
)
|
98 |
-
print(f"Read sequences")
|
99 |
return_contacts = "contacts" in include
|
100 |
|
101 |
assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
|
@@ -142,7 +142,7 @@ def generate_caption(protein, prompt):
|
|
142 |
outputs = model_esm(**inputs)
|
143 |
esm_emb = outputs.last_hidden_state.detach()[0]
|
144 |
'''
|
145 |
-
print("esm embedding generated")
|
146 |
esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
|
147 |
if prompt is None:
|
148 |
prompt = 'none'
|
@@ -182,13 +182,15 @@ def generate_caption(protein, prompt):
|
|
182 |
return res_str
|
183 |
res_str = ''
|
184 |
if len(union_pred_terms[0]) != 0:
|
185 |
-
|
|
|
186 |
if len(union_pred_terms[1]) != 0:
|
187 |
-
|
|
|
188 |
if len(union_pred_terms[2]) != 0:
|
189 |
-
|
|
|
190 |
return res_str
|
191 |
-
# return "test"
|
192 |
|
193 |
|
194 |
# Define the FAPM interface
|
@@ -226,7 +228,9 @@ with gr.Blocks(css=css) as demo:
|
|
226 |
prompt = gr.Textbox(type="text", label="Taxonomy Prompt (Optional)")
|
227 |
submit_btn = gr.Button(value="Submit")
|
228 |
with gr.Column():
|
229 |
-
output_text = gr.Textbox(label="Output Text")
|
|
|
|
|
230 |
# O14813 train index 127, 266, 738, 1060 test index 4
|
231 |
gr.Examples(
|
232 |
examples=[
|
@@ -238,12 +242,12 @@ with gr.Blocks(css=css) as demo:
|
|
238 |
['MNDLMIQLLDQFEMGLRERAIKVMATINDEKHRFPMELNKKQCSLMLLGTTDTTTFDMRFNSKKDFPRIKGAREKYPRDAVIEWYHQNWMRTEVKQ', 'Bacteriophage'],
|
239 |
],
|
240 |
inputs=[input_protein, prompt],
|
241 |
-
outputs=[
|
242 |
fn=generate_caption,
|
243 |
cache_examples=True,
|
244 |
label='Try examples'
|
245 |
)
|
246 |
-
submit_btn.click(generate_caption, [input_protein, prompt], [
|
247 |
|
248 |
demo.launch(debug=True)
|
249 |
|
|
|
86 |
repr_layers = [36]
|
87 |
truncation_seq_length = 1024
|
88 |
toks_per_batch = 4096
|
89 |
+
# print("start")
|
90 |
dataset = FastaBatchedDataset([protein_name], [protein_seq])
|
91 |
+
# print("dataset prepared")
|
92 |
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
|
93 |
+
# print("batches prepared")
|
94 |
|
95 |
data_loader = torch.utils.data.DataLoader(
|
96 |
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
|
97 |
)
|
98 |
+
# print(f"Read sequences")
|
99 |
return_contacts = "contacts" in include
|
100 |
|
101 |
assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
|
|
|
142 |
outputs = model_esm(**inputs)
|
143 |
esm_emb = outputs.last_hidden_state.detach()[0]
|
144 |
'''
|
145 |
+
# print("esm embedding generated")
|
146 |
esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
|
147 |
if prompt is None:
|
148 |
prompt = 'none'
|
|
|
182 |
return res_str
|
183 |
res_str = ''
|
184 |
if len(union_pred_terms[0]) != 0:
|
185 |
+
temp = ['- '+i+'\n' for i in union_pred_terms[0]]
|
186 |
+
res_str += f"Based on the given amino acid sequence, the protein appears to have a primary function of \n{temp} \n"
|
187 |
if len(union_pred_terms[1]) != 0:
|
188 |
+
temp = ['- ' + i + '\n' for i in union_pred_terms[1]]
|
189 |
+
res_str += f"It is likely involved in the following process: \n{temp} \n"
|
190 |
if len(union_pred_terms[2]) != 0:
|
191 |
+
temp = ['- ' + i + '\n' for i in union_pred_terms[2]]
|
192 |
+
res_str += f"It's subcellular localization is within the: \n{temp}"
|
193 |
return res_str
|
|
|
194 |
|
195 |
|
196 |
# Define the FAPM interface
|
|
|
228 |
prompt = gr.Textbox(type="text", label="Taxonomy Prompt (Optional)")
|
229 |
submit_btn = gr.Button(value="Submit")
|
230 |
with gr.Column():
|
231 |
+
# output_text = gr.Textbox(label="Output Text")
|
232 |
+
with gr.Accordion('Answer:', open=True):
|
233 |
+
output_markdown = gr.Markdown(label="Output")
|
234 |
# O14813 train index 127, 266, 738, 1060 test index 4
|
235 |
gr.Examples(
|
236 |
examples=[
|
|
|
242 |
['MNDLMIQLLDQFEMGLRERAIKVMATINDEKHRFPMELNKKQCSLMLLGTTDTTTFDMRFNSKKDFPRIKGAREKYPRDAVIEWYHQNWMRTEVKQ', 'Bacteriophage'],
|
243 |
],
|
244 |
inputs=[input_protein, prompt],
|
245 |
+
outputs=[output_markdown],
|
246 |
fn=generate_caption,
|
247 |
cache_examples=True,
|
248 |
label='Try examples'
|
249 |
)
|
250 |
+
submit_btn.click(generate_caption, [input_protein, prompt], [output_markdown])
|
251 |
|
252 |
demo.launch(debug=True)
|
253 |
|