Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +173 -174
modeling_quiet.py
CHANGED
@@ -1315,180 +1315,179 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1315 |
return model
|
1316 |
|
1317 |
|
1318 |
-
|
1319 |
-
|
1320 |
-
|
1321 |
-
|
1322 |
-
|
1323 |
-
|
1324 |
-
|
1325 |
-
|
1326 |
-
|
1327 |
-
|
1328 |
-
|
1329 |
-
|
1330 |
-
|
1331 |
-
|
1332 |
-
|
1333 |
-
|
1334 |
-
|
1335 |
-
|
1336 |
-
|
1337 |
-
|
1338 |
-
|
1339 |
-
|
1340 |
-
|
1341 |
-
|
1342 |
-
|
1343 |
-
|
1344 |
-
|
1345 |
-
|
1346 |
-
|
1347 |
-
|
1348 |
-
|
1349 |
-
|
1350 |
-
|
1351 |
-
|
1352 |
-
|
1353 |
-
|
1354 |
-
|
1355 |
-
|
1356 |
-
|
1357 |
-
|
1358 |
-
|
1359 |
-
|
1360 |
-
|
1361 |
-
|
1362 |
-
|
1363 |
-
|
1364 |
-
|
1365 |
-
|
1366 |
-
|
1367 |
-
|
1368 |
-
|
1369 |
-
|
1370 |
-
|
1371 |
-
|
1372 |
-
|
1373 |
-
|
1374 |
-
|
1375 |
-
|
1376 |
-
|
1377 |
-
|
1378 |
-
|
1379 |
-
|
1380 |
-
|
1381 |
-
|
1382 |
-
|
1383 |
-
|
1384 |
-
|
1385 |
-
|
1386 |
-
|
1387 |
-
|
1388 |
-
|
1389 |
-
|
1390 |
-
|
1391 |
-
|
1392 |
-
|
1393 |
-
|
1394 |
-
|
1395 |
-
|
1396 |
-
|
1397 |
-
|
1398 |
-
|
1399 |
-
|
1400 |
-
|
1401 |
-
|
1402 |
-
|
1403 |
-
|
1404 |
-
|
1405 |
-
|
1406 |
-
|
1407 |
-
|
1408 |
-
|
1409 |
-
|
1410 |
-
|
1411 |
-
|
1412 |
-
|
1413 |
-
|
1414 |
-
|
1415 |
-
|
1416 |
-
|
1417 |
-
|
1418 |
-
|
1419 |
-
|
1420 |
-
|
1421 |
-
|
1422 |
-
|
1423 |
-
|
1424 |
-
|
1425 |
-
|
1426 |
-
|
1427 |
-
|
1428 |
-
|
1429 |
-
|
1430 |
-
|
1431 |
-
|
1432 |
-
|
1433 |
-
|
1434 |
-
|
1435 |
-
|
1436 |
-
|
1437 |
-
|
1438 |
-
|
1439 |
-
|
1440 |
-
|
1441 |
-
|
1442 |
-
|
1443 |
-
|
1444 |
-
|
1445 |
-
|
1446 |
-
|
1447 |
-
|
1448 |
-
|
1449 |
-
|
1450 |
-
|
1451 |
-
|
1452 |
-
|
1453 |
-
|
1454 |
-
|
1455 |
-
|
1456 |
-
|
1457 |
-
|
1458 |
-
|
1459 |
-
|
1460 |
-
|
1461 |
-
|
1462 |
-
|
1463 |
-
|
1464 |
-
|
1465 |
-
|
1466 |
-
|
1467 |
-
|
1468 |
-
|
1469 |
-
|
1470 |
-
|
1471 |
-
|
1472 |
-
|
1473 |
-
|
1474 |
-
|
1475 |
-
|
1476 |
-
|
1477 |
-
|
1478 |
-
|
1479 |
-
|
1480 |
-
|
1481 |
-
|
1482 |
-
|
1483 |
-
|
1484 |
-
|
1485 |
-
|
1486 |
-
|
1487 |
-
|
1488 |
-
|
1489 |
-
|
1490 |
-
|
1491 |
-
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1492 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1493 |
def forward(
|
1494 |
self,
|
|
|
1315 |
return model
|
1316 |
|
1317 |
|
1318 |
+
@torch.no_grad()
|
1319 |
+
def infer(
|
1320 |
+
self,
|
1321 |
+
input_ids: torch.LongTensor,
|
1322 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1323 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1324 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1325 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1326 |
+
use_cache: Optional[bool] = None,
|
1327 |
+
output_attentions: Optional[bool] = None,
|
1328 |
+
output_hidden_states: Optional[bool] = None,
|
1329 |
+
return_dict: Optional[bool] = None,
|
1330 |
+
):
|
1331 |
+
batch_size, seq_len = input_ids.shape
|
1332 |
+
|
1333 |
+
# Save the original input_ids and attention_mask for later use
|
1334 |
+
original_input_ids = input_ids.clone()
|
1335 |
+
original_attention_mask = attention_mask.clone() if attention_mask is not None else None
|
1336 |
+
|
1337 |
+
# Append the start thought token to the input sequence
|
1338 |
+
start_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
|
1339 |
+
input_ids = torch.cat([input_ids, torch.tensor([[start_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
|
1340 |
+
seq_len += 1
|
1341 |
+
|
1342 |
+
# Update the attention mask
|
1343 |
+
if attention_mask is not None:
|
1344 |
+
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
1345 |
+
|
1346 |
+
# Generate the continuation
|
1347 |
+
continuation_length = self.n_ahead - 2
|
1348 |
+
new_key_values = past_key_values
|
1349 |
+
next_token_id_defined = False # Flag to check if next_token_id is defined
|
1350 |
+
|
1351 |
+
start_time = time.time()
|
1352 |
+
for continuation_idx in range(continuation_length):
|
1353 |
+
outputs = self.model(
|
1354 |
+
input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
|
1355 |
+
attention_mask=attention_mask,
|
1356 |
+
position_ids=position_ids,
|
1357 |
+
past_key_values=new_key_values,
|
1358 |
+
inputs_embeds=inputs_embeds,
|
1359 |
+
use_cache=True,
|
1360 |
+
output_attentions=output_attentions,
|
1361 |
+
output_hidden_states=output_hidden_states,
|
1362 |
+
return_dict=return_dict,
|
1363 |
+
)
|
1364 |
+
new_key_values = outputs.past_key_values
|
1365 |
+
|
1366 |
+
hidden_states = outputs[0]
|
1367 |
+
|
1368 |
+
logits = self.lm_head(hidden_states)
|
1369 |
+
logits = logits[:, -1, :] # Only consider the last token
|
1370 |
+
|
1371 |
+
# Apply Gumbel-Softmax to the logits
|
1372 |
+
next_token_logits = F.gumbel_softmax(logits, tau=self.gumbel_temperature, hard=True, dim=-1)
|
1373 |
+
next_token_id = torch.argmax(next_token_logits, dim=-1)
|
1374 |
+
|
1375 |
+
# Append the generated token to the input sequence
|
1376 |
+
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
|
1377 |
+
seq_len += 1
|
1378 |
+
|
1379 |
+
# Update the attention mask
|
1380 |
+
if attention_mask is not None:
|
1381 |
+
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
1382 |
+
|
1383 |
+
next_token_id_defined = True # Set the flag to True after next_token_id is defined
|
1384 |
+
|
1385 |
+
# Check if next_token_id is defined before using it
|
1386 |
+
if next_token_id_defined:
|
1387 |
+
# Append the end thought token to the input sequence
|
1388 |
+
end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
|
1389 |
+
input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
|
1390 |
+
seq_len += 1
|
1391 |
+
|
1392 |
+
# Update the attention mask
|
1393 |
+
if attention_mask is not None:
|
1394 |
+
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
1395 |
+
|
1396 |
+
# Get the hidden states before and after the thought
|
1397 |
+
outputs_before = self.model(
|
1398 |
+
input_ids=original_input_ids,
|
1399 |
+
attention_mask=original_attention_mask,
|
1400 |
+
position_ids=position_ids,
|
1401 |
+
past_key_values=past_key_values,
|
1402 |
+
inputs_embeds=inputs_embeds,
|
1403 |
+
use_cache=use_cache,
|
1404 |
+
output_attentions=output_attentions,
|
1405 |
+
output_hidden_states=output_hidden_states,
|
1406 |
+
return_dict=return_dict,
|
1407 |
+
)
|
1408 |
+
hidden_states_before = outputs_before[0][:, -1:, :]
|
1409 |
+
|
1410 |
+
# Only execute if next_token_id is defined
|
1411 |
+
if next_token_id_defined:
|
1412 |
+
outputs_after = self.model(
|
1413 |
+
input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1),
|
1414 |
+
attention_mask=attention_mask,
|
1415 |
+
position_ids=position_ids,
|
1416 |
+
past_key_values=new_key_values,
|
1417 |
+
inputs_embeds=inputs_embeds,
|
1418 |
+
use_cache=use_cache,
|
1419 |
+
output_attentions=output_attentions,
|
1420 |
+
output_hidden_states=output_hidden_states,
|
1421 |
+
return_dict=return_dict,
|
1422 |
+
)
|
1423 |
+
hidden_states_after = outputs_after[0][:, -1:, :]
|
1424 |
+
|
1425 |
+
# Apply the talk head to get the mixing weight
|
1426 |
+
mixing_weight = self.talk_head[0](torch.cat([hidden_states_before, hidden_states_after], dim=-1))
|
1427 |
+
|
1428 |
+
# Apply the mixing weight to the hidden states
|
1429 |
+
mixed_hidden_states = (1 - mixing_weight) * hidden_states_before + mixing_weight * hidden_states_after
|
1430 |
+
|
1431 |
+
# Apply the language model head to get the final logits
|
1432 |
+
logits = self.lm_head(mixed_hidden_states)
|
1433 |
+
|
1434 |
+
if not return_dict:
|
1435 |
+
return logits
|
1436 |
+
|
1437 |
+
return BaseModelOutputWithPast(
|
1438 |
+
logits=logits,
|
1439 |
+
past_key_values=new_key_values,
|
1440 |
+
hidden_states=outputs_after.hidden_states if output_hidden_states else None,
|
1441 |
+
attentions=outputs_after.attentions if output_attentions else None,
|
1442 |
+
)
|
1443 |
+
else:
|
1444 |
+
# Handle the case where next_token_id is not defined (e.g., continuation_length <= 0)
|
1445 |
+
# This part of the code needs to be adapted based on how you want to handle this scenario.
|
1446 |
+
# As a placeholder, returning the logits from the last state of the original input.
|
1447 |
+
logits = self.lm_head(hidden_states_before)
|
1448 |
+
|
1449 |
+
if not return_dict:
|
1450 |
+
return logits
|
1451 |
+
|
1452 |
+
return BaseModelOutputWithPast(
|
1453 |
+
logits=logits,
|
1454 |
+
past_key_values=past_key_values,
|
1455 |
+
hidden_states=outputs_before.hidden_states if output_hidden_states else None,
|
1456 |
+
attentions=outputs_before.attentions if output_attentions else None,
|
1457 |
+
)
|
1458 |
+
|
1459 |
+
@torch.no_grad()
|
1460 |
+
def generate(
|
1461 |
+
self,
|
1462 |
+
input_ids: torch.LongTensor,
|
1463 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1464 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1465 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1466 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1467 |
+
use_cache: Optional[bool] = None,
|
1468 |
+
output_attentions: Optional[bool] = None,
|
1469 |
+
output_hidden_states: Optional[bool] = None,
|
1470 |
+
return_dict_in_generate: Optional[bool] = None,
|
1471 |
+
**model_kwargs,
|
1472 |
+
) -> Union[BaseModelOutputWithPast, torch.LongTensor]:
|
1473 |
+
return_dict = return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
1474 |
+
|
1475 |
+
output = self.infer(
|
1476 |
+
input_ids=input_ids,
|
1477 |
+
attention_mask=attention_mask,
|
1478 |
+
position_ids=position_ids,
|
1479 |
+
past_key_values=past_key_values,
|
1480 |
+
inputs_embeds=inputs_embeds,
|
1481 |
+
use_cache=use_cache,
|
1482 |
+
output_attentions=output_attentions,
|
1483 |
+
output_hidden_states=output_hidden_states,
|
1484 |
+
return_dict=return_dict,
|
1485 |
+
)
|
1486 |
+
|
1487 |
+
if return_dict:
|
1488 |
+
return output
|
1489 |
+
else:
|
1490 |
+
return output.logits
|
|
|
1491 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1492 |
def forward(
|
1493 |
self,
|