File size: 7,463 Bytes
5f8097d e414fc2 5f8097d e414fc2 3d22fd4 5f8097d e414fc2 5f8097d e414fc2 5f8097d e414fc2 5f8097d e414fc2 5f8097d e414fc2 5f8097d e414fc2 5f8097d e414fc2 5f8097d 3d22fd4 5f8097d 3d22fd4 5f8097d e414fc2 3d22fd4 5f8097d e414fc2 5f8097d e414fc2 a013f3d e414fc2 5f8097d 3d22fd4 5f8097d e414fc2 5f8097d e414fc2 3d22fd4 5f8097d e414fc2 5f8097d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
from __future__ import annotations
from typing import Iterable
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
import math
import torch
from chronos import ChronosPipeline
import warnings
from seafoam import Seafoam
warnings.filterwarnings("ignore")
import numpy as np
import matplotlib.ticker as ticker
os.makedirs("example_files", exist_ok=True)
def process_csv(file):
if file is None:
return None, gr.Dropdown(choices=[])
if not file.name.endswith('.csv'):
raise gr.Error("Please upload a CSV file only")
df = pd.read_csv(file.name)
columns = df.columns.tolist()
transformed_columns = list(map(lambda x: ' '.join([word.capitalize() for word in x.split('_')]), columns))
data_columns = gr.Dropdown(choices=transformed_columns, value=None)
return df, data_columns, data_columns
def process_data(csv_file, date_column_value, target_column_value):
try:
if not csv_file:
return "Error: Upload Csv File"
if not date_column_value or not target_column_value:
return "Error: Both date and target columns must be selected"
date_column = date_column_value.lower().replace(" ", "_")
target_column = target_column_value.lower().replace(" ", "_")
# Read the CSV file
df = pd.read_csv(csv_file.name)
numeric_mask = df[date_column].apply(lambda x: isinstance(x, (int, float)))
if numeric_mask.any():
return "Error: Found numeric values in column '{date_column}'. Please provide dates in string format like 'YYYY-MM-DD'."
df['date'] = pd.to_datetime(df[date_column])
df['month'] = df['date'].dt.month
df['year'] = df['date'].dt.year
df['sold_qty'] = df[target_column]
monthly_sales = df.groupby(['year', 'month'])['sold_qty'].sum().reset_index()
monthly_sales = monthly_sales.rename(columns={'year': 'year', 'month': 'month', 'sold_qty': 'y'})
pipeline = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-base",
device_map="cpu",
torch_dtype=torch.float32,
)
context = torch.tensor(monthly_sales["y"])
prediction_length = 12
forecast = pipeline.predict(context, prediction_length)
# Prepare forecast data
forecast_index = range(len(monthly_sales), len(monthly_sales) + prediction_length)
low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
df['month_name'] = df['date'].dt.month_name()
month_order = [
'January', 'February', 'March', 'April', 'May', 'June',
'July', 'August', 'September', 'October', 'November', 'December'
]
df['month_name'] = pd.Categorical(df['month_name'], categories=month_order, ordered=True)
expanded_df = df.copy()
year_month_sum = expanded_df.groupby(['year', 'month_name'])['sold_qty'].sum().reset_index()
# Create a pivot table: sum of units sold per year and month
pivot_table = year_month_sum.pivot(index='year', columns='month_name', values='sold_qty')
new_data_list = [math.ceil(x) for x in median]
# Add the new data list for the next year (incrementing the year by 1)
next_year = pivot_table.index[-1] + 1 # Increment the year by 1
pivot_table.loc[next_year] = new_data_list # Add the new row for the next year
# Visualization: Pivot Table Data (Second Plot)
fig3, ax3 = plt.subplots(figsize=(18, 6))
# Create a table inside the plot
ax3.axis('off') # Turn off the axis
table = ax3.table(cellText=pivot_table.values, colLabels=pivot_table.columns, rowLabels=pivot_table.index, loc='center', cellLoc='center')
# Style the table
table.auto_set_font_size(False)
table.set_fontsize(12)
table.scale(1.2, 1.2) # Scale the table for better visibility
# Adjust table colors (optional)
for (i, j), cell in table.get_celld().items():
if i == 0:
cell.set_text_props(weight='bold')
cell.set_facecolor('#f2f2f2')
elif j == 0:
cell.set_text_props(weight='bold')
cell.set_facecolor('#f2f2f2')
else:
cell.set_facecolor('white')
# Visualization
plt.figure(figsize=(30, 10))
plt.plot(monthly_sales["y"], color="royalblue", label="Historical Data", linewidth=2)
plt.plot(forecast_index, median, color="tomato", label="Median Forecast", linewidth=2)
plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% Prediction Interval")
plt.title("Sales Forecasting Visualization", fontsize=16)
plt.xlabel("Months", fontsize=20)
plt.ylabel("Sold Qty", fontsize=20)
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
ax = plt.gca()
ax.xaxis.set_major_locator(ticker.MultipleLocator(3))
ax.yaxis.set_major_locator(ticker.MultipleLocator(5))
ax.grid(which='major', linestyle='--', linewidth=1.2, color='gray', alpha=0.7)
plt.legend(fontsize=18)
plt.grid(linestyle='--', linewidth=1.2, color='gray', alpha=0.7)
plt.tight_layout()
return plt.gcf(), fig3
except Exception as e:
print(f"Error: {str(e)}")
return None
# Create Gradio interface
with gr.Blocks(theme=Seafoam()) as demo:
gr.Markdown("# Chronos Forecasting - Tops infosolutions Pvt Ltd")
gr.Markdown("Upload a CSV file and click 'Forecast' to generate sales forecast for next 12 months .")
df_state = gr.State()
with gr.Row():
file_input = gr.File(label="Upload CSV File", file_types=[".csv"])
with gr.Row():
date_column = gr.Dropdown(
choices=[],
label="Select Date column",
multiselect=False,
value=None
)
target_column = gr.Dropdown(
choices=[],
label="Select Target column",
multiselect=False,
value=None
)
gr.Examples(
examples=[
["example_files/test_tops_product_id_1.csv"],
["example_files/test_tops_product_id_2.csv"],
["example_files/test_tops_product_id_3.csv"],
["example_files/test_tops_product_id_4.csv"]
],
inputs=file_input,
outputs=[df_state, date_column, target_column],
fn=process_csv,
cache_examples=True
)
with gr.Row():
visualize_btn = gr.Button("Forecast", variant="primary")
with gr.Row():
plot_output = gr.Plot(label="Chronos Forecasting Visualization")
with gr.Row():
pivot_plot_output = gr.Plot(label="Monthly Sales Pivot Table")
file_input.upload(
process_csv,
inputs=[file_input],
outputs=[df_state, date_column, target_column]
)
# Column selection handler
date_column.change(
lambda x: x if x else "",
inputs=[date_column],
outputs=[]
)
target_column.change(
lambda x: x if x else "",
inputs=[target_column],
outputs=[]
)
visualize_btn.click(
fn=process_data,
inputs=[file_input, date_column, target_column],
outputs=[plot_output, pivot_plot_output]
)
# Launch the app
if __name__ == "__main__":
demo.launch() |