Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +48 -48
modeling_quiet.py
CHANGED
@@ -1111,54 +1111,54 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1111 |
logits = self.lm_head(mixed_hidden_states)
|
1112 |
return logits
|
1113 |
|
1114 |
-
@torch.no_grad()
|
1115 |
-
def generate(
|
1116 |
-
|
1117 |
-
|
1118 |
-
|
1119 |
-
|
1120 |
-
|
1121 |
-
|
1122 |
-
):
|
1123 |
-
|
1124 |
-
|
1125 |
-
|
1126 |
-
|
1127 |
-
|
1128 |
-
|
1129 |
-
|
1130 |
-
|
1131 |
-
|
1132 |
-
|
1133 |
-
|
1134 |
-
|
1135 |
-
|
1136 |
-
|
1137 |
-
|
1138 |
-
|
1139 |
-
|
1140 |
-
|
1141 |
-
|
1142 |
-
|
1143 |
-
|
1144 |
-
|
1145 |
-
|
1146 |
-
|
1147 |
-
|
1148 |
-
|
1149 |
-
|
1150 |
-
|
1151 |
-
|
1152 |
-
|
1153 |
-
|
1154 |
-
|
1155 |
-
|
1156 |
-
|
1157 |
-
|
1158 |
-
|
1159 |
-
|
1160 |
-
|
1161 |
-
|
1162 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1163 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1164 |
def forward(
|
|
|
1111 |
logits = self.lm_head(mixed_hidden_states)
|
1112 |
return logits
|
1113 |
|
1114 |
+
# @torch.no_grad()
|
1115 |
+
# def generate(
|
1116 |
+
# self,
|
1117 |
+
# input_ids: torch.LongTensor,
|
1118 |
+
# attention_mask: Optional[torch.Tensor] = None,
|
1119 |
+
# max_new_tokens: Optional[int] = None,
|
1120 |
+
# temperature: float = 1.0,
|
1121 |
+
# **kwargs,
|
1122 |
+
# ):
|
1123 |
+
# if isinstance(input_ids, str):
|
1124 |
+
# input_ids = self.tokenizer(input_ids, return_tensors="pt").input_ids
|
1125 |
+
|
1126 |
+
# if attention_mask is None:
|
1127 |
+
# attention_mask = torch.ones_like(input_ids)
|
1128 |
+
|
1129 |
+
# batch_size, seq_len = input_ids.shape
|
1130 |
+
# max_length = seq_len + max_new_tokens if max_new_tokens is not None else self.config.max_length
|
1131 |
+
|
1132 |
+
# position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device)
|
1133 |
+
# position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
1134 |
+
|
1135 |
+
# past_key_values = None
|
1136 |
+
# hidden_states = None
|
1137 |
+
# all_hidden_states = ()
|
1138 |
+
|
1139 |
+
# for _ in range(max_length - seq_len):
|
1140 |
+
# logits = self.infer(
|
1141 |
+
# input_ids=input_ids,
|
1142 |
+
# attention_mask=attention_mask,
|
1143 |
+
# position_ids=position_ids,
|
1144 |
+
# past_key_values=past_key_values,
|
1145 |
+
# inputs_embeds=hidden_states,
|
1146 |
+
# use_cache=True,
|
1147 |
+
# output_attentions=False,
|
1148 |
+
# output_hidden_states=False,
|
1149 |
+
# return_dict=False,
|
1150 |
+
# )
|
1151 |
+
|
1152 |
+
# next_token_logits = logits[:, -1, :] / temperature
|
1153 |
+
# next_token_id = torch.argmax(next_token_logits, dim=-1)
|
1154 |
+
|
1155 |
+
# input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1)], dim=-1)
|
1156 |
+
# attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1), device=attention_mask.device)], dim=-1)
|
1157 |
+
# position_ids = torch.cat([position_ids, (position_ids[:, -1] + 1).unsqueeze(-1)], dim=-1)
|
1158 |
+
|
1159 |
+
# all_hidden_states = all_hidden_states + (hidden_states,)
|
1160 |
+
|
1161 |
+
# return input_ids, all_hidden_states
|
1162 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1163 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1164 |
def forward(
|