Crystalcareai commited on
Commit
b9e81b4
·
verified ·
1 Parent(s): 27bbda7

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +80 -123
modeling_quiet.py CHANGED
@@ -1022,9 +1022,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1022
  seq_len += 1
1023
 
1024
  # Update the attention mask
1025
- if attention_mask is None:
1026
- attention_mask = torch.ones_like(input_ids)
1027
- else:
1028
  attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1029
 
1030
  # Generate the continuation
@@ -1059,11 +1057,12 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1059
  next_token_id = torch.argmax(next_token_logits, dim=-1)
1060
 
1061
  # Append the generated token to the input sequence
1062
- input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
1063
  seq_len += 1
1064
 
1065
  # Update the attention mask
1066
- attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
 
1067
 
1068
  # Append the end thought token to the input sequence
1069
  end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
@@ -1071,7 +1070,8 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1071
  seq_len += 1
1072
 
1073
  # Update the attention mask
1074
- attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
 
1075
 
1076
  # Get the hidden states before and after the thought
1077
  outputs_before = self.model(
@@ -1090,7 +1090,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1090
  # two new tokens: last continuation token and end thought token
1091
  outputs_after = self.model(
1092
  input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1),
1093
- attention_mask=torch.cat([attention_mask[:, -2:], torch.ones((batch_size, 2)).to(attention_mask.device)], dim=-1),
1094
  position_ids=position_ids,
1095
  past_key_values=new_key_values,
1096
  inputs_embeds=inputs_embeds,
@@ -1110,127 +1110,25 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1110
  # Apply the language model head to get the final logits
1111
  logits = self.lm_head(mixed_hidden_states)
1112
  return logits
1113
-
1114
- def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs):
1115
- if attention_mask is None:
1116
- attention_mask = torch.ones_like(input_ids)
1117
- return {"input_ids": input_ids, "attention_mask": attention_mask}
1118
-
1119
- def _generate_no_beam_search(
1120
- self,
1121
- input_ids,
1122
- cur_len,
1123
- max_length,
1124
- min_length,
1125
- do_sample,
1126
- temperature,
1127
- top_k,
1128
- top_p,
1129
- repetition_penalty,
1130
- no_repeat_ngram_size,
1131
- bad_words_ids,
1132
- pad_token_id,
1133
- eos_token_id,
1134
- batch_size,
1135
- attention_mask,
1136
- use_cache,
1137
- model_kwargs,
1138
- ):
1139
- finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
1140
- for cur_token_idx in range(max_length):
1141
- # Sample the next token
1142
- new_ids = self(
1143
- input_ids[~finished_generating],
1144
- attention_mask=attention_mask[~finished_generating]
1145
- )['logits']
1146
- # Mask out the start and end thought tokens so we don't accidentally sample them
1147
- new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
1148
- for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
1149
- # Find the index of the last token that is not padding
1150
- base_answer_ids = input_ids[answer_idx]
1151
- new_answer_ids = new_ids[list_idx]
1152
- last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
1153
-
1154
- new_ids_sampled = torch.multinomial(
1155
- torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1)
1156
- # Assign the new id to the last token
1157
- if last_token_idx + 1 >= len(base_answer_ids):
1158
- # Add padding everywhere
1159
- new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long,
1160
- device=input_ids.device)
1161
- input_ids = torch.cat([input_ids, new_padding], dim=-1)
1162
- attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
1163
- attention_mask[answer_idx, last_token_idx + 1] = 1
1164
- input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
1165
- if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
1166
- finished_generating[answer_idx] = 1
1167
- # Check if the end token is generated
1168
- if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("<|/assistant|>"):
1169
- finished_generating[answer_idx] = 1
1170
- if finished_generating.all():
1171
- break
1172
- return input_ids
1173
-
1174
  @torch.no_grad()
1175
  def generate(
1176
  self,
1177
- input_ids=None,
1178
- max_length=None,
1179
- min_length=None,
1180
- do_sample=None,
1181
- early_stopping=None,
1182
- num_beams=None,
1183
- temperature=None,
1184
- top_k=None,
1185
- top_p=None,
1186
- repetition_penalty=None,
1187
- bad_words_ids=None,
1188
- bos_token_id=None,
1189
- pad_token_id=None,
1190
- eos_token_id=None,
1191
- length_penalty=None,
1192
- no_repeat_ngram_size=None,
1193
- num_return_sequences=None,
1194
- attention_mask=None,
1195
- decoder_start_token_id=None,
1196
- use_cache=None,
1197
- **model_kwargs,
1198
- ):
1199
- max_length = max_length if max_length is not None else self.config.max_length
1200
- min_length = min_length if min_length is not None else self.config.min_length
1201
- do_sample = do_sample if do_sample is not None else self.config.do_sample
1202
- temperature = temperature if temperature is not None else self.config.temperature
1203
- pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
1204
- eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
1205
-
1206
- # if input_ids is None:
1207
- # raise ValueError("You have to specify either input_ids")
1208
 
1209
- # batch_size = input_ids.shape[0]
1210
- # cur_len = input_ids.shape[-1]
 
1211
 
