Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +40 -0
modeling_quiet.py
CHANGED
@@ -1424,6 +1424,46 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1424 |
# Apply the language model head to get the final logits
|
1425 |
logits = self.lm_head(mixed_hidden_states)
|
1426 |
return logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1427 |
|
1428 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1429 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
|
1424 |
# Apply the language model head to get the final logits
|
1425 |
logits = self.lm_head(mixed_hidden_states)
|
1426 |
return logits
|
1427 |
+
|
1428 |
+
@torch.no_grad()
|
1429 |
+
def generate(
|
1430 |
+
self,
|
1431 |
+
input_ids: torch.LongTensor,
|
1432 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1433 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1434 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1435 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1436 |
+
streamer: Optional[TextStreamer] = None,
|
1437 |
+
use_cache: Optional[bool] = None,
|
1438 |
+
output_attentions: Optional[bool] = None,
|
1439 |
+
output_hidden_states: Optional[bool] = None,
|
1440 |
+
return_dict: Optional[bool] = None,
|
1441 |
+
**kwargs,
|
1442 |
+
):
|
1443 |
+
# Call your custom infer function
|
1444 |
+
logits = self.infer(
|
1445 |
+
input_ids,
|
1446 |
+
attention_mask,
|
1447 |
+
position_ids,
|
1448 |
+
past_key_values,
|
1449 |
+
inputs_embeds,
|
1450 |
+
use_cache,
|
1451 |
+
output_attentions,
|
1452 |
+
output_hidden_states,
|
1453 |
+
return_dict,
|
1454 |
+
)
|
1455 |
+
|
1456 |
+
# Sample the next token using the logits
|
1457 |
+
next_token = torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1)
|
1458 |
+
|
1459 |
+
# Append the generated token to the input sequence
|
1460 |
+
input_ids = torch.cat([input_ids, next_token], dim=-1)
|
1461 |
+
|
1462 |
+
# Stream the generated token if a streamer is provided
|
1463 |
+
if streamer is not None:
|
1464 |
+
streamer.put(next_token)
|
1465 |
+
|
1466 |
+
return input_ids
|
1467 |
|
1468 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1469 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|