""" converse.py - this script has functions for handling the conversation between the user and the bot. https://huggingface.co/docs/transformers/v4.15.0/en/main_classes/model#transformers.generation_utils.GenerationMixin.generate.no_repeat_ngram_size """ import logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) import pprint as pp import time from grammar_improve import remove_trailing_punctuation from constrained_generation import constrained_generation def discussion( prompt_text: str, speaker: str, responder: str, pipeline, timeout=45, min_length=8, max_length=64, top_p=0.95, top_k=50, temperature=0.7, full_text=False, length_penalty=0.8, no_repeat_ngram_size=2, num_return_sequences=1, device=-1, verbose=False, constrained_beam_search=False, ): """ discussion - a function that takes in a prompt and generates a response. This function is meant to be used in a conversation loop, and is the main function for the bot. Parameters ---------- prompt_text : str, the prompt to ask the bot, usually the user's question speaker : str, the name of the person who is speaking the prompt responder : str, the name of the person who is responding to the prompt pipeline : transformers.Pipeline, the pipeline to use for generating the response timeout : int, optional, the number of seconds to wait before timing out, by default 45 max_length : int, optional, the maximum number of tokens to generate, defaults to 128 top_p : float, optional, the top probability to use for sampling, defaults to 0.95 top_k : int, optional, the top k to use for sampling, defaults to 50 temperature : float, optional, the temperature to use for sampling, defaults to 0.7 full_text : bool, optional, whether to return the full text or just the generated text, defaults to False num_return_sequences : int, optional, the number of sequences to return, defaults to 1 device : int, optional, the device to use for generation, defaults to -1 (CPU) verbose : bool, optional, whether to print the generated text, defaults to False Returns ------- str, the generated text """ logging.debug(f"input args: {locals()}") p_list = [] # track conversation p_list.append(speaker.lower() + ":" + "\n") p_list.append(prompt_text.lower() + "\n") p_list.append("\n") p_list.append(responder.lower() + ":" + "\n") this_prompt = "".join(p_list) if verbose: print("overall prompt:\n") pp.pprint(this_prompt, indent=4) if constrained_beam_search: logging.info("generating using constrained beam search ...") response = constrained_generation( prompt=this_prompt, pipeline=pipeline, min_generated_tokens=min_length, max_generated_tokens=max_length, no_repeat_ngram_size=no_repeat_ngram_size, length_penalty=length_penalty, repetition_penalty=1.0, num_beams=4, timeout=timeout, verbose=False, full_text=full_text, speaker_name=speaker, responder_name=responder, ) bot_dialogue = consolidate_texts( name_resp=responder, model_resp=response.split("\n"), name_spk=speaker, verbose=verbose, print_debug=True, ) else: logging.info("generating using sampling ...") bot_dialogue = gen_response( this_prompt, pipeline, speaker, responder, timeout=timeout, min_length=min_length, max_length=max_length, top_p=top_p, top_k=top_k, temperature=temperature, full_text=full_text, no_repeat_ngram_size=no_repeat_ngram_size, length_penalty=length_penalty, num_return_sequences=num_return_sequences, device=device, verbose=verbose, ) logging.debug(f"generation done. bot_dialogue: {bot_dialogue}") if isinstance(bot_dialogue, list) and len(bot_dialogue) > 1: bot_resp = ", ".join(bot_dialogue) elif isinstance(bot_dialogue, list) and len(bot_dialogue) == 1: bot_resp = bot_dialogue[0] else: bot_resp = bot_dialogue bot_resp = " ".join(bot_resp) if isinstance(bot_resp, list) else bot_resp bot_resp = bot_resp.strip() # remove the last ',' '.' chars bot_resp = remove_trailing_punctuation(bot_resp) if verbose: print("\nfinished!") print("\n... bot response:\n") pp.pprint(bot_resp) p_list.append(bot_resp + "\n") p_list.append("\n") logging.info(f"finished generating response:\n\t{bot_resp}") # return the bot response and the full conversation return {"out_text": bot_resp, "full_conv": p_list} def gen_response( query: str, pipeline, speaker: str, responder: str, timeout=45, min_length=12, max_length=48, top_p=0.95, top_k=20, temperature=0.5, full_text=False, num_return_sequences=1, length_penalty: float = 0.8, repetition_penalty: float = 3.5, no_repeat_ngram_size=2, device=-1, verbose=False, **kwargs, ): """ gen_response - a function that takes in a prompt and generates a response using the pipeline. This operates underneath the discussion function. Parameters ---------- query : str, the prompt to ask the bot, usually the user's question speaker : str, the name of the person who is speaking the prompt responder : str, the name of the person who is responding to the prompt pipeline : transformers.Pipeline, the pipeline to use for generating the response timeout : int, optional, the number of seconds to wait before timing out, by default 45 min_length : int, optional, the minimum number of tokens to generate, defaults to 4 max_length : int, optional, the maximum number of tokens to generate, defaults to 64 top_p : float, optional, the top probability to use for sampling, defaults to 0.95 top_k : int, optional, the top k to use for sampling, defaults to 50 temperature : float, optional, the temperature to use for sampling, defaults to 0.7 full_text : bool, optional, whether to return the full text or just the generated text, defaults to False num_return_sequences : int, optional, the number of sequences to return, defaults to 1 device : int, optional, the device to use for generation, defaults to -1 (CPU) verbose : bool, optional, whether to print the generated text, defaults to False Returns ------- str, the generated text """ logging.debug(f"input args - gen_response() : {locals()}") input_len = len(pipeline.tokenizer(query).input_ids) if max_length + input_len > 1024: max_length = max(1024 - input_len, 8) print(f"max_length too large, setting to {max_length}") st = time.perf_counter() response = pipeline( query, min_length=min_length + input_len, max_length=max_length + input_len, temperature=temperature, top_k=top_k, top_p=top_p, num_return_sequences=num_return_sequences, max_time=timeout, return_full_text=full_text, no_repeat_ngram_size=no_repeat_ngram_size, repetition_penalty=repetition_penalty, length_penalty=length_penalty, clean_up_tokenization_spaces=True, remove_invalid_values=True, **kwargs, ) # the likely better beam-less method rt = round(time.perf_counter() - st, 2) if verbose: print(f"took {rt} sec to respond") if verbose: print("\n[DEBUG] generated:\n") pp.pprint(response) # for debugging # process the full result to get the ~bot response~ piece this_result = str(response[0]["generated_text"]).split( "\n" ) # TODO: adjust hardcoded value for index to dynamic (if n>1) bot_dialogue = consolidate_texts( name_resp=responder, model_resp=this_result, name_spk=speaker, verbose=verbose, print_debug=True, ) if verbose: print(f"DEBUG: {bot_dialogue} was original response pre-SC") return bot_dialogue # def consolidate_texts( model_resp: list, name_resp: str = None, name_spk: str = None, verbose=False, print_debug=False, ): """ consolidate_texts - given a list with speaker name followed by speaker text, returns all consecutive values of the first speaker name Parameters: name_resp (str): the name of the person who is responding model_resp (list): the list of strings to consolidate (usually from the model) name_spk (str): the name of the person who is speaking verbose (bool): whether to print the results print_debug (bool): whether to print the debug info during looping Returns: list, a list of all the consecutive messages of the first speaker name """ assert len(model_resp) > 0, "model_resp is empty" if len(model_resp) == 1: return model_resp[0] name_resp = "person beta" if name_resp is None else name_resp name_spk = "person alpha" if name_spk is None else name_spk if verbose: print("====" * 10) print( f"\n[DEBUG] initial model_resp has {len(model_resp)} lines: \n\t{model_resp}" ) print( f" the first element is \n\t{model_resp[0]} and it is {type(model_resp[0])}" ) fn_resp = [] name_counter = 0 break_safe = False for resline in model_resp: if name_resp.lower() in resline: name_counter += 1 break_safe = True # know the line is from bot as this line starts with the name of the bot continue # don't add this line to the list if name_spk.lower() in resline.lower(): if print_debug: print(f"\nDEBUG: \n\t{resline}\ncaused the break") break # the name of the speaker is in the line, so we're done if ( any([": " in resline, ":\n" in resline]) and name_resp.lower() not in resline.lower() ): if print_debug: print(f"\nDEBUG: \n\t{resline}\ncaused the break") break else: fn_resp.append(resline) break_safe = False if verbose: print("--" * 10) print("\nthe full response is:\n") print("\n".join(fn_resp)) print("--" * 10) return fn_resp