Crystalcareai commited on
Commit
1eb16fa
·
verified ·
1 Parent(s): 9b4bbbb

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +110 -112
modeling_quiet.py CHANGED
@@ -1311,119 +1311,117 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1311
  elif isinstance(module, nn.Embedding):
1312
  nn.init.xavier_uniform_(module.weight)
1313
 
1314
- @torch.no_grad()
1315
- def infer(
1316
- self,
1317
- input_ids: torch.LongTensor,
1318
- attention_mask: Optional[torch.Tensor] = None,
1319
- position_ids: Optional[torch.LongTensor] = None,
1320
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1321
- inputs_embeds: Optional[torch.FloatTensor] = None,
1322
- use_cache: Optional[bool] = None,
1323
- output_attentions: Optional[bool] = None,
1324
- output_hidden_states: Optional[bool] = None,
1325
- return_dict: Optional[bool] = None,
1326
- ):
1327
- batch_size, seq_len = input_ids.shape
1328
-
1329
- # Save the original input_ids and attention_mask for later use
1330
- original_input_ids = input_ids.clone()
1331
- original_attention_mask = attention_mask.clone() if attention_mask is not None else None
1332
-
1333
- # Append the start thought token to the input sequence
1334
- start_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
1335
- input_ids = torch.cat([input_ids, torch.tensor([[start_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1336
- seq_len += 1
1337
-
1338
- # Update the attention mask
1339
- if attention_mask is not None:
1340
- attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1341
-
1342
- # Generate the continuation
1343
- continuation_length = self.n_ahead - 2
1344
- new_key_values = past_key_values
1345
-
1346
- # Initialize next_token_id with a default value
1347
- next_token_id = torch.zeros(batch_size, dtype=torch.long).to(input_ids.device)
1348
-
1349
- start_time = time.time()
1350
- for continuation_idx in range(continuation_length):
1351
- outputs = self.model(
1352
- input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
1353
- attention_mask=attention_mask,
1354
- position_ids=position_ids,
1355
- past_key_values=new_key_values,
1356
- inputs_embeds=inputs_embeds,
1357
- use_cache=True,
1358
- output_attentions=output_attentions,
1359
- output_hidden_states=output_hidden_states,
1360
- return_dict=return_dict,
1361
- )
1362
- new_key_values = outputs.past_key_values
1363
-
1364
- hidden_states = outputs[0]
1365
-
1366
- logits = self.lm_head(hidden_states)
1367
- logits = logits[:, -1, :] # Only consider the last token
1368
-
1369
- # Apply Gumbel-Softmax to the logits
1370
- next_token_logits = F.gumbel_softmax(logits, tau=self.gumbel_temperature, hard=True, dim=-1)
1371
- next_token_id = torch.argmax(next_token_logits, dim=-1)
1372
-
1373
- # Append the generated token to the input sequence
1374
- # input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
1375
- seq_len += 1
1376
-
1377
- # Update the attention mask
1378
- if attention_mask is not None:
1379
- attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1380
-
1381
- # Append the end thought token to the input sequence
1382
- end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
1383
- input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1384
- seq_len += 1
1385
-
1386
- # Update the attention mask
1387
- if attention_mask is not None:
1388
- attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1389
-
1390
- # Get the hidden states before and after the thought
1391
- outputs_before = self.model(
1392
- input_ids=original_input_ids,
1393
- attention_mask=original_attention_mask,
1394
- position_ids=position_ids,
1395
- past_key_values=past_key_values,
1396
- inputs_embeds=inputs_embeds,
1397
- use_cache=use_cache,
1398
- output_attentions=output_attentions,
1399
- output_hidden_states=output_hidden_states,
1400
- return_dict=return_dict,
1401
- )
1402
- hidden_states_before = outputs_before[0][:, -1:, :]
1403
-
1404
- # two new tokens: last continuation token and end thought token
1405
- outputs_after = self.model(
1406
- input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1),
1407
- attention_mask=torch.cat([attention_mask[:, -1:], torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1),
1408
- position_ids=position_ids,
1409
- past_key_values=new_key_values,
1410
- inputs_embeds=inputs_embeds,
1411
- use_cache=use_cache,
1412
- output_attentions=output_attentions,
1413
- output_hidden_states=output_hidden_states,
1414
- return_dict=return_dict,
1415
- )
1416
- hidden_states_after = outputs_after[0][:, -1:, :]
1417
-
1418
- # Apply the talk head to get the mixing weight
1419
- mixing_weight = self.talk_head[0](torch.cat([hidden_states_before, hidden_states_after], dim=-1))
1420
-
1421
- # Apply the mixing weight to the hidden states
1422
- mixed_hidden_states = (1 - mixing_weight) * hidden_states_before + mixing_weight * hidden_states_after
 
1423
 
1424
- # Apply the language model head to get the final logits
1425
- logits = self.lm_head(mixed_hidden_states)
1426
- return logits
1427
 
1428
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1429
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
1311
  elif isinstance(module, nn.Embedding):
1312
  nn.init.xavier_uniform_(module.weight)
1313
 
1314
+ @torch.no_grad()
1315
+ def infer(
1316
+ self,
1317
+ input_ids: torch.LongTensor,
1318
+ attention_mask: Optional[torch.Tensor] = None,
1319
+ position_ids: Optional[torch.LongTensor] = None,
1320
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1321
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1322
+ use_cache: Optional[bool] = None,
1323
+ output_attentions: Optional[bool] = None,
1324
+ output_hidden_states: Optional[bool] = None,
1325
+ return_dict: Optional[bool] = None,
1326
+ ):
1327
+ batch_size, seq_len = input_ids.shape
1328
+
1329
+ # Save the original input_ids and attention_mask for later use
1330
+ original_input_ids = input_ids.clone()
1331
+ original_attention_mask = attention_mask.clone() if attention_mask is not None else None
1332
+
1333
+ # Append the start thought token to the input sequence
1334
+ start_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
1335
+ input_ids = torch.cat([input_ids, torch.tensor([[start_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1336
+ seq_len += 1
1337
+
1338
+ # Update the attention mask
1339
+ if attention_mask is not None:
1340
+ attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1341
+
1342
+ # Generate the continuation
1343
+ continuation_length = self.n_ahead - 2
1344
+ new_key_values = past_key_values
1345
+
1346
+ start_time = time.time()
1347
+ for continuation_idx in range(continuation_length):
1348
+ outputs = self.model(
1349
+ input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
1350
+ attention_mask=attention_mask,
1351
+ position_ids=position_ids,
1352
+ past_key_values=new_key_values,
1353
+ inputs_embeds=inputs_embeds,
1354
+ use_cache=True,
1355
+ output_attentions=output_attentions,
1356
+ output_hidden_states=output_hidden_states,
1357
+ return_dict=return_dict,
1358
+ )
1359
+ new_key_values = outputs.past_key_values
1360
+
1361
+ hidden_states = outputs[0]
1362
+
1363
+ logits = self.lm_head(hidden_states)
1364
+ logits = logits[:, -1, :] # Only consider the last token
1365
+
1366
+ # Apply Gumbel-Softmax to the logits
1367
+ next_token_logits = F.gumbel_softmax(logits, tau=self.gumbel_temperature, hard=True, dim=-1)
1368
+ next_token_id = torch.argmax(next_token_logits, dim=-1)
1369
+
1370
+ # Append the generated token to the input sequence
1371
+ input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
1372
+ seq_len += 1
1373
+
1374
+ # Update the attention mask
1375
+ if attention_mask is not None:
1376
+ attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1377
+
1378
+ # Append the end thought token to the input sequence
1379
+ end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
1380
+ input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1381
+ seq_len += 1
1382
+
1383
+ # Update the attention mask
1384
+ if attention_mask is not None:
1385
+ attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1386
+
1387
+ # Get the hidden states before and after the thought
1388
+ outputs_before = self.model(
1389
+ input_ids=original_input_ids,
1390
+ attention_mask=original_attention_mask,
1391
+ position_ids=position_ids,
1392
+ past_key_values=past_key_values,
1393
+ inputs_embeds=inputs_embeds,
1394
+ use_cache=use_cache,
1395
+ output_attentions=output_attentions,
1396
+ output_hidden_states=output_hidden_states,
1397
+ return_dict=return_dict,
1398
+ )
1399
+ hidden_states_before = outputs_before[0][:, -1:, :]
1400
+
1401
+ # two new tokens: last continuation token and end thought token
1402
+ outputs_after = self.model(
1403
+ input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor(end_thought_token_id).unsqueeze(-1).unsqueeze(-1).to(input_ids.device)], dim=-1),
1404
+ attention_mask=attention_mask,
1405
+ position_ids=position_ids,
1406
+ past_key_values=new_key_values,
1407
+ inputs_embeds=inputs_embeds,
1408
+ use_cache=use_cache,
1409
+ output_attentions=output_attentions,
1410
+ output_hidden_states=output_hidden_states,
1411
+ return_dict=return_dict,
1412
+ )
1413
+ hidden_states_after = outputs_after[0][:, -1:, :]
1414
+
1415
+ # Apply the talk head to get the mixing weight
1416
+ mixing_weight = self.talk_head[0](torch.cat([hidden_states_before, hidden_states_after], dim=-1))
1417
+
1418
+ # Apply the mixing weight to the hidden states
1419
+ mixed_hidden_states = (1 - mixing_weight) * hidden_states_before + mixing_weight * hidden_states_after
1420
+
1421
+ # Apply the language model head to get the final logits
1422
+ logits = self.lm_head(mixed_hidden_states)
1423
+ return logits
1424
 
 
 
 
1425
 
1426
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1427
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)