import gradio as gr
import random
import time

from langchain.chat_models import ChatOpenAI
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Pinecone
from langchain.chains import LLMChain
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain.chains.question_answering import load_qa_chain
import pinecone

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

#OPENAI_API_KEY = ""
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
OPENAI_TEMP  = 0

PINECONE_KEY = os.environ["PINECONE_KEY"]
PINECONE_ENV = "asia-northeast1-gcp"
PINECONE_INDEX = "3gpp"

EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"

# return top-k text chunk from vector store
VECTOR_SEARCH_TOP_K = 10

# LLM input history length
LLM_HISTORY_LEN = 3


BUTTON_MIN_WIDTH = 150

STATUS_NOK = "404-MODEL UNREADY-red"
STATUS_OK  = "200-MODEL LOADED-brightgreen"

def get_status(inputs) -> str:
    return f"""<img src="https://img.shields.io/badge/{inputs}?style=flat"></a>"""
    

MODEL_NULL = get_status(STATUS_NOK)
MODEL_DONE = get_status(STATUS_OK)

MODEL_WARNING = "Please paste your OpenAI API Key from openai.com and press 'Enter' to initialize this application!"


webui_title = """
# 3GPP OpenAI Chatbot for Hackathon Demo

"""

KEY_INIT   = "Initialize Model"
KEY_SUBMIT = "Submit"
KEY_CLEAR  = "Clear"

init_message = f"""Welcome to use 3GPP Chatbot, this demo toolkit is based on OpenAI with LangChain and Pinecone
    1. Insert your OpenAI API key and click `{KEY_INIT}`
    2. Insert your Question and click `{KEY_SUBMIT}`
"""

def init_model(api_key):
    try:
        if api_key and api_key.startswith("sk-") and len(api_key) > 50:
            
            embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)

            pinecone.init(api_key     = PINECONE_KEY,
                          environment = PINECONE_ENV)

            #llm = OpenAI(temperature=OPENAI_TEMP, model_name="gpt-3.5-turbo-0301")

            llm = ChatOpenAI(temperature = OPENAI_TEMP,
                             openai_api_key = api_key)

            chain = load_qa_chain(llm, chain_type="stuff")

            db = Pinecone.from_existing_index(index_name = PINECONE_INDEX,
                                              embedding  = embeddings)

            return api_key, MODEL_DONE, chain, db, None
        else:
            return None,MODEL_NULL,None,None,None
    except Exception as e:
        print(e)
        return None,MODEL_NULL,None,None,None


def get_chat_history(inputs) -> str:
    res = []
    for human, ai in inputs:
        res.append(f"Human: {human}\nAI: {ai}")
    return "\n".join(res)

def remove_duplicates(documents):
    seen_content = set()
    unique_documents = []
    for doc in documents:
        if doc.page_content not in seen_content:
            seen_content.add(doc.page_content)
            unique_documents.append(doc)
    return unique_documents

def doc_similarity(query, db, top_k):
    docsearch = db.as_retriever(search_kwargs={'k':top_k})
    docs = docsearch.get_relevant_documents(query)
    return remove_duplicates(docs)

def user(user_message, history):
    return "", history+[[user_message, None]]

def bot(box_message, ref_message, chain, db, top_k):

    # bot_message = random.choice(["Yes", "No"])
    # 0 is user question, 1 is bot response
    question = box_message[-1][0]
    history  = box_message[:-1]
    
    if (not chain) or (not db):
        box_message[-1][1] = MODEL_WARNING
        return box_message, "", ""

    if not ref_message:
        ref_message = question
        details = f"Q:  {question}"
    else:
        details = f"Q:  {question}\nR: {ref_message}"
        
        
    docs = doc_similarity(ref_message, db, top_k)
    
    delta_top_k = top_k - len(docs)
    
    if delta_top_k > 0:
        docs = doc_similarity(ref_message, db, top_k+delta_top_k)
        print(docs)

    all_output = chain({"input_documents": docs,
                        "question": question,
                        "chat_history": get_chat_history(history)})

    bot_message = all_output['output_text']


    source = "".join([f"""<details> <summary>{doc.metadata["source"]}</summary>
{doc.page_content}

</details>""" for i, doc in enumerate(docs)])

    #print(source)

    box_message[-1][1] = bot_message
    return box_message, "", [[details, source]]

with gr.Blocks(css=""".bigbox {
    min-height:200px;
}""") as demo:
    llm_chain = gr.State()
    vector_db = gr.State()
    gr.Markdown(webui_title)
    gr.Markdown(init_message)
    
    with gr.Row():
        with gr.Column(scale=9):
            api_textbox = gr.Textbox(
                label = "OpenAI API Key",
                value = OPENAI_API_KEY,
                placeholder = "Paste Your OpenAI API Key (sk-...) and Hit ENTER",
                lines=1,
                type='password')
            
        with gr.Column(scale=1, min_width=BUTTON_MIN_WIDTH):
            
            init = gr.Button(KEY_INIT).style(full_width=False)
            model_statusbox = gr.HTML(MODEL_NULL)
    
    with gr.Tab("3GPP-Chatbot"):
        with gr.Row():
            with gr.Column(scale=10):
                chatbot = gr.Chatbot(elem_classes="bigbox")
            '''
            with gr.Column(scale=1, min_width=BUTTON_MIN_WIDTH):
                temp = gr.Slider(0,
                          2,
                          value=OPENAI_TEMP,
                          step=0.1,
                          label="temperature",
                          interactive=True)
                init = gr.Button("Init")
            '''
        with gr.Row():
            with gr.Column(scale=10):
                query = gr.Textbox(label="Question:",
                                   lines=2)
                ref = gr.Textbox(label="Reference(optional):")
            with gr.Column(scale=1, min_width=BUTTON_MIN_WIDTH):
                clear = gr.Button(KEY_CLEAR)
                submit = gr.Button(KEY_SUBMIT,variant="primary")
                

    with gr.Tab("Details"):
        top_k = gr.Slider(1,
                          20,
                          value=VECTOR_SEARCH_TOP_K,
                          step=1,
                          label="Vector similarity top_k",
                          interactive=True)
        detail_panel = gr.Chatbot(label="Related Docs")
        
                
    api_textbox.submit(init_model,
                       api_textbox,
                       [api_textbox, model_statusbox, llm_chain, vector_db, chatbot])
    init.click(init_model,
                       api_textbox,
                       [api_textbox, model_statusbox, llm_chain, vector_db, chatbot])
    
    submit.click(user,
                 [query, chatbot],
                 [query, chatbot],
                 queue=False).then(
        bot,
        [chatbot, ref, llm_chain, vector_db, top_k],
        [chatbot, ref, detail_panel]
    )
    
    clear.click(lambda: (None,None,None), None, [query, ref, chatbot], queue=False)

if __name__ == "__main__":
    demo.launch(share=False, inbrowser=True)