"""
Credit to Derek Thomas, derek@huggingface.co
"""
import os
import logging
from pathlib import Path
from time import perf_counter

import gradio as gr
from jinja2 import Environment, FileSystemLoader

from backend.query_llm import generate_hf, generate_openai
from backend.semantic_search import retrieve
from backend.reranker import rerank_documents


TOP_K = int(os.getenv("TOP_K", 4))

proj_dir = Path(__file__).parent
# Setting up the logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set up the template environment with the templates directory
env = Environment(loader=FileSystemLoader(proj_dir / 'templates'))

# Load the templates directly from the environment
template = env.get_template('template.j2')
template_html = env.get_template('template_html.j2')


def add_text(history, text):
    history = [] if history is None else history
    history = history + [(text, None)]
    return history, gr.Textbox(value="", interactive=False)


def bot(history, chunk_table, embedding_model, llm_model, cross_encoder, top_k_param, rerank_topk ):
    top_k_param = int(top_k_param)
    query = history[-1][0]

    logger.info("bot launched ...")
    logger.info(f"embedding model: {embedding_model}")
    logger.info(f"LLM model: {llm_model}")
    logger.info(f"Cross encoder model: {cross_encoder}")
    logger.info(f"TopK: {top_k_param}")
    logger.info(f"ReRank TopK: {rerank_topk}")


    if not query:
        raise gr.Warning("Please submit a non-empty string as a prompt")

    logger.info('Retrieving documents...')
    # Retrieve documents relevant to query
    document_start = perf_counter()

    #documents = retrieve(query, TOP_K)
    documents = retrieve(query, top_k_param, chunk_table, embedding_model)
    logger.info(f'Retrived document count: {len(documents)}')

    if cross_encoder != "None" and len(documents) > 1:
        documents = rerank_documents(cross_encoder, documents, query, top_k_rerank=rerank_topk)
        #"cross-encoder/ms-marco-MiniLM-L-6-v2"
        logger.info(f'ReRank done, document count: {len(documents)}')





    document_time = perf_counter() - document_start
    logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')

    # Create Prompt
    prompt = template.render(documents=documents, query=query)
    prompt_html = template_html.render(documents=documents, query=query)

    if llm_model == "mistralai/Mistral-7B-Instruct-v0.2":
        generate_fn = generate_hf
    if llm_model == "mistralai/Mistral-7B-v0.1":
        generate_fn = generate_hf
    if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
        generate_fn = generate_hf 
    if llm_model == "gpt-3.5-turbo":
        generate_fn = generate_openai
    if llm_model == "gpt-4-turbo-preview":
        generate_fn = generate_openai

    #if api_kind == "HuggingFace":
    #     generate_fn = generate_hf
    #elif api_kind == "OpenAI":
    #     generate_fn = generate_openai
    #else:
    #     raise gr.Error(f"API {api_kind} is not supported")
    
    logger.info(f'Complition started. llm_model: {llm_model}, prompt: {prompt}')
    history[-1][1] = ""
    for character in generate_fn(prompt, history[:-1], llm_model):
        history[-1][1] = character
        yield history, prompt_html


with gr.Blocks() as demo:
    chatbot = gr.Chatbot(
            [],
            elem_id="chatbot",
            avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
                           'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
            bubble_full_width=False,
            show_copy_button=True,
            show_share_button=True,
            )

    with gr.Row():
        txt = gr.Textbox(
                scale=3,
                show_label=False,
                placeholder="Enter text and press enter",
                container=False,
                )
        txt_btn = gr.Button(value="Submit text", scale=1)

    #api_kind = gr.Radio(choices=["HuggingFace",
    #                             "OpenAI"], value="HuggingFace")
    
    chunk_table = gr.Radio(choices=["BGE_CharacterTextSplitter", 
                                    "BGE_FixedSizeSplitter",
                                    "BGE_RecursiveCharacterTextSplitter",
                                    "MiniLM_CharacterTextSplitter", 
                                    "MiniLM_FixedSizeSplitter",
                                    "MiniLM_RecursiveCharacterSplitter"
                                    ], 
                                    value="MiniLM_CharacterTextSplitter",
                                    label="Chunk table")
    embedding_model = gr.Radio(
                choices=[
                    "BAAI/bge-large-en-v1.5",
                    "sentence-transformers/all-MiniLM-L6-v2",
                ],
                value="sentence-transformers/all-MiniLM-L6-v2",
                label='Embedding model'
            )
    llm_model = gr.Radio(
                choices=[
                    "mistralai/Mistral-7B-Instruct-v0.2",
                    "gpt-3.5-turbo",
                    "gpt-4-turbo-preview",
                    "mistralai/Mistral-7B-v0.1",
                    "mistralai/Mixtral-8x7B-Instruct-v0.1"
                ],
                value="mistralai/Mistral-7B-Instruct-v0.2",
                label='LLM'
            )    
    cross_encoder = gr.Radio(
                choices=[
                    "None",
                    "BAAI/bge-reranker-large",
                    "cross-encoder/ms-marco-MiniLM-L-6-v2",
                ],
                value="None",
                label='Cross-encoder model'
            )
    top_k_param = gr.Radio(
                choices=[
                    "5",
                    "10",
                    "20",
                    "50",
                ],
                value="5",
                label='top-K'
            )
    rerank_topk = gr.Radio(
                choices=[
                    "5",
                    "10",
                    "20",
                    "50",
                ],
                value="5",
                label='rerank-top-K'
            )  


    prompt_html = gr.HTML()
    # Turn off interactivity while generating if you click
    txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
            bot, [chatbot, chunk_table, embedding_model, llm_model, cross_encoder, top_k_param, rerank_topk], [chatbot, prompt_html])

    # Turn it back on
    txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)

    # Turn off interactivity while generating if you hit enter
    txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
            bot, [chatbot, chunk_table, embedding_model, llm_model, cross_encoder, top_k_param, rerank_topk], [chatbot, prompt_html])

    # Turn it back on
    txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)

demo.queue()
demo.launch(debug=True)