Crystalcareai commited on
Commit
560abff
·
verified ·
1 Parent(s): d3a9a29

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +4 -3
modeling_quiet.py CHANGED
@@ -1261,6 +1261,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1261
  output_attentions: Optional[bool] = None,
1262
  output_hidden_states: Optional[bool] = None,
1263
  return_dict: Optional[bool] = None,
 
1264
  ):
1265
  batch_size, seq_len = input_ids.shape
1266
 
@@ -1278,11 +1279,12 @@ class QuietForCausalLM(QuietPreTrainedModel):
1278
  attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1279
 
1280
  # Generate the continuation
1281
- continuation_length = self.n_ahead - 2
 
1282
  new_key_values = past_key_values
1283
  start_time = time.time()
1284
 
1285
- for continuation_idx in range(continuation_length):
1286
  outputs = self.model(
1287
  input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
1288
  attention_mask=attention_mask,
@@ -1358,7 +1360,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
1358
  logits = self.lm_head(mixed_hidden_states)
1359
 
1360
  return logits
1361
-
1362
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1363
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1364
  def forward(
 
1261
  output_attentions: Optional[bool] = None,
1262
  output_hidden_states: Optional[bool] = None,
1263
  return_dict: Optional[bool] = None,
1264
+ max_length: Optional[int] = None, # Add the max_length argument
1265
  ):
1266
  batch_size, seq_len = input_ids.shape
1267
 
 
1279
  attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1280
 
1281
  # Generate the continuation
1282
+ if max_length is None:
1283
+ max_length = self.n_ahead - 2 # Use the default value if max_length is not provided
1284
  new_key_values = past_key_values
1285
  start_time = time.time()
1286
 
1287
+ for continuation_idx in range(max_length):
1288
  outputs = self.model(
1289
  input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
1290
  attention_mask=attention_mask,
 
1360
  logits = self.lm_head(mixed_hidden_states)
1361
 
1362
  return logits
 
1363
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1364
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1365
  def forward(