duzx16 commited on
Commit
74d61a6
·
1 Parent(s): 4d01789

Add gradient checkpointing

Browse files
config.json CHANGED
@@ -36,5 +36,5 @@
36
  "transformers_version": "4.27.1",
37
  "tie_word_embeddings": false,
38
  "eos_token_id": 2,
39
- "pad_token_id": 2
40
  }
 
36
  "transformers_version": "4.27.1",
37
  "tie_word_embeddings": false,
38
  "eos_token_id": 2,
39
+ "pad_token_id": 0
40
  }
configuration_chatglm.py CHANGED
@@ -28,9 +28,12 @@ class ChatGLMConfig(PretrainedConfig):
28
  attention_softmax_in_fp32=True,
29
  fp32_residual_connection=False,
30
  quantization_bit=0,
 
 
31
  **kwargs
32
  ):
33
  self.num_layers = num_layers
 
34
  self.padded_vocab_size = padded_vocab_size
35
  self.hidden_size = hidden_size
36
  self.ffn_hidden_size = ffn_hidden_size
@@ -52,4 +55,6 @@ class ChatGLMConfig(PretrainedConfig):
52
  self.attention_softmax_in_fp32 = attention_softmax_in_fp32
53
  self.fp32_residual_connection = fp32_residual_connection
54
  self.quantization_bit = quantization_bit
 
 
55
  super().__init__(**kwargs)
 
28
  attention_softmax_in_fp32=True,
29
  fp32_residual_connection=False,
30
  quantization_bit=0,
31
+ pre_seq_len=None,
32
+ prefix_projection=False,
33
  **kwargs
34
  ):
35
  self.num_layers = num_layers
36
+ self.vocab_size = padded_vocab_size
37
  self.padded_vocab_size = padded_vocab_size
38
  self.hidden_size = hidden_size
39
  self.ffn_hidden_size = ffn_hidden_size
 
55
  self.attention_softmax_in_fp32 = attention_softmax_in_fp32
56
  self.fp32_residual_connection = fp32_residual_connection
57
  self.quantization_bit = quantization_bit
58
+ self.pre_seq_len = pre_seq_len
59
+ self.prefix_projection = prefix_projection
60
  super().__init__(**kwargs)
modeling_chatglm.py CHANGED
@@ -56,6 +56,36 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
56
  return scores
57
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def split_tensor_along_last_dim(
60
  tensor: torch.Tensor,
61
  num_partitions: int,
@@ -566,6 +596,8 @@ class GLMTransformer(torch.nn.Module):
566
  self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
567
  dtype=config.torch_dtype)
568
 
 
 
569
  def _get_layer(self, layer_number):
570
  return self.layers[layer_number]
571
 
@@ -577,6 +609,13 @@ class GLMTransformer(torch.nn.Module):
577
  if not kv_caches:
578
  kv_caches = [None for _ in range(self.num_layers)]
579
  presents = () if use_cache else None
 
 
 
 
 
 
 
580
  all_self_attentions = None
581
  all_hidden_states = () if output_hidden_states else None
582
  for index in range(self.num_layers):
@@ -584,14 +623,24 @@ class GLMTransformer(torch.nn.Module):
584
  all_hidden_states = all_hidden_states + (hidden_states,)
585
 
586
  layer = self._get_layer(index)
587
-
588
- hidden_states, kv_cache = layer(
589
- hidden_states,
590
- attention_mask,
591
- rotary_pos_emb,
592
- kv_cache=kv_caches[index],
593
- use_cache=use_cache
594
- )
 
 
 
 
 
 
 
 
 
 
595
  if use_cache:
596
  presents = presents + (kv_cache,)
597
 
@@ -645,7 +694,7 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
645
  return position_ids
646
 
647
  def _set_gradient_checkpointing(self, module, value=False):
648
- if isinstance(module, ChatGLMModel):
649
  module.gradient_checkpointing = value
650
 
651
 
@@ -700,11 +749,33 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
700
  self.encoder = init_method(GLMTransformer, config, **init_kwargs)
701
  self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
702
  dtype=config.torch_dtype, **init_kwargs)
703
- self.gradient_checkpointing = False
 
 
 
 
 
 
 
704
 
705
  def get_input_embeddings(self):
706
  return self.embedding.word_embeddings
