Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +121 -50
modeling_quiet.py
CHANGED
@@ -37,7 +37,7 @@ import transformers
|
|
37 |
|
38 |
from transformers.activations import ACT2FN
|
39 |
from transformers.cache_utils import Cache, DynamicCache
|
40 |
-
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
41 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
42 |
from transformers.modeling_utils import PreTrainedModel
|
43 |
from transformers.utils import (
|
@@ -1110,7 +1110,126 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1110 |
# Apply the language model head to get the final logits
|
1111 |
logits = self.lm_head(mixed_hidden_states)
|
1112 |
return logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1114 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1115 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1116 |
def forward(
|
@@ -1137,7 +1256,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1137 |
top_p: Optional[float] = None,
|
1138 |
min_p: Optional[float] = None,
|
1139 |
top_k: Optional[int] = None,
|
1140 |
-
cache_position: Optional[bool] = None,
|
1141 |
repetition_penalty: Optional[float] = None,
|
1142 |
presence_penalty: Optional[float] = None,
|
1143 |
frequency_penalty: Optional[float] = None,
|
@@ -1412,17 +1530,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1412 |
past_key_values_length,
|
1413 |
sliding_window=self.config.sliding_window,
|
1414 |
)
|
1415 |
-
|
1416 |
-
if attention_mask.dim() == 2:
|
1417 |
-
# Expand the attention mask to have dimensions (batch_size, 1, 1, seq_length)
|
1418 |
-
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
1419 |
-
elif attention_mask.dim() == 3:
|
1420 |
-
# Expand the attention mask to have dimensions (batch_size, 1, seq_length, seq_length)
|
1421 |
-
attention_mask = attention_mask.unsqueeze(1)
|
1422 |
-
else:
|
1423 |
-
raise ValueError(
|
1424 |
-
f"Attention mask should have 2 or 3 dimensions, but got {attention_mask.dim()} dimensions."
|
1425 |
-
)
|
1426 |
outputs = self.model(
|
1427 |
# input_ids=input_ids,
|
1428 |
attention_mask=attention_mask,
|
@@ -1861,43 +1969,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1861 |
hidden_states=outputs.hidden_states,
|
1862 |
attentions=outputs.attentions,
|
1863 |
)
|
1864 |
-
|
1865 |
-
|
1866 |
-
from .generate import custom_generate
|
1867 |
-
|
1868 |
-
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
1869 |
-
return {"input_ids": input_ids, **kwargs}
|
1870 |
-
|
1871 |
-
def _generate_no_beam_search(
|
1872 |
-
self,
|
1873 |
-
input_ids,
|
1874 |
-
cur_len,
|
1875 |
-
max_length,
|
1876 |
-
min_length,
|
1877 |
-
do_sample,
|
1878 |
-
temperature,
|
1879 |
-
top_k,
|
1880 |
-
top_p,
|
1881 |
-
repetition_penalty,
|
1882 |
-
no_repeat_ngram_size,
|
1883 |
-
bad_words_ids,
|
1884 |
-
pad_token_id,
|
1885 |
-
eos_token_id,
|
1886 |
-
batch_size,
|
1887 |
-
attention_mask,
|
1888 |
-
use_cache,
|
1889 |
-
model_kwargs,
|
1890 |
-
):
|
1891 |
-
generated_token_ids = custom_generate(
|
1892 |
-
self,
|
1893 |
-
input_ids=input_ids,
|
1894 |
-
attention_mask=attention_mask,
|
1895 |
-
max_new_tokens=max_length - cur_len,
|
1896 |
-
temperature=temperature,
|
1897 |
-
**model_kwargs,
|
1898 |
-
)
|
1899 |
-
|
1900 |
-
return generated_token_ids
|
1901 |
|
1902 |
@staticmethod
|
1903 |
def _reorder_cache(past_key_values, beam_idx):
|
|
|
37 |
|
38 |
from transformers.activations import ACT2FN
|
39 |
from transformers.cache_utils import Cache, DynamicCache
|
40 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
41 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
42 |
from transformers.modeling_utils import PreTrainedModel
|
43 |
from transformers.utils import (
|
|
|
1110 |
# Apply the language model head to get the final logits
|
1111 |
logits = self.lm_head(mixed_hidden_states)
|
1112 |
return logits
|
1113 |
+
|
1114 |
+
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
1115 |
+
return {"input_ids": input_ids}
|
1116 |
+
|
1117 |
+
def _generate_no_beam_search(
|
1118 |
+
self,
|
1119 |
+
input_ids,
|
1120 |
+
cur_len,
|
1121 |
+
max_length,
|
1122 |
+
min_length,
|
1123 |
+
do_sample,
|
1124 |
+
temperature,
|
1125 |
+
top_k,
|
1126 |
+
top_p,
|
1127 |
+
repetition_penalty,
|
1128 |
+
no_repeat_ngram_size,
|
1129 |
+
bad_words_ids,
|
1130 |
+
pad_token_id,
|
1131 |
+
eos_token_id,
|
1132 |
+
batch_size,
|
1133 |
+
attention_mask,
|
1134 |
+
use_cache,
|
1135 |
+
model_kwargs,
|
1136 |
+
):
|
1137 |
+
if input_ids is None or input_ids.nelement() == 0:
|
1138 |
+
input_ids = torch.LongTensor([[self.tokenizer.bos_token_id]]).to(self.device)
|
1139 |
+
attention_mask = torch.ones_like(input_ids).to(self.device)
|
1140 |
+
|
1141 |
+
device = input_ids.device
|
1142 |
+
with torch.no_grad():
|
1143 |
+
batch_size = input_ids.shape[0]
|
1144 |
+
finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
1145 |
+
generated_token_ids = torch.full((batch_size, max_length - cur_len), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
|
1146 |
+
|
1147 |
+
for cur_token_idx in range(max_length - cur_len):
|
1148 |
+
new_ids = self(
|
1149 |
+
input_ids[~finished_generating],
|
1150 |
+
attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None,
|
1151 |
+
**model_kwargs
|
1152 |
+
)['logits']
|
1153 |
+
|
1154 |
+
new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
|
1155 |
+
|
1156 |
+
for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
|
1157 |
+
base_answer_ids = input_ids[answer_idx]
|
1158 |
+
new_answer_ids = new_ids[list_idx]
|
1159 |
+
last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
|
1160 |
+
|
1161 |
+
new_ids_sampled = torch.multinomial(
|
1162 |
+
torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1)
|
1163 |
+
|
1164 |
+
if last_token_idx + 1 >= len(base_answer_ids):
|
1165 |
+
new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long,
|
1166 |
+
device=device)
|
1167 |
+
input_ids = torch.cat([input_ids, new_padding], dim=-1)
|
1168 |
+
if attention_mask is not None:
|
1169 |
+
attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
|
1170 |
+
|
1171 |
+
if attention_mask is not None:
|
1172 |
+
attention_mask[answer_idx, last_token_idx + 1] = 1
|
1173 |
+
input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
|
1174 |
+
generated_token_ids[answer_idx, cur_token_idx] = new_ids_sampled
|
1175 |
+
|
1176 |
+
if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
|
1177 |
+
finished_generating[answer_idx] = 1
|
1178 |
+
|
1179 |
+
if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("</s>"):
|
1180 |
+
finished_generating[answer_idx] = 1
|
1181 |
+
|
1182 |
+
if finished_generating.all():
|
1183 |
+
break
|
1184 |
+
|
1185 |
+
return generated_token_ids
|
1186 |
+
|
1187 |
+
@torch.no_grad()
|
1188 |
+
def generate(
|
1189 |
+
self,
|
1190 |
+
input_ids: torch.LongTensor = torch.LongTensor(),
|
1191 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1192 |
+
max_new_tokens: Optional[int] = None,
|
1193 |
+
temperature: float = 1.1,
|
1194 |
+
**kwargs,
|
1195 |
+
):
|
1196 |
+
if isinstance(input_ids, str):
|
1197 |
+
input_ids = self.tokenizer(input_ids, return_tensors="pt").input_ids
|
1198 |
+
|
1199 |
+
if attention_mask is None:
|
1200 |
+
attention_mask = torch.ones_like(input_ids)
|
1201 |
+
|
1202 |
+
max_length = max_new_tokens + input_ids.shape[1] if max_new_tokens is not None else None
|
1203 |
+
|
1204 |
+
# Set model attributes
|
1205 |
+
self.max_thoughts = kwargs.get('n_ahead', 4) + kwargs.get('n_ahead_talk', 4) + 1
|
1206 |
+
self.merged_talk_heads = kwargs.get('merged_talk_heads', True)
|
1207 |
+
self.merged_lm_and_talk_heads = kwargs.get('merged_lm_and_talk_heads', False)
|
1208 |
+
self.merged_lm_and_think_heads = kwargs.get('merged_lm_and_think_heads', True)
|
1209 |
+
self.use_concat_talk_head = kwargs.get('use_concat_talk_head', True)
|
1210 |
+
self.use_shallow_think = kwargs.get('use_shallow_think', True)
|
1211 |
+
self.use_shallow_talk = kwargs.get('use_shallow_talk', False)
|
1212 |
+
self.use_complex_think_head = kwargs.get('use_complex_think_head', False)
|
1213 |
+
self.use_complex_talk_head = kwargs.get('use_complex_talk_head', True)
|
1214 |
+
self.use_weighted_talk_head = kwargs.get('use_weighted_talk_head', True)
|
1215 |
+
|
1216 |
+
# Set model properties
|
1217 |
+
self.use_end_thought_token = True
|
1218 |
+
self.use_start_thought_token = True
|
1219 |
+
self.n_ahead = kwargs.get('n_ahead', 4)
|
1220 |
+
self.n_passes = 1
|
1221 |
+
self.eval_mode = True
|
1222 |
+
self.first_run = False
|
1223 |
+
self.rm_initialized = True
|
1224 |
+
self.original_mode = False
|
1225 |
|
1226 |
+
return super().generate(
|
1227 |
+
input_ids,
|
1228 |
+
attention_mask=attention_mask,
|
1229 |
+
max_length=max_length,
|
1230 |
+
temperature=temperature,
|
1231 |
+
**kwargs,
|
1232 |
+
)
|
1233 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1234 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1235 |
def forward(
|
|
|
1256 |
top_p: Optional[float] = None,
|
1257 |
min_p: Optional[float] = None,
|
1258 |
top_k: Optional[int] = None,
|
|
|
1259 |
repetition_penalty: Optional[float] = None,
|
1260 |
presence_penalty: Optional[float] = None,
|
1261 |
frequency_penalty: Optional[float] = None,
|
|
|
1530 |
past_key_values_length,
|
1531 |
sliding_window=self.config.sliding_window,
|
1532 |
)
|
1533 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1534 |
outputs = self.model(
|
1535 |
# input_ids=input_ids,
|
1536 |
attention_mask=attention_mask,
|
|
|
1969 |
hidden_states=outputs.hidden_states,
|
1970 |
attentions=outputs.attentions,
|
1971 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1972 |
|
1973 |
@staticmethod
|
1974 |
def _reorder_cache(past_key_values, beam_idx):
|