Crystalcareai commited on
Commit
c0c99ee
·
verified ·
1 Parent(s): 16f10b2

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +44 -50
modeling_quiet.py CHANGED
@@ -373,27 +373,13 @@ class QuietAttention(nn.Module):
373
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
374
  )
375
 
376
- # print("Hidden states contains NaN before applying attention mask:", torch.isnan(hidden_states).any().item())
377
- # print("Attention mask contains NaN:", torch.isnan(attention_mask).any().item())
378
-
379
  attn_weights = attn_weights + attention_mask
380
 
381
- # print("Attention weights contains NaN after applying attention mask:", torch.isnan(attn_weights).any().item())
382
-
383
  # upcast attention to fp32
384
- # print("Attention weights contains NaN before softmax:", torch.isnan(attn_weights).any().item())
385
-
386
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
387
-
388
- # print("Attention weights contains NaN after softmax:", torch.isnan(attn_weights).any().item())
389
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
390
- # print("Attention weights contains NaN before matmul:", torch.isnan(attn_weights).any().item())
391
- # print("Value states contains NaN before matmul:", torch.isnan(value_states).any().item())
392
-
393
  attn_output = torch.matmul(attn_weights, value_states)
394
 
395
- # print("Attention output contains NaN:", torch.isnan(attn_output).any().item())
396
-
397
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
398
  raise ValueError(
399
  f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
@@ -1085,22 +1071,27 @@ class QuietModel(QuietPreTrainedModel):
1085
  " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1086
  )
1087
 
1088
- if attention_mask is None:
1089
- attention_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=inputs_embeds.device)
1090
-
1091
- if attention_mask.dim() == 2:
1092
- attention_mask = attention_mask.view(batch_size, 1, 1, seq_len)
1093
- attention_mask = attention_mask.expand(batch_size, 1, seq_len, seq_len)
1094
- elif attention_mask.dim() == 3:
1095
- attention_mask = attention_mask.unsqueeze(1)
1096
- elif attention_mask.dim() != 4:
1097
- raise ValueError(f"Attention mask should be of shape (batch_size, 1, seq_len, seq_len) or (batch_size, 1, 1, seq_len), but got {attention_mask.shape}")
1098
-
1099
- attention_mask = attention_mask.to(dtype=torch.bool)
1100
-
1101
-
1102
- print("Attention mask shape after expansion:", attention_mask.shape)
1103
- print("Attention mask contains NaN:", torch.isnan(attention_mask).any().item())
 
 
 
 
 
1104
 
1105
  hidden_states = inputs_embeds
1106
 
@@ -1110,7 +1101,6 @@ class QuietModel(QuietPreTrainedModel):
1110
  next_decoder_cache = None
1111
 
1112
  for decoder_layer in self.layers:
1113
- print(f"Hidden states contains NaN before layer {id}:", torch.isnan(hidden_states).any().item())
1114
  if output_hidden_states:
1115
  all_hidden_states += (hidden_states,)
1116
 
@@ -1168,7 +1158,6 @@ def nonzero_mean(x, axis=None):
1168
 
1169
  def loss_mean(x):
1170
  return x.sum() / (x != 0).sum()
1171
- print(f"Hidden states contains NaN after layer {id}:", torch.isnan(hidden_states).any().item())
1172
 
1173
  class QuietForCausalLM(QuietPreTrainedModel):
1174
  _tied_weights_keys = ["lm_head.weight"]
@@ -1353,8 +1342,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
1353
  return_dict=return_dict,
1354
  )
1355
  new_key_values = outputs.past_key_values
1356
- print(f"Hidden states contains NaN: {torch.isnan(hidden_states).any().item()}")
1357
-
1358
  hidden_states = outputs[0]
1359
  logits = self.lm_head(hidden_states)
1360
  logits = logits[:, -1, :] # Only consider the last token
@@ -1896,15 +1883,29 @@ class QuietForCausalLM(QuietPreTrainedModel):
1896
  inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
1897
  inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
1898
 
1899
- if attention_mask is not None:
1900
- if attention_mask.dim() == 2:
1901
- attention_mask = attention_mask.view(batch_size, 1, 1, seq_len)
1902
- attention_mask = attention_mask.expand(batch_size, 1, seq_len, seq_len)
1903
- elif attention_mask.dim() != 4:
1904
- raise ValueError(f"Attention mask should be of shape (batch_size, 1, seq_len, seq_len), but got {attention_mask.shape}")
1905
-
1906
- attention_mask = attention_mask.to(dtype=torch.bool)
1907
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1908
  past_key_values = outputs.past_key_values
1909
  position_ids = position_ids + 1
1910
 
@@ -1917,16 +1918,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
1917
  else:
1918
  loss_logits = logits
1919
  shift_idx = 1 + max(0, ahead_idx - (self.n_ahead - 1))
