Spaces:
Runtime error
Runtime error
Better cuda detection.
Browse files
test.py
CHANGED
@@ -20,7 +20,12 @@ output_dir = sys.argv[2]
|
|
20 |
device = torch.device('cuda' if is_cuda() else 'cpu')
|
21 |
|
22 |
model = arch.RRDB_Net(1, 1, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', mode='CNA', res_scale=1, upsample_mode='upconv')
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
24 |
model.eval()
|
25 |
|
26 |
for k, v in model.named_parameters():
|
|
|
20 |
device = torch.device('cuda' if is_cuda() else 'cpu')
|
21 |
|
22 |
model = arch.RRDB_Net(1, 1, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', mode='CNA', res_scale=1, upsample_mode='upconv')
|
23 |
+
|
24 |
+
if is_cuda():
|
25 |
+
model.load_state_dict(torch.load(model_path), strict=True)
|
26 |
+
else:
|
27 |
+
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=True)
|
28 |
+
|
29 |
model.eval()
|
30 |
|
31 |
for k, v in model.named_parameters():
|