import torch from transformers import LogitsProcessorList, StoppingCriteriaList def custom_generate( self, input_ids, attention_mask=None, max_new_tokens=None, temperature=1.0, do_sample=True, pad_token_id=None, eos_token_id=None, **kwargs ): device = input_ids.device logits_processor = LogitsProcessorList() stopping_criteria = StoppingCriteriaList() if attention_mask is None: attention_mask = torch.ones_like(input_ids) # Initialize unfinished sentences to manage the loop for early stopping unfinished_sents = input_ids.new(input_ids.shape[0]).fill_(1) cur_len = input_ids.shape[1] while cur_len < max_new_tokens: model_outputs = self( input_ids=input_ids, attention_mask=attention_mask, use_cache=True, return_dict=True ) next_token_logits = model_outputs.logits[:, -1, :] # Processing logits to avoid generating undesired tokens next_token_logits[:, pad_token_id] = -float('inf') # Never select pad next_token_logits[:, eos_token_id] = -float('inf') # Avoid generating end token prematurely # Apply temperature scaling and softmax to generate probabilities if do_sample: probabilities = torch.nn.functional.softmax(next_token_logits / temperature, dim=-1) next_token = torch.multinomial(probabilities, num_samples=1).squeeze(1) else: next_token = next_token_logits.argmax(dim=-1) # Update input_ids and attention_mask input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=1) new_attention = torch.ones_like(input_ids[:, 0:1]) attention_mask = torch.cat([attention_mask, new_attention], dim=1) # Check unfinished sentences unfinished_sents.mul_(next_token.ne(eos_token_id).long()) if unfinished_sents.max() == 0: break cur_len += 1 # Optionally return additional information if kwargs.get('return_dict_in_generate', False): output = { "sequences": input_ids, "scores": None, # Placeholder for when score calculations are implemented "attentions": model_outputs.attentions, "hidden_states": model_outputs.hidden_states } else: output = input_ids return output