Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +77 -0
modeling_quiet.py
CHANGED
@@ -2136,6 +2136,8 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
2136 |
attentions=outputs.attentions,
|
2137 |
)
|
2138 |
|
|
|
|
|
2139 |
def compute_complexity_scores(self, input_ids, attention_mask):
|
2140 |
# Compute complexity scores based on input sequence characteristics
|
2141 |
# Example: Normalize sequence lengths and consider the presence of rare tokens
|
@@ -2228,6 +2230,81 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
2228 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
2229 |
)
|
2230 |
return reordered_past
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2231 |
|
2232 |
|
2233 |
@add_start_docstrings(
|
|
|
2136 |
attentions=outputs.attentions,
|
2137 |
)
|
2138 |
|
2139 |
+
|
2140 |
+
|
2141 |
def compute_complexity_scores(self, input_ids, attention_mask):
|
2142 |
# Compute complexity scores based on input sequence characteristics
|
2143 |
# Example: Normalize sequence lengths and consider the presence of rare tokens
|
|
|
2230 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
2231 |
)
|
2232 |
return reordered_past
|
2233 |
+
|
2234 |
+
@torch.no_grad()
|
2235 |
+
def generate(self, input_ids, attention_mask=None, **generate_kwargs):
|
2236 |
+
batch_size, seq_len = input_ids.shape
|
2237 |
+
|
2238 |
+
max_length = generate_kwargs.get("max_length", self.config.max_length)
|
2239 |
+
min_length = generate_kwargs.get("min_length", self.config.min_length)
|
2240 |
+
do_sample = generate_kwargs.get("do_sample", self.config.do_sample)
|
2241 |
+
early_stopping = generate_kwargs.get("early_stopping", self.config.early_stopping)
|
2242 |
+
num_beams = generate_kwargs.get("num_beams", self.config.num_beams)
|
2243 |
+
temperature = generate_kwargs.get("temperature", self.config.temperature)
|
2244 |
+
top_k = generate_kwargs.get("top_k", self.config.top_k)
|
2245 |
+
top_p = generate_kwargs.get("top_p", self.config.top_p)
|
2246 |
+
repetition_penalty = generate_kwargs.get("repetition_penalty", self.config.repetition_penalty)
|
2247 |
+
pad_token_id = generate_kwargs.get("pad_token_id", self.config.pad_token_id)
|
2248 |
+
eos_token_id = generate_kwargs.get("eos_token_id", self.config.eos_token_id)
|
2249 |
+
|
2250 |
+
# Prepend the start thought token to the input sequence
|
2251 |
+
start_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
|
2252 |
+
input_ids = torch.cat([input_ids, torch.tensor([[start_thought_token_id]], device=input_ids.device)], dim=-1)
|
2253 |
+
if attention_mask is not None:
|
2254 |
+
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1), device=attention_mask.device)], dim=-1)
|
2255 |
+
|
2256 |
+
thought_embeds = self.model.embed_tokens(input_ids)
|
2257 |
+
|
2258 |
+
past_key_values = None
|
2259 |
+
unfinished_sequences = input_ids.new(batch_size).fill_(1)
|
2260 |
+
sequence_lengths = input_ids.new(batch_size).fill_(max_length)
|
2261 |
+
|
2262 |
+
while True:
|
2263 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, attention_mask=attention_mask, use_cache=True)
|
2264 |
+
thought_outputs = self.model(**model_inputs, output_hidden_states=True, return_dict=True)
|
2265 |
+
next_thought_embeds = self.prepare_thought_embeds(thought_outputs.hidden_states[-1], temperature=temperature)
|
2266 |
+
|
2267 |
+
thought_embeds = torch.cat([thought_embeds, next_thought_embeds], dim=1)
|
2268 |
+
|
2269 |
+
lm_logits = self.lm_head(thought_outputs.last_hidden_state)
|
2270 |
+
lm_logits = lm_logits[:, -1, :] / temperature
|
2271 |
+
|
2272 |
+
if do_sample:
|
2273 |
+
next_tokens = torch.multinomial(F.softmax(lm_logits, dim=-1), num_samples=1).squeeze(1)
|
2274 |
+
else:
|
2275 |
+
next_tokens = torch.argmax(lm_logits, dim=-1)
|
2276 |
+
|
2277 |
+
# Update input_ids, attention_mask and past_key_values
|
2278 |
+
input_ids = torch.cat([input_ids, next_tokens.unsqueeze(-1)], dim=-1)
|
2279 |
+
if attention_mask is not None:
|
2280 |
+
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((batch_size, 1))], dim=-1)
|
2281 |
+
past_key_values = thought_outputs.past_key_values
|
2282 |
+
|
2283 |
+
# Check if generation is complete
|
2284 |
+
if eos_token_id is not None:
|
2285 |
+
unfinished_sequences = unfinished_sequences & (next_tokens != eos_token_id)
|
2286 |
+
unfinished_sequences = unfinished_sequences & (input_ids.shape[-1] < max_length)
|
2287 |
+
if unfinished_sequences.max() == 0:
|
2288 |
+
break
|
2289 |
+
elif input_ids.shape[-1] >= max_length:
|
2290 |
+
input_ids[:, 0] = eos_token_id
|
2291 |
+
break
|
2292 |
+
|
2293 |
+
return input_ids
|
2294 |
+
|
2295 |
+
def prepare_thought_embeds(self, hidden_states, temperature=1.0):
|
2296 |
+
if self.use_start_thought_token:
|
2297 |
+
start_embed = self.start_embedding[0].unsqueeze(0) * temperature
|
2298 |
+
else:
|
2299 |
+
start_embed = hidden_states[:, 0, :]
|
2300 |
+
|
2301 |
+
if self.use_end_thought_token:
|
2302 |
+
end_embed = self.end_embedding[0].unsqueeze(0) * temperature
|
2303 |
+
thought_embeds = torch.cat([start_embed, hidden_states[:, 1:-1, :], end_embed], dim=1)
|
2304 |
+
else:
|
2305 |
+
thought_embeds = torch.cat([start_embed, hidden_states[:, 1:, :]], dim=1)
|
2306 |
+
|
2307 |
+
return thought_embeds
|
2308 |
|
2309 |
|
2310 |
@add_start_docstrings(
|