1920
- # print("initial_loss_logits contains NaN:", torch.isnan(initial_loss_logits).any().item())
1921
- # print("logits contains NaN:", torch.isnan(logits).any().item())
1922
- # print("loss_logits contains NaN:", torch.isnan(loss_logits).any().item())
1923
-
1924
  shift_logits = loss_logits[..., :-shift_idx, :].contiguous()
1925
- # print("shift_logits contains NaN:", torch.isnan(shift_logits).any().item())
1926
  shift_labels = labels[..., shift_idx:].contiguous()
1927
  # Flatten the tokens
1928
- # assert not torch.isnan(shift_logits).any(), "NaN values found in shift_logits"
1929
- # assert not torch.isnan(shift_labels).any(), "NaN values found in shift_labels"
1930
  loss_fct = CrossEntropyLoss(reduction="none")
1931
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
1932
  shift_labels = shift_labels.view(-1)
 
373
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
374
  )
375
 
 
 
 
376
  attn_weights = attn_weights + attention_mask
377
 
 
 
378
  # upcast attention to fp32
 
 
379
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 
 
380
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
 
 
 
381
  attn_output = torch.matmul(attn_weights, value_states)
382
 
 
 
383
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
384
  raise ValueError(
385
  f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
 
1071
  " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1072
  )
1073
 
1074
+ if self._attn_implementation == "flash_attention_2":
1075
+ # 2d mask is passed through the layers
1076
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1077
+ elif self._attn_implementation == "sdpa" and not output_attentions and attention_mask.dim() == 2 and False:
1078
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
1079
+ # the manual implementation that requires a 4D causal mask in all cases.
1080
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1081
+ attention_mask,
1082
+ (batch_size, seq_length),
1083
+ inputs_embeds,
1084
+ past_key_values_length,
1085
+ )
1086
+ elif attention_mask is None or attention_mask.dim() == 2:
1087
+ # 4d mask is passed through the layers
1088
+ attention_mask = _prepare_4d_causal_attention_mask(
1089
+ attention_mask,
1090
+ (batch_size, seq_length),
1091
+ inputs_embeds,
1092
+ past_key_values_length,
1093
+ sliding_window=self.config.sliding_window,
1094
+ )
1095
 
1096
  hidden_states = inputs_embeds
1097
 
 
1101
  next_decoder_cache = None
1102
 
1103
  for decoder_layer in self.layers:
 
1104
  if output_hidden_states:
1105
  all_hidden_states += (hidden_states,)
1106
 
 
1158
 
1159
  def loss_mean(x):
1160
  return x.sum() / (x != 0).sum()
 
1161
 
1162
  class QuietForCausalLM(QuietPreTrainedModel):
1163
  _tied_weights_keys = ["lm_head.weight"]
 
1342
  return_dict=return_dict,
1343
  )
1344
  new_key_values = outputs.past_key_values
 
 
1345
  hidden_states = outputs[0]
1346
  logits = self.lm_head(hidden_states)
1347
  logits = logits[:, -1, :] # Only consider the last token
 
1883
  inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
1884
  inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
1885
 
1886
+ if len(attention_mask.shape) == 2:
1887
+ breakpoint()
1888
+ else:
1889
+ original_attention = attention_mask[..., :attention_mask.shape[-2]]
1890
+ if self.use_upper_triangular:
1891
+ new_attention = original_attention
1892
+ else:
1893
+ original_attention = original_attention == attention_mask.max()
1894
+ # because eye isn't implemented for BF16, we need to handle the case
1895
+ if not attention_mask.dtype == torch.bfloat16:
1896
+ new_attention = torch.eye(
1897
+ seq_len, dtype=attention_mask.dtype, device=attention_mask.device
1898
+ )
1899
+ else:
1900
+ new_attention = torch.eye(
1901
+ seq_len, dtype=torch.float32, device=attention_mask.device
1902
+ ).to(attention_mask.dtype)
1903
+
1904
+ new_attention = new_attention.view(1, 1, seq_len, seq_len).repeat(input_ids.shape[0], 1, 1, 1)
1905
+ new_attention = new_attention * original_attention
1906
+ new_attention[new_attention == 0] = attention_mask.min()
1907
+ new_attention[new_attention == 1] = attention_mask.max()
1908
+ attention_mask = torch.cat([attention_mask, new_attention], dim=-1)
1909
  past_key_values = outputs.past_key_values
1910
  position_ids = position_ids + 1
1911
 
 
1918
  else:
1919
  loss_logits = logits
1920
  shift_idx = 1 + max(0, ahead_idx - (self.n_ahead - 1))
 
 
 
 
1921
  shift_logits = loss_logits[..., :-shift_idx, :].contiguous()
 
1922
  shift_labels = labels[..., shift_idx:].contiguous()
1923
  # Flatten the tokens
 
 
1924
  loss_fct = CrossEntropyLoss(reduction="none")
1925
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
1926
  shift_labels = shift_labels.view(-1)