707
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
708
  def forward(
709
  self,
710
  input_ids,
@@ -740,6 +811,11 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
740
  rotary_pos_emb = rotary_pos_emb[None, :seq_length]
741
  rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
742
 
 
 
 
 
 
743
  # Run encoder.
744
  hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
745
  inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
@@ -913,10 +989,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
913
  return response
914
 
915
  def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
916
- prompt = ""
917
- for i, (old_query, response) in enumerate(history):
918
- prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i + 1, old_query, response)
919
- prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
920
  inputs = tokenizer([prompt], return_tensors="pt")
921
  inputs = inputs.to(self.device)
922
  return inputs
@@ -933,7 +1006,6 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
933
  inputs = inputs.to(self.device)
934
  return inputs
935
 
936
-
937
  @torch.no_grad()
938
  def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1,
939
  do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
@@ -969,6 +1041,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
969
  inputs = self.build_stream_inputs(tokenizer, query, history=history)
970
  if past_key_values is not None:
971
  past_length = past_key_values[0][0].shape[0]
 
 
972
  inputs.position_ids += past_length
973
  attention_mask = inputs.attention_mask
974
  attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
 
56
  return scores
57
 
58
 
59
+ class PrefixEncoder(torch.nn.Module):
60
+ """
61
+ The torch.nn model to encode the prefix
62
+ Input shape: (batch-size, prefix-length)
63
+ Output shape: (batch-size, prefix-length, 2*layers*hidden)
64
+ """
65
+
66
+ def __init__(self, config):
67
+ super().__init__()
68
+ self.prefix_projection = config.prefix_projection
69
+ if self.prefix_projection:
70
+ # Use a two-layer MLP to encode the prefix
71
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
72
+ self.trans = torch.nn.Sequential(
73
+ torch.nn.Linear(config.hidden_size, config.hidden_size),
74
+ torch.nn.Tanh(),
75
+ torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
76
+ )
77
+ else:
78
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2)
79
+
80
+ def forward(self, prefix: torch.Tensor):
81
+ if self.prefix_projection:
82
+ prefix_tokens = self.embedding(prefix)
83
+ past_key_values = self.trans(prefix_tokens)
84
+ else:
85
+ past_key_values = self.embedding(prefix)
86
+ return past_key_values
87
+
88
+
89
  def split_tensor_along_last_dim(
90
  tensor: torch.Tensor,
91
  num_partitions: int,
 
596
  self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
597
  dtype=config.torch_dtype)
598
 
599
+ self.gradient_checkpointing = False
600
+
601
  def _get_layer(self, layer_number):
602
  return self.layers[layer_number]
603
 
 
609
  if not kv_caches:
610
  kv_caches = [None for _ in range(self.num_layers)]
611
  presents = () if use_cache else None
612
+ if self.gradient_checkpointing and self.training:
613
+ if use_cache:
614
+ logger.warning_once(
615
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
616
+ )
617
+ use_cache = False
618
+
619
  all_self_attentions = None
620
  all_hidden_states = () if output_hidden_states else None
621
  for index in range(self.num_layers):
 
623
  all_hidden_states = all_hidden_states + (hidden_states,)
624
 
625
  layer = self._get_layer(index)
626
+ if self.gradient_checkpointing and self.training:
627
+ layer_ret = torch.utils.checkpoint.checkpoint(
628
+ layer,
629
+ hidden_states,
630
+ attention_mask,
631
+ rotary_pos_emb,
632
+ kv_cache=kv_caches[index],
633
+ use_cache=use_cache
634
+ )
635
+ else:
636
+ layer_ret = layer(
637
+ hidden_states,
638
+ attention_mask,
639
+ rotary_pos_emb,
640
+ kv_cache=kv_caches[index],
641
+ use_cache=use_cache
642
+ )
643
+ hidden_states, kv_cache = layer_ret
644
  if use_cache:
645
  presents = presents + (kv_cache,)
646
 
 
694
  return position_ids
695
 
696
  def _set_gradient_checkpointing(self, module, value=False):
697
+ if isinstance(module, GLMTransformer):
698
  module.gradient_checkpointing = value
699
 
700
 
 
749
  self.encoder = init_method(GLMTransformer, config, **init_kwargs)
750
  self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
751
  dtype=config.torch_dtype, **init_kwargs)
752
+ self.pre_seq_len = config.pre_seq_len
753
+ self.prefix_projection = config.prefix_projection
754
+ if self.pre_seq_len is not None:
755
+ for param in self.parameters():
756
+ param.requires_grad = False
757
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
758
+ self.prefix_encoder = PrefixEncoder(config)
759
+ self.dropout = torch.nn.Dropout(0.1)
760
 
761
  def get_input_embeddings(self):
762
  return self.embedding.word_embeddings
763
 
