Crystalcareai commited on
Commit
c1ae7ac
·
verified ·
1 Parent(s): 40c682f

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +110 -39
modeling_quiet.py CHANGED
@@ -1307,47 +1307,118 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1307
  nn.init.constant_(module.bias, 0)
1308
  elif isinstance(module, nn.Embedding):
1309
  nn.init.xavier_uniform_(module.weight)
1310
-
1311
  @torch.no_grad()
1312
- def generate(self, input_ids, attention_mask=None, streamer=None, **kwargs):
1313
- if attention_mask is None:
1314
- attention_mask = torch.ones_like(input_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1315
 
1316
- max_length = kwargs.get("max_length", 20)
1317
- temp = kwargs.get("temperature", 1.0)
1318
-
1319
- with torch.no_grad():
1320
- finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device)
1321
- for cur_token_idx in range(max_length):
1322
- # Sample the next token
1323
- new_ids = self(
1324
- input_ids[~finished_generating],
1325
- attention_mask=attention_mask[~finished_generating]
1326
- )['logits']
1327
- # Mask out the start and end thought tokens so we don't accidentally sample them
1328
- new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
1329
- for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
1330
- # Find the index of the last token that is not padding
1331
- base_answer_ids = input_ids[answer_idx]
1332
- new_answer_ids = new_ids[list_idx]
1333
- last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
1334
-
1335
- new_ids_sampled = torch.multinomial(
1336
- torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temp, dim=-1), 1)
1337
- # Assign the new id to the last token
1338
- if last_token_idx + 1 >= len(base_answer_ids):
1339
- # Add padding everywhere
1340
- new_padding = torch.full((len(input_ids), 1), self.tokenizer.pad_token_id, dtype=torch.long,
1341
- device=input_ids.device)
1342
- input_ids = torch.cat([input_ids, new_padding], dim=-1)
1343
- attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
1344
- attention_mask[answer_idx, last_token_idx + 1] = 1
1345
- input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
1346
- if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
1347
- finished_generating[answer_idx] = 1
1348
- if finished_generating.all():
1349
- break
1350
- return input_ids, attention_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1351
 
1352
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1353
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
1307
  nn.init.constant_(module.bias, 0)
1308
  elif isinstance(module, nn.Embedding):
1309
  nn.init.xavier_uniform_(module.weight)
1310
+
1311
  @torch.no_grad()
1312
+ def infer(
1313
+ self,
1314
+ input_ids: torch.LongTensor,
1315
+ attention_mask: Optional[torch.Tensor] = None,
1316
+ position_ids: Optional[torch.LongTensor] = None,
1317
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1318
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1319
+ use_cache: Optional[bool] = None,
1320
+ output_attentions: Optional[bool] = None,
1321
+ output_hidden_states: Optional[bool] = None,
1322
+ return_dict: Optional[bool] = None,
1323
+ ):
1324
+ batch_size, seq_len = input_ids.shape
1325
+
1326
+ # Save the original input_ids and attention_mask for later use
1327
+ original_input_ids = input_ids.clone()
1328
+ original_attention_mask = attention_mask.clone() if attention_mask is not None else None
1329
+
1330
+ # Append the start thought token to the input sequence
1331
+ start_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
1332
+ input_ids = torch.cat([input_ids, torch.tensor([[start_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1333
+ seq_len += 1
1334
+
1335
+ # Update the attention mask
1336
+ if attention_mask is not None:
1337
+ attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1338
+
1339
+ # Generate the continuation
1340
+ continuation_length = self.n_ahead - 2
1341
+ new_key_values = past_key_values
1342
 
1343
+ start_time = time.time()
1344
+ for continuation_idx in range(continuation_length):
1345
+ outputs = self.model(
1346
+ input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
1347
+ attention_mask=attention_mask,
1348
+ position_ids=position_ids,
1349
+ past_key_values=new_key_values,
1350
+ inputs_embeds=inputs_embeds,
1351
+ use_cache=True,
1352
+ output_attentions=output_attentions,
1353
+ output_hidden_states=output_hidden_states,
1354
+ return_dict=return_dict,
1355
+ )
1356
+ new_key_values = outputs.past_key_values
1357
+
1358
+ hidden_states = outputs[0]
1359
+
1360
+ logits = self.lm_head(hidden_states)
1361
+ logits = logits[:, -1, :] # Only consider the last token
1362
+
1363
+ # Apply Gumbel-Softmax to the logits
1364
+ next_token_logits = F.gumbel_softmax(logits, tau=self.gumbel_temperature, hard=True, dim=-1)
1365
+ next_token_id = torch.argmax(next_token_logits, dim=-1)
1366
+
1367
+ # Append the generated token to the input sequence
1368
+ input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
1369
+ seq_len += 1
1370
+
1371
+ # Update the attention mask
1372
+ if attention_mask is not None:
1373
+ attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1374
+
1375
+ # Append the end thought token to the input sequence
1376
+ end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
1377
+ input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1378
+ seq_len += 1
1379
+
1380
+ # Update the attention mask
1381
+ if attention_mask is not None:
1382
+ attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1383
+
1384
+ # Get the hidden states before and after the thought
1385
+ outputs_before = self.model(
1386
+ input_ids=original_input_ids,
1387
+ attention_mask=original_attention_mask,
1388
+ position_ids=position_ids,
1389
+ past_key_values=past_key_values,
1390
+ inputs_embeds=inputs_embeds,
1391
+ use_cache=use_cache,
1392
+ output_attentions=output_attentions,
1393
+ output_hidden_states=output_hidden_states,
1394
+ return_dict=return_dict,
1395
+ )
1396
+ hidden_states_before = outputs_before[0][:, -1:, :]
1397
+
1398
+ # two new tokens: last continuation token and end thought token
1399
+ outputs_after = self.model(
1400
+ input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor(end_thought_token_id).unsqueeze(-1).unsqueeze(-1).to(input_ids.device)], dim=-1),
1401
+ attention_mask=attention_mask,
1402
+ position_ids=position_ids,
1403
+ past_key_values=new_key_values,
1404
+ inputs_embeds=inputs_embeds,
1405
+ use_cache=use_cache,
1406
+ output_attentions=output_attentions,
1407
+ output_hidden_states=output_hidden_states,
1408
+ return_dict=return_dict,
1409
+ )
1410
+ hidden_states_after = outputs_after[0][:, -1:, :]
1411
+
1412
+ # Apply the talk head to get the mixing weight
1413
+ mixing_weight = self.talk_head[0](torch.cat([hidden_states_before, hidden_states_after], dim=-1))
1414
+
1415
+ # Apply the mixing weight to the hidden states
1416
+ mixed_hidden_states = (1 - mixing_weight) * hidden_states_before + mixing_weight * hidden_states_after
1417
+
1418
+ # Apply the language model head to get the final logits
1419
+ logits = self.lm_head(mixed_hidden_states)
1420
+ return logits
1421
+
1422
 
1423
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1424
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)