Spaces:
Runtime error
Runtime error
""" | |
constrained_generation.py - use constrained beam search to generate text from a model with entered constraints | |
""" | |
import copy | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
import time | |
from pathlib import Path | |
import yake | |
from transformers import AutoTokenizer, PhrasalConstraint | |
def get_tokenizer(model_name="gpt2", verbose=False): | |
""" | |
get_tokenizer - returns a tokenizer object | |
:param model_name: name of the model to use, default gpt2 | |
:param verbose: verbosity | |
""" | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name, add_special_tokens=False, padding=True, truncation=True | |
) | |
tokenizer.pad_token = tokenizer.eos_token | |
if verbose: | |
print(f"loaded tokenizer {model_name}") | |
return tokenizer | |
def unique_words(list_of_strings): | |
""" | |
unique_words - return a list of unique words from a list of strings. Uses set to remove duplicates. | |
""" | |
unique_words = [] | |
output_list = [] | |
for string in list_of_strings: | |
# split string into words | |
words = string.split() | |
# check if word is unique | |
unique_status = True | |
for word in words: | |
if word not in unique_words: | |
unique_words.append(word) | |
else: | |
unique_status = False | |
break | |
if unique_status: | |
output_list.append(string) | |
return output_list | |
def create_kw_extractor( | |
language="en", | |
max_ngram_size=3, | |
deduplication_algo="seqm", | |
windowSize=10, | |
numOfKeywords=10, | |
ddpt=0.7, | |
): | |
""" | |
creates a keyword extractor object | |
:param language: language of the text | |
:param max_ngram_size: max ngram size | |
:param deduplication_algo: deduplication algorithm | |
:param windowSize: window size | |
:param numOfKeywords: number of keywords | |
:param ddpt: Deduplication Percentage Threshold | |
:return: keyword extractor object | |
""" | |
assert ddpt >= 0 and ddpt <= 1, f"need 0<thresh<1, got {ddpt}" | |
return yake.KeywordExtractor( | |
lan=language, | |
n=max_ngram_size, | |
dedupLim=ddpt, | |
dedupFunc=deduplication_algo, | |
windowsSize=windowSize, | |
top=numOfKeywords, | |
features=None, | |
) | |
def simple_kw(body_text: str, yake_ex=None, max_kw=15, verbose=False): | |
""" | |
simple_kw - extract keywords from a text using yake | |
Args: | |
body_text (str): text to extract keywords from | |
yake_ex (yake.KeywordExtractor, optional): yake keyword extractor. Defaults to None. | |
max_kw (int, optional): maximum number of keywords to extract. Defaults to 10. | |
verbose (bool, optional): Defaults to False. | |
Returns: | |
list: list of keywords | |
""" | |
yake_ex = yake_ex or create_kw_extractor( | |
max_ngram_size=2, | |
ddpt=0.9, | |
windowSize=10, | |
deduplication_algo="seqm", | |
numOfKeywords=max_kw, | |
) # per optuna study | |
keywords = yake_ex.extract_keywords(body_text) | |
keywords_list = [str(kw[0]).lower() for kw in keywords] | |
logging.info( | |
f"YAKE: found {len(keywords_list)} keywords, the top {max_kw} are: {keywords_list[:max_kw]}" | |
) | |
if verbose: | |
print(f"found {len(keywords_list)} keywords, the top {max_kw} are:") | |
print(keywords_list[:max_kw]) | |
logging.info(f"found {len(keywords_list)} keywords, the top {max_kw} are:") | |
return keywords_list[:max_kw] | |
def constrained_generation( | |
prompt: str, | |
pipeline, | |
tokenizer=None, | |
no_repeat_ngram_size=2, | |
length_penalty=0.7, | |
repetition_penalty=3.5, | |
num_beams=4, | |
max_generated_tokens=48, | |
min_generated_tokens=2, | |
timeout=300, | |
num_return_sequences=1, | |
verbose=False, | |
full_text=False, | |
force_word: str = None, | |
speaker_name: str = "Person Alpha", | |
responder_name: str = "Person Beta", | |
**kwargs, | |
): | |
""" | |
constrained_generation - generate text based on prompt and constraints | |
USAGE | |
----- | |
response = constrained_generation("hey man - how have you been lately?", | |
tokenizer, my_chatbot, verbose=True, | |
force_word=" meme", num_beams=32) | |
Parameters | |
---------- | |
prompt : str, prompt to use for generation, | |
tokenizer : transformers.PreTrainedTokenizer, tokenizer to use, must be compatible with model | |
pipeline : transformers.pipeline, pipeline to use, must be compatible with tokenizer & text2text model | |
no_repeat_ngram_size : int, optional, default=2, | |
num_beams : int, optional, default=8, | |
max_generated_tokens : int, optional, default=64, | |
min_generated_tokens : int, optional, default=16, | |
verbose : bool, optional, default=False, print output | |
force_word : _type_, optional, default=None, force word to be used in generation | |
speaker_name : str, optional, default="Person Alpha", name of speaker | |
responder_name : str, optional, default="Person Beta", name of responder | |
Returns | |
------- | |
response : str, generated text | |
""" | |
logging.debug(f" constraining generation with {locals()}") | |
st = time.perf_counter() | |
tokenizer = tokenizer or copy.deepcopy(pipeline.tokenizer) | |
tokenizer.add_prefix_space = True | |
tokenizer.add_special_tokens = False | |
prompt_length = len(tokenizer(prompt, truncation=True).input_ids) | |
if responder_name.lower() not in prompt.lower(): | |
prompt = f"{prompt}\n\n{responder_name}:\n" | |
# key_prompt_phrases = get_keyberts(prompt) | |
key_prompt_phrases = simple_kw(prompt) | |
try: | |
responder_name_words = responder_name.lower().split() | |
speaker_name_words = speaker_name.lower().split() | |
except Exception as e: | |
responder_name_words = [] | |
speaker_name_words = [] | |
logging.info(f"could not split names: {e}") | |
key_prompt_phrases = [ | |
p | |
for p in key_prompt_phrases | |
if not any([name in p for name in responder_name_words]) | |
and not any([name in p for name in speaker_name_words]) | |
] | |
force_flexible = unique_words(key_prompt_phrases) | |
print(f"found keywords: {force_flexible}") | |
if verbose: | |
logging.info(f"found the following keywords: {force_flexible}") | |
logging.info( | |
f"forcing the word: {force_word}" | |
) if force_word is not None else logging.info("\n") | |
else: | |
logging.info(f"found the following keywords: {force_flexible}") | |
if len(force_flexible) == 0: | |
force_flexible = None | |
constraints = ( | |
[ | |
PhrasalConstraint( | |
tokenizer(force_word, add_special_tokens=False).input_ids, | |
), | |
] | |
if force_word is not None | |
else None | |
) | |
force_words_ids = ( | |
[ | |
tokenizer( | |
force_flexible, | |
).input_ids, | |
] | |
if force_flexible is not None | |
else None | |
) | |
try: | |
logging.info("generating text..") | |
result = pipeline( | |
prompt, | |
constraints=constraints if force_word is not None else None, | |
force_words_ids=force_words_ids if force_flexible is not None else None, | |
max_length=None, | |
max_new_tokens=max_generated_tokens, | |
min_length=min_generated_tokens + prompt_length | |
if full_text | |
else min_generated_tokens, | |
num_beams=num_beams, | |
no_repeat_ngram_size=no_repeat_ngram_size, | |
num_return_sequences=num_return_sequences, | |
max_time=timeout, | |
length_penalty=length_penalty, | |
repetition_penalty=repetition_penalty, | |
return_full_text=full_text, | |
clean_up_tokenization_spaces=True, | |
early_stopping=True, | |
do_sample=False, | |
**kwargs, | |
) | |
response = result[0]["generated_text"] | |
rt = round((time.perf_counter() - st) / 60, 3) | |
logging.info(f"generated response in {rt} minutes") | |
if verbose: | |
print(f"input prompt:\n\t{prompt}") | |
print(f"response:\n\t{response}") | |
except Exception as e: | |
logging.info(f"could not generate response: {e}") | |
response = "Sorry, I don't know how to respond to that." | |
return response | |