David Wisdom commited on
Commit
11cb781
·
1 Parent(s): 9629b6b

first draft

Browse files
Files changed (1) hide show
  1. app.py +131 -0
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # stop tensorflow from printing novels to stdout
3
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
4
+ import pickle
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import plotly.express as px
9
+ import streamlit as st
10
+ import tensorflow as tf
11
+ import tensorflow_hub as hub
12
+
13
+ from sklearn.cluster import DBSCAN
14
+
15
+
16
+ def read_stops(p: str):
17
+ """
18
+ DOCSTRING
19
+ """
20
+ return pd.read_csv(p)
21
+
22
+
23
+ def read_encodings(p: str) -> tf.Tensor:
24
+ """
25
+ Unpickle the Universal Sentence Encoder v4 encodings
26
+ and return them
27
+
28
+ This function doesn't make any attempt to patch the security holes in `pickle`.
29
+
30
+ :param p: Path to the encodings
31
+
32
+ :returns: A Tensor of the encodings with shape (number of sentences, 512)
33
+ """
34
+ with open(p, 'rb') as f:
35
+ encodings = pickle.load(f)
36
+ return encodings
37
+
38
+
39
+ def cluster_encodings(encodings: tf.Tensor) -> np.ndarray:
40
+ """
41
+ DOCSTRING
42
+ """
43
+ # I know the hyperparams I want from the EDA I did in the notebook
44
+ clusterer = DBSCAN(eps=0.7, min_samples=100).fit(encodings)
45
+ return clusterer.labels_
46
+
47
+
48
+ def cluster_lat_lon(df: pd.DataFrame) -> np.ndarray:
49
+ """
50
+ DOCSTRING
51
+ """
52
+ # I know the hyperparams I want from the EDA I did in the notebook
53
+ clusterer = DBSCAN(eps=0.025, min_samples=100).fit(df[['latitude', 'longitude']])
54
+ return clusterer.labels_
55
+
56
+
57
+ def plot_example(df: pd.DataFrame, labels: np.ndarray) -> px.Figure:
58
+ """
59
+ DOCSTRING
60
+ """
61
+ plot_size = 800
62
+ labels = labels.astype('str')
63
+
64
+ fig = px.scatter(df, x='longitude', y='latitude',
65
+ hover_name='display_name',
66
+ color=labels,
67
+ opacity=0.5,
68
+ color_discrete_sequence=px.colors.qualitative.Safe,
69
+ template='presentation',
70
+ width=plot_size,
71
+ height=plot_size)
72
+ # fig.show()
73
+ return fig
74
+
75
+
76
+ def plot_venice_blvd(df: pd.DataFrame, labels: np.ndarray) -> px.Figure:
77
+ """
78
+ DOCSTRING
79
+ """
80
+ px.set_mapbox_access_token(st.secrets['mapbox_token'])
81
+ venice_blvd = {'lat': 34.008350,
82
+ 'lon': -118.425362}
83
+ labels = labels.astype('str')
84
+
85
+ fig = px.scatter_mapbox(df, lat='latitude', lon='longitude',
86
+ color=labels,
87
+ hover_name='display_name',
88
+ center=venice_blvd,
89
+ zoom=12,
90
+ color_discrete_sequence=px.colors.qualitative.Dark24)
91
+
92
+ # fig.show()
93
+ return fig
94
+
95
+
96
+ def main(data_path: str, enc_path: str):
97
+ df = read_stops(data_path)
98
+
99
+ # Cluster based on lat/lon
100
+ example_labels = cluster_lat_lon(df)
101
+ example_fig = plot_example(df, example_labels)
102
+
103
+ # Cluster based on the name of the stop
104
+ encodings = read_encodings(enc_path)
105
+ encoding_labels = cluster_encodings(encodings)
106
+ venice_fig = plot_venice_blvd(df, encoding_labels)
107
+
108
+ # Display the plots with Streamlit
109
+ st.write('# Example of what DBSCAN does')
110
+ st.plotly_chart(example_fig, use_container_width=True)
111
+
112
+ st.write('# Venice Blvd')
113
+ st.plotly_chart(example_fig, use_container_width=True)
114
+
115
+
116
+ if __name__ == '__main__':
117
+ import argparse
118
+
119
+ p = argparse.ArgumentParser()
120
+ p.add_argument('--data_path',
121
+ nargs='?',
122
+ default='data/stops.csv',
123
+ help="Path to the dataset of LA Metro stops. Defaults to 'data/stops.csv'")
124
+ p.add_argument('--enc_path',
125
+ nargs='?',
126
+ default='data/encodings.pkl',
127
+ help="Path to the pickled encodings. Defaults to 'data/encodings.pkl'")
128
+ args = p.parse_args()
129
+
130
+ main(**vars(args))
131
+