vincentiusyoshuac commited on
Commit
468610e
·
verified ·
1 Parent(s): a1d8a9f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -42
app.py CHANGED
@@ -47,10 +47,11 @@ class TimeSeriesForecaster:
47
  """
48
  Create visualization of predictions
49
  """
50
- plt.figure(figsize=(12, 6))
 
51
 
52
  # Plot original series
53
- plt.plot(range(len(self.original_series)), self.original_series, label='Historical Data', color='blue')
54
 
55
  # Calculate forecast statistics
56
  forecast_np = forecasts[0].numpy()
@@ -58,60 +59,127 @@ class TimeSeriesForecaster:
58
 
59
  # Plot forecast
60
  forecast_index = range(len(self.original_series), len(self.original_series) + len(median))
61
- plt.plot(forecast_index, median, color='red', label='Median Forecast')
62
- plt.fill_between(forecast_index, low, high, color='red', alpha=0.3, label='80% Prediction Interval')
63
 
64
- plt.title('Time Series Forecasting with Amazon Chronos')
65
- plt.xlabel('Time Index')
66
- plt.ylabel('Value')
67
- plt.legend()
 
68
 
69
  return plt
70
 
71
  def main():
72
- st.title('🕰️ Time Series Forecasting with Amazon Chronos')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- # Sidebar for upload and configuration
75
- st.sidebar.header('Forecast Settings')
 
 
 
 
 
 
76
 
77
- # Upload CSV file
78
- uploaded_file = st.sidebar.file_uploader(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  "Upload CSV File",
80
  type=['csv'],
81
- help="Ensure CSV file has date and numeric columns"
 
82
  )
83
-
84
- # Column selection and prediction settings
 
85
  if uploaded_file is not None:
86
  # Read CSV
87
  df = pd.read_csv(uploaded_file)
88
 
89
- # Select columns
90
- date_column = st.sidebar.selectbox(
91
- 'Select Date Column',
92
- options=df.columns
93
- )
94
- value_column = st.sidebar.selectbox(
95
- 'Select Value Column',
96
- options=[col for col in df.columns if col != date_column]
97
- )
 
 
 
 
 
 
 
 
98
 
99
  # Prediction parameters
100
- context_length = st.sidebar.slider(
101
- 'Context Length',
102
- min_value=10,
103
- max_value=100,
104
- value=30
105
- )
106
- prediction_length = st.sidebar.slider(
107
- 'Prediction Length',
108
- min_value=1,
109
- max_value=30,
110
- value=7
111
- )
 
 
 
 
 
 
 
112
 
113
- # Process button
114
- if st.sidebar.button('Perform Forecast'):
115
  try:
116
  # Initialize forecaster
117
  forecaster = TimeSeriesForecaster()
@@ -128,12 +196,17 @@ def main():
128
  # Perform forecasting
129
  forecasts = forecaster.forecast(context, prediction_length)
130
 
131
- # Visualize results
 
132
  st.subheader('Forecast Visualization')
133
  plt = forecaster.visualize_forecast(context, forecasts)
134
  st.pyplot(plt)
 
 
 
 
 
135
 
136
- # Display forecast details
137
  forecast_np = forecasts[0].numpy()
138
  forecast_mean = forecast_np.mean(axis=0)
139
  forecast_lower = np.percentile(forecast_np, 10, axis=0)
@@ -145,11 +218,13 @@ def main():
145
  'Upper Bound (90%)': forecast_upper
146
  })
147
 
148
- st.subheader('Forecast Details')
149
  st.dataframe(prediction_df)
 
150
 
151
  except Exception as e:
152
  st.error(f"An error occurred: {str(e)}")
 
 
153
 
154
  if __name__ == '__main__':
155
  main()
 
47
  """
48
  Create visualization of predictions
49
  """
50
+ plt.figure(figsize=(12, 6), facecolor='#f0f2f6')
51
+ plt.style.use('seaborn')
52
 
53
  # Plot original series
54
+ plt.plot(range(len(self.original_series)), self.original_series, label='Historical Data', color='#1E88E5', linewidth=2)
55
 
56
  # Calculate forecast statistics
57
  forecast_np = forecasts[0].numpy()
 
59
 
60
  # Plot forecast
61
  forecast_index = range(len(self.original_series), len(self.original_series) + len(median))
62
+ plt.plot(forecast_index, median, color='#D81B60', linewidth=2, label='Median Forecast')
63
+ plt.fill_between(forecast_index, low, high, color='#D81B60', alpha=0.3, label='80% Prediction Interval')
64
 
65
+ plt.title('Time Series Forecasting', fontsize=16, fontweight='bold')
66
+ plt.xlabel('Time Index', fontsize=12)
67
+ plt.ylabel('Value', fontsize=12)
68
+ plt.legend(frameon=True)
69
+ plt.grid(True, linestyle='--', alpha=0.7)
70
 
71
  return plt
72
 
73
  def main():
