TOPSInfosol commited on
Commit
dbae9bc
·
verified ·
1 Parent(s): 1f0dbf4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -20
app.py CHANGED
@@ -2,6 +2,7 @@ import pandas as pd
2
  import gradio as gr
3
  from pathlib import Path
4
  import plotly.express as px
 
5
  import numpy as np
6
  import torch
7
  from chronos import ChronosPipeline
@@ -149,26 +150,86 @@ def forecast_chronos_data(df_state, date_column, target_column, select_period, f
149
  forecast_index = range(len(monthly_sales), len(monthly_sales) + prediction_length)
150
  low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
151
 
152
- plt.figure(figsize=(30, 10))
153
- plt.plot(monthly_sales["y"], color="royalblue", label="Historical Data", linewidth=2)
154
- plt.plot(forecast_index, median, color="tomato", label="Median Forecast", linewidth=2)
155
- plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% Prediction Interval")
156
- plt.title("Sales Forecasting Visualization", fontsize=16)
157
- plt.xlabel("Months", fontsize=20)
158
- plt.ylabel("Sold Qty", fontsize=20)
159
-
160
- plt.xticks(fontsize=18)
161
- plt.yticks(fontsize=18)
162
-
163
- ax = plt.gca()
164
- ax.xaxis.set_major_locator(ticker.MultipleLocator(3))
165
- ax.yaxis.set_major_locator(ticker.MultipleLocator(5))
166
- ax.grid(which='major', linestyle='--', linewidth=1.2, color='gray', alpha=0.7)
167
-
168
- plt.legend(fontsize=18)
169
- plt.grid(linestyle='--', linewidth=1.2, color='gray', alpha=0.7)
170
- plt.tight_layout()
171
- return plt.gcf()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
 
174
  def home_page():
 
2
  import gradio as gr
3
  from pathlib import Path
4
  import plotly.express as px
5
+ import plotly.graph_objects as go
6
  import numpy as np
7
  import torch
8
  from chronos import ChronosPipeline
 
150
  forecast_index = range(len(monthly_sales), len(monthly_sales) + prediction_length)
151
  low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
152
 
153
+ forecast_index = list(forecast_index)
154
+ fig = px.line(
155
+ x=monthly_sales.index,
156
+ y=monthly_sales["y"],
157
+ title="Sales Forecasting Visualization",
158
+ labels={"x": "Months", "y": f"{target_column}"},
159
+ )
160
+
161
+ fig.add_trace(
162
+ go.Scatter(
163
+ x=forecast_index,
164
+ y=median,
165
+ name="Median Forecast",
166
+ line=dict(color="tomato", width=2)
167
+ )
168
+ )
169
+
170
+ fig.add_trace(
171
+ go.Scatter(
172
+ x=forecast_index,
173
+ y=high,
174
+ name="80% Prediction Interval",
175
+ mode='lines',
176
+ line=dict(width=0),
177
+ showlegend=False
178
+ )
179
+ )
180
+
181
+ fig.update_layout(
182
+ title_font_size=20,
183
+ xaxis_title_font_size=16,
184
+ yaxis_title_font_size=16,
185
+ legend_font_size=16,
186
+ xaxis_tickfont_size=14,
187
+ yaxis_tickfont_size=14,
188
+ showlegend=True,
189
+ width=1200, # Equivalent to figsize=(30, 10)
190
+ height=400,
191
+ xaxis=dict(
192
+ gridcolor='rgba(128, 128, 128, 0.7)',
193
+ gridwidth=1.2,
194
+ dtick=3, # Set tick interval to 3 months
195
+ griddash='dash'
196
+ ),
197
+ yaxis=dict(
198
+ gridcolor='rgba(128, 128, 128, 0.7)',
199
+ gridwidth=1.2,
200
+ dtick=5, # Set tick interval to 5 units
201
+ griddash='dash'
202
+ ),
203
+ plot_bgcolor='white'
204
+ margin=dict(l=50, r=50, t=50, b=50)
205
+ )
206
+
207
+ fig.update_traces(
208
+ line=dict(color="royalblue", width=2),
209
+ selector=dict(name="y") # Updates only the historical data line
210
+ )
211
+ return fig
212
+
213
+ # plt.figure(figsize=(30, 10))
214
+ # plt.plot(monthly_sales["y"], color="royalblue", label="Historical Data", linewidth=2)
215
+ # plt.plot(forecast_index, median, color="tomato", label="Median Forecast", linewidth=2)
216
+ # plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% Prediction Interval")
217
+ # plt.title("Sales Forecasting Visualization", fontsize=16)
218
+ # plt.xlabel("Months", fontsize=20)
219
+ # plt.ylabel("Sold Qty", fontsize=20)
220
+
221
+ # plt.xticks(fontsize=18)
222
+ # plt.yticks(fontsize=18)
223
+
224
+ # ax = plt.gca()
225
+ # ax.xaxis.set_major_locator(ticker.MultipleLocator(3))
226
+ # ax.yaxis.set_major_locator(ticker.MultipleLocator(5))
227
+ # ax.grid(which='major', linestyle='--', linewidth=1.2, color='gray', alpha=0.7)
228
+
229
+ # plt.legend(fontsize=18)
230
+ # plt.grid(linestyle='--', linewidth=1.2, color='gray', alpha=0.7)
231
+ # plt.tight_layout()
232
+ # return plt.gcf()
233
 
234
 
235
  def home_page():