Quiet-Star-Custom / generate.py
Crystalcareai's picture
Update generate.py
dcc7444 verified
raw
history blame
2.45 kB
import torch
from transformers import LogitsProcessorList, StoppingCriteriaList
def custom_generate(
self,
input_ids,
attention_mask=None,
max_new_tokens=None,
temperature=1.0,
do_sample=True,
pad_token_id=None,
eos_token_id=None,
**kwargs
):
device = input_ids.device
logits_processor = LogitsProcessorList()
stopping_criteria = StoppingCriteriaList()
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
# Initialize unfinished sentences to manage the loop for early stopping
unfinished_sents = input_ids.new(input_ids.shape[0]).fill_(1)
cur_len = input_ids.shape[1]
while cur_len < max_new_tokens:
model_outputs = self(
input_ids=input_ids,
attention_mask=attention_mask,
use_cache=True,
return_dict=True
)
next_token_logits = model_outputs.logits[:, -1, :]
# Processing logits to avoid generating undesired tokens
next_token_logits[:, pad_token_id] = -float('inf') # Never select pad
next_token_logits[:, eos_token_id] = -float('inf') # Avoid generating end token prematurely
# Apply temperature scaling and softmax to generate probabilities
if do_sample:
probabilities = torch.nn.functional.softmax(next_token_logits / temperature, dim=-1)
next_token = torch.multinomial(probabilities, num_samples=1).squeeze(1)
else:
next_token = next_token_logits.argmax(dim=-1)
# Update input_ids and attention_mask
input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=1)
new_attention = torch.ones_like(input_ids[:, 0:1])
attention_mask = torch.cat([attention_mask, new_attention], dim=1)
# Check unfinished sentences
unfinished_sents.mul_(next_token.ne(eos_token_id).long())
if unfinished_sents.max() == 0:
break
cur_len += 1
# Optionally return additional information
if kwargs.get('return_dict_in_generate', False):
output = {
"sequences": input_ids,
"scores": None, # Placeholder for when score calculations are implemented
"attentions": model_outputs.attentions,
"hidden_states": model_outputs.hidden_states
}
else:
output = input_ids
return output