74
+ # Set page configuration
75
+ st.set_page_config(
76
+ page_title="Time Series Forecaster",
77
+ page_icon="📈",
78
+ layout="wide",
79
+ initial_sidebar_state="collapsed"
80
+ )
81
+
82
+ # Custom CSS for modern look
83
+ st.markdown("""
84
+ <style>
85
+ /* Modern, clean design */
86
+ .stApp {
87
+ background-color: #f0f2f6;
88
+ font-family: 'Inter', sans-serif;
89
+ }
90
+
91
+ /* Card-like containers */
92
+ .card {
93
+ background-color: white;
94
+ border-radius: 10px;
95
+ box-shadow: 0 4px 6px rgba(0,0,0,0.1);
96
+ padding: 20px;
97
+ margin-bottom: 20px;
98
+ }
99
 
100
+ /* Stylish file uploader */
101
+ .stFileUploader {
102
+ background-color: #f8f9fa;
103
+ border: 2px dashed #6c757d;
104
+ border-radius: 10px;
105
+ padding: 20px;
106
+ text-align: center;
107
+ }
108
 
109
+ /* Buttons */
110
+ .stButton>button {
111
+ background-color: #1E88E5;
112
+ color: white;
113
+ border-radius: 6px;
114
+ transition: all 0.3s ease;
115
+ }
116
+ .stButton>button:hover {
117
+ background-color: #1565c0;
118
+ transform: scale(1.05);
119
+ }
120
+ </style>
121
+ """, unsafe_allow_html=True)
122
+
123
+ # Title and description
124
+ st.markdown("<h1 style='text-align: center; color: #1E88E5;'>🕰️ Time Series Forecaster</h1>", unsafe_allow_html=True)
125
+ st.markdown("<p style='text-align: center; color: #6c757d;'>Predict future trends with advanced machine learning</p>", unsafe_allow_html=True)
126
+
127
+ # File upload section
128
+ st.markdown("<div class='card'>", unsafe_allow_html=True)
129
+ uploaded_file = st.file_uploader(
130
  "Upload CSV File",
131
  type=['csv'],
132
+ help="Upload a CSV file with time series data",
133
+ label_visibility="collapsed"
134
  )
135
+ st.markdown("</div>", unsafe_allow_html=True)
136
+
137
+ # Data and forecast configuration
138
  if uploaded_file is not None:
139
  # Read CSV
140
  df = pd.read_csv(uploaded_file)
141
 
142
+ # Configuration card
143
+ st.markdown("<div class='card'>", unsafe_allow_html=True)
144
+ col1, col2 = st.columns(2)
145
+
146
+ with col1:
147
+ date_column = st.selectbox(
148
+ 'Select Date Column',
149
+ options=df.columns,
150
+ help="Choose the column representing timestamps"
151
+ )
152
+
153
+ with col2:
154
+ value_column = st.selectbox(
155
+ 'Select Value Column',
156
+ options=[col for col in df.columns if col != date_column],
157
+ help="Choose the numeric column to forecast"
158
+ )
159
 
160
  # Prediction parameters
161
+ col3, col4 = st.columns(2)
162
+
163
+ with col3:
164
+ context_length = st.slider(
165
+ 'Context Length',
166
+ min_value=10,
167
+ max_value=100,
168
+ value=30,
169
+ help="Number of historical data points to use for prediction"
170
+ )
171
+
172
+ with col4:
173
+ prediction_length = st.slider(
174
+ 'Prediction Length',
175
+ min_value=1,
176
+ max_value=30,
177
+ value=7,
178
+ help="Number of future time steps to predict"
179
+ )
180
 
181
+ # Forecast button
182
+ if st.button('Generate Forecast'):
183
  try:
184
  # Initialize forecaster
185
  forecaster = TimeSeriesForecaster()
 
196
  # Perform forecasting
197
  forecasts = forecaster.forecast(context, prediction_length)
198
 
199
+ # Visualization card
200
+ st.markdown("<div class='card'>", unsafe_allow_html=True)
201
  st.subheader('Forecast Visualization')
202
  plt = forecaster.visualize_forecast(context, forecasts)
203
  st.pyplot(plt)
204
+ st.markdown("</div>", unsafe_allow_html=True)
205
+
206
+ # Forecast details card
207
+ st.markdown("<div class='card'>", unsafe_allow_html=True)
208
+ st.subheader('Forecast Details')
209
 
 
210
  forecast_np = forecasts[0].numpy()
211
  forecast_mean = forecast_np.mean(axis=0)
212
  forecast_lower = np.percentile(forecast_np, 10, axis=0)
 
218
  'Upper Bound (90%)': forecast_upper
219
  })
220
 
 
221
  st.dataframe(prediction_df)
222
+ st.markdown("</div>", unsafe_allow_html=True)
223
 
224
  except Exception as e:
225
  st.error(f"An error occurred: {str(e)}")
226
+
227
+ st.markdown("</div>", unsafe_allow_html=True)
228
 
229
  if __name__ == '__main__':
230
  main()