Spaces:
Runtime error
Runtime error
import os | |
# stop tensorflow from printing novels to stdout | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
import pickle | |
import numpy as np | |
import pandas as pd | |
import plotly.express as px | |
import streamlit as st | |
import tensorflow as tf | |
import tensorflow_hub as hub | |
from sklearn.cluster import DBSCAN | |
def read_stops(p: str): | |
""" | |
DOCSTRING | |
""" | |
return pd.read_csv(p) | |
def read_encodings(p: str) -> tf.Tensor: | |
""" | |
Unpickle the Universal Sentence Encoder v4 encodings | |
and return them | |
This function doesn't make any attempt to patch the security holes in `pickle`. | |
:param p: Path to the encodings | |
:returns: A Tensor of the encodings with shape (number of sentences, 512) | |
""" | |
with open(p, 'rb') as f: | |
encodings = pickle.load(f) | |
return encodings | |
def cluster_encodings(encodings: tf.Tensor) -> np.ndarray: | |
""" | |
DOCSTRING | |
""" | |
# I know the hyperparams I want from the EDA I did in the notebook | |
clusterer = DBSCAN(eps=0.7, min_samples=100).fit(encodings) | |
return clusterer.labels_ | |
def cluster_lat_lon(df: pd.DataFrame) -> np.ndarray: | |
""" | |
DOCSTRING | |
""" | |
# I know the hyperparams I want from the EDA I did in the notebook | |
clusterer = DBSCAN(eps=0.025, min_samples=100).fit(df[['latitude', 'longitude']]) | |
return clusterer.labels_ | |
def plot_example(df: pd.DataFrame, labels: np.ndarray) -> px.Figure: | |
""" | |
DOCSTRING | |
""" | |
plot_size = 800 | |
labels = labels.astype('str') | |
fig = px.scatter(df, x='longitude', y='latitude', | |
hover_name='display_name', | |
color=labels, | |
opacity=0.5, | |
color_discrete_sequence=px.colors.qualitative.Safe, | |
template='presentation', | |
width=plot_size, | |
height=plot_size) | |
# fig.show() | |
return fig | |
def plot_venice_blvd(df: pd.DataFrame, labels: np.ndarray) -> px.Figure: | |
""" | |
DOCSTRING | |
""" | |
px.set_mapbox_access_token(st.secrets['mapbox_token']) | |
venice_blvd = {'lat': 34.008350, | |
'lon': -118.425362} | |
labels = labels.astype('str') | |
fig = px.scatter_mapbox(df, lat='latitude', lon='longitude', | |
color=labels, | |
hover_name='display_name', | |
center=venice_blvd, | |
zoom=12, | |
color_discrete_sequence=px.colors.qualitative.Dark24) | |
# fig.show() | |
return fig | |
def main(data_path: str, enc_path: str): | |
df = read_stops(data_path) | |
# Cluster based on lat/lon | |
example_labels = cluster_lat_lon(df) | |
example_fig = plot_example(df, example_labels) | |
# Cluster based on the name of the stop | |
encodings = read_encodings(enc_path) | |
encoding_labels = cluster_encodings(encodings) | |
venice_fig = plot_venice_blvd(df, encoding_labels) | |
# Display the plots with Streamlit | |
st.write('# Example of what DBSCAN does') | |
st.plotly_chart(example_fig, use_container_width=True) | |
st.write('# Venice Blvd') | |
st.plotly_chart(example_fig, use_container_width=True) | |
if __name__ == '__main__': | |
import argparse | |
p = argparse.ArgumentParser() | |
p.add_argument('--data_path', | |
nargs='?', | |
default='data/stops.csv', | |
help="Path to the dataset of LA Metro stops. Defaults to 'data/stops.csv'") | |
p.add_argument('--enc_path', | |
nargs='?', | |
default='data/encodings.pkl', | |
help="Path to the pickled encodings. Defaults to 'data/encodings.pkl'") | |
args = p.parse_args() | |
main(**vars(args)) | |