Crystalcareai commited on
Commit
49663aa
·
verified ·
1 Parent(s): aa488fc

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +96 -29
modeling_quiet.py CHANGED
@@ -2128,6 +2128,50 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
2128
  del start_embedding
2129
  del end_embedding
2130
  torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2131
 
2132
  return CausalLMOutputWithPast(
2133
  loss=loss if loss is not None else None,
@@ -2169,36 +2213,59 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
2169
  def prepare_inputs_for_generation(
2170
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
2171
  ):
2172
- if past_key_values:
2173
- input_ids = input_ids[:, -1:]
2174
-
2175
- if attention_mask is None:
2176
- attention_mask = torch.ones_like(input_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2177
  else:
2178
- attention_mask = attention_mask[:, -input_ids.shape[1]:] # Adjust the attention mask size
2179
-
2180
- if self.use_start_thought_token:
2181
- start_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
2182
- input_ids = torch.cat(
2183
- [input_ids, torch.tensor([[start_thought_token_id]] * input_ids.shape[0], device=input_ids.device)],
2184
- dim=-1
2185
- )
2186
- attention_mask = torch.cat(
2187
- [attention_mask, torch.ones((input_ids.shape[0], 1), device=attention_mask.device)],
2188
- dim=-1
2189
- )
2190
-
2191
- # Expand the attention mask to the correct shape
2192
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
2193
- attention_mask = attention_mask.expand(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1])
2194
-
2195
- return {
2196
- "input_ids": input_ids,
2197
- "past_key_values": past_key_values,
2198
- "use_cache": kwargs.get("use_cache"),
2199
- "attention_mask": attention_mask,
2200
- "inputs_embeds": inputs_embeds,
2201
- }
2202
 
2203
  @staticmethod
2204
  def _reorder_cache(past_key_values, beam_idx):
 
2128
  del start_embedding
2129
  del end_embedding
2130
  torch.cuda.empty_cache()
2131
+
2132
+ if not self.training:
2133
+ # Inference mode
2134
+ if max_length is None:
2135
+ max_length = self.config.max_length
2136
+
2137
+ finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
2138
+ for cur_token_idx in range(max_length):
2139
+ outputs = self.model(
2140
+ input_ids=input_ids,
2141
+ attention_mask=attention_mask,
2142
+ position_ids=position_ids,
2143
+ past_key_values=past_key_values,
2144
+ inputs_embeds=inputs_embeds,
2145
+ use_cache=use_cache,
2146
+ output_attentions=output_attentions,
2147
+ output_hidden_states=output_hidden_states,
2148
+ return_dict=return_dict,
2149
+ )
2150
+ hidden_states = outputs[0]
2151
+ logits = self.lm_head(hidden_states)
2152
+
2153
+ # Mask out the start and end thought tokens
2154
+ logits[:, :, self.start_token_id] = -float("inf")
2155
+ logits[:, :, self.end_token_id] = -float("inf")
2156
+
2157
+ for batch_idx in range(batch_size):
2158
+ if not finished_generating[batch_idx]:
2159
+ last_token_idx = (input_ids[batch_idx] != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
2160
+ new_id_sampled = torch.multinomial(
2161
+ torch.nn.functional.softmax(logits[batch_idx, last_token_idx] / temperature, dim=-1), 1
2162
+ )
2163
+ if last_token_idx + 1 >= input_ids.shape[1]:
2164
+ # Add padding
2165
+ new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long, device=input_ids.device)
2166
+ input_ids = torch.cat([input_ids, new_padding], dim=-1)
2167
+ attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
2168
+ attention_mask[batch_idx, last_token_idx + 1] = 1
2169
+ input_ids[batch_idx, last_token_idx + 1] = new_id_sampled
2170
+ if new_id_sampled == self.tokenizer.eos_token_id or new_id_sampled == self.tokenizer.bos_token_id or new_id_sampled == self.tokenizer.pad_token_id:
2171
+ finished_generating[batch_idx] = True
2172
+
2173
+ if finished_generating.all():
2174
+ break
2175
 
2176
  return CausalLMOutputWithPast(
2177
  loss=loss if loss is not None else None,
 
2213
  def prepare_inputs_for_generation(
2214
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
2215
  ):
2216
+ # Omit tokens covered by past_key_values
2217
+ if past_key_values is not None:
2218
+ if isinstance(past_key_values, Cache):
2219
+ cache_length = past_key_values.get_seq_length()
2220
+ past_length = past_key_values.seen_tokens
2221
+ max_cache_length = past_key_values.get_max_length()
2222
+ else:
2223
+ cache_length = past_length = past_key_values[0][0].shape[2]
2224
+ max_cache_length = None
2225
+
2226
+ # Keep only the unprocessed tokens:
2227
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
2228
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing inputs_embeds as
2229
+ # input)
2230
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
2231
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
2232
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
2233
+ # input_ids based on the past_length.
2234
+ elif past_length < input_ids.shape[1]:
2235
+ input_ids = input_ids[:, past_length:]
2236
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
2237
+
2238
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
2239
+ if (
2240
+ max_cache_length is not None
2241
+ and attention_mask is not None
2242
+ and cache_length + input_ids.shape[1] > max_cache_length
2243
+ ):
2244
+ attention_mask = attention_mask[:, -max_cache_length:]
2245
+
2246
+ position_ids = kwargs.get("position_ids", None)
2247
+ if attention_mask is not None and position_ids is None:
2248
+ # create position_ids on the fly for batch generation
2249
+ position_ids = attention_mask.long().cumsum(-1) - 1
2250
+ position_ids.masked_fill_(attention_mask == 0, 1)
2251
+ if past_key_values:
2252
+ position_ids = position_ids[:, -input_ids.shape[1] :]
2253
+
2254
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
2255
+ if inputs_embeds is not None and past_key_values is None:
2256
+ model_inputs = {"inputs_embeds": inputs_embeds}
2257
  else:
2258
+ model_inputs = {"input_ids": input_ids}
2259
+
2260
+ model_inputs.update(
2261
+ {
2262
+ "position_ids": position_ids,
2263
+ "past_key_values": past_key_values,
2264
+ "use_cache": kwargs.get("use_cache"),
2265
+ "attention_mask": attention_mask,
2266
+ }
2267
+ )
2268
+ return model_inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
2269
 
2270
  @staticmethod
2271
  def _reorder_cache(past_key_values, beam_idx):