Crystalcareai commited on
Commit
1f8e662
·
verified ·
1 Parent(s): 38e552e

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +75 -49
modeling_quiet.py CHANGED
@@ -23,6 +23,7 @@ import math
23
  import copy
24
  import os
25
  import time
 
26
  import seaborn as sns
27
  import matplotlib.pyplot as plt
28
  import wandb
@@ -68,6 +69,73 @@ logger = logging.get_logger(__name__)
68
 
69
  _CONFIG_FOR_DOC = "QuietConfig"
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
72
  def _get_unpad_data(attention_mask):
73
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
@@ -257,13 +325,6 @@ class QuietAttention(nn.Module):
257
  use_cache: bool = False,
258
  **kwargs,
259
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
260
-
261
- if past_key_value is not None:
262
- expected_attention_mask_size = (bsz, 1, q_len, q_len + past_key_value.get_usable_length(q_len, self.layer_idx))
263
- if attention_mask.size() != expected_attention_mask_size:
264
- # Assuming the attention mask is larger than expected, slice it to match the expected size
265
- attention_mask = attention_mask[:, :, :, -expected_attention_mask_size[-1]:]
266
-
267
  if "padding_mask" in kwargs:
268
  warnings.warn(
269
  "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
@@ -307,16 +368,11 @@ class QuietAttention(nn.Module):
307
  )
308
 
309
  if attention_mask is not None:
310
- if attention_mask.dim() == 3:
311
- attention_mask = attention_mask.unsqueeze(1)
312
- elif attention_mask.dim() == 2:
313
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
314
-
315
- if attention_mask.size(0) != bsz or attention_mask.size(-1) != kv_seq_len:
316
  raise ValueError(
317
- f"Attention mask should be of size ({bsz}, 1, q_len, {kv_seq_len}), but is {attention_mask.size()}"
318
  )
319
-
320
  attn_weights = attn_weights + attention_mask
321
 
322
  # upcast attention to fp32
@@ -693,21 +749,11 @@ class QuietSdpaAttention(QuietAttention):
693
  value_states = repeat_kv(value_states, self.num_key_value_groups)
694
 
695
  if attention_mask is not None:
696
- if attention_mask.dim() == 3:
697
- attention_mask = attention_mask.unsqueeze(1)
698
- elif attention_mask.dim() == 2:
699
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
700
-
701
- if attention_mask is not None:
702
- if attention_mask.dim() == 3:
703
- attention_mask = attention_mask.unsqueeze(1)
704
- elif attention_mask.dim() == 2:
705
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
706
-
707
- if attention_mask.size(0) != bsz or attention_mask.size(-1) != kv_seq_len:
708
  raise ValueError(
709
- f"Attention mask should be of size ({bsz}, 1, q_len, {kv_seq_len}), but is {attention_mask.size()}"
710
  )
 
711
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
712
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
713
  if query_states.device.type == "cuda" and attention_mask is not None:
@@ -1281,27 +1327,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1281
  # Generate the continuation
1282
  continuation_length = self.n_ahead - 2
1283
  new_key_values = past_key_values
