Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +43 -43
modeling_quiet.py
CHANGED
@@ -1024,14 +1024,16 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1024 |
# Update the attention mask
|
1025 |
if attention_mask is not None:
|
1026 |
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
|
|
|
|
1027 |
|
1028 |
# Generate the continuation
|
1029 |
continuation_length = self.n_ahead - 2
|
1030 |
new_key_values = past_key_values
|
1031 |
-
|
1032 |
# Initialize next_token_id with a default value
|
1033 |
next_token_id = torch.zeros(batch_size, dtype=torch.long).to(input_ids.device)
|
1034 |
-
|
1035 |
start_time = time.time()
|
1036 |
for continuation_idx in range(continuation_length):
|
1037 |
outputs = self.model(
|
@@ -1057,59 +1059,57 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1057 |
next_token_id = torch.argmax(next_token_logits, dim=-1)
|
1058 |
|
1059 |
# Append the generated token to the input sequence
|
1060 |
-
|
1061 |
seq_len += 1
|
1062 |
|
1063 |
# Update the attention mask
|
1064 |
-
|
1065 |
-
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
1066 |
|
1067 |
# Append the end thought token to the input sequence
|
1068 |
-
|
1069 |
-
|
1070 |
-
|
1071 |
|
1072 |
-
|
1073 |
-
if attention_mask is not None:
|
1074 |
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
1075 |
|
1076 |
-
|
1077 |
-
|
1078 |
-
|
1079 |
-
|
1080 |
-
|
1081 |
-
|
1082 |
-
|
1083 |
-
|
1084 |
-
|
1085 |
-
|
1086 |
-
|
1087 |
-
|
1088 |
-
|
1089 |
|
1090 |
-
|
1091 |
-
|
1092 |
-
|
1093 |
-
|
1094 |
-
|
1095 |
-
|
1096 |
-
|
1097 |
-
|
1098 |
-
|
1099 |
-
|
1100 |
-
|
1101 |
-
|
1102 |
-
|
1103 |
|
1104 |
-
|
1105 |
-
|
1106 |
|
1107 |
-
|
1108 |
-
|
1109 |
|
1110 |
-
|
1111 |
-
|
1112 |
-
|
1113 |
|
1114 |
@torch.no_grad()
|
1115 |
def generate(
|
|
|
1024 |
# Update the attention mask
|
1025 |
if attention_mask is not None:
|
1026 |
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
1027 |
+
else:
|
1028 |
+
attention_mask = torch.ones((batch_size, seq_len)).to(input_ids.device)
|
1029 |
|
1030 |
# Generate the continuation
|
1031 |
continuation_length = self.n_ahead - 2
|
1032 |
new_key_values = past_key_values
|
1033 |
+
|
1034 |
# Initialize next_token_id with a default value
|
1035 |
next_token_id = torch.zeros(batch_size, dtype=torch.long).to(input_ids.device)
|
1036 |
+
|
1037 |
start_time = time.time()
|
1038 |
for continuation_idx in range(continuation_length):
|
1039 |
outputs = self.model(
|
|
|
1059 |
next_token_id = torch.argmax(next_token_logits, dim=-1)
|
1060 |
|
1061 |
# Append the generated token to the input sequence
|
1062 |
+
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
|
1063 |
seq_len += 1
|
1064 |
|
1065 |
# Update the attention mask
|
1066 |
+
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
|
|
1067 |
|
1068 |
# Append the end thought token to the input sequence
|
1069 |
+
end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
|
1070 |
+
input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
|
1071 |
+
seq_len += 1
|
1072 |
|
1073 |
+
# Update the attention mask
|
|
|
1074 |
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
1075 |
|
1076 |
+
# Get the hidden states before and after the thought
|
1077 |
+
outputs_before = self.model(
|
1078 |
+
input_ids=original_input_ids,
|
1079 |
+
attention_mask=original_attention_mask,
|
1080 |
+
position_ids=position_ids,
|
1081 |
+
past_key_values=past_key_values,
|
1082 |
+
inputs_embeds=inputs_embeds,
|
1083 |
+
use_cache=use_cache,
|
1084 |
+
output_attentions=output_attentions,
|
1085 |
+
output_hidden_states=output_hidden_states,
|
1086 |
+
return_dict=return_dict,
|
1087 |
+
)
|
1088 |
+
hidden_states_before = outputs_before[0][:, -1:, :]
|
1089 |
|
1090 |
+
# two new tokens: last continuation token and end thought token
|
1091 |
+
outputs_after = self.model(
|
1092 |
+
input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1),
|
1093 |
+
attention_mask=attention_mask[:, -2:],
|
1094 |
+
position_ids=position_ids,
|
1095 |
+
past_key_values=new_key_values,
|
1096 |
+
inputs_embeds=inputs_embeds,
|
1097 |
+
use_cache=use_cache,
|
1098 |
+
output_attentions=output_attentions,
|
1099 |
+
output_hidden_states=output_hidden_states,
|
1100 |
+
return_dict=return_dict,
|
1101 |
+
)
|
1102 |
+
hidden_states_after = outputs_after[0][:, -1:, :]
|
1103 |
|
1104 |
+
# Apply the talk head to get the mixing weight
|
1105 |
+
mixing_weight = self.talk_head[0](torch.cat([hidden_states_before, hidden_states_after], dim=-1))
|
1106 |
|
1107 |
+
# Apply the mixing weight to the hidden states
|
1108 |
+
mixed_hidden_states = (1 - mixing_weight) * hidden_states_before + mixing_weight * hidden_states_after
|
1109 |
|
1110 |
+
# Apply the language model head to get the final logits
|
1111 |
+
logits = self.lm_head(mixed_hidden_states)
|
1112 |
+
return logits
|
1113 |
|
1114 |
@torch.no_grad()
|
1115 |
def generate(
|