Files changed (1) hide show
  1. app.py +13 -15
app.py CHANGED
@@ -1,30 +1,28 @@
1
  import gradio as gr
2
  import torch
3
  import kornia as K
 
4
  from kornia.geometry.transform import resize
5
- import cv2
6
- import numpy as np
7
- from torchvision import transforms
8
  from torchvision.utils import make_grid
9
 
 
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
- def read_image(f_name):
13
- image_to_tensor = transforms.ToTensor()
14
- img = image_to_tensor(cv2.imread(f_name, cv2.IMREAD_COLOR))
15
- resized_image = resize(img,(50, 50))
16
- return resized_image
17
 
18
- def predict(images, eps):
19
- eps = float(eps)
20
- f_names = [img.name for img in images]
21
- images = [read_image(f) for f in f_names]
22
  images = torch.stack(images, dim = 0).to(device)
23
  zca = K.enhance.ZCAWhitening(eps=eps, compute_inv=True)
24
  zca.fit(images)
25
  zca_images = zca(images)
26
- grid_zca = make_grid(zca_images, nrow=3, normalize=True).cpu().numpy()
27
- return np.transpose(grid_zca,[1,2,0])
28
 
29
  title = 'ZCA Whitening with Kornia!'
30
  description = '''[ZCA Whitening](https://paperswithcode.com/method/zca-whitening) is an image preprocessing method that leads to a transformation of data such that the covariance matrix is the identity matrix, leading to decorrelated features:
@@ -49,7 +47,7 @@ iface = gr.Interface(fn=predict,
49
  'Carnation.jpg',
50
  'Orchid.jpg',
51
  'Peony.jpg'
52
- ], 0.01]]
53
  )
54
 
55
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import torch
3
  import kornia as K
4
+ from kornia.core import Tensor
5
  from kornia.geometry.transform import resize
6
+
 
 
7
  from torchvision.utils import make_grid
8
 
9
+ eps: float = 0.01
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
+ def read_image(f_name: str) -> Tensor:
13
+ # load the image using the rust backend
14
+ img: Tensor = K.io.load_image(file.name, K.io.ImageLoadType.RGB32)
15
+ img = img[None] # 1xCxHxW / fp32 / [0, 1]
16
+ return resize(img,(50, 50))
17
 
18
+ def predict(images):
19
+ images = [read_image(f.name) for f in f_names]
 
 
20
  images = torch.stack(images, dim = 0).to(device)
21
  zca = K.enhance.ZCAWhitening(eps=eps, compute_inv=True)
22
  zca.fit(images)
23
  zca_images = zca(images)
24
+ grid_zca = make_grid(zca_images, nrow=3, normalize=True)
25
+ return K.tensor_to_image(grid_zca)
26
 
27
  title = 'ZCA Whitening with Kornia!'
28
  description = '''[ZCA Whitening](https://paperswithcode.com/method/zca-whitening) is an image preprocessing method that leads to a transformation of data such that the covariance matrix is the identity matrix, leading to decorrelated features:
 
47
  'Carnation.jpg',
48
  'Orchid.jpg',
49
  'Peony.jpg'
50
+ ]]]
51
  )
52
 
53
  if __name__ == "__main__":