Crystalcareai commited on
Commit
bffb9c8
·
verified ·
1 Parent(s): b47da4b

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +40 -0
modeling_quiet.py CHANGED
@@ -1424,6 +1424,46 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1424
  # Apply the language model head to get the final logits
1425
  logits = self.lm_head(mixed_hidden_states)
1426
  return logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1427
 
1428
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1429
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
1424
  # Apply the language model head to get the final logits
1425
  logits = self.lm_head(mixed_hidden_states)
1426
  return logits
1427
+
1428
+ @torch.no_grad()
1429
+ def generate(
1430
+ self,
1431
+ input_ids: torch.LongTensor,
1432
+ attention_mask: Optional[torch.Tensor] = None,
1433
+ position_ids: Optional[torch.LongTensor] = None,
1434
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1435
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1436
+ streamer: Optional[TextStreamer] = None,
1437
+ use_cache: Optional[bool] = None,
1438
+ output_attentions: Optional[bool] = None,
1439
+ output_hidden_states: Optional[bool] = None,
1440
+ return_dict: Optional[bool] = None,
1441
+ **kwargs,
1442
+ ):
1443
+ # Call your custom infer function
1444
+ logits = self.infer(
1445
+ input_ids,
1446
+ attention_mask,
1447
+ position_ids,
1448
+ past_key_values,
1449
+ inputs_embeds,
1450
+ use_cache,
1451
+ output_attentions,
1452
+ output_hidden_states,
1453
+ return_dict,
1454
+ )
1455
+
1456
+ # Sample the next token using the logits
1457
+ next_token = torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1)
1458
+
1459
+ # Append the generated token to the input sequence
1460
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
1461
+
1462
+ # Stream the generated token if a streamer is provided
1463
+ if streamer is not None:
1464
+ streamer.put(next_token)
1465
+
1466
+ return input_ids
1467
 
1468
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1469
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)