Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +96 -29
modeling_quiet.py
CHANGED
@@ -2128,6 +2128,50 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
2128 |
del start_embedding
|
2129 |
del end_embedding
|
2130 |
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2131 |
|
2132 |
return CausalLMOutputWithPast(
|
2133 |
loss=loss if loss is not None else None,
|
@@ -2169,36 +2213,59 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
2169 |
def prepare_inputs_for_generation(
|
2170 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
2171 |
):
|
2172 |
-
|
2173 |
-
|
2174 |
-
|
2175 |
-
|
2176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2177 |
else:
|
2178 |
-
|
2179 |
-
|
2180 |
-
|
2181 |
-
|
2182 |
-
|
2183 |
-
|
2184 |
-
|
2185 |
-
|
2186 |
-
|
2187 |
-
|
2188 |
-
|
2189 |
-
)
|
2190 |
-
|
2191 |
-
# Expand the attention mask to the correct shape
|
2192 |
-
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
2193 |
-
attention_mask = attention_mask.expand(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1])
|
2194 |
-
|
2195 |
-
return {
|
2196 |
-
"input_ids": input_ids,
|
2197 |
-
"past_key_values": past_key_values,
|
2198 |
-
"use_cache": kwargs.get("use_cache"),
|
2199 |
-
"attention_mask": attention_mask,
|
2200 |
-
"inputs_embeds": inputs_embeds,
|
2201 |
-
}
|
2202 |
|
2203 |
@staticmethod
|
2204 |
def _reorder_cache(past_key_values, beam_idx):
|
|
|
2128 |
del start_embedding
|
2129 |
del end_embedding
|
2130 |
torch.cuda.empty_cache()
|
2131 |
+
|
2132 |
+
if not self.training:
|
2133 |
+
# Inference mode
|
2134 |
+
if max_length is None:
|
2135 |
+
max_length = self.config.max_length
|
2136 |
+
|
2137 |
+
finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
|
2138 |
+
for cur_token_idx in range(max_length):
|
2139 |
+
outputs = self.model(
|
2140 |
+
input_ids=input_ids,
|
2141 |
+
attention_mask=attention_mask,
|
2142 |
+
position_ids=position_ids,
|
2143 |
+
past_key_values=past_key_values,
|
2144 |
+
inputs_embeds=inputs_embeds,
|
2145 |
+
use_cache=use_cache,
|
2146 |
+
output_attentions=output_attentions,
|
2147 |
+
output_hidden_states=output_hidden_states,
|
2148 |
+
return_dict=return_dict,
|
2149 |
+
)
|
2150 |
+
hidden_states = outputs[0]
|
2151 |
+
logits = self.lm_head(hidden_states)
|
2152 |
+
|
2153 |
+
# Mask out the start and end thought tokens
|
2154 |
+
logits[:, :, self.start_token_id] = -float("inf")
|
2155 |
+
logits[:, :, self.end_token_id] = -float("inf")
|
2156 |
+
|
2157 |
+
for batch_idx in range(batch_size):
|
2158 |
+
if not finished_generating[batch_idx]:
|
2159 |
+
last_token_idx = (input_ids[batch_idx] != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
|
2160 |
+
new_id_sampled = torch.multinomial(
|
2161 |
+
torch.nn.functional.softmax(logits[batch_idx, last_token_idx] / temperature, dim=-1), 1
|
2162 |
+
)
|
2163 |
+
if last_token_idx + 1 >= input_ids.shape[1]:
|
2164 |
+
# Add padding
|
2165 |
+
new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long, device=input_ids.device)
|
2166 |
+
input_ids = torch.cat([input_ids, new_padding], dim=-1)
|
2167 |
+
attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
|
2168 |
+
attention_mask[batch_idx, last_token_idx + 1] = 1
|
2169 |
+
input_ids[batch_idx, last_token_idx + 1] = new_id_sampled
|
2170 |
+
if new_id_sampled == self.tokenizer.eos_token_id or new_id_sampled == self.tokenizer.bos_token_id or new_id_sampled == self.tokenizer.pad_token_id:
|
2171 |
+
finished_generating[batch_idx] = True
|
2172 |
+
|
2173 |
+
if finished_generating.all():
|
2174 |
+
break
|
2175 |
|
2176 |
return CausalLMOutputWithPast(
|
2177 |
loss=loss if loss is not None else None,
|
|
|
2213 |
def prepare_inputs_for_generation(
|
2214 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
2215 |
):
|
2216 |
+
# Omit tokens covered by past_key_values
|
2217 |
+
if past_key_values is not None:
|
2218 |
+
if isinstance(past_key_values, Cache):
|
2219 |
+
cache_length = past_key_values.get_seq_length()
|
2220 |
+
past_length = past_key_values.seen_tokens
|
2221 |
+
max_cache_length = past_key_values.get_max_length()
|
2222 |
+
else:
|
2223 |
+
cache_length = past_length = past_key_values[0][0].shape[2]
|
2224 |
+
max_cache_length = None
|
2225 |
+
|
2226 |
+
# Keep only the unprocessed tokens:
|
2227 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
2228 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing inputs_embeds as
|
2229 |
+
# input)
|
2230 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
2231 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
2232 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
2233 |
+
# input_ids based on the past_length.
|
2234 |
+
elif past_length < input_ids.shape[1]:
|
2235 |
+
input_ids = input_ids[:, past_length:]
|
2236 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
2237 |
+
|
2238 |
+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
2239 |
+
if (
|
2240 |
+
max_cache_length is not None
|
2241 |
+
and attention_mask is not None
|
2242 |
+
and cache_length + input_ids.shape[1] > max_cache_length
|
2243 |
+
):
|
2244 |
+
attention_mask = attention_mask[:, -max_cache_length:]
|
2245 |
+
|
2246 |
+
position_ids = kwargs.get("position_ids", None)
|
2247 |
+
if attention_mask is not None and position_ids is None:
|
2248 |
+
# create position_ids on the fly for batch generation
|
2249 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
2250 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
2251 |
+
if past_key_values:
|
2252 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
2253 |
+
|
2254 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
2255 |
+
if inputs_embeds is not None and past_key_values is None:
|
2256 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
2257 |
else:
|
2258 |
+
model_inputs = {"input_ids": input_ids}
|
2259 |
+
|
2260 |
+
model_inputs.update(
|
2261 |
+
{
|
2262 |
+
"position_ids": position_ids,
|
2263 |
+
"past_key_values": past_key_values,
|
2264 |
+
"use_cache": kwargs.get("use_cache"),
|
2265 |
+
"attention_mask": attention_mask,
|
2266 |
+
}
|
2267 |
+
)
|
2268 |
+
return model_inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2269 |
|
2270 |
@staticmethod
|
2271 |
def _reorder_cache(past_key_values, beam_idx):
|