Crystalcareai commited on
Commit
dc09b15
·
verified ·
1 Parent(s): 560abff

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +68 -21
modeling_quiet.py CHANGED
@@ -270,10 +270,22 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
270
 
271
 
272
  class QuietAttention(nn.Module):
 
 
 
 
 
273
  def __init__(self, config: QuietConfig, layer_idx: Optional[int] = None):
274
  super().__init__()
275
  self.config = config
276
  self.layer_idx = layer_idx
 
 
 
 
 
 
 
277
  self.hidden_size = config.hidden_size
278
  self.num_heads = config.num_attention_heads
279
  self.head_dim = self.hidden_size // self.num_heads
@@ -284,6 +296,11 @@ class QuietAttention(nn.Module):
284
  self.is_causal = True
285
  self.attention_dropout = config.attention_dropout
286
 
 
 
 
 
 
287
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
288
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
289
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
@@ -295,6 +312,9 @@ class QuietAttention(nn.Module):
295
  base=self.rope_theta,
296
  )
297
 
 
 
 
298
  def forward(
299
  self,
300
  hidden_states: torch.Tensor,
@@ -304,7 +324,11 @@ class QuietAttention(nn.Module):
304
  output_attentions: bool = False,
305
  use_cache: bool = False,
306
  **kwargs,
307
- ):
 
 
 
 
308
  bsz, q_len, _ = hidden_states.size()
309
 
310
  query_states = self.q_proj(hidden_states)
@@ -318,30 +342,50 @@ class QuietAttention(nn.Module):
318
  kv_seq_len = key_states.shape[-2]
319
  if past_key_value is not None:
320
  if self.layer_idx is None:
321
- raise ValueError("Layer index must be provided when using past key values.")
 
 
 
 
322
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
323
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
324
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
325
 
326
  if past_key_value is not None:
327
- cache_kwargs = {"sin": sin, "cos": cos}
328
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
329
 
 
330
  key_states = repeat_kv(key_states, self.num_key_value_groups)
331
  value_states = repeat_kv(value_states, self.num_key_value_groups)
332
 
333
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
334
 
 
 
 
 
 
 
335
  if attention_mask is not None:
336
- if attention_mask.size(-1) != kv_seq_len:
337
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
338
- attention_mask = attention_mask.expand(-1, 1, q_len, kv_seq_len)
 
 
339
  attn_weights = attn_weights + attention_mask
340
 
 
341
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
342
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
343
  attn_output = torch.matmul(attn_weights, value_states)
344
 
 
 
 
 
 
 
345
  attn_output = attn_output.transpose(1, 2).contiguous()
346
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
347
 
@@ -1017,12 +1061,6 @@ class QuietModel(QuietPreTrainedModel):
1017
 
1018
  if inputs_embeds is None:
1019
  inputs_embeds = self.embed_tokens(input_ids)
1020
-
1021
- if attention_mask is not None:
1022
- attention_mask = attention_mask.unsqueeze(-1)
1023
- inputs_embeds = inputs_embeds * attention_mask
1024
-
1025
- hidden_states = inputs_embeds
1026
 
1027
  if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
1028
  is_padding_right = attention_mask[:, -1].sum().item() != batch_size
@@ -1045,7 +1083,16 @@ class QuietModel(QuietPreTrainedModel):
1045
  inputs_embeds,
1046
  past_key_values_length,
1047
  )
1048
-
 
 
 
 
 
 
 
 
 
1049
  hidden_states = inputs_embeds
1050
 
1051
  # decoder layers
@@ -1261,7 +1308,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
1261
  output_attentions: Optional[bool] = None,
1262
  output_hidden_states: Optional[bool] = None,
1263
  return_dict: Optional[bool] = None,
1264
- max_length: Optional[int] = None, # Add the max_length argument
1265
  ):
1266
  batch_size, seq_len = input_ids.shape
1267
 
@@ -1279,12 +1325,11 @@ class QuietForCausalLM(QuietPreTrainedModel):
1279
  attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1280
 
1281
  # Generate the continuation
1282
- if max_length is None:
1283
- max_length = self.n_ahead - 2 # Use the default value if max_length is not provided
1284
  new_key_values = past_key_values
 
1285
  start_time = time.time()
1286
-
1287
- for continuation_idx in range(max_length):
1288
  outputs = self.model(
1289
  input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
1290
  attention_mask=attention_mask,
@@ -1297,7 +1342,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
1297
  return_dict=return_dict,
1298
  )
1299
  new_key_values = outputs.past_key_values
 
1300
  hidden_states = outputs[0]
 
1301
  logits = self.lm_head(hidden_states)
1302
  logits = logits[:, -1, :] # Only consider the last token
1303
 
@@ -1336,9 +1383,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
1336
  )
1337
  hidden_states_before = outputs_before[0][:, -1:, :]
1338
 
