Crystalcareai commited on
Commit
20b3935
·
verified ·
1 Parent(s): c945236

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +48 -106
modeling_quiet.py CHANGED
@@ -23,16 +23,16 @@ import math
23
  import pdb
24
  import warnings
25
  from collections import defaultdict
26
- from typing import List, Optional, Tuple, Union, Iterable, Callable
27
 
28
  import torch
29
  import torch.nn.functional as F
30
  import torch.utils.checkpoint
31
  from torch import nn
32
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
33
- from transformers.generation.utils import GenerationMixin, GenerationConfig
34
  from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria
35
- from transformers import TextStreamer, AutoTokenizer
36
 
37
  from transformers.activations import ACT2FN
38
  from transformers.cache_utils import Cache, DynamicCache
@@ -143,6 +143,7 @@ class QuietRMSNorm(nn.Module):
143
  self.weight = nn.Parameter(torch.ones(hidden_size))
144
  self.variance_epsilon = eps
145
 
 
146
  def forward(self, hidden_states):
147
  input_dtype = hidden_states.dtype
148
  hidden_states = hidden_states.to(torch.float32)
@@ -150,6 +151,7 @@ class QuietRMSNorm(nn.Module):
150
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
151
  return hidden_states.to(input_dtype) * self.weight.to(hidden_states.device)
152
 
 
153
  # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Quiet
154
  class QuietRotaryEmbedding(nn.Module):
155
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
@@ -235,8 +237,7 @@ class QuietMLP(nn.Module):
235
  self.act_fn = ACT2FN[config.hidden_act]
236
 
237
  def forward(self, x):
238
- hidden_states = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
239
- return hidden_states
240
 
241
 
242
  # Copied from transformers.models.llama.modeling_llama.repeat_kv
@@ -848,7 +849,7 @@ class QuietDecoderLayer(nn.Module):
848
  residual = hidden_states
849
 
850
  hidden_states = self.input_layernorm(hidden_states)
851
-
852
  # Self Attention
853
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
854
  hidden_states=hidden_states,
@@ -1022,8 +1023,6 @@ class QuietModel(QuietPreTrainedModel):
1022
  output_hidden_states: Optional[bool] = None,
1023
  return_dict: Optional[bool] = None,
1024
  ) -> Union[Tuple, BaseModelOutputWithPast]:
1025
-
1026
- # print("Hidden states shape after embedding:", inputs_embeds.shape)
1027
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1028
  output_hidden_states = (
1029
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1072,27 +1071,32 @@ class QuietModel(QuietPreTrainedModel):
1072
  if self._attn_implementation == "flash_attention_2":
1073
  # 2d mask is passed through the layers
1074
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1075
- elif self._attn_implementation == "sdpa" and not output_attentions and attention_mask is not None and attention_mask.dim() == 2 and False:
1076
  # output_attentions=True can not be supported when using SDPA, and we fall back on
1077
  # the manual implementation that requires a 4D causal mask in all cases.
1078
  attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1079
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length,
 
 
 
1080
  )
1081
  elif attention_mask is None or attention_mask.dim() == 2:
1082
  # 4d mask is passed through the layers
1083
  attention_mask = _prepare_4d_causal_attention_mask(
1084
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length,
 
 
 
1085
  sliding_window=self.config.sliding_window,
1086
  )
1087
 
1088
-
1089
  hidden_states = inputs_embeds
1090
 
1091
  # decoder layers
1092
  all_hidden_states = () if output_hidden_states else None
1093
  all_self_attns = () if output_attentions else None
1094
  next_decoder_cache = None
1095
- # print("Hidden states shape before decoder layers:", hidden_states.shape)
1096
  for decoder_layer in self.layers:
1097
  if output_hidden_states:
1098
  all_hidden_states += (hidden_states,)
@@ -1116,15 +1120,15 @@ class QuietModel(QuietPreTrainedModel):
1116
  output_attentions=output_attentions,
1117
  use_cache=use_cache,
1118
  )
 
1119
  hidden_states = layer_outputs[0]
1120
- # print(f"Hidden states shape after decoder layer {decoder_layer}:", hidden_states.shape)
1121
- # print("Hidden states shape after decoder layers:", hidden_states.shape)
1122
 
1123
  if use_cache:
1124
  next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1125
 
1126
  if output_attentions:
1127
  all_self_attns += (layer_outputs[1],)
 
1128
  hidden_states = self.norm(hidden_states)
1129
 
1130
  # add hidden states from the last decoder layer
@@ -1155,7 +1159,7 @@ def loss_mean(x):
1155
  class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1156
  _tied_weights_keys = ["lm_head.weight"]
1157
 
1158
- def __init__(self, config,tokenizer=None):
1159
  super().__init__(config)
