Spaces:
Runtime error
Runtime error
""" | |
app.py - the main file for the app. This creates the flask app and handles the routes. | |
""" | |
import argparse | |
import logging | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s") | |
import os | |
import sys | |
import time | |
import warnings | |
from os.path import dirname | |
from pathlib import Path | |
import gradio as gr | |
import nltk | |
import torch | |
from cleantext import clean | |
from gradio.inputs import Slider, Textbox, Radio | |
from transformers import pipeline | |
from converse import discussion | |
from grammar_improve import ( | |
build_symspell_obj, | |
detect_propers, | |
fix_punct_spacing, | |
load_ns_checker, | |
neuspell_correct, | |
remove_repeated_words, | |
remove_trailing_punctuation, | |
symspeller, | |
synthesize_grammar, | |
) | |
from utils import corr, setup_logging | |
nltk.download("stopwords") # download stopwords | |
sys.path.append(dirname(dirname(os.path.abspath(__file__)))) | |
warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*") | |
import transformers | |
transformers.logging.set_verbosity_error() | |
cwd = Path.cwd() | |
_cwd_str = str(cwd.resolve()) # string so it can be passed to os.path() objects | |
def chat( | |
prompt_message, | |
temperature: float = 0.5, | |
top_p: float = 0.95, | |
top_k: int = 20, | |
constrained_generation: str = "False", | |
) -> str: | |
""" | |
chat - the main function for the chatbot. This is the function that is called when the user | |
:param _type_ prompt_message: the message to send to the model | |
:param float temperature: the temperature value for the model, defaults to 0.6 | |
:param float top_p: the top_p value for the model, defaults to 0.95 | |
:param int top_k: the top_k value for the model, defaults to 25 | |
:param bool constrained_generation: whether to use constrained generation or not, defaults to False | |
:return str: the response from the model | |
""" | |
history = [] | |
response = ask_gpt( | |
message=prompt_message, | |
chat_pipe=my_chatbot, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
constrained_generation=constrained_generation, | |
) | |
history = [prompt_message, response] | |
html = "" | |
for item in history: | |
html += f"<b>{item}</b> <br><br>" | |
html += "" | |
return html | |
def ask_gpt( | |
message: str, | |
chat_pipe, | |
speaker="person alpha", | |
responder="person beta", | |
min_length=12, | |
max_length=48, | |
top_p=0.95, | |
top_k=25, | |
temperature=0.5, | |
constrained_generation=False, | |
max_input_length=128, | |
) -> str: | |
""" | |
ask_gpt - helper function that asks the GPT model a question and returns the response | |
:param str message: the question to ask the model | |
:param chat_pipe: the pipeline object for the model, created by the pipeline() function | |
:param str speaker: the name of the speaker, defaults to "person alpha" | |
:param str responder: the name of the responder, defaults to "person beta" | |
:param int min_length: the minimum length of the response, defaults to 12 | |
:param int max_length: the maximum length of the response, defaults to 64 | |
:param float top_p: the top_p value for the model, defaults to 0.95 | |
:param int top_k: the top_k value for the model, defaults to 25 | |
:param float temperature: the temperature value for the model, defaults to 0.6 | |
:param bool constrained_generation: whether to use constrained generation or not, defaults to False | |
:return str: the response from the model | |
""" | |
st = time.perf_counter() | |
prompt = clean(message) # clean user input | |
prompt = prompt.strip() # get rid of any extra whitespace | |
in_len = len(chat_pipe.tokenizer(prompt).input_ids) | |
if in_len > max_input_length: | |
# truncate to last max_input_length tokens | |
tokens = chat_pipe.tokenizer(prompt).input_ids | |
trunc_tokens = tokens[-max_input_length:] | |
prompt = chat_pipe.tokenizer.decode(trunc_tokens) | |
print(f"truncated prompt to {len(trunc_tokens)} tokens, input length: {in_len}") | |
logging.info(f"prompt: {prompt}") | |
resp = discussion( | |
prompt_text=prompt, | |
pipeline=chat_pipe, | |
speaker=speaker, | |
responder=responder, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
max_length=max_length, | |
min_length=min_length, | |
constrained_beam_search=constrained_generation, | |
) | |
gpt_et = time.perf_counter() | |
gpt_rt = round(gpt_et - st, 2) | |
rawtxt = resp["out_text"] | |
# check for proper nouns | |
if basic_sc: | |
cln_resp = symspeller(rawtxt, sym_checker=basic_spell) | |
else: | |
cln_resp = synthesize_grammar(corrector=grammarbot, message=rawtxt) | |
bot_resp_a = corr(remove_repeated_words(cln_resp)) | |
bot_resp = fix_punct_spacing(bot_resp_a) | |
corr_rt = round(time.perf_counter() - gpt_et, 4) | |
print(f"{gpt_rt + corr_rt} to respond, {gpt_rt} GPT, {corr_rt} for correction\n") | |
return remove_trailing_punctuation(bot_resp) | |
def get_parser(): | |
""" | |
get_parser - a helper function for the argparse module | |
""" | |
parser = argparse.ArgumentParser( | |
description="submit a question, GPT model responds" | |
) | |
parser.add_argument( | |
"-m", | |
"--model", | |
required=False, | |
type=str, | |
default="ethzanalytics/ai-msgbot-gpt2-XL", # default model | |
help="the model to use for the chatbot on https://huggingface.co/models OR a path to a local model", | |
) | |
parser.add_argument( | |
"-gm", | |
"--gram-model", | |
required=False, | |
type=str, | |
default="pszemraj/grammar-synthesis-base", | |
help="text2text generation model ID from huggingface for the model to correct grammar", | |
) | |
parser.add_argument( | |
"--basic-sc", | |
required=False, | |
default=False, | |
action="store_true", | |
help="use symspell (statistical spelling correction) instead of neural spell correction", | |
) | |
parser.add_argument( | |
"--test", | |
action="store_true", | |
default=False, | |
help="load the smallest model for simple testing (ethzanalytics/distilgpt2-tiny-conversational)", | |
) | |
parser.add_argument( | |
"--verbose", | |
action="store_true", | |
default=False, | |
help="turn on verbose printing", | |
) | |
parser.add_argument( | |
"-q", | |
"--quiet", | |
dest="loglevel", | |
help="set loglevel to WARNING (reduce output)", | |
action="store_const", | |
const=logging.WARNING, | |
) | |
parser.add_argument( | |
"-vv", | |
"--very-verbose", | |
dest="loglevel", | |
help="set loglevel to DEBUG", | |
action="store_const", | |
const=logging.DEBUG, | |
) | |
return parser | |
if __name__ == "__main__": | |
args = get_parser().parse_args() | |
loglevel = args.loglevel or logging.INFO | |
setup_logging(loglevel) | |
logging.info("\n\n\nStarting app.py\n\n\n") | |
logging.info(f"args: {args}") | |
default_model = str(args.model) | |
if args.test: | |
logging.info("loading the smallest model for testing") | |
default_model = "ethzanalytics/distilgpt2-tiny-conversational" | |
model_loc = Path(default_model) # if the model is a path, use it | |
basic_sc = args.basic_sc # whether to use the baseline spellchecker | |
gram_model = str(args.gram_model) | |
device = 0 if torch.cuda.is_available() else -1 | |
logging.info(f"CUDA avail is {torch.cuda.is_available()}") | |
my_chatbot = ( | |
pipeline("text-generation", model=model_loc.resolve(), device=device) | |
if model_loc.exists() and model_loc.is_dir() | |
else pipeline("text-generation", model=default_model, device=device) | |
) # if the model is a name, use it. stays on CPU if no GPU available | |
logging.info(f"using model {my_chatbot.model}") | |
if basic_sc: | |
logging.info("Using the baseline spellchecker") | |
basic_spell = build_symspell_obj() | |
else: | |
logging.info("using neural spell checker") | |
grammarbot = pipeline("text2text-generation", gram_model, device=device) | |
logging.debug(f"using model stored here: \n {model_loc} \n") | |
iface = gr.Interface( | |
chat, | |
inputs=[ | |
Textbox( | |
default="Why is everyone here eating chocolate cake?", | |
label="prompt_message", | |
placeholder="Start a conversation with the bot", | |
lines=2, | |
), | |
Slider( | |
minimum=0.0, maximum=1.0, step=0.05, default=0.4, label="temperature" | |
), | |
Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.95, label="top_p"), | |
Slider(minimum=0, maximum=100, step=5, default=20, label="top_k"), | |
Radio( | |
choices=[True, False], | |
default=False, | |
label="constrained_generation", | |
), | |
], | |
outputs="html", | |
examples_per_page=8, | |
examples=[ | |
["Point Break or Bad Boys II?", 0.75, 0.95, 50, False], | |
["So... you're saying this wasn't an accident?", 0.6, 0.95, 40, False], | |
["Hi, my name is Reginald", 0.6, 0.95, 100, False], | |
["Happy birthday!", 0.9, 0.95, 50, False], | |
["I have a question, can you help me?", 0.6, 0.95, 50, False], | |
["Do you know a joke?", 0.8, 0.85, 50, False], | |
["Will you marry me?", 0.9, 0.95, 100, False], | |
["Are you single?", 0.95, 0.95, 100, False], | |
["Do you like people?", 0.7, 0.95, 25, False], | |
["You never took a shortcut before?", 0.7, 0.95, 100, False], | |
], | |
title=f"GPT Chatbot Demo: {default_model} Model", | |
description=f"A Demo of a Chatbot trained for conversation with humans. Size XL= 1.5B parameters.\n\n" | |
"**Important Notes & About:**\n\n" | |
"You can find a link to the model card **[here](https://huggingface.co/ethzanalytics/ai-msgbot-gpt2-XL-dialogue)**\n\n" | |
"1. responses can take up to 60 seconds to respond sometimes, patience is a virtue.\n" | |
"2. the model was trained on several different datasets. fact-check responses instead of regarding as a true statement.\n" | |
"3. Try adjusting the **[generation parameters](https://huggingface.co/blog/how-to-generate)** to get a better understanding of how they work!\n" | |
"4. New - try using [constrained beam search](https://huggingface.co/blog/constrained-beam-search) decoding to generate more coherent responses. _(experimental, feedback welcome!)_\n", | |
css=""" | |
.chatbox {display:flex;flex-direction:row} | |
.user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%} | |
.user_msg {background-color:cornflowerblue;color:white;align-self:start} | |
.resp_msg {background-color:lightgray;align-self:self-end} | |
""", | |
allow_flagging="never", | |
theme="dark", | |
) | |
# launch the gradio interface and start the server | |
iface.launch( | |
enable_queue=True, | |
) | |