AJ-Gazin commited on
Commit
8138b99
·
1 Parent(s): 0f19412

Saved viz's to HTML files, directly loading now

Browse files
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.csv filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.csv filter=lfs diff=lfs merge=lfs -text
37
+ *.html filter=lfs diff=lfs merge=lfs -text
Visualizations/pca_visualization.html ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e654ab0d04dd190f2165a6c89a720606c80e2211f4bc5350022b5179574af41
3
+ size 4290216
Visualizations/tsne_visualization.html ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6ff7dda2c472758b59e110be2720024dcd3f38d5d8fdc4a1779ebcfd74f4632
3
+ size 4266707
Visualizations/umap_visualization.html ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68bdda27d0f4924077ee91d49ddf7855f313dc0a40bf9dd6f51fbef1accd92c0
3
+ size 4251158
__pycache__/model_def.cpython-312.pyc ADDED
Binary file (3.37 kB). View file
 
__pycache__/viz_utils.cpython-312.pyc ADDED
Binary file (4.16 kB). View file
 
app.py CHANGED
@@ -41,14 +41,14 @@ st.title("Movie Recommendation App")
41
 
42
 
43
  # --- VISUALIZATIONS ---
44
- #with open("umap_visualization.html", "r", encoding='utf-8') as f:
45
- # umap_html = f.read()
46
 
47
- #with open("tsne_visualization.html", "r") as f:
48
- # tsne_html = f.read()
49
 
50
- #with open("pca_visualization.html", "r") as f:
51
- # pca_html = f.read()
52
 
53
  tab1, tab2 = st.tabs(["Visualizations", "Recommendations"])
54
 
@@ -63,23 +63,23 @@ with tab1:
63
  umap_expander = st.expander("UMAP Visualization")
64
  with umap_expander:
65
  st.subheader('UMAP Visualization')
66
- umap_fig = viz_utils.visualize_embeddings_umap(embedding_df)
67
- st.plotly_chart(umap_fig)
68
- #components.html(umap_html, width=800, height=800)
69
 
70
  tsne_expander = st.expander("TSNE Visualization")
71
  with tsne_expander:
72
  st.subheader('TSNE Visualization')
73
- tsne_fig = viz_utils.visualize_embeddings_tsne(embedding_df)
74
- st.plotly_chart(tsne_fig)
75
- #components.html(tsne_html, width=800, height=800)
76
 
77
  pca_expander = st.expander("PCA Visualization")
78
  with pca_expander:
79
  st.subheader('PCA Visualization')
80
- pca_fig = viz_utils.visualize_embeddings_pca(embedding_df)
81
- st.plotly_chart(pca_fig)
82
- #components.html(pca_html, width=800, height=800)
83
 
84
 
85
 
 
41
 
42
 
43
  # --- VISUALIZATIONS ---
44
+ with open("./Visualizations/umap_visualization.html", "r", encoding='utf-8') as f:
45
+ umap_html = f.read()
46
 
47
+ with open("./Visualizations/tsne_visualization.html", "r") as f:
48
+ tsne_html = f.read()
49
 
50
+ with open("./Visualizations/pca_visualization.html", "r") as f:
51
+ pca_html = f.read()
52
 
53
  tab1, tab2 = st.tabs(["Visualizations", "Recommendations"])
54
 
 
63
  umap_expander = st.expander("UMAP Visualization")
64
  with umap_expander:
65
  st.subheader('UMAP Visualization')
66
+ #umap_fig = viz_utils.visualize_embeddings_umap(embedding_df)
67
+ #st.plotly_chart(umap_fig)
68
+ components.html(umap_html, width=800, height=800)
69
 
70
  tsne_expander = st.expander("TSNE Visualization")
71
  with tsne_expander:
72
  st.subheader('TSNE Visualization')
73
+ #tsne_fig = viz_utils.visualize_embeddings_tsne(embedding_df)
74
+ #st.plotly_chart(tsne_fig)
75
+ components.html(tsne_html, width=800, height=800)
76
 
77
  pca_expander = st.expander("PCA Visualization")
78
  with pca_expander:
79
  st.subheader('PCA Visualization')
80
+ #pca_fig = viz_utils.visualize_embeddings_pca(embedding_df)
81
+ #st.plotly_chart(pca_fig)
82
+ components.html(pca_html, width=800, height=800)
83
 
84
 
85
 
visualizer.py CHANGED
@@ -14,6 +14,8 @@ from torch_geometric.transforms import RandomLinkSplit, ToUndirected
14
  from sentence_transformers import SentenceTransformer