1160
  self.model = QuietModel(config)
1161
  self.vocab_size = config.vocab_size
@@ -1178,7 +1182,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1178
  self.n_tokens_print = 1
1179
  self.gradient_accumulation_steps = 1
1180
  self.training_steps = 0
1181
- self.tokenizer = tokenizer
1182
  self.start_token_id = None
1183
  self.end_token_id = None
1184
  self.rm_initialized = False
@@ -1306,14 +1310,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1306
  nn.init.constant_(module.bias, 0)
1307
  elif isinstance(module, nn.Embedding):
1308
  nn.init.xavier_uniform_(module.weight)
1309
-
1310
- @classmethod
1311
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1312
- tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
1313
- model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
1314
- model.tokenizer = tokenizer
1315
- return model
1316
-
1317
 
1318
  @torch.no_grad()
1319
  def infer(
@@ -1347,10 +1343,13 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1347
  continuation_length = self.n_ahead - 2
1348
  new_key_values = past_key_values
1349
 
 
 
 
1350
  start_time = time.time()
1351
  for continuation_idx in range(continuation_length):
1352
  outputs = self.model(
1353
- input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
1354
  attention_mask=attention_mask,
1355
  position_ids=position_ids,
1356
  past_key_values=new_key_values,
@@ -1371,86 +1370,33 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1371
  next_token_logits = F.gumbel_softmax(logits, tau=self.gumbel_temperature, hard=True, dim=-1)
1372
  next_token_id = torch.argmax(next_token_logits, dim=-1)
1373
 
1374
- # Append the generated token to the input sequence
1375
- input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
1376
- seq_len += 1
1377
 
1378
- # Update the attention mask
1379
- if attention_mask is not None:
1380
- attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1381
-
1382
- # Append the end thought token to the input sequence
1383
  end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
1384
- input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1385
- seq_len += 1
1386
 
1387
- # Update the attention mask
1388
- if attention_mask is not None:
1389
- attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1390
-
1391
- # Get the hidden states before and after the thought
1392
- outputs_before = self.model(
1393
- input_ids=original_input_ids,
1394
- attention_mask=original_attention_mask,
1395
- position_ids=position_ids,
1396
- past_key_values=past_key_values,
1397
- inputs_embeds=inputs_embeds,
1398
- use_cache=use_cache,
1399
- output_attentions=output_attentions,
1400
- output_hidden_states=output_hidden_states,
1401
- return_dict=return_dict,
1402
- )
1403
- hidden_states_before = outputs_before[0][:, -1:, :]
1404
-
1405
- outputs_after = self.model(
1406
- input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1),
1407
- attention_mask=attention_mask,
1408
- position_ids=position_ids,
1409
- past_key_values=new_key_values,
1410
- inputs_embeds=inputs_embeds,
1411
- use_cache=use_cache,
1412
- output_attentions=output_attentions,
1413
- output_hidden_states=output_hidden_states,
1414
- return_dict=return_dict,
1415
  )
1416
- hidden_states_after = outputs_after[0][:, -1:, :]
1417
-
1418
- # Apply the talk head to get the mixing weight
1419
- mixing_weight = self.talk_head[0](torch.cat([hidden_states_before, hidden_states_after], dim=-1))
1420
 
1421
- # Apply the mixing weight to the hidden states
1422
- mixed_hidden_states = (1 - mixing_weight) * hidden_states_before + mixing_weight * hidden_states_after
1423
 
1424
- # Apply the language model head to get the final logits
1425
- logits = self.lm_head(mixed_hidden_states)
1426
-
1427
- return logits
1428
-
1429
- @torch.no_grad()
1430
- def generate(
1431
- self,
1432
- input_ids: torch.LongTensor,
1433
- attention_mask: Optional[torch.Tensor] = None,
1434
- position_ids: Optional[torch.LongTensor] = None,
1435
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1436
- inputs_embeds: Optional[torch.FloatTensor] = None,
1437
- use_cache: Optional[bool] = None,
1438
- output_attentions: Optional[bool] = None,
1439
- output_hidden_states: Optional[bool] = None,
1440
- return_dict_in_generate: Optional[bool] = None,
1441
- **model_kwargs,
1442
- ) -> Union[BaseModelOutputWithPast, torch.LongTensor]:
1443
- return self.infer(
1444
- input_ids=input_ids,
1445
- attention_mask=attention_mask,
1446
- position_ids=position_ids,
1447
- past_key_values=past_key_values,
1448
- inputs_embeds=inputs_embeds,
1449
- use_cache=use_cache,
1450
- output_attentions=output_attentions,
1451
- output_hidden_states=output_hidden_states,
1452
- return_dict=return_dict_in_generate,
1453
- )
1454
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1455
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1456
  def forward(
@@ -1641,7 +1587,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1641
 
1642
  complexity_scores = self.compute_complexity_scores(input_ids, attention_mask)
1643
  temperature = self.temperature * complexity_scores.unsqueeze(-1)
1644
- # pdb.set_trace()
1645
  if self.use_end_thought_token or self.use_start_thought_token:
1646
  if not self.use_reparam_for_thought_embeddings:
1647
  start_embedding = self.start_embedding[0].unsqueeze(0) * self.embedding_scale * temperature
@@ -1671,10 +1617,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1671
  position_ids = position_ids.unsqueeze(0).view(-1, seq_len)
1672
  else:
1673
  position_ids = position_ids.view(-1, seq_len).long()
1674
-
1675
-
1676
- # print("Input IDs shape:", input_ids.shape)
1677
- # print("Inputs embeds shape before embedding:", inputs_embeds.shape if inputs_embeds is not None else None)
1678
  if inputs_embeds is None:
1679
  contains_start = self.use_start_thought_token and (input_ids == self.start_token_id).any()
1680
  contains_end = self.use_end_thought_token and (input_ids == self.end_token_id).any()
@@ -1694,7 +1637,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1694
  else:
1695
  with torch.set_grad_enabled(not self.train_only_thinking_embedding):
1696
  inputs_embeds = self.model.embed_tokens(input_ids)
1697
- # print("Inputs embeds shape after embedding:", inputs_embeds.shape)
1698
 
1699
  if self.n_ahead != 1 or self.n_ahead_talk != 1 or self.comparison_mode:
1700
  if attention_mask is None:
 
23
  import pdb
24
  import warnings
25
  from collections import defaultdict
26
+ from typing import List, Optional, Tuple, Union
27
 
28
  import torch
29
  import torch.nn.functional as F
30
  import torch.utils.checkpoint
31
  from torch import nn
32
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
33
+ from transformers.generation.utils import GenerationMixin
34
  from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria
35
+ from transformers import TextStreamer
36
 
37
  from transformers.activations import ACT2FN
38
  from transformers.cache_utils import Cache, DynamicCache
 
143
  self.weight = nn.Parameter(torch.ones(hidden_size))
144
  self.variance_epsilon = eps
145
 
146
+
147
  def forward(self, hidden_states):
148
  input_dtype = hidden_states.dtype
149
  hidden_states = hidden_states.to(torch.float32)
 
151
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
152
  return hidden_states.to(input_dtype) * self.weight.to(hidden_states.device)
153
 
154
+
155
  # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Quiet
156
  class QuietRotaryEmbedding(nn.Module):
157
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
 
237
  self.act_fn = ACT2FN[config.hidden_act]
238
 
239
  def forward(self, x):
240
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
 
241
 
242
 
243
  # Copied from transformers.models.llama.modeling_llama.repeat_kv
 
849
  residual = hidden_states
850
 
851
  hidden_states = self.input_layernorm(hidden_states)
852
+
853
  # Self Attention
854
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
855
  hidden_states=hidden_states,
 
1023
  output_hidden_states: Optional[bool] = None,
1024
  return_dict: Optional[bool] = None,
1025
  ) -> Union[Tuple, BaseModelOutputWithPast]:
 
 
1026
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1027
  output_hidden_states = (
1028
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
1071
  if self._attn_implementation == "flash_attention_2":
1072
  # 2d mask is passed through the layers
1073
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1074
+ elif self._attn_implementation == "sdpa" and not output_attentions and attention_mask.dim() == 2 and False:
1075
  # output_attentions=True can not be supported when using SDPA, and we fall back on
1076
  # the manual implementation that requires a 4D causal mask in all cases.
1077
  attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1078
+ attention_mask,
1079
+ (batch_size, seq_length),
1080
+ inputs_embeds,
1081
+ past_key_values_length,
1082
  )
1083
  elif attention_mask is None or attention_mask.dim() == 2:
1084
  # 4d mask is passed through the layers
1085
  attention_mask = _prepare_4d_causal_attention_mask(
1086
+ attention_mask,
1087
+ (batch_size, seq_length),
1088
+ inputs_embeds,
1089
+ past_key_values_length,
1090
  sliding_window=self.config.sliding_window,
1091
  )
1092
 
 
1093
  hidden_states = inputs_embeds
1094
 
1095
  # decoder layers
1096
  all_hidden_states = () if output_hidden_states else None
1097
  all_self_attns = () if output_attentions else None
1098
  next_decoder_cache = None
1099
+
1100
  for decoder_layer in self.layers:
1101
  if output_hidden_states:
1102
  all_hidden_states += (hidden_states,)
 
1120
  output_attentions=output_attentions,
1121
  use_cache=use_cache,
1122
  )
