AJ-Gazin
Merge branch 'main' of https://huggingface.co/spaces/sohvren/MovieRecommenderV2
45cab06
import streamlit as st
import streamlit.components.v1 as components
import pandas as pd
import torch
import requests
import random
from io import BytesIO
from PIL import Image
from torch_geometric.nn import SAGEConv, to_hetero, Linear
from dotenv import load_dotenv
import os
from IPython.display import HTML
import viz_utils
import model_def
load_dotenv() #load environment variables from .env file
##no clue why this is necessary. But won't see subfolders without it. Just on my laptop.
os.chdir(os.path.dirname(os.path.abspath(__file__)))
#API_KEY = os.getenv("HUGGINGFACE_API_KEY")
API_KEY = os.environ["HUGGINGFACE_API_KEY"]
API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
# --- LOAD DATA AND MODEL ---
# map_location forces the model to be loaded on the CPU for huggingface compatibility
movies_df = pd.read_csv("./sampled_movie_dataset/movies_metadata.csv") # Load your movie data
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = torch.load("./PyGdata.pt", map_location=device)
model = model_def.Model(hidden_channels=32).to(device)
model.load_state_dict(torch.load("PyGTrainedModelState.pt", map_location=device))
model.eval()
# --- STREAMLIT APP ---
st.title("Movie Recommendation App")
# --- VISUALIZATIONS ---
with open("./Visualizations/umap_visualization.html", "r", encoding='utf-8') as f:
umap_html = f.read()
with open("./Visualizations/tsne_visualization.html", "r") as f:
tsne_html = f.read()
with open("./Visualizations/pca_visualization.html", "r") as f:
pca_html = f.read()
tab1, tab2 = st.tabs(["Visualizations", "Recommendations"])
with torch.no_grad():
a = model.encoder(data.x_dict,data.edge_index_dict)
user = pd.DataFrame(a['user'].detach().cpu())
movie = pd.DataFrame(a['movie'].detach().cpu())
embedding_df = pd.concat([user, movie], axis=0)
with tab1:
umap_expander = st.expander("UMAP Visualization")
with umap_expander:
st.subheader('UMAP Visualization')
#umap_fig = viz_utils.visualize_embeddings_umap(embedding_df)
#st.plotly_chart(umap_fig)
components.html(umap_html, width=800, height=800)
tsne_expander = st.expander("TSNE Visualization")
with tsne_expander:
st.subheader('TSNE Visualization')
#tsne_fig = viz_utils.visualize_embeddings_tsne(embedding_df)
#st.plotly_chart(tsne_fig)
components.html(tsne_html, width=800, height=800)
pca_expander = st.expander("PCA Visualization")
with pca_expander:
st.subheader('PCA Visualization')
#pca_fig = viz_utils.visualize_embeddings_pca(embedding_df)
#st.plotly_chart(pca_fig)
components.html(pca_html, width=800, height=800)
def get_movie_recommendations(model, data, user_id, total_movies):
user_row = torch.tensor([user_id] * total_movies).to(device)
all_movie_ids = torch.arange(total_movies).to(device)
edge_label_index = torch.stack([user_row, all_movie_ids], dim=0)
pred = model(data.x_dict, data.edge_index_dict, edge_label_index).to('cpu')
top_five_indices = pred.topk(5).indices
recommended_movies = movies_df.iloc[top_five_indices]
return recommended_movies
def generate_poster(movie_title):
headers = {"Authorization": f"Bearer {API_KEY}"}
#creates random seed so movie poster changes on refresh even if same title.
seed = random.randint(0, 2**32 - 1)
payload = {
"inputs": movie_title,
# "parameters": {
# "seed": seed
# }
}
try:
response = requests.post(API_URL, headers=headers, json=payload)
response.raise_for_status() # Raise an error if the request fails
# Display the generated image
image = Image.open(BytesIO(response.content))
st.image(image, caption=movie_title)
except requests.exceptions.HTTPError as err:
st.error(f"Image generation failed: {err}")
with tab2:
user_id = st.number_input("Enter the User ID:", min_value=0)
if st.button("Get Recommendations"):
st.write("Top 5 Recommendations:")
try:
total_movies = data['movie'].num_nodes
recommended_movies = get_movie_recommendations(model, data, user_id, total_movies)
cols = st.columns(3)
for i, row in recommended_movies.iterrows():
with cols[i % 3]:
#st.write(f"{i+1}. {row['title']}")
try:
image = generate_poster(row['title'])
except requests.exceptions.HTTPError as err:
st.error(f"Image generation failed for {row['title']}: {err}")
except Exception as e:
st.error(f"An error occurred: {e}")