Spaces:
Running
Running
import numpy as np | |
import streamlit as st | |
import torch | |
import disvae | |
import transforms as trans | |
def load_decode_function(): | |
P_MODEL = "models/btcvae_celeba" | |
sorter = trans.LatentSorter(disvae.get_kl_dict(P_MODEL)) | |
vae = disvae.load_model(P_MODEL) | |
_dec = trans.sequential_function( | |
sorter.inv, | |
vae.decoder | |
) | |
def decode(latent): | |
with torch.no_grad(): | |
return trans.np_sample(_dec)(latent) | |
return decode | |
# GUI ----------------------------------------------------------- | |
decode = load_decode_function() | |
latent_vector = np.array([st.slider(f"L{l}",min_value=-3.0,max_value=3.0,value=0.0) for l in range(3)]) | |
latent_vector = np.concatenate([latent_vector,np.zeros(7)],axis=0) | |
value = decode(latent_vector) | |
value = np.swapaxes(np.swapaxes(value, 0, 2), 0, 1) | |
st.image(value, use_column_width="always") | |