Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +68 -21
modeling_quiet.py
CHANGED
@@ -270,10 +270,22 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
270 |
|
271 |
|
272 |
class QuietAttention(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
273 |
def __init__(self, config: QuietConfig, layer_idx: Optional[int] = None):
|
274 |
super().__init__()
|
275 |
self.config = config
|
276 |
self.layer_idx = layer_idx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
self.hidden_size = config.hidden_size
|
278 |
self.num_heads = config.num_attention_heads
|
279 |
self.head_dim = self.hidden_size // self.num_heads
|
@@ -284,6 +296,11 @@ class QuietAttention(nn.Module):
|
|
284 |
self.is_causal = True
|
285 |
self.attention_dropout = config.attention_dropout
|
286 |
|
|
|
|
|
|
|
|
|
|
|
287 |
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
288 |
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
289 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
@@ -295,6 +312,9 @@ class QuietAttention(nn.Module):
|
|
295 |
base=self.rope_theta,
|
296 |
)
|
297 |
|
|
|
|
|
|
|
298 |
def forward(
|
299 |
self,
|
300 |
hidden_states: torch.Tensor,
|
@@ -304,7 +324,11 @@ class QuietAttention(nn.Module):
|
|
304 |
output_attentions: bool = False,
|
305 |
use_cache: bool = False,
|
306 |
**kwargs,
|
307 |
-
):
|
|
|
|
|
|
|
|
|
308 |
bsz, q_len, _ = hidden_states.size()
|
309 |
|
310 |
query_states = self.q_proj(hidden_states)
|
@@ -318,30 +342,50 @@ class QuietAttention(nn.Module):
|
|
318 |
kv_seq_len = key_states.shape[-2]
|
319 |
if past_key_value is not None:
|
320 |
if self.layer_idx is None:
|
321 |
-
raise ValueError(
|
|
|
|
|
|
|
|
|
322 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
323 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
324 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
325 |
|
326 |
if past_key_value is not None:
|
327 |
-
cache_kwargs = {"sin": sin, "cos": cos}
|
328 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
329 |
|
|
|
330 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
331 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
332 |
|
333 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
334 |
|
|
|
|
|
|
|
|
|
|
|
|
|
335 |
if attention_mask is not None:
|
336 |
-
if attention_mask.size(
|
337 |
-
|
338 |
-
|
|
|
|
|
339 |
attn_weights = attn_weights + attention_mask
|
340 |
|
|
|
341 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
342 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
343 |
attn_output = torch.matmul(attn_weights, value_states)
|
344 |
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
346 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
347 |
|
@@ -1017,12 +1061,6 @@ class QuietModel(QuietPreTrainedModel):
|
|
1017 |
|
1018 |
if inputs_embeds is None:
|
1019 |
inputs_embeds = self.embed_tokens(input_ids)
|
1020 |
-
|
1021 |
-
if attention_mask is not None:
|
1022 |
-
attention_mask = attention_mask.unsqueeze(-1)
|
1023 |
-
inputs_embeds = inputs_embeds * attention_mask
|
1024 |
-
|
1025 |
-
hidden_states = inputs_embeds
|
1026 |
|
1027 |
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
|
1028 |
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
@@ -1045,7 +1083,16 @@ class QuietModel(QuietPreTrainedModel):
|
|
1045 |
inputs_embeds,
|
1046 |
past_key_values_length,
|
1047 |
)
|
1048 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1049 |
hidden_states = inputs_embeds
|
1050 |
|
1051 |
# decoder layers
|
@@ -1261,7 +1308,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1261 |
output_attentions: Optional[bool] = None,
|
1262 |
output_hidden_states: Optional[bool] = None,
|
1263 |
return_dict: Optional[bool] = None,
|
1264 |
-
max_length: Optional[int] = None, # Add the max_length argument
|
1265 |
):
|
1266 |
batch_size, seq_len = input_ids.shape
|
1267 |
|
@@ -1279,12 +1325,11 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1279 |
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
1280 |
|
1281 |
# Generate the continuation
|
1282 |
-
|
1283 |
-
max_length = self.n_ahead - 2 # Use the default value if max_length is not provided
|
1284 |
new_key_values = past_key_values
|
|
|
1285 |
start_time = time.time()
|
1286 |
-
|
1287 |
-
for continuation_idx in range(max_length):
|
1288 |
outputs = self.model(
|
1289 |
input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
|
1290 |
attention_mask=attention_mask,
|
@@ -1297,7 +1342,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1297 |
return_dict=return_dict,
|
1298 |
)
|
1299 |
new_key_values = outputs.past_key_values
|
|
|
1300 |
hidden_states = outputs[0]
|
|
|
1301 |
logits = self.lm_head(hidden_states)
|
1302 |
logits = logits[:, -1, :] # Only consider the last token
|
1303 |
|
@@ -1336,9 +1383,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1336 |
)
|
1337 |
hidden_states_before = outputs_before[0][:, -1:, :]
|
1338 |
|
1339 |
-
#
|
1340 |
outputs_after = self.model(
|
1341 |
-
input_ids=input_ids,
|
1342 |
attention_mask=attention_mask,
|
1343 |
position_ids=position_ids,
|
1344 |
past_key_values=new_key_values,
|
@@ -1358,8 +1405,8 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1358 |
|
1359 |
# Apply the language model head to get the final logits
|
1360 |
logits = self.lm_head(mixed_hidden_states)
|
1361 |
-
|
1362 |
return logits
|
|
|
1363 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1364 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1365 |
def forward(
|
|
|
270 |
|
271 |
|
272 |
class QuietAttention(nn.Module):
|
273 |
+
"""
|
274 |
+
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
275 |
+
and "Generating Long Sequences with Sparse Transformers".
|
276 |
+
"""
|
277 |
+
|
278 |
def __init__(self, config: QuietConfig, layer_idx: Optional[int] = None):
|
279 |
super().__init__()
|
280 |
self.config = config
|
281 |
self.layer_idx = layer_idx
|
282 |
+
if layer_idx is None:
|
283 |
+
logger.warning_once(
|
284 |
+
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
285 |
+
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
286 |
+
"when creating this class."
|
287 |
+
)
|
288 |
+
|
289 |
self.hidden_size = config.hidden_size
|
290 |
self.num_heads = config.num_attention_heads
|
291 |
self.head_dim = self.hidden_size // self.num_heads
|
|
|
296 |
self.is_causal = True
|
297 |
self.attention_dropout = config.attention_dropout
|
298 |
|
299 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
300 |
+
raise ValueError(
|
301 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
302 |
+
f" and `num_heads`: {self.num_heads})."
|
303 |
+
)
|
304 |
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
305 |
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
306 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
|
|
312 |
base=self.rope_theta,
|
313 |
)
|
314 |
|
315 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
316 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
317 |
+
|
318 |
def forward(
|
319 |
self,
|
320 |
hidden_states: torch.Tensor,
|
|
|
324 |
output_attentions: bool = False,
|
325 |
use_cache: bool = False,
|
326 |
**kwargs,
|
327 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
328 |
+
if "padding_mask" in kwargs:
|
329 |
+
warnings.warn(
|
330 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
331 |
+
)
|
332 |
bsz, q_len, _ = hidden_states.size()
|
333 |
|
334 |
query_states = self.q_proj(hidden_states)
|
|
|
342 |
kv_seq_len = key_states.shape[-2]
|
343 |
if past_key_value is not None:
|
344 |
if self.layer_idx is None:
|
345 |
+
raise ValueError(
|
346 |
+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
347 |
+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
348 |
+
"with a layer index."
|
349 |
+
)
|
350 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
351 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
352 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
353 |
|
354 |
if past_key_value is not None:
|
355 |
+
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
356 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
357 |
|
358 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
359 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
360 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
361 |
|
362 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
363 |
|
364 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
365 |
+
raise ValueError(
|
366 |
+
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
367 |
+
f" {attn_weights.size()}"
|
368 |
+
)
|
369 |
+
|
370 |
if attention_mask is not None:
|
371 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
372 |
+
raise ValueError(
|
373 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
374 |
+
)
|
375 |
+
|
376 |
attn_weights = attn_weights + attention_mask
|
377 |
|
378 |
+
# upcast attention to fp32
|
379 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
380 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
381 |
attn_output = torch.matmul(attn_weights, value_states)
|
382 |
|
383 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
384 |
+
raise ValueError(
|
385 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
386 |
+
f" {attn_output.size()}"
|
387 |
+
)
|
388 |
+
|
389 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
390 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
391 |
|
|
|
1061 |
|
1062 |
if inputs_embeds is None:
|
1063 |
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
1064 |
|
1065 |
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
|
1066 |
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
|
|
1083 |
inputs_embeds,
|
1084 |
past_key_values_length,
|
1085 |
)
|
1086 |
+
elif attention_mask is None or attention_mask.dim() == 2:
|
1087 |
+
# 4d mask is passed through the layers
|
1088 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
1089 |
+
attention_mask,
|
1090 |
+
(batch_size, seq_length),
|
1091 |
+
inputs_embeds,
|
1092 |
+
past_key_values_length,
|
1093 |
+
sliding_window=self.config.sliding_window,
|
1094 |
+
)
|
1095 |
+
|
1096 |
hidden_states = inputs_embeds
|
1097 |
|
1098 |
# decoder layers
|
|
|
1308 |
output_attentions: Optional[bool] = None,
|
1309 |
output_hidden_states: Optional[bool] = None,
|
1310 |
return_dict: Optional[bool] = None,
|
|
|
1311 |
):
|
1312 |
batch_size, seq_len = input_ids.shape
|
1313 |
|
|
|
1325 |
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
1326 |
|
1327 |
# Generate the continuation
|
1328 |
+
continuation_length = self.n_ahead - 2
|
|
|
1329 |
new_key_values = past_key_values
|
1330 |
+
|
1331 |
start_time = time.time()
|
1332 |
+
for continuation_idx in range(continuation_length):
|
|
|
1333 |
outputs = self.model(
|
1334 |
input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
|
1335 |
attention_mask=attention_mask,
|
|
|
1342 |
return_dict=return_dict,
|
1343 |
)
|
1344 |
new_key_values = outputs.past_key_values
|
1345 |
+
|
1346 |
hidden_states = outputs[0]
|
1347 |
+
|
1348 |
logits = self.lm_head(hidden_states)
|
1349 |
logits = logits[:, -1, :] # Only consider the last token
|
1350 |
|
|
|
1383 |
)
|
1384 |
hidden_states_before = outputs_before[0][:, -1:, :]
|
1385 |
|
1386 |
+
# two new tokens: last continuation token and end thought token
|
1387 |
outputs_after = self.model(
|
1388 |
+
input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor(end_thought_token_id).unsqueeze(-1).unsqueeze(-1).to(input_ids.device)], dim=-1),
|
1389 |
attention_mask=attention_mask,
|
1390 |
position_ids=position_ids,
|
1391 |
past_key_values=new_key_values,
|
|
|
1405 |
|
1406 |
# Apply the language model head to get the final logits
|
1407 |
logits = self.lm_head(mixed_hidden_states)
|
|
|
1408 |
return logits
|
1409 |
+
|
1410 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1411 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1412 |
def forward(
|