gpt2-xl-conversational / converse.py
pszemraj's picture
🔊 add logs
a738f02
"""
converse.py - this script has functions for handling the conversation between the user and the bot.
https://huggingface.co/docs/transformers/v4.15.0/en/main_classes/model#transformers.generation_utils.GenerationMixin.generate.no_repeat_ngram_size
"""
import logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
import pprint as pp
import time
from grammar_improve import remove_trailing_punctuation
from constrained_generation import constrained_generation
def discussion(
prompt_text: str,
speaker: str,
responder: str,
pipeline,
timeout=45,
min_length=8,
max_length=64,
top_p=0.95,
top_k=50,
temperature=0.7,
full_text=False,
length_penalty=0.8,
no_repeat_ngram_size=2,
num_return_sequences=1,
device=-1,
verbose=False,
constrained_beam_search=False,
):
"""
discussion - a function that takes in a prompt and generates a response. This function is meant to be used in a conversation loop, and is the main function for the bot.
Parameters
----------
prompt_text : str, the prompt to ask the bot, usually the user's question
speaker : str, the name of the person who is speaking the prompt
responder : str, the name of the person who is responding to the prompt
pipeline : transformers.Pipeline, the pipeline to use for generating the response
timeout : int, optional, the number of seconds to wait before timing out, by default 45
max_length : int, optional, the maximum number of tokens to generate, defaults to 128
top_p : float, optional, the top probability to use for sampling, defaults to 0.95
top_k : int, optional, the top k to use for sampling, defaults to 50
temperature : float, optional, the temperature to use for sampling, defaults to 0.7
full_text : bool, optional, whether to return the full text or just the generated text, defaults to False
num_return_sequences : int, optional, the number of sequences to return, defaults to 1
device : int, optional, the device to use for generation, defaults to -1 (CPU)
verbose : bool, optional, whether to print the generated text, defaults to False
Returns
-------
str, the generated text
"""
logging.debug(f"input args: {locals()}")
p_list = [] # track conversation
p_list.append(speaker.lower() + ":" + "\n")
p_list.append(prompt_text.lower() + "\n")
p_list.append("\n")
p_list.append(responder.lower() + ":" + "\n")
this_prompt = "".join(p_list)
if verbose:
print("overall prompt:\n")
pp.pprint(this_prompt, indent=4)
if constrained_beam_search:
logging.info("generating using constrained beam search ...")
response = constrained_generation(
prompt=this_prompt,
pipeline=pipeline,
min_generated_tokens=min_length,
max_generated_tokens=max_length,
no_repeat_ngram_size=no_repeat_ngram_size,
length_penalty=length_penalty,
repetition_penalty=1.0,
num_beams=4,
timeout=timeout,
verbose=False,
full_text=full_text,
speaker_name=speaker,
responder_name=responder,
)
bot_dialogue = consolidate_texts(
name_resp=responder,
model_resp=response.split("\n"),
name_spk=speaker,
verbose=verbose,
print_debug=True,
)
else:
logging.info("generating using sampling ...")
bot_dialogue = gen_response(
this_prompt,
pipeline,
speaker,
responder,
timeout=timeout,
min_length=min_length,
max_length=max_length,
top_p=top_p,
top_k=top_k,
temperature=temperature,
full_text=full_text,
no_repeat_ngram_size=no_repeat_ngram_size,
length_penalty=length_penalty,
num_return_sequences=num_return_sequences,
device=device,
verbose=verbose,
)
logging.debug(f"generation done. bot_dialogue: {bot_dialogue}")
if isinstance(bot_dialogue, list) and len(bot_dialogue) > 1:
bot_resp = ", ".join(bot_dialogue)
elif isinstance(bot_dialogue, list) and len(bot_dialogue) == 1:
bot_resp = bot_dialogue[0]
else:
bot_resp = bot_dialogue
bot_resp = " ".join(bot_resp) if isinstance(bot_resp, list) else bot_resp
bot_resp = bot_resp.strip()
# remove the last ',' '.' chars
bot_resp = remove_trailing_punctuation(bot_resp)
if verbose:
print("\nfinished!")
print("\n... bot response:\n")
pp.pprint(bot_resp)
p_list.append(bot_resp + "\n")
p_list.append("\n")
logging.info(f"finished generating response:\n\t{bot_resp}")
# return the bot response and the full conversation
return {"out_text": bot_resp, "full_conv": p_list}
def gen_response(
query: str,
pipeline,
speaker: str,
responder: str,
timeout=45,
min_length=12,
max_length=48,
top_p=0.95,
top_k=20,
temperature=0.5,
full_text=False,
num_return_sequences=1,
length_penalty: float = 0.8,
repetition_penalty: float = 3.5,
no_repeat_ngram_size=2,
device=-1,
verbose=False,
**kwargs,
):
"""
gen_response - a function that takes in a prompt and generates a response using the pipeline. This operates underneath the discussion function.
Parameters
----------
query : str, the prompt to ask the bot, usually the user's question
speaker : str, the name of the person who is speaking the prompt
responder : str, the name of the person who is responding to the prompt
pipeline : transformers.Pipeline, the pipeline to use for generating the response
timeout : int, optional, the number of seconds to wait before timing out, by default 45
min_length : int, optional, the minimum number of tokens to generate, defaults to 4
max_length : int, optional, the maximum number of tokens to generate, defaults to 64
top_p : float, optional, the top probability to use for sampling, defaults to 0.95
top_k : int, optional, the top k to use for sampling, defaults to 50
temperature : float, optional, the temperature to use for sampling, defaults to 0.7
full_text : bool, optional, whether to return the full text or just the generated text, defaults to False
num_return_sequences : int, optional, the number of sequences to return, defaults to 1
device : int, optional, the device to use for generation, defaults to -1 (CPU)
verbose : bool, optional, whether to print the generated text, defaults to False
Returns
-------
str, the generated text
"""
logging.debug(f"input args - gen_response() : {locals()}")
input_len = len(pipeline.tokenizer(query).input_ids)
if max_length + input_len > 1024:
max_length = max(1024 - input_len, 8)
print(f"max_length too large, setting to {max_length}")
st = time.perf_counter()
response = pipeline(
query,
min_length=min_length + input_len,
max_length=max_length + input_len,
temperature=temperature,
top_k=top_k,
top_p=top_p,
num_return_sequences=num_return_sequences,
max_time=timeout,
return_full_text=full_text,
no_repeat_ngram_size=no_repeat_ngram_size,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
clean_up_tokenization_spaces=True,
remove_invalid_values=True,
**kwargs,
) # the likely better beam-less method
rt = round(time.perf_counter() - st, 2)
if verbose:
print(f"took {rt} sec to respond")
if verbose:
print("\n[DEBUG] generated:\n")
pp.pprint(response) # for debugging
# process the full result to get the ~bot response~ piece
this_result = str(response[0]["generated_text"]).split(
"\n"
) # TODO: adjust hardcoded value for index to dynamic (if n>1)
bot_dialogue = consolidate_texts(
name_resp=responder,
model_resp=this_result,
name_spk=speaker,
verbose=verbose,
print_debug=True,
)
if verbose:
print(f"DEBUG: {bot_dialogue} was original response pre-SC")
return bot_dialogue #
def consolidate_texts(
model_resp: list,
name_resp: str = None,
name_spk: str = None,
verbose=False,
print_debug=False,
):
"""
consolidate_texts - given a list with speaker name followed by speaker text, returns all consecutive values of the first speaker name
Parameters:
name_resp (str): the name of the person who is responding
model_resp (list): the list of strings to consolidate (usually from the model)
name_spk (str): the name of the person who is speaking
verbose (bool): whether to print the results
print_debug (bool): whether to print the debug info during looping
Returns:
list, a list of all the consecutive messages of the first speaker name
"""
assert len(model_resp) > 0, "model_resp is empty"
if len(model_resp) == 1:
return model_resp[0]
name_resp = "person beta" if name_resp is None else name_resp
name_spk = "person alpha" if name_spk is None else name_spk
if verbose:
print("====" * 10)
print(
f"\n[DEBUG] initial model_resp has {len(model_resp)} lines: \n\t{model_resp}"
)
print(
f" the first element is \n\t{model_resp[0]} and it is {type(model_resp[0])}"
)
fn_resp = []
name_counter = 0
break_safe = False
for resline in model_resp:
if name_resp.lower() in resline:
name_counter += 1
break_safe = True # know the line is from bot as this line starts with the name of the bot
continue # don't add this line to the list
if name_spk.lower() in resline.lower():
if print_debug:
print(f"\nDEBUG: \n\t{resline}\ncaused the break")
break # the name of the speaker is in the line, so we're done
if (
any([": " in resline, ":\n" in resline])
and name_resp.lower() not in resline.lower()
):
if print_debug:
print(f"\nDEBUG: \n\t{resline}\ncaused the break")
break
else:
fn_resp.append(resline)
break_safe = False
if verbose:
print("--" * 10)
print("\nthe full response is:\n")
print("\n".join(fn_resp))
print("--" * 10)
return fn_resp