V8055 commited on
Commit
a6a2244
·
verified ·
1 Parent(s): 0106f84

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -0
app.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import numpy as np
5
+ from sklearn.datasets import load_iris
6
+ from sklearn.model_selection import train_test_split
7
+ from sklearn.preprocessing import StandardScaler
8
+ from sklearn.linear_model import LogisticRegression
9
+ from sklearn.tree import DecisionTreeClassifier
10
+ from sklearn.ensemble import RandomForestClassifier
11
+ from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
12
+ import plotly.express as px
13
+ import plotly.graph_objects as go
14
+ import seaborn as sns
15
+ import matplotlib.pyplot as plt
16
+
17
+ class IrisPredictor:
18
+ def __init__(self):
19
+ self.iris = load_iris()
20
+ self.df = pd.DataFrame(data=np.c_[self.iris['data'], self.iris['target']],
21
+ columns=self.iris['feature_names'] + ['target'])
22
+ self.models = {
23
+ 'Logistic Regression': LogisticRegression(),
24
+ 'Decision Tree': DecisionTreeClassifier(),
25
+ 'Random Forest': RandomForestClassifier()
26
+ }
27
+ self.X = self.df.drop('target', axis=1)
28
+ self.y = self.df['target']
29
+ self.scaler = StandardScaler()
30
+
31
+ def preprocess_data(self):
32
+ # Split the data
33
+ X_train, X_test, y_train, y_test = train_test_split(
34
+ self.X, self.y, test_size=0.2, random_state=42
35
+ )
36
+
37
+ # Scale the features
38
+ X_train_scaled = self.scaler.fit_transform(X_train)
39
+ X_test_scaled = self.scaler.transform(X_test)
40
+
41
+ return X_train_scaled, X_test_scaled, y_train, y_test
42
+
43
+ def train_model(self, model_name):
44
+ X_train_scaled, X_test_scaled, y_train, y_test = self.preprocess_data()
45
+
46
+ # Train model
47
+ model = self.models[model_name]
48
+ model.fit(X_train_scaled, y_train)
49
+
50
+ # Make predictions
51
+ y_pred = model.predict(X_test_scaled)
52
+
53
+ # Calculate metrics
54
+ accuracy = accuracy_score(y_test, y_pred)
55
+ conf_matrix = confusion_matrix(y_test, y_pred)
56
+ class_report = classification_report(y_test, y_pred)
57
+
58
+ return {
59
+ 'model': model,
60
+ 'accuracy': accuracy,
61
+ 'confusion_matrix': conf_matrix,
62
+ 'classification_report': class_report,
63
+ 'X_test': X_test_scaled,
64
+ 'y_test': y_test,
65
+ 'y_pred': y_pred
66
+ }
67
+
68
+ def plot_confusion_matrix(self, conf_matrix):
69
+ fig = px.imshow(conf_matrix,
70
+ labels=dict(x="Predicted", y="Actual"),
71
+ x=['Setosa', 'Versicolor', 'Virginica'],
72
+ y=['Setosa', 'Versicolor', 'Virginica'],
73
+ title="Confusion Matrix")
74
+ return fig
75
+
76
+ def plot_feature_importance(self, model_name, model):
77
+ if model_name == 'Logistic Regression':
78
+ importance = abs(model.coef_[0])
79
+ else:
80
+ importance = model.feature_importances_
81
+
82
+ fig = px.bar(x=self.X.columns, y=importance,
83
+ title=f"Feature Importance - {model_name}",
84
+ labels={'x': 'Features', 'y': 'Importance'})
85
+ return fig
86
+
87
+ def predict_single_sample(self, model, features):
88
+ # Scale features
89
+ scaled_features = self.scaler.transform([features])
90
+ # Make prediction
91
+ prediction = model.predict(scaled_features)
92
+ probabilities = model.predict_proba(scaled_features)
93
+ return prediction[0], probabilities[0]
94
+
95
+ def main():
96
+ st.title("🌸 Iris Flower Prediction App")
97
+ st.write("""
98
+ This app predicts the Iris flower type based on its features.
99
+ Choose a model and see how it performs!
100
+ """)
101
+
102
+ # Initialize predictor
103
+ predictor = IrisPredictor()
104
+
105
+ # Model selection
106
+ st.sidebar.header("Model Selection")
107
+ model_name = st.sidebar.selectbox(
108
+ "Choose a model",
109
+ list(predictor.models.keys())
110
+ )
111
+
112
+ # Train model and show results
113
+ if st.sidebar.button("Train Model"):
114
+ with st.spinner("Training model..."):
115
+ results = predictor.train_model(model_name)
116
+
117
+ # Display metrics
118
+ st.header("Model Performance")
119
+ st.metric("Accuracy", f"{results['accuracy']:.2%}")
120
+
121
+ # Display confusion matrix
122
+ st.subheader("Confusion Matrix")
123
+ conf_matrix_fig = predictor.plot_confusion_matrix(results['confusion_matrix'])
124
+ st.plotly_chart(conf_matrix_fig)
125
+
126
+ # Display feature importance
127
+ st.subheader("Feature Importance")
128
+ importance_fig = predictor.plot_feature_importance(model_name, results['model'])
129
+ st.plotly_chart(importance_fig)
130
+
131
+ # Display classification report
132
+ st.subheader("Classification Report")
133
+ st.text(results['classification_report'])
134
+
135
+ # Store the trained model in session state
136
+ st.session_state['trained_model'] = results['model']
137
+
138
+ # Make predictions
139
+ st.header("Make Predictions")
140
+ col1, col2 = st.columns(2)
141
+
142
+ with col1:
143
+ sepal_length = st.slider("Sepal Length", 4.0, 8.0, 5.4)
144
+ sepal_width = st.slider("Sepal Width", 2.0, 4.5, 3.4)
145
+
146
+ with col2:
147
+ petal_length = st.slider("Petal Length", 1.0, 7.0, 4.7)
148
+ petal_width = st.slider("Petal Width", 0.1, 2.5, 1.4)
149
+
150
+ if st.button("Predict"):
151
+ if 'trained_model' in st.session_state:
152
+ features = [sepal_length, sepal_width, petal_length, petal_width]
153
+ prediction, probabilities = predictor.predict_single_sample(
154
+ st.session_state['trained_model'], features
155
+ )
156
+
157
+ # Display prediction
158
+ iris_types = ['Setosa', 'Versicolor', 'Virginica']
159
+ st.success(f"Predicted Iris Type: {iris_types[int(prediction)]}")
160
+
161
+ # Display probability distribution
162
+ st.subheader("Prediction Probabilities")
163
+ prob_fig = px.bar(x=iris_types, y=probabilities,
164
+ title="Prediction Probability Distribution",
165
+ labels={'x': 'Iris Type', 'y': 'Probability'})
166
+ st.plotly_chart(prob_fig)
167
+ else:
168
+ st.warning("Please train a model first!")
169
+
170
+ if __name__ == "__main__":
171
+ main()