764
+ def get_prompt(self, batch_size, device, dtype=torch.half):
765
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
766
+ past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
767
+ past_key_values = past_key_values.view(
768
+ batch_size,
769
+ self.pre_seq_len,
770
+ self.num_layers * 2,
771
+ self.num_attention_heads,
772
+ self.hidden_size // self.num_attention_heads
773
+ )
774
+ # seq_len, b, nh, hidden_size
775
+ past_key_values = self.dropout(past_key_values)
776
+ past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
777
+ return past_key_values
778
+
779
  def forward(
780
  self,
781
  input_ids,
 
811
  rotary_pos_emb = rotary_pos_emb[None, :seq_length]
812
  rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
813
 
814
+ if past_key_values is None:
815
+ if self.pre_seq_len is not None:
816
+ past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
817
+ dtype=inputs_embeds.dtype)
818
+
819
  # Run encoder.
820
  hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
821
  inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
 
989
  return response
990
 
991
  def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
992
+ prompt = tokenizer.build_prompt(query, history=history)
 
 
 
993
  inputs = tokenizer([prompt], return_tensors="pt")
994
  inputs = inputs.to(self.device)
995
  return inputs
 
1006
  inputs = inputs.to(self.device)
1007
  return inputs
1008
 
 
1009
  @torch.no_grad()
1010
  def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1,
1011
  do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
 
1041
  inputs = self.build_stream_inputs(tokenizer, query, history=history)
1042
  if past_key_values is not None:
1043
  past_length = past_key_values[0][0].shape[0]
1044
+ if self.transformer.pre_seq_len is not None:
1045
+ past_length -= self.transformer.pre_seq_len
1046
  inputs.position_ids += past_length
1047
  attention_mask = inputs.attention_mask
1048
  attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
tokenization_chatglm.py CHANGED
@@ -17,7 +17,7 @@ class SPTokenizer:
17
  self.n_words: int = self.sp_model.vocab_size()
18
  self.bos_id: int = self.sp_model.bos_id()
19
  self.eos_id: int = self.sp_model.eos_id()
20
- self.pad_id: int = self.sp_model.eos_id()
21
  assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
22
 
23
  special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"]
@@ -55,7 +55,7 @@ class SPTokenizer:
55
 
56
  def convert_id_to_token(self, index):
57
  """Converts an index (integer) in a token (str) using the vocab."""
58
- if index in self.index_special_tokens:
59
  return ""
60
  return self.sp_model.IdToPiece(index)
61
 
@@ -85,12 +85,20 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
85
 
86
  @property
87
  def pad_token(self) -> str:
88
- return "</s>"
89
 
90
  @property
91
  def pad_token_id(self):
92
  return self.get_command("<pad>")
93
 
 
 
 
 
 
 
 
 
94
  @property
95
  def vocab_size(self):
96
  return self.tokenizer.n_words
@@ -147,6 +155,15 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
147
  prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
148
  return prefix_tokens
149
 
 
 
 
 
 
 
 
 
 
150
  def build_inputs_with_special_tokens(
151
  self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
152
  ) -> List[int]:
 
17
  self.n_words: int = self.sp_model.vocab_size()
18
  self.bos_id: int = self.sp_model.bos_id()
19
  self.eos_id: int = self.sp_model.eos_id()
20
+ self.pad_id: int = self.sp_model.unk_id()
21
  assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
22
 
23
  special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"]
 
55
 
56
  def convert_id_to_token(self, index):
57
  """Converts an index (integer) in a token (str) using the vocab."""
58
+ if index in self.index_special_tokens or index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
59
  return ""
60
  return self.sp_model.IdToPiece(index)
61
 
 
85
 
86
  @property
87
  def pad_token(self) -> str:
88
+ return "<unk>"
89
 
90
  @property
91
  def pad_token_id(self):
92
  return self.get_command("<pad>")
93
 
94
+ @property
95
+ def eos_token(self) -> str:
96
+ return "</s>"
97
+
98
+ @property
99
+ def eos_token_id(self):
100
+ return self.get_command("<eos>")
101
+
102
  @property
103
  def vocab_size(self):
104
  return self.tokenizer.n_words
 
155
  prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
156
  return prefix_tokens
157
 
158
+ def build_prompt(self, query, history=None):
159
+ if history is None:
160
+ history = []
161
+ prompt = ""
162
+ for i, (old_query, response) in enumerate(history):
163
+ prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i + 1, old_query, response)
164
+ prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
165
+ return prompt
166
+
167
  def build_inputs_with_special_tokens(
168
  self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
169
  ) -> List[int]: