Spaces:
Running
Running
from pandasai.llm import GoogleGemini | |
import streamlit as st | |
import os | |
import pandas as pd | |
from pandasai import SmartDataframe | |
from pandasai.responses.response_parser import ResponseParser | |
from st_on_hover_tabs import on_hover_tabs | |
from ydata_profiling import ProfileReport | |
import google.generativeai as genai | |
import json | |
class StreamLitResponse(ResponseParser): | |
def __init__(self,context) -> None: | |
super().__init__(context) | |
def format_dataframe(self,result): | |
st.dataframe(result['value']) | |
return | |
def format_plot(self,result): | |
st.image(result['value']) | |
return | |
def format_other(self, result): | |
st.write(result['value']) | |
return | |
gemini_api_key = os.environ['Gemini'] | |
genai.configure(api_key=gemini_api_key) | |
generation_config = { | |
"temperature": 0.2, | |
"top_p": 0.95, | |
"max_output_tokens": 5000, | |
} | |
model = genai.GenerativeModel( | |
model_name="gemini-1.5-flash", | |
generation_config=generation_config, | |
) | |
def calculate_kpis(df): | |
""" | |
Calculates key performance indicators from a given transaction dataset. | |
Args: | |
df: Pandas DataFrame containing transaction data. | |
Returns: | |
A JSON object containing the calculated KPIs. | |
""" | |
# Calculate Total Revenue | |
total_revenue = df['Price'] * df['Quantity'].sum() | |
# Calculate Top Five Products by Revenue | |
if df['Description'].nunique() > 5: | |
top_five_products = df.groupby('Description')['Price'].sum().nlargest(5).index.tolist() | |
else: | |
top_five_product = "there are less than 5 products in this dataset" | |
if df['Branch_Name'].nunique() > 1: | |
best_branch = df.groupby('Branch_Name')['Price'].sum().nlargest(1).index.tolist() | |
else: | |
best_branch = "there is only one branch in this dataset" | |
# Calculate Average Order Value (AOV) | |
aov = df.groupby('Receipt No_')['Price'].sum().mean() | |
# Calculate Customer Purchase Frequency (Requires more data for accurate calculation) | |
# Assuming 'Member Card No_' is a unique identifier for customers | |
customer_purchase_frequency = df.groupby('Customer_Name')['Receipt No_'].nunique().mean() | |
# Calculate Estimated Customer Lifetime Value (CLTV) (Requires more data for accurate calculation) | |
# Assuming a simple CLTV model based on AOV and purchase frequency | |
estimated_cltv = aov * customer_purchase_frequency * 12 # Assuming annual value | |
# Create JSON output | |
kpis = { | |
"total_revenue": total_revenue, | |
"top_five_products": top_five_products, | |
"average_order_value": aov, | |
"customer_purchase_frequency": customer_purchase_frequency, | |
"estimated_cltv": estimated_cltv, | |
"best_performing_branch": best_branch | |
} | |
return kpis | |
def get_pandas_profile(df): | |
profile = ProfileReport(df, title="Profiling Report") | |
json_profile = profile.to_json() | |
dict_p = json.loads(json_profile) | |
keys_to_keep = ['analysis', 'table', 'correlations', 'alerts', 'sample'] | |
# Assuming your dictionary is named 'my_dict' | |
filtered_dict = {key: dict_p[key] for key in keys_to_keep} | |
return filtered_dict | |
def generateResponse(dataFrame,prompt): | |
llm = GoogleGemini(api_key=gemini_api_key) | |
pandas_agent = SmartDataframe(dataFrame,config={"llm":llm, "response_parser":StreamLitResponse}) | |
answer = pandas_agent.chat(prompt) | |
return answer | |
st.write("# Brave Retail Insights") | |
st.markdown('<style>' + open('./style.css').read() + '</style>', unsafe_allow_html=True) | |
st.write("##### Engage in insightful conversations with your data through powerful visualizations") | |
with st.sidebar: | |
st.title("Brave Retail Insights") | |
st.sidebar.image("IMG_1181.jpeg", use_column_width=True) | |
tabs = on_hover_tabs(tabName=['Chat', 'Reports'], | |
iconName=['chat', 'dashboard'], default_choice=0) | |
uploaded_file = "bon_marche.csv" | |
#uploaded_file = "healthcare_dataset.csv" | |
if tabs =='Chat': | |
df = pd.read_csv(uploaded_file) | |
st.subheader("Brave Retail Chat") | |
st.write("Get visualizations and analysis from our Gemini powered agent") | |
# Read the CSV file | |
#df = pd.read_csv(uploaded_file) | |
# Display the data | |
with st.expander("Preview"): | |
st.write(df.head()) | |
# Plot the data | |
user_input = st.text_input("Type your message here",placeholder="Ask me about your data") | |
if user_input: | |
answer = generateResponse(dataFrame=df,prompt=user_input) | |
st.write(answer) | |
elif tabs == 'Reports': | |
df = pd.read_csv(uploaded_file) | |
# Streamlit App | |
st.subheader("Reports") | |
st.write("Filter by Branch Name or Product to generate report") | |
# Display original | |
# Filtering Interface | |
st.write("Filtering Options") | |
branch_names = df['Branch_Name'].unique().tolist() | |
#product_names = df['Description'].unique().tolist() | |
selected_branches = st.multiselect('Select Branch(es) Name', branch_names, default=branch_names) | |
#selected_products = st.multiselect('Select product(s) Name', product_names, default=product_names) | |
# Button to apply filters | |
if st.button('Apply Filters and Generate report'): | |
df = pd.read_csv(uploaded_file) | |
filtered_df = df.copy() | |
# Apply Branch Name Filter | |
if selected_branches: | |
filtered_df = filtered_df[filtered_df['Branch_Name'].isin(selected_branches)] | |
# Apply Description Filter | |
#if selected_products: | |
# filtered_df = filtered_df[filtered_df['Product_Name'].isin(selected_products)] | |
# Display filtered DataFrame | |
st.write("Filtered DataFrame") | |
with st.expander("Preview"): | |
st.write(filtered_df.head()) | |
with st.spinner("Generating Report, Please Wait...."): | |
prompt = """ | |
You are an expert business analyst. Analyze the following data and generate a comprehensive and insightful business report, including appropriate key perfomance indicators and reccomendations. | |
data: | |
""" + str(calculate_kpis(filtered_df)) + str(get_pandas_profile(filtered_df)) | |
response = model.generate_content(prompt) | |
report = response.text | |
st.markdown(report) | |
st.success("Report Generated!") | |
else: | |
st.write("Filtered DataFrame") | |
st.write("Click 'Apply Filters' to see the filtered data.") | |