TOPSInfosol's picture
Update app.py
a013f3d verified
raw
history blame
7.46 kB
from __future__ import annotations
from typing import Iterable
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
import math
import torch
from chronos import ChronosPipeline
import warnings
from seafoam import Seafoam
warnings.filterwarnings("ignore")
import numpy as np
import matplotlib.ticker as ticker
os.makedirs("example_files", exist_ok=True)
def process_csv(file):
if file is None:
return None, gr.Dropdown(choices=[])
if not file.name.endswith('.csv'):
raise gr.Error("Please upload a CSV file only")
df = pd.read_csv(file.name)
columns = df.columns.tolist()
transformed_columns = list(map(lambda x: ' '.join([word.capitalize() for word in x.split('_')]), columns))
data_columns = gr.Dropdown(choices=transformed_columns, value=None)
return df, data_columns, data_columns
def process_data(csv_file, date_column_value, target_column_value):
try:
if not csv_file:
return "Error: Upload Csv File"
if not date_column_value or not target_column_value:
return "Error: Both date and target columns must be selected"
date_column = date_column_value.lower().replace(" ", "_")
target_column = target_column_value.lower().replace(" ", "_")
# Read the CSV file
df = pd.read_csv(csv_file.name)
numeric_mask = df[date_column].apply(lambda x: isinstance(x, (int, float)))
if numeric_mask.any():
return "Error: Found numeric values in column '{date_column}'. Please provide dates in string format like 'YYYY-MM-DD'."
df['date'] = pd.to_datetime(df[date_column])
df['month'] = df['date'].dt.month
df['year'] = df['date'].dt.year
df['sold_qty'] = df[target_column]
monthly_sales = df.groupby(['year', 'month'])['sold_qty'].sum().reset_index()
monthly_sales = monthly_sales.rename(columns={'year': 'year', 'month': 'month', 'sold_qty': 'y'})
pipeline = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-base",
device_map="cpu",
torch_dtype=torch.float32,
)
context = torch.tensor(monthly_sales["y"])
prediction_length = 12
forecast = pipeline.predict(context, prediction_length)
# Prepare forecast data
forecast_index = range(len(monthly_sales), len(monthly_sales) + prediction_length)
low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
df['month_name'] = df['date'].dt.month_name()
month_order = [
'January', 'February', 'March', 'April', 'May', 'June',
'July', 'August', 'September', 'October', 'November', 'December'
]
df['month_name'] = pd.Categorical(df['month_name'], categories=month_order, ordered=True)
expanded_df = df.copy()
year_month_sum = expanded_df.groupby(['year', 'month_name'])['sold_qty'].sum().reset_index()
# Create a pivot table: sum of units sold per year and month
pivot_table = year_month_sum.pivot(index='year', columns='month_name', values='sold_qty')
new_data_list = [math.ceil(x) for x in median]
# Add the new data list for the next year (incrementing the year by 1)
next_year = pivot_table.index[-1] + 1 # Increment the year by 1
pivot_table.loc[next_year] = new_data_list # Add the new row for the next year
# Visualization: Pivot Table Data (Second Plot)
fig3, ax3 = plt.subplots(figsize=(18, 6))
# Create a table inside the plot
ax3.axis('off') # Turn off the axis
table = ax3.table(cellText=pivot_table.values, colLabels=pivot_table.columns, rowLabels=pivot_table.index, loc='center', cellLoc='center')
# Style the table
table.auto_set_font_size(False)
table.set_fontsize(12)
table.scale(1.2, 1.2) # Scale the table for better visibility
# Adjust table colors (optional)
for (i, j), cell in table.get_celld().items():
if i == 0:
cell.set_text_props(weight='bold')
cell.set_facecolor('#f2f2f2')
elif j == 0:
cell.set_text_props(weight='bold')
cell.set_facecolor('#f2f2f2')
else:
cell.set_facecolor('white')
# Visualization
plt.figure(figsize=(30, 10))
plt.plot(monthly_sales["y"], color="royalblue", label="Historical Data", linewidth=2)
plt.plot(forecast_index, median, color="tomato", label="Median Forecast", linewidth=2)
plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% Prediction Interval")
plt.title("Sales Forecasting Visualization", fontsize=16)
plt.xlabel("Months", fontsize=20)
plt.ylabel("Sold Qty", fontsize=20)
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
ax = plt.gca()
ax.xaxis.set_major_locator(ticker.MultipleLocator(3))
ax.yaxis.set_major_locator(ticker.MultipleLocator(5))
ax.grid(which='major', linestyle='--', linewidth=1.2, color='gray', alpha=0.7)
plt.legend(fontsize=18)
plt.grid(linestyle='--', linewidth=1.2, color='gray', alpha=0.7)
plt.tight_layout()
return plt.gcf(), fig3
except Exception as e:
print(f"Error: {str(e)}")
return None
# Create Gradio interface
with gr.Blocks(theme=Seafoam()) as demo:
gr.Markdown("# Chronos Forecasting - Tops infosolutions Pvt Ltd")
gr.Markdown("Upload a CSV file and click 'Forecast' to generate sales forecast for next 12 months .")
df_state = gr.State()
with gr.Row():
file_input = gr.File(label="Upload CSV File", file_types=[".csv"])
with gr.Row():
date_column = gr.Dropdown(
choices=[],
label="Select Date column",
multiselect=False,
value=None
)
target_column = gr.Dropdown(
choices=[],
label="Select Target column",
multiselect=False,
value=None
)
gr.Examples(
examples=[
["example_files/test_tops_product_id_1.csv"],
["example_files/test_tops_product_id_2.csv"],
["example_files/test_tops_product_id_3.csv"],
["example_files/test_tops_product_id_4.csv"]
],
inputs=file_input,
outputs=[df_state, date_column, target_column],
fn=process_csv,
cache_examples=True
)
with gr.Row():
visualize_btn = gr.Button("Forecast", variant="primary")
with gr.Row():
plot_output = gr.Plot(label="Chronos Forecasting Visualization")
with gr.Row():
pivot_plot_output = gr.Plot(label="Monthly Sales Pivot Table")
file_input.upload(
process_csv,
inputs=[file_input],
outputs=[df_state, date_column, target_column]
)
# Column selection handler
date_column.change(
lambda x: x if x else "",
inputs=[date_column],
outputs=[]
)
target_column.change(
lambda x: x if x else "",
inputs=[target_column],
outputs=[]
)
visualize_btn.click(
fn=process_data,
inputs=[file_input, date_column, target_column],
outputs=[plot_output, pivot_plot_output]
)
# Launch the app
if __name__ == "__main__":
demo.launch()