zxdu20 commited on
Commit
cc96a22
·
1 Parent(s): 11c270c

Implement batch generation

Browse files
Files changed (2) hide show
  1. modeling_chatglm.py +100 -83
  2. tokenization_chatglm.py +103 -12
modeling_chatglm.py CHANGED
@@ -13,7 +13,7 @@ import torch.nn.functional as F
13
  from torch import nn
14
  from torch.nn import CrossEntropyLoss, LayerNorm
15
  from torch.nn.utils import skip_init
16
- from typing import Optional, Tuple, Union, List, Callable
17
 
18
  from transformers.utils import (
19
  add_code_sample_docstrings,
@@ -28,7 +28,7 @@ from transformers.modeling_outputs import (
28
  from transformers.modeling_utils import PreTrainedModel
29
  from transformers.utils import logging
30
  from transformers.generation.logits_process import LogitsProcessor
31
- from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
32
 
33
  from .configuration_chatglm import ChatGLMConfig
34
 
@@ -664,6 +664,39 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
664
  """Initialize the weights."""
665
  return
666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
667
  def _set_gradient_checkpointing(self, module, value=False):
668
  if isinstance(module, ChatGLMModel):
669
  module.gradient_checkpointing = value
@@ -828,39 +861,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
828
  # past_key_values = [(v[0], v[1]) for v in past_key_values]
829
  return past_key_values
830
 
831
- def get_masks(self, input_ids, device):
832
- batch_size, seq_length = input_ids.shape
833
- context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
834
- attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
835
- attention_mask.tril_()
836
- for i, context_length in enumerate(context_lengths):
837
- attention_mask[i, :, :context_length] = 1
838
- attention_mask.unsqueeze_(1)
839
- attention_mask = (attention_mask < 0.5).bool()
840
-
841
- return attention_mask
842
-
843
- def get_position_ids(self, input_ids, mask_positions, device, gmask=False):
844
- batch_size, seq_length = input_ids.shape
845
- context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
846
- if self.position_encoding_2d:
847
- position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
848
- for i, context_length in enumerate(context_lengths):
849
- position_ids[i, context_length:] = mask_positions[i]
850
- block_position_ids = [torch.cat((
851
- torch.zeros(context_length, dtype=torch.long, device=device),
852
- torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
853
- )) for context_length in context_lengths]
854
- block_position_ids = torch.stack(block_position_ids, dim=0)
855
- position_ids = torch.stack((position_ids, block_position_ids), dim=1)
856
- else:
857
- position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
858
- if not gmask:
859
- for i, context_length in enumerate(context_lengths):
860
- position_ids[context_length:] = mask_positions[i]
861
-
862
- return position_ids
863
-
864
  @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
865
  @add_code_sample_docstrings(
866
  checkpoint=_CHECKPOINT_FOR_DOC,
@@ -1038,35 +1038,39 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1038
  def set_output_embeddings(self, new_embeddings):
1039
  self.lm_head = new_embeddings
1040
 
1041
- def get_masks_and_position_ids(self, input_ids, mask_positions, device, gmask=False):
1042
- batch_size, seq_length = input_ids.shape
1043
- context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
1044
- attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
1045
- attention_mask.tril_()
1046
- for i, context_length in enumerate(context_lengths):
1047
- attention_mask[i, :, :context_length] = 1
1048
- attention_mask.unsqueeze_(1)
1049
- attention_mask = (attention_mask < 0.5).bool()
 
 
1050
 
1051
- batch_size, seq_length = input_ids.shape
1052
- context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
1053
- if self.position_encoding_2d:
1054
- position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
1055
- for i, context_length in enumerate(context_lengths):
1056
- position_ids[i, context_length:] = mask_positions[i]
1057
- block_position_ids = [torch.cat((
1058
- torch.zeros(context_length, dtype=torch.long, device=device),
1059
- torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
1060
- )) for context_length in context_lengths]
1061
- block_position_ids = torch.stack(block_position_ids, dim=0)
1062
- position_ids = torch.stack((position_ids, block_position_ids), dim=1)
1063
- else:
1064
- position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
1065
- if not gmask:
1066
- for i, context_length in enumerate(context_lengths):
1067
- position_ids[context_length:] = mask_positions[i]
1068
 
1069
- return attention_mask, position_ids
 
 
 
 
 
 
 
 
 
1070
 
1071
  def prepare_inputs_for_generation(
1072
  self,
@@ -1074,6 +1078,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1074
  past: Optional[torch.Tensor] = None,
1075
  past_key_values: Optional[torch.Tensor] = None,
1076
  attention_mask: Optional[torch.Tensor] = None,
 
1077
  **kwargs
1078
  ) -> dict:
1079
  batch_size, seq_length = input_ids.shape
@@ -1085,15 +1090,20 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1085
 
1086
  # only last token for input_ids if past is not None
1087
  if past is not None or past_key_values is not None:
1088
- context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
1089
  last_token = input_ids[:, -1].unsqueeze(-1)
1090
- if self.position_encoding_2d:
1091
- position_ids = torch.tensor(
1092
- [[mask_position, seq_length - context_length] for mask_position, context_length in
1093
- zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1)
1094
  else:
1095
- position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long,
1096
- device=input_ids.device).unsqueeze(-1)
 
 
 
 
 
 
1097
 
1098
  if past is None:
1099
  past = past_key_values
@@ -1101,14 +1111,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1101
  "input_ids": last_token,
1102
  "past_key_values": past,
1103
  "position_ids": position_ids,
 
1104
  }
1105
  else:
1106
- attention_mask, position_ids = self.get_masks_and_position_ids(
1107
- input_ids,
1108
- mask_positions=mask_positions,
1109
- device=input_ids.device,
1110
- gmask=use_gmask
1111
- )
 
 
 
 
 
 
1112
 
1113
  return {
1114
  "input_ids": input_ids,
@@ -1226,10 +1243,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1226
  for i, (old_query, response) in enumerate(history):
1227
  prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
1228
  prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
1229
- input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
1230
- input_ids = input_ids.to(self.device)
1231
- outputs = self.generate(**input_ids, **gen_kwargs)
1232
- outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
1233
  response = tokenizer.decode(outputs)
1234
  response = self.process_response(response)
1235
  history = history + [(query, response)]
@@ -1252,10 +1269,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1252
  for i, (old_query, response) in enumerate(history):
1253
  prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
1254
  prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
1255
- input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
1256
- input_ids = input_ids.to(self.device)
1257
- for outputs in self.stream_generate(**input_ids, **gen_kwargs):
1258
- outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
1259
  response = tokenizer.decode(outputs)
1260
  response = self.process_response(response)
1261
  new_history = history + [(query, response)]
 
13
  from torch import nn
14
  from torch.nn import CrossEntropyLoss, LayerNorm
15
  from torch.nn.utils import skip_init
16
+ from typing import Optional, Tuple, Union, List, Callable, Dict, Any
17
 
18
  from transformers.utils import (
19
  add_code_sample_docstrings,
 
28
  from transformers.modeling_utils import PreTrainedModel
29
  from transformers.utils import logging
30
  from transformers.generation.logits_process import LogitsProcessor
31
+ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
32
 
33
  from .configuration_chatglm import ChatGLMConfig
34
 
 
664
  """Initialize the weights."""
665
  return
666
 
667
+ def get_masks(self, input_ids, device):
668
+ batch_size, seq_length = input_ids.shape
669
+ context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
670
+ attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
671
+ attention_mask.tril_()
672
+ for i, context_length in enumerate(context_lengths):
673
+ attention_mask[i, :, :context_length] = 1
674
+ attention_mask.unsqueeze_(1)
675
+ attention_mask = (attention_mask < 0.5).bool()
676
+
677
+ return attention_mask
678
+
679
+ def get_position_ids(self, input_ids, mask_positions, device, gmask=False):
680
+ batch_size, seq_length = input_ids.shape
681
+ context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
682
+ if self.position_encoding_2d:
683
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
684
+ for i, context_length in enumerate(context_lengths):
685
+ position_ids[i, context_length:] = mask_positions[i]
686
+ block_position_ids = [torch.cat((
687
+ torch.zeros(context_length, dtype=torch.long, device=device),
688
+ torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
689
+ )) for context_length in context_lengths]
690
+ block_position_ids = torch.stack(block_position_ids, dim=0)
691
+ position_ids = torch.stack((position_ids, block_position_ids), dim=1)
692
+ else:
693
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
694
+ if not gmask:
695
+ for i, context_length in enumerate(context_lengths):
696
+ position_ids[context_length:] = mask_positions[i]
697
+
698
+ return position_ids
699
+
700
  def _set_gradient_checkpointing(self, module, value=False):
701
  if isinstance(module, ChatGLMModel):
702
  module.gradient_checkpointing = value
 
861
  # past_key_values = [(v[0], v[1]) for v in past_key_values]
862
  return past_key_values
863
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
864
  @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
865
  @add_code_sample_docstrings(
866
  checkpoint=_CHECKPOINT_FOR_DOC,
 
1038
  def set_output_embeddings(self, new_embeddings):
1039
  self.lm_head = new_embeddings
1040
 
1041
+ def _update_model_kwargs_for_generation(
1042
+ self,
1043
+ outputs: ModelOutput,
1044
+ model_kwargs: Dict[str, Any],
1045
+ is_encoder_decoder: bool = False,
1046
+ standardize_cache_format: bool = False,
1047
+ ) -> Dict[str, Any]:
1048
+ # update past_key_values
1049
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
1050
+ outputs, standardize_cache_format=standardize_cache_format
1051
+ )
1052
 
1053
+ # update attention mask
1054
+ if "attention_mask" in model_kwargs:
1055
+ attention_mask = model_kwargs["attention_mask"]
1056
+ attention_mask = torch.cat(
1057
+ [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3)
1058
+ new_attention_mask = attention_mask[:, :, -1:].clone()
1059
+ new_attention_mask[..., -1] = False
1060
+ model_kwargs["attention_mask"] = torch.cat(
1061
+ [attention_mask, new_attention_mask], dim=2
1062
+ )
 
 
 
 
 
 
 
1063
 
1064
+ # update position ids
1065
+ if "position_ids" in model_kwargs:
1066
+ position_ids = model_kwargs["position_ids"]
1067
+ new_position_id = position_ids[..., -1:].clone()
1068
+ new_position_id[:, 1, :] += 1
1069
+ model_kwargs["position_ids"] = torch.cat(
1070
+ [position_ids, new_position_id], dim=-1
1071
+ )
1072
+
1073
+ return model_kwargs
1074
 
1075
  def prepare_inputs_for_generation(
1076
  self,
 
1078
  past: Optional[torch.Tensor] = None,
1079
  past_key_values: Optional[torch.Tensor] = None,
1080
  attention_mask: Optional[torch.Tensor] = None,
1081
+ position_ids: Optional[torch.Tensor] = None,
1082
  **kwargs
1083
  ) -> dict:
1084
  batch_size, seq_length = input_ids.shape
 
1090
 
1091
  # only last token for input_ids if past is not None
1092
  if past is not None or past_key_values is not None:
 
1093
  last_token = input_ids[:, -1].unsqueeze(-1)
1094
+ if attention_mask is not None:
1095
+ attention_mask = attention_mask[:, :, -1:]
1096
+ if position_ids is not None:
1097
+ position_ids = position_ids[..., -1:]
1098
  else:
1099
+ context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
1100
+ if self.position_encoding_2d:
1101
+ position_ids = torch.tensor(
1102
+ [[mask_position, seq_length - context_length] for mask_position, context_length in
1103
+ zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1)
1104
+ else:
1105
+ position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long,
1106
+ device=input_ids.device).unsqueeze(-1)
1107
 
1108
  if past is None:
1109
  past = past_key_values
 
1111
  "input_ids": last_token,
1112
  "past_key_values": past,
1113
  "position_ids": position_ids,
1114
+ "attention_mask": attention_mask
1115
  }
1116
  else:
1117
+ if attention_mask is None:
1118
+ attention_mask = self.get_masks(
1119
+ input_ids,
1120
+ device=input_ids.device
1121
+ )
1122
+ if position_ids is None:
1123
+ position_ids = self.get_position_ids(
1124
+ input_ids,
1125
+ device=input_ids.device,
1126
+ mask_positions=mask_positions,
1127
+ gmask=use_gmask
1128
+ )
1129
 
1130
  return {
1131
  "input_ids": input_ids,
 
1243
  for i, (old_query, response) in enumerate(history):
1244
  prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
1245
  prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
1246
+ inputs = tokenizer([prompt], return_tensors="pt", padding=True)
1247
+ inputs = inputs.to(self.device)
1248
+ outputs = self.generate(**inputs, **gen_kwargs)
1249
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
1250
  response = tokenizer.decode(outputs)
1251
  response = self.process_response(response)
1252
  history = history + [(query, response)]
 
1269
  for i, (old_query, response) in enumerate(history):
1270
  prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
1271
  prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
1272
+ inputs = tokenizer([prompt], return_tensors="pt", padding=True)
1273
+ inputs = inputs.to(self.device)
1274
+ for outputs in self.stream_generate(**inputs, **gen_kwargs):
1275
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
1276
  response = tokenizer.decode(outputs)
1277
  response = self.process_response(response)
1278
  new_history = history + [(query, response)]
tokenization_chatglm.py CHANGED
@@ -1,17 +1,14 @@
1
  """Tokenization classes for ChatGLM."""
2
- import sys
3
- import unicodedata
4
  from typing import List, Optional, Union
5
- from functools import lru_cache
6
  import os
7
- import collections
8
- import re
9
 
10
  from transformers.tokenization_utils import PreTrainedTokenizer
11
  from icetk.text_tokenizer import TextTokenizer
12
- from icetk.utils import auto_create
13
  import icetk.sentencepiece_model_pb2 as sp_model
14
- from transformers.utils import logging
 
 
 
15
 
16
  logger = logging.get_logger(__name__)
17
 
@@ -192,7 +189,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
192
  eop_token='eop',
193
  mask_token='[MASK]',
194
  gmask_token='[gMASK]',
195
- padding_side="right",
196
  **kwargs
197
  ) -> None:
198
  super().__init__(
@@ -210,7 +207,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
210
  self.eos_token = eos_token
211
  self.eop_token = eop_token
212
  self.mask_token = mask_token
213
- self.gMASK_token = gmask_token
214
 
215
  self.sp_tokenizer = SPTokenizer(vocab_file)
216
 
@@ -331,10 +328,9 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
331
  Returns:
332
  `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
333
  """
334
- if token_ids_1 is not None:
335
- token_ids_0 += token_ids_1
336
  mask_ids = self.sp_tokenizer[self.mask_token]
337
- gmask_ids = self.sp_tokenizer[self.gMASK_token]
 
338
  if mask_ids not in token_ids_0 and gmask_ids not in token_ids_0:
339
  token_ids_0 += [gmask_ids]
340
 
@@ -343,4 +339,99 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
343
 
344
  token_ids_0 += [self.sp_tokenizer[self.bos_token]]
345
 
 
 
 
 
 
346
  return token_ids_0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Tokenization classes for ChatGLM."""
 
 
2
  from typing import List, Optional, Union
 
3
  import os
 
 
4
 
5
  from transformers.tokenization_utils import PreTrainedTokenizer
6
  from icetk.text_tokenizer import TextTokenizer
 
7
  import icetk.sentencepiece_model_pb2 as sp_model
8
+ from transformers.utils import logging, PaddingStrategy
9
+ from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
10
+ from typing import Dict
11
+ import numpy as np
12
 
13
  logger = logging.get_logger(__name__)
14
 
 
189
  eop_token='eop',
190
  mask_token='[MASK]',
191
  gmask_token='[gMASK]',
192
+ padding_side="left",
193
  **kwargs
194
  ) -> None:
195
  super().__init__(
 
207
  self.eos_token = eos_token
208
  self.eop_token = eop_token
209
  self.mask_token = mask_token
210
+ self.gmask_token = gmask_token
211
 
212
  self.sp_tokenizer = SPTokenizer(vocab_file)
213
 
 
328
  Returns:
329
  `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
330
  """
 
 
331
  mask_ids = self.sp_tokenizer[self.mask_token]
332
+ gmask_ids = self.sp_tokenizer[self.gmask_token]
333
+ eop_id = self.sp_tokenizer[self.eop_token]
334
  if mask_ids not in token_ids_0 and gmask_ids not in token_ids_0:
335
  token_ids_0 += [gmask_ids]
336
 
 
339
 
340
  token_ids_0 += [self.sp_tokenizer[self.bos_token]]
341
 
342
+ if token_ids_1 is not None:
343
+ if token_ids_1[-1] != eop_id:
344
+ token_ids_1 += [eop_id]
345
+ token_ids_0 += token_ids_1
346
+
347
  return token_ids_0
348
+
349
+ def _pad(
350
+ self,
351
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
352
+ max_length: Optional[int] = None,
353
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
354
+ pad_to_multiple_of: Optional[int] = None,
355
+ return_attention_mask: Optional[bool] = None,
356
+ ) -> dict:
357
+ """
358
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
359
+
360
+ Args:
361
+ encoded_inputs:
362
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
363
+ max_length: maximum length of the returned list and optionally padding length (see below).
364
+ Will truncate by taking into account the special tokens.
365
+ padding_strategy: PaddingStrategy to use for padding.
366
+
367
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
368
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
369
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
370
+ The tokenizer padding sides are defined in self.padding_side:
371
+
372
+ - 'left': pads on the left of the sequences
373
+ - 'right': pads on the right of the sequences
374
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
375
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
376
+ `>= 7.5` (Volta).
377
+ return_attention_mask:
378
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
379
+ """
380
+ # Load from model defaults
381
+ bos_token_id = self.sp_tokenizer[self.bos_token]
382
+ mask_token_id = self.sp_tokenizer[self.mask_token]
383
+ gmask_token_id = self.sp_tokenizer[self.gmask_token]
384
+ assert self.padding_side == "left"
385
+ if return_attention_mask is None:
386
+ return_attention_mask = "attention_mask" in self.model_input_names
387
+
388
+ required_input = encoded_inputs[self.model_input_names[0]]
389
+ seq_length = len(required_input)
390
+
391
+ if padding_strategy == PaddingStrategy.LONGEST:
392
+ max_length = len(required_input)
393
+
394
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
395
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
396
+
397
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
398
+
399
+ # Initialize attention mask if not present.
400
+ if needs_to_be_padded or return_attention_mask:
401
+ context_length = required_input.index(bos_token_id)
402
+ attention_mask = np.ones((1, seq_length, seq_length))
403
+ attention_mask = np.tril(attention_mask)
404
+ attention_mask[:, :, :context_length] = 1
405
+ attention_mask = np.bool_(attention_mask < 0.5)
406
+ encoded_inputs["attention_mask"] = attention_mask
407
+
408
+ if needs_to_be_padded or return_attention_mask:
409
+ mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
410
+ mask_position = required_input.index(mask_token)
411
+ context_length = required_input.index(bos_token_id)
412
+ position_ids = np.arange(seq_length, dtype=np.int64)
413
+ position_ids[context_length:] = mask_position
414
+ block_position_ids = np.concatenate(
415
+ [np.zeros(context_length, dtype=np.int64),
416
+ np.arange(1, seq_length - context_length + 1, dtype=np.int64)])
417
+ encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
418
+
419
+ if needs_to_be_padded:
420
+ difference = max_length - len(required_input)
421
+
422
+ if "attention_mask" in encoded_inputs:
423
+ encoded_inputs["attention_mask"] = np.pad(encoded_inputs["attention_mask"],
424
+ pad_width=[(0, 0), (difference, 0), (difference, 0)],
425
+ mode='constant', constant_values=True)
426
+ if "token_type_ids" in encoded_inputs:
427
+ encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
428
+ "token_type_ids"
429
+ ]
430
+ if "special_tokens_mask" in encoded_inputs:
431
+ encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
432
+ if "position_ids" in encoded_inputs:
433
+ encoded_inputs["position_ids"] = np.pad(encoded_inputs["position_ids"],
434
+ pad_width=[(0, 0), (difference, 0)])
435
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
436
+
437
+ return encoded_inputs