CaxtonEmeraldS's picture
Update app.py
61cc1ca verified
import streamlit as st
from tensorflow import keras
import os
import matplotlib.pyplot as plt
from io import BytesIO
from NNVisualiser import NNVisualiser
import glob
import inspect
from tensorflow.keras.models import save_model
import tempfile
import re
import zipfile
import io
# Function to create a ZIP file of all PNG files
def create_zip_of_png_files():
# Get current working directory
cwd = os.getcwd()
png_files = [f for f in os.listdir(cwd) if f.endswith('.png')]
# Create a BytesIO object to hold the ZIP file in memory
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w') as zip_file:
for png_file in png_files:
zip_file.write(os.path.join(cwd, png_file), arcname=png_file)
zip_buffer.seek(0) # Seek to the beginning of the BytesIO buffer
return zip_buffer
def generate_title_from_method_name(method_name):
# Remove the "plot" prefix if it exists
if method_name.startswith("plot"):
method_name = method_name[4:] # Remove the first 4 characters ("plot")
# Split the string at camel case boundaries
words = re.findall(r'[A-Z][a-z]*', method_name)
# Join the words with spaces and format the final string
title = "Plotting " + " ".join(words[:]) + " Plot "
return title
def downloadKerasModel():
with tempfile.NamedTemporaryFile(delete=False, suffix=".keras") as tmp_file:
save_model(model, tmp_file.name)
tmp_file.seek(0)
model_data = tmp_file.read()
return model_data
# Function to build folder hierarchy up to the 6th level (excluding files and hidden folders)
# @st.cache_data
def generate_folder_hierarchy(root_folder, max_depth=7):
folder_dict = {}
# Traverse through the directory tree
for dirpath, dirnames, filenames in os.walk(root_folder):
# Get the relative path from the root folder
rel_path = os.path.relpath(dirpath, root_folder)
depth = rel_path.count(os.sep) + 1 # Calculate the depth level
# Only include directories up to the max_depth (7th level)
if depth > max_depth:
continue
# Filter out directories that start with a dot (e.g., .git)
dirnames[:] = [d for d in dirnames if not d.startswith('.') and d != '1']
sub_dict = folder_dict
# Split the relative path into parts to create a nested structure
for part in rel_path.split(os.sep):
if part == '.' or part.startswith('.'):
continue
if part not in sub_dict:
sub_dict[part] = {}
sub_dict = sub_dict[part]
return folder_dict
@st.cache_data
def getPlotMethods():
return [name for name, func in inspect.getmembers(NNVisualiser, inspect.isfunction) if name.startswith('plot')]
# Example usage
root_folder = os.getcwd(); # Replace with your folder path
folder_hierarchy = generate_folder_hierarchy(root_folder)
# Streamlit app
st.title("Repository : Simple ANN Models with UAT Architecture")
st.write(f"A Collection of ANN Models with a 1-xReLU-1 Architecture for Basic 1D Functions on Bounded Intervals")
#Commented
# col1, col2, col3 = st.columns([4, 3, 3])
# with col1:
# # Level 1: Initialisation dropdown
# initialisation = st.selectbox("Select Initialisation", list(folder_hierarchy.keys()))
# with col2:
# # Level 2: Sample size dropdown, based on selected initialisation
# sampleSize = st.selectbox("Select Sample Size", list(folder_hierarchy[initialisation].keys()))
# with col3:
# # Level 3: Batch size dropdown, based on selected sample size
# batchSize = st.selectbox("Select Batch Size", list(folder_hierarchy[initialisation][sampleSize].keys()))
# col4, col5, col6 = st.columns([3, 4, 3])
# with col4:
# # Level 4: Epochs count dropdown, based on selected batch size
# epochs = st.selectbox("Select Epochs Count", list(folder_hierarchy[initialisation][sampleSize][batchSize].keys()))
# with col5:
# # Level 5: Functions list dropdown, based on selected epochs count
# functions = st.selectbox("Select Neurons Count", list(folder_hierarchy[initialisation][sampleSize][batchSize][epochs].keys()))
# with col6:
# # Level 6: Neurons count dropdown, based on selected function
# neurons = st.selectbox("Select Neurons Count", list(folder_hierarchy[initialisation][sampleSize][batchSize][epochs][functions].keys()))
repo = st.sidebar.selectbox("Select Model Repository",list(folder_hierarchy.keys()))
initialisation = st.sidebar.selectbox("Select Initialisation", list(folder_hierarchy[repo].keys()))
sampleSize = st.sidebar.selectbox("Select Sample Size", list(folder_hierarchy[repo][initialisation].keys()))
batchSize = st.sidebar.selectbox("Select Batch Size", list(folder_hierarchy[repo][initialisation][sampleSize].keys()))
epochs = st.sidebar.selectbox("Select Epochs Count", list(folder_hierarchy[repo][initialisation][sampleSize][batchSize].keys()))
functions = st.sidebar.selectbox("Select Function", list(folder_hierarchy[repo][initialisation][sampleSize][batchSize][epochs].keys()))
neurons = st.sidebar.selectbox("Select Neurons Count", list(folder_hierarchy[repo][initialisation][sampleSize][batchSize][epochs][functions].keys()))
# Display the selected values
st.write(f"You selected: {repo} : {initialisation} : {sampleSize} : {batchSize} : {epochs} : {functions} : {neurons}")
modelPath = os.path.join(os.getcwd(), repo, initialisation, sampleSize, batchSize, epochs, functions, neurons);
model = keras.models.load_model(modelPath);
visualiser = NNVisualiser(model);
visualiser.setSavePlots(True);
# Function to get layer and neuron information
def get_layer_info(model):
layer_info = []
for layer in model.layers:
layer_info.append({
'index': len(layer_info),
'type': layer.__class__.__name__,
'units': getattr(layer, 'units', None), # Number of neurons
})
return layer_info
layer_info = get_layer_info(model)
# Extract layer indices and neuron counts
layer_indices = [layer['index'] for layer in layer_info]
neuron_counts = [layer['units'] for layer in layer_info]
# Dropdown for selecting layer index
#selected_layer_index = st.sidebar.selectbox("Select Layer Index", layer_indices)
# Find the number of neurons for the selected layer
#selected_layer_units = neuron_counts[selected_layer_index]
# Dropdown for selecting neuron index in the selected layer
#neuron_indices = list(range(selected_layer_units))
#selected_neuron_index = st.sidebar.selectbox("Select Neuron Index", neuron_indices)
# Dropdown for selecting plots from NNVisualiser
plotMethods = getPlotMethods()
selectedPlotMethod = st.sidebar.selectbox("Select Plot", plotMethods)
#Removing earlier plots
image_files = glob.glob("*.png")
for file in image_files:
try:
os.remove(file)
except Exception as e:
st.write("Error in removing previous plots")
st.session_state.title_text = generate_title_from_method_name(selectedPlotMethod)
st.title(st.session_state.title_text)
# Call your package's plot method (which directly plots without returning a figure)
visualiser.setSavePlots(True);
method = getattr(visualiser, selectedPlotMethod, None)
if method is not None:
if 'Neuron' in selectedPlotMethod:
selected_layer_index = st.sidebar.selectbox("Select Layer Index", layer_indices)
# Find the number of neurons for the selected layer
selected_layer_units = neuron_counts[selected_layer_index]
# Dropdown for selecting neuron index in the selected layer
neuron_indices = list(range(selected_layer_units))
selected_neuron_index = st.sidebar.selectbox("Select Neuron Index", neuron_indices)
params = (selected_layer_index, selected_neuron_index)
method(*params)
elif 'Layer' in selectedPlotMethod:
selected_layer_index = st.sidebar.selectbox("Select Layer Index", layer_indices)
params = (selected_layer_index,)
method(*params)
else:
method()
st.session_state.kerasModelToDownload = downloadKerasModel()
st.session_state.plotsToDownload = create_zip_of_png_files()
@st.fragment()
def downloads():
st.download_button(
label="Download Model",
data = downloadKerasModel(),
file_name="model.keras",
mime="application/octet-stream"
);
st.download_button(
label="Download Plots",
data=create_zip_of_png_files(),
file_name="images.zip",
mime="application/zip"
);
# column = st.columns (2)
# column[0].download_button(
# label="Download Model",
# data = downloadKerasModel(),
# file_name="model.keras",
# mime="application/octet-stream"
# );
# column[1].download_button(
# label="Download Plots",
# data=create_zip_of_png_files(),
# file_name="images.zip",
# mime="application/zip"
# );
with st.sidebar:
downloads()
# visualiser.plotFlowForNetwork();
image_files = glob.glob("*.png")
# Use Streamlit to display the image from the buffer
st.image(image_files)
# if st.sidebar.button("Download Keras model"):
# downloadKerasModel()
# if st.sidebar.download_button(
# label="Download Keras Model",
# data = downloadKerasModel(),
# file_name="model.keras",
# mime="application/octet-stream"
# ):
# st.sidebar.success(f"Model Downloaded Successfully")
# # Button to create and download the ZIP file
# if st.sidebar.download_button(
# label="Download Plots",
# data=create_zip_of_png_files(),
# file_name="images.zip",
# mime="application/zip"
# ):
# st.sidebar.success(f"Plots Downloaded Successfully")