Update modeling_quiet.py
Browse files- modeling_quiet.py +2 -2
modeling_quiet.py
CHANGED
@@ -1674,10 +1674,10 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1674 |
base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
|
1675 |
base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
|
1676 |
attention_mask = base_attention_mask
|
1677 |
-
breakpoint()
|
1678 |
elif attention_mask.dim() == 2:
|
1679 |
if seq_len + past_key_values_length != attention_mask.shape[-1]:
|
1680 |
-
breakpoint()
|
1681 |
attention_mask = torch.cat(
|
1682 |
[torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
|
1683 |
dim=-1
|
|
|
1674 |
base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
|
1675 |
base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
|
1676 |
attention_mask = base_attention_mask
|
1677 |
+
# breakpoint()
|
1678 |
elif attention_mask.dim() == 2:
|
1679 |
if seq_len + past_key_values_length != attention_mask.shape[-1]:
|
1680 |
+
# breakpoint()
|
1681 |
attention_mask = torch.cat(
|
1682 |
[torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
|
1683 |
dim=-1
|