File size: 2,452 Bytes
e35584d dcc7444 e35584d 8bac0a3 ffe6ef0 dcc7444 8bac0a3 dcc7444 8bac0a3 7874fb0 dcc7444 91da428 dcc7444 91da428 dcc7444 91da428 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import torch
from transformers import LogitsProcessorList, StoppingCriteriaList
def custom_generate(
self,
input_ids,
attention_mask=None,
max_new_tokens=None,
temperature=1.0,
do_sample=True,
pad_token_id=None,
eos_token_id=None,
**kwargs
):
device = input_ids.device
logits_processor = LogitsProcessorList()
stopping_criteria = StoppingCriteriaList()
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
# Initialize unfinished sentences to manage the loop for early stopping
unfinished_sents = input_ids.new(input_ids.shape[0]).fill_(1)
cur_len = input_ids.shape[1]
while cur_len < max_new_tokens:
model_outputs = self(
input_ids=input_ids,
attention_mask=attention_mask,
use_cache=True,
return_dict=True
)
next_token_logits = model_outputs.logits[:, -1, :]
# Processing logits to avoid generating undesired tokens
next_token_logits[:, pad_token_id] = -float('inf') # Never select pad
next_token_logits[:, eos_token_id] = -float('inf') # Avoid generating end token prematurely
# Apply temperature scaling and softmax to generate probabilities
if do_sample:
probabilities = torch.nn.functional.softmax(next_token_logits / temperature, dim=-1)
next_token = torch.multinomial(probabilities, num_samples=1).squeeze(1)
else:
next_token = next_token_logits.argmax(dim=-1)
# Update input_ids and attention_mask
input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=1)
new_attention = torch.ones_like(input_ids[:, 0:1])
attention_mask = torch.cat([attention_mask, new_attention], dim=1)
# Check unfinished sentences
unfinished_sents.mul_(next_token.ne(eos_token_id).long())
if unfinished_sents.max() == 0:
break
cur_len += 1
# Optionally return additional information
if kwargs.get('return_dict_in_generate', False):
output = {
"sequences": input_ids,
"scores": None, # Placeholder for when score calculations are implemented
"attentions": model_outputs.attentions,
"hidden_states": model_outputs.hidden_states
}
else:
output = input_ids
return output
|