thomas.dang commited on
Commit
4921b85
·
1 Parent(s): cbb9c95
Files changed (1) hide show
  1. tokenization_qwen.py +233 -126
tokenization_qwen.py CHANGED
@@ -13,15 +13,35 @@ import itertools
13
 
14
  import requests
15
  import unicodedata
16
- from typing import Collection, Dict, List, Set, Tuple, Union, Any, Callable, Optional
 
 
 
 
 
 
 
 
 
 
17
 
18
  import tiktoken
19
  import numpy as np
20
 
21
  from transformers import PreTrainedTokenizer, AddedToken
22
  from transformers.utils import try_to_load_from_cache
23
- from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, TruncationStrategy, \
24
- TextInput, TextInputPair, PreTokenizedInput, PreTokenizedInputPair, TensorType, EncodedInput, EncodedInputPair
 
 
 
 
 
 
 
 
 
 
25
 
26
  import matplotlib.colors as mcolors
27
  from matplotlib.font_manager import FontProperties
@@ -40,10 +60,10 @@ IMEND = "<|im_end|>"
40
  # as different as possible to minimize the impact
41
  EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
42
  SPECIAL_TOKENS = (
43
- ENDOFTEXT,
44
- IMSTART,
45
- IMEND,
46
- ) + EXTRAS
47
 
