0x90e commited on
Commit
9f0dff5
·
1 Parent(s): 6887d0a

Better cuda detection.

Browse files
Files changed (1) hide show
  1. test.py +6 -1
test.py CHANGED
@@ -20,7 +20,12 @@ output_dir = sys.argv[2]
20
  device = torch.device('cuda' if is_cuda() else 'cpu')
21
 
22
  model = arch.RRDB_Net(1, 1, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', mode='CNA', res_scale=1, upsample_mode='upconv')
23
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cuda' if is_cuda() else 'cpu')), strict=True)
 
 
 
 
 
24
  model.eval()
25
 
26
  for k, v in model.named_parameters():
 
20
  device = torch.device('cuda' if is_cuda() else 'cpu')
21
 
22
  model = arch.RRDB_Net(1, 1, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', mode='CNA', res_scale=1, upsample_mode='upconv')
23
+
24
+ if is_cuda():
25
+ model.load_state_dict(torch.load(model_path), strict=True)
26
+ else:
27
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=True)
28
+
29
  model.eval()
30
 
31
  for k, v in model.named_parameters():