Update modeling_quiet.py
Browse files- modeling_quiet.py +31 -1
modeling_quiet.py
CHANGED
@@ -1665,7 +1665,37 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1665 |
else:
|
1666 |
with torch.set_grad_enabled(not self.train_only_thinking_embedding):
|
1667 |
inputs_embeds = self.model.embed_tokens(input_ids)
|
1668 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1669 |
if self.n_ahead != 1 or self.n_ahead_talk != 1 or self.comparison_mode:
|
1670 |
if attention_mask is None:
|
1671 |
base_attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=0).to(input_ids.device)
|
|
|
1665 |
else:
|
1666 |
with torch.set_grad_enabled(not self.train_only_thinking_embedding):
|
1667 |
inputs_embeds = self.model.embed_tokens(input_ids)
|
1668 |
+
|
1669 |
+
def _update_inputs_for_thought_tokens(
|
1670 |
+
self, input_ids, attention_mask, contains_start, contains_end
|
1671 |
+
):
|
1672 |
+
batch_size = input_ids.size(0)
|
1673 |
+
seq_len = input_ids.size(1)
|
1674 |
+
|
1675 |
+
if contains_start:
|
1676 |
+
start_token_ids = torch.tensor(
|
1677 |
+
[[self.start_token_id]] * batch_size, device=input_ids.device
|
1678 |
+
)
|
1679 |
+
input_ids = torch.cat([input_ids, start_token_ids], dim=1)
|
1680 |
+
if attention_mask is not None:
|
1681 |
+
start_attention_mask = torch.ones(
|
1682 |
+
(batch_size, 1), device=attention_mask.device
|
1683 |
+
)
|
1684 |
+
attention_mask = torch.cat([attention_mask, start_attention_mask], dim=1)
|
1685 |
+
|
1686 |
+
if contains_end:
|
1687 |
+
end_token_ids = torch.tensor(
|
1688 |
+
[[self.end_token_id]] * batch_size, device=input_ids.device
|
1689 |
+
)
|
1690 |
+
input_ids = torch.cat([input_ids, end_token_ids], dim=1)
|
1691 |
+
if attention_mask is not None:
|
1692 |
+
end_attention_mask = torch.ones(
|
1693 |
+
(batch_size, 1), device=attention_mask.device
|
1694 |
+
)
|
1695 |
+
attention_mask = torch.cat([attention_mask, end_attention_mask], dim=1)
|
1696 |
+
|
1697 |
+
return input_ids, attention_mask
|
1698 |
+
|
1699 |
if self.n_ahead != 1 or self.n_ahead_talk != 1 or self.comparison_mode:
|
1700 |
if attention_mask is None:
|
1701 |
base_attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=0).to(input_ids.device)
|