Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +44 -50
modeling_quiet.py
CHANGED
@@ -373,27 +373,13 @@ class QuietAttention(nn.Module):
|
|
373 |
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
374 |
)
|
375 |
|
376 |
-
# print("Hidden states contains NaN before applying attention mask:", torch.isnan(hidden_states).any().item())
|
377 |
-
# print("Attention mask contains NaN:", torch.isnan(attention_mask).any().item())
|
378 |
-
|
379 |
attn_weights = attn_weights + attention_mask
|
380 |
|
381 |
-
# print("Attention weights contains NaN after applying attention mask:", torch.isnan(attn_weights).any().item())
|
382 |
-
|
383 |
# upcast attention to fp32
|
384 |
-
# print("Attention weights contains NaN before softmax:", torch.isnan(attn_weights).any().item())
|
385 |
-
|
386 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
387 |
-
|
388 |
-
# print("Attention weights contains NaN after softmax:", torch.isnan(attn_weights).any().item())
|
389 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
390 |
-
# print("Attention weights contains NaN before matmul:", torch.isnan(attn_weights).any().item())
|
391 |
-
# print("Value states contains NaN before matmul:", torch.isnan(value_states).any().item())
|
392 |
-
|
393 |
attn_output = torch.matmul(attn_weights, value_states)
|
394 |
|
395 |
-
# print("Attention output contains NaN:", torch.isnan(attn_output).any().item())
|
396 |
-
|
397 |
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
398 |
raise ValueError(
|
399 |
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
@@ -1085,22 +1071,27 @@ class QuietModel(QuietPreTrainedModel):
|
|
1085 |
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
1086 |
)
|
1087 |
|
1088 |
-
if
|
1089 |
-
|
1090 |
-
|
1091 |
-
|
1092 |
-
|
1093 |
-
|
1094 |
-
|
1095 |
-
|
1096 |
-
|
1097 |
-
|
1098 |
-
|
1099 |
-
|
1100 |
-
|
1101 |
-
|
1102 |
-
|
1103 |
-
|
|
|
|
|
|
|
|
|
|
|
1104 |
|
1105 |
hidden_states = inputs_embeds
|
1106 |
|
@@ -1110,7 +1101,6 @@ class QuietModel(QuietPreTrainedModel):
|
|
1110 |
next_decoder_cache = None
|
1111 |
|
1112 |
for decoder_layer in self.layers:
|
1113 |
-
print(f"Hidden states contains NaN before layer {id}:", torch.isnan(hidden_states).any().item())
|
1114 |
if output_hidden_states:
|
1115 |
all_hidden_states += (hidden_states,)
|
1116 |
|
@@ -1168,7 +1158,6 @@ def nonzero_mean(x, axis=None):
|
|
1168 |
|
1169 |
def loss_mean(x):
|
1170 |
return x.sum() / (x != 0).sum()
|
1171 |
-
print(f"Hidden states contains NaN after layer {id}:", torch.isnan(hidden_states).any().item())
|
1172 |
|
1173 |
class QuietForCausalLM(QuietPreTrainedModel):
|
1174 |
_tied_weights_keys = ["lm_head.weight"]
|
@@ -1353,8 +1342,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1353 |
return_dict=return_dict,
|
1354 |
)
|
1355 |
new_key_values = outputs.past_key_values
|
1356 |
-
print(f"Hidden states contains NaN: {torch.isnan(hidden_states).any().item()}")
|
1357 |
-
|
1358 |
hidden_states = outputs[0]
|
1359 |
logits = self.lm_head(hidden_states)
|
1360 |
logits = logits[:, -1, :] # Only consider the last token
|
@@ -1896,15 +1883,29 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1896 |
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
1897 |
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
1898 |
|
1899 |
-
if attention_mask
|
1900 |
-
|
1901 |
-
|
1902 |
-
|
1903 |
-
|
1904 |
-
|
1905 |
-
|
1906 |
-
|
1907 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1908 |
past_key_values = outputs.past_key_values
|
1909 |
position_ids = position_ids + 1
|
1910 |
|
@@ -1917,16 +1918,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1917 |
else:
|
1918 |
loss_logits = logits
|
1919 |
shift_idx = 1 + max(0, ahead_idx - (self.n_ahead - 1))
|
1920 |
-
# print("initial_loss_logits contains NaN:", torch.isnan(initial_loss_logits).any().item())
|
1921 |
-
# print("logits contains NaN:", torch.isnan(logits).any().item())
|
1922 |
-
# print("loss_logits contains NaN:", torch.isnan(loss_logits).any().item())
|
1923 |
-
|
1924 |
shift_logits = loss_logits[..., :-shift_idx, :].contiguous()
|
1925 |
-
# print("shift_logits contains NaN:", torch.isnan(shift_logits).any().item())
|
1926 |
shift_labels = labels[..., shift_idx:].contiguous()
|
1927 |
# Flatten the tokens
|
1928 |
-
# assert not torch.isnan(shift_logits).any(), "NaN values found in shift_logits"
|
1929 |
-
# assert not torch.isnan(shift_labels).any(), "NaN values found in shift_labels"
|
1930 |
loss_fct = CrossEntropyLoss(reduction="none")
|
1931 |
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
1932 |
shift_labels = shift_labels.view(-1)
|
|
|
373 |
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
374 |
)
|
375 |
|
|
|
|
|
|
|
376 |
attn_weights = attn_weights + attention_mask
|
377 |
|
|
|
|
|
378 |
# upcast attention to fp32
|
|
|
|
|
379 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
|
|
|
|
380 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
|
|
|
|
|
|
381 |
attn_output = torch.matmul(attn_weights, value_states)
|
382 |
|
|
|
|
|
383 |
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
384 |
raise ValueError(
|
385 |
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
|
|
1071 |
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
1072 |
)
|
1073 |
|
1074 |
+
if self._attn_implementation == "flash_attention_2":
|
1075 |
+
# 2d mask is passed through the layers
|
1076 |
+
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
1077 |
+
elif self._attn_implementation == "sdpa" and not output_attentions and attention_mask.dim() == 2 and False:
|
1078 |
+
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
1079 |
+
# the manual implementation that requires a 4D causal mask in all cases.
|
1080 |
+
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
1081 |
+
attention_mask,
|
1082 |
+
(batch_size, seq_length),
|
1083 |
+
inputs_embeds,
|
1084 |
+
past_key_values_length,
|
1085 |
+
)
|
1086 |
+
elif attention_mask is None or attention_mask.dim() == 2:
|
1087 |
+
# 4d mask is passed through the layers
|
1088 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
1089 |
+
attention_mask,
|
1090 |
+
(batch_size, seq_length),
|
1091 |
+
inputs_embeds,
|
1092 |
+
past_key_values_length,
|
1093 |
+
sliding_window=self.config.sliding_window,
|
1094 |
+
)
|
1095 |
|
1096 |
hidden_states = inputs_embeds
|
1097 |
|
|
|
1101 |
next_decoder_cache = None
|
1102 |
|
1103 |
for decoder_layer in self.layers:
|
|
|
1104 |
if output_hidden_states:
|
1105 |
all_hidden_states += (hidden_states,)
|
1106 |
|
|
|
1158 |
|
1159 |
def loss_mean(x):
|
1160 |
return x.sum() / (x != 0).sum()
|
|
|
1161 |
|
1162 |
class QuietForCausalLM(QuietPreTrainedModel):
|
1163 |
_tied_weights_keys = ["lm_head.weight"]
|
|
|
1342 |
return_dict=return_dict,
|
1343 |
)
|
1344 |
new_key_values = outputs.past_key_values
|
|
|
|
|
1345 |
hidden_states = outputs[0]
|
1346 |
logits = self.lm_head(hidden_states)
|
1347 |
logits = logits[:, -1, :] # Only consider the last token
|
|
|
1883 |
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
1884 |
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
1885 |
|
1886 |
+
if len(attention_mask.shape) == 2:
|
1887 |
+
breakpoint()
|
1888 |
+
else:
|
1889 |
+
original_attention = attention_mask[..., :attention_mask.shape[-2]]
|
1890 |
+
if self.use_upper_triangular:
|
1891 |
+
new_attention = original_attention
|
1892 |
+
else:
|
1893 |
+
original_attention = original_attention == attention_mask.max()
|
1894 |
+
# because eye isn't implemented for BF16, we need to handle the case
|
1895 |
+
if not attention_mask.dtype == torch.bfloat16:
|
1896 |
+
new_attention = torch.eye(
|
1897 |
+
seq_len, dtype=attention_mask.dtype, device=attention_mask.device
|
1898 |
+
)
|
1899 |
+
else:
|
1900 |
+
new_attention = torch.eye(
|
1901 |
+
seq_len, dtype=torch.float32, device=attention_mask.device
|
1902 |
+
).to(attention_mask.dtype)
|
1903 |
+
|
1904 |
+
new_attention = new_attention.view(1, 1, seq_len, seq_len).repeat(input_ids.shape[0], 1, 1, 1)
|
1905 |
+
new_attention = new_attention * original_attention
|
1906 |
+
new_attention[new_attention == 0] = attention_mask.min()
|
1907 |
+
new_attention[new_attention == 1] = attention_mask.max()
|
1908 |
+
attention_mask = torch.cat([attention_mask, new_attention], dim=-1)
|
1909 |
past_key_values = outputs.past_key_values
|
1910 |
position_ids = position_ids + 1
|
1911 |
|
|
|
1918 |
else:
|
1919 |
loss_logits = logits
|
1920 |
shift_idx = 1 + max(0, ahead_idx - (self.n_ahead - 1))
|
|
|
|
|
|
|
|
|
1921 |
shift_logits = loss_logits[..., :-shift_idx, :].contiguous()
|
|
|
1922 |
shift_labels = labels[..., shift_idx:].contiguous()
|
1923 |
# Flatten the tokens
|
|
|
|
|
1924 |
loss_fct = CrossEntropyLoss(reduction="none")
|
1925 |
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
1926 |
shift_labels = shift_labels.view(-1)
|