File size: 7,463 Bytes
5f8097d
e414fc2
5f8097d
 
 
 
 
e414fc2
3d22fd4
5f8097d
 
 
e414fc2
 
 
5f8097d
 
e414fc2
 
 
 
 
 
 
 
5f8097d
e414fc2
 
 
 
 
 
 
5f8097d
 
e414fc2
5f8097d
e414fc2
 
 
 
 
 
 
 
 
5f8097d
 
 
e414fc2
 
 
 
 
 
5f8097d
 
e414fc2
5f8097d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d22fd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f8097d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d22fd4
5f8097d
 
 
 
 
 
e414fc2
3d22fd4
5f8097d
 
e414fc2
 
5f8097d
 
 
e414fc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a013f3d
 
 
 
e414fc2
 
 
 
 
 
 
5f8097d
 
 
 
 
 
 
3d22fd4
5f8097d
e414fc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f8097d
 
e414fc2
3d22fd4
5f8097d
 
e414fc2
 
 
 
5f8097d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
from __future__ import annotations

from typing import Iterable
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
import math
import torch
from chronos import ChronosPipeline
import warnings

from seafoam import Seafoam

warnings.filterwarnings("ignore")

import numpy as np
import matplotlib.ticker as ticker

os.makedirs("example_files", exist_ok=True)

def process_csv(file):
    if file is None:
        return None, gr.Dropdown(choices=[])

    if not file.name.endswith('.csv'):
        raise gr.Error("Please upload a CSV file only")
    df = pd.read_csv(file.name)
    columns = df.columns.tolist()
    transformed_columns = list(map(lambda x: ' '.join([word.capitalize() for word in x.split('_')]), columns))
    data_columns = gr.Dropdown(choices=transformed_columns, value=None)
    return df, data_columns, data_columns


def process_data(csv_file, date_column_value, target_column_value):
    try:
        if not csv_file:
            return "Error: Upload Csv File"

        if not date_column_value or not target_column_value:
            return "Error: Both date and target columns must be selected"

        date_column = date_column_value.lower().replace(" ", "_")
        target_column = target_column_value.lower().replace(" ", "_")

        # Read the CSV file
        df = pd.read_csv(csv_file.name)

        numeric_mask = df[date_column].apply(lambda x: isinstance(x, (int, float)))
        if numeric_mask.any():
            return "Error: Found numeric values in column '{date_column}'. Please provide dates in string format like 'YYYY-MM-DD'."

        df['date'] = pd.to_datetime(df[date_column])

        df['month'] = df['date'].dt.month
        df['year'] = df['date'].dt.year
        df['sold_qty'] = df[target_column]

        monthly_sales = df.groupby(['year', 'month'])['sold_qty'].sum().reset_index()
        monthly_sales = monthly_sales.rename(columns={'year': 'year', 'month': 'month', 'sold_qty': 'y'})

        pipeline = ChronosPipeline.from_pretrained(
            "amazon/chronos-t5-base",
            device_map="cpu",
            torch_dtype=torch.float32,
            )
        context = torch.tensor(monthly_sales["y"])
        prediction_length = 12
        forecast = pipeline.predict(context, prediction_length)

        # Prepare forecast data
        forecast_index = range(len(monthly_sales), len(monthly_sales) + prediction_length)
        low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)

        df['month_name'] = df['date'].dt.month_name()
        month_order = [
            'January', 'February', 'March', 'April', 'May', 'June',
            'July', 'August', 'September', 'October', 'November', 'December'
        ]
        df['month_name'] = pd.Categorical(df['month_name'], categories=month_order, ordered=True)

        expanded_df = df.copy()
        year_month_sum = expanded_df.groupby(['year', 'month_name'])['sold_qty'].sum().reset_index()

        # Create a pivot table: sum of units sold per year and month
        pivot_table = year_month_sum.pivot(index='year', columns='month_name', values='sold_qty')

        new_data_list = [math.ceil(x) for x in median]

        # Add the new data list for the next year (incrementing the year by 1)
        next_year = pivot_table.index[-1] + 1  # Increment the year by 1
        pivot_table.loc[next_year] = new_data_list  # Add the new row for the next year

        # Visualization: Pivot Table Data (Second Plot)
        fig3, ax3 = plt.subplots(figsize=(18, 6))

        # Create a table inside the plot
        ax3.axis('off')  # Turn off the axis
        table = ax3.table(cellText=pivot_table.values, colLabels=pivot_table.columns, rowLabels=pivot_table.index, loc='center', cellLoc='center')

        # Style the table
        table.auto_set_font_size(False)
        table.set_fontsize(12)
        table.scale(1.2, 1.2)  # Scale the table for better visibility

        # Adjust table colors (optional)
        for (i, j), cell in table.get_celld().items():
            if i == 0:
                cell.set_text_props(weight='bold')
                cell.set_facecolor('#f2f2f2')
            elif j == 0:
                cell.set_text_props(weight='bold')
                cell.set_facecolor('#f2f2f2')
            else:
                cell.set_facecolor('white')

        # Visualization
        plt.figure(figsize=(30, 10))
        plt.plot(monthly_sales["y"], color="royalblue", label="Historical Data", linewidth=2)
        plt.plot(forecast_index, median, color="tomato", label="Median Forecast", linewidth=2)
        plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% Prediction Interval")
        plt.title("Sales Forecasting Visualization", fontsize=16)
        plt.xlabel("Months", fontsize=20)
        plt.ylabel("Sold Qty", fontsize=20)

        plt.xticks(fontsize=18)
        plt.yticks(fontsize=18)

        ax = plt.gca()
        ax.xaxis.set_major_locator(ticker.MultipleLocator(3))
        ax.yaxis.set_major_locator(ticker.MultipleLocator(5))
        ax.grid(which='major', linestyle='--', linewidth=1.2, color='gray', alpha=0.7)

        plt.legend(fontsize=18)
        plt.grid(linestyle='--', linewidth=1.2, color='gray', alpha=0.7)
        plt.tight_layout()

        return plt.gcf(), fig3

    except Exception as e:
        print(f"Error: {str(e)}")
        return None

