AJ-Gazin commited on
Commit
06d9388
·
1 Parent(s): 960b542

Added more of main program

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # .gitignore
2
+ .env
3
+ creds.dat
PredictionGenerator.ipynb ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "2.2.1+cu121\n"
13
+ ]
14
+ }
15
+ ],
16
+ "source": [
17
+ "import sys\n",
18
+ "\n",
19
+ "sys.path.insert(0, \"./interactive_tutorials\")\n",
20
+ "\n",
21
+ "import pandas as pd\n",
22
+ "from tqdm import tqdm\n",
23
+ "import numpy as np\n",
24
+ "import matplotlib.pyplot as plt\n",
25
+ "import itertools\n",
26
+ "import requests\n",
27
+ "import sys\n",
28
+ "\n",
29
+ "import torch\n",
30
+ "import torch.nn.functional as F\n",
31
+ "from torch.nn import Linear\n",
32
+ "import torch_geometric.transforms as T\n",
33
+ "from torch_geometric.nn import SAGEConv, to_hetero\n",
34
+ "from torch_geometric.transforms import RandomLinkSplit, ToUndirected\n",
35
+ "from sentence_transformers import SentenceTransformer\n",
36
+ "from torch_geometric.data import HeteroData\n",
37
+ "import yaml\n",
38
+ "\n",
39
+ "print(torch.__version__)\n",
40
+ "\n",
41
+ "\n",
42
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 3,
48
+ "metadata": {},
49
+ "outputs": [
50
+ {
51
+ "data": {
52
+ "text/plain": [
53
+ "{('user',\n",
54
+ " 'rates',\n",
55
+ " 'movie'): tensor([[ 0, 0, 0, ..., 670, 670, 670],\n",
56
+ " [ 0, 1, 2, ..., 1327, 1329, 2941]], device='cuda:0'),\n",
57
+ " ('movie',\n",
58
+ " 'rev_rates',\n",
59
+ " 'user'): tensor([[ 0, 1, 2, ..., 1327, 1329, 2941],\n",
60
+ " [ 0, 0, 0, ..., 670, 670, 670]], device='cuda:0')}"
61
+ ]
62
+ },
63
+ "execution_count": 3,
64
+ "metadata": {},
65
+ "output_type": "execute_result"
66
+ }
67
+ ],
68
+ "source": [
69
+ "data = torch.load(\"./PyGdata.pt\")\n",
70
+ "data.edge_index_dict\n",
71
+ "\n",
72
+ "\n"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": 12,
78
+ "metadata": {},
79
+ "outputs": [
80
+ {
81
+ "data": {
82
+ "text/plain": [
83
+ "{('user',\n",
84
+ " 'rates',\n",
85
+ " 'movie'): tensor([[ 0, 0, 0, ..., 670, 670, 670],\n",
86
+ " [ 0, 1, 2, ..., 1327, 1329, 2941]], device='cuda:0'),\n",
87
+ " ('movie',\n",
88
+ " 'rev_rates',\n",
89
+ " 'user'): tensor([[ 0, 1, 2, ..., 1327, 1329, 2941],\n",
90
+ " [ 0, 0, 0, ..., 670, 670, 670]], device='cuda:0')}"
91
+ ]
92
+ },
93
+ "execution_count": 12,
94
+ "metadata": {},
95
+ "output_type": "execute_result"
96
+ }
97
+ ],
98
+ "source": []
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": 4,
103
+ "metadata": {},
104
+ "outputs": [],
105
+ "source": [
106
+ "class GNNEncoder(torch.nn.Module):\n",
107
+ " def __init__(self, hidden_channels, out_channels):\n",
108
+ " super().__init__()\n",
109
+ " # these convolutions have been replicated to match the number of edge types\n",
110
+ " self.conv1 = SAGEConv((-1, -1), hidden_channels)\n",
111
+ " self.conv2 = SAGEConv((-1, -1), out_channels)\n",
112
+ "\n",
113
+ " def forward(self, x, edge_index):\n",
114
+ " x = self.conv1(x, edge_index).relu()\n",
115
+ " x = self.conv2(x, edge_index)\n",
116
+ " return x\n"
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "code",
121
+ "execution_count": 5,
122
+ "metadata": {},
123
+ "outputs": [],
124
+ "source": [
125
+ "class EdgeDecoder(torch.nn.Module):\n",
126
+ " def __init__(self, hidden_channels):\n",
127
+ " super().__init__()\n",
128
+ " self.lin1 = Linear(2 * hidden_channels, hidden_channels)\n",
129
+ " self.lin2 = Linear(hidden_channels, 1)\n",
130
+ "\n",
131
+ " def forward(self, z_dict, edge_label_index):\n",
132
+ " row, col = edge_label_index\n",
133
+ " # concat user and movie embeddings\n",
134
+ " z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1)\n",
135
+ " # concatenated embeddings passed to linear layer\n",
136
+ " z = self.lin1(z).relu()\n",
137
+ " z = self.lin2(z)\n",
138
+ " return z.view(-1)"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": 6,
144
+ "metadata": {},
145
+ "outputs": [],
146
+ "source": [
147
+ "class Model(torch.nn.Module):\n",
148
+ " def __init__(self, hidden_channels):\n",
149
+ " super().__init__()\n",
150
+ " self.encoder = GNNEncoder(hidden_channels, hidden_channels)\n",
151
+ " self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')\n",
152
+ " self.decoder = EdgeDecoder(hidden_channels)\n",
153
+ "\n",
154
+ " def forward(self, x_dict, edge_index_dict, edge_label_index):\n",
155
+ " # z_dict contains dictionary of movie and user embeddings returned from GraphSage\n",
156
+ " z_dict = self.encoder(x_dict, edge_index_dict)\n",
157
+ " return self.decoder(z_dict, edge_label_index)"
158
+ ]
159
+ },
160
+ {
161
+ "cell_type": "code",
162
+ "execution_count": 10,
163
+ "metadata": {},
164
+ "outputs": [
165
+ {
166
+ "name": "stdout",
167
+ "output_type": "stream",
168
+ "text": [
169
+ "HeteroData(\n",
170
+ " user={ x=[671, 671] },\n",
171
+ " movie={ x=[9025, 404] },\n",
172
+ " (user, rates, movie)={\n",
173
+ " edge_index=[2, 99810],\n",
174
+ " edge_label=[99810],\n",
175
+ " },\n",
176
+ " (movie, rev_rates, user)={ edge_index=[2, 99810] }\n",
177
+ ")\n"
178
+ ]
179
+ }
180
+ ],
181
+ "source": [
182
+ "model = Model(hidden_channels=32).to(device)\n",
183
+ "model2 = Model(hidden_channels=32).to(device)\n",
184
+ "model.load_state_dict(torch.load(\"PyGTrainedModelState.pt\"))\n",
185
+ "model.eval()\n",
186
+ "\n",
187
+ "total_users = data['user'].num_nodes \n",
188
+ "total_movies = data['movie'].num_nodes \n",
189
+ "print(data)\n"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": 9,
195
+ "metadata": {},
196
+ "outputs": [
197
+ {
198
+ "name": "stderr",
199
+ "output_type": "stream",
200
+ "text": [
201
+ "100%|██████████| 671/671 [00:05<00:00, 121.64it/s]\n"
202
+ ]
203
+ }
204
+ ],
205
+ "source": [
206
+ "movie_recs = []\n",
207
+ "for user_id in tqdm(range(0, total_users)):\n",
208
+ " user_row = torch.tensor([user_id] * total_movies)\n",
209
+ " all_movie_ids = torch.arange(total_movies)\n",
210
+ " edge_label_index = torch.stack([user_row, all_movie_ids], dim=0)\n",
211
+ " pred = model(data.x_dict, data.edge_index_dict,\n",
212
+ " edge_label_index)\n",
213
+ " pred = pred.clamp(min=0, max=5)\n",
214
+ " # we will only select movies for the user where the predicting rating is =5\n",
215
+ " rec_movie_ids = (pred == 5).nonzero(as_tuple=True)\n",
216
+ " top_ten_recs = [rec_movies for rec_movies in rec_movie_ids[0].tolist()[:10]]\n",
217
+ " movie_recs.append({'user': user_id, 'rec_movies': top_ten_recs})"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": 15,
223
+ "metadata": {},
224
+ "outputs": [
225
+ {
226
+ "name": "stdout",
227
+ "output_type": "stream",
228
+ "text": [
229
+ "Movie not found\n"
230
+ ]
231
+ },
232
+ {
233
+ "name": "stderr",
234
+ "output_type": "stream",
235
+ "text": [
236
+ "C:\\Users\\aj\\AppData\\Local\\Temp\\ipykernel_24552\\778055959.py:2: DtypeWarning: Columns (10) have mixed types. Specify dtype option on import or set low_memory=False.\n",
237
+ " df = pd.read_csv(metadata_path)\n"
238
+ ]
239
+ }
240
+ ],
241
+ "source": [
242
+ "metadata_path = './sampled_movie_dataset/movies_metadata.csv'\n",
243
+ "df = pd.read_csv(metadata_path)\n",
244
+ "df.columns\n",
245
+ "\n",
246
+ "def get_movie_title(movie_id):\n",
247
+ " \"\"\"Looks up a movie title by its ID in the DataFrame.\"\"\"\n",
248
+ "\n",
249
+ " row = df[df['id'] == movie_id]\n",
250
+ "\n",
251
+ " if not row.empty:\n",
252
+ " return row['title'].iloc[0] # Get the title from the first matching row\n",
253
+ " else:\n",
254
+ " return \"Movie not found\"\n",
255
+ " \n",
256
+ "print(get_movie_title(14))"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": null,
262
+ "metadata": {},
263
+ "outputs": [],
264
+ "source": []
265
+ },
266
+ {
267
+ "cell_type": "code",
268
+ "execution_count": 26,
269
+ "metadata": {},
270
+ "outputs": [
271
+ {
272
+ "name": "stdout",
273
+ "output_type": "stream",
274
+ "text": [
275
+ " user rec_movies\n",
276
+ "0 0 [14, 85, 101, 106, 111, 131, 132, 150, 210, 216]\n",
277
+ "1 1 [13, 45, 95, 108, 109, 126, 130, 132, 213, 220]\n",
278
+ "2 2 [562, 571, 894, 1013, 1169, 1289, 1378, 1405, ...\n",
279
+ "3 3 [126, 137, 502, 571, 616, 696, 811, 966, 999, ...\n",
280
+ "4 4 [364, 436, 493, 502, 509, 706, 781, 811, 1244,...\n",
281
+ "Index(['user', 'rec_movies'], dtype='object')\n"
282
+ ]
283
+ }
284
+ ],
285
+ "source": [
286
+ "\n",
287
+ "movie_recs_df = pd.DataFrame(movie_recs)\n",
288
+ "#movie_recs_df = movie_recs_df.set_index('id').join(df[['title']].set_index('id'), how='left')\n",
289
+ "print(movie_recs_df.head()) \n",
290
+ "print(movie_recs_df.columns) "
291
+ ]
292
+ }
293
+ ],
294
+ "metadata": {
295
+ "kernelspec": {
296
+ "display_name": ".venv",
297
+ "language": "python",
298
+ "name": "python3"
299
+ },
300
+ "language_info": {
301
+ "codemirror_mode": {
302
+ "name": "ipython",
303
+ "version": 3
304
+ },
305
+ "file_extension": ".py",
306
+ "mimetype": "text/x-python",
307
+ "name": "python",
308
+ "nbconvert_exporter": "python",
309
+ "pygments_lexer": "ipython3",
310
+ "version": "3.12.2"
311
+ }
312
+ },
313
+ "nbformat": 4,
314
+ "nbformat_minor": 2
315
+ }
PyGTrainedModelState.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d31578ebb232fd63763d00f19051db44f8841e7117b9203c2f316bcae91a5deb
3
+ size 307626
PyGdata.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:27cbee2403cf96a9942394e4ad78f229b52bb8ca8c2e16839b3083bcaef877a6
3
+ size 20380556
model_def.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ import torch
5
+ from torch_geometric.nn import SAGEConv, to_hetero, Linear
6
+ from dotenv import load_dotenv
7
+
8
+ data = torch.load("./PyGdata.pt")
9
+
10
+ class GNNEncoder(torch.nn.Module):
11
+ def __init__(self, hidden_channels, out_channels):
12
+ super().__init__()
13
+ self.conv1 = SAGEConv((-1, -1), hidden_channels)
14
+ self.conv2 = SAGEConv((-1, -1), out_channels)
15
+
16
+ def forward(self, x, edge_index):
17
+ x = self.conv1(x, edge_index).relu()
18
+ x = self.conv2(x, edge_index)
19
+ return x
20
+
21
+ class EdgeDecoder(torch.nn.Module):
22
+ def __init__(self, hidden_channels):
23
+ super().__init__()
24
+ self.lin1 = Linear(2 * hidden_channels, hidden_channels)
25
+ self.lin2 = Linear(hidden_channels, 1)
26
+
27
+ def forward(self, z_dict, edge_label_index):
28
+ row, col = edge_label_index
29
+ z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1)
30
+ z = self.lin1(z).relu()
31
+ z = self.lin2(z)
32
+ return z.view(-1)
33
+
34
+ class Model(torch.nn.Module):
35
+ def __init__(self, hidden_channels):
36
+ super().__init__()
37
+ self.encoder = GNNEncoder(hidden_channels, hidden_channels)
38
+ self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
39
+ self.decoder = EdgeDecoder(hidden_channels)
40
+
41
+ def forward(self, x_dict, edge_index_dict, edge_label_index):
42
+ z_dict = self.encoder(x_dict, edge_index_dict)
43
+ return self.decoder(z_dict, edge_label_index)
movie_embeddings.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e80e58c0f25e6ec7fbe54e9668f233c7ddc3083f268cb21a4b6917ac09332cee
3
+ size 13863625
movie_embeddings_concat.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:74b32c1f6eb41a76cc7d392a1fc709b6e5c5002cf5c91c523f1906ca753319a0
3
+ size 14585708
requirements.txt ADDED
Binary file (3.35 kB). View file
 
visualizer.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import umap.umap_ as umap
2
+ import plotly.express as px
3
+ import pandas as pd
4
+ import random
5
+ import viz_utils
6
+ import torch
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch.nn import Linear
11
+ import torch_geometric.transforms as T
12
+ from torch_geometric.nn import SAGEConv, to_hetero
13
+ from torch_geometric.transforms import RandomLinkSplit, ToUndirected
14
+ from sentence_transformers import SentenceTransformer
15
+ from torch_geometric.data import HeteroData
16
+ import yaml
17
+
18
+
19
+
20
+ data = torch.load("./PyGdata.pt")
21
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22
+
23
+
24
+
25
+ movies_df = pd.read_csv("./sampled_movie_dataset/movies_metadata.csv")
26
+
27
+ class GNNEncoder(torch.nn.Module):
28
+ def __init__(self, hidden_channels, out_channels):
29
+ super().__init__()
30
+ # these convolutions have been replicated to match the number of edge types
31
+ self.conv1 = SAGEConv((-1, -1), hidden_channels)
32
+ self.conv2 = SAGEConv((-1, -1), out_channels)
33
+
34
+ def forward(self, x, edge_index):
35
+ x = self.conv1(x, edge_index).relu()
36
+ x = self.conv2(x, edge_index)
37
+ return x
38
+
39
+ class EdgeDecoder(torch.nn.Module):
40
+ def __init__(self, hidden_channels):
41
+ super().__init__()
42
+ self.lin1 = Linear(2 * hidden_channels, hidden_channels)
43
+ self.lin2 = Linear(hidden_channels, 1)
44
+
45
+ def forward(self, z_dict, edge_label_index):
46
+ row, col = edge_label_index
47
+ # concat user and movie embeddings
48
+ z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1)
49
+ # concatenated embeddings passed to linear layer
50
+ z = self.lin1(z).relu()
51
+ z = self.lin2(z)
52
+ return z.view(-1)
53
+
54
+ class Model(torch.nn.Module):
55
+ def __init__(self, hidden_channels):
56
+ super().__init__()
57
+ self.encoder = GNNEncoder(hidden_channels, hidden_channels)
58
+ self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
59
+ self.decoder = EdgeDecoder(hidden_channels)
60
+
61
+ def forward(self, x_dict, edge_index_dict, edge_label_index):
62
+ # z_dict contains dictionary of movie and user embeddings returned from GraphSage
63
+ z_dict = self.encoder(x_dict, edge_index_dict)
64
+ return self.decoder(z_dict, edge_label_index)
65
+
66
+ model = Model(hidden_channels=32).to(device)
67
+ model2 = Model(hidden_channels=32).to(device)
68
+ model.load_state_dict(torch.load("PyGTrainedModelState.pt"))
69
+ model.eval()
70
+
71
+ total_users = data['user'].num_nodes
72
+ total_movies = data['movie'].num_nodes
73
+
74
+ print("total users =", total_users)
75
+ print("total movies =", total_movies)
76
+
77
+
78
+
79
+ with torch.no_grad():
80
+ a = model.encoder(data.x_dict,data.edge_index_dict)
81
+ user = pd.DataFrame(a['user'].detach().cpu())
82
+ movie = pd.DataFrame(a['movie'].detach().cpu())
83
+ embedding_df = pd.concat([user, movie], axis=0)
84
+
85
+
86
+ movie_index = 20
87
+ title = movies_df.iloc[movie_index]['title']
88
+ print(title)
89
+
90
+
91
+ fig_umap = viz_utils.visualize_embeddings_umap(embedding_df)
92
+ viz_utils.save_visualization(fig_umap, "./Visualizations/umap_visualization")
93
+
94
+ fig_tsne = viz_utils.visualize_embeddings_tsne(embedding_df)
95
+ viz_utils.save_visualization(fig_tsne, "./Visualizations/tsne_visualization")
96
+
97
+ fig_pca = viz_utils.visualize_embeddings_pca(embedding_df)
98
+ viz_utils.save_visualization(fig_pca, "./Visualizations/pca_visualization")
viz_utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import umap.umap_ as umap
2
+ import plotly.express as px
3
+ import pandas as pd
4
+ import random
5
+ import numpy
6
+ from sklearn.manifold import TSNE
7
+ from sklearn.decomposition import PCA
8
+ import os
9
+
10
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
11
+
12
+ movies_df = pd.read_csv("./sampled_movie_dataset/movies_metadata.csv")
13
+
14
+
15
+ ##all_genres = movies_df['genres'].unique().tolist() # Adjust the column name if needed
16
+ genres = movies_df['genres'].tolist()[671:] # Offset to start at movies
17
+
18
+
19
+
20
+ ##can't get to work for coloring by genre
21
+ def get_genre_for_movie(movie_index):
22
+ genres_str = movies_df.iloc[movie_index]['genres']
23
+ # You might need to parse genres_str if it's not a simple list
24
+ return genres_str # Or a list of genres
25
+
26
+ print(get_genre_for_movie(20))
27
+
28
+
29
+
30
+ def visualize_embeddings_umap(embedding_df, n_neighbors=15, min_dist=0.1, n_components=3):
31
+ # Convert Series to DataFrame
32
+ #embedding_df = pd.DataFrame(embedding_series.tolist(), columns=[f'dim_{i+1}' for i in range(len(embedding_series[0]))])
33
+ # Perform UMAP dimensionality reduction
34
+ umap_embedded = umap.UMAP(
35
+ n_neighbors=n_neighbors,
36
+ min_dist=min_dist,
37
+ n_components=n_components,
38
+ random_state=42,
39
+ ).fit_transform(embedding_df.values)
40
+
41
+
42
+ # Plot the UMAP embedding
43
+ umap_df = pd.DataFrame(umap_embedded, columns=['UMAP Dimension 1', 'UMAP Dimension 2', 'UMAP Dimension 3'])
44
+ umap_df['Label'] = embedding_df.index
45
+
46
+
47
+ color = [0]*671 + [1]*9025
48
+ umap_df['color'] = color
49
+
50
+ # Plot the UMAP embedding using Plotly Express
51
+ fig = px.scatter_3d(umap_df, x='UMAP Dimension 1', y='UMAP Dimension 2',z='UMAP Dimension 3',color='color',hover_data=['Label'], title='UMAP Visualization of Embeddings')
52
+ return fig
53
+
54
+ def visualize_embeddings_tsne(embedding_df, n_components=3, perplexity=30.0, early_exaggeration=12.0, learning_rate=200.0):
55
+ # Perform t-SNE dimensionality reduction
56
+ tsne_embedded = TSNE(
57
+ n_components=n_components,
58
+ perplexity=perplexity,
59
+ early_exaggeration=early_exaggeration,
60
+ learning_rate=learning_rate,
61
+ random_state=42,
62
+ ).fit_transform(embedding_df.values)
63
+
64
+ # Plot the t-SNE embedding
65
+ tsne_df = pd.DataFrame(tsne_embedded, columns=[f't-SNE Dimension {i+1}' for i in range(n_components)])
66
+ tsne_df['Label'] = embedding_df.index
67
+
68
+ # Add color column (adjust how colors are applied based on your data)
69
+ tsne_df['color'] = [0]*671 + [1]*9025
70
+
71
+ fig = px.scatter_3d(tsne_df, x='t-SNE Dimension 1', y='t-SNE Dimension 2', z='t-SNE Dimension 3', color='color', hover_data=['Label'], title='t-SNE Visualization of Embeddings')
72
+ return fig
73
+
74
+
75
+ def visualize_embeddings_pca(embedding_df, n_components=3):
76
+ # Perform PCA
77
+ pca = PCA(n_components=n_components, random_state=42)
78
+ pca_embedded = pca.fit_transform(embedding_df.values)
79
+
80
+ # Plot the PCA embedding
81
+ pca_df = pd.DataFrame(pca_embedded, columns=[f'PCA Dimension {i+1}' for i in range(n_components)])
82
+ pca_df['Label'] = embedding_df.index
83
+
84
+ # Add color column (adjust how colors are applied based on your data)
85
+ pca_df['color'] = [0]*671 + [1]*9025
86
+
87
+ fig = px.scatter_3d(pca_df, x='PCA Dimension 1', y='PCA Dimension 2', z='PCA Dimension 3', color='color', hover_data=['Label'], title='PCA Visualization of Embeddings')
88
+ return fig
89
+
90
+
91
+
92
+
93
+ def save_visualization(fig, filename):
94
+ fig.write_html(f"{filename}.html")
95
+
96
+