Patching hf bug that creates wrong cache length if only inputs_embeds are passed to the model

#19
by tomer-nv - opened
Files changed (1) hide show
  1. modeling_decilm.py +45 -1
modeling_decilm.py CHANGED
@@ -25,7 +25,7 @@ import torch.utils.checkpoint
25
  from torch import nn
26
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
  from transformers import GenerationConfig
28
- from transformers.generation.utils import GenerationMixin, NEED_SETUP_CACHE_CLASSES_MAPPING
29
  from transformers.modeling_utils import PreTrainedModel
30
  from transformers.utils import (
31
  add_start_docstrings,
@@ -1311,6 +1311,50 @@ class DeciLMForCausalLM(DeciLMPreTrainedModel, GenerationMixin):
1311
  )
1312
  return model_inputs
1313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1314
 
1315
  @add_start_docstrings(
1316
  """
 
25
  from torch import nn
26
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
  from transformers import GenerationConfig
28
+ from transformers.generation.utils import NEED_SETUP_CACHE_CLASSES_MAPPING, GenerationMixin, GenerateOutput
29
  from transformers.modeling_utils import PreTrainedModel
30
  from transformers.utils import (
31
  add_start_docstrings,
 
1311
  )
1312
  return model_inputs
1313
 
1314
+ def _maybe_initialize_input_ids_for_generation(
1315
+ self,
1316
+ inputs: Optional[torch.Tensor] = None,
1317
+ bos_token_id: Optional[torch.Tensor] = None,
1318
+ model_kwargs: Optional[dict[str, torch.Tensor]] = None,
1319
+ ) -> torch.LongTensor:
1320
+ """
1321
+ Patching hf bug that creates wrong cache length if only inputs_embeds are passed to the model
1322
+ """
1323
+ input_ids = super()._maybe_initialize_input_ids_for_generation(
1324
+ inputs=inputs, bos_token_id=bos_token_id, model_kwargs=model_kwargs)
1325
+ if (
1326
+ "inputs_embeds" in model_kwargs
1327
+ and input_ids is not None
1328
+ and input_ids.shape[1] == 0
1329
+ ):
1330
+ batch_size, input_sequence_length = model_kwargs["inputs_embeds"].shape[:2]
1331
+ input_ids = torch.zeros((batch_size, input_sequence_length), dtype=torch.long, device=self.device)
1332
+ return input_ids
1333
+
1334
+ def generate(
1335
+ self,
1336
+ inputs: Optional[torch.Tensor] = None,
1337
+ *args,
1338
+ **kwargs,
1339
+ ) -> Union[GenerateOutput, torch.LongTensor]:
1340
+ """
1341
+ Patching hf bug that creates wrong cache length if only inputs_embeds are passed to the model
1342
+ """
1343
+ only_passed_inputs_embeds = (
1344
+ "inputs_embeds" in kwargs and
1345
+ "input_ids" not in kwargs and
1346
+ inputs is None
1347
+ )
1348
+ if only_passed_inputs_embeds:
1349
+ input_sequence_length = kwargs["inputs_embeds"].shape[1]
1350
+
1351
+ generation_output = super().generate(inputs=inputs, *args, **kwargs)
1352
+
1353
+ if only_passed_inputs_embeds and isinstance(generation_output, torch.Tensor):
1354
+ generation_output = generation_output[:, input_sequence_length:]
1355
+
1356
+ return generation_output
1357
+
1358
 
1359
  @add_start_docstrings(
1360
  """