Spaces:
Runtime error
Runtime error
""" | |
app.py - the main file for the app. This creates the flask app and handles the routes. | |
""" | |
import argparse | |
import logging | |
import os | |
import sys | |
import time | |
import warnings | |
from os.path import dirname | |
from pathlib import Path | |
import gradio as gr | |
import nltk | |
import torch | |
from cleantext import clean | |
from gradio.inputs import Slider, Textbox | |
from transformers import pipeline | |
from converse import discussion | |
from grammar_improve import ( | |
build_symspell_obj, | |
detect_propers, | |
fix_punct_spacing, | |
load_ns_checker, | |
neuspell_correct, | |
remove_repeated_words, | |
remove_trailing_punctuation, | |
symspeller, | |
synthesize_grammar, | |
) | |
from utils import corr | |
nltk.download("stopwords") # download stopwords | |
sys.path.append(dirname(dirname(os.path.abspath(__file__)))) | |
warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*") | |
import transformers | |
transformers.logging.set_verbosity_error() | |
logging.basicConfig() | |
cwd = Path.cwd() | |
my_cwd = str(cwd.resolve()) # string so it can be passed to os.path() objects | |
def chat(prompt_message, temperature=0.7, top_p=0.95, top_k=50): | |
""" | |
chat - helper function that makes the whole gradio thing work. | |
Args: | |
trivia_query (str): the question to ask the bot | |
Returns: | |
[str]: the bot's response | |
""" | |
history = [] | |
response = ask_gpt( | |
message=prompt_message, | |
chat_pipe=my_chatbot, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
) | |
history = [prompt_message, response] | |
html = "" | |
for item in history: | |
html += f"<b>{item}</b> <br><br>" | |
html += "" | |
return html | |
def ask_gpt( | |
message: str, | |
chat_pipe, | |
speaker="person alpha", | |
responder="person beta", | |
max_len=96, | |
top_p=0.95, | |
top_k=25, | |
temperature=0.6, | |
): | |
""" | |
ask_gpt - a function that takes in a prompt and generates a response using the pipeline. This interacts the discussion function. | |
Parameters: | |
message (str): the question to ask the bot | |
chat_pipe (str): the chat_pipe to use for the bot (default: "pszemraj/Ballpark-Trivia-XL") | |
speaker (str): the name of the speaker (default: "person alpha") | |
responder (str): the name of the responder (default: "person beta") | |
max_len (int): the maximum length of the response (default: 128) | |
top_p (float): the top probability threshold (default: 0.95) | |
top_k (int): the top k threshold (default: 50) | |
temperature (float): the temperature of the response (default: 0.7) | |
""" | |
st = time.perf_counter() | |
prompt = clean(message) # clean user input | |
prompt = prompt.strip() # get rid of any extra whitespace | |
in_len = len(prompt) | |
if in_len > 512: | |
prompt = prompt[-512:] # truncate to 512 chars | |
print(f"Truncated prompt to last 512 chars: started with {in_len} chars") | |
max_len = min(max_len, 512) | |
resp = discussion( | |
prompt_text=prompt, | |
pipeline=chat_pipe, | |
speaker=speaker, | |
responder=responder, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
max_length=max_len, | |
) | |
gpt_et = time.perf_counter() | |
gpt_rt = round(gpt_et - st, 2) | |
rawtxt = resp["out_text"] | |
# check for proper nouns | |
if basic_sc: | |
cln_resp = symspeller(rawtxt, sym_checker=schnellspell) | |
else: | |
cln_resp = synthesize_grammar(corrector=grammarbot, message=cln_resp) | |
bot_resp_a = corr(remove_repeated_words(cln_resp)) | |
bot_resp = fix_punct_spacing(bot_resp_a) | |
print(f"the prompt was:\n\t{message}\nand the response was:\n\t{bot_resp}\n") | |
corr_rt = round(time.perf_counter() - gpt_et, 4) | |
print( | |
f"took {gpt_rt + corr_rt} sec to respond, {gpt_rt} for GPT, {corr_rt} for correction\n" | |
) | |
return remove_trailing_punctuation(bot_resp) | |
def get_parser(): | |
""" | |
get_parser - a helper function for the argparse module | |
""" | |
parser = argparse.ArgumentParser( | |
description="submit a question, GPT model responds" | |
) | |
parser.add_argument( | |
"-m", | |
"--model", | |
required=False, | |
type=str, | |
default="ethzanalytics/ai-msgbot-gpt2-XL", # default model | |
help="the model to use for the chatbot on https://huggingface.co/models OR a path to a local model", | |
) | |
parser.add_argument( | |
"--gram-model", | |
required=False, | |
type=str, | |
default="pszemraj/t5-v1_1-base-ft-jflAUG", | |
help="text2text generation model ID from huggingface for the model to correct grammar", | |
) | |
parser.add_argument( | |
"--basic-sc", | |
required=False, | |
default=False, # TODO: change this back to False once Neuspell issues are resolved. | |
action="store_true", | |
help="turn on symspell (baseline) correction instead of the more advanced neural net models", | |
) | |
parser.add_argument( | |
"--verbose", | |
action="store_true", | |
default=False, | |
help="turn on verbose logging", | |
) | |
return parser | |
if __name__ == "__main__": | |
args = get_parser().parse_args() | |
default_model = str(args.model) | |
model_loc = Path(default_model) # if the model is a path, use it | |
basic_sc = args.basic_sc # whether to use the baseline spellchecker | |
gram_model = str(args.gram_model) | |
device = 0 if torch.cuda.is_available() else -1 | |
print(f"CUDA avail is {torch.cuda.is_available()}") | |
my_chatbot = ( | |
pipeline("text-generation", model=model_loc.resolve(), device=device) | |
if model_loc.exists() and model_loc.is_dir() | |
else pipeline("text-generation", model=default_model, device=device) | |
) # if the model is a name, use it. stays on CPU if no GPU available | |
print(f"using model {my_chatbot.model}") | |
if basic_sc: | |
print("Using the baseline spellchecker") | |
schnellspell = build_symspell_obj() | |
else: | |
print("using neural spell checker") | |
grammarbot = pipeline("text2text-generation", gram_model, device=device) | |
print(f"using model stored here: \n {model_loc} \n") | |
iface = gr.Interface( | |
chat, | |
inputs=[ | |
Textbox( | |
default="Why is everyone here eating chocolate cake?", | |
label="prompt_message", | |
placeholder="Enter a question", | |
lines=2, | |
), | |
Slider( | |
minimum=0.0, maximum=1.0, step=0.01, default=0.6, label="temperature" | |
), | |
Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.95, label="top_p"), | |
Slider(minimum=0, maximum=100, step=5, default=20, label="top_k"), | |
], | |
outputs="html", | |
examples_per_page=8, | |
examples=[ | |
["Point Break or Bad Boys II?", 0.75, 0.95, 50], | |
["So... you're saying this wasn't an accident?", 0.6, 0.95, 50], | |
["Hi, my name is Reginald", 0.6, 0.95, 100], | |
["Happy birthday!", 0.9, 0.95, 50], | |
["I have a question, can you help me?", 0.6, 0.95, 50], | |
["Do you know a joke?", 0.8, 0.85, 50], | |
["Will you marry me?", 0.9, 0.95, 100], | |
["Are you single?", 0.6, 0.95, 100], | |
["Do you like people?", 0.7, 0.95, 25], | |
["You never took a short cut before?", 0.7, 0.95, 100], | |
], | |
title=f"GPT Chatbot Demo: {default_model} Model", | |
description=f"A Demo of a Chatbot trained for conversation with humans. Size XL= 1.5B parameters.\n\n" | |
"**Important Notes & About:**\n\n" | |
"You can find a link to the model card **[here](https://huggingface.co/ethzanalytics/ai-msgbot-gpt2-XL-dialogue)**\n\n" | |
"1. responses can take up to 60 seconds to respond sometimes, patience is a virtue.\n" | |
"2. the model was trained on several different datasets. fact-check responses instead of regarding as a true statement.\n" | |
"3. Try adjusting the **[generation parameters](https://huggingface.co/blog/how-to-generate)** to get a better understanding of how they work!\n", | |
css=""" | |
.chatbox {display:flex;flex-direction:row} | |
.user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%} | |
.user_msg {background-color:cornflowerblue;color:white;align-self:start} | |
.resp_msg {background-color:lightgray;align-self:self-end} | |
""", | |
allow_screenshot=True, | |
allow_flagging="never", | |
theme="dark", | |
) | |
# launch the gradio interface and start the server | |
iface.launch( | |
# prevent_thread_lock=True, | |
enable_queue=True, # also allows for dealing with multiple users simultaneously (per newer gradio version) | |
) | |