Crystalcareai commited on
Commit
5bc3c10
·
verified ·
1 Parent(s): 43927de

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +7 -6
modeling_quiet.py CHANGED
@@ -1567,15 +1567,16 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1567
  return x
1568
  return x.repeat_interleave(n, dim=0)
1569
 
1570
- if self.n_passes > 1:
1571
  input_ids = none_repeat_interleave(input_ids, self.n_passes)
1572
- attention_mask = none_repeat_interleave(attention_mask, self.n_passes)
1573
- position_ids = none_repeat_interleave(position_ids, self.n_passes)
1574
- inputs_embeds = none_repeat_interleave(inputs_embeds, self.n_passes)
1575
- labels = none_repeat_interleave(labels, self.n_passes)
1576
  if past_key_values is not None:
1577
  past_key_values = [none_repeat_interleave(p, self.n_passes) for p in past_key_values]
1578
- cur_token_indices = torch.arange(input_ids.shape[1], device=input_ids.device)
 
1579
 
1580
  self.tokenizer_has_start_thought_token = True
1581
  self.tokenizer_has_end_thought_token = True
 
1567
  return x
1568
  return x.repeat_interleave(n, dim=0)
1569
 
1570
+ if self.n_passes > 1 and input_ids is not None:
1571
  input_ids = none_repeat_interleave(input_ids, self.n_passes)
1572
+ attention_mask = none_repeat_interleave(attention_mask, self.n_passes) if attention_mask is not None else None
1573
+ position_ids = none_repeat_interleave(position_ids, self.n_passes) if position_ids is not None else None
1574
+ inputs_embeds = none_repeat_interleave(inputs_embeds, self.n_passes) if inputs_embeds is not None else None
1575
+ labels = none_repeat_interleave(labels, self.n_passes) if labels is not None else None
1576
  if past_key_values is not None:
1577
  past_key_values = [none_repeat_interleave(p, self.n_passes) for p in past_key_values]
1578
+ if input_ids is not None:
1579
+ cur_token_indices = torch.arange(input_ids.shape[1], device=input_ids.device)
1580
 
1581
  self.tokenizer_has_start_thought_token = True
1582
  self.tokenizer_has_end_thought_token = True