Crystalcareai commited on
Commit
195f100
·
verified ·
1 Parent(s): 88eec50

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +121 -50
modeling_quiet.py CHANGED
@@ -37,7 +37,7 @@ import transformers
37
 
38
  from transformers.activations import ACT2FN
39
  from transformers.cache_utils import Cache, DynamicCache
40
- from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask # _prepare_4d_causal_attention_mask_for_sdpa
41
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
42
  from transformers.modeling_utils import PreTrainedModel
43
  from transformers.utils import (
@@ -1110,7 +1110,126 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1110
  # Apply the language model head to get the final logits
1111
  logits = self.lm_head(mixed_hidden_states)
1112
  return logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1113
 
 
 
 
 
 
 
 
1114
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1115
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1116
  def forward(
@@ -1137,7 +1256,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1137
  top_p: Optional[float] = None,
1138
  min_p: Optional[float] = None,
1139
  top_k: Optional[int] = None,
1140
- cache_position: Optional[bool] = None,
1141
  repetition_penalty: Optional[float] = None,
1142
  presence_penalty: Optional[float] = None,
1143
  frequency_penalty: Optional[float] = None,
@@ -1412,17 +1530,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1412
  past_key_values_length,
1413
  sliding_window=self.config.sliding_window,
1414
  )
1415
- if attention_mask is not None:
1416
- if attention_mask.dim() == 2:
1417
- # Expand the attention mask to have dimensions (batch_size, 1, 1, seq_length)
1418
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
1419
- elif attention_mask.dim() == 3:
1420
- # Expand the attention mask to have dimensions (batch_size, 1, seq_length, seq_length)
1421
- attention_mask = attention_mask.unsqueeze(1)
1422
- else:
1423
- raise ValueError(
1424
- f"Attention mask should have 2 or 3 dimensions, but got {attention_mask.dim()} dimensions."
1425
- )
1426
  outputs = self.model(
1427
  # input_ids=input_ids,
1428
  attention_mask=attention_mask,
@@ -1861,43 +1969,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1861
  hidden_states=outputs.hidden_states,
1862
  attentions=outputs.attentions,
1863
  )
1864
-
1865
-
1866
- from .generate import custom_generate
1867
-
1868
- def prepare_inputs_for_generation(self, input_ids, **kwargs):
1869
- return {"input_ids": input_ids, **kwargs}
1870
-
1871
- def _generate_no_beam_search(
1872
- self,
1873
- input_ids,
1874
- cur_len,
1875
- max_length,
1876
- min_length,
1877
- do_sample,
1878
- temperature,
1879
- top_k,
1880
- top_p,
1881
- repetition_penalty,
1882
- no_repeat_ngram_size,
1883
- bad_words_ids,
1884
- pad_token_id,
1885
- eos_token_id,
1886
- batch_size,
1887
- attention_mask,
1888
- use_cache,
1889
- model_kwargs,
1890
- ):
1891
- generated_token_ids = custom_generate(
1892
- self,
1893
- input_ids=input_ids,
1894
- attention_mask=attention_mask,
1895
- max_new_tokens=max_length - cur_len,
1896
- temperature=temperature,
1897
- **model_kwargs,
1898
- )
1899
-
1900
- return generated_token_ids
1901
 
1902
  @staticmethod
1903
  def _reorder_cache(past_key_values, beam_idx):
 
37
 
38
  from transformers.activations import ACT2FN
39
  from transformers.cache_utils import Cache, DynamicCache
40
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
41
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
42
  from transformers.modeling_utils import PreTrainedModel
43
  from transformers.utils import (
 
1110
  # Apply the language model head to get the final logits
1111
  logits = self.lm_head(mixed_hidden_states)
1112
  return logits
1113
+
1114
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
1115
+ return {"input_ids": input_ids}
1116
+
1117
+ def _generate_no_beam_search(
1118
+ self,
1119
+ input_ids,
1120
+ cur_len,
1121
+ max_length,
1122
+ min_length,
1123
+ do_sample,
1124
+ temperature,
1125
+ top_k,
1126
+ top_p,
1127
+ repetition_penalty,
1128
+ no_repeat_ngram_size,
1129
+ bad_words_ids,
1130
+ pad_token_id,
1131
+ eos_token_id,
1132
+ batch_size,
1133
+ attention_mask,
1134
+ use_cache,
1135
+ model_kwargs,
1136
+ ):
1137
+ if input_ids is None or input_ids.nelement() == 0:
1138
+ input_ids = torch.LongTensor([[self.tokenizer.bos_token_id]]).to(self.device)
1139
+ attention_mask = torch.ones_like(input_ids).to(self.device)
1140
+
1141
+ device = input_ids.device
1142
+ with torch.no_grad():
1143
+ batch_size = input_ids.shape[0]
1144
+ finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device)
1145
+ generated_token_ids = torch.full((batch_size, max_length - cur_len), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
1146
+
1147
+ for cur_token_idx in range(max_length - cur_len):
1148
+ new_ids = self(
1149
+ input_ids[~finished_generating],
1150
+ attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None,
1151
+ **model_kwargs
1152
+ )['logits']
1153
+
1154
+ new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
1155
+
1156
+ for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
1157
+ base_answer_ids = input_ids[answer_idx]
1158
+ new_answer_ids = new_ids[list_idx]
1159
+ last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
1160
+
1161
+ new_ids_sampled = torch.multinomial(
1162
+ torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1)
1163
+
1164
+ if last_token_idx + 1 >= len(base_answer_ids):
1165
+ new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long,
1166
+ device=device)
1167
+ input_ids = torch.cat([input_ids, new_padding], dim=-1)
1168
+ if attention_mask is not None:
1169
+ attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
1170
+
1171
+ if attention_mask is not None:
1172
+ attention_mask[answer_idx, last_token_idx + 1] = 1
1173
+ input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
1174
+ generated_token_ids[answer_idx, cur_token_idx] = new_ids_sampled
1175
+
1176
+ if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
1177
+ finished_generating[answer_idx] = 1
1178
+
1179
+ if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("</s>"):
1180
+ finished_generating[answer_idx] = 1
1181
+
1182
+ if finished_generating.all():
1183
+ break
1184
+
1185
+ return generated_token_ids
1186
+
1187
+ @torch.no_grad()
1188
+ def generate(
1189
+ self,
1190
+ input_ids: torch.LongTensor = torch.LongTensor(),
1191
+ attention_mask: Optional[torch.Tensor] = None,
1192
+ max_new_tokens: Optional[int] = None,
1193
+ temperature: float = 1.1,
1194
+ **kwargs,
1195
+ ):
1196
+ if isinstance(input_ids, str):
1197
+ input_ids = self.tokenizer(input_ids, return_tensors="pt").input_ids
1198
+
1199
+ if attention_mask is None:
1200
+ attention_mask = torch.ones_like(input_ids)
1201
+
1202
+ max_length = max_new_tokens + input_ids.shape[1] if max_new_tokens is not None else None
1203
+
1204
+ # Set model attributes
1205
+ self.max_thoughts = kwargs.get('n_ahead', 4) + kwargs.get('n_ahead_talk', 4) + 1
1206
+ self.merged_talk_heads = kwargs.get('merged_talk_heads', True)
1207
+ self.merged_lm_and_talk_heads = kwargs.get('merged_lm_and_talk_heads', False)
1208
+ self.merged_lm_and_think_heads = kwargs.get('merged_lm_and_think_heads', True)
1209
+ self.use_concat_talk_head = kwargs.get('use_concat_talk_head', True)
1210
+ self.use_shallow_think = kwargs.get('use_shallow_think', True)
1211
+ self.use_shallow_talk = kwargs.get('use_shallow_talk', False)
1212
+ self.use_complex_think_head = kwargs.get('use_complex_think_head', False)
1213
+ self.use_complex_talk_head = kwargs.get('use_complex_talk_head', True)
1214
+ self.use_weighted_talk_head = kwargs.get('use_weighted_talk_head', True)
1215
+
1216
+ # Set model properties
1217
+ self.use_end_thought_token = True
1218
+ self.use_start_thought_token = True
1219
+ self.n_ahead = kwargs.get('n_ahead', 4)
1220
+ self.n_passes = 1
1221
+ self.eval_mode = True
1222
+ self.first_run = False
1223
+ self.rm_initialized = True
1224
+ self.original_mode = False
1225
 
1226
+ return super().generate(
1227
+ input_ids,
1228
+ attention_mask=attention_mask,
1229
+ max_length=max_length,
1230
+ temperature=temperature,
1231
+ **kwargs,
1232
+ )
1233
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1234
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1235
  def forward(
 
1256
  top_p: Optional[float] = None,
1257
  min_p: Optional[float] = None,
1258
  top_k: Optional[int] = None,
 
1259
  repetition_penalty: Optional[float] = None,
1260
  presence_penalty: Optional[float] = None,
1261
  frequency_penalty: Optional[float] = None,
 
1530
  past_key_values_length,
1531
  sliding_window=self.config.sliding_window,
1532
  )
1533
+
 
 
 
 
 
 
 
 
 
 
1534
  outputs = self.model(
1535
  # input_ids=input_ids,
1536
  attention_mask=attention_mask,
 
1969
  hidden_states=outputs.hidden_states,
1970
  attentions=outputs.attentions,
1971
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1972
 
1973
  @staticmethod
1974
  def _reorder_cache(past_key_values, beam_idx):