Crystalcareai commited on
Commit
3d962f0
·
verified ·
1 Parent(s): 5af661b

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +6 -322
modeling_quiet.py CHANGED
@@ -47,20 +47,11 @@ from transformers.utils import (
47
  logging,
48
  replace_return_docstrings,
49
  )
50
- import transformers
51
  from .configuration_quiet import QuietConfig
52
 
53
  import time
54
  from typing import Optional, List
55
 
56
-
57
- if is_flash_attn_2_available():
58
- from flash_attn import flash_attn_func, flash_attn_varlen_func
59
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
60
-
61
- _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
62
-
63
-
64
  logger = logging.get_logger(__name__)
65
 
66
  _CONFIG_FOR_DOC = "QuietConfig"
@@ -408,312 +399,6 @@ class QuietAttention(nn.Module):
408
  return attn_output, attn_weights, past_key_value
409
 
410
 
411
- class QuietFlashAttention2(QuietAttention):
412
- """
413
- Quiet flash attention module. This module inherits from `QuietAttention` as the weights of the module stays
414
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
415
- flash attention and deal with padding tokens in case the input contains any of them.
416
- """
417
-
418
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
419
- def __init__(self, *args, **kwargs):
420
- super().__init__(*args, **kwargs)
421
-
422
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
423
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
424
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
425
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
426
-
427
- def forward(
428
- self,
429
- hidden_states: torch.Tensor,
430
- attention_mask: Optional[torch.Tensor] = None,
431
- position_ids: Optional[torch.LongTensor] = None,
432
- past_key_value: Optional[Cache] = None,
433
- output_attentions: bool = False,
434
- use_cache: bool = False,
435
- **kwargs,
436
- ):
437
- if "padding_mask" in kwargs:
438
- warnings.warn(
439
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
440
- )
441
-
442
- # overwrite attention_mask with padding_mask
443
- attention_mask = kwargs.pop("padding_mask")
444
- bsz, q_len, _ = hidden_states.size()
445
-
446
- query_states = self.q_proj(hidden_states)
447
- key_states = self.k_proj(hidden_states)
448
- value_states = self.v_proj(hidden_states)
449
-
450
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
451
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
452
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
453
-
454
- kv_seq_len = key_states.shape[-2]
455
- if past_key_value is not None:
456
- if self.layer_idx is None:
457
- raise ValueError(
458
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
459
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
460
- "with a layer index."
461
- )
462
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
463
-
464
- # Because the input can be padded, the absolute sequence length depends on the max position id.
465
- rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
466
- cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
467
-
468
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
469
-
470
- use_sliding_windows = (
471
- _flash_supports_window_size
472
- and getattr(self.config, "sliding_window", None) is not None
473
- and kv_seq_len > self.config.sliding_window
474
- )
475
-
476
- if not _flash_supports_window_size:
477
- logger.warning_once(
478
- "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
479
- " make sure to upgrade flash-attn library."
480
- )
481
-
482
- if past_key_value is not None:
483
- # Activate slicing cache only if the config has a value `sliding_windows` attribute
484
- cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
485
- if (
486
- getattr(self.config, "sliding_window", None) is not None
487
- and kv_seq_len > self.config.sliding_window
488
- and cache_has_contents
489
- ):
490
- slicing_tokens = 1 - self.config.sliding_window
491
-
492
- past_key = past_key_value[self.layer_idx][0]
493
- past_value = past_key_value[self.layer_idx][1]
494
-
495
- past_key = past_key[:, :, slicing_tokens:, :].contiguous()
496
- past_value = past_value[:, :, slicing_tokens:, :].contiguous()
497
-
498
- if past_key.shape[-2] != self.config.sliding_window - 1:
499
- raise ValueError(
500
- f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
501
- f" {past_key.shape}"
502
- )
503
-
504
- if attention_mask is not None:
505
- attention_mask = attention_mask[:, slicing_tokens:]
506
- attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
507
-
508
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
509
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
510
-
511
- # repeat k/v heads if n_kv_heads < n_heads
512
- key_states = repeat_kv(key_states, self.num_key_value_groups)
513
- value_states = repeat_kv(value_states, self.num_key_value_groups)
514
- dropout_rate = 0.0 if not self.training else self.attention_dropout
515
-
516
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
517
- # therefore the input hidden states gets silently casted in float32. Hence, we need
518
- # cast them back in float16 just to be sure everything works as expected.
519
- input_dtype = query_states.dtype
520
- if input_dtype == torch.float32:
521
- if torch.is_autocast_enabled():
522
- target_dtype = torch.get_autocast_gpu_dtype()
523
- # Handle the case where the model is quantized
524
- elif hasattr(self.config, "_pre_quantization_dtype"):
525
- target_dtype = self.config._pre_quantization_dtype
526
- else:
527
- target_dtype = self.q_proj.weight.dtype
528
-
529
- logger.warning_once(
530
- f"The input hidden states seems to be silently casted in float32, this might be related to"
531
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
532
- f" {target_dtype}."
533
- )
534
-
535
- query_states = query_states.to(target_dtype)
536
- key_states = key_states.to(target_dtype)
537
- value_states = value_states.to(target_dtype)
538
-
539
- # Reashape to the expected shape for Flash Attention
540
- query_states = query_states.transpose(1, 2)
541
- key_states = key_states.transpose(1, 2)
542
- value_states = value_states.transpose(1, 2)
543
-
544
- attn_output = self._flash_attention_forward(
545
- query_states,
546
- key_states,
547
- value_states,
548
- attention_mask,
549
- q_len,
550
- dropout=dropout_rate,
551
- use_sliding_windows=use_sliding_windows,
552
- )
553
-
554
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
555
- attn_output = self.o_proj(attn_output)
556
-
557
- if not output_attentions:
558
- attn_weights = None
559
-
560
- return attn_output, attn_weights, past_key_value
561
-
562
- def _flash_attention_forward(
563
- self,
564
- query_states,
565
- key_states,
566
- value_states,
567
- attention_mask,
568
- query_length,
569
- dropout=0.0,
570
- softmax_scale=None,
571
- use_sliding_windows=False,
572
- ):
573
- """
574
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
575
- first unpad the input, then computes the attention scores and pad the final attention scores.
576
- Args:
577
- query_states (`torch.Tensor`):
578
- Input query states to be passed to Flash Attention API
579
- key_states (`torch.Tensor`):
580
- Input key states to be passed to Flash Attention API
581
- value_states (`torch.Tensor`):
582
- Input value states to be passed to Flash Attention API
583
- attention_mask (`torch.Tensor`):
584
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
585
- position of padding tokens and 1 for the position of non-padding tokens.
586
- dropout (`int`, *optional*):
587
- Attention dropout
588
- softmax_scale (`float`, *optional*):
589
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
590
- use_sliding_windows (`bool`, *optional*):
591
- Whether to activate sliding window attention.
592
- """
593
- if not self._flash_attn_uses_top_left_mask:
594
- causal = self.is_causal
595
- else:
596
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
597
- causal = self.is_causal and query_length != 1
598
-
599
- # Ensure attention_mask has the correct shape and values
600
- if attention_mask is not None:
601
- if attention_mask.dim() == 4:
602
- # Convert 4D attention mask to 2D
603
- attention_mask = attention_mask.squeeze(1).squeeze(1)
604
- elif attention_mask.dim() != 2:
605
- raise ValueError(
606
- f"Invalid attention mask dimension: {attention_mask.dim()}. Expected 2D or 4D mask."
607
- )
608
-
609
- # Ensure attention_mask has values of 0 and 1
610
- attention_mask = attention_mask.to(torch.bool).to(torch.int32)
611
-
612
- # Contains at least one padding token in the sequence
613
- if attention_mask is not None:
614
- batch_size = query_states.shape[0]
615
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
616
- query_states, key_states, value_states, attention_mask, query_length
617
- )
618
-
619
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
620
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
621
-
622
- if not use_sliding_windows:
623
- attn_output_unpad = flash_attn_varlen_func(
624
- query_states,
625
- key_states,
626
- value_states,
627
- cu_seqlens_q=cu_seqlens_q,
628
- cu_seqlens_k=cu_seqlens_k,
629
- max_seqlen_q=max_seqlen_in_batch_q,
630
- max_seqlen_k=max_seqlen_in_batch_k,
631
- dropout_p=dropout,
632
- softmax_scale=softmax_scale,
633
- causal=causal,
634
- )
635
- else:
636
- attn_output_unpad = flash_attn_varlen_func(
637
- query_states,
638
- key_states,
639
- value_states,
640
- cu_seqlens_q=cu_seqlens_q,
641
- cu_seqlens_k=cu_seqlens_k,
642
- max_seqlen_q=max_seqlen_in_batch_q,
643
- max_seqlen_k=max_seqlen_in_batch_k,
644
- dropout_p=dropout,
645
- softmax_scale=softmax_scale,
646
- causal=causal,
647
- window_size=(self.config.sliding_window, self.config.sliding_window),
648
- )
649
-
650
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
651
- else:
652
- if not use_sliding_windows:
653
- attn_output = flash_attn_func(
654
- query_states,
655
- key_states,
656
- value_states,
657
- dropout,
658
- softmax_scale=softmax_scale,
659
- causal=causal,
660
- )
661
- else:
662
- attn_output = flash_attn_func(
663
- query_states,
664
- key_states,
665
- value_states,
666
- dropout,
667
- softmax_scale=softmax_scale,
668
- causal=causal,
669
- window_size=(self.config.sliding_window, self.config.sliding_window),
670
- )
671
-
672
- return attn_output
673
-
674
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
675
- batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
676
-
677
- # On the first iteration we need to properly re-create the padding mask
678
- # by slicing it on the proper place
679
- if kv_seq_len != attention_mask.shape[-1]:
680
- attention_mask_num_tokens = attention_mask.shape[-1]
681
- attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
682
-
683
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
684
-
685
- key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
686
- value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
687
-
688
- if query_length == kv_seq_len:
689
- query_layer = index_first_axis(
690
- query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
691
- )
692
- cu_seqlens_q = cu_seqlens_k
693
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
694
- indices_q = indices_k
695
- elif query_length == 1:
696
- max_seqlen_in_batch_q = 1
697
- cu_seqlens_q = torch.arange(
698
- batch_size + 1, dtype=torch.int32, device=query_layer.device
699
- ) # There is a memcpy here, that is very bad.
700
- indices_q = cu_seqlens_q[:-1]
701
- query_layer = query_layer.squeeze(1)
702
- else:
703
- # The -q_len: slice assumes left padding.
704
- attention_mask = attention_mask[:, -query_length:]
705
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
706
-
707
- return (
708
- query_layer,
709
- key_layer,
710
- value_layer,
711
- indices_q,
712
- (cu_seqlens_q, cu_seqlens_k),
713
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
714
- )
715
-
716
-
717
  # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
