Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +6 -322
modeling_quiet.py
CHANGED
@@ -47,20 +47,11 @@ from transformers.utils import (
|
|
47 |
logging,
|
48 |
replace_return_docstrings,
|
49 |
)
|
50 |
-
import transformers
|
51 |
from .configuration_quiet import QuietConfig
|
52 |
|
53 |
import time
|
54 |
from typing import Optional, List
|
55 |
|
56 |
-
|
57 |
-
if is_flash_attn_2_available():
|
58 |
-
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
59 |
-
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
60 |
-
|
61 |
-
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
62 |
-
|
63 |
-
|
64 |
logger = logging.get_logger(__name__)
|
65 |
|
66 |
_CONFIG_FOR_DOC = "QuietConfig"
|
@@ -408,312 +399,6 @@ class QuietAttention(nn.Module):
|
|
408 |
return attn_output, attn_weights, past_key_value
|
409 |
|
410 |
|
411 |
-
class QuietFlashAttention2(QuietAttention):
|
412 |
-
"""
|
413 |
-
Quiet flash attention module. This module inherits from `QuietAttention` as the weights of the module stays
|
414 |
-
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
415 |
-
flash attention and deal with padding tokens in case the input contains any of them.
|
416 |
-
"""
|
417 |
-
|
418 |
-
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
|
419 |
-
def __init__(self, *args, **kwargs):
|
420 |
-
super().__init__(*args, **kwargs)
|
421 |
-
|
422 |
-
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
423 |
-
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
424 |
-
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
425 |
-
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
426 |
-
|
427 |
-
def forward(
|
428 |
-
self,
|
429 |
-
hidden_states: torch.Tensor,
|
430 |
-
attention_mask: Optional[torch.Tensor] = None,
|
431 |
-
position_ids: Optional[torch.LongTensor] = None,
|
432 |
-
past_key_value: Optional[Cache] = None,
|
433 |
-
output_attentions: bool = False,
|
434 |
-
use_cache: bool = False,
|
435 |
-
**kwargs,
|
436 |
-
):
|
437 |
-
if "padding_mask" in kwargs:
|
438 |
-
warnings.warn(
|
439 |
-
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
440 |
-
)
|
441 |
-
|
442 |
-
# overwrite attention_mask with padding_mask
|
443 |
-
attention_mask = kwargs.pop("padding_mask")
|
444 |
-
bsz, q_len, _ = hidden_states.size()
|
445 |
-
|
446 |
-
query_states = self.q_proj(hidden_states)
|
447 |
-
key_states = self.k_proj(hidden_states)
|
448 |
-
value_states = self.v_proj(hidden_states)
|
449 |
-
|
450 |
-
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
451 |
-
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
452 |
-
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
453 |
-
|
454 |
-
kv_seq_len = key_states.shape[-2]
|
455 |
-
if past_key_value is not None:
|
456 |
-
if self.layer_idx is None:
|
457 |
-
raise ValueError(
|
458 |
-
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
459 |
-
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
460 |
-
"with a layer index."
|
461 |
-
)
|
462 |
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
463 |
-
|
464 |
-
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
465 |
-
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
466 |
-
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
|
467 |
-
|
468 |
-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
469 |
-
|
470 |
-
use_sliding_windows = (
|
471 |
-
_flash_supports_window_size
|
472 |
-
and getattr(self.config, "sliding_window", None) is not None
|
473 |
-
and kv_seq_len > self.config.sliding_window
|
474 |
-
)
|
475 |
-
|
476 |
-
if not _flash_supports_window_size:
|
477 |
-
logger.warning_once(
|
478 |
-
"The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
|
479 |
-
" make sure to upgrade flash-attn library."
|
480 |
-
)
|
481 |
-
|
482 |
-
if past_key_value is not None:
|
483 |
-
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
484 |
-
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
|
485 |
-
if (
|
486 |
-
getattr(self.config, "sliding_window", None) is not None
|
487 |
-
and kv_seq_len > self.config.sliding_window
|
488 |
-
and cache_has_contents
|
489 |
-
):
|
490 |
-
slicing_tokens = 1 - self.config.sliding_window
|
491 |
-
|
492 |
-
past_key = past_key_value[self.layer_idx][0]
|
493 |
-
past_value = past_key_value[self.layer_idx][1]
|
494 |
-
|
495 |
-
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
496 |
-
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
497 |
-
|
498 |
-
if past_key.shape[-2] != self.config.sliding_window - 1:
|
499 |
-
raise ValueError(
|
500 |
-
f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
501 |
-
f" {past_key.shape}"
|
502 |
-
)
|
503 |
-
|
504 |
-
if attention_mask is not None:
|
505 |
-
attention_mask = attention_mask[:, slicing_tokens:]
|
506 |
-
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
|
507 |
-
|
508 |
-
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
509 |
-
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
510 |
-
|
511 |
-
# repeat k/v heads if n_kv_heads < n_heads
|
512 |
-
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
513 |
-
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
514 |
-
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
515 |
-
|
516 |
-
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
517 |
-
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
518 |
-
# cast them back in float16 just to be sure everything works as expected.
|
519 |
-
input_dtype = query_states.dtype
|
520 |
-
if input_dtype == torch.float32:
|
521 |
-
if torch.is_autocast_enabled():
|
522 |
-
target_dtype = torch.get_autocast_gpu_dtype()
|
523 |
-
# Handle the case where the model is quantized
|
524 |
-
elif hasattr(self.config, "_pre_quantization_dtype"):
|
525 |
-
target_dtype = self.config._pre_quantization_dtype
|
526 |
-
else:
|
527 |
-
target_dtype = self.q_proj.weight.dtype
|
528 |
-
|
529 |
-
logger.warning_once(
|
530 |
-
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
531 |
-
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
532 |
-
f" {target_dtype}."
|
533 |
-
)
|
534 |
-
|
535 |
-
query_states = query_states.to(target_dtype)
|
536 |
-
key_states = key_states.to(target_dtype)
|
537 |
-
value_states = value_states.to(target_dtype)
|
538 |
-
|
539 |
-
# Reashape to the expected shape for Flash Attention
|
540 |
-
query_states = query_states.transpose(1, 2)
|
541 |
-
key_states = key_states.transpose(1, 2)
|
542 |
-
value_states = value_states.transpose(1, 2)
|
543 |
-
|
544 |
-
attn_output = self._flash_attention_forward(
|
545 |
-
query_states,
|
546 |
-
key_states,
|
547 |
-
value_states,
|
548 |
-
attention_mask,
|
549 |
-
q_len,
|
550 |
-
dropout=dropout_rate,
|
551 |
-
use_sliding_windows=use_sliding_windows,
|
552 |
-
)
|
553 |
-
|
554 |
-
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
555 |
-
attn_output = self.o_proj(attn_output)
|
556 |
-
|
557 |
-
if not output_attentions:
|
558 |
-
attn_weights = None
|
559 |
-
|
560 |
-
return attn_output, attn_weights, past_key_value
|
561 |
-
|
562 |
-
def _flash_attention_forward(
|
563 |
-
self,
|
564 |
-
query_states,
|
565 |
-
key_states,
|
566 |
-
value_states,
|
567 |
-
attention_mask,
|
568 |
-
query_length,
|
569 |
-
dropout=0.0,
|
570 |
-
softmax_scale=None,
|
571 |
-
use_sliding_windows=False,
|
572 |
-
):
|
573 |
-
"""
|
574 |
-
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
575 |
-
first unpad the input, then computes the attention scores and pad the final attention scores.
|
576 |
-
Args:
|
577 |
-
query_states (`torch.Tensor`):
|
578 |
-
Input query states to be passed to Flash Attention API
|
579 |
-
key_states (`torch.Tensor`):
|
580 |
-
Input key states to be passed to Flash Attention API
|
581 |
-
value_states (`torch.Tensor`):
|
582 |
-
Input value states to be passed to Flash Attention API
|
583 |
-
attention_mask (`torch.Tensor`):
|
584 |
-
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
585 |
-
position of padding tokens and 1 for the position of non-padding tokens.
|
586 |
-
dropout (`int`, *optional*):
|
587 |
-
Attention dropout
|
588 |
-
softmax_scale (`float`, *optional*):
|
589 |
-
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
590 |
-
use_sliding_windows (`bool`, *optional*):
|
591 |
-
Whether to activate sliding window attention.
|
592 |
-
"""
|
593 |
-
if not self._flash_attn_uses_top_left_mask:
|
594 |
-
causal = self.is_causal
|
595 |
-
else:
|
596 |
-
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
597 |
-
causal = self.is_causal and query_length != 1
|
598 |
-
|
599 |
-
# Ensure attention_mask has the correct shape and values
|
600 |
-
if attention_mask is not None:
|
601 |
-
if attention_mask.dim() == 4:
|
602 |
-
# Convert 4D attention mask to 2D
|
603 |
-
attention_mask = attention_mask.squeeze(1).squeeze(1)
|
604 |
-
elif attention_mask.dim() != 2:
|
605 |
-
raise ValueError(
|
606 |
-
f"Invalid attention mask dimension: {attention_mask.dim()}. Expected 2D or 4D mask."
|
607 |
-
)
|
608 |
-
|
609 |
-
# Ensure attention_mask has values of 0 and 1
|
610 |
-
attention_mask = attention_mask.to(torch.bool).to(torch.int32)
|
611 |
-
|
612 |
-
# Contains at least one padding token in the sequence
|
613 |
-
if attention_mask is not None:
|
614 |
-
batch_size = query_states.shape[0]
|
615 |
-
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
616 |
-
query_states, key_states, value_states, attention_mask, query_length
|
617 |
-
)
|
618 |
-
|
619 |
-
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
620 |
-
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
621 |
-
|
622 |
-
if not use_sliding_windows:
|
623 |
-
attn_output_unpad = flash_attn_varlen_func(
|
624 |
-
query_states,
|
625 |
-
key_states,
|
626 |
-
value_states,
|
627 |
-
cu_seqlens_q=cu_seqlens_q,
|
628 |
-
cu_seqlens_k=cu_seqlens_k,
|
629 |
-
max_seqlen_q=max_seqlen_in_batch_q,
|
630 |
-
max_seqlen_k=max_seqlen_in_batch_k,
|
631 |
-
dropout_p=dropout,
|
632 |
-
softmax_scale=softmax_scale,
|
633 |
-
causal=causal,
|
634 |
-
)
|
635 |
-
else:
|
636 |
-
attn_output_unpad = flash_attn_varlen_func(
|
637 |
-
query_states,
|
638 |
-
key_states,
|
639 |
-
value_states,
|
640 |
-
cu_seqlens_q=cu_seqlens_q,
|
641 |
-
cu_seqlens_k=cu_seqlens_k,
|
642 |
-
max_seqlen_q=max_seqlen_in_batch_q,
|
643 |
-
max_seqlen_k=max_seqlen_in_batch_k,
|
644 |
-
dropout_p=dropout,
|
645 |
-
softmax_scale=softmax_scale,
|
646 |
-
causal=causal,
|
647 |
-
window_size=(self.config.sliding_window, self.config.sliding_window),
|
648 |
-
)
|
649 |
-
|
650 |
-
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
651 |
-
else:
|
652 |
-
if not use_sliding_windows:
|
653 |
-
attn_output = flash_attn_func(
|
654 |
-
query_states,
|
655 |
-
key_states,
|
656 |
-
value_states,
|
657 |
-
dropout,
|
658 |
-
softmax_scale=softmax_scale,
|
659 |
-
causal=causal,
|
660 |
-
)
|
661 |
-
else:
|
662 |
-
attn_output = flash_attn_func(
|
663 |
-
query_states,
|
664 |
-
key_states,
|
665 |
-
value_states,
|
666 |
-
dropout,
|
667 |
-
softmax_scale=softmax_scale,
|
668 |
-
causal=causal,
|
669 |
-
window_size=(self.config.sliding_window, self.config.sliding_window),
|
670 |
-
)
|
671 |
-
|
672 |
-
return attn_output
|
673 |
-
|
674 |
-
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
675 |
-
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
676 |
-
|
677 |
-
# On the first iteration we need to properly re-create the padding mask
|
678 |
-
# by slicing it on the proper place
|
679 |
-
if kv_seq_len != attention_mask.shape[-1]:
|
680 |
-
attention_mask_num_tokens = attention_mask.shape[-1]
|
681 |
-
attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
|
682 |
-
|
683 |
-
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
684 |
-
|
685 |
-
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
686 |
-
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
687 |
-
|
688 |
-
if query_length == kv_seq_len:
|
689 |
-
query_layer = index_first_axis(
|
690 |
-
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
|
691 |
-
)
|
692 |
-
cu_seqlens_q = cu_seqlens_k
|
693 |
-
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
694 |
-
indices_q = indices_k
|
695 |
-
elif query_length == 1:
|
696 |
-
max_seqlen_in_batch_q = 1
|
697 |
-
cu_seqlens_q = torch.arange(
|
698 |
-
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
699 |
-
) # There is a memcpy here, that is very bad.
|
700 |
-
indices_q = cu_seqlens_q[:-1]
|
701 |
-
query_layer = query_layer.squeeze(1)
|
702 |
-
else:
|
703 |
-
# The -q_len: slice assumes left padding.
|
704 |
-
attention_mask = attention_mask[:, -query_length:]
|
705 |
-
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
706 |
-
|
707 |
-
return (
|
708 |
-
query_layer,
|
709 |
-
key_layer,
|
710 |
-
value_layer,
|
711 |
-
indices_q,
|
712 |
-
(cu_seqlens_q, cu_seqlens_k),
|
713 |
-
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
714 |
-
)
|
715 |
-
|
716 |
-
|
717 |
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
|
718 |
class QuietSdpaAttention(QuietAttention):
|
719 |
"""
|
@@ -1567,16 +1252,15 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1567 |
return x
|
1568 |
return x.repeat_interleave(n, dim=0)
|
1569 |
|
1570 |
-
if self.n_passes > 1
|
1571 |
input_ids = none_repeat_interleave(input_ids, self.n_passes)
|
1572 |
-
attention_mask = none_repeat_interleave(attention_mask, self.n_passes)
|
1573 |
-
position_ids = none_repeat_interleave(position_ids, self.n_passes)
|
1574 |
-
inputs_embeds = none_repeat_interleave(inputs_embeds, self.n_passes)
|
1575 |
-
labels = none_repeat_interleave(labels, self.n_passes)
|
1576 |
if past_key_values is not None:
|
1577 |
past_key_values = [none_repeat_interleave(p, self.n_passes) for p in past_key_values]
|
1578 |
-
|
1579 |
-
cur_token_indices = torch.arange(input_ids.shape[1], device=input_ids.device)
|
1580 |
|
1581 |
self.tokenizer_has_start_thought_token = True
|
1582 |
self.tokenizer_has_end_thought_token = True
|
|
|
47 |
logging,
|
48 |
replace_return_docstrings,
|
49 |
)
|
|
|
50 |
from .configuration_quiet import QuietConfig
|
51 |
|
52 |
import time
|
53 |
from typing import Optional, List
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
logger = logging.get_logger(__name__)
|
56 |
|
57 |
_CONFIG_FOR_DOC = "QuietConfig"
|
|
|
399 |
return attn_output, attn_weights, past_key_value
|
400 |
|
401 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
402 |
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
|
403 |
class QuietSdpaAttention(QuietAttention):
|
404 |
"""
|
|
|
1252 |
return x
|
1253 |
return x.repeat_interleave(n, dim=0)
|
1254 |
|
1255 |
+
if self.n_passes > 1:
|
1256 |
input_ids = none_repeat_interleave(input_ids, self.n_passes)
|
1257 |
+
attention_mask = none_repeat_interleave(attention_mask, self.n_passes)
|
1258 |
+
position_ids = none_repeat_interleave(position_ids, self.n_passes)
|
1259 |
+
inputs_embeds = none_repeat_interleave(inputs_embeds, self.n_passes)
|
1260 |
+
labels = none_repeat_interleave(labels, self.n_passes)
|
1261 |
if past_key_values is not None:
|
1262 |
past_key_values = [none_repeat_interleave(p, self.n_passes) for p in past_key_values]
|
1263 |
+
cur_token_indices = torch.arange(input_ids.shape[1], device=input_ids.device)
|
|
|
1264 |
|
1265 |
self.tokenizer_has_start_thought_token = True
|
1266 |
self.tokenizer_has_end_thought_token = True
|