48
  LANGUAGES = {
49
  "en": "english",
@@ -62,14 +82,16 @@ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
62
  contents = f.read()
63
  return {
64
  base64.b64decode(token): int(rank)
65
- for token, rank in (line.split() for line in contents.splitlines() if line)
 
 
66
  }
67
 
68
 
69
  def _list_find(
70
- input_list: List[Any],
71
- candidates: Tuple[Any],
72
- start: int = 0,
73
  ):
74
  for i in range(start, len(input_list)):
75
  if input_list[i] in candidates:
@@ -78,12 +100,12 @@ def _list_find(
78
 
79
 
80
  def _replace_closed_tag(
81
- input_tokens: List[Any],
82
- start_tags: Union[Any, Tuple[Any]],
83
- end_tags: Union[Any, Tuple[Any]],
84
- inclusive_replace_func: Callable,
85
- exclusive_replace_func: Callable = lambda x: x,
86
- audio_info: Dict = None
87
  ):
88
  if isinstance(start_tags, (str, int)):
89
  start_tags = (start_tags,)
@@ -98,12 +120,16 @@ def _replace_closed_tag(
98
  start = _list_find(input_tokens, start_tags, end)
99
  if start == -1:
100
  break
101
- output_tokens.extend(exclusive_replace_func(input_tokens[end: start]))
102
  tag_idx = start_tags.index(input_tokens[start])
103
  end = _list_find(input_tokens, (end_tags[tag_idx],), start)
104
  if end == -1:
105
  raise ValueError("Unclosed audio token")
106
- output_tokens.extend(inclusive_replace_func(input_tokens[start: end + 1], audio_info, audio_idx))
 
 
 
 
107
  end += 1
108
  audio_idx += 1
109
  output_tokens.extend(exclusive_replace_func(input_tokens[end:]))
@@ -116,12 +142,12 @@ class QWenTokenizer(PreTrainedTokenizer):
116
  vocab_files_names = VOCAB_FILES_NAMES
117
 
118
  def __init__(
119
- self,
120
- vocab_file,
121
- errors="replace",
122
- audio_start_tag='<audio>',
123
- audio_end_tag='</audio>',
124
- **kwargs,
125
  ):
126
  super().__init__(**kwargs)
127
  self.audio_start_tag = audio_start_tag
@@ -129,7 +155,7 @@ class QWenTokenizer(PreTrainedTokenizer):
129
  self.audio_pad_tag = "[[[AUDIO:modality]]]"
130
 
131
  self.AUDIO_ST = (
132
- '[[[AUDIO:modality]]]',
133
  # Transcription Tag
134
  "<|startoftranscript|>", # Transcription
135
  "<|startofanalysis|>", # Analysis
@@ -146,7 +172,9 @@ class QWenTokenizer(PreTrainedTokenizer):
146
  "<|notimestamps|>",
147
  "<|sil|>",
148
  "<|timestamps|>",
149
- *[f"<|{i * 0.01:.2f}|>" for i in range(3001)], # timestamps 0.00-30.00
 
 
150
  # Output Instruction
151
  "<|caption_audiocaps|>", # Audiocaps caption style
152
  "<|caption_clotho|>", # Clotho caption style
@@ -164,12 +192,15 @@ class QWenTokenizer(PreTrainedTokenizer):
164
  "<|endofword|>",
165
  "<|delim|>", # delimiter of timestamps pair in audio grounding
166
  "<|emotion_recognition|>", # emotion recognition
 
167
  "<|music_description|>", # music description
168
  "<|note_analysis|>", # note analysis
169
  "<|pitch|>", # note analysis: pitch
170
  *[f"<|midi_pitch_{i}|>" for i in range(128)], # midi pitch 0-127
171
  "<|velocity|>", # note analysis: velocity
172
- *[f"<|midi_velocity_{i}|>" for i in range(128)], # midi velocity 0-127
 
 
173
  "<|sonic|>", # note analysis: sonic
174
  "<|instrument|>", # note analysis: instrument
175
  "<|speaker_meta|>", # meta information of speaker
@@ -186,25 +217,28 @@ class QWenTokenizer(PreTrainedTokenizer):
186
  "<|entities|>", # speech language understanding: entities
187
  "<|speech_edit|>", # speech edit
188
  audio_start_tag,
189
- audio_end_tag
190
  )
191
 
192
  self.errors = errors # how to handle errors in decoding
193
 
194
- self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int]
 
 
195
  self.special_tokens = {
196
  token: index
197
  for index, token in enumerate(
198
  SPECIAL_TOKENS + self.AUDIO_ST, start=len(self.mergeable_ranks)
199
-
200
  )
201
  }
202
  self.audio_start_id = self.special_tokens[self.audio_start_tag]
203
  self.audio_end_id = self.special_tokens[self.audio_end_tag]
204
  self.audio_pad_id = self.special_tokens[self.audio_pad_tag]
205
- print(f"audio_start_id: {self.audio_start_id}, "
206
- f"audio_end_id: {self.audio_end_id}, "
207
- f"audio_pad_id: {self.audio_pad_id}.")
 
 
208
 
209
  enc = tiktoken.Encoding(
210
  "Qwen",
@@ -213,7 +247,7 @@ class QWenTokenizer(PreTrainedTokenizer):
213
  special_tokens=self.special_tokens,
214
  )
215
  assert (
216
- len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
217
  ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
218
 
219
  self.decoder = {
@@ -230,7 +264,7 @@ class QWenTokenizer(PreTrainedTokenizer):
230
  def __getstate__(self):
231
  # for pickle lovers
232
  state = self.__dict__.copy()
233
- del state['tokenizer']
234
  return state
235
 
236
  def __setstate__(self, state):
@@ -251,7 +285,7 @@ class QWenTokenizer(PreTrainedTokenizer):
251
  return self.mergeable_ranks
252
 
253
  def convert_tokens_to_ids(
254
- self, tokens: Union[bytes, str, List[Union[bytes, str]]]
255
  ) -> List[int]:
256
  ids = []
257
  if isinstance(tokens, (str, bytes)):
@@ -266,13 +300,21 @@ class QWenTokenizer(PreTrainedTokenizer):
266
  ids.append(self.mergeable_ranks.get(token))
267
  return ids
268
 
269
- def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
 
 
 
 
270
  if not special_tokens and new_tokens:
271
- raise ValueError('Adding regular tokens is not supported')
272
  for token in new_tokens:
273
- surface_form = token.content if isinstance(token, AddedToken) else token
274
- if surface_form not in SPECIAL_TOKENS + self.AUDIO_ST:
275
- raise ValueError('Adding unknown special tokens is not supported')
 
 
 
 
276
  return 0
277
 
278
  def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
@@ -290,12 +332,12 @@ class QWenTokenizer(PreTrainedTokenizer):
290
  return (file_path,)
291
 
292
  def tokenize(
293
- self,
294
- text: str,
295
- allowed_special: Union[Set, str] = "all",
296
- disallowed_special: Union[Collection, str] = (),
297
- audio_info: Dict = None,
298
- **kwargs,
299
  ) -> List[Union[bytes, str]]:
300
  """
301
  Converts a string in a sequence of tokens.
@@ -321,61 +363,89 @@ class QWenTokenizer(PreTrainedTokenizer):
321
 
322
  # this implementation takes a detour: text -> token id -> token surface forms
323
  for t in self.tokenizer.encode(
324
- text, allowed_special=allowed_special, disallowed_special=disallowed_special
 
 
325
  ):
326
  tokens.append(self.decoder[t])
327
 
328
  def _encode_audiourl(audio_tokens, audio_info, audio_idx):
329
- assert audio_tokens[0] == self.audio_start_tag and audio_tokens[-1] == self.audio_end_tag
330
- audio_token_span = audio_info['audio_span_tokens'][audio_idx]
331
- out_audio_tokens = [self.audio_start_tag] + [self.audio_pad_tag] * (audio_token_span - 2) + [
332
- self.audio_end_tag]
 
 
 
 
 
 
333
  return out_audio_tokens
334
 
335
- return _replace_closed_tag(tokens, self.audio_start_tag, self.audio_end_tag, _encode_audiourl,
336
- audio_info=audio_info)
 
 
 
 
 
337
 
338
  def _batch_encode_plus(
339
- self,
340
- batch_text_or_text_pairs: Union[
341
- List[TextInput],
342
- List[TextInputPair],
343
- List[PreTokenizedInput],
344
- List[PreTokenizedInputPair],
345
- List[EncodedInput],
346
- List[EncodedInputPair],
347
- ],
348
- add_special_tokens: bool = True,
349
- padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
350
- truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
351
- max_length: Optional[int] = None,
352
- stride: int = 0,
353
- is_split_into_words: bool = False,
354
- pad_to_multiple_of: Optional[int] = None,
355
- return_tensors: Optional[Union[str, TensorType]] = None,
356
- return_token_type_ids: Optional[bool] = None,
357
- return_attention_mask: Optional[bool] = None,
358
- return_overflowing_tokens: bool = False,
359
- return_special_tokens_mask: bool = False,
360
- return_offsets_mapping: bool = False,
361
- return_length: bool = False,
362
- verbose: bool = True,
363
- **kwargs,
364
  ) -> BatchEncoding:
365
 
366
  def get_input_ids(text):
367
  if isinstance(text, str):
368
  tokens = self.tokenize(text, **kwargs)
369
  return self.convert_tokens_to_ids(tokens)
370
- elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
 
 
 
 
371
  if is_split_into_words:
372
  tokens = list(
373
- itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text))
 
 
 
 
 
 
 
374
  )
375
  return self.convert_tokens_to_ids(tokens)
376
  else:
377
  return self.convert_tokens_to_ids(text)
378
- elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
 
 
 
 
379
  return text
380
  else:
381
  raise ValueError(
@@ -392,18 +462,22 @@ class QWenTokenizer(PreTrainedTokenizer):
392
  input_ids = []
393
  audio_info = kwargs.pop("audio_info", None)
394
  for pair_id in range(len(batch_text_or_text_pairs)):
395
- kwargs['audio_info'] = audio_info[pair_id]
396
  ids_or_pair_ids = batch_text_or_text_pairs[pair_id]
397
  # for ids_or_pair_ids in batch_text_or_text_pairs:
398
  if not isinstance(ids_or_pair_ids, (list, tuple)):
399
  ids, pair_ids = ids_or_pair_ids, None
400
- elif is_split_into_words and not isinstance(ids_or_pair_ids[0], (list, tuple)):
 
 
401
  ids, pair_ids = ids_or_pair_ids, None
402
  else:
403
  ids, pair_ids = ids_or_pair_ids
404
 
405
  first_ids = get_input_ids(ids)
406
- second_ids = get_input_ids(pair_ids) if pair_ids is not None else None
 
 
407
  input_ids.append((first_ids, second_ids))
408
 
409
  batch_outputs = self._batch_prepare_for_model(
@@ -473,23 +547,35 @@ class QWenTokenizer(PreTrainedTokenizer):
473
  raise NotImplementedError
474
 
475
  def _decode(
476
- self,
477
- token_ids: Union[int, List[int]],
478
- skip_special_tokens: bool = False,
479
- errors: str = None,
480
- **kwargs,
481
  ) -> str:
482
  if isinstance(token_ids, int):
483
  token_ids = [token_ids]
484
  audio_info = kwargs.pop("audio_info", None)
485
 
486
  def _decode_audiourl(audio_token_ids, audio_info, audio_idx):
487
- assert audio_token_ids[0] == self.audio_start_id and audio_token_ids[-1] == self.audio_end_id
 
 
 
488
  audio_url = audio_info["audio_urls"][audio_idx]
489
- return [self.audio_start_id] + self.tokenizer.encode(audio_url) + [self.audio_end_id]
 
 
 
 
490
 
491
- token_ids = _replace_closed_tag(token_ids, self.audio_start_id, self.audio_end_id, _decode_audiourl,
492
- audio_info=audio_info)
 
 
 
 
 
493
 
494
  if skip_special_tokens:
495
  token_ids = [i for i in token_ids if i < self.eod_id]
@@ -498,18 +584,32 @@ class QWenTokenizer(PreTrainedTokenizer):
498
  def to_list_format(self, text: str):
499
  text = unicodedata.normalize("NFC", text)
500
  token_ids = self.tokenizer.encode(
501
- text, allowed_special=set(self.AUDIO_ST + (ENDOFTEXT,)))
 
502
 
503
  def _encode_audio_info(tokens):
504
  if len(tokens) == 0:
505
  return []
506
- if tokens[0] == self.audio_start_id and tokens[-1] == self.audio_end_id:
507
- key = 'audio'
 
 
 
508
  else:
509
- _tobytes = lambda x: x.encode('utf-8') if isinstance(x, str) else x
510
- return [{'text': b''.join(map(_tobytes, map(self.decoder.get, tokens))).decode('utf-8')}]
511
- _tobytes = lambda x: x.encode('utf-8') if isinstance(x, str) else x
512
- val = b''.join(map(_tobytes, map(self.decoder.get, tokens[1:-1]))).decode('utf-8')
 
 
 
 
 
 
 
 
 
 
513
  return [{key: val}]
514
 
515
  return _replace_closed_tag(
@@ -521,21 +621,25 @@ class QWenTokenizer(PreTrainedTokenizer):
521
  )
522
 
523
  def from_list_format(self, list_format: List[Dict]):
524
- text = ''
525
  num_audios = 0
526
  for ele in list_format:
527
- if 'audio' in ele:
528
  num_audios += 1
529
- text += f'Audio {num_audios}:'
530
- text += self.audio_start_tag + ele['audio'] + self.audio_end_tag
531
- text += '\n'
532
- elif 'text' in ele:
533
- text += ele['text']
534
- elif 'box' in ele:
535
- if 'ref' in ele:
536
- text += self.ref_start_tag + ele['ref'] + self.ref_end_tag
537
- for box in ele['box']:
538
- text += self.box_start_tag + '(%d,%d),(%d,%d)' % (box[0], box[1], box[2], box[3]) + self.box_end_tag
 
 
 
 
539
  else:
540
  raise ValueError("Unsupport element: " + str(ele))
541
  return text
@@ -549,12 +653,16 @@ class QWenTokenizer(PreTrainedTokenizer):
549
  if len(audio_urls) > 0:
550
  audios, audio_lens, audio_span_tokens = [], [], []
551
  for audio_path in audio_urls:
552
- if audio_path.startswith("http://") or audio_path.startswith("https://"): # http
 
 
553
  data = bytes(requests.get(audio_path, stream=True).content)
554
  audio = load_bytesio_audio(data)
555
  else:
556
  audio = load_audio(audio_path)
557
- L = (audio.shape[0] if audio.shape[0] <= 480000 else 480000) # max_length < 30s
 
 
558
  mel_len = L // 160
559
  audio = pad_or_trim(audio.flatten())
560
  mel = log_mel_spectrogram(audio)
@@ -563,17 +671,16 @@ class QWenTokenizer(PreTrainedTokenizer):
563
  audio_len = [audio_len_after_cnn, audio_token_num]
564
  audios.append(mel)
565
  audio_lens.append(audio_len)
566
- audio_span_tokens.append(audio_token_num + 2) # add audio bos eos
 
 
567
  input_audio_lengths = torch.IntTensor(audio_lens)
568
  input_audios = torch.stack(audios, dim=0)
569
- return {"input_audios": input_audios,
570
- "input_audio_lengths": input_audio_lengths,
571
- "audio_span_tokens": audio_span_tokens,
572
- "audio_urls": audio_urls}
 
 
573
  else:
574
  return None
575
-
576
-
577
-
578
-
579
-
 
13
 
14
  import requests
15
  import unicodedata
16
+ from typing import (
17
+ Collection,
18
+ Dict,
19
+ List,
20
+ Set,
21
+ Tuple,
22
+ Union,
23
+ Any,
24
+ Callable,
25
+ Optional,
26
+ )
27
 
28
  import tiktoken
29
  import numpy as np
30
 
31
  from transformers import PreTrainedTokenizer, AddedToken
32
  from transformers.utils import try_to_load_from_cache
33
+ from transformers.tokenization_utils_base import (
34
+ BatchEncoding,
35
+ PaddingStrategy,
36
+ TruncationStrategy,
37
+ TextInput,
38
+ TextInputPair,
39
+ PreTokenizedInput,
40
+ PreTokenizedInputPair,
41
+ TensorType,
42
+ EncodedInput,
43
+ EncodedInputPair,
44
+ )
45
 
46
  import matplotlib.colors as mcolors
47
  from matplotlib.font_manager import FontProperties
 
60
  # as different as possible to minimize the impact
61
  EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
62
  SPECIAL_TOKENS = (
63
+ ENDOFTEXT,
64
+ IMSTART,
65
+ IMEND,
66
+ ) + EXTRAS
67
 
68
  LANGUAGES = {
69
  "en": "english",
 
82
  contents = f.read()
83
  return {
84
  base64.b64decode(token): int(rank)
85
+ for token, rank in (
86
+ line.split() for line in contents.splitlines() if line
87
+ )
88
  }
89
 
90
 
91
  def _list_find(
92
+ input_list: List[Any],
93
+ candidates: Tuple[Any],
94
+ start: int = 0,
95
  ):
96
  for i in range(start, len(input_list)):
97
  if input_list[i] in candidates:
 
100
 
101
 
102
  def _replace_closed_tag(
103
+ input_tokens: List[Any],
104
+ start_tags: Union[Any, Tuple[Any]],
105
+ end_tags: Union[Any, Tuple[Any]],
106
+ inclusive_replace_func: Callable,
107
+ exclusive_replace_func: Callable = lambda x: x,
108
+ audio_info: Dict = None,
109
  ):
110
  if isinstance(start_tags, (str, int)):
111
  start_tags = (start_tags,)
 
120
  start = _list_find(input_tokens, start_tags, end)
121
  if start == -1:
122
  break
123
+ output_tokens.extend(exclusive_replace_func(input_tokens[end:start]))
124
  tag_idx = start_tags.index(input_tokens[start])
125
  end = _list_find(input_tokens, (end_tags[tag_idx],), start)
126
  if end == -1:
127
  raise ValueError("Unclosed audio token")
128
+ output_tokens.extend(
129
+ inclusive_replace_func(
130
+ input_tokens[start : end + 1], audio_info, audio_idx
131
+ )
132
+ )
133
  end += 1
134
  audio_idx += 1
135
  output_tokens.extend(exclusive_replace_func(input_tokens[end:]))
 
142
  vocab_files_names = VOCAB_FILES_NAMES
143
 
144
  def __init__(
145
+ self,
146
+ vocab_file,
147
+ errors="replace",
148
+ audio_start_tag="<audio>",
149
+ audio_end_tag="</audio>",
150
+ **kwargs,
151
  ):
152
  super().__init__(**kwargs)
153
  self.audio_start_tag = audio_start_tag
 
155
  self.audio_pad_tag = "[[[AUDIO:modality]]]"
156
 
157
  self.AUDIO_ST = (
158
+ "[[[AUDIO:modality]]]",
159
  # Transcription Tag
160
  "<|startoftranscript|>", # Transcription
161
  "<|startofanalysis|>", # Analysis
 
172
  "<|notimestamps|>",
173
  "<|sil|>",
174
  "<|timestamps|>",
175
+ *[
176
+ f"<|{i * 0.01:.2f}|>" for i in range(3001)
177
+ ], # timestamps 0.00-30.00
178
  # Output Instruction
179
  "<|caption_audiocaps|>", # Audiocaps caption style
180
  "<|caption_clotho|>", # Clotho caption style
 
192
  "<|endofword|>",
193
  "<|delim|>", # delimiter of timestamps pair in audio grounding
194
  "<|emotion_recognition|>", # emotion recognition
195
+ "<|emotion_transcription|>",
196
  "<|music_description|>", # music description
197
  "<|note_analysis|>", # note analysis
198
  "<|pitch|>", # note analysis: pitch
199
  *[f"<|midi_pitch_{i}|>" for i in range(128)], # midi pitch 0-127
200
  "<|velocity|>", # note analysis: velocity
201
+ *[
202
+ f"<|midi_velocity_{i}|>" for i in range(128)
203
+ ], # midi velocity 0-127
204
  "<|sonic|>", # note analysis: sonic
205
  "<|instrument|>", # note analysis: instrument
206
  "<|speaker_meta|>", # meta information of speaker
 
217
  "<|entities|>", # speech language understanding: entities
218
  "<|speech_edit|>", # speech edit
219
  audio_start_tag,
220
+ audio_end_tag,
221
  )
222
 
223
  self.errors = errors # how to handle errors in decoding
224
 
225
+ self.mergeable_ranks = _load_tiktoken_bpe(
226
+ vocab_file
227
+ ) # type: dict[bytes, int]
228
  self.special_tokens = {
229
  token: index
230
  for index, token in enumerate(
231
  SPECIAL_TOKENS + self.AUDIO_ST, start=len(self.mergeable_ranks)
 
232
  )
233
  }
234
  self.audio_start_id = self.special_tokens[self.audio_start_tag]
235
  self.audio_end_id = self.special_tokens[self.audio_end_tag]
236
  self.audio_pad_id = self.special_tokens[self.audio_pad_tag]
237
+ print(
238
+ f"audio_start_id: {self.audio_start_id}, "
239
+ f"audio_end_id: {self.audio_end_id}, "
240
+ f"audio_pad_id: {self.audio_pad_id}."
241
+ )
242
 
243
  enc = tiktoken.Encoding(
244
  "Qwen",
 
247
  special_tokens=self.special_tokens,
248
  )
249
  assert (
250
+ len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
251
  ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
252
 
253
  self.decoder = {
 
264
  def __getstate__(self):
265
  # for pickle lovers
266
  state = self.__dict__.copy()
267
+ del state["tokenizer"]
268
  return state
269
 
270
  def __setstate__(self, state):
 
285
  return self.mergeable_ranks
286
 
287
  def convert_tokens_to_ids(
288
+ self, tokens: Union[bytes, str, List[Union[bytes, str]]]
289
  ) -> List[int]:
290
  ids = []
291
  if isinstance(tokens, (str, bytes)):
 
300
  ids.append(self.mergeable_ranks.get(token))
301
  return ids
302
 
303
+ def _add_tokens(
304
+ self,
305
+ new_tokens: Union[List[str], List[AddedToken]],
306
+ special_tokens: bool = False,
307
+ ) -> int:
308
  if not special_tokens and new_tokens:
309
+ raise ValueError("Adding regular tokens is not supported")
310
  for token in new_tokens:
311
+ surface_form = (
312
+ token.content if isinstance(token, AddedToken) else token
313
+ )
314
+ if surface_form not in SPECIAL_TOKENS + self.AUDIO_ST:
315
+ raise ValueError(
316
+ "Adding unknown special tokens is not supported"
317
+ )
318
  return 0
319
 
320
  def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
 
332
  return (file_path,)
333
 
334
  def tokenize(
335
+ self,
336
+ text: str,
337
+ allowed_special: Union[Set, str] = "all",
338
+ disallowed_special: Union[Collection, str] = (),
339
+ audio_info: Dict = None,
340
+ **kwargs,
341
  ) -> List[Union[bytes, str]]:
342
  """
343
  Converts a string in a sequence of tokens.
 
363
 
364
  # this implementation takes a detour: text -> token id -> token surface forms
365
  for t in self.tokenizer.encode(
366
+ text,
367
+ allowed_special=allowed_special,
368
+ disallowed_special=disallowed_special,
369
  ):
370
  tokens.append(self.decoder[t])
371
 
372
  def _encode_audiourl(audio_tokens, audio_info, audio_idx):
373
+ assert (
374
+ audio_tokens[0] == self.audio_start_tag
375
+ and audio_tokens[-1] == self.audio_end_tag
376
+ )
377
+ audio_token_span = audio_info["audio_span_tokens"][audio_idx]
378
+ out_audio_tokens = (
379
+ [self.audio_start_tag]
380
+ + [self.audio_pad_tag] * (audio_token_span - 2)
381
+ + [self.audio_end_tag]
382
+ )
383
  return out_audio_tokens
384
 
385
+ return _replace_closed_tag(
386
+ tokens,
387
+ self.audio_start_tag,
388
+ self.audio_end_tag,
389
+ _encode_audiourl,
390
+ audio_info=audio_info,
391
+ )
392
 
393
  def _batch_encode_plus(
394
+ self,
395
+ batch_text_or_text_pairs: Union[
396
+ List[TextInput],
397
+ List[TextInputPair],
398
+ List[PreTokenizedInput],
399
+ List[PreTokenizedInputPair],
400
+ List[EncodedInput],
401
+ List[EncodedInputPair],
402
+ ],
403
+ add_special_tokens: bool = True,
404
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
405
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
406
+ max_length: Optional[int] = None,
407
+ stride: int = 0,
408
+ is_split_into_words: bool = False,
409
+ pad_to_multiple_of: Optional[int] = None,
410
+ return_tensors: Optional[Union[str, TensorType]] = None,
411
+ return_token_type_ids: Optional[bool] = None,
412
+ return_attention_mask: Optional[bool] = None,
413
+ return_overflowing_tokens: bool = False,
414
+ return_special_tokens_mask: bool = False,
415
+ return_offsets_mapping: bool = False,
416
+ return_length: bool = False,
417
+ verbose: bool = True,
418
+ **kwargs,
419
  ) -> BatchEncoding:
420
 
421
  def get_input_ids(text):
422
  if isinstance(text, str):
423
  tokens = self.tokenize(text, **kwargs)
424
  return self.convert_tokens_to_ids(tokens)
425
+ elif (
426
+ isinstance(text, (list, tuple))
427
+ and len(text) > 0
428
+ and isinstance(text[0], str)
429
+ ):
430
  if is_split_into_words:
431
  tokens = list(
432
+ itertools.chain(
433
+ *(
434
+ self.tokenize(
435
+ t, is_split_into_words=True, **kwargs
436
+ )
437
+ for t in text
438
+ )
439
+ )
440
  )
441
  return self.convert_tokens_to_ids(tokens)
442
  else:
443
  return self.convert_tokens_to_ids(text)
444
+ elif (
445
+ isinstance(text, (list, tuple))
446
+ and len(text) > 0
447
+ and isinstance(text[0], int)
448
+ ):
449
  return text
450
  else:
451
  raise ValueError(
 
462
  input_ids = []
463
  audio_info = kwargs.pop("audio_info", None)
464
  for pair_id in range(len(batch_text_or_text_pairs)):
465
+ kwargs["audio_info"] = audio_info[pair_id]
466
  ids_or_pair_ids = batch_text_or_text_pairs[pair_id]
467
  # for ids_or_pair_ids in batch_text_or_text_pairs:
468
  if not isinstance(ids_or_pair_ids, (list, tuple)):
469
  ids, pair_ids = ids_or_pair_ids, None
470
+ elif is_split_into_words and not isinstance(
471
+ ids_or_pair_ids[0], (list, tuple)
472
+ ):
473
  ids, pair_ids = ids_or_pair_ids, None
474
  else:
475
  ids, pair_ids = ids_or_pair_ids
476
 
477
  first_ids = get_input_ids(ids)
478
+ second_ids = (
479
+ get_input_ids(pair_ids) if pair_ids is not None else None
480
+ )
481
  input_ids.append((first_ids, second_ids))
482
 
483
  batch_outputs = self._batch_prepare_for_model(
 
547
  raise NotImplementedError
548
 
549
  def _decode(
550
+ self,
551
+ token_ids: Union[int, List[int]],
552
+ skip_special_tokens: bool = False,
553
+ errors: str = None,
554
+ **kwargs,
555
  ) -> str:
556
  if isinstance(token_ids, int):
557
  token_ids = [token_ids]
558
  audio_info = kwargs.pop("audio_info", None)
559
 
560
  def _decode_audiourl(audio_token_ids, audio_info, audio_idx):
561
+ assert (
562
+ audio_token_ids[0] == self.audio_start_id
563
+ and audio_token_ids[-1] == self.audio_end_id
564
+ )
565
  audio_url = audio_info["audio_urls"][audio_idx]
566
+ return (
567
+ [self.audio_start_id]
568
+ + self.tokenizer.encode(audio_url)
569
+ + [self.audio_end_id]
570
+ )
571
 
572
+ token_ids = _replace_closed_tag(
573
+ token_ids,
574
+ self.audio_start_id,
575
+ self.audio_end_id,
576
+ _decode_audiourl,
577
+ audio_info=audio_info,
578
+ )
579
 
580
  if skip_special_tokens:
581
  token_ids = [i for i in token_ids if i < self.eod_id]
 
584
  def to_list_format(self, text: str):
585
  text = unicodedata.normalize("NFC", text)
586
  token_ids = self.tokenizer.encode(
587
+ text, allowed_special=set(self.AUDIO_ST + (ENDOFTEXT,))
588
+ )
589
 
590
  def _encode_audio_info(tokens):
591
  if len(tokens) == 0:
592
  return []
593
+ if (
594
+ tokens[0] == self.audio_start_id
595
+ and tokens[-1] == self.audio_end_id
596
+ ):
597
+ key = "audio"
598
  else:
599
+ _tobytes = lambda x: (
600
+ x.encode("utf-8") if isinstance(x, str) else x
601
+ )
602
+ return [
603
+ {
604
+ "text": b"".join(
605
+ map(_tobytes, map(self.decoder.get, tokens))
606
+ ).decode("utf-8")
607
+ }
608
+ ]
609
+ _tobytes = lambda x: x.encode("utf-8") if isinstance(x, str) else x
610
+ val = b"".join(
611
+ map(_tobytes, map(self.decoder.get, tokens[1:-1]))
612
+ ).decode("utf-8")
613
  return [{key: val}]
614
 
615
  return _replace_closed_tag(
 
621
  )
622
 
623
  def from_list_format(self, list_format: List[Dict]):
624
+ text = ""
625
  num_audios = 0
626
  for ele in list_format:
627
+ if "audio" in ele:
628
  num_audios += 1
629
+ text += f"Audio {num_audios}:"
630
+ text += self.audio_start_tag + ele["audio"] + self.audio_end_tag
631
+ text += "\n"
632
+ elif "text" in ele:
633
+ text += ele["text"]
634
+ elif "box" in ele:
635
+ if "ref" in ele:
636
+ text += self.ref_start_tag + ele["ref"] + self.ref_end_tag
637
+ for box in ele["box"]:
638
+ text += (
639
+ self.box_start_tag
640
+ + "(%d,%d),(%d,%d)" % (box[0], box[1], box[2], box[3])
641
+ + self.box_end_tag
642
+ )
643
  else:
644
  raise ValueError("Unsupport element: " + str(ele))
645
  return text
 
653
  if len(audio_urls) > 0:
654
  audios, audio_lens, audio_span_tokens = [], [], []
655
  for audio_path in audio_urls:
656
+ if audio_path.startswith("http://") or audio_path.startswith(
657
+ "https://"
658
+ ): # http
659
  data = bytes(requests.get(audio_path, stream=True).content)
660
  audio = load_bytesio_audio(data)
661
  else:
662
  audio = load_audio(audio_path)
663
+ L = (
664
+ audio.shape[0] if audio.shape[0] <= 480000 else 480000
665
+ ) # max_length < 30s
666
  mel_len = L // 160
667
  audio = pad_or_trim(audio.flatten())
668
  mel = log_mel_spectrogram(audio)
 
671
  audio_len = [audio_len_after_cnn, audio_token_num]
672
  audios.append(mel)
673
  audio_lens.append(audio_len)
674
+ audio_span_tokens.append(
675
+ audio_token_num + 2
676
+ ) # add audio bos eos
677
  input_audio_lengths = torch.IntTensor(audio_lens)
678
  input_audios = torch.stack(audios, dim=0)
679
+ return {
680
+ "input_audios": input_audios,
681
+ "input_audio_lengths": input_audio_lengths,
682
+ "audio_span_tokens": audio_span_tokens,
683
+ "audio_urls": audio_urls,
684
+ }
685
  else:
686
  return None