Spaces:
Sleeping
Sleeping
import streamlit as st | |
import numpy as np | |
from PIL import Image | |
from tensorflow.keras.models import load_model | |
# Install the streamlit_drawable_canvas package if you haven't already | |
# !pip install streamlit_drawable_canvas | |
# Import the st_canvas function | |
from streamlit_drawable_canvas import st_canvas | |
# Function to preprocess the drawn image | |
def preprocess_image(drawing, size=(28, 28)): | |
# Convert the drawing to a PIL Image | |
img = Image.fromarray(np.uint8(drawing)) | |
# Resize the image to the desired size | |
img = img.resize(size) | |
# Convert the image to grayscale | |
img = img.convert('L') | |
# Convert the image to a numpy array | |
img_array = np.array(img) | |
# Normalize the pixel values to be between 0 and 1 | |
img_array = img_array / 255.0 | |
# Add a channel dimension (1 channel for grayscale) | |
img_array = np.expand_dims(img_array, axis=-1) | |
return img_array | |
def preprocess_and_predict(image): | |
model = load_model("mnist_cnn_model.h5") | |
# Expand dimensions to match the input shape expected by the model | |
image = np.expand_dims(image, axis=0) | |
# Reshape to match the input shape expected by the model | |
image = np.reshape(image, (1, 28, 28, 1)) | |
prediction = model.predict(image) | |
predicted_class = np.argmax(prediction) | |
return predicted_class | |
# Main code | |
def main(): | |
st.title('Draw Digit') | |
# Create a drawing canvas | |
drawing = st_canvas( | |
fill_color="rgb(0, 0, 0)", # Background color of the canvas | |
stroke_width=4, # Stroke width | |
stroke_color="rgb(255, 255, 255)", # Stroke color | |
background_color="#000000", # Background color of the canvas component | |
height=168, # Height of the canvas | |
width=168, # Width of the canvas | |
drawing_mode="freedraw", # Drawing mode: "freedraw" or "transform" | |
key="canvas", | |
) | |
predict = st.button('Predict digit') | |
# Check if the user has drawn anything | |
if predict is True: | |
# Preprocess the drawn image | |
processed_image = preprocess_image(drawing.image_data) | |
digit_class = preprocess_and_predict(processed_image) | |
st.title("Predicted Digit:") | |
st.success(digit_class) | |
predict = False | |
if __name__ == "__main__": | |
main() | |