Quiet-Star-Custom / generate.py
Crystalcareai's picture
Update generate.py
f118366 verified
raw
history blame
8.05 kB
import torch
from transformers.generation.utils import GenerationMixin, validate_stopping_criteria, StoppingCriteriaList
from transformers import TextStreamer
def custom_generate(
self,
input_ids,
attention_mask=None,
max_new_tokens=None,
min_length=None,
do_sample=None,
early_stopping=None,
num_beams=None,
temperature=None,
top_k=None,
top_p=None,
repetition_penalty=None,
bad_words_ids=None,
bos_token_id=None,
pad_token_id=None,
eos_token_id=None,
streamer=None,
length_penalty=None,
no_repeat_ngram_size=None,
num_return_sequences=None,
decoder_start_token_id=None,
use_cache=None,
num_beam_groups=None,
diversity_penalty=None,
prefix_allowed_tokens_fn=None,
output_attentions=None,
output_hidden_states=None,
output_scores=None,
return_dict_in_generate=None,
forced_bos_token_id=None,
forced_eos_token_id=None,
remove_invalid_values=None,
synced_gpus=None,
**kwargs,
):
if input_ids is None or input_ids.nelement() == 0:
# If input_ids is None or an empty tensor, create a default input tensor
input_ids = torch.LongTensor([[self.tokenizer.bos_token_id]]).to(self.device)
attention_mask = torch.ones_like(input_ids).to(self.device)
device = input_ids.device
with torch.no_grad():
batch_size = input_ids.shape[0]
finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device)
generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
for cur_token_idx in range(max_new_tokens):
# Sample the next token
new_ids = self(
input_ids[~finished_generating],
attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None,
**kwargs
)['logits']
# Mask out the start and end thought tokens so we don't accidentally sample them
new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
# Find the index of the last token that is not padding
base_answer_ids = input_ids[answer_idx]
new_answer_ids = new_ids[list_idx]
last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
new_ids_sampled = torch.multinomial(
torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1)
# Assign the new id to the last token
if last_token_idx + 1 >= len(base_answer_ids):
# Add padding everywhere
new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long,
device=device)
input_ids = torch.cat([input_ids, new_padding], dim=-1)
if attention_mask is not None:
attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
if attention_mask is not None:
attention_mask[answer_idx, last_token_idx + 1] = 1
input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
generated_token_ids[answer_idx, cur_token_idx] = new_ids_sampled
if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
finished_generating[answer_idx] = 1
# Check if the end token is generated
if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("</s>"):
finished_generating[answer_idx] = 1
if finished_generating.all():
break
if streamer is not None:
streamer.put(new_ids_sampled)
return generated_token_ids
def generate(
self,
input_ids,
attention_mask=None,
max_new_tokens=None,
min_length=None,
do_sample=None,
early_stopping=None,
num_beams=None,
temperature=1.1,
streamer=None,
top_k=None,
top_p=None,
repetition_penalty=None,
bad_words_ids=None,
bos_token_id=None,
pad_token_id=None,
eos_token_id=None,
length_penalty=None,
no_repeat_ngram_size=None,
num_return_sequences=None,
decoder_start_token_id=None,
use_cache=None,
num_beam_groups=None,
diversity_penalty=None,
prefix_allowed_tokens_fn=None,
output_attentions=None,
output_hidden_states=None,
output_scores=None,
return_dict_in_generate=None,
forced_bos_token_id=None,
forced_eos_token_id=None,
remove_invalid_values=None,
synced_gpus=None,
n_ahead=4,
n_ahead_talk=4,
merged_talk_heads=True,
merged_lm_and_talk_heads=False,
merged_lm_and_think_heads=True,
use_concat_talk_head=True,
use_shallow_think=True,
use_shallow_talk=False,
use_complex_think_head=False,
use_complex_talk_head=True,
use_weighted_talk_head=True,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
**model_kwargs,
):
if max_new_tokens is None:
max_new_tokens = 128
# Set model attributes
self.max_thoughts = n_ahead + n_ahead_talk + 1
self.merged_talk_heads = merged_talk_heads
self.merged_lm_and_talk_heads = merged_lm_and_talk_heads
self.merged_lm_and_think_heads = merged_lm_and_think_heads
self.use_concat_talk_head = use_concat_talk_head
self.use_shallow_think = use_shallow_think
self.use_shallow_talk = use_shallow_talk
self.use_complex_think_head = use_complex_think_head
self.use_complex_talk_head = use_complex_talk_head
self.use_weighted_talk_head = use_weighted_talk_head
# Set model properties
self.use_end_thought_token = True
self.use_start_thought_token = True
self.n_ahead = n_ahead
self.n_passes = 1
self.eval_mode = True
self.first_run = False
self.rm_initialized = True
self.original_mode = False
# Check if the input is a string (for compatibility with text-generation-webui)
if isinstance(input_ids, str):
input_ids = self.tokenizer.encode(input_ids, return_tensors='pt')
# Move input_ids and attention_mask to the same device as the model
input_ids = input_ids.to(self.device)
if attention_mask is not None:
attention_mask = attention_mask.to(self.device)
generated_token_ids = custom_generate(
self,
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
min_length=min_length,
do_sample=do_sample,
early_stopping=early_stopping,
num_beams=num_beams,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
bad_words_ids=bad_words_ids,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
length_penalty=length_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
num_return_sequences=num_return_sequences,
decoder_start_token_id=decoder_start_token_id,
use_cache=use_cache,
num_beam_groups=num_beam_groups,
diversity_penalty=diversity_penalty,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id,
remove_invalid_values=remove_invalid_values,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)
return generated_token_ids