Crystalcareai commited on
Commit
0d44b2a
·
verified ·
1 Parent(s): b94c1e7

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +19 -19
modeling_quiet.py CHANGED
@@ -869,7 +869,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
869
  self.n_tokens_print = 1
870
  self.gradient_accumulation_steps = 1
871
  self.training_steps = 0
872
- self.tokenizer = None #AutoTokenizer.from_pretrained("Crystalcareai/Quiet-Star-Custom")
873
  self.start_token_id = None
874
  self.end_token_id = None
875
  self.rm_initialized = False
@@ -1111,24 +1111,24 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1111
  logits = self.lm_head(mixed_hidden_states)
1112
  return logits
1113
 
1114
- # @torch.no_grad()
1115
- # def generate(
1116
- # self,
1117
- # input_ids: torch.LongTensor = torch.LongTensor(),
1118
- # attention_mask: Optional[torch.Tensor] = None,
1119
- # max_new_tokens: Optional[int] = None,
1120
- # temperature: float = 1.1,
1121
- # **kwargs,
1122
- # ):
1123
- # if isinstance(input_ids, str):
1124
- # input_ids = self.tokenizer(input_ids, return_tensors="pt").input_ids
1125
-
1126
- # if attention_mask is None:
1127
- # # Create a default attention mask if not provided
1128
- # attention_mask = torch.ones_like(input_ids)
1129
-
1130
- # from .generate import generate
1131
- # return generate(self, input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
1132
 
1133
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1134
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
869
  self.n_tokens_print = 1
870
  self.gradient_accumulation_steps = 1
871
  self.training_steps = 0
872
+ self.tokenizer = AutoTokenizer.from_pretrained("Crystalcareai/Quiet-Star-Custom")
873
  self.start_token_id = None
874
  self.end_token_id = None
875
  self.rm_initialized = False
 
1111
  logits = self.lm_head(mixed_hidden_states)
1112
  return logits
1113
 
1114
+ @torch.no_grad()
1115
+ def generate(
1116
+ self,
1117
+ input_ids: torch.LongTensor = torch.LongTensor(),
1118
+ attention_mask: Optional[torch.Tensor] = None,
1119
+ max_new_tokens: Optional[int] = None,
1120
+ temperature: float = 1.1,
1121
+ **kwargs,
1122
+ ):
1123
+ if isinstance(input_ids, str):
1124
+ input_ids = self.tokenizer(input_ids, return_tensors="pt").input_ids
1125
+
1126
+ if attention_mask is None:
1127
+ # Create a default attention mask if not provided
1128
+ attention_mask = torch.ones_like(input_ids)
1129
+
1130
+ from .generate import generate
1131
+ return generate(self, input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
1132
 
1133
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1134
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)