CNNClassifier / app.py
HarnithaS's picture
added prediction code
3ea7ff1
import streamlit as st
import numpy as np
from PIL import Image
from tensorflow.keras.models import load_model
# Install the streamlit_drawable_canvas package if you haven't already
# !pip install streamlit_drawable_canvas
# Import the st_canvas function
from streamlit_drawable_canvas import st_canvas
# Function to preprocess the drawn image
def preprocess_image(drawing, size=(28, 28)):
# Convert the drawing to a PIL Image
img = Image.fromarray(np.uint8(drawing))
# Resize the image to the desired size
img = img.resize(size)
# Convert the image to grayscale
img = img.convert('L')
# Convert the image to a numpy array
img_array = np.array(img)
# Normalize the pixel values to be between 0 and 1
img_array = img_array / 255.0
# Add a channel dimension (1 channel for grayscale)
img_array = np.expand_dims(img_array, axis=-1)
return img_array
def preprocess_and_predict(image):
model = load_model("mnist_cnn_model.h5")
# Expand dimensions to match the input shape expected by the model
image = np.expand_dims(image, axis=0)
# Reshape to match the input shape expected by the model
image = np.reshape(image, (1, 28, 28, 1))
prediction = model.predict(image)
predicted_class = np.argmax(prediction)
return predicted_class
# Main code
def main():
st.title('Draw Digit')
# Create a drawing canvas
drawing = st_canvas(
fill_color="rgb(0, 0, 0)", # Background color of the canvas
stroke_width=4, # Stroke width
stroke_color="rgb(255, 255, 255)", # Stroke color
background_color="#000000", # Background color of the canvas component
height=168, # Height of the canvas
width=168, # Width of the canvas
drawing_mode="freedraw", # Drawing mode: "freedraw" or "transform"
key="canvas",
)
predict = st.button('Predict digit')
# Check if the user has drawn anything
if predict is True:
# Preprocess the drawn image
processed_image = preprocess_image(drawing.image_data)
digit_class = preprocess_and_predict(processed_image)
st.title("Predicted Digit:")
st.success(digit_class)
predict = False
if __name__ == "__main__":
main()