Spaces:
Sleeping
Sleeping
import numpy as np | |
import streamlit as st | |
import torch | |
import matplotlib.pyplot as plt | |
import disvae | |
import transforms as trans | |
P_MODEL = "model/drilling_ds_btcvae" | |
def load_decode_function(): | |
sorter = trans.LatentSorter(disvae.get_kl_dict(P_MODEL)) | |
vae = disvae.load_model(P_MODEL) | |
scaler = trans.MinMaxScaler(_min=torch.tensor([1.3]),_max=torch.tensor([4.0]),min_norm=0.3,max_norm=0.6) | |
imaging = trans.SumField() | |
_dec = trans.sequential_function( | |
sorter.inv, | |
vae.decoder, | |
scaler.inv, | |
imaging.inv | |
) | |
def decode(latent): | |
with torch.no_grad(): | |
return trans.np_sample(_dec)(latent) | |
return decode | |
decode = load_decode_function() | |
st.markdown("**Latent Space Parameters**") | |
latent_vector = np.array([st.slider(f"Latent Dimension {l}",min_value=-3.0,max_value=3.0,value=0.0) for l in range(3)]) | |
latent_vector = np.concatenate([latent_vector,np.zeros(7)],axis=0) | |
ts = decode(latent_vector) | |
st.markdown("**Generated Time Series**") | |
fig, ax = plt.subplots(figsize=(8,4)) | |
ax.plot(ts.ravel()) | |
ax.set_ylim([0,4]) | |
st.pyplot(fig) |