|
import torch |
|
from transformers.utils import logging |
|
from transformers.generation.utils import ( |
|
GenerationMixin, |
|
validate_stopping_criteria, |
|
StoppingCriteriaList, |
|
) |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
def custom_generate( |
|
self, |
|
input_ids, |
|
attention_mask=None, |
|
max_length=None, |
|
min_length=None, |
|
do_sample=None, |
|
early_stopping=None, |
|
num_beams=None, |
|
temperature=1.0, |
|
top_k=None, |
|
top_p=None, |
|
repetition_penalty=None, |
|
bad_words_ids=None, |
|
bos_token_id=None, |
|
pad_token_id=None, |
|
eos_token_id=None, |
|
length_penalty=None, |
|
no_repeat_ngram_size=None, |
|
num_return_sequences=None, |
|
decoder_start_token_id=None, |
|
use_cache=None, |
|
num_beam_groups=None, |
|
diversity_penalty=None, |
|
prefix_allowed_tokens_fn=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
output_scores=None, |
|
return_dict_in_generate=None, |
|
forced_bos_token_id=None, |
|
forced_eos_token_id=None, |
|
remove_invalid_values=None, |
|
synced_gpus=None, |
|
**kwargs, |
|
): |
|
with torch.no_grad(): |
|
finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device) |
|
|
|
while not finished_generating.all() and input_ids.shape[1] < max_length: |
|
|
|
new_ids = self( |
|
input_ids[~finished_generating], |
|
attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None, |
|
**kwargs |
|
)['logits'] |
|
|
|
|
|
new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf") |
|
|
|
for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]): |
|
|
|
base_answer_ids = input_ids[answer_idx] |
|
new_answer_ids = new_ids[list_idx] |
|
last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max() |
|
|
|
new_ids_sampled = torch.multinomial( |
|
torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1) |
|
|
|
|
|
if last_token_idx + 1 >= len(base_answer_ids): |
|
|
|
new_padding = torch.full((len(input_ids), 1), self.tokenizer.pad_token_id, dtype=torch.long, |
|
device=input_ids.device) |
|
input_ids = torch.cat([input_ids, new_padding], dim=-1) |
|
if attention_mask is not None: |
|
attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1) |
|
|
|
if attention_mask is not None: |
|
attention_mask[answer_idx, last_token_idx + 1] = 1 |
|
input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled |
|
|
|
if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id: |
|
finished_generating[answer_idx] = 1 |
|
|
|
|
|
if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("<|/assistant|>"): |
|
finished_generating[answer_idx] = 1 |
|
|
|
return input_ids, attention_mask |
|
|
|
def generate( |
|
self, |
|
input_ids, |
|
attention_mask=None, |
|
max_length=None, |
|
min_length=None, |
|
do_sample=None, |
|
early_stopping=None, |
|
num_beams=None, |
|
temperature=1.0, |
|
top_k=None, |
|
top_p=None, |
|
repetition_penalty=None, |
|
bad_words_ids=None, |
|
bos_token_id=None, |
|
pad_token_id=None, |
|
eos_token_id=None, |
|
length_penalty=None, |
|
no_repeat_ngram_size=None, |
|
num_return_sequences=None, |
|
decoder_start_token_id=None, |
|
use_cache=None, |
|
num_beam_groups=None, |
|
diversity_penalty=None, |
|
prefix_allowed_tokens_fn=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
output_scores=None, |
|
return_dict_in_generate=None, |
|
forced_bos_token_id=None, |
|
forced_eos_token_id=None, |
|
remove_invalid_values=None, |
|
synced_gpus=None, |
|
**model_kwargs, |
|
): |
|
return custom_generate( |
|
self, |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
max_length=max_length, |
|
min_length=min_length, |
|
do_sample=do_sample, |
|
early_stopping=early_stopping, |
|
num_beams=num_beams, |
|
temperature=temperature, |
|
top_k=top_k, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
bad_words_ids=bad_words_ids, |
|
bos_token_id=bos_token_id, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
length_penalty=length_penalty, |
|
no_repeat_ngram_size=no_repeat_ngram_size, |
|
num_return_sequences=num_return_sequences, |
|
decoder_start_token_id=decoder_start_token_id, |
|
use_cache=use_cache, |
|
num_beam_groups=num_beam_groups, |
|
diversity_penalty=diversity_penalty, |
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
output_scores=output_scores, |
|
return_dict_in_generate=return_dict_in_generate, |
|
forced_bos_token_id=forced_bos_token_id, |
|
forced_eos_token_id=forced_eos_token_id, |
|
remove_invalid_values=remove_invalid_values, |
|
synced_gpus=synced_gpus, |
|
**model_kwargs, |
|
) |