Spaces:
Sleeping
Sleeping
File size: 4,815 Bytes
960b542 56eea6d 960b542 f709b0a 960b542 d134ba2 0f19412 960b542 0f19412 960b542 8138b99 960b542 8138b99 960b542 8138b99 960b542 8138b99 960b542 8138b99 960b542 8138b99 960b542 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import streamlit as st
import streamlit.components.v1 as components
import pandas as pd
import torch
import requests
import random
from io import BytesIO
from PIL import Image
from torch_geometric.nn import SAGEConv, to_hetero, Linear
from dotenv import load_dotenv
import os
from IPython.display import HTML
import viz_utils
import model_def
load_dotenv() #load environment variables from .env file
##no clue why this is necessary. But won't see subfolders without it. Just on my laptop.
os.chdir(os.path.dirname(os.path.abspath(__file__)))
#API_KEY = os.getenv("HUGGINGFACE_API_KEY")
API_KEY = os.environ["HUGGINGFACE_API_KEY"]
API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
# --- LOAD DATA AND MODEL ---
# map_location forces the model to be loaded on the CPU for huggingface compatibility
movies_df = pd.read_csv("./sampled_movie_dataset/movies_metadata.csv") # Load your movie data
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = torch.load("./PyGdata.pt", map_location=device)
model = model_def.Model(hidden_channels=32).to(device)
model.load_state_dict(torch.load("PyGTrainedModelState.pt", map_location=device))
model.eval()
# --- STREAMLIT APP ---
st.title("Movie Recommendation App")
# --- VISUALIZATIONS ---
with open("./Visualizations/umap_visualization.html", "r", encoding='utf-8') as f:
umap_html = f.read()
with open("./Visualizations/tsne_visualization.html", "r") as f:
tsne_html = f.read()
with open("./Visualizations/pca_visualization.html", "r") as f:
pca_html = f.read()
tab1, tab2 = st.tabs(["Visualizations", "Recommendations"])
with torch.no_grad():
a = model.encoder(data.x_dict,data.edge_index_dict)
user = pd.DataFrame(a['user'].detach().cpu())
movie = pd.DataFrame(a['movie'].detach().cpu())
embedding_df = pd.concat([user, movie], axis=0)
with tab1:
umap_expander = st.expander("UMAP Visualization")
with umap_expander:
st.subheader('UMAP Visualization')
#umap_fig = viz_utils.visualize_embeddings_umap(embedding_df)
#st.plotly_chart(umap_fig)
components.html(umap_html, width=800, height=800)
tsne_expander = st.expander("TSNE Visualization")
with tsne_expander:
st.subheader('TSNE Visualization')
#tsne_fig = viz_utils.visualize_embeddings_tsne(embedding_df)
#st.plotly_chart(tsne_fig)
components.html(tsne_html, width=800, height=800)
pca_expander = st.expander("PCA Visualization")
with pca_expander:
st.subheader('PCA Visualization')
#pca_fig = viz_utils.visualize_embeddings_pca(embedding_df)
#st.plotly_chart(pca_fig)
components.html(pca_html, width=800, height=800)
def get_movie_recommendations(model, data, user_id, total_movies):
user_row = torch.tensor([user_id] * total_movies).to(device)
all_movie_ids = torch.arange(total_movies).to(device)
edge_label_index = torch.stack([user_row, all_movie_ids], dim=0)
pred = model(data.x_dict, data.edge_index_dict, edge_label_index).to('cpu')
top_five_indices = pred.topk(5).indices
recommended_movies = movies_df.iloc[top_five_indices]
return recommended_movies
def generate_poster(movie_title):
headers = {"Authorization": f"Bearer {API_KEY}"}
#creates random seed so movie poster changes on refresh even if same title.
seed = random.randint(0, 2**32 - 1)
payload = {
"inputs": movie_title,
# "parameters": {
# "seed": seed
# }
}
try:
response = requests.post(API_URL, headers=headers, json=payload)
response.raise_for_status() # Raise an error if the request fails
# Display the generated image
image = Image.open(BytesIO(response.content))
st.image(image, caption=movie_title)
except requests.exceptions.HTTPError as err:
st.error(f"Image generation failed: {err}")
with tab2:
user_id = st.number_input("Enter the User ID:", min_value=0)
if st.button("Get Recommendations"):
st.write("Top 5 Recommendations:")
try:
total_movies = data['movie'].num_nodes
recommended_movies = get_movie_recommendations(model, data, user_id, total_movies)
cols = st.columns(3)
for i, row in recommended_movies.iterrows():
with cols[i % 3]:
#st.write(f"{i+1}. {row['title']}")
try:
image = generate_poster(row['title'])
except requests.exceptions.HTTPError as err:
st.error(f"Image generation failed for {row['title']}: {err}")
except Exception as e:
st.error(f"An error occurred: {e}")
|