Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +121 -113
modeling_quiet.py
CHANGED
@@ -1024,16 +1024,14 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1024 |
# Update the attention mask
|
1025 |
if attention_mask is not None:
|
1026 |
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
1027 |
-
else:
|
1028 |
-
attention_mask = torch.ones((batch_size, seq_len)).to(input_ids.device)
|
1029 |
|
1030 |
# Generate the continuation
|
1031 |
continuation_length = self.n_ahead - 2
|
1032 |
new_key_values = past_key_values
|
1033 |
-
|
1034 |
# Initialize next_token_id with a default value
|
1035 |
next_token_id = torch.zeros(batch_size, dtype=torch.long).to(input_ids.device)
|
1036 |
-
|
1037 |
start_time = time.time()
|
1038 |
for continuation_idx in range(continuation_length):
|
1039 |
outputs = self.model(
|
@@ -1059,106 +1057,79 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1059 |
next_token_id = torch.argmax(next_token_logits, dim=-1)
|
1060 |
|
1061 |
# Append the generated token to the input sequence
|
1062 |
-
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
|
1063 |
seq_len += 1
|
1064 |
|
1065 |
# Update the attention mask
|
1066 |
-
|
|
|
1067 |
|
1068 |
# Append the end thought token to the input sequence
|
1069 |
-
|
1070 |
-
|
1071 |
-
|
1072 |
|
1073 |
-
|
|
|
1074 |
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
1075 |
|
1076 |
-
|
1077 |
-
|
1078 |
-
|
1079 |
-
|
1080 |
-
|
1081 |
-
|
1082 |
-
|
1083 |
-
|
1084 |
-
|
1085 |
-
|
1086 |
-
|
1087 |
-
|
1088 |
-
|
1089 |
|
1090 |
-
|
1091 |
-
|
1092 |
-
|
1093 |
-
|
1094 |
-
|
1095 |
-
|
1096 |
-
|
1097 |
-
|
1098 |
-
|
1099 |
-
|
1100 |
-
|
1101 |
-
|
1102 |
-
|
1103 |
|
1104 |
-
|
1105 |
-
|
1106 |
|
1107 |
-
|
1108 |
-
|
1109 |
|
1110 |
-
|
1111 |
-
|
1112 |
-
|
1113 |
|
1114 |
-
|
1115 |
-
|
1116 |
-
|
1117 |
-
|
1118 |
-
|
1119 |
-
|
1120 |
-
|
1121 |
-
|
1122 |
-
|
1123 |
-
|
1124 |
-
|
1125 |
-
|
1126 |
-
|
1127 |
-
|
1128 |
-
|
1129 |
-
|
1130 |
-
|
1131 |
-
|
1132 |
-
|
1133 |
-
# position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
1134 |
-
|
1135 |
-
# past_key_values = None
|
1136 |
-
# hidden_states = None
|
1137 |
-
# all_hidden_states = ()
|
1138 |
-
|
1139 |
-
# for _ in range(max_length - seq_len):
|
1140 |
-
# logits = self.infer(
|
1141 |
-
# input_ids=input_ids,
|
1142 |
-
# attention_mask=attention_mask,
|
1143 |
-
# position_ids=position_ids,
|
1144 |
-
# past_key_values=past_key_values,
|
1145 |
-
# inputs_embeds=hidden_states,
|
1146 |
-
# use_cache=True,
|
1147 |
-
# output_attentions=False,
|
1148 |
-
# output_hidden_states=False,
|
1149 |
-
# return_dict=False,
|
1150 |
-
# )
|
1151 |
-
|
1152 |
-
# next_token_logits = logits[:, -1, :] / temperature
|
1153 |
-
# next_token_id = torch.argmax(next_token_logits, dim=-1)
|
1154 |
-
|
1155 |
-
# input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1)], dim=-1)
|
1156 |
-
# attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1), device=attention_mask.device)], dim=-1)
|
1157 |
-
# position_ids = torch.cat([position_ids, (position_ids[:, -1] + 1).unsqueeze(-1)], dim=-1)
|
1158 |
-
|
1159 |
-
# all_hidden_states = all_hidden_states + (hidden_states,)
|
1160 |
-
|
1161 |
-
# return input_ids, all_hidden_states
|
1162 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1163 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1164 |
def forward(
|
@@ -1891,16 +1862,12 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1891 |
torch.cuda.empty_cache()
|
1892 |
|
1893 |
|
1894 |
-
return
|
1895 |
-
|
1896 |
-
|
1897 |
-
|
1898 |
-
|
1899 |
-
|
1900 |
-
use_cache=use_cache,
|
1901 |
-
output_attentions=output_attentions,
|
1902 |
-
output_hidden_states=output_hidden_states,
|
1903 |
-
return_dict=return_dict,
|
1904 |
)
|
1905 |
|
1906 |
|
@@ -1908,18 +1875,59 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1908 |
def prepare_inputs_for_generation(
|
1909 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
1910 |
):
|
1911 |
-
|
1912 |
-
|
1913 |
-
|
1914 |
-
|
1915 |
-
|
1916 |
-
|
1917 |
-
|
1918 |
-
|
1919 |
-
|
1920 |
-
|
1921 |
-
|
1922 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1923 |
|
1924 |
@staticmethod
|
1925 |
def _reorder_cache(past_key_values, beam_idx):
|
|
|
1024 |
# Update the attention mask
|
1025 |
if attention_mask is not None:
|
1026 |
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
|
|
|
|
1027 |
|
1028 |
# Generate the continuation
|
1029 |
continuation_length = self.n_ahead - 2
|
1030 |
new_key_values = past_key_values
|
1031 |
+
|
1032 |
# Initialize next_token_id with a default value
|
1033 |
next_token_id = torch.zeros(batch_size, dtype=torch.long).to(input_ids.device)
|
1034 |
+
|
1035 |
start_time = time.time()
|
1036 |
for continuation_idx in range(continuation_length):
|
1037 |
outputs = self.model(
|
|
|
1057 |
next_token_id = torch.argmax(next_token_logits, dim=-1)
|
1058 |
|
1059 |
# Append the generated token to the input sequence
|
1060 |
+
# input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
|
1061 |
seq_len += 1
|
1062 |
|
1063 |
# Update the attention mask
|
1064 |
+
if attention_mask is not None:
|
1065 |
+
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
1066 |
|
1067 |
# Append the end thought token to the input sequence
|
1068 |
+
end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
|
1069 |
+
input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
|
1070 |
+
seq_len += 1
|
1071 |
|
1072 |
+
# Update the attention mask
|
1073 |
+
if attention_mask is not None:
|
1074 |
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
1075 |
|
1076 |
+
# Get the hidden states before and after the thought
|
1077 |
+
outputs_before = self.model(
|
1078 |
+
input_ids=original_input_ids,
|
1079 |
+
attention_mask=original_attention_mask,
|
1080 |
+
position_ids=position_ids,
|
1081 |
+
past_key_values=past_key_values,
|
1082 |
+
inputs_embeds=inputs_embeds,
|
1083 |
+
use_cache=use_cache,
|
1084 |
+
output_attentions=output_attentions,
|
1085 |
+
output_hidden_states=output_hidden_states,
|
1086 |
+
return_dict=return_dict,
|
1087 |
+
)
|
1088 |
+
hidden_states_before = outputs_before[0][:, -1:, :]
|
1089 |
|
1090 |
+
# two new tokens: last continuation token and end thought token
|
1091 |
+
outputs_after = self.model(
|
1092 |
+
input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1),
|
1093 |
+
attention_mask=torch.cat([attention_mask[:, -1:], torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1),
|
1094 |
+
position_ids=position_ids,
|
1095 |
+
past_key_values=new_key_values,
|
1096 |
+
inputs_embeds=inputs_embeds,
|
1097 |
+
use_cache=use_cache,
|
1098 |
+
output_attentions=output_attentions,
|
1099 |
+
output_hidden_states=output_hidden_states,
|
1100 |
+
return_dict=return_dict,
|
1101 |
+
)
|
1102 |
+
hidden_states_after = outputs_after[0][:, -1:, :]
|
1103 |
|
1104 |
+
# Apply the talk head to get the mixing weight
|
1105 |
+
mixing_weight = self.talk_head[0](torch.cat([hidden_states_before, hidden_states_after], dim=-1))
|
1106 |
|
1107 |
+
# Apply the mixing weight to the hidden states
|
1108 |
+
mixed_hidden_states = (1 - mixing_weight) * hidden_states_before + mixing_weight * hidden_states_after
|
1109 |
|
1110 |
+
# Apply the language model head to get the final logits
|
1111 |
+
logits = self.lm_head(mixed_hidden_states)
|
1112 |
+
return logits
|
1113 |
|
1114 |
+
@torch.no_grad()
|
1115 |
+
def generate(
|
1116 |
+
self,
|
1117 |
+
input_ids: torch.LongTensor = torch.LongTensor(),
|
1118 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1119 |
+
max_new_tokens: Optional[int] = None,
|
1120 |
+
temperature: float = 1.1,
|
1121 |
+
**kwargs,
|
1122 |
+
):
|
1123 |
+
if isinstance(input_ids, str):
|
1124 |
+
input_ids = self.tokenizer(input_ids, return_tensors="pt").input_ids
|
1125 |
+
|
1126 |
+
if attention_mask is None:
|
1127 |
+
# Create a default attention mask if not provided
|
1128 |
+
attention_mask = torch.ones_like(input_ids)
|
1129 |
+
|
1130 |
+
from .generate import generate
|
1131 |
+
return generate(self, input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
|
1132 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1133 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1134 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1135 |
def forward(
|
|
|
1862 |
torch.cuda.empty_cache()
|
1863 |
|
1864 |
|
1865 |
+
return CausalLMOutputWithPast(
|
1866 |
+
loss=loss if loss is not None else None,
|
1867 |
+
logits=(rm_logits if self.n_ahead > 1 else logits) if not self.output_logits_at_the_end else logits,
|
1868 |
+
past_key_values=outputs.past_key_values,
|
1869 |
+
hidden_states=outputs.hidden_states,
|
1870 |
+
attentions=outputs.attentions,
|
|
|
|
|
|
|
|
|
1871 |
)
|
1872 |
|
1873 |
|
|
|
1875 |
def prepare_inputs_for_generation(
|
1876 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
1877 |
):
|
1878 |
+
# Omit tokens covered by past_key_values
|
1879 |
+
if past_key_values is not None:
|
1880 |
+
if isinstance(past_key_values, Cache):
|
1881 |
+
cache_length = past_key_values.get_seq_length()
|
1882 |
+
past_length = past_key_values.seen_tokens
|
1883 |
+
max_cache_length = past_key_values.get_max_length()
|
1884 |
+
else:
|
1885 |
+
cache_length = past_length = past_key_values[0][0].shape[2]
|
1886 |
+
max_cache_length = None
|
1887 |
+
|
1888 |
+
# Keep only the unprocessed tokens:
|
1889 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
1890 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing inputs_embeds as
|
1891 |
+
# input)
|
1892 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
1893 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
1894 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
1895 |
+
# input_ids based on the past_length.
|
1896 |
+
elif past_length < input_ids.shape[1]:
|
1897 |
+
input_ids = input_ids[:, past_length:]
|
1898 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
1899 |
+
|
1900 |
+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
1901 |
+
if (
|
1902 |
+
max_cache_length is not None
|
1903 |
+
and attention_mask is not None
|
1904 |
+
and cache_length + input_ids.shape[1] > max_cache_length
|
1905 |
+
):
|
1906 |
+
attention_mask = attention_mask[:, -max_cache_length:]
|
1907 |
+
|
1908 |
+
position_ids = kwargs.get("position_ids", None)
|
1909 |
+
if attention_mask is not None and position_ids is None:
|
1910 |
+
# create position_ids on the fly for batch generation
|
1911 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
1912 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
1913 |
+
if past_key_values:
|
1914 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
1915 |
+
|
1916 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1917 |
+
if inputs_embeds is not None and past_key_values is None:
|
1918 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
1919 |
+
else:
|
1920 |
+
model_inputs = {"input_ids": input_ids}
|
1921 |
+
|
1922 |
+
model_inputs.update(
|
1923 |
+
{
|
1924 |
+
"position_ids": position_ids,
|
1925 |
+
"past_key_values": past_key_values,
|
1926 |
+
"use_cache": kwargs.get("use_cache"),
|
1927 |
+
"attention_mask": attention_mask,
|
1928 |
+
}
|
1929 |
+
)
|
1930 |
+
return model_inputs
|
1931 |
|
1932 |
@staticmethod
|
1933 |
def _reorder_cache(past_key_values, beam_idx):
|