HwwwH commited on
Commit
18005e7
·
verified ·
1 Parent(s): 77225ff

Avoid duplicate input kwargs in `_decode`

Browse files
Files changed (1) hide show
  1. modeling_minicpmo.py +6 -1
modeling_minicpmo.py CHANGED
@@ -649,6 +649,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
649
  return outputs
650
 
651
  def _decode_stream(self, inputs_embeds, tokenizer, **kwargs):
 
652
  terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
653
  streamer = TextIteratorStreamer(tokenizer=tokenizer)
654
  generation_kwargs = {
@@ -777,6 +778,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
777
  tokenizer=None,
778
  vision_hidden_states=None,
779
  stream=False,
 
780
  **kwargs,
781
  ):
782
  assert input_ids is not None
@@ -814,7 +816,10 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
814
  outputs = self._decode(model_inputs["inputs_embeds"], tokenizer, attention_mask, **kwargs)
815
 
816
  result = self._decode_text(outputs.sequences, tokenizer)
817
-
 
 
 
818
  return result, outputs
819
 
820
  def chat(
 
649
  return outputs
650
 
651
  def _decode_stream(self, inputs_embeds, tokenizer, **kwargs):
652
+ kwargs.pop("output_hidden_states", None)
653
  terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
654
  streamer = TextIteratorStreamer(tokenizer=tokenizer)
655
  generation_kwargs = {
 
778
  tokenizer=None,
779
  vision_hidden_states=None,
780
  stream=False,
781
+ return_dict_in_generate=False,
782
  **kwargs,
783
  ):
784
  assert input_ids is not None
 
816
  outputs = self._decode(model_inputs["inputs_embeds"], tokenizer, attention_mask, **kwargs)
817
 
818
  result = self._decode_text(outputs.sequences, tokenizer)
819
+
820
+ if return_dict_in_generate is True:
821
+ return outputs
822
+
823
  return result, outputs
824
 
825
  def chat(