15
  from torch_geometric.data import HeteroData
16
  import yaml
 
 
17
 
18
 
19
 
@@ -24,48 +26,10 @@ data = torch.load("./PyGdata.pt", map_location=device)
24
 
25
  movies_df = pd.read_csv("./sampled_movie_dataset/movies_metadata.csv")
26
 
27
- class GNNEncoder(torch.nn.Module):
28
- def __init__(self, hidden_channels, out_channels):
29
- super().__init__()
30
- # these convolutions have been replicated to match the number of edge types
31
- self.conv1 = SAGEConv((-1, -1), hidden_channels)
32
- self.conv2 = SAGEConv((-1, -1), out_channels)
33
-
34
- def forward(self, x, edge_index):
35
- x = self.conv1(x, edge_index).relu()
36
- x = self.conv2(x, edge_index)
37
- return x
38
-
39
- class EdgeDecoder(torch.nn.Module):
40
- def __init__(self, hidden_channels):
41
- super().__init__()
42
- self.lin1 = Linear(2 * hidden_channels, hidden_channels)
43
- self.lin2 = Linear(hidden_channels, 1)
44
-
45
- def forward(self, z_dict, edge_label_index):
46
- row, col = edge_label_index
47
- # concat user and movie embeddings
48
- z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1)
49
- # concatenated embeddings passed to linear layer
50
- z = self.lin1(z).relu()
51
- z = self.lin2(z)
52
- return z.view(-1)
53
-
54
- class Model(torch.nn.Module):
55
- def __init__(self, hidden_channels):
56
- super().__init__()
57
- self.encoder = GNNEncoder(hidden_channels, hidden_channels)
58
- self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
59
- self.decoder = EdgeDecoder(hidden_channels)
60
-
61
- def forward(self, x_dict, edge_index_dict, edge_label_index):
62
- # z_dict contains dictionary of movie and user embeddings returned from GraphSage
63
- z_dict = self.encoder(x_dict, edge_index_dict)
64
- return self.decoder(z_dict, edge_label_index)
65
 
66
- model = Model(hidden_channels=32).to(device)
67
- model2 = Model(hidden_channels=32).to(device)
68
- model.load_state_dict(torch.load("PyGTrainedModelState.pt"), map_location=device)
69
  model.eval()
70
 
71
  total_users = data['user'].num_nodes
@@ -87,12 +51,16 @@ movie_index = 20
87
  title = movies_df.iloc[movie_index]['title']
88
  print(title)
89
 
 
90
 
91
  fig_umap = viz_utils.visualize_embeddings_umap(embedding_df)
92
  viz_utils.save_visualization(fig_umap, "./Visualizations/umap_visualization")
 
93
 
94
  fig_tsne = viz_utils.visualize_embeddings_tsne(embedding_df)
95
  viz_utils.save_visualization(fig_tsne, "./Visualizations/tsne_visualization")
 
96
 
97
  fig_pca = viz_utils.visualize_embeddings_pca(embedding_df)
98
  viz_utils.save_visualization(fig_pca, "./Visualizations/pca_visualization")
 
 
14
  from sentence_transformers import SentenceTransformer
15
  from torch_geometric.data import HeteroData
16
  import yaml
17
+ import os
18
+ import model_def
19
 
20
 
21
 
 
26
 
27
  movies_df = pd.read_csv("./sampled_movie_dataset/movies_metadata.csv")
28
 
29
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ model = model_def.Model(hidden_channels=32).to(device)
32
+ model.load_state_dict(torch.load("PyGTrainedModelState.pt", map_location=device)),
 
33
  model.eval()
34
 
35
  total_users = data['user'].num_nodes
 
51
  title = movies_df.iloc[movie_index]['title']
52
  print(title)
53
 
54
+ os.makedirs("Visualizations", exist_ok=True)
55
 
56
  fig_umap = viz_utils.visualize_embeddings_umap(embedding_df)
57
  viz_utils.save_visualization(fig_umap, "./Visualizations/umap_visualization")
58
+ print("UMAP visualization saved")
59
 
60
  fig_tsne = viz_utils.visualize_embeddings_tsne(embedding_df)
61
  viz_utils.save_visualization(fig_tsne, "./Visualizations/tsne_visualization")
62
+ print("TSNE visualization saved")
63
 
64
  fig_pca = viz_utils.visualize_embeddings_pca(embedding_df)
65
  viz_utils.save_visualization(fig_pca, "./Visualizations/pca_visualization")
66
+ print("PCA visualization saved")