Spaces:
Sleeping
Sleeping
vincentiusyoshuac
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -47,10 +47,11 @@ class TimeSeriesForecaster:
|
|
47 |
"""
|
48 |
Create visualization of predictions
|
49 |
"""
|
50 |
-
plt.figure(figsize=(12, 6))
|
|
|
51 |
|
52 |
# Plot original series
|
53 |
-
plt.plot(range(len(self.original_series)), self.original_series, label='Historical Data', color='
|
54 |
|
55 |
# Calculate forecast statistics
|
56 |
forecast_np = forecasts[0].numpy()
|
@@ -58,60 +59,127 @@ class TimeSeriesForecaster:
|
|
58 |
|
59 |
# Plot forecast
|
60 |
forecast_index = range(len(self.original_series), len(self.original_series) + len(median))
|
61 |
-
plt.plot(forecast_index, median, color='
|
62 |
-
plt.fill_between(forecast_index, low, high, color='
|
63 |
|
64 |
-
plt.title('Time Series Forecasting
|
65 |
-
plt.xlabel('Time Index')
|
66 |
-
plt.ylabel('Value')
|
67 |
-
plt.legend()
|
|
|
68 |
|
69 |
return plt
|
70 |
|
71 |
def main():
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
"Upload CSV File",
|
80 |
type=['csv'],
|
81 |
-
help="
|
|
|
82 |
)
|
83 |
-
|
84 |
-
|
|
|
85 |
if uploaded_file is not None:
|
86 |
# Read CSV
|
87 |
df = pd.read_csv(uploaded_file)
|
88 |
|
89 |
-
#
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
# Prediction parameters
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
-
#
|
114 |
-
if st.
|
115 |
try:
|
116 |
# Initialize forecaster
|
117 |
forecaster = TimeSeriesForecaster()
|
@@ -128,12 +196,17 @@ def main():
|
|
128 |
# Perform forecasting
|
129 |
forecasts = forecaster.forecast(context, prediction_length)
|
130 |
|
131 |
-
#
|
|
|
132 |
st.subheader('Forecast Visualization')
|
133 |
plt = forecaster.visualize_forecast(context, forecasts)
|
134 |
st.pyplot(plt)
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
-
# Display forecast details
|
137 |
forecast_np = forecasts[0].numpy()
|
138 |
forecast_mean = forecast_np.mean(axis=0)
|
139 |
forecast_lower = np.percentile(forecast_np, 10, axis=0)
|
@@ -145,11 +218,13 @@ def main():
|
|
145 |
'Upper Bound (90%)': forecast_upper
|
146 |
})
|
147 |
|
148 |
-
st.subheader('Forecast Details')
|
149 |
st.dataframe(prediction_df)
|
|
|
150 |
|
151 |
except Exception as e:
|
152 |
st.error(f"An error occurred: {str(e)}")
|
|
|
|
|
153 |
|
154 |
if __name__ == '__main__':
|
155 |
main()
|
|
|
47 |
"""
|
48 |
Create visualization of predictions
|
49 |
"""
|
50 |
+
plt.figure(figsize=(12, 6), facecolor='#f0f2f6')
|
51 |
+
plt.style.use('seaborn')
|
52 |
|
53 |
# Plot original series
|
54 |
+
plt.plot(range(len(self.original_series)), self.original_series, label='Historical Data', color='#1E88E5', linewidth=2)
|
55 |
|
56 |
# Calculate forecast statistics
|
57 |
forecast_np = forecasts[0].numpy()
|
|
|
59 |
|
60 |
# Plot forecast
|
61 |
forecast_index = range(len(self.original_series), len(self.original_series) + len(median))
|
62 |
+
plt.plot(forecast_index, median, color='#D81B60', linewidth=2, label='Median Forecast')
|
63 |
+
plt.fill_between(forecast_index, low, high, color='#D81B60', alpha=0.3, label='80% Prediction Interval')
|
64 |
|
65 |
+
plt.title('Time Series Forecasting', fontsize=16, fontweight='bold')
|
66 |
+
plt.xlabel('Time Index', fontsize=12)
|
67 |
+
plt.ylabel('Value', fontsize=12)
|
68 |
+
plt.legend(frameon=True)
|
69 |
+
plt.grid(True, linestyle='--', alpha=0.7)
|
70 |
|
71 |
return plt
|
72 |
|
73 |
def main():
|
74 |
+
# Set page configuration
|
75 |
+
st.set_page_config(
|
76 |
+
page_title="Time Series Forecaster",
|
77 |
+
page_icon="📈",
|
78 |
+
layout="wide",
|
79 |
+
initial_sidebar_state="collapsed"
|
80 |
+
)
|
81 |
+
|
82 |
+
# Custom CSS for modern look
|
83 |
+
st.markdown("""
|
84 |
+
<style>
|
85 |
+
/* Modern, clean design */
|
86 |
+
.stApp {
|
87 |
+
background-color: #f0f2f6;
|
88 |
+
font-family: 'Inter', sans-serif;
|
89 |
+
}
|
90 |
+
|
91 |
+
/* Card-like containers */
|
92 |
+
.card {
|
93 |
+
background-color: white;
|
94 |
+
border-radius: 10px;
|
95 |
+
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
|
96 |
+
padding: 20px;
|
97 |
+
margin-bottom: 20px;
|
98 |
+
}
|
99 |
|
100 |
+
/* Stylish file uploader */
|
101 |
+
.stFileUploader {
|
102 |
+
background-color: #f8f9fa;
|
103 |
+
border: 2px dashed #6c757d;
|
104 |
+
border-radius: 10px;
|
105 |
+
padding: 20px;
|
106 |
+
text-align: center;
|
107 |
+
}
|
108 |
|
109 |
+
/* Buttons */
|
110 |
+
.stButton>button {
|
111 |
+
background-color: #1E88E5;
|
112 |
+
color: white;
|
113 |
+
border-radius: 6px;
|
114 |
+
transition: all 0.3s ease;
|
115 |
+
}
|
116 |
+
.stButton>button:hover {
|
117 |
+
background-color: #1565c0;
|
118 |
+
transform: scale(1.05);
|
119 |
+
}
|
120 |
+
</style>
|
121 |
+
""", unsafe_allow_html=True)
|
122 |
+
|
123 |
+
# Title and description
|
124 |
+
st.markdown("<h1 style='text-align: center; color: #1E88E5;'>🕰️ Time Series Forecaster</h1>", unsafe_allow_html=True)
|
125 |
+
st.markdown("<p style='text-align: center; color: #6c757d;'>Predict future trends with advanced machine learning</p>", unsafe_allow_html=True)
|
126 |
+
|
127 |
+
# File upload section
|
128 |
+
st.markdown("<div class='card'>", unsafe_allow_html=True)
|
129 |
+
uploaded_file = st.file_uploader(
|
130 |
"Upload CSV File",
|
131 |
type=['csv'],
|
132 |
+
help="Upload a CSV file with time series data",
|
133 |
+
label_visibility="collapsed"
|
134 |
)
|
135 |
+
st.markdown("</div>", unsafe_allow_html=True)
|
136 |
+
|
137 |
+
# Data and forecast configuration
|
138 |
if uploaded_file is not None:
|
139 |
# Read CSV
|
140 |
df = pd.read_csv(uploaded_file)
|
141 |
|
142 |
+
# Configuration card
|
143 |
+
st.markdown("<div class='card'>", unsafe_allow_html=True)
|
144 |
+
col1, col2 = st.columns(2)
|
145 |
+
|
146 |
+
with col1:
|
147 |
+
date_column = st.selectbox(
|
148 |
+
'Select Date Column',
|
149 |
+
options=df.columns,
|
150 |
+
help="Choose the column representing timestamps"
|
151 |
+
)
|
152 |
+
|
153 |
+
with col2:
|
154 |
+
value_column = st.selectbox(
|
155 |
+
'Select Value Column',
|
156 |
+
options=[col for col in df.columns if col != date_column],
|
157 |
+
help="Choose the numeric column to forecast"
|
158 |
+
)
|
159 |
|
160 |
# Prediction parameters
|
161 |
+
col3, col4 = st.columns(2)
|
162 |
+
|
163 |
+
with col3:
|
164 |
+
context_length = st.slider(
|
165 |
+
'Context Length',
|
166 |
+
min_value=10,
|
167 |
+
max_value=100,
|
168 |
+
value=30,
|
169 |
+
help="Number of historical data points to use for prediction"
|
170 |
+
)
|
171 |
+
|
172 |
+
with col4:
|
173 |
+
prediction_length = st.slider(
|
174 |
+
'Prediction Length',
|
175 |
+
min_value=1,
|
176 |
+
max_value=30,
|
177 |
+
value=7,
|
178 |
+
help="Number of future time steps to predict"
|
179 |
+
)
|
180 |
|
181 |
+
# Forecast button
|
182 |
+
if st.button('Generate Forecast'):
|
183 |
try:
|
184 |
# Initialize forecaster
|
185 |
forecaster = TimeSeriesForecaster()
|
|
|
196 |
# Perform forecasting
|
197 |
forecasts = forecaster.forecast(context, prediction_length)
|
198 |
|
199 |
+
# Visualization card
|
200 |
+
st.markdown("<div class='card'>", unsafe_allow_html=True)
|
201 |
st.subheader('Forecast Visualization')
|
202 |
plt = forecaster.visualize_forecast(context, forecasts)
|
203 |
st.pyplot(plt)
|
204 |
+
st.markdown("</div>", unsafe_allow_html=True)
|
205 |
+
|
206 |
+
# Forecast details card
|
207 |
+
st.markdown("<div class='card'>", unsafe_allow_html=True)
|
208 |
+
st.subheader('Forecast Details')
|
209 |
|
|
|
210 |
forecast_np = forecasts[0].numpy()
|
211 |
forecast_mean = forecast_np.mean(axis=0)
|
212 |
forecast_lower = np.percentile(forecast_np, 10, axis=0)
|
|
|
218 |
'Upper Bound (90%)': forecast_upper
|
219 |
})
|
220 |
|
|
|
221 |
st.dataframe(prediction_df)
|
222 |
+
st.markdown("</div>", unsafe_allow_html=True)
|
223 |
|
224 |
except Exception as e:
|
225 |
st.error(f"An error occurred: {str(e)}")
|
226 |
+
|
227 |
+
st.markdown("</div>", unsafe_allow_html=True)
|
228 |
|
229 |
if __name__ == '__main__':
|
230 |
main()
|