hail75 commited on
Commit
0564946
·
1 Parent(s): 82f1d3f
Files changed (1) hide show
  1. models/SRFlow/srflow.py +5 -12
models/SRFlow/srflow.py CHANGED
@@ -7,7 +7,7 @@ from PIL import Image
7
  import matplotlib.pyplot as plt
8
  from torchvision.transforms import PILToTensor, ToPILImage
9
 
10
- def return_SRFlow_result(lr, divide=True, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml'):
11
  """
12
  Apply Super-Resolution using SRFlow model to the input LR (low-resolution) image.
13
 
@@ -23,10 +23,6 @@ def return_SRFlow_result(lr, divide=True, conf_path='models/SRFlow/code/confs/SR
23
 
24
  lr = PILToTensor()(lr).permute(1, 2, 0).numpy()
25
 
26
- if not divide:
27
- lr = np.round(255 * lr).astype(np.uint8)
28
-
29
-
30
  scale = opt['scale']
31
  pad_factor = 2
32
 
@@ -45,7 +41,7 @@ def return_SRFlow_result(lr, divide=True, conf_path='models/SRFlow/code/confs/SR
45
  sr_img = Image.fromarray((sr).astype('uint8'))
46
  return sr_img
47
 
48
- def return_SRFlow_result_from_tensor(lr_tensor, divide):
49
  """
50
  Apply Super-Resolution using SRFlow model to the input batched BCHW tensor.
51
 
@@ -60,25 +56,22 @@ def return_SRFlow_result_from_tensor(lr_tensor, divide):
60
 
61
  for b in range(batch_size):
62
  lr_image = ToPILImage()(lr_tensor[b])
63
- sr_image = return_SRFlow_result(lr_image, divide)
64
  sr_tensor = PILToTensor()(sr_image).unsqueeze(0)
65
 
66
  sr_list.append(sr_tensor)
67
 
68
  sr_tensor = torch.cat(sr_list, dim=0)
69
 
70
- if not divide:
71
- sr_tensor = torch.ones_like(sr_tensor) - sr_tensor / 255
72
-
73
  return sr_tensor
74
 
75
  if __name__ == '__main__':
76
  lr = Image.open('images/demo.png')
77
 
78
- lr_tensor = PILToTensor()(lr).unsqueeze(0) / 255
79
  print(lr_tensor.shape)
80
  random_tensor = torch.randn(8, 3, 64, 64)
81
- sr = return_SRFlow_result_from_tensor(lr_tensor, divide=False)
82
  print(sr)
83
 
84
  # Show SR image of the first one in the batch
 
7
  import matplotlib.pyplot as plt
8
  from torchvision.transforms import PILToTensor, ToPILImage
9
 
10
+ def return_SRFlow_result(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml'):
11
  """
12
  Apply Super-Resolution using SRFlow model to the input LR (low-resolution) image.
13
 
 
23
 
24
  lr = PILToTensor()(lr).permute(1, 2, 0).numpy()
25
 
 
 
 
 
26
  scale = opt['scale']
27
  pad_factor = 2
28
 
 
41
  sr_img = Image.fromarray((sr).astype('uint8'))
42
  return sr_img
43
 
44
+ def return_SRFlow_result_from_tensor(lr_tensor):
45
  """
46
  Apply Super-Resolution using SRFlow model to the input batched BCHW tensor.
47
 
 
56
 
57
  for b in range(batch_size):
58
  lr_image = ToPILImage()(lr_tensor[b])
59
+ sr_image = return_SRFlow_result(lr_image)
60
  sr_tensor = PILToTensor()(sr_image).unsqueeze(0)
61
 
62
  sr_list.append(sr_tensor)
63
 
64
  sr_tensor = torch.cat(sr_list, dim=0)
65
 
 
 
 
66
  return sr_tensor
67
 
68
  if __name__ == '__main__':
69
  lr = Image.open('images/demo.png')
70
 
71
+ lr_tensor = PILToTensor()(lr).unsqueeze(0)
72
  print(lr_tensor.shape)
73
  random_tensor = torch.randn(8, 3, 64, 64)
74
+ sr = return_SRFlow_result_from_tensor(lr_tensor)
75
  print(sr)
76
 
77
  # Show SR image of the first one in the batch