# Create Gradio interface
with gr.Blocks(theme=Seafoam()) as demo:
    gr.Markdown("# Chronos Forecasting - Tops infosolutions Pvt Ltd")
    gr.Markdown("Upload a CSV file and click 'Forecast' to generate sales forecast for next 12 months .")

    df_state = gr.State()

    with gr.Row():
        file_input = gr.File(label="Upload CSV File", file_types=[".csv"])


    with gr.Row():
        date_column = gr.Dropdown(
            choices=[],
            label="Select Date column",
            multiselect=False,
            value=None
        )

        target_column = gr.Dropdown(
            choices=[],
            label="Select Target column",
            multiselect=False,
            value=None
        )

    gr.Examples(
        examples=[
            ["example_files/test_tops_product_id_1.csv"],
            ["example_files/test_tops_product_id_2.csv"],
            ["example_files/test_tops_product_id_3.csv"],
            ["example_files/test_tops_product_id_4.csv"]
        ],
        inputs=file_input,
        outputs=[df_state, date_column, target_column],
        fn=process_csv,
        cache_examples=True
    )

    with gr.Row():
        visualize_btn = gr.Button("Forecast", variant="primary")

    with gr.Row():
        plot_output = gr.Plot(label="Chronos Forecasting Visualization")

    with gr.Row():
        pivot_plot_output = gr.Plot(label="Monthly Sales Pivot Table")

    file_input.upload(
        process_csv,
        inputs=[file_input],
        outputs=[df_state, date_column, target_column]
    )

    # Column selection handler
    date_column.change(
        lambda x: x if x else "",
        inputs=[date_column],
        outputs=[]
    )

    target_column.change(
        lambda x: x if x else "",
        inputs=[target_column],
        outputs=[]
    )

    visualize_btn.click(
        fn=process_data,
        inputs=[file_input, date_column, target_column],
        outputs=[plot_output, pivot_plot_output]
    )





# Launch the app
if __name__ == "__main__":
    demo.launch()