Update app.py
Browse files
app.py
CHANGED
@@ -13,25 +13,35 @@ from data.evaluate_data.utils import Ontology
|
|
13 |
import difflib
|
14 |
import re
|
15 |
from transformers import MistralForCausalLM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
|
|
17 |
# Load the trained model
|
18 |
def get_model(type='Molecule Function'):
|
19 |
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
|
20 |
if type == 'Molecule Function':
|
21 |
-
model.load_checkpoint("model/checkpoint_mf2.pth")
|
|
|
22 |
model.Qformer.bert = torch.load('model/mf2_bert.pth', map_location=torch.device('cpu'))
|
23 |
model.to('cuda')
|
24 |
elif type == 'Biological Process':
|
25 |
-
model.load_checkpoint("model/checkpoint_bp1.pth")
|
|
|
26 |
model.Qformer.bert = torch.load('model/bp1_bert.pth', map_location=torch.device('cpu'))
|
27 |
model.to('cuda')
|
28 |
elif type == 'Cellar Component':
|
29 |
-
model.load_checkpoint("model/checkpoint_cc2.pth")
|
|
|
30 |
model.Qformer.bert = torch.load('model/cc2_bert.pth', map_location=torch.device('cpu'))
|
31 |
model.to('cuda')
|
32 |
return model
|
33 |
|
34 |
-
|
35 |
models = {
|
36 |
'Molecule Function': get_model('Molecule Function'),
|
37 |
'Biological Process': get_model('Biological Process'),
|
|
|
13 |
import difflib
|
14 |
import re
|
15 |
from transformers import MistralForCausalLM
|
16 |
+
from huggingface_hub import hf_hub_download
|
17 |
+
bp_param = hf_hub_download(repo_id="wenkai/FAPM", filename="model/checkpoint_bp1.pth")
|
18 |
+
mf_param = hf_hub_download(repo_id="wenkai/FAPM", filename="model/checkpoint_mf2.pth")
|
19 |
+
cc_param = hf_hub_download(repo_id="wenkai/FAPM", filename="model/checkpoint_cc2.pth")
|
20 |
+
# hf_hub_download(repo_id="wenkai/FAPM", filename="model/mf2_bert.pth")
|
21 |
+
# hf_hub_download(repo_id="wenkai/FAPM", filename="model/bp1_bert.pth")
|
22 |
+
# hf_hub_download(repo_id="wenkai/FAPM", filename="model/cc2_bert.pth")
|
23 |
|
24 |
+
# bert_param = BertModel.from_pretrained("bert-base-uncased").state_dict()
|
25 |
# Load the trained model
|
26 |
def get_model(type='Molecule Function'):
|
27 |
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
|
28 |
if type == 'Molecule Function':
|
29 |
+
# model.load_checkpoint("model/checkpoint_mf2.pth")
|
30 |
+
model.load_checkpoint(mf_param)
|
31 |
model.Qformer.bert = torch.load('model/mf2_bert.pth', map_location=torch.device('cpu'))
|
32 |
model.to('cuda')
|
33 |
elif type == 'Biological Process':
|
34 |
+
# model.load_checkpoint("model/checkpoint_bp1.pth")
|
35 |
+
model.load_checkpoint(bp_param)
|
36 |
model.Qformer.bert = torch.load('model/bp1_bert.pth', map_location=torch.device('cpu'))
|
37 |
model.to('cuda')
|
38 |
elif type == 'Cellar Component':
|
39 |
+
# model.load_checkpoint("model/checkpoint_cc2.pth")
|
40 |
+
model.load_checkpoint(cc_param)
|
41 |
model.Qformer.bert = torch.load('model/cc2_bert.pth', map_location=torch.device('cpu'))
|
42 |
model.to('cuda')
|
43 |
return model
|
44 |
|
|
|
45 |
models = {
|
46 |
'Molecule Function': get_model('Molecule Function'),
|
47 |
'Biological Process': get_model('Biological Process'),
|