import torch from transformers.generation.utils import ( GenerationMixin, validate_stopping_criteria, StoppingCriteriaList, ) from transformers import TextStreamer def custom_generate( self, input_ids, attention_mask=None, max_new_tokens=None, min_length=None, do_sample=None, early_stopping=None, num_beams=None, temperature=None, top_k=None, top_p=None, repetition_penalty=None, bad_words_ids=None, bos_token_id=None, pad_token_id=None, eos_token_id=None, streamer=None, length_penalty=None, no_repeat_ngram_size=None, num_return_sequences=None, decoder_start_token_id=None, use_cache=None, num_beam_groups=None, diversity_penalty=None, prefix_allowed_tokens_fn=None, output_attentions=None, output_hidden_states=None, output_scores=None, return_dict_in_generate=None, forced_bos_token_id=None, forced_eos_token_id=None, remove_invalid_values=None, synced_gpus=None, **kwargs, ): device = input_ids.device with torch.no_grad(): finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=device) for cur_token_idx in range(max_new_tokens): # Sample the next token new_ids = self( input_ids[~finished_generating], attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None, **kwargs )['logits'] # Mask out the start and end thought tokens so we don't accidentally sample them new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf") for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]): # Find the index of the last token that is not padding base_answer_ids = input_ids[answer_idx] new_answer_ids = new_ids[list_idx] last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max() new_ids_sampled = torch.multinomial( torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1) # Assign the new id to the last token if last_token_idx + 1 >= len(base_answer_ids): # Add padding everywhere new_padding = torch.full((len(input_ids), 1), self.tokenizer.pad_token_id, dtype=torch.long, device=device) input_ids = torch.cat([input_ids, new_padding], dim=-1) if attention_mask is not None: attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1) if attention_mask is not None: attention_mask[answer_idx, last_token_idx + 1] = 1 input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id: finished_generating[answer_idx] = 1 # Check if the end token is generated if new_ids_sampled == self.tokenizer.convert_tokens_to_ids(""): finished_generating[answer_idx] = 1 if finished_generating.all(): break if streamer is not None: streamer.put(new_ids_sampled) generated_token_ids = input_ids.tolist() return generated_token_ids, attention_mask def generate( self, input_ids, attention_mask=None, max_new_tokens=None, min_length=None, do_sample=None, early_stopping=None, num_beams=None, temperature=1.1, streamer=None, top_k=None, top_p=None, repetition_penalty=None, bad_words_ids=None, bos_token_id=None, pad_token_id=None, eos_token_id=None, length_penalty=None, no_repeat_ngram_size=None, num_return_sequences=None, decoder_start_token_id=None, use_cache=None, num_beam_groups=None, diversity_penalty=None, prefix_allowed_tokens_fn=None, output_attentions=None, output_hidden_states=None, output_scores=None, return_dict_in_generate=None, forced_bos_token_id=None, forced_eos_token_id=None, remove_invalid_values=None, synced_gpus=None, n_ahead=12, n_ahead_talk=4, merged_talk_heads=True, merged_lm_and_talk_heads=False, merged_lm_and_think_heads=True, use_concat_talk_head=True, use_shallow_think=True, use_shallow_talk=False, use_complex_think_head=False, use_complex_talk_head=True, use_weighted_talk_head=True, trust_remote_code=True, torch_dtype=torch.bfloat16, **model_kwargs, ): # Set model attributes self.max_thoughts = n_ahead + n_ahead_talk + 1 self.merged_talk_heads = merged_talk_heads self.merged_lm_and_talk_heads = merged_lm_and_talk_heads self.merged_lm_and_think_heads = merged_lm_and_think_heads self.use_concat_talk_head = use_concat_talk_head self.use_shallow_think = use_shallow_think self.use_shallow_talk = use_shallow_talk self.use_complex_think_head = use_complex_think_head self.use_complex_talk_head = use_complex_talk_head self.use_weighted_talk_head = use_weighted_talk_head # Set model properties self.use_end_thought_token = True self.use_start_thought_token = True self.n_ahead = n_ahead self.n_passes = 1 self.eval_mode = True self.first_run = False self.rm_initialized = True self.original_mode = False generated_token_ids, attention_mask = custom_generate( self, input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, min_length=min_length, do_sample=do_sample, early_stopping=early_stopping, num_beams=num_beams, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, bad_words_ids=bad_words_ids, bos_token_id=bos_token_id, pad_token_id=pad_token_id, eos_token_id=eos_token_id, length_penalty=length_penalty, no_repeat_ngram_size=no_repeat_ngram_size, num_return_sequences=num_return_sequences, decoder_start_token_id=decoder_start_token_id, use_cache=use_cache, num_beam_groups=num_beam_groups, diversity_penalty=diversity_penalty, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, forced_bos_token_id=forced_bos_token_id, forced_eos_token_id=forced_eos_token_id, remove_invalid_values=remove_invalid_values, synced_gpus=synced_gpus, streamer=streamer, **model_kwargs, ) return generated_token_ids, attention_mask