wenkai commited on
Commit
d0dd902
·
verified ·
1 Parent(s): dfc4d26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -4
app.py CHANGED
@@ -13,25 +13,35 @@ from data.evaluate_data.utils import Ontology
13
  import difflib
14
  import re
15
  from transformers import MistralForCausalLM
 
 
 
 
 
 
 
16
 
 
17
  # Load the trained model
18
  def get_model(type='Molecule Function'):
19
  model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
20
  if type == 'Molecule Function':
21
- model.load_checkpoint("model/checkpoint_mf2.pth")
 
22
  model.Qformer.bert = torch.load('model/mf2_bert.pth', map_location=torch.device('cpu'))
23
  model.to('cuda')
24
  elif type == 'Biological Process':
25
- model.load_checkpoint("model/checkpoint_bp1.pth")
 
26
  model.Qformer.bert = torch.load('model/bp1_bert.pth', map_location=torch.device('cpu'))
27
  model.to('cuda')
28
  elif type == 'Cellar Component':
29
- model.load_checkpoint("model/checkpoint_cc2.pth")
 
30
  model.Qformer.bert = torch.load('model/cc2_bert.pth', map_location=torch.device('cpu'))
31
  model.to('cuda')
32
  return model
33
 
34
-
35
  models = {
36
  'Molecule Function': get_model('Molecule Function'),
37
  'Biological Process': get_model('Biological Process'),
 
13
  import difflib
14
  import re
15
  from transformers import MistralForCausalLM
16
+ from huggingface_hub import hf_hub_download
17
+ bp_param = hf_hub_download(repo_id="wenkai/FAPM", filename="model/checkpoint_bp1.pth")
18
+ mf_param = hf_hub_download(repo_id="wenkai/FAPM", filename="model/checkpoint_mf2.pth")
19
+ cc_param = hf_hub_download(repo_id="wenkai/FAPM", filename="model/checkpoint_cc2.pth")
20
+ # hf_hub_download(repo_id="wenkai/FAPM", filename="model/mf2_bert.pth")
21
+ # hf_hub_download(repo_id="wenkai/FAPM", filename="model/bp1_bert.pth")
22
+ # hf_hub_download(repo_id="wenkai/FAPM", filename="model/cc2_bert.pth")
23
 
24
+ # bert_param = BertModel.from_pretrained("bert-base-uncased").state_dict()
25
  # Load the trained model
26
  def get_model(type='Molecule Function'):
27
  model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
28
  if type == 'Molecule Function':
29
+ # model.load_checkpoint("model/checkpoint_mf2.pth")
30
+ model.load_checkpoint(mf_param)
31
  model.Qformer.bert = torch.load('model/mf2_bert.pth', map_location=torch.device('cpu'))
32
  model.to('cuda')
33
  elif type == 'Biological Process':
34
+ # model.load_checkpoint("model/checkpoint_bp1.pth")
35
+ model.load_checkpoint(bp_param)
36
  model.Qformer.bert = torch.load('model/bp1_bert.pth', map_location=torch.device('cpu'))
37
  model.to('cuda')
38
  elif type == 'Cellar Component':
39
+ # model.load_checkpoint("model/checkpoint_cc2.pth")
40
+ model.load_checkpoint(cc_param)
41
  model.Qformer.bert = torch.load('model/cc2_bert.pth', map_location=torch.device('cpu'))
42
  model.to('cuda')
43
  return model
44
 
 
45
  models = {
46
  'Molecule Function': get_model('Molecule Function'),
47
  'Biological Process': get_model('Biological Process'),