|
from __future__ import annotations |
|
|
|
from typing import Iterable |
|
import gradio as gr |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import os |
|
import math |
|
import torch |
|
from chronos import ChronosPipeline |
|
import warnings |
|
|
|
from seafoam import Seafoam |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
import numpy as np |
|
import matplotlib.ticker as ticker |
|
|
|
os.makedirs("example_files", exist_ok=True) |
|
|
|
def process_csv(file): |
|
if file is None: |
|
return None, gr.Dropdown(choices=[]) |
|
|
|
if not file.name.endswith('.csv'): |
|
raise gr.Error("Please upload a CSV file only") |
|
df = pd.read_csv(file.name) |
|
columns = df.columns.tolist() |
|
transformed_columns = list(map(lambda x: ' '.join([word.capitalize() for word in x.split('_')]), columns)) |
|
data_columns = gr.Dropdown(choices=transformed_columns, value=None) |
|
return df, data_columns, data_columns |
|
|
|
|
|
def process_data(csv_file, date_column_value, target_column_value): |
|
try: |
|
if not csv_file: |
|
return "Error: Upload Csv File" |
|
|
|
if not date_column_value or not target_column_value: |
|
return "Error: Both date and target columns must be selected" |
|
|
|
date_column = date_column_value.lower().replace(" ", "_") |
|
target_column = target_column_value.lower().replace(" ", "_") |
|
|
|
|
|
df = pd.read_csv(csv_file.name) |
|
|
|
numeric_mask = df[date_column].apply(lambda x: isinstance(x, (int, float))) |
|
if numeric_mask.any(): |
|
return "Error: Found numeric values in column '{date_column}'. Please provide dates in string format like 'YYYY-MM-DD'." |
|
|
|
df['date'] = pd.to_datetime(df[date_column]) |
|
|
|
df['month'] = df['date'].dt.month |
|
df['year'] = df['date'].dt.year |
|
df['sold_qty'] = df[target_column] |
|
|
|
monthly_sales = df.groupby(['year', 'month'])['sold_qty'].sum().reset_index() |
|
monthly_sales = monthly_sales.rename(columns={'year': 'year', 'month': 'month', 'sold_qty': 'y'}) |
|
|
|
pipeline = ChronosPipeline.from_pretrained( |
|
"amazon/chronos-t5-base", |
|
device_map="cpu", |
|
torch_dtype=torch.float32, |
|
) |
|
context = torch.tensor(monthly_sales["y"]) |
|
prediction_length = 12 |
|
forecast = pipeline.predict(context, prediction_length) |
|
|
|
|
|
forecast_index = range(len(monthly_sales), len(monthly_sales) + prediction_length) |
|
low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0) |
|
|
|
df['month_name'] = df['date'].dt.month_name() |
|
month_order = [ |
|
'January', 'February', 'March', 'April', 'May', 'June', |
|
'July', 'August', 'September', 'October', 'November', 'December' |
|
] |
|
df['month_name'] = pd.Categorical(df['month_name'], categories=month_order, ordered=True) |
|
|
|
expanded_df = df.copy() |
|
year_month_sum = expanded_df.groupby(['year', 'month_name'])['sold_qty'].sum().reset_index() |
|
|
|
|
|
pivot_table = year_month_sum.pivot(index='year', columns='month_name', values='sold_qty') |
|
|
|
new_data_list = [math.ceil(x) for x in median] |
|
|
|
|
|
next_year = pivot_table.index[-1] + 1 |
|
pivot_table.loc[next_year] = new_data_list |
|
|
|
|
|
fig3, ax3 = plt.subplots(figsize=(18, 6)) |
|
|
|
|
|
ax3.axis('off') |
|
table = ax3.table(cellText=pivot_table.values, colLabels=pivot_table.columns, rowLabels=pivot_table.index, loc='center', cellLoc='center') |
|
|
|
|
|
table.auto_set_font_size(False) |
|
table.set_fontsize(12) |
|
table.scale(1.2, 1.2) |
|
|
|
|
|
for (i, j), cell in table.get_celld().items(): |
|
if i == 0: |
|
cell.set_text_props(weight='bold') |
|
cell.set_facecolor('#f2f2f2') |
|
elif j == 0: |
|
cell.set_text_props(weight='bold') |
|
cell.set_facecolor('#f2f2f2') |
|
else: |
|
cell.set_facecolor('white') |
|
|
|
|
|
plt.figure(figsize=(30, 10)) |
|
plt.plot(monthly_sales["y"], color="royalblue", label="Historical Data", linewidth=2) |
|
plt.plot(forecast_index, median, color="tomato", label="Median Forecast", linewidth=2) |
|
plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% Prediction Interval") |
|
plt.title("Sales Forecasting Visualization", fontsize=16) |
|
plt.xlabel("Months", fontsize=20) |
|
plt.ylabel("Sold Qty", fontsize=20) |
|
|
|
plt.xticks(fontsize=18) |
|
plt.yticks(fontsize=18) |
|
|
|
ax = plt.gca() |
|
ax.xaxis.set_major_locator(ticker.MultipleLocator(3)) |
|
ax.yaxis.set_major_locator(ticker.MultipleLocator(5)) |
|
ax.grid(which='major', linestyle='--', linewidth=1.2, color='gray', alpha=0.7) |
|
|
|
plt.legend(fontsize=18) |
|
plt.grid(linestyle='--', linewidth=1.2, color='gray', alpha=0.7) |
|
plt.tight_layout() |
|
|
|
return plt.gcf(), fig3 |
|
|
|
except Exception as e: |
|
print(f"Error: {str(e)}") |
|
return None |
|
|
|
|
|
with gr.Blocks(theme=Seafoam()) as demo: |
|
gr.Markdown("# Chronos Forecasting - Tops infosolutions Pvt Ltd") |
|
gr.Markdown("Upload a CSV file and click 'Forecast' to generate sales forecast for next 12 months .") |
|
|
|
df_state = gr.State() |
|
|
|
with gr.Row(): |
|
file_input = gr.File(label="Upload CSV File", file_types=[".csv"]) |
|
|
|
|
|
with gr.Row(): |
|
date_column = gr.Dropdown( |
|
choices=[], |
|
label="Select Date column", |
|
multiselect=False, |
|
value=None |
|
) |
|
|
|
target_column = gr.Dropdown( |
|
choices=[], |
|
label="Select Target column", |
|
multiselect=False, |
|
value=None |
|
) |
|
|
|
gr.Examples( |
|
examples=[ |
|
["example_files/test_tops_product_id_1.csv"], |
|
["example_files/test_tops_product_id_2.csv"], |
|
["example_files/test_tops_product_id_3.csv"], |
|
["example_files/test_tops_product_id_4.csv"] |
|
], |
|
inputs=file_input, |
|
outputs=[df_state, date_column, target_column], |
|
fn=process_csv, |
|
cache_examples=True |
|
) |
|
|
|
with gr.Row(): |
|
visualize_btn = gr.Button("Forecast", variant="primary") |
|
|
|
with gr.Row(): |
|
plot_output = gr.Plot(label="Chronos Forecasting Visualization") |
|
|
|
with gr.Row(): |
|
pivot_plot_output = gr.Plot(label="Monthly Sales Pivot Table") |
|
|
|
file_input.upload( |
|
process_csv, |
|
inputs=[file_input], |
|
outputs=[df_state, date_column, target_column] |
|
) |
|
|
|
|
|
date_column.change( |
|
lambda x: x if x else "", |
|
inputs=[date_column], |
|
outputs=[] |
|
) |
|
|
|
target_column.change( |
|
lambda x: x if x else "", |
|
inputs=[target_column], |
|
outputs=[] |
|
) |
|
|
|
visualize_btn.click( |
|
fn=process_data, |
|
inputs=[file_input, date_column, target_column], |
|
outputs=[plot_output, pivot_plot_output] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |