|
import torch |
|
from transformers import LogitsProcessorList, StoppingCriteriaList |
|
|
|
def custom_generate( |
|
self, |
|
input_ids, |
|
attention_mask=None, |
|
max_new_tokens=None, |
|
temperature=1.0, |
|
do_sample=True, |
|
pad_token_id=None, |
|
eos_token_id=None, |
|
**kwargs |
|
): |
|
device = input_ids.device |
|
logits_processor = LogitsProcessorList() |
|
stopping_criteria = StoppingCriteriaList() |
|
|
|
if attention_mask is None: |
|
attention_mask = torch.ones_like(input_ids) |
|
|
|
|
|
unfinished_sents = input_ids.new(input_ids.shape[0]).fill_(1) |
|
|
|
cur_len = input_ids.shape[1] |
|
|
|
while cur_len < max_new_tokens: |
|
model_outputs = self( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
use_cache=True, |
|
return_dict=True |
|
) |
|
|
|
next_token_logits = model_outputs.logits[:, -1, :] |
|
|
|
|
|
next_token_logits[:, pad_token_id] = -float('inf') |
|
next_token_logits[:, eos_token_id] = -float('inf') |
|
|
|
|
|
if do_sample: |
|
probabilities = torch.nn.functional.softmax(next_token_logits / temperature, dim=-1) |
|
next_token = torch.multinomial(probabilities, num_samples=1).squeeze(1) |
|
else: |
|
next_token = next_token_logits.argmax(dim=-1) |
|
|
|
|
|
input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=1) |
|
new_attention = torch.ones_like(input_ids[:, 0:1]) |
|
attention_mask = torch.cat([attention_mask, new_attention], dim=1) |
|
|
|
|
|
unfinished_sents.mul_(next_token.ne(eos_token_id).long()) |
|
if unfinished_sents.max() == 0: |
|
break |
|
|
|
cur_len += 1 |
|
|
|
|
|
if kwargs.get('return_dict_in_generate', False): |
|
output = { |
|
"sequences": input_ids, |
|
"scores": None, |
|
"attentions": model_outputs.attentions, |
|
"hidden_states": model_outputs.hidden_states |
|
} |
|
else: |
|
output = input_ids |
|
|
|
return output |
|
|