File size: 5,623 Bytes
c0dd54c dc34aea c0dd54c 8bac0a3 dc34aea c0dd54c dc34aea c0dd54c dc34aea c0dd54c 8bac0a3 dc34aea 8bac0a3 dc34aea 8bac0a3 dc34aea 8bac0a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import torch
from transformers.utils import logging
from transformers.generation.utils import (
GenerationMixin,
validate_stopping_criteria,
StoppingCriteriaList,
)
logger = logging.get_logger(__name__)
def custom_generate(
self,
input_ids,
attention_mask=None,
max_length=None,
min_length=None,
do_sample=None,
early_stopping=None,
num_beams=None,
temperature=1.0,
top_k=None,
top_p=None,
repetition_penalty=None,
bad_words_ids=None,
bos_token_id=None,
pad_token_id=None,
eos_token_id=None,
length_penalty=None,
no_repeat_ngram_size=None,
num_return_sequences=None,
decoder_start_token_id=None,
use_cache=None,
num_beam_groups=None,
diversity_penalty=None,
prefix_allowed_tokens_fn=None,
output_attentions=None,
output_hidden_states=None,
output_scores=None,
return_dict_in_generate=None,
forced_bos_token_id=None,
forced_eos_token_id=None,
remove_invalid_values=None,
synced_gpus=None,
**kwargs,
):
with torch.no_grad():
finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device)
while not finished_generating.all() and input_ids.shape[1] < max_length:
# Sample the next token
new_ids = self(
input_ids[~finished_generating],
attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None,
**kwargs
)['logits']
# Mask out the start and end thought tokens so we don't accidentally sample them
new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
# Find the index of the last token that is not padding
base_answer_ids = input_ids[answer_idx]
new_answer_ids = new_ids[list_idx]
last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
new_ids_sampled = torch.multinomial(
torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1)
# Assign the new id to the last token
if last_token_idx + 1 >= len(base_answer_ids):
# Add padding everywhere
new_padding = torch.full((len(input_ids), 1), self.tokenizer.pad_token_id, dtype=torch.long,
device=input_ids.device)
input_ids = torch.cat([input_ids, new_padding], dim=-1)
if attention_mask is not None:
attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
if attention_mask is not None:
attention_mask[answer_idx, last_token_idx + 1] = 1
input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
finished_generating[answer_idx] = 1
# Check if the end token is generated
if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("<|/assistant|>"):
finished_generating[answer_idx] = 1
return input_ids, attention_mask
def generate(
self,
input_ids,
attention_mask=None,
max_length=None,
min_length=None,
do_sample=None,
early_stopping=None,
num_beams=None,
temperature=1.0,
top_k=None,
top_p=None,
repetition_penalty=None,
bad_words_ids=None,
bos_token_id=None,
pad_token_id=None,
eos_token_id=None,
length_penalty=None,
no_repeat_ngram_size=None,
num_return_sequences=None,
decoder_start_token_id=None,
use_cache=None,
num_beam_groups=None,
diversity_penalty=None,
prefix_allowed_tokens_fn=None,
output_attentions=None,
output_hidden_states=None,
output_scores=None,
return_dict_in_generate=None,
forced_bos_token_id=None,
forced_eos_token_id=None,
remove_invalid_values=None,
synced_gpus=None,
**model_kwargs,
):
return custom_generate(
self,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
early_stopping=early_stopping,
num_beams=num_beams,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
bad_words_ids=bad_words_ids,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
length_penalty=length_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
num_return_sequences=num_return_sequences,
decoder_start_token_id=decoder_start_token_id,
use_cache=use_cache,
num_beam_groups=num_beam_groups,
diversity_penalty=diversity_penalty,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id,
remove_invalid_values=remove_invalid_values,
synced_gpus=synced_gpus,
**model_kwargs,
) |