718
  class QuietSdpaAttention(QuietAttention):
719
  """
@@ -1567,16 +1252,15 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1567
  return x
1568
  return x.repeat_interleave(n, dim=0)
1569
 
1570
- if self.n_passes > 1 and input_ids is not None:
1571
  input_ids = none_repeat_interleave(input_ids, self.n_passes)
1572
- attention_mask = none_repeat_interleave(attention_mask, self.n_passes) if attention_mask is not None else None
1573
- position_ids = none_repeat_interleave(position_ids, self.n_passes) if position_ids is not None else None
1574
- inputs_embeds = none_repeat_interleave(inputs_embeds, self.n_passes) if inputs_embeds is not None else None
1575
- labels = none_repeat_interleave(labels, self.n_passes) if labels is not None else None
1576
  if past_key_values is not None:
1577
  past_key_values = [none_repeat_interleave(p, self.n_passes) for p in past_key_values]
1578
- if input_ids is not None:
1579
- cur_token_indices = torch.arange(input_ids.shape[1], device=input_ids.device)
1580
 
1581
  self.tokenizer_has_start_thought_token = True
1582
  self.tokenizer_has_end_thought_token = True
 
47
  logging,
48
  replace_return_docstrings,
49
  )
 
50
  from .configuration_quiet import QuietConfig
51
 
52
  import time
53
  from typing import Optional, List
54
 
 
 
 
 
 
 
 
 
55
  logger = logging.get_logger(__name__)
56
 
57
  _CONFIG_FOR_DOC = "QuietConfig"
 
399
  return attn_output, attn_weights, past_key_value
400
 
401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
  # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
403
  class QuietSdpaAttention(QuietAttention):
404
  """
 
1252
  return x
1253
  return x.repeat_interleave(n, dim=0)
1254
 
1255
+ if self.n_passes > 1:
1256
  input_ids = none_repeat_interleave(input_ids, self.n_passes)
1257
+ attention_mask = none_repeat_interleave(attention_mask, self.n_passes)
1258
+ position_ids = none_repeat_interleave(position_ids, self.n_passes)
1259
+ inputs_embeds = none_repeat_interleave(inputs_embeds, self.n_passes)
1260
+ labels = none_repeat_interleave(labels, self.n_passes)
1261
  if past_key_values is not None:
1262
  past_key_values = [none_repeat_interleave(p, self.n_passes) for p in past_key_values]
1263
+ cur_token_indices = torch.arange(input_ids.shape[1], device=input_ids.device)
 
1264
 
1265
  self.tokenizer_has_start_thought_token = True
1266
  self.tokenizer_has_end_thought_token = True