Spaces:
Sleeping
Sleeping
AJ-Gazin
commited on
Commit
·
0f19412
1
Parent(s):
e4f5e4c
Still figuring out this CPU thing.
Browse files- app.py +2 -2
- model_def.py +1 -1
- visualizer.py +2 -2
app.py
CHANGED
@@ -30,9 +30,9 @@ movies_df = pd.read_csv("./sampled_movie_dataset/movies_metadata.csv") # Load y
|
|
30 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
31 |
|
32 |
|
33 |
-
data = torch.load("./PyGdata.pt", map_location=device
|
34 |
model = model_def.Model(hidden_channels=32).to(device)
|
35 |
-
model.load_state_dict(torch.load("PyGTrainedModelState.pt", map_location=device
|
36 |
model.eval()
|
37 |
|
38 |
# --- STREAMLIT APP ---
|
|
|
30 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
31 |
|
32 |
|
33 |
+
data = torch.load("./PyGdata.pt", map_location=device)
|
34 |
model = model_def.Model(hidden_channels=32).to(device)
|
35 |
+
model.load_state_dict(torch.load("PyGTrainedModelState.pt", map_location=device))
|
36 |
model.eval()
|
37 |
|
38 |
# --- STREAMLIT APP ---
|
model_def.py
CHANGED
@@ -6,7 +6,7 @@ from torch_geometric.nn import SAGEConv, to_hetero, Linear
|
|
6 |
from dotenv import load_dotenv
|
7 |
|
8 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
9 |
-
data = torch.load("./PyGdata.pt", map_location=device
|
10 |
|
11 |
class GNNEncoder(torch.nn.Module):
|
12 |
def __init__(self, hidden_channels, out_channels):
|
|
|
6 |
from dotenv import load_dotenv
|
7 |
|
8 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
9 |
+
data = torch.load("./PyGdata.pt", map_location=device)
|
10 |
|
11 |
class GNNEncoder(torch.nn.Module):
|
12 |
def __init__(self, hidden_channels, out_channels):
|
visualizer.py
CHANGED
@@ -19,7 +19,7 @@ import yaml
|
|
19 |
|
20 |
|
21 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
22 |
-
data = torch.load("./PyGdata.pt", map_location=device
|
23 |
|
24 |
|
25 |
movies_df = pd.read_csv("./sampled_movie_dataset/movies_metadata.csv")
|
@@ -65,7 +65,7 @@ class Model(torch.nn.Module):
|
|
65 |
|
66 |
model = Model(hidden_channels=32).to(device)
|
67 |
model2 = Model(hidden_channels=32).to(device)
|
68 |
-
model.load_state_dict(torch.load("PyGTrainedModelState.pt"), map_location=device
|
69 |
model.eval()
|
70 |
|
71 |
total_users = data['user'].num_nodes
|
|
|
19 |
|
20 |
|
21 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
22 |
+
data = torch.load("./PyGdata.pt", map_location=device)
|
23 |
|
24 |
|
25 |
movies_df = pd.read_csv("./sampled_movie_dataset/movies_metadata.csv")
|
|
|
65 |
|
66 |
model = Model(hidden_channels=32).to(device)
|
67 |
model2 = Model(hidden_channels=32).to(device)
|
68 |
+
model.load_state_dict(torch.load("PyGTrainedModelState.pt"), map_location=device)
|
69 |
model.eval()
|
70 |
|
71 |
total_users = data['user'].num_nodes
|