Spaces:
Sleeping
Sleeping
import streamlit as st | |
import plotly.graph_objects as go | |
import networkx as nx | |
import pandas as pd | |
# Load the CSV file | |
df = pd.read_csv('/spices_by_cuisine_with_all_flavors.csv', index_col=0) | |
# Create a graph | |
G = nx.Graph() | |
# Add nodes for each cuisine and spice, and edges based on the DataFrame | |
for col in df.columns: | |
if col != "Flavor Description": | |
G.add_node(col, type='cuisine') | |
spices_for_cuisine = df[df[col] == 1].index.tolist() | |
for spice in spices_for_cuisine: | |
G.add_node(spice, type='spice') | |
G.add_edge(col, spice) | |
# Get node positions using the spring layout | |
pos = nx.spring_layout(G) | |
# Create edge trace | |
edge_trace = go.Scatter( | |
x=[], | |
y=[], | |
line=dict(width=0.5, color='#888'), | |
hoverinfo='none', | |
mode='lines') | |
for edge in G.edges(): | |
x0, y0 = pos[edge[0]] | |
x1, y1 = pos[edge[1]] | |
edge_trace['x'] += tuple([x0, x1, None]) | |
edge_trace['y'] += tuple([y0, y1, None]) | |
# Assign a unique color to each cuisine | |
cuisine_colors = {cuisine: f"hsl({i * (360 // len(df.columns[:-1]))}, 80%, 50%)" | |
for i, cuisine in enumerate(df.columns) if cuisine != "Flavor Description"} | |
# Create node trace for cuisines | |
node_trace_cuisines = go.Scatter( | |
x=[], | |
y=[], | |
text=[], | |
hovertext=[], | |
mode='markers+text', | |
hoverinfo='text', | |
marker=dict( | |
showscale=False, | |
size=20, | |
color=[], | |
line=dict(width=0))) | |
# Create node trace for spices | |
node_trace_spices = go.Scatter( | |
x=[], | |
y=[], | |
text=[], | |
hovertext=[], | |
mode='markers+text', | |
hoverinfo='text', | |
marker=dict( | |
showscale=False, | |
color='grey', | |
size=10, | |
line=dict(width=0))) | |
for node in G.nodes(): | |
x, y = pos[node] | |
if G.nodes[node]['type'] == 'cuisine': | |
node_trace_cuisines['x'] += tuple([x]) | |
node_trace_cuisines['y'] += tuple([y]) | |
node_trace_cuisines['text'] += tuple([node]) | |
node_trace_cuisines['marker']['color'] += tuple([cuisine_colors[node]]) | |
# Collect all spices associated with this cuisine | |
spices_associated = df[df[node] == 1].index.tolist() | |
hover_text = f"{node} uses: {', '.join(spices_associated)}" | |
node_trace_cuisines['hovertext'] += tuple([hover_text]) | |
else: | |
node_trace_spices['x'] += tuple([x]) | |
node_trace_spices['y'] += tuple([y]) | |
node_trace_spices['text'] += tuple([node]) | |
# Collect all cuisines that use this spice | |
cuisines_using_spice = df.columns[df.loc[node] == 1].tolist() | |
hover_text = f"{node} is used in: {', '.join(cuisines_using_spice)}" | |
node_trace_spices['hovertext'] += tuple([hover_text]) | |
# Create the network graph figure with updated hover information | |
fig = go.Figure(data=[edge_trace, node_trace_cuisines, node_trace_spices], | |
layout=go.Layout( | |
title="Network Graph of Cuisines and their Spices", | |
titlefont_size=16, | |
showlegend=False, | |
hovermode='closest', | |
margin=dict(b=20, l=5, r=5, t=40), | |
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)) | |
) | |
st.plotly_chart(fig, use_container_width=True) |