Spaces:
Sleeping
Sleeping
import streamlit as st | |
from nn import NeuralNetwork | |
import json | |
from utils import sigmoid, sigmoid_prime | |
INPUTS = [[0,0],[0,1],[1,0],[1,1]] | |
OUTPUTS = [[0],[1],[1],[0]] | |
def resetSession(): | |
st.session_state.nn = None | |
st.session_state.train_count = 0 | |
## Controller Function | |
def runNN(): | |
nn = st.session_state.nn | |
df = { | |
"input": [], | |
"expected": [], | |
"predicted": [], | |
"rounded": [], | |
"correct": [] | |
} | |
for i in range(4): | |
result = nn.predict(INPUTS[i][0],INPUTS[i][1], activation=sigmoid) | |
df["input"].append(f"{INPUTS[i][0]} xor {INPUTS[i][1]}") | |
df["expected"].append(OUTPUTS[i][0]) | |
df["predicted"].append(result) | |
df["rounded"].append(round(result)) | |
df["correct"].append('correct' if round(result)==OUTPUTS[i][0] else 'incorrect') | |
st.dataframe(df) | |
# st.write(f"for input `{INPUTS[i][0]} xor {INPUTS[i][1]}` expected `{OUTPUTS[i][0]}` predicted `{result}` which rounds to `{round(result)}` and is `{ 'correct' if round(result)==OUTPUTS[i][0] else 'incorrect' }`") | |
def sidebar(): | |
# Neural network controls | |
st.sidebar.header('Neural Network Controls') | |
st.sidebar.text('Number of epochs') | |
epochs = st.sidebar.slider('Epochs', 1, 10000, 500) | |
st.sidebar.text('Learning rate') | |
alphas = st.sidebar.slider('Alphas', 1, 100, 20) | |
col1, col2 = st.sidebar.columns(2) | |
if col1.button('New Model'): | |
btnNewModel() | |
if col2.button('Reset Model'): | |
resetSession() | |
if "nn" in st.session_state and st.session_state.nn is not None: | |
if st.sidebar.button('Train Model'): | |
btnTrainModel(epochs, alphas) | |
if st.sidebar.button('Run Neural Network'): | |
btnRunModel() | |
st.sidebar.download_button(label="Save Model", data=json.dumps(st.session_state.nn.getModelJson()), file_name="model.json", mime="application/json") | |
def btnNewModel(): | |
resetSession() | |
st.session_state.nn = NeuralNetwork() | |
st.sidebar.text("New model created") | |
def btnTrainModel(epochs, alphas): | |
st.session_state.nn.train(inputs=INPUTS, outputs=OUTPUTS, epochs=epochs, alpha=alphas) | |
st.session_state.train_count += 1 | |
st.sidebar.text(f"Model trained {st.session_state.train_count} times") | |
def btnRunModel(): | |
runNN() | |
def btnResetModel(): | |
resetSession() | |
st.sidebar.text("Model reset") | |
def app(): | |
# initSession() | |
st.title('Simple Neural Network App') | |
st.write('I followed a tutorial in the reference and changed to apply good programming practices.') | |
st.write('This is the Neural Network image we are trying to implement!') | |
st.image('nn.png', width=550) | |
sidebar() | |
st.markdown(''' | |
### References | |
* https://www.codingame.com/playgrounds/59631/neural-network-xor-example-from-scratch-no-libs | |
''') | |
if __name__ == '__main__': | |
app() | |