import streamlit as st
import streamlit.components.v1 as components
import pandas as pd
import torch
import requests
import random
from io import BytesIO
from PIL import Image
from torch_geometric.nn import SAGEConv, to_hetero, Linear
from dotenv import load_dotenv
import os

from IPython.display import HTML

import viz_utils
import model_def

load_dotenv() #load environment variables from .env file

##no clue why this is necessary. But won't see subfolders without it. Just on my laptop. 
os.chdir(os.path.dirname(os.path.abspath(__file__)))

#API_KEY = os.getenv("HUGGINGFACE_API_KEY")
API_KEY =  os.environ["HUGGINGFACE_API_KEY"]
API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"

# --- LOAD DATA AND MODEL ---
# map_location forces the model to be loaded on the CPU for huggingface compatibility
movies_df = pd.read_csv("./sampled_movie_dataset/movies_metadata.csv")  # Load your movie data
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


data = torch.load("./PyGdata.pt", map_location=device)
model = model_def.Model(hidden_channels=32).to(device) 
model.load_state_dict(torch.load("PyGTrainedModelState.pt", map_location=device))
model.eval()

# --- STREAMLIT APP ---
st.title("Movie Recommendation App")



# --- VISUALIZATIONS ---
with open("./Visualizations/umap_visualization.html", "r", encoding='utf-8') as f:
   umap_html = f.read()

with open("./Visualizations/tsne_visualization.html", "r") as f:
    tsne_html = f.read()

with open("./Visualizations/pca_visualization.html", "r") as f:
    pca_html = f.read()

tab1, tab2 = st.tabs(["Visualizations", "Recommendations"])
                      

with torch.no_grad():
    a = model.encoder(data.x_dict,data.edge_index_dict)
    user = pd.DataFrame(a['user'].detach().cpu())
    movie = pd.DataFrame(a['movie'].detach().cpu())
    embedding_df = pd.concat([user, movie], axis=0)

with tab1:
    umap_expander = st.expander("UMAP Visualization")
    with umap_expander:
        st.subheader('UMAP Visualization')
        #umap_fig = viz_utils.visualize_embeddings_umap(embedding_df)
        #st.plotly_chart(umap_fig)
        components.html(umap_html, width=800, height=800)

    tsne_expander = st.expander("TSNE Visualization")
    with tsne_expander:
        st.subheader('TSNE Visualization')
        #tsne_fig = viz_utils.visualize_embeddings_tsne(embedding_df)
        #st.plotly_chart(tsne_fig)
        components.html(tsne_html, width=800, height=800)

    pca_expander = st.expander("PCA Visualization")
    with pca_expander:
        st.subheader('PCA Visualization')
        #pca_fig = viz_utils.visualize_embeddings_pca(embedding_df)
        #st.plotly_chart(pca_fig)
        components.html(pca_html, width=800, height=800)




def get_movie_recommendations(model, data, user_id, total_movies):
    user_row = torch.tensor([user_id] * total_movies).to(device)
    all_movie_ids = torch.arange(total_movies).to(device)
    edge_label_index = torch.stack([user_row, all_movie_ids], dim=0)

    pred = model(data.x_dict, data.edge_index_dict, edge_label_index).to('cpu')
    top_five_indices = pred.topk(5).indices

    recommended_movies = movies_df.iloc[top_five_indices]
    return recommended_movies

def generate_poster(movie_title):
    headers = {"Authorization": f"Bearer {API_KEY}"}

    #creates random seed so movie poster changes on refresh even if same title. 
    seed = random.randint(0, 2**32 - 1)
    payload = {
        "inputs": movie_title,
        # "parameters": {
        #     "seed": seed
        # }
    }

    try:
        response = requests.post(API_URL, headers=headers, json=payload)
        response.raise_for_status()  # Raise an error if the request fails

        # Display the generated image
        image = Image.open(BytesIO(response.content))
        st.image(image, caption=movie_title)

    except requests.exceptions.HTTPError as err:
        st.error(f"Image generation failed: {err}")

with tab2:
    user_id = st.number_input("Enter the User ID:", min_value=0)
    if st.button("Get Recommendations"):
        st.write("Top 5 Recommendations:")
        try:
            total_movies = data['movie'].num_nodes  
            recommended_movies = get_movie_recommendations(model, data, user_id, total_movies)
            cols = st.columns(3)  

        
            for i, row in recommended_movies.iterrows():
                with cols[i % 3]: 
                    #st.write(f"{i+1}. {row['title']}") 
                    try:
                        image = generate_poster(row['title'])
                    except requests.exceptions.HTTPError as err:
                        st.error(f"Image generation failed for {row['title']}: {err}")

        except Exception as e:
            st.error(f"An error occurred: {e}")