Spaces:
Sleeping
Sleeping
AJ-Gazin
commited on
Commit
·
8138b99
1
Parent(s):
0f19412
Saved viz's to HTML files, directly loading now
Browse files- .gitattributes +1 -0
- Visualizations/pca_visualization.html +3 -0
- Visualizations/tsne_visualization.html +3 -0
- Visualizations/umap_visualization.html +3 -0
- __pycache__/model_def.cpython-312.pyc +0 -0
- __pycache__/viz_utils.cpython-312.pyc +0 -0
- app.py +15 -15
- visualizer.py +9 -41
.gitattributes
CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
*.csv filter=lfs diff=lfs merge=lfs -text
|
|
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
*.csv filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.html filter=lfs diff=lfs merge=lfs -text
|
Visualizations/pca_visualization.html
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3e654ab0d04dd190f2165a6c89a720606c80e2211f4bc5350022b5179574af41
|
3 |
+
size 4290216
|
Visualizations/tsne_visualization.html
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d6ff7dda2c472758b59e110be2720024dcd3f38d5d8fdc4a1779ebcfd74f4632
|
3 |
+
size 4266707
|
Visualizations/umap_visualization.html
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:68bdda27d0f4924077ee91d49ddf7855f313dc0a40bf9dd6f51fbef1accd92c0
|
3 |
+
size 4251158
|
__pycache__/model_def.cpython-312.pyc
ADDED
Binary file (3.37 kB). View file
|
|
__pycache__/viz_utils.cpython-312.pyc
ADDED
Binary file (4.16 kB). View file
|
|
app.py
CHANGED
@@ -41,14 +41,14 @@ st.title("Movie Recommendation App")
|
|
41 |
|
42 |
|
43 |
# --- VISUALIZATIONS ---
|
44 |
-
|
45 |
-
|
46 |
|
47 |
-
|
48 |
-
|
49 |
|
50 |
-
|
51 |
-
|
52 |
|
53 |
tab1, tab2 = st.tabs(["Visualizations", "Recommendations"])
|
54 |
|
@@ -63,23 +63,23 @@ with tab1:
|
|
63 |
umap_expander = st.expander("UMAP Visualization")
|
64 |
with umap_expander:
|
65 |
st.subheader('UMAP Visualization')
|
66 |
-
umap_fig = viz_utils.visualize_embeddings_umap(embedding_df)
|
67 |
-
st.plotly_chart(umap_fig)
|
68 |
-
|
69 |
|
70 |
tsne_expander = st.expander("TSNE Visualization")
|
71 |
with tsne_expander:
|
72 |
st.subheader('TSNE Visualization')
|
73 |
-
tsne_fig = viz_utils.visualize_embeddings_tsne(embedding_df)
|
74 |
-
st.plotly_chart(tsne_fig)
|
75 |
-
|
76 |
|
77 |
pca_expander = st.expander("PCA Visualization")
|
78 |
with pca_expander:
|
79 |
st.subheader('PCA Visualization')
|
80 |
-
pca_fig = viz_utils.visualize_embeddings_pca(embedding_df)
|
81 |
-
st.plotly_chart(pca_fig)
|
82 |
-
|
83 |
|
84 |
|
85 |
|
|
|
41 |
|
42 |
|
43 |
# --- VISUALIZATIONS ---
|
44 |
+
with open("./Visualizations/umap_visualization.html", "r", encoding='utf-8') as f:
|
45 |
+
umap_html = f.read()
|
46 |
|
47 |
+
with open("./Visualizations/tsne_visualization.html", "r") as f:
|
48 |
+
tsne_html = f.read()
|
49 |
|
50 |
+
with open("./Visualizations/pca_visualization.html", "r") as f:
|
51 |
+
pca_html = f.read()
|
52 |
|
53 |
tab1, tab2 = st.tabs(["Visualizations", "Recommendations"])
|
54 |
|
|
|
63 |
umap_expander = st.expander("UMAP Visualization")
|
64 |
with umap_expander:
|
65 |
st.subheader('UMAP Visualization')
|
66 |
+
#umap_fig = viz_utils.visualize_embeddings_umap(embedding_df)
|
67 |
+
#st.plotly_chart(umap_fig)
|
68 |
+
components.html(umap_html, width=800, height=800)
|
69 |
|
70 |
tsne_expander = st.expander("TSNE Visualization")
|
71 |
with tsne_expander:
|
72 |
st.subheader('TSNE Visualization')
|
73 |
+
#tsne_fig = viz_utils.visualize_embeddings_tsne(embedding_df)
|
74 |
+
#st.plotly_chart(tsne_fig)
|
75 |
+
components.html(tsne_html, width=800, height=800)
|
76 |
|
77 |
pca_expander = st.expander("PCA Visualization")
|
78 |
with pca_expander:
|
79 |
st.subheader('PCA Visualization')
|
80 |
+
#pca_fig = viz_utils.visualize_embeddings_pca(embedding_df)
|
81 |
+
#st.plotly_chart(pca_fig)
|
82 |
+
components.html(pca_html, width=800, height=800)
|
83 |
|
84 |
|
85 |
|
visualizer.py
CHANGED
@@ -14,6 +14,8 @@ from torch_geometric.transforms import RandomLinkSplit, ToUndirected
|
|
14 |
from sentence_transformers import SentenceTransformer
|
15 |
from torch_geometric.data import HeteroData
|
16 |
import yaml
|
|
|
|
|
17 |
|
18 |
|
19 |
|
@@ -24,48 +26,10 @@ data = torch.load("./PyGdata.pt", map_location=device)
|
|
24 |
|
25 |
movies_df = pd.read_csv("./sampled_movie_dataset/movies_metadata.csv")
|
26 |
|
27 |
-
|
28 |
-
def __init__(self, hidden_channels, out_channels):
|
29 |
-
super().__init__()
|
30 |
-
# these convolutions have been replicated to match the number of edge types
|
31 |
-
self.conv1 = SAGEConv((-1, -1), hidden_channels)
|
32 |
-
self.conv2 = SAGEConv((-1, -1), out_channels)
|
33 |
-
|
34 |
-
def forward(self, x, edge_index):
|
35 |
-
x = self.conv1(x, edge_index).relu()
|
36 |
-
x = self.conv2(x, edge_index)
|
37 |
-
return x
|
38 |
-
|
39 |
-
class EdgeDecoder(torch.nn.Module):
|
40 |
-
def __init__(self, hidden_channels):
|
41 |
-
super().__init__()
|
42 |
-
self.lin1 = Linear(2 * hidden_channels, hidden_channels)
|
43 |
-
self.lin2 = Linear(hidden_channels, 1)
|
44 |
-
|
45 |
-
def forward(self, z_dict, edge_label_index):
|
46 |
-
row, col = edge_label_index
|
47 |
-
# concat user and movie embeddings
|
48 |
-
z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1)
|
49 |
-
# concatenated embeddings passed to linear layer
|
50 |
-
z = self.lin1(z).relu()
|
51 |
-
z = self.lin2(z)
|
52 |
-
return z.view(-1)
|
53 |
-
|
54 |
-
class Model(torch.nn.Module):
|
55 |
-
def __init__(self, hidden_channels):
|
56 |
-
super().__init__()
|
57 |
-
self.encoder = GNNEncoder(hidden_channels, hidden_channels)
|
58 |
-
self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
|
59 |
-
self.decoder = EdgeDecoder(hidden_channels)
|
60 |
-
|
61 |
-
def forward(self, x_dict, edge_index_dict, edge_label_index):
|
62 |
-
# z_dict contains dictionary of movie and user embeddings returned from GraphSage
|
63 |
-
z_dict = self.encoder(x_dict, edge_index_dict)
|
64 |
-
return self.decoder(z_dict, edge_label_index)
|
65 |
|
66 |
-
model = Model(hidden_channels=32).to(device)
|
67 |
-
|
68 |
-
model.load_state_dict(torch.load("PyGTrainedModelState.pt"), map_location=device)
|
69 |
model.eval()
|
70 |
|
71 |
total_users = data['user'].num_nodes
|
@@ -87,12 +51,16 @@ movie_index = 20
|
|
87 |
title = movies_df.iloc[movie_index]['title']
|
88 |
print(title)
|
89 |
|
|
|
90 |
|
91 |
fig_umap = viz_utils.visualize_embeddings_umap(embedding_df)
|
92 |
viz_utils.save_visualization(fig_umap, "./Visualizations/umap_visualization")
|
|
|
93 |
|
94 |
fig_tsne = viz_utils.visualize_embeddings_tsne(embedding_df)
|
95 |
viz_utils.save_visualization(fig_tsne, "./Visualizations/tsne_visualization")
|
|
|
96 |
|
97 |
fig_pca = viz_utils.visualize_embeddings_pca(embedding_df)
|
98 |
viz_utils.save_visualization(fig_pca, "./Visualizations/pca_visualization")
|
|
|
|
14 |
from sentence_transformers import SentenceTransformer
|
15 |
from torch_geometric.data import HeteroData
|
16 |
import yaml
|
17 |
+
import os
|
18 |
+
import model_def
|
19 |
|
20 |
|
21 |
|
|
|
26 |
|
27 |
movies_df = pd.read_csv("./sampled_movie_dataset/movies_metadata.csv")
|
28 |
|
29 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
+
model = model_def.Model(hidden_channels=32).to(device)
|
32 |
+
model.load_state_dict(torch.load("PyGTrainedModelState.pt", map_location=device)),
|
|
|
33 |
model.eval()
|
34 |
|
35 |
total_users = data['user'].num_nodes
|
|
|
51 |
title = movies_df.iloc[movie_index]['title']
|
52 |
print(title)
|
53 |
|
54 |
+
os.makedirs("Visualizations", exist_ok=True)
|
55 |
|
56 |
fig_umap = viz_utils.visualize_embeddings_umap(embedding_df)
|
57 |
viz_utils.save_visualization(fig_umap, "./Visualizations/umap_visualization")
|
58 |
+
print("UMAP visualization saved")
|
59 |
|
60 |
fig_tsne = viz_utils.visualize_embeddings_tsne(embedding_df)
|
61 |
viz_utils.save_visualization(fig_tsne, "./Visualizations/tsne_visualization")
|
62 |
+
print("TSNE visualization saved")
|
63 |
|
64 |
fig_pca = viz_utils.visualize_embeddings_pca(embedding_df)
|
65 |
viz_utils.save_visualization(fig_pca, "./Visualizations/pca_visualization")
|
66 |
+
print("PCA visualization saved")
|