Quiet-Star-Custom / generate.py
Crystalcareai's picture
Update generate.py
8bac0a3 verified
raw
history blame
5.62 kB
import torch
from transformers.utils import logging
from transformers.generation.utils import (
GenerationMixin,
validate_stopping_criteria,
StoppingCriteriaList,
)
logger = logging.get_logger(__name__)
def custom_generate(
self,
input_ids,
attention_mask=None,
max_length=None,
min_length=None,
do_sample=None,
early_stopping=None,
num_beams=None,
temperature=1.0,
top_k=None,
top_p=None,
repetition_penalty=None,
bad_words_ids=None,
bos_token_id=None,
pad_token_id=None,
eos_token_id=None,
length_penalty=None,
no_repeat_ngram_size=None,
num_return_sequences=None,
decoder_start_token_id=None,
use_cache=None,
num_beam_groups=None,
diversity_penalty=None,
prefix_allowed_tokens_fn=None,
output_attentions=None,
output_hidden_states=None,
output_scores=None,
return_dict_in_generate=None,
forced_bos_token_id=None,
forced_eos_token_id=None,
remove_invalid_values=None,
synced_gpus=None,
**kwargs,
):
with torch.no_grad():
finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device)
while not finished_generating.all() and input_ids.shape[1] < max_length:
# Sample the next token
new_ids = self(
input_ids[~finished_generating],
attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None,
**kwargs
)['logits']
# Mask out the start and end thought tokens so we don't accidentally sample them
new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
# Find the index of the last token that is not padding
base_answer_ids = input_ids[answer_idx]
new_answer_ids = new_ids[list_idx]
last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
new_ids_sampled = torch.multinomial(
torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1)
# Assign the new id to the last token
if last_token_idx + 1 >= len(base_answer_ids):
# Add padding everywhere
new_padding = torch.full((len(input_ids), 1), self.tokenizer.pad_token_id, dtype=torch.long,
device=input_ids.device)
input_ids = torch.cat([input_ids, new_padding], dim=-1)
if attention_mask is not None:
attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
if attention_mask is not None:
attention_mask[answer_idx, last_token_idx + 1] = 1
input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
finished_generating[answer_idx] = 1
# Check if the end token is generated
if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("<|/assistant|>"):
finished_generating[answer_idx] = 1
return input_ids, attention_mask
def generate(
self,
input_ids,
attention_mask=None,
max_length=None,
min_length=None,
do_sample=None,
early_stopping=None,
num_beams=None,
temperature=1.0,
top_k=None,
top_p=None,
repetition_penalty=None,
bad_words_ids=None,
bos_token_id=None,
pad_token_id=None,
eos_token_id=None,
length_penalty=None,
no_repeat_ngram_size=None,
num_return_sequences=None,
decoder_start_token_id=None,
use_cache=None,
num_beam_groups=None,
diversity_penalty=None,
prefix_allowed_tokens_fn=None,
output_attentions=None,
output_hidden_states=None,
output_scores=None,
return_dict_in_generate=None,
forced_bos_token_id=None,
forced_eos_token_id=None,
remove_invalid_values=None,
synced_gpus=None,
**model_kwargs,
):
return custom_generate(
self,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
early_stopping=early_stopping,
num_beams=num_beams,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
bad_words_ids=bad_words_ids,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
length_penalty=length_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
num_return_sequences=num_return_sequences,
decoder_start_token_id=decoder_start_token_id,
use_cache=use_cache,
num_beam_groups=num_beam_groups,
diversity_penalty=diversity_penalty,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id,
remove_invalid_values=remove_invalid_values,
synced_gpus=synced_gpus,
**model_kwargs,
)