Update app.py
Browse files
app.py
CHANGED
@@ -19,6 +19,7 @@ def get_model(type='Molecule Function'):
|
|
19 |
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
|
20 |
if type == 'Molecule Function':
|
21 |
model.load_checkpoint("model/checkpoint_mf2.pth")
|
|
|
22 |
model.to('cuda')
|
23 |
elif type == 'Biological Process':
|
24 |
model.load_checkpoint("model/checkpoint_bp1.pth")
|
@@ -26,6 +27,7 @@ def get_model(type='Molecule Function'):
|
|
26 |
model.to('cuda')
|
27 |
elif type == 'Cellar Component':
|
28 |
model.load_checkpoint("model/checkpoint_cc2.pth")
|
|
|
29 |
model.to('cuda')
|
30 |
return model
|
31 |
|
|
|
19 |
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
|
20 |
if type == 'Molecule Function':
|
21 |
model.load_checkpoint("model/checkpoint_mf2.pth")
|
22 |
+
model.Qformer.bert = torch.load('model/mf2_bert.pth', map_location=torch.device('cpu'))
|
23 |
model.to('cuda')
|
24 |
elif type == 'Biological Process':
|
25 |
model.load_checkpoint("model/checkpoint_bp1.pth")
|
|
|
27 |
model.to('cuda')
|
28 |
elif type == 'Cellar Component':
|
29 |
model.load_checkpoint("model/checkpoint_cc2.pth")
|
30 |
+
model.Qformer.bert = torch.load('model/cc2_bert.pth', map_location=torch.device('cpu'))
|
31 |
model.to('cuda')
|
32 |
return model
|
33 |
|