|
import pandas as pd |
|
import gradio as gr |
|
from pathlib import Path |
|
import plotly.express as px |
|
import plotly.graph_objects as go |
|
import numpy as np |
|
import torch |
|
from chronos import ChronosPipeline |
|
from datetime import datetime |
|
import matplotlib.pyplot as plt |
|
import matplotlib.ticker as ticker |
|
|
|
|
|
def filter_data(start, end, df_state, select_product_column, date_column, target_column): |
|
|
|
if not date_column: |
|
raise gr.Error("Please select a Date column") |
|
|
|
if not target_column: |
|
raise gr.Error("Please select a target column") |
|
|
|
start_datetime = pd.to_datetime(datetime.utcfromtimestamp(start)) |
|
end_datetime = pd.to_datetime(datetime.utcfromtimestamp(end)) |
|
|
|
original_date_column = None |
|
original_target_column = None |
|
|
|
column_mapping = { |
|
' '.join([word.capitalize() for word in col.split('_')]): col |
|
for col in df_state.columns |
|
} |
|
if date_column in column_mapping: |
|
original_date_column = column_mapping[date_column] |
|
|
|
if target_column in column_mapping: |
|
original_target_column = column_mapping[target_column] |
|
|
|
df_state[original_date_column] = pd.to_datetime(df_state[original_date_column]) |
|
filtered_df = df_state[(df_state[original_date_column] >= start_datetime) & (df_state[original_date_column] <= end_datetime)] |
|
|
|
filtered_df = filtered_df.groupby(original_date_column)[original_target_column].sum().reset_index() |
|
filtered_df = filtered_df.sort_values(by=original_date_column) |
|
fig = px.line(filtered_df, x=original_date_column, y=original_target_column, title="Historical Sales Data") |
|
return [filtered_df, fig] |
|
|
|
|
|
def upload_file(filepath): |
|
name = Path(filepath).name |
|
df = pd.read_csv(filepath.name) |
|
datetime_columns = [] |
|
numeric_columns = [] |
|
|
|
for col in df.columns: |
|
try: |
|
if all(isinstance(float(x), float) for x in df[col].head(3)): |
|
numeric_columns.append(col) |
|
except ValueError: |
|
continue |
|
|
|
for col in df.columns: |
|
if df[col].dtype == 'object': |
|
try: |
|
df[col] = pd.to_datetime(df[col]) |
|
except: |
|
pass |
|
if df[col].dtype == 'datetime64[ns]': |
|
datetime_columns.append(col) |
|
datetime_columns = list(map(lambda x: ' '.join([word.capitalize() for word in x.split('_')]), datetime_columns)) |
|
columns = df.columns.tolist() |
|
transformed_columns = list(map(lambda x: ' '.join([word.capitalize() for word in x.split('_')]), columns)) |
|
target_col = list(map(lambda x: ' '.join([word.capitalize() for word in x.split('_')]), numeric_columns)) |
|
|
|
transformed_columns.insert(0, "") |
|
data_columns = gr.Dropdown(choices=transformed_columns, value=None) |
|
date_columns = gr.Dropdown(choices=datetime_columns, value=None) |
|
target_columns = gr.Dropdown(choices=target_col, value=None) |
|
|
|
return [df, data_columns, date_columns, target_columns] |
|
|
|
def download_file(): |
|
return [gr.UploadButton(visible=True), gr.DownloadButton(visible=False)] |
|
|
|
|
|
def set_products(selected_column, df_state): |
|
column_mapping = { |
|
' '.join([word.capitalize() for word in col.split('_')]): col |
|
for col in df_state.columns |
|
} |
|
if selected_column in column_mapping: |
|
original_column = column_mapping[selected_column] |
|
unique_values = df_state[original_column].dropna().unique().tolist() |
|
return unique_values |
|
return [] |
|
|
|
|
|
def set_dates(selected_column, df_state): |
|
column_mapping = { |
|
' '.join([word.capitalize() for word in col.split('_')]): col |
|
for col in df_state.columns |
|
} |
|
|
|
if selected_column in column_mapping: |
|
original_column = column_mapping[selected_column] |
|
min_date = df_state[original_column].min() |
|
max_date = df_state[original_column].max() |
|
return min_date, max_date |
|
return None, None |
|
|
|
|
|
def forecast_chronos_data(df_state, date_column, target_column, select_period, forecasting_type): |
|
if not date_column: |
|
raise gr.Error("Please select a Date column") |
|
|
|
if not target_column: |
|
raise gr.Error("Please select a target column") |
|
|
|
original_date_column = None |
|
original_target_column = None |
|
|
|
column_mapping = { |
|
' '.join([word.capitalize() for word in col.split('_')]): col |
|
for col in df_state.columns |
|
} |
|
if date_column in column_mapping: |
|
original_date_column = column_mapping[date_column] |
|
|
|
if target_column in column_mapping: |
|
original_target_column = column_mapping[target_column] |
|
|
|
df_forecast = pd.DataFrame() |
|
df_forecast['date'] = df_state[original_date_column] |
|
df_forecast['month'] = df_forecast['date'].dt.month |
|
df_forecast['year'] = df_forecast['date'].dt.year |
|
df_forecast['sold_qty'] = df_state[original_target_column] |
|
|
|
monthly_sales = df_forecast.groupby(['year', 'month'])['sold_qty'].sum().reset_index() |
|
monthly_sales = monthly_sales.rename(columns={'year': 'year', 'month': 'month', 'sold_qty': 'y'}) |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
pipeline = ChronosPipeline.from_pretrained( |
|
"amazon/chronos-t5-base", |
|
device_map=device, |
|
torch_dtype=torch.float32, |
|
) |
|
context = torch.tensor(monthly_sales["y"]) |
|
prediction_length = select_period |
|
forecast = pipeline.predict(context, prediction_length) |
|
|
|
forecast_index = range(len(monthly_sales), len(monthly_sales) + prediction_length) |
|
low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0) |
|
low, median, high = np.ceil(low).astype(int), np.ceil(median).astype(int), np.ceil(high).astype(int) |
|
|
|
forecast_index = list(forecast_index) |
|
fig = px.line( |
|
x=monthly_sales.index, |
|
y=monthly_sales["y"], |
|
title="Sales Forecasting Visualization", |
|
labels={"x": "Months", "y": f"{target_column}"}, |
|
) |
|
|
|
fig.add_trace( |
|
go.Scatter( |
|
x=forecast_index, |
|
y=median, |
|
name="Median Forecast", |
|
line=dict(color="tomato", width=2) |
|
) |
|
) |
|
|
|
fig.add_trace( |
|
go.Scatter( |
|
x=forecast_index, |
|
y=high, |
|
name="80% Prediction Interval", |
|
mode='lines', |
|
line=dict(width=2, color='rgba(50, 205, 50, 1)'), |
|
showlegend=False |
|
) |
|
) |
|
|
|
fig.add_trace( |
|
go.Scatter( |
|
x=forecast_index, |
|
y=low, |
|
name="10% Prediction Interval", |
|
mode='lines', |
|
line=dict(width=1, color='rgba(255, 255, 0, 1)'), |
|
showlegend=False, |
|
fillcolor='rgba(255, 99, 71, 0.3)', |
|
fill='tonexty', |
|
) |
|
) |
|
|
|
fig.update_layout( |
|
title_font_size=20, |
|
xaxis_title_font_size=16, |
|
yaxis_title_font_size=16, |
|
legend_font_size=16, |
|
xaxis_tickfont_size=14, |
|
yaxis_tickfont_size=14, |
|
showlegend=True, |
|
width=1600, |
|
height=400, |
|
xaxis=dict( |
|
title="Months", |
|
tickfont=dict(size=14), |
|
gridcolor='rgba(128, 128, 128, 0.7)', |
|
gridwidth=1.2, |
|
dtick=3, |
|
griddash='dash', |
|
rangeslider=dict(visible=True), |
|
rangeselector=dict( |
|
buttons=list([ |
|
dict(count=6, label="6m", step="month", stepmode="backward"), |
|
dict(count=12, label="1y", step="month", stepmode="backward"), |
|
dict(count=24, label="2y", step="month", stepmode="backward"), |
|
dict(step="all", label="All") |
|
]) |
|
) |
|
), |
|
yaxis=dict( |
|
gridcolor='rgba(128, 128, 128, 0.7)', |
|
gridwidth=1.2, |
|
dtick=5, |
|
griddash='dash' |
|
), |
|
plot_bgcolor='white' |
|
|
|
) |
|
|
|
fig.update_traces( |
|
line=dict(color="royalblue", width=2), |
|
selector=dict(name="y") |
|
) |
|
return fig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def home_page(): |
|
content = """ |
|
### **Sales Forecasting with Chronos** |
|
|
|
Welcome to the future of sales optimization with **Chronos**. |
|
Say goodbye to guesswork and unlock the power of **data-driven insights** with our advanced forecasting platform. |
|
|
|
- **Seamless CSV Upload**: Quickly upload your sales data in CSV formatβno technical expertise needed. |
|
- **AI-Powered Predictions**: Harness the power of state-of-the-art machine learning models to uncover trends and forecast future sales performance. |
|
- **Interactive Visualizations**: Gain actionable insights with intuitive charts and graphs that make data easy to understand. |
|
|
|
Start making smarter, data-backed business decisions today with **Chronos**! |
|
|
|
""" |
|
return content |
|
|
|
def about_page(): |
|
content = """ |
|
### π§ **Contact Us:** |
|
- **Email**: [email protected] βοΈ |
|
- **Website**: [https://www.topsinfosolutions.com/](https://www.topsinfosolutions.com/) π |
|
|
|
### π **What We Offer:** |
|
- **Custom AI Solutions**: Tailored to your business needs π€ |
|
- **Chatbot Development**: Build intelligent conversational agents π¬ |
|
- **Vision Models**: Computer vision solutions for various applications πΌοΈ |
|
- **AI Agents**: Personalized agents powered by advanced LLMs π€ |
|
|
|
### π€ **How We Can Help:** |
|
Reach out to us for bespoke AI services. Whether you need chatbots, vision models, or AI-powered agents, weβre here to build solutions that make a difference! π |
|
|
|
### π¬ **Get in Touch:** |
|
If you have any questions or need a custom solution, click the button below to schedule a consultation with us. π
|
|
""" |
|
return content |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Default()) as demo: |
|
with gr.Tabs(): |
|
with gr.TabItem("Home"): |
|
df_state = gr.State() |
|
|
|
|
|
home_output = gr.Markdown(value=home_page(), label="Playground") |
|
|
|
gr.Markdown("## Step 1: Historical/Training Data (currently supports *.csv only)") |
|
|
|
with gr.Row(): |
|
file_input = gr.File(label="Upload Historical (Training Data) Sales Data", file_types=[".csv"]) |
|
|
|
with gr.Row(): |
|
date_column = gr.Dropdown(choices=[], label="Select Date column (*Required)", multiselect=False, value=None) |
|
target_column = gr.Dropdown(choices=[], label="Select Target column (*Required)", multiselect=False, value=None) |
|
select_product_column = gr.Dropdown(choices=[], label="Select Product column (Optional)", multiselect=False, value=None) |
|
select_product = gr.Dropdown(choices=[], label="Select Product (Optional)", multiselect=False, value=None) |
|
|
|
|
|
|
|
with gr.Row(): |
|
start = gr.DateTime("2021-01-01 00:00:00", label="Training data Start date") |
|
end = gr.DateTime("2021-01-05 00:00:00", label="Training data End date") |
|
apply_btn = gr.Button("Visualize Data", scale=0) |
|
|
|
gr.Examples( |
|
examples=[ |
|
["example_files/test_tops_product_id_1.csv"], |
|
["example_files/test_tops_product_id_2.csv"], |
|
["example_files/test_tops_product_id_3.csv"], |
|
["example_files/test_tops_product_id_4.csv"] |
|
], |
|
inputs=file_input, |
|
outputs=[df_state, select_product_column, date_column, target_column], |
|
fn=upload_file, |
|
) |
|
|
|
with gr.Row(): |
|
historical_data_plot = gr.Plot() |
|
|
|
apply_btn.click( |
|
filter_data, |
|
inputs=[start, end, df_state, select_product_column, date_column, target_column], |
|
outputs=[df_state, historical_data_plot] |
|
) |
|
|
|
gr.Markdown("## Step 2: Forecast") |
|
with gr.Row(): |
|
forecasting_type = gr.Radio(["day", "monthly", "year"], value="monthly", label="Forecasting Type", interactive=False) |
|
select_period = gr.Slider(2, 60, value=12, label="Select Period", info="Check Selected Forecast Type", interactive =True, step=1) |
|
forecast_btn = gr.Button("Forecast") |
|
|
|
|
|
with gr.Row(): |
|
plot_forecast_output = gr.Plot(label="Chronos Forecasting Visualization") |
|
|
|
forecast_btn.click( |
|
forecast_chronos_data, |
|
inputs=[df_state, date_column, target_column, select_period], |
|
outputs=[plot_forecast_output] |
|
) |
|
|
|
|
|
|
|
file_input.change( |
|
upload_file, |
|
inputs=[file_input], |
|
outputs=[df_state, select_product_column, date_column, target_column] |
|
) |
|
|
|
select_product_column.change( |
|
set_products, |
|
inputs=[select_product_column, df_state], |
|
outputs=[] |
|
) |
|
|
|
date_column.change( |
|
set_dates, |
|
inputs=[date_column, df_state], |
|
outputs=[start, end] |
|
) |
|
|
|
target_column.change( |
|
lambda x: x if x else [], |
|
inputs=[target_column], |
|
outputs=[] |
|
) |
|
|
|
with gr.TabItem("About Tops"): |
|
df_state = gr.State() |
|
|
|
|
|
about_output = gr.Markdown(value=about_page(), label="About Tops") |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|
|
|
|
|