Crystalcareai commited on
Commit
aa488fc
·
verified ·
1 Parent(s): be7199b

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +29 -52
modeling_quiet.py CHANGED
@@ -2169,59 +2169,36 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
2169
  def prepare_inputs_for_generation(
2170
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
2171
  ):
2172
- # Omit tokens covered by past_key_values
2173
- if past_key_values is not None:
2174
- if isinstance(past_key_values, Cache):
2175
- cache_length = past_key_values.get_seq_length()
2176
- past_length = past_key_values.seen_tokens
2177
- max_cache_length = past_key_values.get_max_length()
2178
- else:
2179
- cache_length = past_length = past_key_values[0][0].shape[2]
2180
- max_cache_length = None
2181
-
2182
- # Keep only the unprocessed tokens:
2183
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
2184
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing inputs_embeds as
2185
- # input)
2186
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
2187
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
2188
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
2189
- # input_ids based on the past_length.
2190
- elif past_length < input_ids.shape[1]:
2191
- input_ids = input_ids[:, past_length:]
2192
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
2193
-
2194
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
2195
- if (
2196
- max_cache_length is not None
2197
- and attention_mask is not None
2198
- and cache_length + input_ids.shape[1] > max_cache_length
2199
- ):
2200
- attention_mask = attention_mask[:, -max_cache_length:]
2201
-
2202
- position_ids = kwargs.get("position_ids", None)
2203
- if attention_mask is not None and position_ids is None:
2204
- # create position_ids on the fly for batch generation
2205
- position_ids = attention_mask.long().cumsum(-1) - 1
2206
- position_ids.masked_fill_(attention_mask == 0, 1)
2207
- if past_key_values:
2208
- position_ids = position_ids[:, -input_ids.shape[1] :]
2209
-
2210
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
2211
- if inputs_embeds is not None and past_key_values is None:
2212
- model_inputs = {"inputs_embeds": inputs_embeds}
2213
  else:
2214
- model_inputs = {"input_ids": input_ids}
2215
-
2216
- model_inputs.update(
2217
- {
2218
- "position_ids": position_ids,
2219
- "past_key_values": past_key_values,
2220
- "use_cache": kwargs.get("use_cache"),
2221
- "attention_mask": attention_mask,
2222
- }
2223
- )
2224
- return model_inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
2225
 
2226
  @staticmethod
2227
  def _reorder_cache(past_key_values, beam_idx):
 
2169
  def prepare_inputs_for_generation(
2170
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
2171
  ):
2172
+ if past_key_values:
2173
+ input_ids = input_ids[:, -1:]
2174
+
2175
+ if attention_mask is None:
2176
+ attention_mask = torch.ones_like(input_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2177
  else:
2178
+ attention_mask = attention_mask[:, -input_ids.shape[1]:] # Adjust the attention mask size
2179
+
2180
+ if self.use_start_thought_token:
2181
+ start_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
2182
+ input_ids = torch.cat(
2183
+ [input_ids, torch.tensor([[start_thought_token_id]] * input_ids.shape[0], device=input_ids.device)],
2184
+ dim=-1
2185
+ )
2186
+ attention_mask = torch.cat(
2187
+ [attention_mask, torch.ones((input_ids.shape[0], 1), device=attention_mask.device)],
2188
+ dim=-1
2189
+ )
2190
+
2191
+ # Expand the attention mask to the correct shape
2192
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
2193
+ attention_mask = attention_mask.expand(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1])
2194
+
2195
+ return {
2196
+ "input_ids": input_ids,
2197
+ "past_key_values": past_key_values,
2198
+ "use_cache": kwargs.get("use_cache"),
2199
+ "attention_mask": attention_mask,
2200
+ "inputs_embeds": inputs_embeds,
2201
+ }
2202
 
2203
  @staticmethod
2204
  def _reorder_cache(past_key_values, beam_idx):