GV05 commited on
Commit
093675e
·
1 Parent(s): 8dbf7f7

built space

Browse files
Files changed (6) hide show
  1. MnistVAEmodel.pt +3 -0
  2. app.py +35 -0
  3. model.py +50 -0
  4. original_5.png +0 -0
  5. original_8.png +0 -0
  6. requirements.txt +3 -0
MnistVAEmodel.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d6ab1a824858a37b3dbeffce09cd2de481906e689b4817e505cb2550e992d3d
3
+ size 4796991
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from model import VariationalAutoEncoder
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import gradio as gr
6
+
7
+
8
+ INPUT_DIM = 784
9
+ H_DIM = 512
10
+ Z_DIM = 256
11
+
12
+ model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM)
13
+ model.load_state_dict(torch.load("MnistVAEmodel.pth"))
14
+ model.eval()
15
+ def predict(img):
16
+ img = img.convert('1')
17
+ img = transforms.ToTensor()(img)
18
+ img = transforms.CenterCrop(size=28)(img)
19
+ print(type(img), img.shape)
20
+ mu, sigma = model.encode(img.view(1, INPUT_DIM))
21
+
22
+ res = []
23
+ for example in range(10):
24
+ epsilon = torch.randn_like(sigma)
25
+ z = mu + sigma * epsilon
26
+ out = model.decode(z)
27
+ out = out.view(-1,1,28,28)
28
+ res.append(transforms.ToPILImage()(out[0]))
29
+ return res
30
+
31
+ title = "Variational-Autoencoder-on-MNIST "
32
+ description = "TO DO"
33
+ examples = ["original_5.png", "original_8.png"]
34
+ gr.Interface(fn=predict, inputs = gr.inputs.Image(), outputs= gr.outputs.Gallery(),
35
+ examples=examples, title=title, description=description).launch(inline=False)
model.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class VariationalAutoEncoder(nn.Module):
6
+ # Input image -> hidden dim -> mean, std -> parametirazation trick -> Decoder -> output image
7
+ def __init__(self, inpud_dim, h_dim=200, z_dim=20):
8
+ super().__init__()
9
+
10
+ # encoder
11
+ self.img_2hid = nn.Linear(inpud_dim, h_dim)
12
+ self.hid_2mu = nn.Linear(h_dim, z_dim)
13
+ self.hid_2sigma = nn.Linear(h_dim, z_dim)
14
+
15
+ # decoder
16
+ self.z_2hi = nn.Linear(z_dim, h_dim)
17
+ self.hid_2img = nn.Linear(h_dim, inpud_dim)
18
+
19
+ self.relu = nn.ReLU()
20
+
21
+ def encode(self, x):
22
+ # q_phi(z/x)
23
+ h = self.relu(self.img_2hid(x))
24
+ mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)
25
+
26
+ return mu, sigma
27
+
28
+ def decode(self, z):
29
+ # p_theta(x/z)
30
+ h = self.relu(self.z_2hi(z))
31
+ x = self.hid_2img(h)
32
+ return torch.sigmoid(x) # image values should be between zero and one.
33
+
34
+ def forward(self, x):
35
+ mu, sigma = self.encode(x)
36
+ # parametirazation trick
37
+ epsilon = torch.randn_like(sigma) # Returns a tensor with the same size as input that is filled with random numbers from a normal distribution with mean 0 and variance 1
38
+ z_reparametrized = mu + sigma * epsilon
39
+ x_reconstructed = self.decode(z_reparametrized)
40
+ return x_reconstructed, mu, sigma # 2 parts of loss: 1- mu, sigma pushed to normal distribution. 2 the x_reconstructed should be same as x
41
+
42
+ if __name__ == "__main__":
43
+
44
+ x = torch.randn(4,28*28)
45
+ vae = VariationalAutoEncoder(inpud_dim=784)
46
+ x_reconstructed, mu, sigma = vae(x)
47
+ print(x_reconstructed.shape)
48
+ print(mu.shape)
49
+ print(sigma.shape)
50
+ print(torch.mean(mu))
original_5.png ADDED
original_8.png ADDED
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision