hail75 commited on
Commit
6e95c2f
·
1 Parent(s): 1ba6b94
Files changed (1) hide show
  1. models/SRFlow/srflow.py +11 -11
models/SRFlow/srflow.py CHANGED
@@ -41,9 +41,8 @@ def return_SRFlow_result(lr, divide, conf_path='models/SRFlow/code/confs/SRFlow_
41
  sr = rgb(torch.clamp(sr_t, 0, 1))
42
  sr = sr[:h * scale, :w * scale]
43
 
44
-
45
- sr = Image.fromarray((sr).astype('uint8'))
46
- return sr
47
 
48
  def return_SRFlow_result_from_tensor(lr_tensor, divide=True):
49
  """
@@ -68,8 +67,8 @@ def return_SRFlow_result_from_tensor(lr_tensor, divide=True):
68
  sr_tensor = torch.cat(sr_list, dim=0)
69
 
70
  if not divide:
71
- sr_tensor /= 255
72
-
73
  return sr_tensor
74
 
75
  if __name__ == '__main__':
@@ -77,10 +76,11 @@ if __name__ == '__main__':
77
 
78
  lr_tensor = PILToTensor()(lr).unsqueeze(0)
79
 
80
- sr = return_SRFlow_result_from_tensor(lr_tensor)
81
- print(sr.shape)
 
82
 
83
- # Show SR image of the first one in the batch
84
- plt.imshow(np.transpose(sr[0].cpu().detach().numpy(), (1, 2, 0)))
85
- # plt.axis('off')
86
- plt.show()
 
41
  sr = rgb(torch.clamp(sr_t, 0, 1))
42
  sr = sr[:h * scale, :w * scale]
43
 
44
+ sr_img = Image.fromarray((sr).astype('uint8'))
45
+ return sr_img
 
46
 
47
  def return_SRFlow_result_from_tensor(lr_tensor, divide=True):
48
  """
 
67
  sr_tensor = torch.cat(sr_list, dim=0)
68
 
69
  if not divide:
70
+ sr_tensor = sr_tensor / 255
71
+
72
  return sr_tensor
73
 
74
  if __name__ == '__main__':
 
76
 
77
  lr_tensor = PILToTensor()(lr).unsqueeze(0)
78
 
79
+ random_tensor = torch.randn(8, 3, 64, 64)
80
+ sr = return_SRFlow_result_from_tensor(random_tensor, divide=False)
81
+ print(sr)
82
 
83
+ # # Show SR image of the first one in the batch
84
+ # plt.imshow(np.transpose(sr[0].cpu().detach().numpy(), (1, 2, 0)))
85
+ # # plt.axis('off')
86
+ # plt.show()