1123
+
1124
  hidden_states = layer_outputs[0]
 
 
1125
 
1126
  if use_cache:
1127
  next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1128
 
1129
  if output_attentions:
1130
  all_self_attns += (layer_outputs[1],)
1131
+
1132
  hidden_states = self.norm(hidden_states)
1133
 
1134
  # add hidden states from the last decoder layer
 
1159
  class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1160
  _tied_weights_keys = ["lm_head.weight"]
1161
 
1162
+ def __init__(self, config):
1163
  super().__init__(config)
1164
  self.model = QuietModel(config)
1165
  self.vocab_size = config.vocab_size
 
1182
  self.n_tokens_print = 1
1183
  self.gradient_accumulation_steps = 1
1184
  self.training_steps = 0
1185
+ self.tokenizer = None
1186
  self.start_token_id = None
1187
  self.end_token_id = None
1188
  self.rm_initialized = False
 
1310
  nn.init.constant_(module.bias, 0)
1311
  elif isinstance(module, nn.Embedding):
1312
  nn.init.xavier_uniform_(module.weight)
 
 
 
 
 
 
 
 
1313
 
1314
  @torch.no_grad()
1315
  def infer(
 
1343
  continuation_length = self.n_ahead - 2
1344
  new_key_values = past_key_values
1345
 
1346
+ # Initialize generated_ids with input_ids
1347
+ generated_ids = input_ids.clone()
1348
+
1349
  start_time = time.time()
1350
  for continuation_idx in range(continuation_length):
1351
  outputs = self.model(
1352
+ input_ids=generated_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(generated_ids.device),
1353
  attention_mask=attention_mask,
1354
  position_ids=position_ids,
1355
  past_key_values=new_key_values,
 
1370
  next_token_logits = F.gumbel_softmax(logits, tau=self.gumbel_temperature, hard=True, dim=-1)
1371
  next_token_id = torch.argmax(next_token_logits, dim=-1)
1372
 
1373
+ # Append the generated token to the generated_ids
1374
+ generated_ids = torch.cat([generated_ids, next_token_id.unsqueeze(-1).to(generated_ids.device)], dim=-1)
 
1375
 
1376
+ # Append the end thought token to the generated_ids
 
 
 
 
1377
  end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
1378
+ generated_ids = torch.cat([generated_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(generated_ids.device)], dim=-1)
 
1379
 
1380
+ return generated_ids
1381
+
1382
+
1383
+ @torch.no_grad()
1384
+ def generate(self, *args, **kwargs):
1385
+ # Call the infer method to generate the token ids
1386
+ generated_ids = self.infer(
1387
+ input_ids=kwargs.pop("input_ids", None),
1388
+ attention_mask=kwargs.pop("attention_mask", None),
1389
+ position_ids=kwargs.pop("position_ids", None),
1390
+ past_key_values=kwargs.pop("past_key_values", None),
1391
+ inputs_embeds=kwargs.pop("inputs_embeds", None),
1392
+ use_cache=kwargs.pop("use_cache", None),
1393
+ output_attentions=kwargs.pop("output_attentions", None),
1394
+ output_hidden_states=kwargs.pop("output_hidden_states", None),
1395
+ return_dict=kwargs.pop("return_dict", None),
 
 
 
 
 
 
 
 
 
 
 
 
1396
  )
 
 
 
 
1397
 
1398
+ return generated_ids
 
1399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1400
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1401
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1402
  def forward(
 
1587
 
1588
  complexity_scores = self.compute_complexity_scores(input_ids, attention_mask)
1589
  temperature = self.temperature * complexity_scores.unsqueeze(-1)
1590
+
1591
  if self.use_end_thought_token or self.use_start_thought_token:
1592
  if not self.use_reparam_for_thought_embeddings:
1593
  start_embedding = self.start_embedding[0].unsqueeze(0) * self.embedding_scale * temperature
 
1617
  position_ids = position_ids.unsqueeze(0).view(-1, seq_len)
1618
  else:
1619
  position_ids = position_ids.view(-1, seq_len).long()
1620
+
 
 
 
1621
  if inputs_embeds is None:
1622
  contains_start = self.use_start_thought_token and (input_ids == self.start_token_id).any()
1623
  contains_end = self.use_end_thought_token and (input_ids == self.end_token_id).any()
 
1637
  else:
1638
  with torch.set_grad_enabled(not self.train_only_thinking_embedding):
1639
  inputs_embeds = self.model.embed_tokens(input_ids)
 
1640
 
1641
  if self.n_ahead != 1 or self.n_ahead_talk != 1 or self.comparison_mode:
1642
  if attention_mask is None: