Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +7 -6
modeling_quiet.py
CHANGED
@@ -1567,15 +1567,16 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1567 |
return x
|
1568 |
return x.repeat_interleave(n, dim=0)
|
1569 |
|
1570 |
-
if self.n_passes > 1:
|
1571 |
input_ids = none_repeat_interleave(input_ids, self.n_passes)
|
1572 |
-
attention_mask = none_repeat_interleave(attention_mask, self.n_passes)
|
1573 |
-
position_ids = none_repeat_interleave(position_ids, self.n_passes)
|
1574 |
-
inputs_embeds = none_repeat_interleave(inputs_embeds, self.n_passes)
|
1575 |
-
labels = none_repeat_interleave(labels, self.n_passes)
|
1576 |
if past_key_values is not None:
|
1577 |
past_key_values = [none_repeat_interleave(p, self.n_passes) for p in past_key_values]
|
1578 |
-
|
|
|
1579 |
|
1580 |
self.tokenizer_has_start_thought_token = True
|
1581 |
self.tokenizer_has_end_thought_token = True
|
|
|
1567 |
return x
|
1568 |
return x.repeat_interleave(n, dim=0)
|
1569 |
|
1570 |
+
if self.n_passes > 1 and input_ids is not None:
|
1571 |
input_ids = none_repeat_interleave(input_ids, self.n_passes)
|
1572 |
+
attention_mask = none_repeat_interleave(attention_mask, self.n_passes) if attention_mask is not None else None
|
1573 |
+
position_ids = none_repeat_interleave(position_ids, self.n_passes) if position_ids is not None else None
|
1574 |
+
inputs_embeds = none_repeat_interleave(inputs_embeds, self.n_passes) if inputs_embeds is not None else None
|
1575 |
+
labels = none_repeat_interleave(labels, self.n_passes) if labels is not None else None
|
1576 |
if past_key_values is not None:
|
1577 |
past_key_values = [none_repeat_interleave(p, self.n_passes) for p in past_key_values]
|
1578 |
+
if input_ids is not None:
|
1579 |
+
cur_token_indices = torch.arange(input_ids.shape[1], device=input_ids.device)
|
1580 |
|
1581 |
self.tokenizer_has_start_thought_token = True
|
1582 |
self.tokenizer_has_end_thought_token = True
|