Crystalcareai commited on
Commit
1e2fc7d
·
verified ·
1 Parent(s): 13b833f

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +33 -37
modeling_quiet.py CHANGED
@@ -1328,46 +1328,42 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1328
  def generate(self, input_ids, attention_mask=None, **kwargs):
1329
  if attention_mask is None:
1330
  attention_mask = torch.ones_like(input_ids)
1331
-
1332
  max_length = kwargs.get("max_length", 20)
1333
  temp = kwargs.get("temperature", 1.0)
1334
 
1335
- batch_size = input_ids.shape[0]
1336
- device = input_ids.device
1337
-
1338
- finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device)
1339
-
1340
- for cur_token_idx in range(max_length):
1341
- new_ids = self(
1342
- input_ids[~finished_generating],
1343
- attention_mask=attention_mask[~finished_generating]
1344
- )['logits']
1345
-
1346
- new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
1347
-
1348
- for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
1349
- base_answer_ids = input_ids[answer_idx]
1350
- new_answer_ids = new_ids[list_idx]
1351
- last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
1352
-
1353
- new_ids_sampled = torch.multinomial(
1354
- torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temp, dim=-1), 1)
1355
-
1356
- if last_token_idx + 1 >= len(base_answer_ids):
1357
- new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
1358
- input_ids = torch.cat([input_ids, new_padding], dim=-1)
1359
- attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
1360
-
1361
- attention_mask[answer_idx, last_token_idx + 1] = 1
1362
- input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
1363
-
1364
- if new_ids_sampled in [self.tokenizer.eos_token_id, self.tokenizer.bos_token_id, self.tokenizer.pad_token_id]:
1365
- finished_generating[answer_idx] = 1
1366
-
1367
- if finished_generating.all():
1368
- break
1369
-
1370
- return input_ids
1371
 
1372
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1373
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
1328
  def generate(self, input_ids, attention_mask=None, **kwargs):
1329
  if attention_mask is None:
1330
  attention_mask = torch.ones_like(input_ids)
1331
+
1332
  max_length = kwargs.get("max_length", 20)
1333
  temp = kwargs.get("temperature", 1.0)
1334
 
1335
+ with torch.no_grad():
1336
+ finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device)
1337
+ for cur_token_idx in range(max_length):
1338
+ # Sample the next token
1339
+ new_ids = self(
1340
+ input_ids[~finished_generating],
1341
+ attention_mask=attention_mask[~finished_generating]
1342
+ )['logits']
1343
+ # Mask out the start and end thought tokens so we don't accidentally sample them
1344
+ new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
1345
+ for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
1346
+ # Find the index of the last token that is not padding
1347
+ base_answer_ids = input_ids[answer_idx]
1348
+ new_answer_ids = new_ids[list_idx]
1349
+ last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
1350
+
1351
+ new_ids_sampled = torch.multinomial(
1352
+ torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temp, dim=-1), 1)
1353
+ # Assign the new id to the last token
1354
+ if last_token_idx + 1 >= len(base_answer_ids):
1355
+ # Add padding everywhere
1356
+ new_padding = torch.full((len(input_ids), 1), self.tokenizer.pad_token_id, dtype=torch.long,
1357
+ device=input_ids.device)
1358
+ input_ids = torch.cat([input_ids, new_padding], dim=-1)
1359
+ attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
1360
+ attention_mask[answer_idx, last_token_idx + 1] = 1
1361
+ input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
1362
+ if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
1363
+ finished_generating[answer_idx] = 1
1364
+ if finished_generating.all():
1365
+ break
1366
+ return input_ids, attention_mask
 
 
 
 
1367
 
1368
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1369
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)