Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +13 -31
modeling_quiet.py
CHANGED
@@ -136,22 +136,17 @@ def save_tokens_with_rewards_to_pdf(input_ids, token_rewards, tokenizer, output_
|
|
136 |
c.save()
|
137 |
|
138 |
|
|
|
139 |
def _get_unpad_data(attention_mask):
|
140 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
141 |
-
|
142 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
max_seqlen_in_batch
|
149 |
-
|
150 |
-
# Ensure seqlens_in_batch has the correct shape before cumulative sum
|
151 |
-
seqlens_in_batch = seqlens_in_batch.view(-1)
|
152 |
-
cu_seqlens = torch.cat([torch.zeros(1, dtype=torch.int32, device=attention_mask.device), seqlens_in_batch.cumsum(dim=0)])
|
153 |
-
|
154 |
-
return indices, cu_seqlens, max_seqlen_in_batch
|
155 |
|
156 |
|
157 |
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Quiet
|
@@ -545,7 +540,7 @@ class QuietFlashAttention2(QuietAttention):
|
|
545 |
value_states = value_states.to(target_dtype)
|
546 |
|
547 |
# Compute the causal mask
|
548 |
-
causal = self.config.
|
549 |
if causal:
|
550 |
if self._flash_attn_uses_top_left_mask:
|
551 |
# Compute the causal mask
|
@@ -583,7 +578,7 @@ class QuietFlashAttention2(QuietAttention):
|
|
583 |
indices_q,
|
584 |
cu_seq_lens,
|
585 |
max_seq_lens,
|
586 |
-
) = self.
|
587 |
|
588 |
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
589 |
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
@@ -656,7 +651,7 @@ class QuietFlashAttention2(QuietAttention):
|
|
656 |
|
657 |
return attn_output
|
658 |
|
659 |
-
def
|
660 |
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
661 |
|
662 |
# On the first iteration we need to properly re-create the padding mask
|
@@ -665,23 +660,10 @@ class QuietFlashAttention2(QuietAttention):
|
|
665 |
attention_mask_num_tokens = attention_mask.shape[-1]
|
666 |
attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
|
667 |
|
668 |
-
# Check if attention_mask is empty or all zeros
|
669 |
-
if attention_mask.numel() == 0 or attention_mask.sum() == 0:
|
670 |
-
# Return the original query_layer, key_layer, and value_layer without modifications
|
671 |
-
return (
|
672 |
-
query_layer,
|
673 |
-
key_layer,
|
674 |
-
value_layer,
|
675 |
-
torch.arange(batch_size, device=query_layer.device),
|
676 |
-
(torch.arange(batch_size + 1, device=query_layer.device), torch.arange(batch_size + 1, device=key_layer.device)),
|
677 |
-
(query_layer.shape[1], key_layer.shape[1]),
|
678 |
-
)
|
679 |
-
|
680 |
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
681 |
|
682 |
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
683 |
-
value_layer
|
684 |
-
|
685 |
if query_length == kv_seq_len:
|
686 |
query_layer = index_first_axis(
|
687 |
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
|
@@ -786,7 +768,7 @@ class QuietSdpaAttention(QuietAttention):
|
|
786 |
attn_mask=attention_mask.to(query_states.device) if attention_mask is not None else None,
|
787 |
dropout_p=self.attention_dropout if self.training else 0.0,
|
788 |
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
789 |
-
|
790 |
)
|
791 |
|
792 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
136 |
c.save()
|
137 |
|
138 |
|
139 |
+
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
140 |
def _get_unpad_data(attention_mask):
|
141 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
|
|
142 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
143 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
144 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
145 |
+
return (
|
146 |
+
indices,
|
147 |
+
cu_seqlens,
|
148 |
+
max_seqlen_in_batch,
|
149 |
+
)
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
|
152 |
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Quiet
|
|
|
540 |
value_states = value_states.to(target_dtype)
|
541 |
|
542 |
# Compute the causal mask
|
543 |
+
causal = self.config.causal
|
544 |
if causal:
|
545 |
if self._flash_attn_uses_top_left_mask:
|
546 |
# Compute the causal mask
|
|
|
578 |
indices_q,
|
579 |
cu_seq_lens,
|
580 |
max_seq_lens,
|
581 |
+
) = self._upad_input(query_states, key_states, value_states, attention_mask, q_len)
|
582 |
|
583 |
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
584 |
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
|
|
651 |
|
652 |
return attn_output
|
653 |
|
654 |
+
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
655 |
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
656 |
|
657 |
# On the first iteration we need to properly re-create the padding mask
|
|
|
660 |
attention_mask_num_tokens = attention_mask.shape[-1]
|
661 |
attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
|
662 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
663 |
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
664 |
|
665 |
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
666 |
+
value_layer= index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
|
|
667 |
if query_length == kv_seq_len:
|
668 |
query_layer = index_first_axis(
|
669 |
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
|
|
|
768 |
attn_mask=attention_mask.to(query_states.device) if attention_mask is not None else None,
|
769 |
dropout_p=self.attention_dropout if self.training else 0.0,
|
770 |
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
771 |
+
causal=self.is_causal and attention_mask is None and q_len > 1,
|
772 |
)
|
773 |
|
774 |
attn_output = attn_output.transpose(1, 2).contiguous()
|