vincent-doan commited on
Commit
54770f1
·
1 Parent(s): 8dfd4b9

Added inference method for RCAN

Browse files
Files changed (1) hide show
  1. models/RCAN/rcan.py +10 -4
models/RCAN/rcan.py CHANGED
@@ -107,6 +107,15 @@ class RCAN(nn.Module):
107
  reconstructed_image = self.reconstruction_conv(upscaled_image) # Reconstruction
108
 
109
  return reconstructed_image
 
 
 
 
 
 
 
 
 
110
 
111
  if __name__ == '__main__':
112
  current_dir = os.path.dirname(os.path.realpath(__file__))
@@ -116,7 +125,4 @@ if __name__ == '__main__':
116
  model.eval()
117
  with torch.no_grad():
118
  input_image = Image.open('images/demo.png')
119
- input_tensor = ToTensor()(input_image).unsqueeze(0)
120
- output_tensor = model(input_tensor)
121
- print(input_tensor.shape)
122
- print(output_tensor.shape)
 
107
  reconstructed_image = self.reconstruction_conv(upscaled_image) # Reconstruction
108
 
109
  return reconstructed_image
110
+
111
+ def inference(self, x):
112
+ """
113
+ x is a PIL image
114
+ """
115
+ x = ToTensor()(x).unsqueeze(0)
116
+ x = self.forward(x)
117
+ x = Image.fromarray((x.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype('uint8'))
118
+ return x
119
 
120
  if __name__ == '__main__':
121
  current_dir = os.path.dirname(os.path.realpath(__file__))
 
125
  model.eval()
126
  with torch.no_grad():
127
  input_image = Image.open('images/demo.png')
128
+ output_image = model.inference(input_image)