File size: 4,815 Bytes
960b542
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56eea6d
 
960b542
 
 
f709b0a
960b542
 
d134ba2
 
0f19412
960b542
0f19412
960b542
 
 
 
 
 
 
 
8138b99
 
960b542
8138b99
 
960b542
8138b99
 
960b542
 
 
 
 
 
 
 
 
 
 
 
 
 
8138b99
 
 
960b542
 
 
 
8138b99
 
 
960b542
 
 
 
8138b99
 
 
960b542
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142

import streamlit as st
import streamlit.components.v1 as components
import pandas as pd
import torch
import requests
import random
from io import BytesIO
from PIL import Image
from torch_geometric.nn import SAGEConv, to_hetero, Linear
from dotenv import load_dotenv
import os

from IPython.display import HTML

import viz_utils
import model_def

load_dotenv() #load environment variables from .env file

##no clue why this is necessary. But won't see subfolders without it. Just on my laptop. 
os.chdir(os.path.dirname(os.path.abspath(__file__)))

#API_KEY = os.getenv("HUGGINGFACE_API_KEY")
API_KEY =  os.environ["HUGGINGFACE_API_KEY"]
API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"

# --- LOAD DATA AND MODEL ---
# map_location forces the model to be loaded on the CPU for huggingface compatibility
movies_df = pd.read_csv("./sampled_movie_dataset/movies_metadata.csv")  # Load your movie data
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


data = torch.load("./PyGdata.pt", map_location=device)
model = model_def.Model(hidden_channels=32).to(device) 
model.load_state_dict(torch.load("PyGTrainedModelState.pt", map_location=device))
model.eval()

# --- STREAMLIT APP ---
st.title("Movie Recommendation App")



# --- VISUALIZATIONS ---
with open("./Visualizations/umap_visualization.html", "r", encoding='utf-8') as f:
   umap_html = f.read()

with open("./Visualizations/tsne_visualization.html", "r") as f:
    tsne_html = f.read()

with open("./Visualizations/pca_visualization.html", "r") as f:
    pca_html = f.read()

tab1, tab2 = st.tabs(["Visualizations", "Recommendations"])
                      

with torch.no_grad():
    a = model.encoder(data.x_dict,data.edge_index_dict)
    user = pd.DataFrame(a['user'].detach().cpu())
    movie = pd.DataFrame(a['movie'].detach().cpu())
    embedding_df = pd.concat([user, movie], axis=0)

with tab1:
    umap_expander = st.expander("UMAP Visualization")
    with umap_expander:
        st.subheader('UMAP Visualization')
        #umap_fig = viz_utils.visualize_embeddings_umap(embedding_df)
        #st.plotly_chart(umap_fig)
        components.html(umap_html, width=800, height=800)

    tsne_expander = st.expander("TSNE Visualization")
    with tsne_expander:
        st.subheader('TSNE Visualization')
        #tsne_fig = viz_utils.visualize_embeddings_tsne(embedding_df)
        #st.plotly_chart(tsne_fig)
        components.html(tsne_html, width=800, height=800)

    pca_expander = st.expander("PCA Visualization")
    with pca_expander:
        st.subheader('PCA Visualization')
        #pca_fig = viz_utils.visualize_embeddings_pca(embedding_df)
        #st.plotly_chart(pca_fig)
        components.html(pca_html, width=800, height=800)




def get_movie_recommendations(model, data, user_id, total_movies):
    user_row = torch.tensor([user_id] * total_movies).to(device)
    all_movie_ids = torch.arange(total_movies).to(device)
    edge_label_index = torch.stack([user_row, all_movie_ids], dim=0)

    pred = model(data.x_dict, data.edge_index_dict, edge_label_index).to('cpu')
    top_five_indices = pred.topk(5).indices

    recommended_movies = movies_df.iloc[top_five_indices]
    return recommended_movies

def generate_poster(movie_title):
    headers = {"Authorization": f"Bearer {API_KEY}"}

    #creates random seed so movie poster changes on refresh even if same title. 
    seed = random.randint(0, 2**32 - 1)
    payload = {
        "inputs": movie_title,
        # "parameters": {
        #     "seed": seed
        # }
    }

    try:
        response = requests.post(API_URL, headers=headers, json=payload)
        response.raise_for_status()  # Raise an error if the request fails

        # Display the generated image
        image = Image.open(BytesIO(response.content))
        st.image(image, caption=movie_title)

    except requests.exceptions.HTTPError as err:
        st.error(f"Image generation failed: {err}")

with tab2:
    user_id = st.number_input("Enter the User ID:", min_value=0)
    if st.button("Get Recommendations"):
        st.write("Top 5 Recommendations:")
        try:
            total_movies = data['movie'].num_nodes  
            recommended_movies = get_movie_recommendations(model, data, user_id, total_movies)
            cols = st.columns(3)  

        
            for i, row in recommended_movies.iterrows():
                with cols[i % 3]: 
                    #st.write(f"{i+1}. {row['title']}") 
                    try:
                        image = generate_poster(row['title'])
                    except requests.exceptions.HTTPError as err:
                        st.error(f"Image generation failed for {row['title']}: {err}")

        except Exception as e:
            st.error(f"An error occurred: {e}")