import streamlit as st import streamlit.components.v1 as components import pandas as pd import torch import requests import random from io import BytesIO from PIL import Image from torch_geometric.nn import SAGEConv, to_hetero, Linear from dotenv import load_dotenv import os from IPython.display import HTML import viz_utils import model_def load_dotenv() #load environment variables from .env file ##no clue why this is necessary. But won't see subfolders without it. Just on my laptop. os.chdir(os.path.dirname(os.path.abspath(__file__))) API_KEY = os.getenv("HUGGINGFACE_API_KEY") API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0" # --- LOAD DATA AND MODEL --- # map_location forces the model to be loaded on the CPU for huggingface compatibility movies_df = pd.read_csv("./sampled_movie_dataset/movies_metadata.csv") # Load your movie data device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') data = torch.load("./PyGdata.pt", map_location=device) model = model_def.Model(hidden_channels=32).to(device) model.load_state_dict(torch.load("PyGTrainedModelState.pt", map_location=device)) model.eval() # --- STREAMLIT APP --- st.title("Movie Recommendation App") # --- VISUALIZATIONS --- with open("./Visualizations/umap_visualization.html", "r", encoding='utf-8') as f: umap_html = f.read() with open("./Visualizations/tsne_visualization.html", "r") as f: tsne_html = f.read() with open("./Visualizations/pca_visualization.html", "r") as f: pca_html = f.read() tab1, tab2 = st.tabs(["Visualizations", "Recommendations"]) with torch.no_grad(): a = model.encoder(data.x_dict,data.edge_index_dict) user = pd.DataFrame(a['user'].detach().cpu()) movie = pd.DataFrame(a['movie'].detach().cpu()) embedding_df = pd.concat([user, movie], axis=0) with tab1: umap_expander = st.expander("UMAP Visualization") with umap_expander: st.subheader('UMAP Visualization') #umap_fig = viz_utils.visualize_embeddings_umap(embedding_df) #st.plotly_chart(umap_fig) components.html(umap_html, width=800, height=800) tsne_expander = st.expander("TSNE Visualization") with tsne_expander: st.subheader('TSNE Visualization') #tsne_fig = viz_utils.visualize_embeddings_tsne(embedding_df) #st.plotly_chart(tsne_fig) components.html(tsne_html, width=800, height=800) pca_expander = st.expander("PCA Visualization") with pca_expander: st.subheader('PCA Visualization') #pca_fig = viz_utils.visualize_embeddings_pca(embedding_df) #st.plotly_chart(pca_fig) components.html(pca_html, width=800, height=800) def get_movie_recommendations(model, data, user_id, total_movies): user_row = torch.tensor([user_id] * total_movies).to(device) all_movie_ids = torch.arange(total_movies).to(device) edge_label_index = torch.stack([user_row, all_movie_ids], dim=0) pred = model(data.x_dict, data.edge_index_dict, edge_label_index).to('cpu') top_five_indices = pred.topk(5).indices recommended_movies = movies_df.iloc[top_five_indices] return recommended_movies def generate_poster(movie_title): headers = {"Authorization": f"Bearer {API_KEY}"} #creates random seed so movie poster changes on refresh even if same title. seed = random.randint(0, 2**32 - 1) payload = { "inputs": movie_title, # "parameters": { # "seed": seed # } } try: response = requests.post(API_URL, headers=headers, json=payload) response.raise_for_status() # Raise an error if the request fails # Display the generated image image = Image.open(BytesIO(response.content)) st.image(image, caption=movie_title) except requests.exceptions.HTTPError as err: st.error(f"Image generation failed: {err}") with tab2: user_id = st.number_input("Enter the User ID:", min_value=0) if st.button("Get Recommendations"): st.write("Top 5 Recommendations:") try: total_movies = data['movie'].num_nodes recommended_movies = get_movie_recommendations(model, data, user_id, total_movies) cols = st.columns(3) for i, row in recommended_movies.iterrows(): with cols[i % 3]: #st.write(f"{i+1}. {row['title']}") try: image = generate_poster(row['title']) except requests.exceptions.HTTPError as err: st.error(f"Image generation failed for {row['title']}: {err}") except Exception as e: st.error(f"An error occurred: {e}")