Spaces:
Sleeping
Sleeping
intial commit
Browse files- app.py +72 -0
- cnn_model_train.py +64 -0
- mnist_cnn_model.h5 +3 -0
- requirements.txt +5 -0
app.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from tensorflow.keras.models import load_model
|
5 |
+
|
6 |
+
|
7 |
+
# Install the streamlit_drawable_canvas package if you haven't already
|
8 |
+
# !pip install streamlit_drawable_canvas
|
9 |
+
|
10 |
+
# Import the st_canvas function
|
11 |
+
from streamlit_drawable_canvas import st_canvas
|
12 |
+
|
13 |
+
# Function to preprocess the drawn image
|
14 |
+
def preprocess_image(drawing, size=(28, 28)):
|
15 |
+
# Convert the drawing to a PIL Image
|
16 |
+
img = Image.fromarray(np.uint8(drawing))
|
17 |
+
# Resize the image to the desired size
|
18 |
+
img = img.resize(size)
|
19 |
+
# Convert the image to grayscale
|
20 |
+
img = img.convert('L')
|
21 |
+
# Convert the image to a numpy array
|
22 |
+
img_array = np.array(img)
|
23 |
+
# Normalize the pixel values to be between 0 and 1
|
24 |
+
img_array = img_array / 255.0
|
25 |
+
# Add a channel dimension (1 channel for grayscale)
|
26 |
+
img_array = np.expand_dims(img_array, axis=-1)
|
27 |
+
return img_array
|
28 |
+
|
29 |
+
def preprocess_and_predict(img):
|
30 |
+
model = load_model("mnist_cnn_model.h5")
|
31 |
+
image = np.array(img) / 255.0
|
32 |
+
# Expand dimensions to match the input shape expected by the model
|
33 |
+
image = np.expand_dims(image, axis=0)
|
34 |
+
# Reshape to match the input shape expected by the model
|
35 |
+
image = np.reshape(image, (1, 28, 28, 1))
|
36 |
+
prediction = model.predict(image)
|
37 |
+
predicted_class = np.argmax(prediction)
|
38 |
+
|
39 |
+
return predicted_class
|
40 |
+
|
41 |
+
# Main code
|
42 |
+
def main():
|
43 |
+
st.title('Draw Digit')
|
44 |
+
|
45 |
+
# Create a drawing canvas
|
46 |
+
drawing = st_canvas(
|
47 |
+
fill_color="rgb(0, 0, 0)", # Background color of the canvas
|
48 |
+
stroke_width=10, # Stroke width
|
49 |
+
stroke_color="rgb(255, 255, 255)", # Stroke color
|
50 |
+
background_color="#000000", # Background color of the canvas component
|
51 |
+
height=150, # Height of the canvas
|
52 |
+
width=150, # Width of the canvas
|
53 |
+
drawing_mode="freedraw", # Drawing mode: "freedraw" or "transform"
|
54 |
+
key="canvas",
|
55 |
+
)
|
56 |
+
|
57 |
+
# Check if the user has drawn anything
|
58 |
+
if drawing is not None:
|
59 |
+
st.image(drawing.image_data)
|
60 |
+
|
61 |
+
# Preprocess the drawn image
|
62 |
+
processed_image = preprocess_image(drawing.image_data)
|
63 |
+
st.write("Processed Image Shape:", processed_image.shape)
|
64 |
+
|
65 |
+
# Save the processed image
|
66 |
+
np.save("processed_image.npy", processed_image)
|
67 |
+
print(processed_image.shape)
|
68 |
+
digit_class = preprocess_and_predict(processed_image)
|
69 |
+
st.success(digit_class)
|
70 |
+
|
71 |
+
if __name__ == "__main__":
|
72 |
+
main()
|
cnn_model_train.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Importing necessary libraries
|
2 |
+
import tensorflow as tf
|
3 |
+
from tensorflow.keras import layers, models, datasets
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
# Load the MNIST dataset
|
7 |
+
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()
|
8 |
+
|
9 |
+
# Normalize pixel values to be between 0 and 1
|
10 |
+
train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255
|
11 |
+
test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255
|
12 |
+
|
13 |
+
# Convert labels to categorical one-hot encoding
|
14 |
+
train_labels = tf.keras.utils.to_categorical(train_labels, 10)
|
15 |
+
test_labels = tf.keras.utils.to_categorical(test_labels, 10)
|
16 |
+
|
17 |
+
# Define the CNN model
|
18 |
+
def create_cnn_model(input_shape, num_classes):
|
19 |
+
model = models.Sequential()
|
20 |
+
|
21 |
+
# Convolutional layers
|
22 |
+
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape))
|
23 |
+
model.add(layers.MaxPooling2D((2, 2)))
|
24 |
+
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
|
25 |
+
model.add(layers.MaxPooling2D((2, 2)))
|
26 |
+
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
|
27 |
+
|
28 |
+
# Flatten layer to transition from convolutional layers to fully connected layers
|
29 |
+
model.add(layers.Flatten())
|
30 |
+
|
31 |
+
# Dense (fully connected) layers
|
32 |
+
model.add(layers.Dense(64, activation='relu'))
|
33 |
+
model.add(layers.Dense(num_classes, activation='softmax')) # Output layer with softmax activation for multiclass classification
|
34 |
+
|
35 |
+
return model
|
36 |
+
|
37 |
+
# Define input shape and number of classes
|
38 |
+
input_shape = (28, 28, 1) # Input shape for MNIST images
|
39 |
+
num_classes = 10 # Number of classes for digit classification (0-9)
|
40 |
+
|
41 |
+
# Create an instance of the model
|
42 |
+
model = create_cnn_model(input_shape, num_classes)
|
43 |
+
|
44 |
+
# Print model summary
|
45 |
+
model.summary()
|
46 |
+
|
47 |
+
# Compile the model
|
48 |
+
model.compile(optimizer='adam',
|
49 |
+
loss='categorical_crossentropy',
|
50 |
+
metrics=['accuracy'])
|
51 |
+
|
52 |
+
# Train the model
|
53 |
+
model.fit(train_images, train_labels, epochs=5, batch_size=64, validation_data=(test_images, test_labels))
|
54 |
+
# Save the trained model to disk
|
55 |
+
model.save("mnist_cnn_model.h5")
|
56 |
+
print("Model saved to disk.")
|
57 |
+
|
58 |
+
# Load the saved model
|
59 |
+
loaded_model = models.load_model("mnist_cnn_model.h5")
|
60 |
+
print("Model loaded from disk.")
|
61 |
+
|
62 |
+
# Evaluate the loaded model
|
63 |
+
test_loss, test_accuracy = loaded_model.evaluate(test_images, test_labels)
|
64 |
+
print(f"Test Accuracy: {test_accuracy}")
|
mnist_cnn_model.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6bd33517798011fbd1717bd9d757cfd2be2b1f286ce9a886d2b6b6a38c9ae865
|
3 |
+
size 1172032
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tensorflow
|
2 |
+
numpy
|
3 |
+
streamlit
|
4 |
+
streamlit-drawable-canvas
|
5 |
+
pillow
|