Update modeling_quiet.py
Browse files- modeling_quiet.py +21 -1
modeling_quiet.py
CHANGED
@@ -1281,7 +1281,27 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1281 |
# Generate the continuation
|
1282 |
continuation_length = self.n_ahead - 2
|
1283 |
new_key_values = past_key_values
|
1284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1285 |
start_time = time.time()
|
1286 |
for continuation_idx in range(continuation_length):
|
1287 |
outputs = self.model(
|
|
|
1281 |
# Generate the continuation
|
1282 |
continuation_length = self.n_ahead - 2
|
1283 |
new_key_values = past_key_values
|
1284 |
+
|
1285 |
+
if self.n_ahead != 1 or self.n_ahead_talk != 1 or self.comparison_mode:
|
1286 |
+
if attention_mask is None:
|
1287 |
+
base_attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=0).to(input_ids.device)
|
1288 |
+
base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
|
1289 |
+
base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
|
1290 |
+
attention_mask = base_attention_mask
|
1291 |
+
elif attention_mask.dim() == 2:
|
1292 |
+
if seq_len + past_key_values_length != attention_mask.shape[-1]:
|
1293 |
+
attention_mask = torch.cat(
|
1294 |
+
[torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
|
1295 |
+
dim=-1
|
1296 |
+
)
|
1297 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
1298 |
+
attention_mask,
|
1299 |
+
(batch_size, seq_len),
|
1300 |
+
inputs_embeds,
|
1301 |
+
past_key_values_length,
|
1302 |
+
sliding_window=self.config.sliding_window,
|
1303 |
+
)
|
1304 |
+
|
1305 |
start_time = time.time()
|
1306 |
for continuation_idx in range(continuation_length):
|
1307 |
outputs = self.model(
|