AJ-Gazin
Saved viz's to HTML files, directly loading now
8138b99
raw
history blame
4.77 kB
import streamlit as st
import streamlit.components.v1 as components
import pandas as pd
import torch
import requests
import random
from io import BytesIO
from PIL import Image
from torch_geometric.nn import SAGEConv, to_hetero, Linear
from dotenv import load_dotenv
import os
from IPython.display import HTML
import viz_utils
import model_def
load_dotenv() #load environment variables from .env file
##no clue why this is necessary. But won't see subfolders without it. Just on my laptop.
os.chdir(os.path.dirname(os.path.abspath(__file__)))
API_KEY = os.getenv("HUGGINGFACE_API_KEY")
API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
# --- LOAD DATA AND MODEL ---
# map_location forces the model to be loaded on the CPU for huggingface compatibility
movies_df = pd.read_csv("./sampled_movie_dataset/movies_metadata.csv") # Load your movie data
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = torch.load("./PyGdata.pt", map_location=device)
model = model_def.Model(hidden_channels=32).to(device)
model.load_state_dict(torch.load("PyGTrainedModelState.pt", map_location=device))
model.eval()
# --- STREAMLIT APP ---
st.title("Movie Recommendation App")
# --- VISUALIZATIONS ---
with open("./Visualizations/umap_visualization.html", "r", encoding='utf-8') as f:
umap_html = f.read()
with open("./Visualizations/tsne_visualization.html", "r") as f:
tsne_html = f.read()
with open("./Visualizations/pca_visualization.html", "r") as f:
pca_html = f.read()
tab1, tab2 = st.tabs(["Visualizations", "Recommendations"])
with torch.no_grad():
a = model.encoder(data.x_dict,data.edge_index_dict)
user = pd.DataFrame(a['user'].detach().cpu())
movie = pd.DataFrame(a['movie'].detach().cpu())
embedding_df = pd.concat([user, movie], axis=0)
with tab1:
umap_expander = st.expander("UMAP Visualization")
with umap_expander:
st.subheader('UMAP Visualization')
#umap_fig = viz_utils.visualize_embeddings_umap(embedding_df)
#st.plotly_chart(umap_fig)
components.html(umap_html, width=800, height=800)
tsne_expander = st.expander("TSNE Visualization")
with tsne_expander:
st.subheader('TSNE Visualization')
#tsne_fig = viz_utils.visualize_embeddings_tsne(embedding_df)
#st.plotly_chart(tsne_fig)
components.html(tsne_html, width=800, height=800)
pca_expander = st.expander("PCA Visualization")
with pca_expander:
st.subheader('PCA Visualization')
#pca_fig = viz_utils.visualize_embeddings_pca(embedding_df)
#st.plotly_chart(pca_fig)
components.html(pca_html, width=800, height=800)
def get_movie_recommendations(model, data, user_id, total_movies):
user_row = torch.tensor([user_id] * total_movies).to(device)
all_movie_ids = torch.arange(total_movies).to(device)
edge_label_index = torch.stack([user_row, all_movie_ids], dim=0)
pred = model(data.x_dict, data.edge_index_dict, edge_label_index).to('cpu')
top_five_indices = pred.topk(5).indices
recommended_movies = movies_df.iloc[top_five_indices]
return recommended_movies
def generate_poster(movie_title):
headers = {"Authorization": f"Bearer {API_KEY}"}
#creates random seed so movie poster changes on refresh even if same title.
seed = random.randint(0, 2**32 - 1)
payload = {
"inputs": movie_title,
# "parameters": {
# "seed": seed
# }
}
try:
response = requests.post(API_URL, headers=headers, json=payload)
response.raise_for_status() # Raise an error if the request fails
# Display the generated image
image = Image.open(BytesIO(response.content))
st.image(image, caption=movie_title)
except requests.exceptions.HTTPError as err:
st.error(f"Image generation failed: {err}")
with tab2:
user_id = st.number_input("Enter the User ID:", min_value=0)
if st.button("Get Recommendations"):
st.write("Top 5 Recommendations:")
try:
total_movies = data['movie'].num_nodes
recommended_movies = get_movie_recommendations(model, data, user_id, total_movies)
cols = st.columns(3)
for i, row in recommended_movies.iterrows():
with cols[i % 3]:
#st.write(f"{i+1}. {row['title']}")
try:
image = generate_poster(row['title'])
except requests.exceptions.HTTPError as err:
st.error(f"Image generation failed for {row['title']}: {err}")
except Exception as e:
st.error(f"An error occurred: {e}")