import numpy as np import streamlit as st import torch import matplotlib.pyplot as plt import disvae import transforms as trans P_MODEL = "model/drilling_ds_btcvae" @st.cache_resource def load_decode_function(): sorter = trans.LatentSorter(disvae.get_kl_dict(P_MODEL)) vae = disvae.load_model(P_MODEL) scaler = trans.MinMaxScaler(_min=torch.tensor([1.3]),_max=torch.tensor([4.0]),min_norm=0.3,max_norm=0.6) imaging = trans.SumField() _dec = trans.sequential_function( sorter.inv, vae.decoder, scaler.inv, imaging.inv ) def decode(latent): with torch.no_grad(): return trans.np_sample(_dec)(latent) return decode decode = load_decode_function() st.markdown("**Latent Space Parameters**") latent_vector = np.array([st.slider(f"Latent Dimension {l}",min_value=-3.0,max_value=3.0,value=0.0) for l in range(3)]) latent_vector = np.concatenate([latent_vector,np.zeros(7)],axis=0) ts = decode(latent_vector) st.markdown("**Generated Time Series**") fig, ax = plt.subplots(figsize=(8,4)) ax.plot(ts.ravel()) ax.set_ylim([0,4]) st.pyplot(fig)