Crystalcareai commited on
Commit
5ce32b8
·
verified ·
1 Parent(s): 07bd82c

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +21 -7
modeling_quiet.py CHANGED
@@ -1425,15 +1425,29 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1425
  logits = self.lm_head(mixed_hidden_states)
1426
  return logits
1427
 
1428
- def prepare_inputs_for_generation(self, input_ids, **kwargs):
1429
- return {"input_ids": input_ids}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1430
 
1431
- def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, model_kwargs):
1432
- return model_kwargs
1433
 
1434
- @torch.no_grad()
1435
- def generate(self, **kwargs):
1436
- return self.infer(**kwargs)
1437
 
1438
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1439
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
1425
  logits = self.lm_head(mixed_hidden_states)
1426
  return logits
1427
 
1428
+
1429
+ def generate(self, *args, **kwargs):
1430
+ # Save the original input_ids and attention_mask
1431
+ original_input_ids = kwargs.pop("input_ids", None)
1432
+ original_attention_mask = kwargs.pop("attention_mask", None)
1433
+
1434
+ # Call the infer method to get the logits
1435
+ logits = self.infer(
1436
+ input_ids=original_input_ids,
1437
+ attention_mask=original_attention_mask,
1438
+ position_ids=kwargs.pop("position_ids", None),
1439
+ past_key_values=kwargs.pop("past_key_values", None),
1440
+ inputs_embeds=kwargs.pop("inputs_embeds", None),
1441
+ use_cache=kwargs.pop("use_cache", None),
1442
+ output_attentions=kwargs.pop("output_attentions", None),
1443
+ output_hidden_states=kwargs.pop("output_hidden_states", None),
1444
+ return_dict=kwargs.pop("return_dict", None),
1445
+ )
1446
 
1447
+ # Generate output using the logits
1448
+ output_ids = torch.argmax(logits, dim=-1)
1449
 
1450
+ return output_ids
 
 
1451
 
1452
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1453
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)