import deepsparse
import gradio as gr
from typing import Tuple, List

deepsparse.cpu.print_hardware_capability()

MODEL_ID = "zoo:llama2-7b-gsm8k_llama2_pretrain-pruned60_quantized"

DESCRIPTION = f"""
# Llama 2 Sparse Finetuned on GSM8k with DeepSparse 
![NM Logo](https://files.slack.com/files-pri/T020WGRLR8A-F05TXD28BBK/neuralmagic-logo.png?pub_secret=54e8db19db)
Model ID: {MODEL_ID}

🚀 **Experience the power of LLM mathematical reasoning** through [our Llama 2 sparse finetuned](https://arxiv.org/abs/2310.06927) on the [GSM8K dataset](https://huggingface.co/datasets/gsm8k). 
GSM8K, short for Grade School Math 8K, is a collection of 8.5K high-quality linguistically diverse grade school math word problems, designed to challenge question-answering systems with multi-step reasoning. 
Observe the model's performance in deciphering complex math questions and offering detailed step-by-step solutions.
## Accelerated Inferenced on CPUs 
The Llama 2 model runs purely on CPU courtesy of [sparse software execution by DeepSparse](https://github.com/neuralmagic/deepsparse/tree/main/research/mpt). 
DeepSparse provides accelerated inference by taking advantage of the model's weight sparsity to deliver tokens fast!

![Speedup](https://cdn-uploads.huggingface.co/production/uploads/60466e4b4f40b01b66151416/2XjSvMtX1DO3WY5Rx-L-1.png)
"""

MAX_MAX_NEW_TOKENS = 1024
DEFAULT_MAX_NEW_TOKENS = 200

# Setup the engine
pipe = deepsparse.TextGeneration(model=MODEL_ID, sequence_length=MAX_MAX_NEW_TOKENS)


def clear_and_save_textbox(message: str) -> Tuple[str, str]:
    return "", message


def display_input(
    message: str, history: List[Tuple[str, str]]
) -> List[Tuple[str, str]]:
    history.append((message, ""))
    return history


def delete_prev_fn(history: List[Tuple[str, str]]) -> Tuple[List[Tuple[str, str]], str]:
    try:
        message, _ = history.pop()
    except IndexError:
        message = ""
    return history, message or ""


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            gr.Markdown(DESCRIPTION)
        with gr.Column():
            gr.Markdown("""### Sparse Finetuned Llama Demo""")

            with gr.Group():
                chatbot = gr.Chatbot(label="Chatbot")
                with gr.Row():
                    textbox = gr.Textbox(
                        container=False,
                        placeholder="Type a message...",
                        scale=10,
                    )
                    submit_button = gr.Button(
                        "Submit", variant="primary", scale=1, min_width=0
                    )

            with gr.Row():
                retry_button = gr.Button("🔄  Retry", variant="secondary")
                undo_button = gr.Button("↩ī¸ Undo", variant="secondary")
                clear_button = gr.Button("🗑ī¸  Clear", variant="secondary")

            saved_input = gr.State()

            gr.Examples(
                examples=[
                    "James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many total meters does he run a week?",
                    "Claire makes a 3 egg omelet every morning for breakfast. How many dozens of eggs will she eat in 4 weeks?",
                    "Gretchen has 110 coins. There are 30 more gold coins than silver coins. How many gold coins does Gretchen have?",
                ],
                inputs=[textbox],
            )

            max_new_tokens = gr.Slider(
                label="Max new tokens",
                value=DEFAULT_MAX_NEW_TOKENS,
                minimum=0,
                maximum=MAX_MAX_NEW_TOKENS,
                step=1,
                interactive=True,
                info="The maximum numbers of new tokens",
            )
            temperature = gr.Slider(
                label="Temperature",
                value=0.3,
                minimum=0.05,
                maximum=1.0,
                step=0.05,
                interactive=True,
                info="Higher values produce more diverse outputs",
            )

            # Generation inference
            def generate(
                message,
                history,
                max_new_tokens: int,
                temperature: float,
            ):
                generation_config = {
                    "max_new_tokens": max_new_tokens,
                    "temperature": temperature,
                }
                inference = pipe(sequences=message, streaming=False, **generation_config)
                # history[-1][1] += message
                # for token in inference:
                #     history[-1][1] += token.generations[0].text
                #     yield history
                history[-1][1] += inference.generations[0].text
                print(pipe.timer_manager)
                return history

            textbox.submit(
                fn=clear_and_save_textbox,
                inputs=textbox,
                outputs=[textbox, saved_input],
                api_name=False,
                queue=False,
            ).then(
                fn=display_input,
                inputs=[saved_input, chatbot],
                outputs=chatbot,
                api_name=False,
                queue=False,
            ).success(
                generate,
                inputs=[
                    saved_input,
                    chatbot,
                    max_new_tokens,
                    temperature,
                ],
                outputs=[chatbot],
                api_name=False,
            )

            submit_button.click(
                fn=clear_and_save_textbox,
                inputs=textbox,
                outputs=[textbox, saved_input],
                api_name=False,
                queue=False,
            ).then(
                fn=display_input,
                inputs=[saved_input, chatbot],
                outputs=chatbot,
                api_name=False,
                queue=False,
            ).success(
                generate,
                inputs=[
                    saved_input,
                    chatbot,
                    max_new_tokens,
                    temperature,
                ],
                outputs=[chatbot],
                api_name=False,
            )

            retry_button.click(
                fn=delete_prev_fn,
                inputs=chatbot,
                outputs=[chatbot, saved_input],
                api_name=False,
                queue=False,
            ).then(
                fn=display_input,
                inputs=[saved_input, chatbot],
                outputs=chatbot,
                api_name=False,
                queue=False,
            ).then(
                generate,
                inputs=[
                    saved_input,
                    chatbot,
                    max_new_tokens,
                    temperature,
                ],
                outputs=[chatbot],
                api_name=False,
            )
            undo_button.click(
                fn=delete_prev_fn,
                inputs=chatbot,
                outputs=[chatbot, saved_input],
                api_name=False,
                queue=False,
            ).then(
                fn=lambda x: x,
                inputs=[saved_input],
                outputs=textbox,
                api_name=False,
                queue=False,
            )
            clear_button.click(
                fn=lambda: ([], ""),
                outputs=[chatbot, saved_input],
                queue=False,
                api_name=False,
            )


demo.queue().launch()