from __future__ import annotations from typing import Iterable import gradio as gr import pandas as pd import matplotlib.pyplot as plt import numpy as np import os import math import torch from chronos import ChronosPipeline import warnings from seafoam import Seafoam warnings.filterwarnings("ignore") import numpy as np import matplotlib.ticker as ticker os.makedirs("example_files", exist_ok=True) def process_csv(file): if file is None: return None, gr.Dropdown(choices=[]) if not file.name.endswith('.csv'): raise gr.Error("Please upload a CSV file only") df = pd.read_csv(file.name) columns = df.columns.tolist() transformed_columns = list(map(lambda x: ' '.join([word.capitalize() for word in x.split('_')]), columns)) data_columns = gr.Dropdown(choices=transformed_columns, value=None) return df, data_columns, data_columns def process_data(csv_file, date_column_value, target_column_value): try: if not csv_file: return "Error: Upload Csv File" if not date_column_value or not target_column_value: return "Error: Both date and target columns must be selected" date_column = date_column_value.lower().replace(" ", "_") target_column = target_column_value.lower().replace(" ", "_") # Read the CSV file df = pd.read_csv(csv_file.name) numeric_mask = df[date_column].apply(lambda x: isinstance(x, (int, float))) if numeric_mask.any(): return "Error: Found numeric values in column '{date_column}'. Please provide dates in string format like 'YYYY-MM-DD'." df['date'] = pd.to_datetime(df[date_column]) df['month'] = df['date'].dt.month df['year'] = df['date'].dt.year df['sold_qty'] = df[target_column] monthly_sales = df.groupby(['year', 'month'])['sold_qty'].sum().reset_index() monthly_sales = monthly_sales.rename(columns={'year': 'year', 'month': 'month', 'sold_qty': 'y'}) pipeline = ChronosPipeline.from_pretrained( "amazon/chronos-t5-base", device_map="cpu", torch_dtype=torch.float32, ) context = torch.tensor(monthly_sales["y"]) prediction_length = 12 forecast = pipeline.predict(context, prediction_length) # Prepare forecast data forecast_index = range(len(monthly_sales), len(monthly_sales) + prediction_length) low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0) df['month_name'] = df['date'].dt.month_name() month_order = [ 'January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December' ] df['month_name'] = pd.Categorical(df['month_name'], categories=month_order, ordered=True) expanded_df = df.copy() year_month_sum = expanded_df.groupby(['year', 'month_name'])['sold_qty'].sum().reset_index() # Create a pivot table: sum of units sold per year and month pivot_table = year_month_sum.pivot(index='year', columns='month_name', values='sold_qty') new_data_list = [math.ceil(x) for x in median] # Add the new data list for the next year (incrementing the year by 1) next_year = pivot_table.index[-1] + 1 # Increment the year by 1 pivot_table.loc[next_year] = new_data_list # Add the new row for the next year # Visualization: Pivot Table Data (Second Plot) fig3, ax3 = plt.subplots(figsize=(18, 6)) # Create a table inside the plot ax3.axis('off') # Turn off the axis table = ax3.table(cellText=pivot_table.values, colLabels=pivot_table.columns, rowLabels=pivot_table.index, loc='center', cellLoc='center') # Style the table table.auto_set_font_size(False) table.set_fontsize(12) table.scale(1.2, 1.2) # Scale the table for better visibility # Adjust table colors (optional) for (i, j), cell in table.get_celld().items(): if i == 0: cell.set_text_props(weight='bold') cell.set_facecolor('#f2f2f2') elif j == 0: cell.set_text_props(weight='bold') cell.set_facecolor('#f2f2f2') else: cell.set_facecolor('white') # Visualization plt.figure(figsize=(30, 10)) plt.plot(monthly_sales["y"], color="royalblue", label="Historical Data", linewidth=2) plt.plot(forecast_index, median, color="tomato", label="Median Forecast", linewidth=2) plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% Prediction Interval") plt.title("Sales Forecasting Visualization", fontsize=16) plt.xlabel("Months", fontsize=20) plt.ylabel("Sold Qty", fontsize=20) plt.xticks(fontsize=18) plt.yticks(fontsize=18) ax = plt.gca() ax.xaxis.set_major_locator(ticker.MultipleLocator(3)) ax.yaxis.set_major_locator(ticker.MultipleLocator(5)) ax.grid(which='major', linestyle='--', linewidth=1.2, color='gray', alpha=0.7) plt.legend(fontsize=18) plt.grid(linestyle='--', linewidth=1.2, color='gray', alpha=0.7) plt.tight_layout() return plt.gcf(), fig3 except Exception as e: print(f"Error: {str(e)}") return None # Create Gradio interface with gr.Blocks(theme=Seafoam()) as demo: gr.Markdown("# Chronos Forecasting - Tops infosolutions Pvt Ltd") gr.Markdown("Upload a CSV file and click 'Forecast' to generate sales forecast for next 12 months .") df_state = gr.State() with gr.Row(): file_input = gr.File(label="Upload CSV File", file_types=[".csv"]) with gr.Row(): date_column = gr.Dropdown( choices=[], label="Select Date column", multiselect=False, value=None ) target_column = gr.Dropdown( choices=[], label="Select Target column", multiselect=False, value=None ) gr.Examples( examples=[ ["example_files/13dec_product_id96airaco.csv"], ["example_files/13dec_product_id346airaco.csv"], ["example_files/13dec_product_id567airaco.csv"], ["example_files/13dec_product_id856airaco.csv"], ["example_files/airaco_product_id215.csv"] ], inputs=file_input, outputs=[df_state, date_column, target_column], fn=process_csv, cache_examples=True ) with gr.Row(): visualize_btn = gr.Button("Forecast", variant="primary") with gr.Row(): plot_output = gr.Plot(label="Chronos Forecasting Visualization") with gr.Row(): pivot_plot_output = gr.Plot(label="Monthly Sales Pivot Table") file_input.upload( process_csv, inputs=[file_input], outputs=[df_state, date_column, target_column] ) # Column selection handler date_column.change( lambda x: x if x else "", inputs=[date_column], outputs=[] ) target_column.change( lambda x: x if x else "", inputs=[target_column], outputs=[] ) visualize_btn.click( fn=process_data, inputs=[file_input, date_column, target_column], outputs=[plot_output, pivot_plot_output] ) # Launch the app if __name__ == "__main__": demo.launch()