Crystalcareai commited on
Commit
bef64cd
·
verified ·
1 Parent(s): cc80789

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +77 -0
modeling_quiet.py CHANGED
@@ -2136,6 +2136,8 @@ class QuietForCausalLM(QuietPreTrainedModel):
2136
  attentions=outputs.attentions,
2137
  )
2138
 
 
 
2139
  def compute_complexity_scores(self, input_ids, attention_mask):
2140
  # Compute complexity scores based on input sequence characteristics
2141
  # Example: Normalize sequence lengths and consider the presence of rare tokens
@@ -2228,6 +2230,81 @@ class QuietForCausalLM(QuietPreTrainedModel):
2228
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
2229
  )
2230
  return reordered_past
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2231
 
2232
 
2233
  @add_start_docstrings(
 
2136
  attentions=outputs.attentions,
2137
  )
2138
 
2139
+
2140
+
2141
  def compute_complexity_scores(self, input_ids, attention_mask):
2142
  # Compute complexity scores based on input sequence characteristics
2143
  # Example: Normalize sequence lengths and consider the presence of rare tokens
 
2230
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
2231
  )
2232
  return reordered_past
2233
+
2234
+ @torch.no_grad()
2235
+ def generate(self, input_ids, attention_mask=None, **generate_kwargs):
2236
+ batch_size, seq_len = input_ids.shape
2237
+
2238
+ max_length = generate_kwargs.get("max_length", self.config.max_length)
2239
+ min_length = generate_kwargs.get("min_length", self.config.min_length)
2240
+ do_sample = generate_kwargs.get("do_sample", self.config.do_sample)
2241
+ early_stopping = generate_kwargs.get("early_stopping", self.config.early_stopping)
2242
+ num_beams = generate_kwargs.get("num_beams", self.config.num_beams)
2243
+ temperature = generate_kwargs.get("temperature", self.config.temperature)
2244
+ top_k = generate_kwargs.get("top_k", self.config.top_k)
2245
+ top_p = generate_kwargs.get("top_p", self.config.top_p)
2246
+ repetition_penalty = generate_kwargs.get("repetition_penalty", self.config.repetition_penalty)
2247
+ pad_token_id = generate_kwargs.get("pad_token_id", self.config.pad_token_id)
2248
+ eos_token_id = generate_kwargs.get("eos_token_id", self.config.eos_token_id)
2249
+
2250
+ # Prepend the start thought token to the input sequence
2251
+ start_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
2252
+ input_ids = torch.cat([input_ids, torch.tensor([[start_thought_token_id]], device=input_ids.device)], dim=-1)
2253
+ if attention_mask is not None:
2254
+ attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1), device=attention_mask.device)], dim=-1)
2255
+
2256
+ thought_embeds = self.model.embed_tokens(input_ids)
2257
+
2258
+ past_key_values = None
2259
+ unfinished_sequences = input_ids.new(batch_size).fill_(1)
2260
+ sequence_lengths = input_ids.new(batch_size).fill_(max_length)
2261
+
2262
+ while True:
2263
+ model_inputs = self.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, attention_mask=attention_mask, use_cache=True)
2264
+ thought_outputs = self.model(**model_inputs, output_hidden_states=True, return_dict=True)
2265
+ next_thought_embeds = self.prepare_thought_embeds(thought_outputs.hidden_states[-1], temperature=temperature)
2266
+
2267
+ thought_embeds = torch.cat([thought_embeds, next_thought_embeds], dim=1)
2268
+
2269
+ lm_logits = self.lm_head(thought_outputs.last_hidden_state)
2270
+ lm_logits = lm_logits[:, -1, :] / temperature
2271
+
2272
+ if do_sample:
2273
+ next_tokens = torch.multinomial(F.softmax(lm_logits, dim=-1), num_samples=1).squeeze(1)
2274
+ else:
2275
+ next_tokens = torch.argmax(lm_logits, dim=-1)
2276
+
2277
+ # Update input_ids, attention_mask and past_key_values
2278
+ input_ids = torch.cat([input_ids, next_tokens.unsqueeze(-1)], dim=-1)
2279
+ if attention_mask is not None:
2280
+ attention_mask = torch.cat([attention_mask, attention_mask.new_ones((batch_size, 1))], dim=-1)
2281
+ past_key_values = thought_outputs.past_key_values
2282
+
2283
+ # Check if generation is complete
2284
+ if eos_token_id is not None:
2285
+ unfinished_sequences = unfinished_sequences & (next_tokens != eos_token_id)
2286
+ unfinished_sequences = unfinished_sequences & (input_ids.shape[-1] < max_length)
2287
+ if unfinished_sequences.max() == 0:
2288
+ break
2289
+ elif input_ids.shape[-1] >= max_length:
2290
+ input_ids[:, 0] = eos_token_id
2291
+ break
2292
+
2293
+ return input_ids
2294
+
2295
+ def prepare_thought_embeds(self, hidden_states, temperature=1.0):
2296
+ if self.use_start_thought_token:
2297
+ start_embed = self.start_embedding[0].unsqueeze(0) * temperature
2298
+ else:
2299
+ start_embed = hidden_states[:, 0, :]
2300
+
2301
+ if self.use_end_thought_token:
2302
+ end_embed = self.end_embedding[0].unsqueeze(0) * temperature
2303
+ thought_embeds = torch.cat([start_embed, hidden_states[:, 1:-1, :], end_embed], dim=1)
2304
+ else:
2305
+ thought_embeds = torch.cat([start_embed, hidden_states[:, 1:, :]], dim=1)
2306
+
2307
+ return thought_embeds
2308
 
2309
 
2310
  @add_start_docstrings(