TOPSInfosol
commited on
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import gradio as gr
|
3 |
+
from pathlib import Path
|
4 |
+
import plotly.express as px
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from chronos import ChronosPipeline
|
8 |
+
from datetime import datetime
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import matplotlib.ticker as ticker
|
11 |
+
|
12 |
+
|
13 |
+
def filter_data(start, end, df_state, select_product_column, date_column, target_column):
|
14 |
+
|
15 |
+
if not date_column:
|
16 |
+
raise gr.Error("Please select a Date column")
|
17 |
+
|
18 |
+
if not target_column:
|
19 |
+
raise gr.Error("Please select a target column")
|
20 |
+
|
21 |
+
start_datetime = pd.to_datetime(datetime.utcfromtimestamp(start))
|
22 |
+
end_datetime = pd.to_datetime(datetime.utcfromtimestamp(end))
|
23 |
+
|
24 |
+
original_date_column = None
|
25 |
+
original_target_column = None
|
26 |
+
|
27 |
+
column_mapping = {
|
28 |
+
' '.join([word.capitalize() for word in col.split('_')]): col
|
29 |
+
for col in df_state.columns
|
30 |
+
}
|
31 |
+
if date_column in column_mapping:
|
32 |
+
original_date_column = column_mapping[date_column]
|
33 |
+
|
34 |
+
if target_column in column_mapping:
|
35 |
+
original_target_column = column_mapping[target_column]
|
36 |
+
|
37 |
+
df_state[original_date_column] = pd.to_datetime(df_state[original_date_column])
|
38 |
+
filtered_df = df_state[(df_state[original_date_column] >= start_datetime) & (df_state[original_date_column] <= end_datetime)]
|
39 |
+
|
40 |
+
|
41 |
+
fig = px.line(filtered_df, x=original_date_column, y=original_target_column, title="Historical Sales Data")
|
42 |
+
return [filtered_df, fig]
|
43 |
+
|
44 |
+
|
45 |
+
def upload_file(filepath):
|
46 |
+
name = Path(filepath).name
|
47 |
+
df = pd.read_csv(filepath.name)
|
48 |
+
datetime_columns = []
|
49 |
+
numeric_columns = []
|
50 |
+
|
51 |
+
for col in df.columns:
|
52 |
+
try:
|
53 |
+
if all(isinstance(float(x), float) for x in df[col].head(3)):
|
54 |
+
numeric_columns.append(col)
|
55 |
+
except ValueError:
|
56 |
+
continue
|
57 |
+
|
58 |
+
for col in df.columns:
|
59 |
+
if df[col].dtype == 'object':
|
60 |
+
try:
|
61 |
+
df[col] = pd.to_datetime(df[col])
|
62 |
+
except:
|
63 |
+
pass
|
64 |
+
if df[col].dtype == 'datetime64[ns]':
|
65 |
+
datetime_columns.append(col)
|
66 |
+
datetime_columns = list(map(lambda x: ' '.join([word.capitalize() for word in x.split('_')]), datetime_columns))
|
67 |
+
columns = df.columns.tolist()
|
68 |
+
transformed_columns = list(map(lambda x: ' '.join([word.capitalize() for word in x.split('_')]), columns))
|
69 |
+
target_col = list(map(lambda x: ' '.join([word.capitalize() for word in x.split('_')]), numeric_columns))
|
70 |
+
|
71 |
+
transformed_columns.insert(0, "")
|
72 |
+
data_columns = gr.Dropdown(choices=transformed_columns, value=None)
|
73 |
+
date_columns = gr.Dropdown(choices=datetime_columns, value=None)
|
74 |
+
target_columns = gr.Dropdown(choices=target_col, value=None)
|
75 |
+
|
76 |
+
return [df, data_columns, date_columns, target_columns]
|
77 |
+
|
78 |
+
def download_file():
|
79 |
+
return [gr.UploadButton(visible=True), gr.DownloadButton(visible=False)]
|
80 |
+
|
81 |
+
|
82 |
+
def set_products(selected_column, df_state):
|
83 |
+
column_mapping = {
|
84 |
+
' '.join([word.capitalize() for word in col.split('_')]): col
|
85 |
+
for col in df_state.columns
|
86 |
+
}
|
87 |
+
if selected_column in column_mapping:
|
88 |
+
original_column = column_mapping[selected_column]
|
89 |
+
unique_values = df_state[original_column].dropna().unique().tolist()
|
90 |
+
return unique_values
|
91 |
+
return []
|
92 |
+
|
93 |
+
|
94 |
+
def set_dates(selected_column, df_state):
|
95 |
+
column_mapping = {
|
96 |
+
' '.join([word.capitalize() for word in col.split('_')]): col
|
97 |
+
for col in df_state.columns
|
98 |
+
}
|
99 |
+
|
100 |
+
if selected_column in column_mapping:
|
101 |
+
original_column = column_mapping[selected_column]
|
102 |
+
min_date = df_state[original_column].min()
|
103 |
+
max_date = df_state[original_column].max()
|
104 |
+
return min_date, max_date
|
105 |
+
return None, None
|
106 |
+
|
107 |
+
|
108 |
+
def forecast_chronos_data(df_state, date_column, target_column, select_period, forecasting_type):
|
109 |
+
if not date_column:
|
110 |
+
raise gr.Error("Please select a Date column")
|
111 |
+
|
112 |
+
if not target_column:
|
113 |
+
raise gr.Error("Please select a target column")
|
114 |
+
|
115 |
+
original_date_column = None
|
116 |
+
original_target_column = None
|
117 |
+
|
118 |
+
column_mapping = {
|
119 |
+
' '.join([word.capitalize() for word in col.split('_')]): col
|
120 |
+
for col in df_state.columns
|
121 |
+
}
|
122 |
+
if date_column in column_mapping:
|
123 |
+
original_date_column = column_mapping[date_column]
|
124 |
+
|
125 |
+
if target_column in column_mapping:
|
126 |
+
original_target_column = column_mapping[target_column]
|
127 |
+
|
128 |
+
df_forecast = pd.DataFrame()
|
129 |
+
df_forecast['date'] = df_state[original_date_column]
|
130 |
+
df_forecast['month'] = df_forecast['date'].dt.month
|
131 |
+
df_forecast['year'] = df_forecast['date'].dt.year
|
132 |
+
df_forecast['sold_qty'] = df_state[original_target_column]
|
133 |
+
|
134 |
+
monthly_sales = df_forecast.groupby(['year', 'month'])['sold_qty'].sum().reset_index()
|
135 |
+
monthly_sales = monthly_sales.rename(columns={'year': 'year', 'month': 'month', 'sold_qty': 'y'})
|
136 |
+
|
137 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
138 |
+
|
139 |
+
pipeline = ChronosPipeline.from_pretrained(
|
140 |
+
"amazon/chronos-t5-base",
|
141 |
+
device_map=device,
|
142 |
+
torch_dtype=torch.float32,
|
143 |
+
)
|
144 |
+
context = torch.tensor(monthly_sales["y"])
|
145 |
+
prediction_length = select_period
|
146 |
+
forecast = pipeline.predict(context, prediction_length)
|
147 |
+
|
148 |
+
forecast_index = range(len(monthly_sales), len(monthly_sales) + prediction_length)
|
149 |
+
low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
|
150 |
+
|
151 |
+
plt.figure(figsize=(30, 10))
|
152 |
+
plt.plot(monthly_sales["y"], color="royalblue", label="Historical Data", linewidth=2)
|
153 |
+
plt.plot(forecast_index, median, color="tomato", label="Median Forecast", linewidth=2)
|
154 |
+
plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% Prediction Interval")
|
155 |
+
plt.title("Sales Forecasting Visualization", fontsize=16)
|
156 |
+
plt.xlabel("Months", fontsize=20)
|
157 |
+
plt.ylabel("Sold Qty", fontsize=20)
|
158 |
+
|
159 |
+
plt.xticks(fontsize=18)
|
160 |
+
plt.yticks(fontsize=18)
|
161 |
+
|
162 |
+
ax = plt.gca()
|
163 |
+
ax.xaxis.set_major_locator(ticker.MultipleLocator(3))
|
164 |
+
ax.yaxis.set_major_locator(ticker.MultipleLocator(5))
|
165 |
+
ax.grid(which='major', linestyle='--', linewidth=1.2, color='gray', alpha=0.7)
|
166 |
+
|
167 |
+
plt.legend(fontsize=18)
|
168 |
+
plt.grid(linestyle='--', linewidth=1.2, color='gray', alpha=0.7)
|
169 |
+
plt.tight_layout()
|
170 |
+
return plt.gcf()
|
171 |
+
|
172 |
+
|
173 |
+
def home_page():
|
174 |
+
content = """
|
175 |
+
### **Sales Forecasting with Chronos**
|
176 |
+
|
177 |
+
Welcome to the future of sales optimization with **Chronos**.
|
178 |
+
Say goodbye to guesswork and unlock the power of **data-driven insights** with our advanced forecasting platform.
|
179 |
+
|
180 |
+
- **Seamless CSV Upload**: Quickly upload your sales data in CSV formatโno technical expertise needed.
|
181 |
+
- **AI-Powered Predictions**: Harness the power of state-of-the-art machine learning models to uncover trends and forecast future sales performance.
|
182 |
+
- **Interactive Visualizations**: Gain actionable insights with intuitive charts and graphs that make data easy to understand.
|
183 |
+
|
184 |
+
Start making smarter, data-backed business decisions today with **Chronos**!
|
185 |
+
|
186 |
+
"""
|
187 |
+
return content
|
188 |
+
|
189 |
+
def about_page():
|
190 |
+
content = """
|
191 |
+
### ๐ง **Contact Us:**
|
192 |
+
- **Email**: [email protected] โ๏ธ
|
193 |
+
- **Website**: [https://www.topsinfosolutions.com/](https://www.topsinfosolutions.com/) ๐
|
194 |
+
|
195 |
+
### ๐ **What We Offer:**
|
196 |
+
- **Custom AI Solutions**: Tailored to your business needs ๐ค
|
197 |
+
- **Chatbot Development**: Build intelligent conversational agents ๐ฌ
|
198 |
+
- **Vision Models**: Computer vision solutions for various applications ๐ผ๏ธ
|
199 |
+
- **AI Agents**: Personalized agents powered by advanced LLMs ๐ค
|
200 |
+
|
201 |
+
### ๐ค **How We Can Help:**
|
202 |
+
Reach out to us for bespoke AI services. Whether you need chatbots, vision models, or AI-powered agents, weโre here to build solutions that make a difference! ๐
|
203 |
+
|
204 |
+
### ๐ฌ **Get in Touch:**
|
205 |
+
If you have any questions or need a custom solution, click the button below to schedule a consultation with us. ๐
|
206 |
+
"""
|
207 |
+
return content
|
208 |
+
|
209 |
+
|
210 |
+
with gr.Blocks() as demo:
|
211 |
+
with gr.Tabs():
|
212 |
+
with gr.TabItem("Home"):
|
213 |
+
df_state = gr.State()
|
214 |
+
|
215 |
+
gr.Image("/content/chronos-logo.png", interactive=False)
|
216 |
+
home_output = gr.Markdown(value=home_page(), label="Playground")
|
217 |
+
|
218 |
+
gr.Markdown("## Step 1: Historical/Training Data (currently supports *.csv only)")
|
219 |
+
|
220 |
+
with gr.Row():
|
221 |
+
file_input = gr.File(label="Upload Historical (Training Data) Sales Data", file_types=[".csv"])
|
222 |
+
|
223 |
+
with gr.Row():
|
224 |
+
date_column = gr.Dropdown(choices=[], label="Select Date column (*Required)", multiselect=False, value=None)
|
225 |
+
target_column = gr.Dropdown(choices=[], label="Select Target column (*Required)", multiselect=False, value=None)
|
226 |
+
select_product_column = gr.Dropdown(choices=[], label="Select Product column (Optional)", multiselect=False, value=None)
|
227 |
+
select_product = gr.Dropdown(choices=[], label="Select Product (Optional)", multiselect=False, value=None)
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
with gr.Row():
|
232 |
+
start = gr.DateTime("2021-01-01 00:00:00", label="Training data Start date")
|
233 |
+
end = gr.DateTime("2021-01-05 00:00:00", label="Training data End date")
|
234 |
+
apply_btn = gr.Button("Visualize Data", scale=0)
|
235 |
+
|
236 |
+
with gr.Row():
|
237 |
+
historical_data_plot = gr.Plot()
|
238 |
+
|
239 |
+
apply_btn.click(
|
240 |
+
filter_data,
|
241 |
+
inputs=[start, end, df_state, select_product_column, date_column, target_column],
|
242 |
+
outputs=[df_state, historical_data_plot]
|
243 |
+
)
|
244 |
+
|
245 |
+
gr.Markdown("## Step 2: Forecast")
|
246 |
+
with gr.Row():
|
247 |
+
forecasting_type = gr.Radio(["day", "monthly", "year"], value="monthly", label="Forecasting Type", interactive=False)
|
248 |
+
select_period = gr.Slider(2, 60, value=12, label="Select Period", info="Check Selected Forecast Type", interactive =True, step=1)
|
249 |
+
forecast_btn = gr.Button("Forecast")
|
250 |
+
|
251 |
+
|
252 |
+
with gr.Row():
|
253 |
+
plot_forecast_output = gr.Plot(label="Chronos Forecasting Visualization")
|
254 |
+
|
255 |
+
forecast_btn.click(
|
256 |
+
forecast_chronos_data,
|
257 |
+
inputs=[df_state, date_column, target_column, select_period],
|
258 |
+
outputs=[plot_forecast_output]
|
259 |
+
)
|
260 |
+
|
261 |
+
|
262 |
+
|
263 |
+
file_input.upload(
|
264 |
+
upload_file,
|
265 |
+
inputs=[file_input],
|
266 |
+
outputs=[df_state, select_product_column, date_column, target_column]
|
267 |
+
)
|
268 |
+
|
269 |
+
select_product_column.change(
|
270 |
+
set_products,
|
271 |
+
inputs=[select_product_column, df_state],
|
272 |
+
outputs=[]
|
273 |
+
)
|
274 |
+
|
275 |
+
date_column.change(
|
276 |
+
set_dates,
|
277 |
+
inputs=[date_column, df_state],
|
278 |
+
outputs=[start, end]
|
279 |
+
)
|
280 |
+
|
281 |
+
target_column.change(
|
282 |
+
lambda x: x if x else [],
|
283 |
+
inputs=[target_column],
|
284 |
+
outputs=[]
|
285 |
+
)
|
286 |
+
|
287 |
+
with gr.TabItem("About Tops"):
|
288 |
+
df_state = gr.State()
|
289 |
+
|
290 |
+
gr.Image("/content/chronos-logo.png", interactive=False)
|
291 |
+
about_output = gr.Markdown(value=about_page(), label="About Tops")
|
292 |
+
|
293 |
+
|
294 |
+
if __name__ == "__main__":
|
295 |
+
demo.launch()
|
296 |
+
|
297 |
+
|