Kimata commited on
Commit
6f27114
·
1 Parent(s): ebb8b5c

change paths delimiters

Browse files
Files changed (1) hide show
  1. inference_2.py +3 -3
inference_2.py CHANGED
@@ -10,7 +10,7 @@ from models import image
10
 
11
  from onnx2pytorch import ConvertModel
12
 
13
- onnx_model = onnx.load('checkpoints\\efficientnet.onnx')
14
  pytorch_model = ConvertModel(onnx_model)
15
 
16
  #Set random seed for reproducibility.
@@ -75,7 +75,7 @@ def model_summary(args):
75
  def load_multimodal_model(args):
76
  '''Load multimodal model'''
77
  model = ETMC(args)
78
- ckpt = torch.load('checkpoints\\model.pth', map_location = torch.device('cpu'))
79
  model.load_state_dict(ckpt, strict = True)
80
  model.eval()
81
  return model
@@ -84,7 +84,7 @@ def load_img_modality_model(args):
84
  '''Loads image modality model.'''
85
  rgb_encoder = pytorch_model
86
 
87
- ckpt = torch.load('checkpoints\\model.pth', map_location = torch.device('cpu'))
88
  rgb_encoder.load_state_dict(ckpt['rgb_encoder'], strict = True)
89
  rgb_encoder.eval()
90
  return rgb_encoder
 
10
 
11
  from onnx2pytorch import ConvertModel
12
 
13
+ onnx_model = onnx.load('checkpoints/efficientnet.onnx')
14
  pytorch_model = ConvertModel(onnx_model)
15
 
16
  #Set random seed for reproducibility.
 
75
  def load_multimodal_model(args):
76
  '''Load multimodal model'''
77
  model = ETMC(args)
78
+ ckpt = torch.load('checkpoints/model.pth', map_location = torch.device('cpu'))
79
  model.load_state_dict(ckpt, strict = True)
80
  model.eval()
81
  return model
 
84
  '''Loads image modality model.'''
85
  rgb_encoder = pytorch_model
86
 
87
+ ckpt = torch.load('checkpoints/model.pth', map_location = torch.device('cpu'))
88
  rgb_encoder.load_state_dict(ckpt['rgb_encoder'], strict = True)
89
  rgb_encoder.eval()
90
  return rgb_encoder