Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +26 -54
modeling_quiet.py
CHANGED
@@ -41,6 +41,7 @@ import torch.nn.functional as F
|
|
41 |
import torch.utils.checkpoint
|
42 |
from torch import nn
|
43 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
44 |
|
45 |
from transformers.activations import ACT2FN
|
46 |
from transformers.cache_utils import Cache, DynamicCache
|
@@ -1166,7 +1167,7 @@ def nonzero_mean(x, axis=None):
|
|
1166 |
def loss_mean(x):
|
1167 |
return x.sum() / (x != 0).sum()
|
1168 |
|
1169 |
-
class QuietForCausalLM(QuietPreTrainedModel):
|
1170 |
_tied_weights_keys = ["lm_head.weight"]
|
1171 |
|
1172 |
def __init__(self, config):
|
@@ -2168,59 +2169,30 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
2168 |
def prepare_inputs_for_generation(
|
2169 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
2170 |
):
|
2171 |
-
|
2172 |
-
|
2173 |
-
|
2174 |
-
|
2175 |
-
|
2176 |
-
|
2177 |
-
|
2178 |
-
|
2179 |
-
|
2180 |
-
|
2181 |
-
|
2182 |
-
|
2183 |
-
|
2184 |
-
|
2185 |
-
|
2186 |
-
|
2187 |
-
|
2188 |
-
|
2189 |
-
|
2190 |
-
|
2191 |
-
|
2192 |
-
|
2193 |
-
|
2194 |
-
|
2195 |
-
max_cache_length is not None
|
2196 |
-
and attention_mask is not None
|
2197 |
-
and cache_length + input_ids.shape[1] > max_cache_length
|
2198 |
-
):
|
2199 |
-
attention_mask = attention_mask[:, -max_cache_length:]
|
2200 |
-
|
2201 |
-
position_ids = kwargs.get("position_ids", None)
|
2202 |
-
if attention_mask is not None and position_ids is None:
|
2203 |
-
# create position_ids on the fly for batch generation
|
2204 |
-
position_ids = attention_mask.long().cumsum(-1) - 1
|
2205 |
-
position_ids.masked_fill_(attention_mask == 0, 1)
|
2206 |
-
if past_key_values:
|
2207 |
-
position_ids = position_ids[:, -input_ids.shape[1] :]
|
2208 |
-
|
2209 |
-
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
2210 |
-
if inputs_embeds is not None and past_key_values is None:
|
2211 |
-
model_inputs = {"inputs_embeds": inputs_embeds}
|
2212 |
-
else:
|
2213 |
-
model_inputs = {"input_ids": input_ids}
|
2214 |
-
|
2215 |
-
model_inputs.update(
|
2216 |
-
{
|
2217 |
-
"position_ids": position_ids,
|
2218 |
-
"past_key_values": past_key_values,
|
2219 |
-
"use_cache": kwargs.get("use_cache"),
|
2220 |
-
"attention_mask": attention_mask,
|
2221 |
-
}
|
2222 |
-
)
|
2223 |
-
return model_inputs
|
2224 |
|
2225 |
@staticmethod
|
2226 |
def _reorder_cache(past_key_values, beam_idx):
|
|
|
41 |
import torch.utils.checkpoint
|
42 |
from torch import nn
|
43 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
44 |
+
from transformers import GenerationMixin
|
45 |
|
46 |
from transformers.activations import ACT2FN
|
47 |
from transformers.cache_utils import Cache, DynamicCache
|
|
|
1167 |
def loss_mean(x):
|
1168 |
return x.sum() / (x != 0).sum()
|
1169 |
|
1170 |
+
class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
1171 |
_tied_weights_keys = ["lm_head.weight"]
|
1172 |
|
1173 |
def __init__(self, config):
|
|
|
2169 |
def prepare_inputs_for_generation(
|
2170 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
2171 |
):
|
2172 |
+
if past_key_values:
|
2173 |
+
input_ids = input_ids[:, -1:]
|
2174 |
+
|
2175 |
+
if attention_mask is None:
|
2176 |
+
attention_mask = torch.ones_like(input_ids)
|
2177 |
+
|
2178 |
+
if self.use_start_thought_token:
|
2179 |
+
start_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
|
2180 |
+
input_ids = torch.cat(
|
2181 |
+
[input_ids, torch.tensor([[start_thought_token_id]] * input_ids.shape[0], device=input_ids.device)],
|
2182 |
+
dim=-1
|
2183 |
+
)
|
2184 |
+
attention_mask = torch.cat(
|
2185 |
+
[attention_mask, torch.ones((input_ids.shape[0], 1), device=attention_mask.device)],
|
2186 |
+
dim=-1
|
2187 |
+
)
|
2188 |
+
|
2189 |
+
return {
|
2190 |
+
"input_ids": input_ids,
|
2191 |
+
"past_key_values": past_key_values,
|
2192 |
+
"use_cache": kwargs.get("use_cache"),
|
2193 |
+
"attention_mask": attention_mask,
|
2194 |
+
"inputs_embeds": inputs_embeds,
|
2195 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2196 |
|
2197 |
@staticmethod
|
2198 |
def _reorder_cache(past_key_values, beam_idx):
|