Quiet-Star-Custom / generate.py
Crystalcareai's picture
Update generate.py
b3f07c3 verified
raw
history blame
7.32 kB
import torch
from transformers.utils import logging
from transformers.generation.utils import (
GenerationMixin,
validate_stopping_criteria,
StoppingCriteriaList,
)
logger = logging.get_logger(__name__)
def custom_generate(
self,
input_ids,
attention_mask=None,
max_length=None,
min_length=None,
do_sample=None,
early_stopping=None,
num_beams=None,
temperature=1.0,
top_k=None,
top_p=None,
repetition_penalty=None,
bad_words_ids=None,
bos_token_id=None,
pad_token_id=None,
eos_token_id=None,
length_penalty=None,
no_repeat_ngram_size=None,
num_return_sequences=None,
decoder_start_token_id=None,
use_cache=None,
num_beam_groups=None,
diversity_penalty=None,
prefix_allowed_tokens_fn=None,
output_attentions=None,
output_hidden_states=None,
output_scores=None,
return_dict_in_generate=None,
forced_bos_token_id=None,
forced_eos_token_id=None,
remove_invalid_values=None,
synced_gpus=None,
**kwargs,
):
with torch.no_grad():
finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device)
while not finished_generating.all() and input_ids.shape[1] < max_length:
# Sample the next token
new_ids = self(
input_ids[~finished_generating],
attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None,
**kwargs
)['logits']
# Mask out the start and end thought tokens so we don't accidentally sample them
new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
# Find the index of the last token that is not padding
base_answer_ids = input_ids[answer_idx]
new_answer_ids = new_ids[list_idx]
last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
new_ids_sampled = torch.multinomial(
torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1)
# Assign the new id to the last token
if last_token_idx + 1 >= len(base_answer_ids):
# Add padding everywhere
new_padding = torch.full((len(input_ids), 1), self.tokenizer.pad_token_id, dtype=torch.long,
device=input_ids.device)
input_ids = torch.cat([input_ids, new_padding], dim=-1)
if attention_mask is not None:
attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
if attention_mask is not None:
attention_mask[answer_idx, last_token_idx + 1] = 1
input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
finished_generating[answer_idx] = 1
# Check if the end token is generated
if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("<|/assistant|>"):
finished_generating[answer_idx] = 1
generated_token_ids = input_ids.tolist()
return generated_token_ids
def generate(
self,
input_ids,
attention_mask=None,
max_length=None,
min_length=None,
do_sample=None,
early_stopping=None,
num_beams=None,
temperature=1.0,
top_k=None,
top_p=None,
repetition_penalty=None,
bad_words_ids=None,
bos_token_id=None,
pad_token_id=None,
eos_token_id=None,
length_penalty=None,
no_repeat_ngram_size=None,
num_return_sequences=None,
decoder_start_token_id=None,
use_cache=None,
num_beam_groups=None,
diversity_penalty=None,
prefix_allowed_tokens_fn=None,
output_attentions=None,
output_hidden_states=None,
output_scores=None,
return_dict_in_generate=None,
forced_bos_token_id=None,
forced_eos_token_id=None,
remove_invalid_values=None,
synced_gpus=None,
n_ahead=4,
n_ahead_talk=4,
merged_talk_heads=True,
merged_lm_and_talk_heads=False,
merged_lm_and_think_heads=True,
use_concat_talk_head=True,
use_shallow_think=True,
use_shallow_talk=False,
use_complex_think_head=False,
use_complex_talk_head=True,
use_weighted_talk_head=True,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
**model_kwargs,
):
# Set model attributes
self.max_thoughts = n_ahead + n_ahead_talk + 1
self.merged_talk_heads = merged_talk_heads
self.merged_lm_and_talk_heads = merged_lm_and_talk_heads
self.merged_lm_and_think_heads = merged_lm_and_think_heads
self.use_concat_talk_head = use_concat_talk_head
self.use_shallow_think = use_shallow_think
self.use_shallow_talk = use_shallow_talk
self.use_complex_think_head = use_complex_think_head
self.use_complex_talk_head = use_complex_talk_head
self.use_weighted_talk_head = use_weighted_talk_head
# Set model properties
self.use_end_thought_token = True
self.use_start_thought_token = True
self.wandb_enabled = True
self.n_ahead = n_ahead
self.n_passes = 1
self.eval_mode = True
self.first_run = False
self.kill_after = 100
self.rm_initialized = True
self.original_mode = False
# Validate stopping criteria
stopping_criteria = validate_stopping_criteria(
max_length, max_new_tokens=None, min_length=min_length, stopping_criteria=self._get_stopping_criteria(generation_config=self.config, stopping_criteria=StoppingCriteriaList())
)
# Generate using the custom generate function
generated_token_ids = custom_generate(
self,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
early_stopping=early_stopping,
num_beams=num_beams,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
bad_words_ids=bad_words_ids,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
length_penalty=length_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
num_return_sequences=num_return_sequences,
decoder_start_token_id=decoder_start_token_id,
use_cache=use_cache,
num_beam_groups=num_beam_groups,
diversity_penalty=diversity_penalty,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id,
remove_invalid_values=remove_invalid_values,
synced_gpus=synced_gpus,
**model_kwargs,
)
return generated_token_ids