1212
- # if attention_mask is None:
1213
- # attention_mask = torch.ones(batch_size, cur_len, device=input_ids.device)
1214
-
1215
- return self._generate_no_beam_search(
1216
- input_ids,
1217
- cur_len=cur_len,
1218
- max_length=max_length,
1219
- min_length=min_length,
1220
- do_sample=do_sample,
1221
- temperature=temperature,
1222
- top_k=top_k,
1223
- top_p=top_p,
1224
- repetition_penalty=repetition_penalty,
1225
- no_repeat_ngram_size=no_repeat_ngram_size,
1226
- bad_words_ids=bad_words_ids,
1227
- pad_token_id=pad_token_id,
1228
- eos_token_id=eos_token_id,
1229
- batch_size=batch_size,
1230
- attention_mask=attention_mask,
1231
- use_cache=use_cache,
1232
- model_kwargs=model_kwargs,
1233
- )
1234
 
1235
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1236
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
@@ -1971,6 +1869,65 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1971
  hidden_states=outputs.hidden_states,
1972
  attentions=outputs.attentions,
1973
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1974
 
1975
  @staticmethod
1976
  def _reorder_cache(past_key_values, beam_idx):
 
1022
  seq_len += 1
1023
 
1024
  # Update the attention mask
1025
+ if attention_mask is not None:
 
 
1026
  attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1027
 
1028
  # Generate the continuation
 
1057
  next_token_id = torch.argmax(next_token_logits, dim=-1)
1058
 
1059
  # Append the generated token to the input sequence
1060
+ # input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
1061
  seq_len += 1
1062
 
1063
  # Update the attention mask
1064
+ if attention_mask is not None:
1065
+ attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1066
 
1067
  # Append the end thought token to the input sequence
1068
  end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
 
1070
  seq_len += 1
1071
 
1072
  # Update the attention mask
1073
+ if attention_mask is not None:
1074
+ attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1075
 
1076
  # Get the hidden states before and after the thought
1077
  outputs_before = self.model(
 
1090
  # two new tokens: last continuation token and end thought token
1091
  outputs_after = self.model(
1092
  input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1),
1093
+ attention_mask=torch.cat([attention_mask[:, -1:], torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1),
1094
  position_ids=position_ids,
1095
  past_key_values=new_key_values,
1096
  inputs_embeds=inputs_embeds,
 
1110
  # Apply the language model head to get the final logits
1111
  logits = self.lm_head(mixed_hidden_states)
1112
  return logits
1113
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1114
  @torch.no_grad()
1115
  def generate(
1116
  self,
1117
+ input_ids: torch.LongTensor = torch.LongTensor(),
1118
+ attention_mask: Optional[torch.Tensor] = None,
1119
+ max_new_tokens: Optional[int] = None,
1120
+ temperature: float = 1.1,
1121
+ **kwargs,
1122
+ ):
1123
+ if isinstance(input_ids, str):
1124
+ input_ids = self.tokenizer(input_ids, return_tensors="pt").input_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1125
 
1126
+ if attention_mask is None:
1127
+ # Create a default attention mask if not provided
1128
+ attention_mask = torch.ones_like(input_ids)
1129
 
1130
+ from .generate import custom_generate
1131
+ return custom_generate(self, input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1132
 
1133
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1134
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
1869
  hidden_states=outputs.hidden_states,
1870
  attentions=outputs.attentions,
1871
  )
1872
+
1873
+
1874
+
1875
+ def prepare_inputs_for_generation(
1876
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1877
+ ):
1878
+ # Omit tokens covered by past_key_values
1879
+ if past_key_values is not None:
1880
+ if isinstance(past_key_values, Cache):
1881
+ cache_length = past_key_values.get_seq_length()
1882
+ past_length = past_key_values.seen_tokens
1883
+ max_cache_length = past_key_values.get_max_length()
1884
+ else:
1885
+ cache_length = past_length = past_key_values[0][0].shape[2]
1886
+ max_cache_length = None
1887
+
1888
+ # Keep only the unprocessed tokens:
1889
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1890
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing inputs_embeds as
1891
+ # input)
1892
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1893
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1894
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1895
+ # input_ids based on the past_length.
1896
+ elif past_length < input_ids.shape[1]:
1897
+ input_ids = input_ids[:, past_length:]
1898
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1899
+
1900
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1901
+ if (
1902
+ max_cache_length is not None
1903
+ and attention_mask is not None
1904
+ and cache_length + input_ids.shape[1] > max_cache_length
1905
+ ):
1906
+ attention_mask = attention_mask[:, -max_cache_length:]
1907
+
1908
+ position_ids = kwargs.get("position_ids", None)
1909
+ if attention_mask is not None and position_ids is None:
1910
+ # create position_ids on the fly for batch generation
1911
+ position_ids = attention_mask.long().cumsum(-1) - 1
1912
+ position_ids.masked_fill_(attention_mask == 0, 1)
1913
+ if past_key_values:
1914
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1915
+
1916
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1917
+ if inputs_embeds is not None and past_key_values is None:
1918
+ model_inputs = {"inputs_embeds": inputs_embeds}
1919
+ else:
1920
+ model_inputs = {"input_ids": input_ids}
1921
+
1922
+ model_inputs.update(
1923
+ {
1924
+ "position_ids": position_ids,
1925
+ "past_key_values": past_key_values,
1926
+ "use_cache": kwargs.get("use_cache"),
1927
+ "attention_mask": attention_mask,
1928
+ }
1929
+ )
1930
+ return model_inputs
1931
 
1932
  @staticmethod
1933
  def _reorder_cache(past_key_values, beam_idx):