import copy import warnings import logging from typing import List, Tuple, Optional, Callable import torch from torch import nn from transformers.utils import logging from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig from .modeling_chatglm import ChatGLMForConditionalGeneration, InvalidScoreLogitsProcessor from .characterglm_generation_utils import CharacterGLMGenerationUtils, SessionMeta logger = logging.get_logger(__name__) default_generation_config = { "do_sample": True, "top_k": 100, "top_p": 0.9, "no_repeat_ngram_size": 0, "temperature": 0.9, "num_beams": 1, "length_penalty": 1.6, "repetition_penalty": 1.3, "eos_token_id": 13 } class CharacterGLMForConditionalGeneration(ChatGLMForConditionalGeneration): """ CharacterGLM的prompt格式与chatglm有差异。 CharacterGLMForConditionalGeneration复用了ChatGLMForConditionalGeneration的forward方法, 重新实现了`build_inputs`和`build_stream_inputs`, 调整了`chat`和`stream_chat`方法的函数签名,增加session_meta参数,并修改解码参数的默认值。 """ def build_inputs(self, tokenizer, session_meta: SessionMeta, query: str, history: Optional[List[Tuple[str, str]]] = None): character_glm_history = CharacterGLMGenerationUtils.convert_chatglm_history_to_characterglm_history(query, history or []) prompt = CharacterGLMGenerationUtils.build_inputs(session_meta, character_glm_history) inputs = tokenizer([prompt], return_tensors="pt") inputs = inputs.to(self.device) return inputs def build_stream_inputs(self, tokenizer, session_meta: SessionMeta, query: str, history: Optional[List[Tuple[str, str]]] = None): prompt = "\n[{}]{}\n[{}]".format( session_meta['user_name'], query.replace('\n', ' '), session_meta['bot_name'] ) input_ids = tokenizer.encode(prompt, add_special_tokens=False) input_ids = input_ids[1:] inputs = tokenizer.batch_encode_plus([(input_ids, None)], return_tensors="pt", add_special_tokens=False) inputs = inputs.to(self.device) return inputs @torch.inference_mode() def chat(self, tokenizer, session_meta: SessionMeta, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.9, temperature=0.9, repetition_penalty=1.6, logits_processor=None, **kwargs): if history is None: history = [] if logits_processor is None: logits_processor = LogitsProcessorList() logits_processor.append(InvalidScoreLogitsProcessor()) gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, "temperature": temperature, "logits_processor": logits_processor, "repetition_penalty": repetition_penalty, **kwargs} gen_kwargs.update({k: v for k, v in default_generation_config.items() if k not in gen_kwargs}) inputs = self.build_inputs(tokenizer, session_meta, query, history=history) outputs = self.generate(**inputs, **gen_kwargs) outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] response = tokenizer.decode(outputs) response = self.process_response(response) history = history + [(query, response)] return response, history @torch.inference_mode() def stream_chat(self, tokenizer, session_meta: SessionMeta, query: str, history: List[Tuple[str, str]] = None, past_key_values=None, max_length: int = 8192, do_sample=True, top_p=0.9, temperature=0.9, repetition_penalty=1.0, logits_processor=None, return_past_key_values=False, **kwargs): if history is None: history = [] if logits_processor is None: logits_processor = LogitsProcessorList() logits_processor.append(InvalidScoreLogitsProcessor()) gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, "temperature": temperature, "logits_processor": logits_processor, "repetition_penalty": repetition_penalty, **kwargs} gen_kwargs.update({k: v for k, v in default_generation_config.items() if k not in gen_kwargs}) gen_kwargs.pop('repetition_penalty', None) if past_key_values is None: inputs = self.build_inputs(tokenizer, session_meta, query, history=history) else: inputs = self.build_stream_inputs(tokenizer, session_meta, query, history=history) if past_key_values is not None: past_length = past_key_values[0][0].shape[0] if self.transformer.pre_seq_len is not None: past_length -= self.transformer.pre_seq_len inputs.position_ids += past_length attention_mask = inputs.attention_mask attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) inputs['attention_mask'] = attention_mask for outputs in self.stream_generate(**inputs, past_key_values=past_key_values, return_past_key_values=return_past_key_values, **gen_kwargs): if return_past_key_values: outputs, past_key_values = outputs outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] response = tokenizer.decode(outputs) if response and response[-1] != "�": response = self.process_response(response) new_history = history + [(query, response)] if return_past_key_values: yield response, new_history, past_key_values else: yield response, new_history @torch.inference_mode() def stream_generate( self, input_ids, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, return_past_key_values=False, **kwargs, ): batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] if generation_config is None: generation_config = self.generation_config generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) model_kwargs["use_cache"] = generation_config.use_cache bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None if has_default_max_length and generation_config.max_new_tokens is None: warnings.warn( f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" " recommend using `max_new_tokens` to control the maximum length of the generation.", UserWarning, ) elif generation_config.max_new_tokens is not None: generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length if not has_default_max_length: logger.warn( f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " "Please refer to the documentation for more information. " "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", UserWarning, ) if input_ids_seq_length >= generation_config.max_length: input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" logger.warning( f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" " increasing `max_new_tokens`." ) # 2. Set generation parameters if not already defined logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_seq_length, encoder_input_ids=input_ids, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, logits_processor=logits_processor, ) stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria ) logits_warper = self._get_logits_warper(generation_config) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) scores = None while True: model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token outputs = self( **model_inputs, return_dict=True, output_attentions=False, output_hidden_states=False, ) next_token_logits = outputs.logits[:, -1, :] # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) next_token_scores = logits_warper(input_ids, next_token_scores) # sample probs = nn.functional.softmax(next_token_scores, dim=-1) if generation_config.do_sample: next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: next_tokens = torch.argmax(probs, dim=-1) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) if return_past_key_values: yield input_ids, outputs.past_key_values else: yield input_ids # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): break