duzx16 commited on
Commit
66ecaf1
·
1 Parent(s): 833de79

Fix use_cache=False

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +95 -3
modeling_chatglm.py CHANGED
@@ -11,12 +11,14 @@ import torch.utils.checkpoint
11
  import torch.nn.functional as F
12
  from torch import nn
13
  from torch.nn import CrossEntropyLoss, LayerNorm
 
14
  from torch.nn.utils import skip_init
15
  from typing import Optional, Tuple, Union, List, Callable, Dict, Any
16
 
17
  from transformers.modeling_outputs import (
18
  BaseModelOutputWithPast,
19
  CausalLMOutputWithPast,
 
20
  )
21
  from transformers.modeling_utils import PreTrainedModel
22
  from transformers.utils import logging
@@ -895,6 +897,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
895
  past_key_values: Optional[torch.Tensor] = None,
896
  attention_mask: Optional[torch.Tensor] = None,
897
  position_ids: Optional[torch.Tensor] = None,
 
898
  is_first_forward: bool = True,
899
  **kwargs
900
  ) -> dict:
@@ -902,14 +905,16 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
902
  if position_ids is None:
903
  position_ids = self.get_position_ids(input_ids, device=input_ids.device)
904
  if not is_first_forward:
905
- position_ids = position_ids[..., -1:]
906
- input_ids = input_ids[:, -1:]
 
907
  return {
908
  "input_ids": input_ids,
909
  "past_key_values": past_key_values,
910
  "position_ids": position_ids,
911
  "attention_mask": attention_mask,
912
- "return_last_logit": True
 
913
  }
914
 
915
  def forward(
@@ -1086,6 +1091,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1086
  generation_config = self.generation_config
1087
  generation_config = copy.deepcopy(generation_config)
1088
  model_kwargs = generation_config.update(**kwargs)
 
1089
  bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
1090
 
1091
  if isinstance(eos_token_id, int):
@@ -1191,3 +1197,89 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1191
  self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device,
1192
  **kwargs)
1193
  return self
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  import torch.nn.functional as F
12
  from torch import nn
13
  from torch.nn import CrossEntropyLoss, LayerNorm
14
+ from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
15
  from torch.nn.utils import skip_init
16
  from typing import Optional, Tuple, Union, List, Callable, Dict, Any
17
 
18
  from transformers.modeling_outputs import (
19
  BaseModelOutputWithPast,
20
  CausalLMOutputWithPast,
21
+ SequenceClassifierOutputWithPast,
22
  )
23
  from transformers.modeling_utils import PreTrainedModel
24
  from transformers.utils import logging
 
897
  past_key_values: Optional[torch.Tensor] = None,
898
  attention_mask: Optional[torch.Tensor] = None,
899
  position_ids: Optional[torch.Tensor] = None,
900
+ use_cache: Optional[bool] = None,
901
  is_first_forward: bool = True,
902
  **kwargs
903
  ) -> dict:
 
905
  if position_ids is None:
906
  position_ids = self.get_position_ids(input_ids, device=input_ids.device)
907
  if not is_first_forward:
908
+ if past_key_values is not None:
909
+ position_ids = position_ids[..., -1:]
910
+ input_ids = input_ids[:, -1:]
911
  return {
912
  "input_ids": input_ids,
913
  "past_key_values": past_key_values,
914
  "position_ids": position_ids,
915
  "attention_mask": attention_mask,
916
+ "return_last_logit": True,
917
+ "use_cache": use_cache
918
  }
919
 
920
  def forward(
 
1091
  generation_config = self.generation_config
1092
  generation_config = copy.deepcopy(generation_config)
1093
  model_kwargs = generation_config.update(**kwargs)
1094
+ model_kwargs["use_cache"] = generation_config.use_cache
1095
  bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
1096
 
1097
  if isinstance(eos_token_id, int):
 
1197
  self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device,
1198
  **kwargs)
1199
  return self
1200
+
1201
+
1202
+ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1203
+ def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
1204
+ super().__init__(config)
1205
+
1206
+ self.num_labels = config.num_labels
1207
+ self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
1208
+
1209
+ self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
1210
+ if config.classifier_dropout is not None:
1211
+ self.dropout = nn.Dropout(config.classifier_dropout)
1212
+ else:
1213
+ self.dropout = None
1214
+ self.config = config
1215
+
1216
+ if self.config.quantization_bit:
1217
+ self.quantize(self.config.quantization_bit, empty_init=True)
1218
+
1219
+ def forward(
1220
+ self,
1221
+ input_ids: Optional[torch.LongTensor] = None,
1222
+ position_ids: Optional[torch.LongTensor] = None,
1223
+ attention_mask: Optional[torch.Tensor] = None,
1224
+ full_attention_mask: Optional[torch.Tensor] = None,
1225
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1226
+ inputs_embeds: Optional[torch.LongTensor] = None,
1227
+ labels: Optional[torch.LongTensor] = None,
1228
+ use_cache: Optional[bool] = None,
1229
+ output_hidden_states: Optional[bool] = None,
1230
+ return_dict: Optional[bool] = None,
1231
+ ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
1232
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1233
+
1234
+ transformer_outputs = self.transformer(
1235
+ input_ids=input_ids,
1236
+ position_ids=position_ids,
1237
+ attention_mask=attention_mask,
1238
+ full_attention_mask=full_attention_mask,
1239
+ past_key_values=past_key_values,
1240
+ inputs_embeds=inputs_embeds,
1241
+ use_cache=use_cache,
1242
+ output_hidden_states=output_hidden_states,
1243
+ return_dict=return_dict,
1244
+ )
1245
+
1246
+ hidden_states = transformer_outputs[0]
1247
+ pooled_hidden_states = hidden_states[-1]
1248
+ if self.dropout is not None:
1249
+ pooled_hidden_states = self.dropout(pooled_hidden_states)
1250
+ logits = self.classifier_head(pooled_hidden_states)
1251
+
1252
+ loss = None
1253
+ if labels is not None:
1254
+ if self.config.problem_type is None:
1255
+ if self.num_labels == 1:
1256
+ self.config.problem_type = "regression"
1257
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1258
+ self.config.problem_type = "single_label_classification"
1259
+ else:
1260
+ self.config.problem_type = "multi_label_classification"
1261
+
1262
+ if self.config.problem_type == "regression":
1263
+ loss_fct = MSELoss()
1264
+ if self.num_labels == 1:
1265
+ loss = loss_fct(logits.squeeze().float(), labels.squeeze())
1266
+ else:
1267
+ loss = loss_fct(logits.float(), labels)
1268
+ elif self.config.problem_type == "single_label_classification":
1269
+ loss_fct = CrossEntropyLoss()
1270
+ loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1))
1271
+ elif self.config.problem_type == "multi_label_classification":
1272
+ loss_fct = BCEWithLogitsLoss()
1273
+ loss = loss_fct(logits.float(), labels.view(-1, self.num_labels))
1274
+
1275
+ if not return_dict:
1276
+ output = (logits,) + transformer_outputs[1:]
1277
+ return ((loss,) + output) if loss is not None else output
1278
+
1279
+ return SequenceClassifierOutputWithPast(
1280
+ loss=loss,
1281
+ logits=logits,
1282
+ past_key_values=transformer_outputs.past_key_values,
1283
+ hidden_states=transformer_outputs.hidden_states,
1284
+ attentions=transformer_outputs.attentions,
1285
+ )