Spaces:
Sleeping
Sleeping
AJ-Gazin
Merge branch 'main' of https://huggingface.co/spaces/sohvren/MovieRecommenderV2
45cab06
import streamlit as st | |
import streamlit.components.v1 as components | |
import pandas as pd | |
import torch | |
import requests | |
import random | |
from io import BytesIO | |
from PIL import Image | |
from torch_geometric.nn import SAGEConv, to_hetero, Linear | |
from dotenv import load_dotenv | |
import os | |
from IPython.display import HTML | |
import viz_utils | |
import model_def | |
load_dotenv() #load environment variables from .env file | |
##no clue why this is necessary. But won't see subfolders without it. Just on my laptop. | |
os.chdir(os.path.dirname(os.path.abspath(__file__))) | |
#API_KEY = os.getenv("HUGGINGFACE_API_KEY") | |
API_KEY = os.environ["HUGGINGFACE_API_KEY"] | |
API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0" | |
# --- LOAD DATA AND MODEL --- | |
# map_location forces the model to be loaded on the CPU for huggingface compatibility | |
movies_df = pd.read_csv("./sampled_movie_dataset/movies_metadata.csv") # Load your movie data | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
data = torch.load("./PyGdata.pt", map_location=device) | |
model = model_def.Model(hidden_channels=32).to(device) | |
model.load_state_dict(torch.load("PyGTrainedModelState.pt", map_location=device)) | |
model.eval() | |
# --- STREAMLIT APP --- | |
st.title("Movie Recommendation App") | |
# --- VISUALIZATIONS --- | |
with open("./Visualizations/umap_visualization.html", "r", encoding='utf-8') as f: | |
umap_html = f.read() | |
with open("./Visualizations/tsne_visualization.html", "r") as f: | |
tsne_html = f.read() | |
with open("./Visualizations/pca_visualization.html", "r") as f: | |
pca_html = f.read() | |
tab1, tab2 = st.tabs(["Visualizations", "Recommendations"]) | |
with torch.no_grad(): | |
a = model.encoder(data.x_dict,data.edge_index_dict) | |
user = pd.DataFrame(a['user'].detach().cpu()) | |
movie = pd.DataFrame(a['movie'].detach().cpu()) | |
embedding_df = pd.concat([user, movie], axis=0) | |
with tab1: | |
umap_expander = st.expander("UMAP Visualization") | |
with umap_expander: | |
st.subheader('UMAP Visualization') | |
#umap_fig = viz_utils.visualize_embeddings_umap(embedding_df) | |
#st.plotly_chart(umap_fig) | |
components.html(umap_html, width=800, height=800) | |
tsne_expander = st.expander("TSNE Visualization") | |
with tsne_expander: | |
st.subheader('TSNE Visualization') | |
#tsne_fig = viz_utils.visualize_embeddings_tsne(embedding_df) | |
#st.plotly_chart(tsne_fig) | |
components.html(tsne_html, width=800, height=800) | |
pca_expander = st.expander("PCA Visualization") | |
with pca_expander: | |
st.subheader('PCA Visualization') | |
#pca_fig = viz_utils.visualize_embeddings_pca(embedding_df) | |
#st.plotly_chart(pca_fig) | |
components.html(pca_html, width=800, height=800) | |
def get_movie_recommendations(model, data, user_id, total_movies): | |
user_row = torch.tensor([user_id] * total_movies).to(device) | |
all_movie_ids = torch.arange(total_movies).to(device) | |
edge_label_index = torch.stack([user_row, all_movie_ids], dim=0) | |
pred = model(data.x_dict, data.edge_index_dict, edge_label_index).to('cpu') | |
top_five_indices = pred.topk(5).indices | |
recommended_movies = movies_df.iloc[top_five_indices] | |
return recommended_movies | |
def generate_poster(movie_title): | |
headers = {"Authorization": f"Bearer {API_KEY}"} | |
#creates random seed so movie poster changes on refresh even if same title. | |
seed = random.randint(0, 2**32 - 1) | |
payload = { | |
"inputs": movie_title, | |
# "parameters": { | |
# "seed": seed | |
# } | |
} | |
try: | |
response = requests.post(API_URL, headers=headers, json=payload) | |
response.raise_for_status() # Raise an error if the request fails | |
# Display the generated image | |
image = Image.open(BytesIO(response.content)) | |
st.image(image, caption=movie_title) | |
except requests.exceptions.HTTPError as err: | |
st.error(f"Image generation failed: {err}") | |
with tab2: | |
user_id = st.number_input("Enter the User ID:", min_value=0) | |
if st.button("Get Recommendations"): | |
st.write("Top 5 Recommendations:") | |
try: | |
total_movies = data['movie'].num_nodes | |
recommended_movies = get_movie_recommendations(model, data, user_id, total_movies) | |
cols = st.columns(3) | |
for i, row in recommended_movies.iterrows(): | |
with cols[i % 3]: | |
#st.write(f"{i+1}. {row['title']}") | |
try: | |
image = generate_poster(row['title']) | |
except requests.exceptions.HTTPError as err: | |
st.error(f"Image generation failed for {row['title']}: {err}") | |
except Exception as e: | |
st.error(f"An error occurred: {e}") | |