Spaces:
Running
on
A10G
Running
on
A10G
from __future__ import annotations | |
from typing import Iterable | |
import gradio as gr | |
from gradio.themes.base import Base | |
from gradio.themes.utils import colors, fonts, sizes | |
import time | |
import torch | |
from transformers import pipeline | |
import pandas as pd | |
instruct_pipeline = pipeline(model="databricks/dolly-v2-7b", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto") | |
def run_pipeline(prompt): | |
response = instruct_pipeline(prompt) | |
return response | |
def get_user_input(input_question, history): | |
return "", history + [[input_question, None]] | |
def get_qa_user_input(input_question, history): | |
return "", history + [[input_question, None]] | |
def dolly_chat(history): | |
prompt = history[-1][0] | |
bot_message = run_pipeline(prompt) | |
history[-1][1] = bot_message | |
return history | |
def qa_bot(context, history): | |
query = history[-1][0] | |
prompt = f'instruction: {query} \ncontext: {context}' | |
bot_message = run_pipeline(prompt) | |
history[-1][1] = bot_message | |
return history | |
def reset_chatbot(): | |
return gr.update(value="") | |
def load_customer_support_example(): | |
df = pd.read_csv("examples.csv") | |
return df['doc'].iloc[0], df['question'].iloc[0] | |
def load_databricks_doc_example(): | |
df = pd.read_csv("examples.csv") | |
return df['doc'].iloc[1], df['question'].iloc[1] | |
# Referred & modified from https://gradio.app/theming-guide/ | |
class SeafoamCustom(Base): | |
def __init__( | |
self, | |
*, | |
primary_hue: colors.Color | str = colors.emerald, | |
secondary_hue: colors.Color | str = colors.blue, | |
neutral_hue: colors.Color | str = colors.blue, | |
spacing_size: sizes.Size | str = sizes.spacing_md, | |
radius_size: sizes.Size | str = sizes.radius_md, | |
font: fonts.Font | |
| str | |
| Iterable[fonts.Font | str] = ( | |
fonts.GoogleFont("Quicksand"), | |
"ui-sans-serif", | |
"sans-serif", | |
), | |
font_mono: fonts.Font | |
| str | |
| Iterable[fonts.Font | str] = ( | |
fonts.GoogleFont("IBM Plex Mono"), | |
"ui-monospace", | |
"monospace", | |
), | |
): | |
super().__init__( | |
primary_hue=primary_hue, | |
secondary_hue=secondary_hue, | |
neutral_hue=neutral_hue, | |
spacing_size=spacing_size, | |
radius_size=radius_size, | |
font=font, | |
font_mono=font_mono, | |
) | |
super().set( | |
button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)", | |
button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)", | |
button_primary_text_color="white", | |
button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)", | |
block_shadow="*shadow_drop_lg", | |
button_shadow="*shadow_drop_lg", | |
input_background_fill="zinc", | |
input_border_color="*secondary_300", | |
input_shadow="*shadow_drop", | |
input_shadow_focus="*shadow_drop_lg", | |
) | |
seafoam = SeafoamCustom() | |
with gr.Blocks(theme=seafoam) as demo: | |
with gr.Row(variant='panel'): | |
with gr.Column(): | |
gr.HTML( | |
"""<html><img src='file/dolly.jpg', alt='dolly logo', width=150, height=150 /><br></html>""" | |
) | |
with gr.Column(): | |
gr.Markdown("# **<p align='center'>Dolly 2.0: World's First Truly Open Instruction-Tuned LLM</p>**") | |
gr.Markdown("Dolly 2.0, the first open source, instruction-following LLM, fine-tuned on a human-generated instruction dataset licensed for research and commercial use. It's a 12B parameter language model based on the EleutherAI pythia model family and fine-tuned exclusively on a new, high-quality human generated instruction following dataset, crowdsourced among Databricks employees.") | |
qa_bot_state = gr.State(value=[]) | |
with gr.Tabs(): | |
with gr.TabItem("Dolly Chat"): | |
with gr.Row(): | |
with gr.Column(): | |
chatbot = gr.Chatbot(label="Chat History") | |
input_question = gr.Text( | |
label="Instruction", | |
placeholder="Type prompt and hit enter.", | |
) | |
clear = gr.Button("Clear", variant="primary") | |
with gr.Row(): | |
with gr.Accordion("Show example inputs I can load:", open=False): | |
gr.Examples( | |
[ | |
["Explain to me the difference between nuclear fission and fusion."], | |
["Give me a list of 5 science fiction books I should read next."], | |
["I'm selling my Nikon D-750, write a short blurb for my ad."], | |
["Write a song about sour donuts"], | |
["Write a tweet about a new book launch by J.K. Rowling."], | |
], | |
[input_question], | |
[], | |
None, | |
cache_examples=False, | |
) | |
with gr.TabItem("Q&A with Context"): | |
with gr.Row(): | |
with gr.Column(): | |
input_context = gr.Text(label="Add context here", lines=10) | |
with gr.Column(): | |
qa_chatbot = gr.Chatbot(label="Q&A History") | |
qa_input_question = gr.Text( | |
label="Input Question", | |
placeholder="Type question here and hit enter.", | |
) | |
qa_clear = gr.Button("Clear", variant="primary") | |
with gr.Row(): | |
with gr.Accordion("Show example inputs I can load:", open=False): | |
example_1 = gr.Button("Load Customer support example") | |
example_2 = gr.Button("Load Databricks documentation example") | |
input_question.submit( | |
get_user_input, | |
[input_question, chatbot], | |
[input_question, chatbot], | |
).then(dolly_chat, [chatbot], chatbot) | |
clear.click(lambda: None, None, chatbot) | |
qa_input_question.submit( | |
get_qa_user_input, | |
[qa_input_question, qa_chatbot], | |
[qa_input_question, qa_chatbot], | |
).then(qa_bot, [input_context, qa_chatbot], qa_chatbot) | |
qa_clear.click(lambda: None, None, qa_chatbot) | |
# reset the chatbot Q&A history when input context changes | |
input_context.change(fn=reset_chatbot, inputs=[], outputs=qa_chatbot) | |
example_1.click( | |
load_customer_support_example, | |
[], | |
[input_context, qa_input_question], | |
) | |
example_2.click( | |
load_databricks_doc_example, | |
[], | |
[input_context, qa_input_question], | |
) | |
if __name__ == "__main__": | |
demo.queue(concurrency_count=1,max_size=100).launch(max_threads=5,debug=True) | |