Crystalcareai commited on
Commit
38e552e
·
verified ·
1 Parent(s): b66816f

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +21 -1
modeling_quiet.py CHANGED
@@ -1281,7 +1281,27 @@ class QuietForCausalLM(QuietPreTrainedModel):
1281
  # Generate the continuation
1282
  continuation_length = self.n_ahead - 2
1283
  new_key_values = past_key_values
1284
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1285
  start_time = time.time()
1286
  for continuation_idx in range(continuation_length):
1287
  outputs = self.model(
 
1281
  # Generate the continuation
1282
  continuation_length = self.n_ahead - 2
1283
  new_key_values = past_key_values
1284
+
1285
+ if self.n_ahead != 1 or self.n_ahead_talk != 1 or self.comparison_mode:
1286
+ if attention_mask is None:
1287
+ base_attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=0).to(input_ids.device)
1288
+ base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
1289
+ base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
1290
+ attention_mask = base_attention_mask
1291
+ elif attention_mask.dim() == 2:
1292
+ if seq_len + past_key_values_length != attention_mask.shape[-1]:
1293
+ attention_mask = torch.cat(
1294
+ [torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
1295
+ dim=-1
1296
+ )
1297
+ attention_mask = _prepare_4d_causal_attention_mask(
1298
+ attention_mask,
1299
+ (batch_size, seq_len),
1300
+ inputs_embeds,
1301
+ past_key_values_length,
1302
+ sliding_window=self.config.sliding_window,
1303
+ )
1304
+
1305
  start_time = time.time()
1306
  for continuation_idx in range(continuation_length):
1307
  outputs = self.model(