Crystalcareai commited on
Commit
b11137f
·
verified ·
1 Parent(s): 1ec7ec7

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +8 -48
modeling_quiet.py CHANGED
@@ -54,61 +54,21 @@ _CONFIG_FOR_DOC = "QuietConfig"
54
 
55
 
56
  def _prepare_4d_causal_attention_mask_for_sdpa(attention_mask, input_shape, inputs_embeds, past_key_values_length):
57
- # Compute the attention mask correctly
58
  bsz, tgt_len = input_shape
59
 
60
- # Create a 4D attention mask from a 2D tensor mask.
61
- # The shape of the output attention mask is (batch_size, 1, tgt_len, src_len)
62
- # The values are either 0 or 1, where 0 means padding and 1 means non-padding.
63
- combined_attention_mask = None
64
  if attention_mask is not None:
65
- # What if attention_mask is not None and has a shape of (batch_size, 1, tgt_len, src_len)
66
- # In this case, we can just use it directly.
67
- if attention_mask.dim() == 4:
68
- combined_attention_mask = attention_mask
69
- # What if attention_mask is not None and has a shape of (batch_size, 1, tgt_len)
70
- # In this case, we need to expand it to (batch_size, 1, tgt_len, src_len)
71
- elif attention_mask.dim() == 3:
72
- expanded_attn_mask = attention_mask[:, None, :, :]
73
- combined_attention_mask = expanded_attn_mask
74
- # What if attention_mask is not None and has a shape of (batch_size, tgt_len)
75
- # In this case, we need to expand it to (batch_size, 1, tgt_len, src_len)
76
  elif attention_mask.dim() == 2:
77
- # Provided a padding mask of dimensions [batch_size, seq_length]
78
- # - if the model is a decoder, apply a causal mask in addition to the padding mask
79
- # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
80
- if past_key_values_length > 0:
81
- attention_mask = attention_mask.to(dtype=torch.long)
82
- attention_mask = attention_mask[:, past_key_values_length:]
83
- expanded_attn_mask = attention_mask[:, None, None, :]
84
- combined_attention_mask = expanded_attn_mask
85
  else:
86
- raise ValueError(
87
- "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
88
- input_shape, attention_mask.shape
89
- )
90
- )
91
 
92
- # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
93
- # masked positions, this operation will create a tensor which is 0.0 for
94
- # positions we want to attend and -10000.0 for masked positions.
95
- # Since we are adding it to the raw scores before the softmax, this is
96
- # effectively the same as removing these entirely.
97
- if combined_attention_mask is not None:
98
- # Ensure the attention mask values are within a reasonable range
99
- combined_attention_mask = combined_attention_mask.clamp(min=0, max=1)
100
-
101
- # Convert the attention mask to bfloat16
102
- combined_attention_mask = combined_attention_mask.to(torch.bfloat16)
103
-
104
- # Normalize the attention mask values to be between 0 and 1
105
- combined_attention_mask = (1.0 - combined_attention_mask) * -10000.0
106
- else:
107
- combined_attention_mask = torch.zeros(
108
- (bsz, 1, tgt_len, tgt_len), dtype=torch.bfloat16, device=inputs_embeds.device
109
- )
110
 
111
- return combined_attention_mask
112
 
113
 
114
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
 
54
 
55
 
56
  def _prepare_4d_causal_attention_mask_for_sdpa(attention_mask, input_shape, inputs_embeds, past_key_values_length):
 
57
  bsz, tgt_len = input_shape
58
 
 
 
 
 
59
  if attention_mask is not None:
60
+ if attention_mask.dim() == 3:
61
+ # Expanding from [batch_size, 1, tgt_len] to [batch_size, 1, tgt_len, tgt_len]
62
+ attention_mask = attention_mask.expand(bsz, 1, tgt_len, tgt_len)
 
 
 
 
 
 
 
 
63
  elif attention_mask.dim() == 2:
64
+ # Expanding from [batch_size, tgt_len] to [batch_size, 1, tgt_len, tgt_len]
65
+ attention_mask = attention_mask.unsqueeze(1).expand(bsz, 1, tgt_len, tgt_len)
 
 
 
 
 
 
66
  else:
67
+ raise ValueError(f"Unexpected attention mask shape: {attention_mask.shape}, expected 2 or 3 dimensions.")
 
 
 
 
68
 
69
+ attention_mask = (1.0 - attention_mask) * -10000.0 # Masking operation for softmax
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ return attention_mask
72
 
73
 
74
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data