Crystalcareai commited on
Commit
42469fd
·
verified ·
1 Parent(s): e281793

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +173 -174
modeling_quiet.py CHANGED
@@ -1315,180 +1315,179 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1315
  return model
1316
 
1317
 
1318
- @torch.no_grad()
1319
- def infer(
1320
- self,
1321
- input_ids: torch.LongTensor,
1322
- attention_mask: Optional[torch.Tensor] = None,
1323
- position_ids: Optional[torch.LongTensor] = None,
1324
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1325
- inputs_embeds: Optional[torch.FloatTensor] = None,
1326
- use_cache: Optional[bool] = None,
1327
- output_attentions: Optional[bool] = None,
1328
- output_hidden_states: Optional[bool] = None,
1329
- return_dict: Optional[bool] = None,
1330
- ):
1331
- batch_size, seq_len = input_ids.shape
1332
-
1333
- # Save the original input_ids and attention_mask for later use
1334
- original_input_ids = input_ids.clone()
1335
- original_attention_mask = attention_mask.clone() if attention_mask is not None else None
1336
-
1337
- # Append the start thought token to the input sequence
1338
- start_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
1339
- input_ids = torch.cat([input_ids, torch.tensor([[start_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1340
- seq_len += 1
1341
-
1342
- # Update the attention mask
1343
- if attention_mask is not None:
1344
- attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1345
-
1346
- # Generate the continuation
1347
- continuation_length = self.n_ahead - 2
1348
- new_key_values = past_key_values
1349
- next_token_id_defined = False # Flag to check if next_token_id is defined
1350
-
1351
- start_time = time.time()
1352
- for continuation_idx in range(continuation_length):
1353
- outputs = self.model(
1354
- input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
1355
- attention_mask=attention_mask,
1356
- position_ids=position_ids,
1357
- past_key_values=new_key_values,
1358
- inputs_embeds=inputs_embeds,
1359
- use_cache=True,
1360
- output_attentions=output_attentions,
1361
- output_hidden_states=output_hidden_states,
1362
- return_dict=return_dict,
1363
- )
1364
- new_key_values = outputs.past_key_values
1365
-
1366
- hidden_states = outputs[0]
1367
-
1368
- logits = self.lm_head(hidden_states)
1369
- logits = logits[:, -1, :] # Only consider the last token
1370
-
1371
- # Apply Gumbel-Softmax to the logits
1372
- next_token_logits = F.gumbel_softmax(logits, tau=self.gumbel_temperature, hard=True, dim=-1)
1373
- next_token_id = torch.argmax(next_token_logits, dim=-1)
1374
-
1375
- # Append the generated token to the input sequence
1376
- input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
1377
- seq_len += 1
1378
-
1379
- # Update the attention mask
1380
- if attention_mask is not None:
1381
- attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1382
-
1383
- next_token_id_defined = True # Set the flag to True after next_token_id is defined
1384
-
1385
- # Check if next_token_id is defined before using it
1386
- if next_token_id_defined:
1387
- # Append the end thought token to the input sequence
1388
- end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
1389
- input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1390
- seq_len += 1
1391
-
1392
- # Update the attention mask
1393
- if attention_mask is not None:
1394
- attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1395
-
1396
- # Get the hidden states before and after the thought
1397
- outputs_before = self.model(
1398
- input_ids=original_input_ids,
1399
- attention_mask=original_attention_mask,
1400
- position_ids=position_ids,
1401
- past_key_values=past_key_values,
1402
- inputs_embeds=inputs_embeds,
1403
- use_cache=use_cache,
1404
- output_attentions=output_attentions,
1405
- output_hidden_states=output_hidden_states,
1406
- return_dict=return_dict,
1407
- )
1408
- hidden_states_before = outputs_before[0][:, -1:, :]
1409
-
1410
- # Only execute if next_token_id is defined
1411
- if next_token_id_defined:
1412
- outputs_after = self.model(
1413
- input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1),
1414
- attention_mask=attention_mask,
1415
- position_ids=position_ids,
1416
- past_key_values=new_key_values,
1417
- inputs_embeds=inputs_embeds,
1418
- use_cache=use_cache,
1419
- output_attentions=output_attentions,
1420
- output_hidden_states=output_hidden_states,
1421
- return_dict=return_dict,
1422
- )
1423
- hidden_states_after = outputs_after[0][:, -1:, :]
1424
-
1425
- # Apply the talk head to get the mixing weight
1426
- mixing_weight = self.talk_head[0](torch.cat([hidden_states_before, hidden_states_after], dim=-1))
1427
-
1428
- # Apply the mixing weight to the hidden states
1429
- mixed_hidden_states = (1 - mixing_weight) * hidden_states_before + mixing_weight * hidden_states_after
1430
-
1431
- # Apply the language model head to get the final logits
1432
- logits = self.lm_head(mixed_hidden_states)
1433
-
1434
- if not return_dict:
1435
- return logits
1436
-
1437
- return BaseModelOutputWithPast(
1438
- logits=logits,
1439
- past_key_values=new_key_values,
1440
- hidden_states=outputs_after.hidden_states if output_hidden_states else None,
1441
- attentions=outputs_after.attentions if output_attentions else None,
1442
- )
1443
- else:
1444
- # Handle the case where next_token_id is not defined (e.g., continuation_length <= 0)
1445
- # This part of the code needs to be adapted based on how you want to handle this scenario.
1446
- # As a placeholder, returning the logits from the last state of the original input.
1447
- logits = self.lm_head(hidden_states_before)
1448
-
1449
- if not return_dict:
1450
- return logits
1451
-
1452
- return BaseModelOutputWithPast(
1453
- logits=logits,
1454
- past_key_values=past_key_values,
1455
- hidden_states=outputs_before.hidden_states if output_hidden_states else None,
1456
- attentions=outputs_before.attentions if output_attentions else None,
1457
- )
1458
-
1459
- @torch.no_grad()
1460
- def generate(
1461
- self,
1462
- input_ids: torch.LongTensor,
1463
- attention_mask: Optional[torch.Tensor] = None,
1464
- position_ids: Optional[torch.LongTensor] = None,
1465
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1466
- inputs_embeds: Optional[torch.FloatTensor] = None,
1467
- use_cache: Optional[bool] = None,
1468
- output_attentions: Optional[bool] = None,
1469
- output_hidden_states: Optional[bool] = None,
1470
- return_dict_in_generate: Optional[bool] = None,
1471
- **model_kwargs,
1472
- ) -> Union[BaseModelOutputWithPast, torch.LongTensor]:
1473
- return_dict = return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
1474
-
1475
- output = self.infer(
1476
- input_ids=input_ids,
1477
- attention_mask=attention_mask,
1478
- position_ids=position_ids,
1479
- past_key_values=past_key_values,
1480
- inputs_embeds=inputs_embeds,
1481
- use_cache=use_cache,
1482
- output_attentions=output_attentions,
1483
- output_hidden_states=output_hidden_states,
1484
- return_dict=return_dict,
1485
- )
1486
-
1487
- if return_dict:
1488
- return output
1489
- else:
1490
- return output.logits
1491
- @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1492
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1493
  def forward(
1494
  self,
 
1315
  return model
1316
 
1317
 
1318
+ @torch.no_grad()
1319
+ def infer(
1320
+ self,
1321
+ input_ids: torch.LongTensor,
1322
+ attention_mask: Optional[torch.Tensor] = None,
1323
+ position_ids: Optional[torch.LongTensor] = None,
1324
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1325
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1326
+ use_cache: Optional[bool] = None,
1327
+ output_attentions: Optional[bool] = None,
1328
+ output_hidden_states: Optional[bool] = None,
1329
+ return_dict: Optional[bool] = None,
1330
+ ):
1331
+ batch_size, seq_len = input_ids.shape
1332
+
1333
+ # Save the original input_ids and attention_mask for later use
1334
+ original_input_ids = input_ids.clone()
1335
+ original_attention_mask = attention_mask.clone() if attention_mask is not None else None
1336
+
1337
+ # Append the start thought token to the input sequence
1338
+ start_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
1339
+ input_ids = torch.cat([input_ids, torch.tensor([[start_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1340
+ seq_len += 1
1341
+
1342
+ # Update the attention mask
1343
+ if attention_mask is not None:
1344
+ attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1345
+
1346
+ # Generate the continuation
1347
+ continuation_length = self.n_ahead - 2
1348
+ new_key_values = past_key_values
1349
+ next_token_id_defined = False # Flag to check if next_token_id is defined
1350
+
1351
+ start_time = time.time()
1352
+ for continuation_idx in range(continuation_length):
1353
+ outputs = self.model(
1354
+ input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
1355
+ attention_mask=attention_mask,
1356
+ position_ids=position_ids,
1357
+ past_key_values=new_key_values,
1358
+ inputs_embeds=inputs_embeds,
1359
+ use_cache=True,
1360
+ output_attentions=output_attentions,
1361
+ output_hidden_states=output_hidden_states,
1362
+ return_dict=return_dict,
1363
+ )
1364
+ new_key_values = outputs.past_key_values
1365
+
1366
+ hidden_states = outputs[0]
1367
+
1368
+ logits = self.lm_head(hidden_states)
1369
+ logits = logits[:, -1, :] # Only consider the last token
1370
+
1371
+ # Apply Gumbel-Softmax to the logits
1372
+ next_token_logits = F.gumbel_softmax(logits, tau=self.gumbel_temperature, hard=True, dim=-1)
1373
+ next_token_id = torch.argmax(next_token_logits, dim=-1)
1374
+
1375
+ # Append the generated token to the input sequence
1376
+ input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
1377
+ seq_len += 1
1378
+
1379
+ # Update the attention mask
1380
+ if attention_mask is not None:
1381
+ attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1382
+
1383
+ next_token_id_defined = True # Set the flag to True after next_token_id is defined
1384
+
1385
+ # Check if next_token_id is defined before using it
1386
+ if next_token_id_defined:
1387
+ # Append the end thought token to the input sequence
1388
+ end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
1389
+ input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1390
+ seq_len += 1
1391
+
1392
+ # Update the attention mask
1393
+ if attention_mask is not None:
1394
+ attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1395
+
1396
+ # Get the hidden states before and after the thought
1397
+ outputs_before = self.model(
1398
+ input_ids=original_input_ids,
1399
+ attention_mask=original_attention_mask,
1400
+ position_ids=position_ids,
1401
+ past_key_values=past_key_values,
1402
+ inputs_embeds=inputs_embeds,
1403
+ use_cache=use_cache,
1404
+ output_attentions=output_attentions,
1405
+ output_hidden_states=output_hidden_states,
1406
+ return_dict=return_dict,
1407
+ )
1408
+ hidden_states_before = outputs_before[0][:, -1:, :]
1409
+
1410
+ # Only execute if next_token_id is defined
1411
+ if next_token_id_defined:
1412
+ outputs_after = self.model(
1413
+ input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1),
1414
+ attention_mask=attention_mask,
1415
+ position_ids=position_ids,
1416
+ past_key_values=new_key_values,
1417
+ inputs_embeds=inputs_embeds,
1418
+ use_cache=use_cache,
1419
+ output_attentions=output_attentions,
1420
+ output_hidden_states=output_hidden_states,
1421
+ return_dict=return_dict,
1422
+ )
1423
+ hidden_states_after = outputs_after[0][:, -1:, :]
1424
+
1425
+ # Apply the talk head to get the mixing weight
1426
+ mixing_weight = self.talk_head[0](torch.cat([hidden_states_before, hidden_states_after], dim=-1))
1427
+
1428
+ # Apply the mixing weight to the hidden states
1429
+ mixed_hidden_states = (1 - mixing_weight) * hidden_states_before + mixing_weight * hidden_states_after
1430
+
1431
+ # Apply the language model head to get the final logits
1432
+ logits = self.lm_head(mixed_hidden_states)
1433
+
1434
+ if not return_dict:
1435
+ return logits
1436
+
1437
+ return BaseModelOutputWithPast(
1438
+ logits=logits,
1439
+ past_key_values=new_key_values,
1440
+ hidden_states=outputs_after.hidden_states if output_hidden_states else None,
1441
+ attentions=outputs_after.attentions if output_attentions else None,
1442
+ )
1443
+ else:
1444
+ # Handle the case where next_token_id is not defined (e.g., continuation_length <= 0)
1445
+ # This part of the code needs to be adapted based on how you want to handle this scenario.
1446
+ # As a placeholder, returning the logits from the last state of the original input.
1447
+ logits = self.lm_head(hidden_states_before)
1448
+
1449
+ if not return_dict:
1450
+ return logits
1451
+
1452
+ return BaseModelOutputWithPast(
1453
+ logits=logits,
1454
+ past_key_values=past_key_values,
1455
+ hidden_states=outputs_before.hidden_states if output_hidden_states else None,
1456
+ attentions=outputs_before.attentions if output_attentions else None,
1457
+ )
1458
+
1459
+ @torch.no_grad()
1460
+ def generate(
1461
+ self,
1462
+ input_ids: torch.LongTensor,
1463
+ attention_mask: Optional[torch.Tensor] = None,
1464
+ position_ids: Optional[torch.LongTensor] = None,
1465
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1466
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1467
+ use_cache: Optional[bool] = None,
1468
+ output_attentions: Optional[bool] = None,
1469
+ output_hidden_states: Optional[bool] = None,
1470
+ return_dict_in_generate: Optional[bool] = None,
1471
+ **model_kwargs,
1472
+ ) -> Union[BaseModelOutputWithPast, torch.LongTensor]:
1473
+ return_dict = return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
1474
+
1475
+ output = self.infer(
1476
+ input_ids=input_ids,
1477
+ attention_mask=attention_mask,
1478
+ position_ids=position_ids,
1479
+ past_key_values=past_key_values,
1480
+ inputs_embeds=inputs_embeds,
1481
+ use_cache=use_cache,
1482
+ output_attentions=output_attentions,
1483
+ output_hidden_states=output_hidden_states,
1484
+ return_dict=return_dict,
1485
+ )
1486
+
1487
+ if return_dict:
1488
+ return output
1489
+ else:
1490
+ return output.logits
 
1491
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1492
  def forward(
1493
  self,