TOPSInfosol's picture
Update app.py
9b8c0e9 verified
import pandas as pd
import gradio as gr
from pathlib import Path
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import torch
from chronos import ChronosPipeline
from datetime import datetime
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
def filter_data(start, end, df_state, select_product_column, date_column, target_column):
if not date_column:
raise gr.Error("Please select a Date column")
if not target_column:
raise gr.Error("Please select a target column")
start_datetime = pd.to_datetime(datetime.utcfromtimestamp(start))
end_datetime = pd.to_datetime(datetime.utcfromtimestamp(end))
original_date_column = None
original_target_column = None
column_mapping = {
' '.join([word.capitalize() for word in col.split('_')]): col
for col in df_state.columns
}
if date_column in column_mapping:
original_date_column = column_mapping[date_column]
if target_column in column_mapping:
original_target_column = column_mapping[target_column]
df_state[original_date_column] = pd.to_datetime(df_state[original_date_column])
filtered_df = df_state[(df_state[original_date_column] >= start_datetime) & (df_state[original_date_column] <= end_datetime)]
filtered_df = filtered_df.groupby(original_date_column)[original_target_column].sum().reset_index()
filtered_df = filtered_df.sort_values(by=original_date_column)
fig = px.line(filtered_df, x=original_date_column, y=original_target_column, title="Historical Sales Data")
return [filtered_df, fig]
def upload_file(filepath):
name = Path(filepath).name
df = pd.read_csv(filepath.name)
datetime_columns = []
numeric_columns = []
for col in df.columns:
try:
if all(isinstance(float(x), float) for x in df[col].head(3)):
numeric_columns.append(col)
except ValueError:
continue
for col in df.columns:
if df[col].dtype == 'object':
try:
df[col] = pd.to_datetime(df[col])
except:
pass
if df[col].dtype == 'datetime64[ns]':
datetime_columns.append(col)
datetime_columns = list(map(lambda x: ' '.join([word.capitalize() for word in x.split('_')]), datetime_columns))
columns = df.columns.tolist()
transformed_columns = list(map(lambda x: ' '.join([word.capitalize() for word in x.split('_')]), columns))
target_col = list(map(lambda x: ' '.join([word.capitalize() for word in x.split('_')]), numeric_columns))
transformed_columns.insert(0, "")
data_columns = gr.Dropdown(choices=transformed_columns, value=None)
date_columns = gr.Dropdown(choices=datetime_columns, value=None)
target_columns = gr.Dropdown(choices=target_col, value=None)
return [df, data_columns, date_columns, target_columns]
def download_file():
return [gr.UploadButton(visible=True), gr.DownloadButton(visible=False)]
def set_products(selected_column, df_state):
column_mapping = {
' '.join([word.capitalize() for word in col.split('_')]): col
for col in df_state.columns
}
if selected_column in column_mapping:
original_column = column_mapping[selected_column]
unique_values = df_state[original_column].dropna().unique().tolist()
return unique_values
return []
def set_dates(selected_column, df_state):
column_mapping = {
' '.join([word.capitalize() for word in col.split('_')]): col
for col in df_state.columns
}
if selected_column in column_mapping:
original_column = column_mapping[selected_column]
min_date = df_state[original_column].min()
max_date = df_state[original_column].max()
return min_date, max_date
return None, None
def forecast_chronos_data(df_state, date_column, target_column, select_period, forecasting_type):
if not date_column:
raise gr.Error("Please select a Date column")
if not target_column:
raise gr.Error("Please select a target column")
original_date_column = None
original_target_column = None
column_mapping = {
' '.join([word.capitalize() for word in col.split('_')]): col
for col in df_state.columns
}
if date_column in column_mapping:
original_date_column = column_mapping[date_column]
if target_column in column_mapping:
original_target_column = column_mapping[target_column]
df_forecast = pd.DataFrame()
df_forecast['date'] = df_state[original_date_column]
df_forecast['month'] = df_forecast['date'].dt.month
df_forecast['year'] = df_forecast['date'].dt.year
df_forecast['sold_qty'] = df_state[original_target_column]
monthly_sales = df_forecast.groupby(['year', 'month'])['sold_qty'].sum().reset_index()
monthly_sales = monthly_sales.rename(columns={'year': 'year', 'month': 'month', 'sold_qty': 'y'})
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-base",
device_map=device,
torch_dtype=torch.float32,
)
context = torch.tensor(monthly_sales["y"])
prediction_length = select_period
forecast = pipeline.predict(context, prediction_length)
forecast_index = range(len(monthly_sales), len(monthly_sales) + prediction_length)
low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
low, median, high = np.ceil(low).astype(int), np.ceil(median).astype(int), np.ceil(high).astype(int)
forecast_index = list(forecast_index)
fig = px.line(
x=monthly_sales.index,
y=monthly_sales["y"],
title="Sales Forecasting Visualization",
labels={"x": "Months", "y": f"{target_column}"},
)
fig.add_trace(
go.Scatter(
x=forecast_index,
y=median,
name="Median Forecast",
line=dict(color="tomato", width=2)
)
)
fig.add_trace(
go.Scatter(
x=forecast_index,
y=high,
name="80% Prediction Interval",
mode='lines',
line=dict(width=2, color='rgba(50, 205, 50, 1)'),
showlegend=False
)
)
fig.add_trace(
go.Scatter(
x=forecast_index,
y=low,
name="10% Prediction Interval",
mode='lines',
line=dict(width=1, color='rgba(255, 255, 0, 1)'),
showlegend=False,
fillcolor='rgba(255, 99, 71, 0.3)',
fill='tonexty',
)
)
fig.update_layout(
title_font_size=20,
xaxis_title_font_size=16,
yaxis_title_font_size=16,
legend_font_size=16,
xaxis_tickfont_size=14,
yaxis_tickfont_size=14,
showlegend=True,
width=1600, # Equivalent to figsize=(30, 10)
height=400,
xaxis=dict(
title="Months",
tickfont=dict(size=14),
gridcolor='rgba(128, 128, 128, 0.7)',
gridwidth=1.2,
dtick=3,
griddash='dash',
rangeslider=dict(visible=True),
rangeselector=dict(
buttons=list([
dict(count=6, label="6m", step="month", stepmode="backward"),
dict(count=12, label="1y", step="month", stepmode="backward"),
dict(count=24, label="2y", step="month", stepmode="backward"),
dict(step="all", label="All")
])
)
),
yaxis=dict(
gridcolor='rgba(128, 128, 128, 0.7)',
gridwidth=1.2,
dtick=5, # Set tick interval to 5 units
griddash='dash'
),
plot_bgcolor='white'
# margin=dict(l=50, r=50, t=50, b=50)
)
fig.update_traces(
line=dict(color="royalblue", width=2),
selector=dict(name="y") # Updates only the historical data line
)
return fig
# plt.figure(figsize=(30, 10))
# plt.plot(monthly_sales["y"], color="royalblue", label="Historical Data", linewidth=2)
# plt.plot(forecast_index, median, color="tomato", label="Median Forecast", linewidth=2)
# plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% Prediction Interval")
# plt.title("Sales Forecasting Visualization", fontsize=16)
# plt.xlabel("Months", fontsize=20)
# plt.ylabel("Sold Qty", fontsize=20)
# plt.xticks(fontsize=18)
# plt.yticks(fontsize=18)
# ax = plt.gca()
# ax.xaxis.set_major_locator(ticker.MultipleLocator(3))
# ax.yaxis.set_major_locator(ticker.MultipleLocator(5))
# ax.grid(which='major', linestyle='--', linewidth=1.2, color='gray', alpha=0.7)
# plt.legend(fontsize=18)
# plt.grid(linestyle='--', linewidth=1.2, color='gray', alpha=0.7)
# plt.tight_layout()
# return plt.gcf()
def home_page():
content = """
### **Sales Forecasting with Chronos**
Welcome to the future of sales optimization with **Chronos**.
Say goodbye to guesswork and unlock the power of **data-driven insights** with our advanced forecasting platform.
- **Seamless CSV Upload**: Quickly upload your sales data in CSV formatβ€”no technical expertise needed.
- **AI-Powered Predictions**: Harness the power of state-of-the-art machine learning models to uncover trends and forecast future sales performance.
- **Interactive Visualizations**: Gain actionable insights with intuitive charts and graphs that make data easy to understand.
Start making smarter, data-backed business decisions today with **Chronos**!
"""
return content
def about_page():
content = """
### πŸ“§ **Contact Us:**
- **Email**: [email protected] βœ‰οΈ
- **Website**: [https://www.topsinfosolutions.com/](https://www.topsinfosolutions.com/) 🌐
### πŸ›  **What We Offer:**
- **Custom AI Solutions**: Tailored to your business needs πŸ€–
- **Chatbot Development**: Build intelligent conversational agents πŸ’¬
- **Vision Models**: Computer vision solutions for various applications πŸ–ΌοΈ
- **AI Agents**: Personalized agents powered by advanced LLMs πŸ€–
### πŸ€” **How We Can Help:**
Reach out to us for bespoke AI services. Whether you need chatbots, vision models, or AI-powered agents, we’re here to build solutions that make a difference! 🌟
### πŸ’¬ **Get in Touch:**
If you have any questions or need a custom solution, click the button below to schedule a consultation with us. πŸ“…
"""
return content
with gr.Blocks(theme=gr.themes.Default()) as demo:
with gr.Tabs():
with gr.TabItem("Home"):
df_state = gr.State()
# gr.Image("/content/chronos-logo.png", interactive=False)
home_output = gr.Markdown(value=home_page(), label="Playground")
gr.Markdown("## Step 1: Historical/Training Data (currently supports *.csv only)")
with gr.Row():
file_input = gr.File(label="Upload Historical (Training Data) Sales Data", file_types=[".csv"])
with gr.Row():
date_column = gr.Dropdown(choices=[], label="Select Date column (*Required)", multiselect=False, value=None)
target_column = gr.Dropdown(choices=[], label="Select Target column (*Required)", multiselect=False, value=None)
select_product_column = gr.Dropdown(choices=[], label="Select Product column (Optional)", multiselect=False, value=None)
select_product = gr.Dropdown(choices=[], label="Select Product (Optional)", multiselect=False, value=None)
with gr.Row():
start = gr.DateTime("2021-01-01 00:00:00", label="Training data Start date")
end = gr.DateTime("2021-01-05 00:00:00", label="Training data End date")
apply_btn = gr.Button("Visualize Data", scale=0)
gr.Examples(
examples=[
["example_files/test_tops_product_id_1.csv"],
["example_files/test_tops_product_id_2.csv"],
["example_files/test_tops_product_id_3.csv"],
["example_files/test_tops_product_id_4.csv"]
],
inputs=file_input,
outputs=[df_state, select_product_column, date_column, target_column],
fn=upload_file,
)
with gr.Row():
historical_data_plot = gr.Plot()
apply_btn.click(
filter_data,
inputs=[start, end, df_state, select_product_column, date_column, target_column],
outputs=[df_state, historical_data_plot]
)
gr.Markdown("## Step 2: Forecast")
with gr.Row():
forecasting_type = gr.Radio(["day", "monthly", "year"], value="monthly", label="Forecasting Type", interactive=False)
select_period = gr.Slider(2, 60, value=12, label="Select Period", info="Check Selected Forecast Type", interactive =True, step=1)
forecast_btn = gr.Button("Forecast")
with gr.Row():
plot_forecast_output = gr.Plot(label="Chronos Forecasting Visualization")
forecast_btn.click(
forecast_chronos_data,
inputs=[df_state, date_column, target_column, select_period],
outputs=[plot_forecast_output]
)
file_input.change(
upload_file,
inputs=[file_input],
outputs=[df_state, select_product_column, date_column, target_column]
)
select_product_column.change(
set_products,
inputs=[select_product_column, df_state],
outputs=[]
)
date_column.change(
set_dates,
inputs=[date_column, df_state],
outputs=[start, end]
)
target_column.change(
lambda x: x if x else [],
inputs=[target_column],
outputs=[]
)
with gr.TabItem("About Tops"):
df_state = gr.State()
# gr.Image("/content/chronos-logo.png", interactive=False)
about_output = gr.Markdown(value=about_page(), label="About Tops")
if __name__ == "__main__":
demo.launch()