Crystalcareai commited on
Commit
a9e5703
·
verified ·
1 Parent(s): f874538

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +1 -76
modeling_quiet.py CHANGED
@@ -2231,82 +2231,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
2231
  )
2232
  return reordered_past
2233
 
2234
- @torch.no_grad()
2235
- def generate(self, input_ids, attention_mask=None, **generate_kwargs):
2236
- batch_size, seq_len = input_ids.shape
2237
-
2238
- max_length = generate_kwargs.get("max_length", self.config.max_length)
2239
- min_length = generate_kwargs.get("min_length", self.config.min_length)
2240
- do_sample = generate_kwargs.get("do_sample", self.config.do_sample)
2241
- early_stopping = generate_kwargs.get("early_stopping", self.config.early_stopping)
2242
- num_beams = generate_kwargs.get("num_beams", self.config.num_beams)
2243
- temperature = generate_kwargs.get("temperature", self.config.temperature)
2244
- top_k = generate_kwargs.get("top_k", self.config.top_k)
2245
- top_p = generate_kwargs.get("top_p", self.config.top_p)
2246
- repetition_penalty = generate_kwargs.get("repetition_penalty", self.config.repetition_penalty)
2247
- pad_token_id = generate_kwargs.get("pad_token_id", self.config.pad_token_id)
2248
- eos_token_id = generate_kwargs.get("eos_token_id", self.config.eos_token_id)
2249
-
2250
- # Prepend the start thought token to the input sequence
2251
- start_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
2252
- input_ids = torch.cat([input_ids, torch.tensor([[start_thought_token_id]], device=input_ids.device)], dim=-1)
2253
- if attention_mask is not None:
2254
- attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1), device=attention_mask.device)], dim=-1)
2255
-
2256
- thought_embeds = self.model.embed_tokens(input_ids)
2257
-
2258
- past_key_values = None
2259
- unfinished_sequences = input_ids.new(batch_size).fill_(1)
2260
- sequence_lengths = input_ids.new(batch_size).fill_(max_length)
2261
-
2262
- while True:
2263
- model_inputs = self.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, attention_mask=attention_mask, use_cache=True)
2264
- thought_outputs = self.model(**model_inputs, output_hidden_states=True, return_dict=True)
2265
- next_thought_embeds = self.prepare_thought_embeds(thought_outputs.hidden_states[-1], temperature=temperature)
2266
-
2267
- thought_embeds = torch.cat([thought_embeds, next_thought_embeds], dim=1)
2268
-
2269
- lm_logits = self.lm_head(thought_outputs.last_hidden_state)
2270
- lm_logits = lm_logits[:, -1, :] / temperature
2271
-
2272
- if do_sample:
2273
- next_tokens = torch.multinomial(F.softmax(lm_logits, dim=-1), num_samples=1).squeeze(1)
2274
- else:
2275
- next_tokens = torch.argmax(lm_logits, dim=-1)
2276
-
2277
- # Update input_ids, attention_mask and past_key_values
2278
- input_ids = torch.cat([input_ids, next_tokens.unsqueeze(-1)], dim=-1)
2279
- if attention_mask is not None:
2280
- attention_mask = torch.cat([attention_mask, attention_mask.new_ones((batch_size, 1))], dim=-1)
2281
- past_key_values = thought_outputs.past_key_values
2282
-
2283
- # Check if generation is complete
2284
- if eos_token_id is not None:
2285
- unfinished_sequences = unfinished_sequences & (next_tokens != eos_token_id)
2286
- unfinished_sequences = unfinished_sequences & (input_ids.shape[-1] < max_length)
2287
- if unfinished_sequences.max() == 0:
2288
- break
2289
- elif input_ids.shape[-1] >= max_length:
2290
- input_ids[:, 0] = eos_token_id
2291
- break
2292
-
2293
- return input_ids
2294
-
2295
- def prepare_thought_embeds(self, hidden_states, temperature=1.0):
2296
- batch_size, seq_len, hidden_size = hidden_states.shape
2297
-
2298
- if self.use_start_thought_token:
2299
- start_embed = self.start_embedding[0].unsqueeze(0).unsqueeze(0).repeat(batch_size, 1, 1) * temperature
2300
- else:
2301
- start_embed = hidden_states[:, :1, :]
2302
-
2303
- if self.use_end_thought_token:
2304
- end_embed = self.end_embedding[0].unsqueeze(0).unsqueeze(0).repeat(batch_size, 1, 1) * temperature
2305
- thought_embeds = torch.cat([start_embed, hidden_states[:, 1:-1, :], end_embed], dim=1)
2306
- else:
2307
- thought_embeds = torch.cat([start_embed, hidden_states[:, 1:, :]], dim=1)
2308
-
2309
- return thought_embeds
2310
 
2311
 
2312
  @add_start_docstrings(
 
2231
  )
2232
  return reordered_past
2233
 
2234
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2235
 
2236
 
2237
  @add_start_docstrings(