Jonas Becker commited on
Commit
af94f43
·
1 Parent(s): c53ddec

Added cache function

Browse files
Files changed (1) hide show
  1. app.py +16 -18
app.py CHANGED
@@ -5,34 +5,32 @@ import torch
5
  import disvae
6
  import transforms as trans
7
 
8
- P_MODEL = "models/btcvae_celeba"
 
 
9
 
10
- # Decode Funktion --------------------------------------------------
11
- sorter = trans.LatentSorter(disvae.get_kl_dict(P_MODEL))
12
- vae = disvae.load_model(P_MODEL)
13
- scaler = trans.MinMaxScaler(_min=torch.tensor([1.3]),_max=torch.tensor([4.0]),min_norm=0.3,max_norm=0.6)
14
- imaging = trans.SumField()
15
 
16
- _dec = trans.sequential_function(
17
- sorter.inv,
18
- vae.decoder
19
- )
20
 
21
- def decode(latent):
22
- with torch.no_grad():
23
- return trans.np_sample(_dec)(latent)
 
24
 
25
  # GUI -----------------------------------------------------------
26
 
 
 
27
  latent_vector = np.array([st.slider(f"L{l}",min_value=-3.0,max_value=3.0,value=0.0) for l in range(3)])
28
  latent_vector = np.concatenate([latent_vector,np.zeros(7)],axis=0)
29
 
30
  value = decode(latent_vector)
31
 
32
- value = np.swapaxes(np.swapaxes(value, 0, 2), 0, 1)# * 255
33
 
34
- # st.write(value)
35
  st.image(value, use_column_width="always")
36
-
37
- # x = st.slider("Select a value")
38
- # st.write(x, "squared is", x * x)
 
5
  import disvae
6
  import transforms as trans
7
 
8
+ @st.cache_resource
9
+ def load_decode_function():
10
+ P_MODEL = "models/btcvae_celeba"
11
 
12
+ sorter = trans.LatentSorter(disvae.get_kl_dict(P_MODEL))
13
+ vae = disvae.load_model(P_MODEL)
 
 
 
14
 
15
+ _dec = trans.sequential_function(
16
+ sorter.inv,
17
+ vae.decoder
18
+ )
19
 
20
+ def decode(latent):
21
+ with torch.no_grad():
22
+ return trans.np_sample(_dec)(latent)
23
+ return decode
24
 
25
  # GUI -----------------------------------------------------------
26
 
27
+ decode = load_decode_function()
28
+
29
  latent_vector = np.array([st.slider(f"L{l}",min_value=-3.0,max_value=3.0,value=0.0) for l in range(3)])
30
  latent_vector = np.concatenate([latent_vector,np.zeros(7)],axis=0)
31
 
32
  value = decode(latent_vector)
33
 
34
+ value = np.swapaxes(np.swapaxes(value, 0, 2), 0, 1)
35
 
 
36
  st.image(value, use_column_width="always")