1339
- # Get the hidden states after the thought
1340
  outputs_after = self.model(
1341
- input_ids=input_ids,
1342
  attention_mask=attention_mask,
1343
  position_ids=position_ids,
1344
  past_key_values=new_key_values,
@@ -1358,8 +1405,8 @@ class QuietForCausalLM(QuietPreTrainedModel):
1358
 
1359
  # Apply the language model head to get the final logits
1360
  logits = self.lm_head(mixed_hidden_states)
1361
-
1362
  return logits
 
1363
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1364
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1365
  def forward(
 
270
 
271
 
272
  class QuietAttention(nn.Module):
273
+ """
274
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
275
+ and "Generating Long Sequences with Sparse Transformers".
276
+ """
277
+
278
  def __init__(self, config: QuietConfig, layer_idx: Optional[int] = None):
279
  super().__init__()
280
  self.config = config
281
  self.layer_idx = layer_idx
282
+ if layer_idx is None:
283
+ logger.warning_once(
284
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
285
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
286
+ "when creating this class."
287
+ )
288
+
289
  self.hidden_size = config.hidden_size
290
  self.num_heads = config.num_attention_heads
291
  self.head_dim = self.hidden_size // self.num_heads
 
296
  self.is_causal = True
297
  self.attention_dropout = config.attention_dropout
298
 
299
+ if (self.head_dim * self.num_heads) != self.hidden_size:
300
+ raise ValueError(
301
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
302
+ f" and `num_heads`: {self.num_heads})."
303
+ )
304
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
305
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
306
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
 
312
  base=self.rope_theta,
313
  )
314
 
315
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
316
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
317
+
318
  def forward(
319
  self,
320
  hidden_states: torch.Tensor,
 
324
  output_attentions: bool = False,
325
  use_cache: bool = False,
326
  **kwargs,
327
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
328
+ if "padding_mask" in kwargs:
329
+ warnings.warn(
330
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
331
+ )
332
  bsz, q_len, _ = hidden_states.size()
333
 
334
  query_states = self.q_proj(hidden_states)
 
342
  kv_seq_len = key_states.shape[-2]
343
  if past_key_value is not None:
344
  if self.layer_idx is None:
345
+ raise ValueError(
346
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
347
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
348
+ "with a layer index."
349
+ )
350
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
351
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
352
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
353
 
354
  if past_key_value is not None:
355
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
356
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
357
 
358
+ # repeat k/v heads if n_kv_heads < n_heads
359
  key_states = repeat_kv(key_states, self.num_key_value_groups)
360
  value_states = repeat_kv(value_states, self.num_key_value_groups)
361
 
362
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
363
 
364
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
365
+ raise ValueError(
366
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
367
+ f" {attn_weights.size()}"
368
+ )
369
+
370
  if attention_mask is not None:
371
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
372
+ raise ValueError(
373
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
374
+ )
375
+
376
  attn_weights = attn_weights + attention_mask
377
 
378
+ # upcast attention to fp32
379
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
380
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
381
  attn_output = torch.matmul(attn_weights, value_states)
382
 
383
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
384
+ raise ValueError(
385
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
386
+ f" {attn_output.size()}"
387
+ )
388
+
389
  attn_output = attn_output.transpose(1, 2).contiguous()
390
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
391
 
 
1061
 
1062
  if inputs_embeds is None:
1063
  inputs_embeds = self.embed_tokens(input_ids)
 
 
 
 
 
 
1064
 
1065
  if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
1066
  is_padding_right = attention_mask[:, -1].sum().item() != batch_size
 
1083
  inputs_embeds,
1084
  past_key_values_length,
1085
  )
1086
+ elif attention_mask is None or attention_mask.dim() == 2:
1087
+ # 4d mask is passed through the layers
1088
+ attention_mask = _prepare_4d_causal_attention_mask(
1089
+ attention_mask,
1090
+ (batch_size, seq_length),
1091
+ inputs_embeds,
1092
+ past_key_values_length,
1093
+ sliding_window=self.config.sliding_window,
1094
+ )
1095
+
1096
  hidden_states = inputs_embeds
1097
 
1098
  # decoder layers
 
1308
  output_attentions: Optional[bool] = None,
1309
  output_hidden_states: Optional[bool] = None,
1310
  return_dict: Optional[bool] = None,
 
1311
  ):
1312
  batch_size, seq_len = input_ids.shape
1313
 
 
1325
  attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1326
 
1327
  # Generate the continuation
1328
+ continuation_length = self.n_ahead - 2
 
1329
  new_key_values = past_key_values
1330
+
1331
  start_time = time.time()
1332
+ for continuation_idx in range(continuation_length):
 
1333
  outputs = self.model(
1334
  input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
1335
  attention_mask=attention_mask,
 
1342
  return_dict=return_dict,
1343
  )
1344
  new_key_values = outputs.past_key_values
1345
+
1346
  hidden_states = outputs[0]
1347
+
1348
  logits = self.lm_head(hidden_states)
1349
  logits = logits[:, -1, :] # Only consider the last token
1350
 
 
1383
  )
1384
  hidden_states_before = outputs_before[0][:, -1:, :]
1385
 
1386
+ # two new tokens: last continuation token and end thought token
1387
  outputs_after = self.model(
1388
+ input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor(end_thought_token_id).unsqueeze(-1).unsqueeze(-1).to(input_ids.device)], dim=-1),
1389
  attention_mask=attention_mask,
1390
  position_ids=position_ids,
1391
  past_key_values=new_key_values,
 
1405
 
1406
  # Apply the language model head to get the final logits
1407
  logits = self.lm_head(mixed_hidden_states)
 
1408
  return logits
1409
+
1410
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1411
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1412
  def forward(