import streamlit as st import numpy as np import pandas as pd import plotly.graph_objects as go import networkx as nx import random random.seed(42) np.random.seed(42) st.set_page_config(layout="wide") # Load and process data df = pd.read_csv('spices_by_cuisine_with_all_flavors.csv', index_col=0) pivot = df.drop(columns=['Flavor Description']).sort_index() cuisines = {} for col in pivot.columns: filter = pivot[col] == 1 cuisines[col] = pivot[filter].index.to_list() spices = {} pivot_t = pivot.T.sort_index() for col in pivot_t.columns: filter = pivot_t[col] == 1 spices[col] = pivot_t[filter].index.to_list() def similarity(ratings, kind='user', epsilon=1e-9): if kind == 'user': sim = ratings.dot(ratings.T) + epsilon elif kind == 'item': sim = ratings.T.dot(ratings) + epsilon norms = np.array([np.sqrt(np.diagonal(sim))]) return (sim / norms / norms.T) pivot_names = pivot_t.columns pivot_np = np.array(pivot_t) cuisine_similarity = pd.DataFrame(similarity(pivot_np, kind='user')) cuisine_similarity.columns = pivot_t.index.values cuisine_similarity.index = pivot_t.index.values st.title('Spices Across Cuisines') col1, col2, col3 = st.columns(3) with col1: st.subheader('By Cuisine') select_cuisine = st.selectbox('Select a cuisine to view the top 10 spices',cuisines.keys()) st.write(f'The top 10 ingredients in {select_cuisine} are:', cuisines[select_cuisine]) with col2: st.subheader('By Spice') select_spice = st.selectbox('Select a spice to view which cuisines it is present in',spices.keys()) st.write(f'{select_spice} is part of the following cuisines:', spices[select_spice]) with col3: st.subheader("Similar Cuisines") select_cuisine_sim = st.selectbox('Select a cuisine to view the 10 most similar cuisines by spices',cuisines.keys()) st.write(f'{select_cuisine_sim} is most similar to:', cuisine_similarity[select_cuisine_sim].sort_values(ascending=False).index[1:11].to_list()) # Create a graph G = nx.Graph() # Add nodes for each cuisine and spice, and edges based on the DataFrame for col in df.columns: if col != "Flavor Description": G.add_node(col, type='cuisine') spices_for_cuisine = df[df[col] == 1].index.tolist() for spice in spices_for_cuisine: G.add_node(spice, type='spice') G.add_edge(col, spice) # Get node positions using the spring layout pos = nx.spring_layout(G) # Create edge trace edge_trace = go.Scatter( x=[], y=[], line=dict(width=0.5, color='#888'), hoverinfo='none', mode='lines') for edge in G.edges(): x0, y0 = pos[edge[0]] x1, y1 = pos[edge[1]] edge_trace['x'] += tuple([x0, x1, None]) edge_trace['y'] += tuple([y0, y1, None]) # Assign a unique color to each cuisine cuisine_colors = {cuisine: f"hsl({i * (360 // len(df.columns[:-1]))}, 80%, 50%)" for i, cuisine in enumerate(df.columns) if cuisine != "Flavor Description"} # Create node trace for cuisines node_trace_cuisines = go.Scatter( x=[], y=[], text=[], hovertext=[], mode='markers+text', hoverinfo='text', marker=dict( showscale=False, size=20, color=[], line=dict(width=0))) # Create node trace for spices node_trace_spices = go.Scatter( x=[], y=[], text=[], hovertext=[], mode='markers+text', hoverinfo='text', marker=dict( showscale=False, color='grey', size=10, line=dict(width=0))) for node in G.nodes(): x, y = pos[node] if G.nodes[node]['type'] == 'cuisine': node_trace_cuisines['x'] += tuple([x]) node_trace_cuisines['y'] += tuple([y]) node_trace_cuisines['text'] += tuple([node]) node_trace_cuisines['marker']['color'] += tuple([cuisine_colors[node]]) # Collect all spices associated with this cuisine spices_associated = df[df[node] == 1].index.tolist() hover_text = f"{node} uses: {', '.join(spices_associated)}" node_trace_cuisines['hovertext'] += tuple([hover_text]) else: node_trace_spices['x'] += tuple([x]) node_trace_spices['y'] += tuple([y]) node_trace_spices['text'] += tuple([node]) # Collect all cuisines that use this spice cuisines_using_spice = df.columns[df.loc[node] == 1].tolist() hover_text = f"{node} is used in: {', '.join(cuisines_using_spice)}" node_trace_spices['hovertext'] += tuple([hover_text]) # Create the network graph figure with updated hover information fig = go.Figure(data=[edge_trace, node_trace_cuisines, node_trace_spices], layout=go.Layout( title="Network Graph of Cuisines and their Spices", titlefont_size=16, showlegend=False, hovermode='closest', margin=dict(b=20, l=5, r=5, t=40), xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)) ) st.plotly_chart(fig, use_container_width=True)