import pandas as pd import gradio as gr from pathlib import Path import plotly.express as px import plotly.graph_objects as go import numpy as np import torch from chronos import ChronosPipeline from datetime import datetime import matplotlib.pyplot as plt import matplotlib.ticker as ticker def filter_data(start, end, df_state, select_product_column, date_column, target_column): if not date_column: raise gr.Error("Please select a Date column") if not target_column: raise gr.Error("Please select a target column") start_datetime = pd.to_datetime(datetime.utcfromtimestamp(start)) end_datetime = pd.to_datetime(datetime.utcfromtimestamp(end)) original_date_column = None original_target_column = None column_mapping = { ' '.join([word.capitalize() for word in col.split('_')]): col for col in df_state.columns } if date_column in column_mapping: original_date_column = column_mapping[date_column] if target_column in column_mapping: original_target_column = column_mapping[target_column] df_state[original_date_column] = pd.to_datetime(df_state[original_date_column]) filtered_df = df_state[(df_state[original_date_column] >= start_datetime) & (df_state[original_date_column] <= end_datetime)] filtered_df = filtered_df.groupby(original_date_column)[original_target_column].sum().reset_index() filtered_df = filtered_df.sort_values(by=original_date_column) fig = px.line(filtered_df, x=original_date_column, y=original_target_column, title="Historical Sales Data") return [filtered_df, fig] def upload_file(filepath): name = Path(filepath).name df = pd.read_csv(filepath.name) datetime_columns = [] numeric_columns = [] for col in df.columns: try: if all(isinstance(float(x), float) for x in df[col].head(3)): numeric_columns.append(col) except ValueError: continue for col in df.columns: if df[col].dtype == 'object': try: df[col] = pd.to_datetime(df[col]) except: pass if df[col].dtype == 'datetime64[ns]': datetime_columns.append(col) datetime_columns = list(map(lambda x: ' '.join([word.capitalize() for word in x.split('_')]), datetime_columns)) columns = df.columns.tolist() transformed_columns = list(map(lambda x: ' '.join([word.capitalize() for word in x.split('_')]), columns)) target_col = list(map(lambda x: ' '.join([word.capitalize() for word in x.split('_')]), numeric_columns)) transformed_columns.insert(0, "") data_columns = gr.Dropdown(choices=transformed_columns, value=None) date_columns = gr.Dropdown(choices=datetime_columns, value=None) target_columns = gr.Dropdown(choices=target_col, value=None) return [df, data_columns, date_columns, target_columns] def download_file(): return [gr.UploadButton(visible=True), gr.DownloadButton(visible=False)] def set_products(selected_column, df_state): column_mapping = { ' '.join([word.capitalize() for word in col.split('_')]): col for col in df_state.columns } if selected_column in column_mapping: original_column = column_mapping[selected_column] unique_values = df_state[original_column].dropna().unique().tolist() return unique_values return [] def set_dates(selected_column, df_state): column_mapping = { ' '.join([word.capitalize() for word in col.split('_')]): col for col in df_state.columns } if selected_column in column_mapping: original_column = column_mapping[selected_column] min_date = df_state[original_column].min() max_date = df_state[original_column].max() return min_date, max_date return None, None def forecast_chronos_data(df_state, date_column, target_column, select_period, forecasting_type): if not date_column: raise gr.Error("Please select a Date column") if not target_column: raise gr.Error("Please select a target column") original_date_column = None original_target_column = None column_mapping = { ' '.join([word.capitalize() for word in col.split('_')]): col for col in df_state.columns } if date_column in column_mapping: original_date_column = column_mapping[date_column] if target_column in column_mapping: original_target_column = column_mapping[target_column] df_forecast = pd.DataFrame() df_forecast['date'] = df_state[original_date_column] df_forecast['month'] = df_forecast['date'].dt.month df_forecast['year'] = df_forecast['date'].dt.year df_forecast['sold_qty'] = df_state[original_target_column] monthly_sales = df_forecast.groupby(['year', 'month'])['sold_qty'].sum().reset_index() monthly_sales = monthly_sales.rename(columns={'year': 'year', 'month': 'month', 'sold_qty': 'y'}) device = "cuda" if torch.cuda.is_available() else "cpu" pipeline = ChronosPipeline.from_pretrained( "amazon/chronos-t5-base", device_map=device, torch_dtype=torch.float32, ) context = torch.tensor(monthly_sales["y"]) prediction_length = select_period forecast = pipeline.predict(context, prediction_length) forecast_index = range(len(monthly_sales), len(monthly_sales) + prediction_length) low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0) low, median, high = np.ceil(low).astype(int), np.ceil(median).astype(int), np.ceil(high).astype(int) forecast_index = list(forecast_index) fig = px.line( x=monthly_sales.index, y=monthly_sales["y"], title="Sales Forecasting Visualization", labels={"x": "Months", "y": f"{target_column}"}, ) fig.add_trace( go.Scatter( x=forecast_index, y=median, name="Median Forecast", line=dict(color="tomato", width=2) ) ) fig.add_trace( go.Scatter( x=forecast_index, y=high, name="80% Prediction Interval", mode='lines', line=dict(width=2, color='rgba(50, 205, 50, 1)'), showlegend=False ) ) fig.add_trace( go.Scatter( x=forecast_index, y=low, name="10% Prediction Interval", mode='lines', line=dict(width=1, color='rgba(255, 255, 0, 1)'), showlegend=False, fillcolor='rgba(255, 99, 71, 0.3)', fill='tonexty', ) ) fig.update_layout( title_font_size=20, xaxis_title_font_size=16, yaxis_title_font_size=16, legend_font_size=16, xaxis_tickfont_size=14, yaxis_tickfont_size=14, showlegend=True, width=1600, # Equivalent to figsize=(30, 10) height=400, xaxis=dict( title="Months", tickfont=dict(size=14), gridcolor='rgba(128, 128, 128, 0.7)', gridwidth=1.2, dtick=3, griddash='dash', rangeslider=dict(visible=True), rangeselector=dict( buttons=list([ dict(count=6, label="6m", step="month", stepmode="backward"), dict(count=12, label="1y", step="month", stepmode="backward"), dict(count=24, label="2y", step="month", stepmode="backward"), dict(step="all", label="All") ]) ) ), yaxis=dict( gridcolor='rgba(128, 128, 128, 0.7)', gridwidth=1.2, dtick=5, # Set tick interval to 5 units griddash='dash' ), plot_bgcolor='white' # margin=dict(l=50, r=50, t=50, b=50) ) fig.update_traces( line=dict(color="royalblue", width=2), selector=dict(name="y") # Updates only the historical data line ) return fig # plt.figure(figsize=(30, 10)) # plt.plot(monthly_sales["y"], color="royalblue", label="Historical Data", linewidth=2) # plt.plot(forecast_index, median, color="tomato", label="Median Forecast", linewidth=2) # plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% Prediction Interval") # plt.title("Sales Forecasting Visualization", fontsize=16) # plt.xlabel("Months", fontsize=20) # plt.ylabel("Sold Qty", fontsize=20) # plt.xticks(fontsize=18) # plt.yticks(fontsize=18) # ax = plt.gca() # ax.xaxis.set_major_locator(ticker.MultipleLocator(3)) # ax.yaxis.set_major_locator(ticker.MultipleLocator(5)) # ax.grid(which='major', linestyle='--', linewidth=1.2, color='gray', alpha=0.7) # plt.legend(fontsize=18) # plt.grid(linestyle='--', linewidth=1.2, color='gray', alpha=0.7) # plt.tight_layout() # return plt.gcf() def home_page(): content = """ ### **Sales Forecasting with Chronos** Welcome to the future of sales optimization with **Chronos**. Say goodbye to guesswork and unlock the power of **data-driven insights** with our advanced forecasting platform. - **Seamless CSV Upload**: Quickly upload your sales data in CSV formatβ€”no technical expertise needed. - **AI-Powered Predictions**: Harness the power of state-of-the-art machine learning models to uncover trends and forecast future sales performance. - **Interactive Visualizations**: Gain actionable insights with intuitive charts and graphs that make data easy to understand. Start making smarter, data-backed business decisions today with **Chronos**! """ return content def about_page(): content = """ ### πŸ“§ **Contact Us:** - **Email**: contact@topsinfosolutions.com βœ‰οΈ - **Website**: [https://www.topsinfosolutions.com/](https://www.topsinfosolutions.com/) 🌐 ### πŸ›  **What We Offer:** - **Custom AI Solutions**: Tailored to your business needs πŸ€– - **Chatbot Development**: Build intelligent conversational agents πŸ’¬ - **Vision Models**: Computer vision solutions for various applications πŸ–ΌοΈ - **AI Agents**: Personalized agents powered by advanced LLMs πŸ€– ### πŸ€” **How We Can Help:** Reach out to us for bespoke AI services. Whether you need chatbots, vision models, or AI-powered agents, we’re here to build solutions that make a difference! 🌟 ### πŸ’¬ **Get in Touch:** If you have any questions or need a custom solution, click the button below to schedule a consultation with us. πŸ“… """ return content with gr.Blocks(theme=gr.themes.Default()) as demo: with gr.Tabs(): with gr.TabItem("Home"): df_state = gr.State() # gr.Image("/content/chronos-logo.png", interactive=False) home_output = gr.Markdown(value=home_page(), label="Playground") gr.Markdown("## Step 1: Historical/Training Data (currently supports *.csv only)") with gr.Row(): file_input = gr.File(label="Upload Historical (Training Data) Sales Data", file_types=[".csv"]) with gr.Row(): date_column = gr.Dropdown(choices=[], label="Select Date column (*Required)", multiselect=False, value=None) target_column = gr.Dropdown(choices=[], label="Select Target column (*Required)", multiselect=False, value=None) select_product_column = gr.Dropdown(choices=[], label="Select Product column (Optional)", multiselect=False, value=None) select_product = gr.Dropdown(choices=[], label="Select Product (Optional)", multiselect=False, value=None) with gr.Row(): start = gr.DateTime("2021-01-01 00:00:00", label="Training data Start date") end = gr.DateTime("2021-01-05 00:00:00", label="Training data End date") apply_btn = gr.Button("Visualize Data", scale=0) gr.Examples( examples=[ ["example_files/test_tops_product_id_1.csv"], ["example_files/test_tops_product_id_2.csv"], ["example_files/test_tops_product_id_3.csv"], ["example_files/test_tops_product_id_4.csv"] ], inputs=file_input, outputs=[df_state, select_product_column, date_column, target_column], fn=upload_file, ) with gr.Row(): historical_data_plot = gr.Plot() apply_btn.click( filter_data, inputs=[start, end, df_state, select_product_column, date_column, target_column], outputs=[df_state, historical_data_plot] ) gr.Markdown("## Step 2: Forecast") with gr.Row(): forecasting_type = gr.Radio(["day", "monthly", "year"], value="monthly", label="Forecasting Type", interactive=False) select_period = gr.Slider(2, 60, value=12, label="Select Period", info="Check Selected Forecast Type", interactive =True, step=1) forecast_btn = gr.Button("Forecast") with gr.Row(): plot_forecast_output = gr.Plot(label="Chronos Forecasting Visualization") forecast_btn.click( forecast_chronos_data, inputs=[df_state, date_column, target_column, select_period], outputs=[plot_forecast_output] ) file_input.change( upload_file, inputs=[file_input], outputs=[df_state, select_product_column, date_column, target_column] ) select_product_column.change( set_products, inputs=[select_product_column, df_state], outputs=[] ) date_column.change( set_dates, inputs=[date_column, df_state], outputs=[start, end] ) target_column.change( lambda x: x if x else [], inputs=[target_column], outputs=[] ) with gr.TabItem("About Tops"): df_state = gr.State() # gr.Image("/content/chronos-logo.png", interactive=False) about_output = gr.Markdown(value=about_page(), label="About Tops") if __name__ == "__main__": demo.launch()