Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +75 -49
modeling_quiet.py
CHANGED
@@ -23,6 +23,7 @@ import math
|
|
23 |
import copy
|
24 |
import os
|
25 |
import time
|
|
|
26 |
import seaborn as sns
|
27 |
import matplotlib.pyplot as plt
|
28 |
import wandb
|
@@ -68,6 +69,73 @@ logger = logging.get_logger(__name__)
|
|
68 |
|
69 |
_CONFIG_FOR_DOC = "QuietConfig"
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
72 |
def _get_unpad_data(attention_mask):
|
73 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
@@ -257,13 +325,6 @@ class QuietAttention(nn.Module):
|
|
257 |
use_cache: bool = False,
|
258 |
**kwargs,
|
259 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
260 |
-
|
261 |
-
if past_key_value is not None:
|
262 |
-
expected_attention_mask_size = (bsz, 1, q_len, q_len + past_key_value.get_usable_length(q_len, self.layer_idx))
|
263 |
-
if attention_mask.size() != expected_attention_mask_size:
|
264 |
-
# Assuming the attention mask is larger than expected, slice it to match the expected size
|
265 |
-
attention_mask = attention_mask[:, :, :, -expected_attention_mask_size[-1]:]
|
266 |
-
|
267 |
if "padding_mask" in kwargs:
|
268 |
warnings.warn(
|
269 |
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
@@ -307,16 +368,11 @@ class QuietAttention(nn.Module):
|
|
307 |
)
|
308 |
|
309 |
if attention_mask is not None:
|
310 |
-
if attention_mask.
|
311 |
-
attention_mask = attention_mask.unsqueeze(1)
|
312 |
-
elif attention_mask.dim() == 2:
|
313 |
-
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
314 |
-
|
315 |
-
if attention_mask.size(0) != bsz or attention_mask.size(-1) != kv_seq_len:
|
316 |
raise ValueError(
|
317 |
-
f"Attention mask should be of size (
|
318 |
)
|
319 |
-
|
320 |
attn_weights = attn_weights + attention_mask
|
321 |
|
322 |
# upcast attention to fp32
|
@@ -693,21 +749,11 @@ class QuietSdpaAttention(QuietAttention):
|
|
693 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
694 |
|
695 |
if attention_mask is not None:
|
696 |
-
if attention_mask.
|
697 |
-
attention_mask = attention_mask.unsqueeze(1)
|
698 |
-
elif attention_mask.dim() == 2:
|
699 |
-
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
700 |
-
|
701 |
-
if attention_mask is not None:
|
702 |
-
if attention_mask.dim() == 3:
|
703 |
-
attention_mask = attention_mask.unsqueeze(1)
|
704 |
-
elif attention_mask.dim() == 2:
|
705 |
-
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
706 |
-
|
707 |
-
if attention_mask.size(0) != bsz or attention_mask.size(-1) != kv_seq_len:
|
708 |
raise ValueError(
|
709 |
-
f"Attention mask should be of size (
|
710 |
)
|
|
|
711 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
712 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
713 |
if query_states.device.type == "cuda" and attention_mask is not None:
|
@@ -1281,27 +1327,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1281 |
# Generate the continuation
|
1282 |
continuation_length = self.n_ahead - 2
|
1283 |
new_key_values = past_key_values
|
1284 |
-
|
1285 |
-
if self.n_ahead != 1 or self.n_ahead_talk != 1 or self.comparison_mode:
|
1286 |
-
if attention_mask is None:
|
1287 |
-
base_attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=0).to(input_ids.device)
|
1288 |
-
base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
|
1289 |
-
base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
|
1290 |
-
attention_mask = base_attention_mask
|
1291 |
-
elif attention_mask.dim() == 2:
|
1292 |
-
if seq_len + past_key_values_length != attention_mask.shape[-1]:
|
1293 |
-
attention_mask = torch.cat(
|
1294 |
-
[torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
|
1295 |
-
dim=-1
|
1296 |
-
)
|
1297 |
-
attention_mask = _prepare_4d_causal_attention_mask(
|
1298 |
-
attention_mask,
|
1299 |
-
(batch_size, seq_len),
|
1300 |
-
inputs_embeds,
|
1301 |
-
past_key_values_length,
|
1302 |
-
sliding_window=self.config.sliding_window,
|
1303 |
-
)
|
1304 |
-
|
1305 |
start_time = time.time()
|
1306 |
for continuation_idx in range(continuation_length):
|
1307 |
outputs = self.model(
|
|
|
23 |
import copy
|
24 |
import os
|
25 |
import time
|
26 |
+
import pandas as pd
|
27 |
import seaborn as sns
|
28 |
import matplotlib.pyplot as plt
|
29 |
import wandb
|
|
|
69 |
|
70 |
_CONFIG_FOR_DOC = "QuietConfig"
|
71 |
|
72 |
+
from reportlab.pdfgen import canvas
|
73 |
+
from reportlab.lib.pagesizes import letter
|
74 |
+
from reportlab.lib.colors import HexColor
|
75 |
+
|
76 |
+
def save_tokens_with_rewards_to_pdf(input_ids, token_rewards, tokenizer, output_file="text.pdf", eps=0.2, eps2=0.5):
|
77 |
+
c = canvas.Canvas(output_file, pagesize=letter)
|
78 |
+
c.setFont("Courier", 8)
|
79 |
+
x, y = 50, 750
|
80 |
+
previous_text = ""
|
81 |
+
current_text = ""
|
82 |
+
for token_idx, reward in enumerate(token_rewards):
|
83 |
+
current_text = tokenizer.decode(input_ids[: token_idx + 1])
|
84 |
+
if current_text != previous_text:
|
85 |
+
diff_text = current_text[len(previous_text) :]
|
86 |
+
if "\n" in diff_text:
|
87 |
+
lines = diff_text.split("\n")
|
88 |
+
for line_idx, line in enumerate(lines):
|
89 |
+
if line_idx > 0:
|
90 |
+
x = 50
|
91 |
+
y -= 12
|
92 |
+
if abs(reward) < eps:
|
93 |
+
opacity = 0
|
94 |
+
elif abs(reward) > eps2:
|
95 |
+
opacity = 0.8
|
96 |
+
else:
|
97 |
+
opacity = 0.8 * (abs(reward) - eps) / (eps2 - eps)
|
98 |
+
text_width = c.stringWidth(line)
|
99 |
+
if reward > 0:
|
100 |
+
highlight_color = HexColor("#4CCD99")
|
101 |
+
else:
|
102 |
+
highlight_color = HexColor("#FFC700")
|
103 |
+
highlight_color.alpha = opacity
|
104 |
+
c.setFillColor(highlight_color)
|
105 |
+
c.rect(x, y - 2, text_width, 10, fill=True, stroke=False)
|
106 |
+
c.setFillColor(HexColor("#000000"))
|
107 |
+
c.drawString(x, y, line)
|
108 |
+
x += text_width
|
109 |
+
else:
|
110 |
+
if abs(reward) < eps:
|
111 |
+
opacity = 0
|
112 |
+
elif abs(reward) > eps2:
|
113 |
+
opacity = 0.8
|
114 |
+
else:
|
115 |
+
opacity = 0.8 * (abs(reward) - eps) / (eps2 - eps)
|
116 |
+
text_width = c.stringWidth(diff_text)
|
117 |
+
if reward > 0:
|
118 |
+
highlight_color = HexColor("#4CCD99")
|
119 |
+
else:
|
120 |
+
highlight_color = HexColor("#FFC700")
|
121 |
+
highlight_color.alpha = opacity
|
122 |
+
c.setFillColor(highlight_color)
|
123 |
+
c.rect(x, y - 2, text_width, 10, fill=True, stroke=False)
|
124 |
+
c.setFillColor(HexColor("#000000"))
|
125 |
+
c.drawString(x, y, diff_text)
|
126 |
+
x += text_width
|
127 |
+
if x > 550:
|
128 |
+
x = 50
|
129 |
+
y -= 12
|
130 |
+
if y < 50:
|
131 |
+
c.showPage()
|
132 |
+
y = 750
|
133 |
+
x = 50
|
134 |
+
previous_text = current_text
|
135 |
+
c.showPage()
|
136 |
+
c.save()
|
137 |
+
|
138 |
+
|
139 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
140 |
def _get_unpad_data(attention_mask):
|
141 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
|
|
325 |
use_cache: bool = False,
|
326 |
**kwargs,
|
327 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
if "padding_mask" in kwargs:
|
329 |
warnings.warn(
|
330 |
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
|
|
368 |
)
|
369 |
|
370 |
if attention_mask is not None:
|
371 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
|
|
|
|
|
|
|
|
|
|
372 |
raise ValueError(
|
373 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
374 |
)
|
375 |
+
|
376 |
attn_weights = attn_weights + attention_mask
|
377 |
|
378 |
# upcast attention to fp32
|
|
|
749 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
750 |
|
751 |
if attention_mask is not None:
|
752 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
753 |
raise ValueError(
|
754 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
755 |
)
|
756 |
+
|
757 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
758 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
759 |
if query_states.device.type == "cuda" and attention_mask is not None:
|
|
|
1327 |
# Generate the continuation
|
1328 |
continuation_length = self.n_ahead - 2
|
1329 |
new_key_values = past_key_values
|
1330 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1331 |
start_time = time.time()
|
1332 |
for continuation_idx in range(continuation_length):
|
1333 |
outputs = self.model(
|