Spaces:
Runtime error
Runtime error
File size: 3,535 Bytes
11cb781 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import os
# stop tensorflow from printing novels to stdout
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import pickle
import numpy as np
import pandas as pd
import plotly.express as px
import streamlit as st
import tensorflow as tf
import tensorflow_hub as hub
from sklearn.cluster import DBSCAN
def read_stops(p: str):
"""
DOCSTRING
"""
return pd.read_csv(p)
def read_encodings(p: str) -> tf.Tensor:
"""
Unpickle the Universal Sentence Encoder v4 encodings
and return them
This function doesn't make any attempt to patch the security holes in `pickle`.
:param p: Path to the encodings
:returns: A Tensor of the encodings with shape (number of sentences, 512)
"""
with open(p, 'rb') as f:
encodings = pickle.load(f)
return encodings
def cluster_encodings(encodings: tf.Tensor) -> np.ndarray:
"""
DOCSTRING
"""
# I know the hyperparams I want from the EDA I did in the notebook
clusterer = DBSCAN(eps=0.7, min_samples=100).fit(encodings)
return clusterer.labels_
def cluster_lat_lon(df: pd.DataFrame) -> np.ndarray:
"""
DOCSTRING
"""
# I know the hyperparams I want from the EDA I did in the notebook
clusterer = DBSCAN(eps=0.025, min_samples=100).fit(df[['latitude', 'longitude']])
return clusterer.labels_
def plot_example(df: pd.DataFrame, labels: np.ndarray) -> px.Figure:
"""
DOCSTRING
"""
plot_size = 800
labels = labels.astype('str')
fig = px.scatter(df, x='longitude', y='latitude',
hover_name='display_name',
color=labels,
opacity=0.5,
color_discrete_sequence=px.colors.qualitative.Safe,
template='presentation',
width=plot_size,
height=plot_size)
# fig.show()
return fig
def plot_venice_blvd(df: pd.DataFrame, labels: np.ndarray) -> px.Figure:
"""
DOCSTRING
"""
px.set_mapbox_access_token(st.secrets['mapbox_token'])
venice_blvd = {'lat': 34.008350,
'lon': -118.425362}
labels = labels.astype('str')
fig = px.scatter_mapbox(df, lat='latitude', lon='longitude',
color=labels,
hover_name='display_name',
center=venice_blvd,
zoom=12,
color_discrete_sequence=px.colors.qualitative.Dark24)
# fig.show()
return fig
def main(data_path: str, enc_path: str):
df = read_stops(data_path)
# Cluster based on lat/lon
example_labels = cluster_lat_lon(df)
example_fig = plot_example(df, example_labels)
# Cluster based on the name of the stop
encodings = read_encodings(enc_path)
encoding_labels = cluster_encodings(encodings)
venice_fig = plot_venice_blvd(df, encoding_labels)
# Display the plots with Streamlit
st.write('# Example of what DBSCAN does')
st.plotly_chart(example_fig, use_container_width=True)
st.write('# Venice Blvd')
st.plotly_chart(example_fig, use_container_width=True)
if __name__ == '__main__':
import argparse
p = argparse.ArgumentParser()
p.add_argument('--data_path',
nargs='?',
default='data/stops.csv',
help="Path to the dataset of LA Metro stops. Defaults to 'data/stops.csv'")
p.add_argument('--enc_path',
nargs='?',
default='data/encodings.pkl',
help="Path to the pickled encodings. Defaults to 'data/encodings.pkl'")
args = p.parse_args()
main(**vars(args))
|