Crystalcareai commited on
Commit
eaa31dc
·
verified ·
1 Parent(s): a9e5703

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +26 -54
modeling_quiet.py CHANGED
@@ -41,6 +41,7 @@ import torch.nn.functional as F
41
  import torch.utils.checkpoint
42
  from torch import nn
43
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
44
 
45
  from transformers.activations import ACT2FN
46
  from transformers.cache_utils import Cache, DynamicCache
@@ -1166,7 +1167,7 @@ def nonzero_mean(x, axis=None):
1166
  def loss_mean(x):
1167
  return x.sum() / (x != 0).sum()
1168
 
1169
- class QuietForCausalLM(QuietPreTrainedModel):
1170
  _tied_weights_keys = ["lm_head.weight"]
1171
 
1172
  def __init__(self, config):
@@ -2168,59 +2169,30 @@ class QuietForCausalLM(QuietPreTrainedModel):
2168
  def prepare_inputs_for_generation(
2169
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
2170
  ):
2171
- # Omit tokens covered by past_key_values
2172
- if past_key_values is not None:
2173
- if isinstance(past_key_values, Cache):
2174
- cache_length = past_key_values.get_seq_length()
2175
- past_length = past_key_values.seen_tokens
2176
- max_cache_length = past_key_values.get_max_length()
2177
- else:
2178
- cache_length = past_length = past_key_values[0][0].shape[2]
2179
- max_cache_length = None
2180
-
2181
- # Keep only the unprocessed tokens:
2182
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
2183
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing inputs_embeds as
2184
- # input)
2185
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
2186
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
2187
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
2188
- # input_ids based on the past_length.
2189
- elif past_length < input_ids.shape[1]:
2190
- input_ids = input_ids[:, past_length:]
2191
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
2192
-
2193
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
2194
- if (
2195
- max_cache_length is not None
2196
- and attention_mask is not None
2197
- and cache_length + input_ids.shape[1] > max_cache_length
2198
- ):
2199
- attention_mask = attention_mask[:, -max_cache_length:]
2200
-
2201
- position_ids = kwargs.get("position_ids", None)
2202
- if attention_mask is not None and position_ids is None:
2203
- # create position_ids on the fly for batch generation
2204
- position_ids = attention_mask.long().cumsum(-1) - 1
2205
- position_ids.masked_fill_(attention_mask == 0, 1)
2206
- if past_key_values:
2207
- position_ids = position_ids[:, -input_ids.shape[1] :]
2208
-
2209
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
2210
- if inputs_embeds is not None and past_key_values is None:
2211
- model_inputs = {"inputs_embeds": inputs_embeds}
2212
- else:
2213
- model_inputs = {"input_ids": input_ids}
2214
-
2215
- model_inputs.update(
2216
- {
2217
- "position_ids": position_ids,
2218
- "past_key_values": past_key_values,
2219
- "use_cache": kwargs.get("use_cache"),
2220
- "attention_mask": attention_mask,
2221
- }
2222
- )
2223
- return model_inputs
2224
 
2225
  @staticmethod
2226
  def _reorder_cache(past_key_values, beam_idx):
 
41
  import torch.utils.checkpoint
42
  from torch import nn
43
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
44
+ from transformers import GenerationMixin
45
 
46
  from transformers.activations import ACT2FN
47
  from transformers.cache_utils import Cache, DynamicCache
 
1167
  def loss_mean(x):
1168
  return x.sum() / (x != 0).sum()
1169
 
1170
+ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1171
  _tied_weights_keys = ["lm_head.weight"]
1172
 
1173
  def __init__(self, config):
 
2169
  def prepare_inputs_for_generation(
2170
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
2171
  ):
2172
+ if past_key_values:
2173
+ input_ids = input_ids[:, -1:]
2174
+
2175
+ if attention_mask is None:
2176
+ attention_mask = torch.ones_like(input_ids)
2177
+
2178
+ if self.use_start_thought_token:
2179
+ start_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
2180
+ input_ids = torch.cat(
2181
+ [input_ids, torch.tensor([[start_thought_token_id]] * input_ids.shape[0], device=input_ids.device)],
2182
+ dim=-1
2183
+ )
2184
+ attention_mask = torch.cat(
2185
+ [attention_mask, torch.ones((input_ids.shape[0], 1), device=attention_mask.device)],
2186
+ dim=-1
2187
+ )
2188
+
2189
+ return {
2190
+ "input_ids": input_ids,
2191
+ "past_key_values": past_key_values,
2192
+ "use_cache": kwargs.get("use_cache"),
2193
+ "attention_mask": attention_mask,
2194
+ "inputs_embeds": inputs_embeds,
2195
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2196
 
2197
  @staticmethod
2198
  def _reorder_cache(past_key_values, beam_idx):