Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +19 -19
modeling_quiet.py
CHANGED
@@ -869,7 +869,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
869 |
self.n_tokens_print = 1
|
870 |
self.gradient_accumulation_steps = 1
|
871 |
self.training_steps = 0
|
872 |
-
self.tokenizer =
|
873 |
self.start_token_id = None
|
874 |
self.end_token_id = None
|
875 |
self.rm_initialized = False
|
@@ -1111,24 +1111,24 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1111 |
logits = self.lm_head(mixed_hidden_states)
|
1112 |
return logits
|
1113 |
|
1114 |
-
|
1115 |
-
|
1116 |
-
|
1117 |
-
|
1118 |
-
|
1119 |
-
|
1120 |
-
|
1121 |
-
|
1122 |
-
|
1123 |
-
|
1124 |
-
|
1125 |
-
|
1126 |
-
|
1127 |
-
|
1128 |
-
|
1129 |
-
|
1130 |
-
|
1131 |
-
|
1132 |
|
1133 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1134 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
|
869 |
self.n_tokens_print = 1
|
870 |
self.gradient_accumulation_steps = 1
|
871 |
self.training_steps = 0
|
872 |
+
self.tokenizer = AutoTokenizer.from_pretrained("Crystalcareai/Quiet-Star-Custom")
|
873 |
self.start_token_id = None
|
874 |
self.end_token_id = None
|
875 |
self.rm_initialized = False
|
|
|
1111 |
logits = self.lm_head(mixed_hidden_states)
|
1112 |
return logits
|
1113 |
|
1114 |
+
@torch.no_grad()
|
1115 |
+
def generate(
|
1116 |
+
self,
|
1117 |
+
input_ids: torch.LongTensor = torch.LongTensor(),
|
1118 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1119 |
+
max_new_tokens: Optional[int] = None,
|
1120 |
+
temperature: float = 1.1,
|
1121 |
+
**kwargs,
|
1122 |
+
):
|
1123 |
+
if isinstance(input_ids, str):
|
1124 |
+
input_ids = self.tokenizer(input_ids, return_tensors="pt").input_ids
|
1125 |
+
|
1126 |
+
if attention_mask is None:
|
1127 |
+
# Create a default attention mask if not provided
|
1128 |
+
attention_mask = torch.ones_like(input_ids)
|
1129 |
+
|
1130 |
+
from .generate import generate
|
1131 |
+
return generate(self, input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
|
1132 |
|
1133 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1134 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|