Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +1 -76
modeling_quiet.py
CHANGED
@@ -2231,82 +2231,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
2231 |
)
|
2232 |
return reordered_past
|
2233 |
|
2234 |
-
|
2235 |
-
def generate(self, input_ids, attention_mask=None, **generate_kwargs):
|
2236 |
-
batch_size, seq_len = input_ids.shape
|
2237 |
-
|
2238 |
-
max_length = generate_kwargs.get("max_length", self.config.max_length)
|
2239 |
-
min_length = generate_kwargs.get("min_length", self.config.min_length)
|
2240 |
-
do_sample = generate_kwargs.get("do_sample", self.config.do_sample)
|
2241 |
-
early_stopping = generate_kwargs.get("early_stopping", self.config.early_stopping)
|
2242 |
-
num_beams = generate_kwargs.get("num_beams", self.config.num_beams)
|
2243 |
-
temperature = generate_kwargs.get("temperature", self.config.temperature)
|
2244 |
-
top_k = generate_kwargs.get("top_k", self.config.top_k)
|
2245 |
-
top_p = generate_kwargs.get("top_p", self.config.top_p)
|
2246 |
-
repetition_penalty = generate_kwargs.get("repetition_penalty", self.config.repetition_penalty)
|
2247 |
-
pad_token_id = generate_kwargs.get("pad_token_id", self.config.pad_token_id)
|
2248 |
-
eos_token_id = generate_kwargs.get("eos_token_id", self.config.eos_token_id)
|
2249 |
-
|
2250 |
-
# Prepend the start thought token to the input sequence
|
2251 |
-
start_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
|
2252 |
-
input_ids = torch.cat([input_ids, torch.tensor([[start_thought_token_id]], device=input_ids.device)], dim=-1)
|
2253 |
-
if attention_mask is not None:
|
2254 |
-
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1), device=attention_mask.device)], dim=-1)
|
2255 |
-
|
2256 |
-
thought_embeds = self.model.embed_tokens(input_ids)
|
2257 |
-
|
2258 |
-
past_key_values = None
|
2259 |
-
unfinished_sequences = input_ids.new(batch_size).fill_(1)
|
2260 |
-
sequence_lengths = input_ids.new(batch_size).fill_(max_length)
|
2261 |
-
|
2262 |
-
while True:
|
2263 |
-
model_inputs = self.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, attention_mask=attention_mask, use_cache=True)
|
2264 |
-
thought_outputs = self.model(**model_inputs, output_hidden_states=True, return_dict=True)
|
2265 |
-
next_thought_embeds = self.prepare_thought_embeds(thought_outputs.hidden_states[-1], temperature=temperature)
|
2266 |
-
|
2267 |
-
thought_embeds = torch.cat([thought_embeds, next_thought_embeds], dim=1)
|
2268 |
-
|
2269 |
-
lm_logits = self.lm_head(thought_outputs.last_hidden_state)
|
2270 |
-
lm_logits = lm_logits[:, -1, :] / temperature
|
2271 |
-
|
2272 |
-
if do_sample:
|
2273 |
-
next_tokens = torch.multinomial(F.softmax(lm_logits, dim=-1), num_samples=1).squeeze(1)
|
2274 |
-
else:
|
2275 |
-
next_tokens = torch.argmax(lm_logits, dim=-1)
|
2276 |
-
|
2277 |
-
# Update input_ids, attention_mask and past_key_values
|
2278 |
-
input_ids = torch.cat([input_ids, next_tokens.unsqueeze(-1)], dim=-1)
|
2279 |
-
if attention_mask is not None:
|
2280 |
-
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((batch_size, 1))], dim=-1)
|
2281 |
-
past_key_values = thought_outputs.past_key_values
|
2282 |
-
|
2283 |
-
# Check if generation is complete
|
2284 |
-
if eos_token_id is not None:
|
2285 |
-
unfinished_sequences = unfinished_sequences & (next_tokens != eos_token_id)
|
2286 |
-
unfinished_sequences = unfinished_sequences & (input_ids.shape[-1] < max_length)
|
2287 |
-
if unfinished_sequences.max() == 0:
|
2288 |
-
break
|
2289 |
-
elif input_ids.shape[-1] >= max_length:
|
2290 |
-
input_ids[:, 0] = eos_token_id
|
2291 |
-
break
|
2292 |
-
|
2293 |
-
return input_ids
|
2294 |
-
|
2295 |
-
def prepare_thought_embeds(self, hidden_states, temperature=1.0):
|
2296 |
-
batch_size, seq_len, hidden_size = hidden_states.shape
|
2297 |
-
|
2298 |
-
if self.use_start_thought_token:
|
2299 |
-
start_embed = self.start_embedding[0].unsqueeze(0).unsqueeze(0).repeat(batch_size, 1, 1) * temperature
|
2300 |
-
else:
|
2301 |
-
start_embed = hidden_states[:, :1, :]
|
2302 |
-
|
2303 |
-
if self.use_end_thought_token:
|
2304 |
-
end_embed = self.end_embedding[0].unsqueeze(0).unsqueeze(0).repeat(batch_size, 1, 1) * temperature
|
2305 |
-
thought_embeds = torch.cat([start_embed, hidden_states[:, 1:-1, :], end_embed], dim=1)
|
2306 |
-
else:
|
2307 |
-
thought_embeds = torch.cat([start_embed, hidden_states[:, 1:, :]], dim=1)
|
2308 |
-
|
2309 |
-
return thought_embeds
|
2310 |
|
2311 |
|
2312 |
@add_start_docstrings(
|
|
|
2231 |
)
|
2232 |
return reordered_past
|
2233 |
|
2234 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2235 |
|
2236 |
|
2237 |
@add_start_docstrings(
|