Crystalcareai commited on
Commit
7b223b3
·
verified ·
1 Parent(s): de08a5d

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +48 -48
modeling_quiet.py CHANGED
@@ -1111,54 +1111,54 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1111
  logits = self.lm_head(mixed_hidden_states)
1112
  return logits
1113
 
1114
- @torch.no_grad()
1115
- def generate(
1116
- self,
1117
- input_ids: torch.LongTensor,
1118
- attention_mask: Optional[torch.Tensor] = None,
1119
- max_new_tokens: Optional[int] = None,
1120
- temperature: float = 1.0,
1121
- **kwargs,
1122
- ):
1123
- if isinstance(input_ids, str):
1124
- input_ids = self.tokenizer(input_ids, return_tensors="pt").input_ids
1125
-
1126
- if attention_mask is None:
1127
- attention_mask = torch.ones_like(input_ids)
1128
-
1129
- batch_size, seq_len = input_ids.shape
1130
- max_length = seq_len + max_new_tokens if max_new_tokens is not None else self.config.max_length
1131
-
1132
- position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device)
1133
- position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
1134
-
1135
- past_key_values = None
1136
- hidden_states = None
1137
- all_hidden_states = ()
1138
-
1139
- for _ in range(max_length - seq_len):
1140
- logits = self.infer(
1141
- input_ids=input_ids,
1142
- attention_mask=attention_mask,
1143
- position_ids=position_ids,
1144
- past_key_values=past_key_values,
1145
- inputs_embeds=hidden_states,
1146
- use_cache=True,
1147
- output_attentions=False,
1148
- output_hidden_states=False,
1149
- return_dict=False,
1150
- )
1151
-
1152
- next_token_logits = logits[:, -1, :] / temperature
1153
- next_token_id = torch.argmax(next_token_logits, dim=-1)
1154
-
1155
- input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1)], dim=-1)
1156
- attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1), device=attention_mask.device)], dim=-1)
1157
- position_ids = torch.cat([position_ids, (position_ids[:, -1] + 1).unsqueeze(-1)], dim=-1)
1158
-
1159
- all_hidden_states = all_hidden_states + (hidden_states,)
1160
-
1161
- return input_ids, all_hidden_states
1162
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1163
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1164
  def forward(
 
1111
  logits = self.lm_head(mixed_hidden_states)
1112
  return logits
1113
 
1114
+ # @torch.no_grad()
1115
+ # def generate(
1116
+ # self,
1117
+ # input_ids: torch.LongTensor,
1118
+ # attention_mask: Optional[torch.Tensor] = None,
1119
+ # max_new_tokens: Optional[int] = None,
1120
+ # temperature: float = 1.0,
1121
+ # **kwargs,
1122
+ # ):
1123
+ # if isinstance(input_ids, str):
1124
+ # input_ids = self.tokenizer(input_ids, return_tensors="pt").input_ids
1125
+
1126
+ # if attention_mask is None:
1127
+ # attention_mask = torch.ones_like(input_ids)
1128
+
1129
+ # batch_size, seq_len = input_ids.shape
1130
+ # max_length = seq_len + max_new_tokens if max_new_tokens is not None else self.config.max_length
1131
+
1132
+ # position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device)
1133
+ # position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
1134
+
1135
+ # past_key_values = None
1136
+ # hidden_states = None
1137
+ # all_hidden_states = ()
1138
+
1139
+ # for _ in range(max_length - seq_len):
1140
+ # logits = self.infer(
1141
+ # input_ids=input_ids,
1142
+ # attention_mask=attention_mask,
1143
+ # position_ids=position_ids,
1144
+ # past_key_values=past_key_values,
1145
+ # inputs_embeds=hidden_states,
1146
+ # use_cache=True,
1147
+ # output_attentions=False,
1148
+ # output_hidden_states=False,
1149
+ # return_dict=False,
1150
+ # )
1151
+
1152
+ # next_token_logits = logits[:, -1, :] / temperature
1153
+ # next_token_id = torch.argmax(next_token_logits, dim=-1)
1154
+
1155
+ # input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1)], dim=-1)
1156
+ # attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1), device=attention_mask.device)], dim=-1)
1157
+ # position_ids = torch.cat([position_ids, (position_ids[:, -1] + 1).unsqueeze(-1)], dim=-1)
1158
+
1159
+ # all_hidden_states = all_hidden_states + (hidden_states,)
1160
+
1161
+ # return input_ids, all_hidden_states
1162
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1163
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1164
  def forward(