from threading import Thread
import logging
import time

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
)

import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer

model_id = "pszemraj/nanoT5-mid-2k-instruct"
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
logging.info(f"Running on device:\t {torch_device}")
logging.info(f"CPU threads:\t {torch.get_num_threads()}")


if torch_device == "cuda":
    model = AutoModelForSeq2SeqLM.from_pretrained(
        model_id, load_in_8bit=True, device_map="auto"
    )
else:
    model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
try:
    model = torch.compile(model)
except Exception as e:
    logging.error(f"Unable to compile model:\t{e}")
    
tokenizer = AutoTokenizer.from_pretrained(model_id)


def run_generation(
    user_text,
    top_p,
    temperature,
    top_k,
    max_new_tokens,
    repetition_penalty=1.1,
    length_penalty=1.0,
    no_repeat_ngram_size=4,
    use_generation_config=False,
):
    st = time.perf_counter()
    # Get the model and tokenizer, and tokenize the user text.
    model_inputs = tokenizer([user_text], return_tensors="pt").to(torch_device)

    # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
    # in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
    streamer = TextIteratorStreamer(
        tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        num_beams=1,
        top_p=top_p,
        temperature=float(temperature),
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        length_penalty=length_penalty,
        no_repeat_ngram_size=no_repeat_ngram_size,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    # Pull the generated text from the streamer, and update the model output.
    model_output = ""
    for new_text in streamer:
        model_output += new_text
        yield model_output
    logging.info("Total rt:\t{rt} sec".format(rt=round(time.perf_counter() - st, 3)))
    return model_output


def reset_textbox():
    return gr.update(value="")


with gr.Blocks() as demo:
    duplicate_link = (
        "https://huggingface.co/spaces/joaogante/transformers_streaming?duplicate=true"
    )
    gr.Markdown(
        "# 🤗 Transformers 🔥Streaming🔥 on Gradio\n"
        "This demo showcases the use of the "
        "[streaming feature](https://huggingface.co/docs/transformers/main/en/generation_strategies#streaming) "
        "of 🤗 Transformers with Gradio to generate text in real-time. It uses "
        f"[{model_id}](https://huggingface.co/{model_id}) and the Spaces free compute tier.\n\n"
        f"Feel free to [duplicate this Space]({duplicate_link}) to try your own models or use this space as a "
        "template! 💛"
    )
    gr.Markdown("---")
    with gr.Row():
        with gr.Column(scale=4):
            user_text = gr.Textbox(
                value="How to become a polar bear tamer?",
                label="User input",
            )
            model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
            button_submit = gr.Button(value="Submit", variant="primary")

        with gr.Column(scale=1):
            max_new_tokens = gr.Slider(
                minimum=32,
                maximum=1024,
                value=256,
                step=32,
                interactive=True,
                label="Max New Tokens",
            )
            top_p = gr.Slider(
                minimum=0.05,
                maximum=1.0,
                value=0.95,
                step=0.05,
                interactive=True,
                label="Top-p (nucleus sampling)",
            )
            top_k = gr.Slider(
                minimum=1,
                maximum=50,
                value=50,
                step=1,
                interactive=True,
                label="Top-k",
            )
            temperature = gr.Slider(
                minimum=0.1,
                maximum=1.4,
                value=0.3,
                step=0.05,
                interactive=True,
                label="Temperature",
            )
            repetition_penalty = gr.Slider(
                minimum=0.9,
                maximum=2.5,
                value=1.1,
                step=0.1,
                interactive=True,
                label="Repetition Penalty",
            )
            length_penalty = gr.Slider(
                minimum=0.8,
                maximum=1.5,
                value=1.0,
                step=0.1,
                interactive=True,
                label="Length Penalty",
            )
            # temperature = gr.Slider(
            #     minimum=0.1,
            #     maximum=5.0,
            #     value=0.8,
            #     step=0.1,
            #     interactive=True,
            #     label="Temperature",
            # )
    user_text.submit(
        run_generation,
        [user_text, top_p, temperature, top_k, max_new_tokens, repetition_penalty, length_penalty],
        model_output,
    )
    button_submit.click(
        run_generation,
        [user_text, top_p, temperature, top_k, max_new_tokens, repetition_penalty, length_penalty],
        model_output,
    )

    demo.queue(max_size=32).launch(enable_queue=True)