wenkai commited on
Commit
ab60ac5
·
verified ·
1 Parent(s): 4533bc2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -53
app.py CHANGED
@@ -1,53 +1,66 @@
1
- import os
2
- import torch
3
- import torch.nn as nn
4
- import pandas as pd
5
- import torch.nn.functional as F
6
- from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral
7
- from lavis.models.base_model import FAPMConfig
8
- import spaces
9
- import gradio as gr
10
-
11
-
12
- # Load the model
13
- model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
14
- model.load_checkpoint("model/checkpoint_mf2.pth")
15
- model.to('cuda')
16
-
17
-
18
- @spaces.GPU
19
- def generate_caption(protein, prompt):
20
- # Process the image and the prompt
21
- with open('data/fasta/example.fasta', 'w') as f:
22
- f.write('>{}\n'.format("protein_name"))
23
- f.write('{}\n'.format(protein.strip()))
24
- os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D data/fasta/example.fasta data/emb_esm2_3b --repr_layers 36 --truncation_seq_length 1024 --include per_tok")
25
- esm_emb = torch.load("data/emb_esm2_3b/protein_name.pt")['representations'][36]
26
- esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
27
- samples = {'name': ['test_protein'],
28
- 'image': torch.unsqueeze(esm_emb, dim=0),
29
- 'text_input': ['none'],
30
- 'prompt': [prompt]}
31
- # Generate the output
32
- prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1., repetition_penalty=1.0)
33
-
34
- return prediction
35
-
36
- # Define the FAPM interface
37
- description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information.
38
-
39
- The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main)."""
40
-
41
- iface = gr.Interface(
42
- fn=generate_caption,
43
- inputs=[gr.Textbox(type="pil", label="Upload sequence"), gr.Textbox(label="Prompt", value="taxonomy prompt")],
44
- outputs=gr.Textbox(label="Generated description"),
45
- description=description
46
- )
47
-
48
- # Launch the interface
49
- iface.launch()
50
-
51
-
52
-
53
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import pandas as pd
5
+ import torch.nn.functional as F
6
+ from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral
7
+ from lavis.models.base_model import FAPMConfig
8
+ import spaces
9
+ import gradio as gr
10
+ from esm_scripts.extract import run_demo
11
+ from esm import pretrained, FastaBatchedDataset
12
+
13
+ # from transformers import EsmTokenizer, EsmModel
14
+
15
+
16
+ # Load the model
17
+ model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
18
+ model.load_checkpoint("model/checkpoint_mf2.pth")
19
+ model.to('cuda')
20
+
21
+
22
+ @spaces.GPU
23
+ def generate_caption(protein, prompt):
24
+
25
+ esm_emb = torch.load('data/emb_esm2_3b/P18281.pt')['representations'][36]
26
+ torch.save(esm_emb, 'data/emb_esm2_3b/example.pt')
27
+ '''
28
+ inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda')
29
+ with torch.no_grad():
30
+ outputs = model_esm(**inputs)
31
+ esm_emb = outputs.last_hidden_state.detach()[0]
32
+ '''
33
+ print("esm embedding generated")
34
+ esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
35
+ print("esm embedding processed")
36
+ samples = {'name': ['protein_name'],
37
+ 'image': torch.unsqueeze(esm_emb, dim=0),
38
+ 'text_input': ['none'],
39
+ 'prompt': [prompt]}
40
+
41
+ # Generate the output
42
+ prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
43
+ repetition_penalty=1.0)
44
+
45
+ return prediction
46
+ # return "test"
47
+
48
+
49
+ # Define the FAPM interface
50
+ description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information.
51
+
52
+ The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main)."""
53
+
54
+ iface = gr.Interface(
55
+ fn=generate_caption,
56
+ inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")],
57
+ outputs=gr.Textbox(label="Generated description"),
58
+ description=description
59
+ )
60
+
61
+ # Launch the interface
62
+ iface.launch()
63
+
64
+
65
+
66
+