mgoin's picture
Update app.py
5e19c4f
raw
history blame
8.53 kB
import deepsparse
import gradio as gr
from typing import Tuple, List
deepsparse.cpu.print_hardware_capability()
MODEL_ID = "zoo:llama2-7b-gsm8k_llama2_pretrain-pruned60_quantized"
DESCRIPTION = f"""
# Llama 2 Sparse Finetuned on GSM8k with DeepSparse
![NM Logo](https://files.slack.com/files-pri/T020WGRLR8A-F05TXD28BBK/neuralmagic-logo.png?pub_secret=54e8db19db)
Model ID: {MODEL_ID}
🚀 **Experience the power of LLM mathematical reasoning** through [our Llama 2 sparse finetuned](https://arxiv.org/abs/2310.06927) on the [GSM8K dataset](https://huggingface.co/datasets/gsm8k).
GSM8K, short for Grade School Math 8K, is a collection of 8.5K high-quality linguistically diverse grade school math word problems, designed to challenge question-answering systems with multi-step reasoning.
Observe the model's performance in deciphering complex math questions and offering detailed step-by-step solutions.
## Accelerated Inferenced on CPUs
The Llama 2 model runs purely on CPU courtesy of [sparse software execution by DeepSparse](https://github.com/neuralmagic/deepsparse/tree/main/research/mpt).
DeepSparse provides accelerated inference by taking advantage of the model's weight sparsity to deliver tokens fast!
![Speedup](https://cdn-uploads.huggingface.co/production/uploads/60466e4b4f40b01b66151416/2XjSvMtX1DO3WY5Rx-L-1.png)
"""
MAX_MAX_NEW_TOKENS = 1024
DEFAULT_MAX_NEW_TOKENS = 200
# Setup the engine
pipe = deepsparse.TextGeneration(model=MODEL_ID, sequence_length=MAX_MAX_NEW_TOKENS)
def clear_and_save_textbox(message: str) -> Tuple[str, str]:
return "", message
def display_input(
message: str, history: List[Tuple[str, str]]
) -> List[Tuple[str, str]]:
history.append((message, ""))
return history
def delete_prev_fn(history: List[Tuple[str, str]]) -> Tuple[List[Tuple[str, str]], str]:
try:
message, _ = history.pop()
except IndexError:
message = ""
return history, message or ""
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr.Markdown(DESCRIPTION)
with gr.Column():
gr.Markdown("""### MPT GSM Sparse Finetuned Demo""")
with gr.Group():
chatbot = gr.Chatbot(label="Chatbot")
with gr.Row():
textbox = gr.Textbox(
container=False,
placeholder="Type a message...",
scale=10,
)
submit_button = gr.Button(
"Submit", variant="primary", scale=1, min_width=0
)
with gr.Row():
retry_button = gr.Button("🔄 Retry", variant="secondary")
undo_button = gr.Button("↩️ Undo", variant="secondary")
clear_button = gr.Button("🗑️ Clear", variant="secondary")
saved_input = gr.State()
gr.Examples(
examples=[
"James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many total meters does he run a week?",
"Claire makes a 3 egg omelet every morning for breakfast. How many dozens of eggs will she eat in 4 weeks?",
"Gretchen has 110 coins. There are 30 more gold coins than silver coins. How many gold coins does Gretchen have?",
],
inputs=[textbox],
)
max_new_tokens = gr.Slider(
label="Max new tokens",
value=DEFAULT_MAX_NEW_TOKENS,
minimum=0,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
interactive=True,
info="The maximum numbers of new tokens",
)
temperature = gr.Slider(
label="Temperature",
value=0.3,
minimum=0.05,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
)
top_p = gr.Slider(
label="Top-p (nucleus) sampling",
value=0.40,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
)
top_k = gr.Slider(
label="Top-k sampling",
value=20,
minimum=1,
maximum=100,
step=1,
interactive=True,
info="Sample from the top_k most likely tokens",
)
repetition_penalty = gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
# Generation inference
def generate(
message,
history,
max_new_tokens: int,
temperature: float,
top_p: float,
top_k: int,
repetition_penalty: float,
):
generation_config = {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
}
inference = pipe(sequences=message, streaming=True, **generation_config)
history[-1][1] += message
for token in inference:
history[-1][1] += token.generations[0].text
yield history
print(pipe.timer_manager)
textbox.submit(
fn=clear_and_save_textbox,
inputs=textbox,
outputs=[textbox, saved_input],
api_name=False,
queue=False,
).then(
fn=display_input,
inputs=[saved_input, chatbot],
outputs=chatbot,
api_name=False,
queue=False,
).success(
generate,
inputs=[
saved_input,
chatbot,
max_new_tokens,
temperature,
top_p,
top_k,
repetition_penalty,
],
outputs=[chatbot],
api_name=False,
)
submit_button.click(
fn=clear_and_save_textbox,
inputs=textbox,
outputs=[textbox, saved_input],
api_name=False,
queue=False,
).then(
fn=display_input,
inputs=[saved_input, chatbot],
outputs=chatbot,
api_name=False,
queue=False,
).success(
generate,
inputs=[saved_input, chatbot, max_new_tokens, temperature],
outputs=[chatbot],
api_name=False,
)
retry_button.click(
fn=delete_prev_fn,
inputs=chatbot,
outputs=[chatbot, saved_input],
api_name=False,
queue=False,
).then(
fn=display_input,
inputs=[saved_input, chatbot],
outputs=chatbot,
api_name=False,
queue=False,
).then(
generate,
inputs=[saved_input, chatbot, max_new_tokens, temperature],
outputs=[chatbot],
api_name=False,
)
undo_button.click(
fn=delete_prev_fn,
inputs=chatbot,
outputs=[chatbot, saved_input],
api_name=False,
queue=False,
).then(
fn=lambda x: x,
inputs=[saved_input],
outputs=textbox,
api_name=False,
queue=False,
)
clear_button.click(
fn=lambda: ([], ""),
outputs=[chatbot, saved_input],
queue=False,
api_name=False,
)
demo.queue().launch()