Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +4 -3
modeling_quiet.py
CHANGED
@@ -1261,6 +1261,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1261 |
output_attentions: Optional[bool] = None,
|
1262 |
output_hidden_states: Optional[bool] = None,
|
1263 |
return_dict: Optional[bool] = None,
|
|
|
1264 |
):
|
1265 |
batch_size, seq_len = input_ids.shape
|
1266 |
|
@@ -1278,11 +1279,12 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1278 |
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
1279 |
|
1280 |
# Generate the continuation
|
1281 |
-
|
|
|
1282 |
new_key_values = past_key_values
|
1283 |
start_time = time.time()
|
1284 |
|
1285 |
-
for continuation_idx in range(
|
1286 |
outputs = self.model(
|
1287 |
input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
|
1288 |
attention_mask=attention_mask,
|
@@ -1358,7 +1360,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1358 |
logits = self.lm_head(mixed_hidden_states)
|
1359 |
|
1360 |
return logits
|
1361 |
-
|
1362 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1363 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1364 |
def forward(
|
|
|
1261 |
output_attentions: Optional[bool] = None,
|
1262 |
output_hidden_states: Optional[bool] = None,
|
1263 |
return_dict: Optional[bool] = None,
|
1264 |
+
max_length: Optional[int] = None, # Add the max_length argument
|
1265 |
):
|
1266 |
batch_size, seq_len = input_ids.shape
|
1267 |
|
|
|
1279 |
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
1280 |
|
1281 |
# Generate the continuation
|
1282 |
+
if max_length is None:
|
1283 |
+
max_length = self.n_ahead - 2 # Use the default value if max_length is not provided
|
1284 |
new_key_values = past_key_values
|
1285 |
start_time = time.time()
|
1286 |
|
1287 |
+
for continuation_idx in range(max_length):
|
1288 |
outputs = self.model(
|
1289 |
input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
|
1290 |
attention_mask=attention_mask,
|
|
|
1360 |
logits = self.lm_head(mixed_hidden_states)
|
1361 |
|
1362 |
return logits
|
|
|
1363 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1364 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1365 |
def forward(
|