Update main.py
Browse files
main.py
CHANGED
@@ -20,9 +20,9 @@ def under_max(image):
|
|
20 |
return image
|
21 |
|
22 |
class Model(object):
|
23 |
-
def __init__(self, gpu=
|
24 |
config = Config()
|
25 |
-
config.device = 'cuda:{}'.format(gpu)
|
26 |
model, _ = caption_model.build_model(config)
|
27 |
checkpoint = torch.load('./checkpoint.pth', map_location='cpu')
|
28 |
model.load_state_dict(checkpoint['model'])
|
|
|
20 |
return image
|
21 |
|
22 |
class Model(object):
|
23 |
+
def __init__(self, gpu=None):
|
24 |
config = Config()
|
25 |
+
config.device = 'cpu' if gpu is None else 'cuda:{}'.format(gpu)
|
26 |
model, _ = caption_model.build_model(config)
|
27 |
checkpoint = torch.load('./checkpoint.pth', map_location='cpu')
|
28 |
model.load_state_dict(checkpoint['model'])
|