AJ-Gazin commited on
Commit
0f19412
·
1 Parent(s): e4f5e4c

Still figuring out this CPU thing.

Browse files
Files changed (3) hide show
  1. app.py +2 -2
  2. model_def.py +1 -1
  3. visualizer.py +2 -2
app.py CHANGED
@@ -30,9 +30,9 @@ movies_df = pd.read_csv("./sampled_movie_dataset/movies_metadata.csv") # Load y
30
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31
 
32
 
33
- data = torch.load("./PyGdata.pt", map_location=device('cpu'))
34
  model = model_def.Model(hidden_channels=32).to(device)
35
- model.load_state_dict(torch.load("PyGTrainedModelState.pt", map_location=device('cpu')))
36
  model.eval()
37
 
38
  # --- STREAMLIT APP ---
 
30
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31
 
32
 
33
+ data = torch.load("./PyGdata.pt", map_location=device)
34
  model = model_def.Model(hidden_channels=32).to(device)
35
+ model.load_state_dict(torch.load("PyGTrainedModelState.pt", map_location=device))
36
  model.eval()
37
 
38
  # --- STREAMLIT APP ---
model_def.py CHANGED
@@ -6,7 +6,7 @@ from torch_geometric.nn import SAGEConv, to_hetero, Linear
6
  from dotenv import load_dotenv
7
 
8
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
- data = torch.load("./PyGdata.pt", map_location=device('cpu'))
10
 
11
  class GNNEncoder(torch.nn.Module):
12
  def __init__(self, hidden_channels, out_channels):
 
6
  from dotenv import load_dotenv
7
 
8
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
+ data = torch.load("./PyGdata.pt", map_location=device)
10
 
11
  class GNNEncoder(torch.nn.Module):
12
  def __init__(self, hidden_channels, out_channels):
visualizer.py CHANGED
@@ -19,7 +19,7 @@ import yaml
19
 
20
 
21
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22
- data = torch.load("./PyGdata.pt", map_location=device('cpu'))
23
 
24
 
25
  movies_df = pd.read_csv("./sampled_movie_dataset/movies_metadata.csv")
@@ -65,7 +65,7 @@ class Model(torch.nn.Module):
65
 
66
  model = Model(hidden_channels=32).to(device)
67
  model2 = Model(hidden_channels=32).to(device)
68
- model.load_state_dict(torch.load("PyGTrainedModelState.pt"), map_location=device('cpu'))
69
  model.eval()
70
 
71
  total_users = data['user'].num_nodes
 
19
 
20
 
21
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22
+ data = torch.load("./PyGdata.pt", map_location=device)
23
 
24
 
25
  movies_df = pd.read_csv("./sampled_movie_dataset/movies_metadata.csv")
 
65
 
66
  model = Model(hidden_channels=32).to(device)
67
  model2 = Model(hidden_channels=32).to(device)
68
+ model.load_state_dict(torch.load("PyGTrainedModelState.pt"), map_location=device)
69
  model.eval()
70
 
71
  total_users = data['user'].num_nodes