Crystalcareai commited on
Commit
a00ce27
·
verified ·
1 Parent(s): c066ef6

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +31 -1
modeling_quiet.py CHANGED
@@ -1665,7 +1665,37 @@ class QuietForCausalLM(QuietPreTrainedModel):
1665
  else:
1666
  with torch.set_grad_enabled(not self.train_only_thinking_embedding):
1667
  inputs_embeds = self.model.embed_tokens(input_ids)
1668
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1669
  if self.n_ahead != 1 or self.n_ahead_talk != 1 or self.comparison_mode:
1670
  if attention_mask is None:
1671
  base_attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=0).to(input_ids.device)
 
1665
  else:
1666
  with torch.set_grad_enabled(not self.train_only_thinking_embedding):
1667
  inputs_embeds = self.model.embed_tokens(input_ids)
1668
+
1669
+ def _update_inputs_for_thought_tokens(
1670
+ self, input_ids, attention_mask, contains_start, contains_end
1671
+ ):
1672
+ batch_size = input_ids.size(0)
1673
+ seq_len = input_ids.size(1)
1674
+
1675
+ if contains_start:
1676
+ start_token_ids = torch.tensor(
1677
+ [[self.start_token_id]] * batch_size, device=input_ids.device
1678
+ )
1679
+ input_ids = torch.cat([input_ids, start_token_ids], dim=1)
1680
+ if attention_mask is not None:
1681
+ start_attention_mask = torch.ones(
1682
+ (batch_size, 1), device=attention_mask.device
1683
+ )
1684
+ attention_mask = torch.cat([attention_mask, start_attention_mask], dim=1)
1685
+
1686
+ if contains_end:
1687
+ end_token_ids = torch.tensor(
1688
+ [[self.end_token_id]] * batch_size, device=input_ids.device
1689
+ )
1690
+ input_ids = torch.cat([input_ids, end_token_ids], dim=1)
1691
+ if attention_mask is not None:
1692
+ end_attention_mask = torch.ones(
1693
+ (batch_size, 1), device=attention_mask.device
1694
+ )
1695
+ attention_mask = torch.cat([attention_mask, end_attention_mask], dim=1)
1696
+
1697
+ return input_ids, attention_mask
1698
+
1699
  if self.n_ahead != 1 or self.n_ahead_talk != 1 or self.comparison_mode:
1700
  if attention_mask is None:
1701
  base_attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=0).to(input_ids.device)