gpt2-xl-conversational / constrained_generation.py
pszemraj's picture
🔊 add logs
a738f02
"""
constrained_generation.py - use constrained beam search to generate text from a model with entered constraints
"""
import copy
import logging
logging.basicConfig(level=logging.INFO)
import time
from pathlib import Path
import yake
from transformers import AutoTokenizer, PhrasalConstraint
def get_tokenizer(model_name="gpt2", verbose=False):
"""
get_tokenizer - returns a tokenizer object
:param model_name: name of the model to use, default gpt2
:param verbose: verbosity
"""
tokenizer = AutoTokenizer.from_pretrained(
model_name, add_special_tokens=False, padding=True, truncation=True
)
tokenizer.pad_token = tokenizer.eos_token
if verbose:
print(f"loaded tokenizer {model_name}")
return tokenizer
def unique_words(list_of_strings):
"""
unique_words - return a list of unique words from a list of strings. Uses set to remove duplicates.
"""
unique_words = []
output_list = []
for string in list_of_strings:
# split string into words
words = string.split()
# check if word is unique
unique_status = True
for word in words:
if word not in unique_words:
unique_words.append(word)
else:
unique_status = False
break
if unique_status:
output_list.append(string)
return output_list
def create_kw_extractor(
language="en",
max_ngram_size=3,
deduplication_algo="seqm",
windowSize=10,
numOfKeywords=10,
ddpt=0.7,
):
"""
creates a keyword extractor object
:param language: language of the text
:param max_ngram_size: max ngram size
:param deduplication_algo: deduplication algorithm
:param windowSize: window size
:param numOfKeywords: number of keywords
:param ddpt: Deduplication Percentage Threshold
:return: keyword extractor object
"""
assert ddpt >= 0 and ddpt <= 1, f"need 0<thresh<1, got {ddpt}"
return yake.KeywordExtractor(
lan=language,
n=max_ngram_size,
dedupLim=ddpt,
dedupFunc=deduplication_algo,
windowsSize=windowSize,
top=numOfKeywords,
features=None,
)
def simple_kw(body_text: str, yake_ex=None, max_kw=15, verbose=False):
"""
simple_kw - extract keywords from a text using yake
Args:
body_text (str): text to extract keywords from
yake_ex (yake.KeywordExtractor, optional): yake keyword extractor. Defaults to None.
max_kw (int, optional): maximum number of keywords to extract. Defaults to 10.
verbose (bool, optional): Defaults to False.
Returns:
list: list of keywords
"""
yake_ex = yake_ex or create_kw_extractor(
max_ngram_size=2,
ddpt=0.9,
windowSize=10,
deduplication_algo="seqm",
numOfKeywords=max_kw,
) # per optuna study
keywords = yake_ex.extract_keywords(body_text)
keywords_list = [str(kw[0]).lower() for kw in keywords]
logging.info(
f"YAKE: found {len(keywords_list)} keywords, the top {max_kw} are: {keywords_list[:max_kw]}"
)
if verbose:
print(f"found {len(keywords_list)} keywords, the top {max_kw} are:")
print(keywords_list[:max_kw])
logging.info(f"found {len(keywords_list)} keywords, the top {max_kw} are:")
return keywords_list[:max_kw]
def constrained_generation(
prompt: str,
pipeline,
tokenizer=None,
no_repeat_ngram_size=2,
length_penalty=0.7,
repetition_penalty=3.5,
num_beams=4,
max_generated_tokens=48,
min_generated_tokens=2,
timeout=300,
num_return_sequences=1,
verbose=False,
full_text=False,
force_word: str = None,
speaker_name: str = "Person Alpha",
responder_name: str = "Person Beta",
**kwargs,
):
"""
constrained_generation - generate text based on prompt and constraints
USAGE
-----
response = constrained_generation("hey man - how have you been lately?",
tokenizer, my_chatbot, verbose=True,
force_word=" meme", num_beams=32)
Parameters
----------
prompt : str, prompt to use for generation,
tokenizer : transformers.PreTrainedTokenizer, tokenizer to use, must be compatible with model
pipeline : transformers.pipeline, pipeline to use, must be compatible with tokenizer & text2text model
no_repeat_ngram_size : int, optional, default=2,
num_beams : int, optional, default=8,
max_generated_tokens : int, optional, default=64,
min_generated_tokens : int, optional, default=16,
verbose : bool, optional, default=False, print output
force_word : _type_, optional, default=None, force word to be used in generation
speaker_name : str, optional, default="Person Alpha", name of speaker
responder_name : str, optional, default="Person Beta", name of responder
Returns
-------
response : str, generated text
"""
logging.debug(f" constraining generation with {locals()}")
st = time.perf_counter()
tokenizer = tokenizer or copy.deepcopy(pipeline.tokenizer)
tokenizer.add_prefix_space = True
tokenizer.add_special_tokens = False
prompt_length = len(tokenizer(prompt, truncation=True).input_ids)
if responder_name.lower() not in prompt.lower():
prompt = f"{prompt}\n\n{responder_name}:\n"
# key_prompt_phrases = get_keyberts(prompt)
key_prompt_phrases = simple_kw(prompt)
try:
responder_name_words = responder_name.lower().split()
speaker_name_words = speaker_name.lower().split()
except Exception as e:
responder_name_words = []
speaker_name_words = []
logging.info(f"could not split names: {e}")
key_prompt_phrases = [
p
for p in key_prompt_phrases
if not any([name in p for name in responder_name_words])
and not any([name in p for name in speaker_name_words])
]
force_flexible = unique_words(key_prompt_phrases)
print(f"found keywords: {force_flexible}")
if verbose:
logging.info(f"found the following keywords: {force_flexible}")
logging.info(
f"forcing the word: {force_word}"
) if force_word is not None else logging.info("\n")
else:
logging.info(f"found the following keywords: {force_flexible}")
if len(force_flexible) == 0:
force_flexible = None
constraints = (
[
PhrasalConstraint(
tokenizer(force_word, add_special_tokens=False).input_ids,
),
]
if force_word is not None
else None
)
force_words_ids = (
[
tokenizer(
force_flexible,
).input_ids,
]
if force_flexible is not None
else None
)
try:
logging.info("generating text..")
result = pipeline(
prompt,
constraints=constraints if force_word is not None else None,
force_words_ids=force_words_ids if force_flexible is not None else None,
max_length=None,
max_new_tokens=max_generated_tokens,
min_length=min_generated_tokens + prompt_length
if full_text
else min_generated_tokens,
num_beams=num_beams,
no_repeat_ngram_size=no_repeat_ngram_size,
num_return_sequences=num_return_sequences,
max_time=timeout,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
return_full_text=full_text,
clean_up_tokenization_spaces=True,
early_stopping=True,
do_sample=False,
**kwargs,
)
response = result[0]["generated_text"]
rt = round((time.perf_counter() - st) / 60, 3)
logging.info(f"generated response in {rt} minutes")
if verbose:
print(f"input prompt:\n\t{prompt}")
print(f"response:\n\t{response}")
except Exception as e:
logging.info(f"could not generate response: {e}")
response = "Sorry, I don't know how to respond to that."
return response