Crystalcareai commited on
Commit
3328933
·
verified ·
1 Parent(s): 5d9b936

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +113 -114
modeling_quiet.py CHANGED
@@ -1311,120 +1311,119 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1311
  elif isinstance(module, nn.Embedding):
1312
  nn.init.xavier_uniform_(module.weight)
1313
 
1314
- @torch.no_grad()
1315
-
1316
- def infer(
1317
- self,
1318
- input_ids: torch.LongTensor,
1319
- attention_mask: Optional[torch.Tensor] = None,
1320
- position_ids: Optional[torch.LongTensor] = None,
1321
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1322
- inputs_embeds: Optional[torch.FloatTensor] = None,
1323
- use_cache: Optional[bool] = None,
1324
- output_attentions: Optional[bool] = None,
1325
- output_hidden_states: Optional[bool] = None,
1326
- return_dict: Optional[bool] = None,
1327
- ):
1328
- batch_size, seq_len = input_ids.shape
1329
-
1330
- # Save the original input_ids and attention_mask for later use
1331
- original_input_ids = input_ids.clone()
1332
- original_attention_mask = attention_mask.clone() if attention_mask is not None else None
1333
-
1334
- # Append the start thought token to the input sequence
1335
- start_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
1336
- input_ids = torch.cat([input_ids, torch.tensor([[start_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1337
- seq_len += 1
1338
-
1339
- # Update the attention mask
1340
- if attention_mask is not None:
1341
- attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1342
-
1343
- # Generate the continuation
1344
- continuation_length = self.n_ahead - 2
1345
- new_key_values = past_key_values
1346
-
1347
- # Initialize next_token_id with a default value
1348
- next_token_id = torch.zeros(batch_size, dtype=torch.long).to(input_ids.device)
1349
-
1350
- start_time = time.time()
1351
- for continuation_idx in range(continuation_length):
1352
- outputs = self.model(
1353
- input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
1354
- attention_mask=attention_mask,
1355
- position_ids=position_ids,
1356
- past_key_values=new_key_values,
1357
- inputs_embeds=inputs_embeds,
1358
- use_cache=True,
1359
- output_attentions=output_attentions,
1360
- output_hidden_states=output_hidden_states,
1361
- return_dict=return_dict,
1362
- )
1363
- new_key_values = outputs.past_key_values
1364
-
1365
- hidden_states = outputs[0]
1366
-
1367
- logits = self.lm_head(hidden_states)
1368
- logits = logits[:, -1, :] # Only consider the last token
1369
-
1370
- # Apply Gumbel-Softmax to the logits
1371
- next_token_logits = F.gumbel_softmax(logits, tau=self.gumbel_temperature, hard=True, dim=-1)
1372
- next_token_id = torch.argmax(next_token_logits, dim=-1)
1373
-
1374
- # Append the generated token to the input sequence
1375
- input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
1376
- seq_len += 1
1377
-
1378
- # Update the attention mask
1379
- if attention_mask is not None:
1380
- attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1381
-
1382
- # Append the end thought token to the input sequence
1383
- end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
1384
- input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1385
- seq_len += 1
1386
-
1387
- # Update the attention mask
1388
- if attention_mask is not None:
1389
- attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1390
-
1391
- # Get the hidden states before and after the thought
1392
- outputs_before = self.model(
1393
- input_ids=original_input_ids,
1394
- attention_mask=original_attention_mask,
1395
- position_ids=position_ids,
1396
- past_key_values=past_key_values,
1397
- inputs_embeds=inputs_embeds,
1398
- use_cache=use_cache,
1399
- output_attentions=output_attentions,
1400
- output_hidden_states=output_hidden_states,
1401
- return_dict=return_dict,
1402
- )
1403
- hidden_states_before = outputs_before[0][:, -1:, :]
1404
-
1405
- # two new tokens: last continuation token and end thought token
1406
- outputs_after = self.model(
1407
- input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor([[end_thought_token_id]]).to(input_ids.device)], dim=-1),
1408
- attention_mask=attention_mask,
1409
- position_ids=position_ids,
1410
- past_key_values=new_key_values,
1411
- inputs_embeds=inputs_embeds,
1412
- use_cache=use_cache,
1413
- output_attentions=output_attentions,
1414
- output_hidden_states=output_hidden_states,
1415
- return_dict=return_dict,
1416
- )
1417
- hidden_states_after = outputs_after[0][:, -1:, :]
1418
-
1419
- # Apply the talk head to get the mixing weight
1420
- mixing_weight = self.talk_head[0](torch.cat([hidden_states_before, hidden_states_after], dim=-1))
1421
-
1422
- # Apply the mixing weight to the hidden states
1423
- mixed_hidden_states = (1 - mixing_weight) * hidden_states_before + mixing_weight * hidden_states_after
1424
-
1425
- # Apply the language model head to get the final logits
1426
- logits = self.lm_head(mixed_hidden_states)
1427
- return logits
1428
 
1429
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1430
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
1311
  elif isinstance(module, nn.Embedding):
1312
  nn.init.xavier_uniform_(module.weight)
1313
 
1314
+ @torch.no_grad()
1315
+ def infer(
1316
+ self,
1317
+ input_ids: torch.LongTensor,
1318
+ attention_mask: Optional[torch.Tensor] = None,
1319
+ position_ids: Optional[torch.LongTensor] = None,
1320
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1321
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1322
+ use_cache: Optional[bool] = None,
1323
+ output_attentions: Optional[bool] = None,
1324
+ output_hidden_states: Optional[bool] = None,
1325
+ return_dict: Optional[bool] = None,
1326
+ ):
1327
+ batch_size, seq_len = input_ids.shape
1328
+
1329
+ # Save the original input_ids and attention_mask for later use
1330
+ original_input_ids = input_ids.clone()
1331
+ original_attention_mask = attention_mask.clone() if attention_mask is not None else None
1332
+
1333
+ # Append the start thought token to the input sequence
1334
+ start_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
1335
+ input_ids = torch.cat([input_ids, torch.tensor([[start_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1336
+ seq_len += 1
1337
+
1338
+ # Update the attention mask
1339
+ if attention_mask is not None:
1340
+ attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1341
+
1342
+ # Generate the continuation
1343
+ continuation_length = self.n_ahead - 2
1344
+ new_key_values = past_key_values
1345
+
1346
+ # Initialize next_token_id with a default value
1347
+ next_token_id = torch.zeros(batch_size, dtype=torch.long).to(input_ids.device)
1348
+
1349
+ start_time = time.time()
1350
+ for continuation_idx in range(continuation_length):
1351
+ outputs = self.model(
1352
+ input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
1353
+ attention_mask=attention_mask,
1354
+ position_ids=position_ids,
1355
+ past_key_values=new_key_values,
1356
+ inputs_embeds=inputs_embeds,
1357
+ use_cache=True,
1358
+ output_attentions=output_attentions,
1359
+ output_hidden_states=output_hidden_states,
1360
+ return_dict=return_dict,
1361
+ )
1362
+ new_key_values = outputs.past_key_values
1363
+
1364
+ hidden_states = outputs[0]
1365
+
1366
+ logits = self.lm_head(hidden_states)
1367
+ logits = logits[:, -1, :] # Only consider the last token
1368
+
1369
+ # Apply Gumbel-Softmax to the logits
1370
+ next_token_logits = F.gumbel_softmax(logits, tau=self.gumbel_temperature, hard=True, dim=-1)
1371
+ next_token_id = torch.argmax(next_token_logits, dim=-1)
1372
+
1373
+ # Append the generated token to the input sequence
1374
+ input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
1375
+ seq_len += 1
1376
+
1377
+ # Update the attention mask
1378
+ if attention_mask is not None:
1379
+ attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1380
+
1381
+ # Append the end thought token to the input sequence
1382
+ end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
1383
+ input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1384
+ seq_len += 1
1385
+
1386
+ # Update the attention mask
1387
+ if attention_mask is not None:
1388
+ attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1389
+
1390
+ # Get the hidden states before and after the thought
1391
+ outputs_before = self.model(
1392
+ input_ids=original_input_ids,
1393
+ attention_mask=original_attention_mask,
1394
+ position_ids=position_ids,
1395
+ past_key_values=past_key_values,
1396
+ inputs_embeds=inputs_embeds,
1397
+ use_cache=use_cache,
1398
+ output_attentions=output_attentions,
1399
+ output_hidden_states=output_hidden_states,
1400
+ return_dict=return_dict,
1401
+ )
1402
+ hidden_states_before = outputs_before[0][:, -1:, :]
1403
+
1404
+ # two new tokens: last continuation token and end thought token
1405
+ outputs_after = self.model(
1406
+ input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor([[end_thought_token_id]]).to(input_ids.device)], dim=-1),
1407
+ attention_mask=attention_mask,
1408
+ position_ids=position_ids,
1409
+ past_key_values=new_key_values,
1410
+ inputs_embeds=inputs_embeds,
1411
+ use_cache=use_cache,
1412
+ output_attentions=output_attentions,
1413
+ output_hidden_states=output_hidden_states,
1414
+ return_dict=return_dict,
1415
+ )
1416
+ hidden_states_after = outputs_after[0][:, -1:, :]
1417
+
1418
+ # Apply the talk head to get the mixing weight
1419
+ mixing_weight = self.talk_head[0](torch.cat([hidden_states_before, hidden_states_after], dim=-1))
1420
+
1421
+ # Apply the mixing weight to the hidden states
1422
+ mixed_hidden_states = (1 - mixing_weight) * hidden_states_before + mixing_weight * hidden_states_after
1423
+
1424
+ # Apply the language model head to get the final logits
1425
+ logits = self.lm_head(mixed_hidden_states)
1426
+ return logits
 
1427
 
1428
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1429
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)