1284
-
1285
- if self.n_ahead != 1 or self.n_ahead_talk != 1 or self.comparison_mode:
1286
- if attention_mask is None:
1287
- base_attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=0).to(input_ids.device)
1288
- base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
1289
- base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
1290
- attention_mask = base_attention_mask
1291
- elif attention_mask.dim() == 2:
1292
- if seq_len + past_key_values_length != attention_mask.shape[-1]:
1293
- attention_mask = torch.cat(
1294
- [torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
1295
- dim=-1
1296
- )
1297
- attention_mask = _prepare_4d_causal_attention_mask(
1298
- attention_mask,
1299
- (batch_size, seq_len),
1300
- inputs_embeds,
1301
- past_key_values_length,
1302
- sliding_window=self.config.sliding_window,
1303
- )
1304
-
1305
  start_time = time.time()
1306
  for continuation_idx in range(continuation_length):
1307
  outputs = self.model(
 
23
  import copy
24
  import os
25
  import time
26
+ import pandas as pd
27
  import seaborn as sns
28
  import matplotlib.pyplot as plt
29
  import wandb
 
69
 
70
  _CONFIG_FOR_DOC = "QuietConfig"
71
 
72
+ from reportlab.pdfgen import canvas
73
+ from reportlab.lib.pagesizes import letter
74
+ from reportlab.lib.colors import HexColor
75
+
76
+ def save_tokens_with_rewards_to_pdf(input_ids, token_rewards, tokenizer, output_file="text.pdf", eps=0.2, eps2=0.5):
77
+ c = canvas.Canvas(output_file, pagesize=letter)
78
+ c.setFont("Courier", 8)
79
+ x, y = 50, 750
80
+ previous_text = ""
81
+ current_text = ""
82
+ for token_idx, reward in enumerate(token_rewards):
83
+ current_text = tokenizer.decode(input_ids[: token_idx + 1])
84
+ if current_text != previous_text:
85
+ diff_text = current_text[len(previous_text) :]
86
+ if "\n" in diff_text:
87
+ lines = diff_text.split("\n")
88
+ for line_idx, line in enumerate(lines):
89
+ if line_idx > 0:
90
+ x = 50
91
+ y -= 12
92
+ if abs(reward) < eps:
93
+ opacity = 0
94
+ elif abs(reward) > eps2:
95
+ opacity = 0.8
96
+ else:
97
+ opacity = 0.8 * (abs(reward) - eps) / (eps2 - eps)
98
+ text_width = c.stringWidth(line)
99
+ if reward > 0:
100
+ highlight_color = HexColor("#4CCD99")
101
+ else:
102
+ highlight_color = HexColor("#FFC700")
103
+ highlight_color.alpha = opacity
104
+ c.setFillColor(highlight_color)
105
+ c.rect(x, y - 2, text_width, 10, fill=True, stroke=False)
106
+ c.setFillColor(HexColor("#000000"))
107
+ c.drawString(x, y, line)
108
+ x += text_width
109
+ else:
110
+ if abs(reward) < eps:
111
+ opacity = 0
112
+ elif abs(reward) > eps2:
113
+ opacity = 0.8
114
+ else:
115
+ opacity = 0.8 * (abs(reward) - eps) / (eps2 - eps)
116
+ text_width = c.stringWidth(diff_text)
117
+ if reward > 0:
118
+ highlight_color = HexColor("#4CCD99")
119
+ else:
120
+ highlight_color = HexColor("#FFC700")
121
+ highlight_color.alpha = opacity
122
+ c.setFillColor(highlight_color)
123
+ c.rect(x, y - 2, text_width, 10, fill=True, stroke=False)
124
+ c.setFillColor(HexColor("#000000"))
125
+ c.drawString(x, y, diff_text)
126
+ x += text_width
127
+ if x > 550:
128
+ x = 50
129
+ y -= 12
130
+ if y < 50:
131
+ c.showPage()
132
+ y = 750
133
+ x = 50
134
+ previous_text = current_text
135
+ c.showPage()
136
+ c.save()
137
+
138
+
139
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
140
  def _get_unpad_data(attention_mask):
141
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
 
325
  use_cache: bool = False,
326
  **kwargs,
327
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 
 
 
 
 
 
 
328
  if "padding_mask" in kwargs:
329
  warnings.warn(
330
  "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
 
368
  )
369
 
370
  if attention_mask is not None:
371
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
 
 
 
 
 
372
  raise ValueError(
373
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
374
  )
375
+
376
  attn_weights = attn_weights + attention_mask
377
 
378
  # upcast attention to fp32
 
749
  value_states = repeat_kv(value_states, self.num_key_value_groups)
750
 
751
  if attention_mask is not None:
752
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
 
 
 
 
 
 
 
 
 
 
 
753
  raise ValueError(
754
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
755
  )
756
+
757
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
758
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
759
  if query_states.device.type == "cuda" and attention_mask is not None:
 
1327
  # Generate the continuation
1328
  continuation_length = self.n_ahead - 2
1329
  new_key_values = past_key_values
1330
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1331
  start_time = time.time()
1332
  for continuation_idx in range(continuation_length):
1333
  outputs = self.model(