Crystalcareai commited on
Commit
de08a5d
·
verified ·
1 Parent(s): ea46013

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +43 -43
modeling_quiet.py CHANGED
@@ -1024,14 +1024,16 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1024
  # Update the attention mask
1025
  if attention_mask is not None:
1026
  attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
 
 
1027
 
1028
  # Generate the continuation
1029
  continuation_length = self.n_ahead - 2
1030
  new_key_values = past_key_values
1031
-
1032
  # Initialize next_token_id with a default value
1033
  next_token_id = torch.zeros(batch_size, dtype=torch.long).to(input_ids.device)
1034
-
1035
  start_time = time.time()
1036
  for continuation_idx in range(continuation_length):
1037
  outputs = self.model(
@@ -1057,59 +1059,57 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1057
  next_token_id = torch.argmax(next_token_logits, dim=-1)
1058
 
1059
  # Append the generated token to the input sequence
1060
- # input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
1061
  seq_len += 1
1062
 
1063
  # Update the attention mask
1064
- if attention_mask is not None:
1065
- attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1066
 
1067
  # Append the end thought token to the input sequence
1068
- end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
1069
- input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1070
- seq_len += 1
1071
 
1072
- # Update the attention mask
1073
- if attention_mask is not None:
1074
  attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1075
 
1076
- # Get the hidden states before and after the thought
1077
- outputs_before = self.model(
1078
- input_ids=original_input_ids,
1079
- attention_mask=original_attention_mask,
1080
- position_ids=position_ids,
1081
- past_key_values=past_key_values,
1082
- inputs_embeds=inputs_embeds,
1083
- use_cache=use_cache,
1084
- output_attentions=output_attentions,
1085
- output_hidden_states=output_hidden_states,
1086
- return_dict=return_dict,
1087
- )
1088
- hidden_states_before = outputs_before[0][:, -1:, :]
1089
 
1090
- # two new tokens: last continuation token and end thought token
1091
- outputs_after = self.model(
1092
- input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1),
1093
- attention_mask=torch.cat([attention_mask[:, -1:], torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1),
1094
- position_ids=position_ids,
1095
- past_key_values=new_key_values,
1096
- inputs_embeds=inputs_embeds,
1097
- use_cache=use_cache,
1098
- output_attentions=output_attentions,
1099
- output_hidden_states=output_hidden_states,
1100
- return_dict=return_dict,
1101
- )
1102
- hidden_states_after = outputs_after[0][:, -1:, :]
1103
 
1104
- # Apply the talk head to get the mixing weight
1105
- mixing_weight = self.talk_head[0](torch.cat([hidden_states_before, hidden_states_after], dim=-1))
1106
 
1107
- # Apply the mixing weight to the hidden states
1108
- mixed_hidden_states = (1 - mixing_weight) * hidden_states_before + mixing_weight * hidden_states_after
1109
 
1110
- # Apply the language model head to get the final logits
1111
- logits = self.lm_head(mixed_hidden_states)
1112
- return logits
1113
 
1114
  @torch.no_grad()
1115
  def generate(
 
1024
  # Update the attention mask
1025
  if attention_mask is not None:
1026
  attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1027
+ else:
1028
+ attention_mask = torch.ones((batch_size, seq_len)).to(input_ids.device)
1029
 
1030
  # Generate the continuation
1031
  continuation_length = self.n_ahead - 2
1032
  new_key_values = past_key_values
1033
+
1034
  # Initialize next_token_id with a default value
1035
  next_token_id = torch.zeros(batch_size, dtype=torch.long).to(input_ids.device)
1036
+
1037
  start_time = time.time()
1038
  for continuation_idx in range(continuation_length):
1039
  outputs = self.model(
 
1059
  next_token_id = torch.argmax(next_token_logits, dim=-1)
1060
 
1061
  # Append the generated token to the input sequence
1062
+ input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
1063
  seq_len += 1
1064
 
1065
  # Update the attention mask
1066
+ attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
 
1067
 
1068
  # Append the end thought token to the input sequence
1069
+ end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
1070
+ input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1071
+ seq_len += 1
1072
 
1073
+ # Update the attention mask
 
1074
  attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1075
 
1076
+ # Get the hidden states before and after the thought
1077
+ outputs_before = self.model(
1078
+ input_ids=original_input_ids,
1079
+ attention_mask=original_attention_mask,
1080
+ position_ids=position_ids,
1081
+ past_key_values=past_key_values,
1082
+ inputs_embeds=inputs_embeds,
1083
+ use_cache=use_cache,
1084
+ output_attentions=output_attentions,
1085
+ output_hidden_states=output_hidden_states,
1086
+ return_dict=return_dict,
1087
+ )
1088
+ hidden_states_before = outputs_before[0][:, -1:, :]
1089
 
1090
+ # two new tokens: last continuation token and end thought token
1091
+ outputs_after = self.model(
1092
+ input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1),
1093
+ attention_mask=attention_mask[:, -2:],
1094
+ position_ids=position_ids,
1095
+ past_key_values=new_key_values,
1096
+ inputs_embeds=inputs_embeds,
1097
+ use_cache=use_cache,
1098
+ output_attentions=output_attentions,
1099
+ output_hidden_states=output_hidden_states,
1100
+ return_dict=return_dict,
1101
+ )
1102
+ hidden_states_after = outputs_after[0][:, -1:, :]
1103
 
1104
+ # Apply the talk head to get the mixing weight
1105
+ mixing_weight = self.talk_head[0](torch.cat([hidden_states_before, hidden_states_after], dim=-1))
1106
 
1107
+ # Apply the mixing weight to the hidden states
1108
+ mixed_hidden_states = (1 - mixing_weight) * hidden_states_before + mixing_weight * hidden_states_after
1109
 
1110
+ # Apply the language model head to get the final logits
1111
+ logits = self.lm_head(mixed_hidden_states)
1112
+ return logits
1113
 
1114
  @torch.no_grad()
1115
  def generate(