thomas.dang
commited on
Commit
·
4921b85
1
Parent(s):
cbb9c95
1st cm
Browse files- tokenization_qwen.py +233 -126
tokenization_qwen.py
CHANGED
@@ -13,15 +13,35 @@ import itertools
|
|
13 |
|
14 |
import requests
|
15 |
import unicodedata
|
16 |
-
from typing import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
import tiktoken
|
19 |
import numpy as np
|
20 |
|
21 |
from transformers import PreTrainedTokenizer, AddedToken
|
22 |
from transformers.utils import try_to_load_from_cache
|
23 |
-
from transformers.tokenization_utils_base import
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
import matplotlib.colors as mcolors
|
27 |
from matplotlib.font_manager import FontProperties
|
@@ -40,10 +60,10 @@ IMEND = "<|im_end|>"
|
|
40 |
# as different as possible to minimize the impact
|
41 |
EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
|
42 |
SPECIAL_TOKENS = (
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
|
48 |
LANGUAGES = {
|
49 |
"en": "english",
|
@@ -62,14 +82,16 @@ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
|
|
62 |
contents = f.read()
|
63 |
return {
|
64 |
base64.b64decode(token): int(rank)
|
65 |
-
for token, rank in (
|
|
|
|
|
66 |
}
|
67 |
|
68 |
|
69 |
def _list_find(
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
):
|
74 |
for i in range(start, len(input_list)):
|
75 |
if input_list[i] in candidates:
|
@@ -78,12 +100,12 @@ def _list_find(
|
|
78 |
|
79 |
|
80 |
def _replace_closed_tag(
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
):
|
88 |
if isinstance(start_tags, (str, int)):
|
89 |
start_tags = (start_tags,)
|
@@ -98,12 +120,16 @@ def _replace_closed_tag(
|
|
98 |
start = _list_find(input_tokens, start_tags, end)
|
99 |
if start == -1:
|
100 |
break
|
101 |
-
output_tokens.extend(exclusive_replace_func(input_tokens[end:
|
102 |
tag_idx = start_tags.index(input_tokens[start])
|
103 |
end = _list_find(input_tokens, (end_tags[tag_idx],), start)
|
104 |
if end == -1:
|
105 |
raise ValueError("Unclosed audio token")
|
106 |
-
output_tokens.extend(
|
|
|
|
|
|
|
|
|
107 |
end += 1
|
108 |
audio_idx += 1
|
109 |
output_tokens.extend(exclusive_replace_func(input_tokens[end:]))
|
@@ -116,12 +142,12 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
116 |
vocab_files_names = VOCAB_FILES_NAMES
|
117 |
|
118 |
def __init__(
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
):
|
126 |
super().__init__(**kwargs)
|
127 |
self.audio_start_tag = audio_start_tag
|
@@ -129,7 +155,7 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
129 |
self.audio_pad_tag = "[[[AUDIO:modality]]]"
|
130 |
|
131 |
self.AUDIO_ST = (
|
132 |
-
|
133 |
# Transcription Tag
|
134 |
"<|startoftranscript|>", # Transcription
|
135 |
"<|startofanalysis|>", # Analysis
|
@@ -146,7 +172,9 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
146 |
"<|notimestamps|>",
|
147 |
"<|sil|>",
|
148 |
"<|timestamps|>",
|
149 |
-
*[
|
|
|
|
|
150 |
# Output Instruction
|
151 |
"<|caption_audiocaps|>", # Audiocaps caption style
|
152 |
"<|caption_clotho|>", # Clotho caption style
|
@@ -164,12 +192,15 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
164 |
"<|endofword|>",
|
165 |
"<|delim|>", # delimiter of timestamps pair in audio grounding
|
166 |
"<|emotion_recognition|>", # emotion recognition
|
|
|
167 |
"<|music_description|>", # music description
|
168 |
"<|note_analysis|>", # note analysis
|
169 |
"<|pitch|>", # note analysis: pitch
|
170 |
*[f"<|midi_pitch_{i}|>" for i in range(128)], # midi pitch 0-127
|
171 |
"<|velocity|>", # note analysis: velocity
|
172 |
-
*[
|
|
|
|
|
173 |
"<|sonic|>", # note analysis: sonic
|
174 |
"<|instrument|>", # note analysis: instrument
|
175 |
"<|speaker_meta|>", # meta information of speaker
|
@@ -186,25 +217,28 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
186 |
"<|entities|>", # speech language understanding: entities
|
187 |
"<|speech_edit|>", # speech edit
|
188 |
audio_start_tag,
|
189 |
-
audio_end_tag
|
190 |
)
|
191 |
|
192 |
self.errors = errors # how to handle errors in decoding
|
193 |
|
194 |
-
self.mergeable_ranks = _load_tiktoken_bpe(
|
|
|
|
|
195 |
self.special_tokens = {
|
196 |
token: index
|
197 |
for index, token in enumerate(
|
198 |
SPECIAL_TOKENS + self.AUDIO_ST, start=len(self.mergeable_ranks)
|
199 |
-
|
200 |
)
|
201 |
}
|
202 |
self.audio_start_id = self.special_tokens[self.audio_start_tag]
|
203 |
self.audio_end_id = self.special_tokens[self.audio_end_tag]
|
204 |
self.audio_pad_id = self.special_tokens[self.audio_pad_tag]
|
205 |
-
print(
|
206 |
-
|
207 |
-
|
|
|
|
|
208 |
|
209 |
enc = tiktoken.Encoding(
|
210 |
"Qwen",
|
@@ -213,7 +247,7 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
213 |
special_tokens=self.special_tokens,
|
214 |
)
|
215 |
assert (
|
216 |
-
|
217 |
), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
|
218 |
|
219 |
self.decoder = {
|
@@ -230,7 +264,7 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
230 |
def __getstate__(self):
|
231 |
# for pickle lovers
|
232 |
state = self.__dict__.copy()
|
233 |
-
del state[
|
234 |
return state
|
235 |
|
236 |
def __setstate__(self, state):
|
@@ -251,7 +285,7 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
251 |
return self.mergeable_ranks
|
252 |
|
253 |
def convert_tokens_to_ids(
|
254 |
-
|
255 |
) -> List[int]:
|
256 |
ids = []
|
257 |
if isinstance(tokens, (str, bytes)):
|
@@ -266,13 +300,21 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
266 |
ids.append(self.mergeable_ranks.get(token))
|
267 |
return ids
|
268 |
|
269 |
-
def _add_tokens(
|
|
|
|
|
|
|
|
|
270 |
if not special_tokens and new_tokens:
|
271 |
-
raise ValueError(
|
272 |
for token in new_tokens:
|
273 |
-
surface_form =
|
274 |
-
|
275 |
-
|
|
|
|
|
|
|
|
|
276 |
return 0
|
277 |
|
278 |
def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
|
@@ -290,12 +332,12 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
290 |
return (file_path,)
|
291 |
|
292 |
def tokenize(
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
) -> List[Union[bytes, str]]:
|
300 |
"""
|
301 |
Converts a string in a sequence of tokens.
|
@@ -321,61 +363,89 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
321 |
|
322 |
# this implementation takes a detour: text -> token id -> token surface forms
|
323 |
for t in self.tokenizer.encode(
|
324 |
-
|
|
|
|
|
325 |
):
|
326 |
tokens.append(self.decoder[t])
|
327 |
|
328 |
def _encode_audiourl(audio_tokens, audio_info, audio_idx):
|
329 |
-
assert
|
330 |
-
|
331 |
-
|
332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
return out_audio_tokens
|
334 |
|
335 |
-
return _replace_closed_tag(
|
336 |
-
|
|
|
|
|
|
|
|
|
|
|
337 |
|
338 |
def _batch_encode_plus(
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
) -> BatchEncoding:
|
365 |
|
366 |
def get_input_ids(text):
|
367 |
if isinstance(text, str):
|
368 |
tokens = self.tokenize(text, **kwargs)
|
369 |
return self.convert_tokens_to_ids(tokens)
|
370 |
-
elif
|
|
|
|
|
|
|
|
|
371 |
if is_split_into_words:
|
372 |
tokens = list(
|
373 |
-
itertools.chain(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
374 |
)
|
375 |
return self.convert_tokens_to_ids(tokens)
|
376 |
else:
|
377 |
return self.convert_tokens_to_ids(text)
|
378 |
-
elif
|
|
|
|
|
|
|
|
|
379 |
return text
|
380 |
else:
|
381 |
raise ValueError(
|
@@ -392,18 +462,22 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
392 |
input_ids = []
|
393 |
audio_info = kwargs.pop("audio_info", None)
|
394 |
for pair_id in range(len(batch_text_or_text_pairs)):
|
395 |
-
kwargs[
|
396 |
ids_or_pair_ids = batch_text_or_text_pairs[pair_id]
|
397 |
# for ids_or_pair_ids in batch_text_or_text_pairs:
|
398 |
if not isinstance(ids_or_pair_ids, (list, tuple)):
|
399 |
ids, pair_ids = ids_or_pair_ids, None
|
400 |
-
elif is_split_into_words and not isinstance(
|
|
|
|
|
401 |
ids, pair_ids = ids_or_pair_ids, None
|
402 |
else:
|
403 |
ids, pair_ids = ids_or_pair_ids
|
404 |
|
405 |
first_ids = get_input_ids(ids)
|
406 |
-
second_ids =
|
|
|
|
|
407 |
input_ids.append((first_ids, second_ids))
|
408 |
|
409 |
batch_outputs = self._batch_prepare_for_model(
|
@@ -473,23 +547,35 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
473 |
raise NotImplementedError
|
474 |
|
475 |
def _decode(
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
) -> str:
|
482 |
if isinstance(token_ids, int):
|
483 |
token_ids = [token_ids]
|
484 |
audio_info = kwargs.pop("audio_info", None)
|
485 |
|
486 |
def _decode_audiourl(audio_token_ids, audio_info, audio_idx):
|
487 |
-
assert
|
|
|
|
|
|
|
488 |
audio_url = audio_info["audio_urls"][audio_idx]
|
489 |
-
return
|
|
|
|
|
|
|
|
|
490 |
|
491 |
-
token_ids = _replace_closed_tag(
|
492 |
-
|
|
|
|
|
|
|
|
|
|
|
493 |
|
494 |
if skip_special_tokens:
|
495 |
token_ids = [i for i in token_ids if i < self.eod_id]
|
@@ -498,18 +584,32 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
498 |
def to_list_format(self, text: str):
|
499 |
text = unicodedata.normalize("NFC", text)
|
500 |
token_ids = self.tokenizer.encode(
|
501 |
-
text, allowed_special=set(self.AUDIO_ST + (ENDOFTEXT,))
|
|
|
502 |
|
503 |
def _encode_audio_info(tokens):
|
504 |
if len(tokens) == 0:
|
505 |
return []
|
506 |
-
if
|
507 |
-
|
|
|
|
|
|
|
508 |
else:
|
509 |
-
_tobytes = lambda x:
|
510 |
-
|
511 |
-
|
512 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
513 |
return [{key: val}]
|
514 |
|
515 |
return _replace_closed_tag(
|
@@ -521,21 +621,25 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
521 |
)
|
522 |
|
523 |
def from_list_format(self, list_format: List[Dict]):
|
524 |
-
text =
|
525 |
num_audios = 0
|
526 |
for ele in list_format:
|
527 |
-
if
|
528 |
num_audios += 1
|
529 |
-
text += f
|
530 |
-
text += self.audio_start_tag + ele[
|
531 |
-
text +=
|
532 |
-
elif
|
533 |
-
text += ele[
|
534 |
-
elif
|
535 |
-
if
|
536 |
-
text += self.ref_start_tag + ele[
|
537 |
-
for box in ele[
|
538 |
-
text +=
|
|
|
|
|
|
|
|
|
539 |
else:
|
540 |
raise ValueError("Unsupport element: " + str(ele))
|
541 |
return text
|
@@ -549,12 +653,16 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
549 |
if len(audio_urls) > 0:
|
550 |
audios, audio_lens, audio_span_tokens = [], [], []
|
551 |
for audio_path in audio_urls:
|
552 |
-
if audio_path.startswith("http://") or audio_path.startswith(
|
|
|
|
|
553 |
data = bytes(requests.get(audio_path, stream=True).content)
|
554 |
audio = load_bytesio_audio(data)
|
555 |
else:
|
556 |
audio = load_audio(audio_path)
|
557 |
-
L = (
|
|
|
|
|
558 |
mel_len = L // 160
|
559 |
audio = pad_or_trim(audio.flatten())
|
560 |
mel = log_mel_spectrogram(audio)
|
@@ -563,17 +671,16 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|
563 |
audio_len = [audio_len_after_cnn, audio_token_num]
|
564 |
audios.append(mel)
|
565 |
audio_lens.append(audio_len)
|
566 |
-
audio_span_tokens.append(
|
|
|
|
|
567 |
input_audio_lengths = torch.IntTensor(audio_lens)
|
568 |
input_audios = torch.stack(audios, dim=0)
|
569 |
-
return {
|
570 |
-
|
571 |
-
|
572 |
-
|
|
|
|
|
573 |
else:
|
574 |
return None
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
|
|
13 |
|
14 |
import requests
|
15 |
import unicodedata
|
16 |
+
from typing import (
|
17 |
+
Collection,
|
18 |
+
Dict,
|
19 |
+
List,
|
20 |
+
Set,
|
21 |
+
Tuple,
|
22 |
+
Union,
|
23 |
+
Any,
|
24 |
+
Callable,
|
25 |
+
Optional,
|
26 |
+
)
|
27 |
|
28 |
import tiktoken
|
29 |
import numpy as np
|
30 |
|
31 |
from transformers import PreTrainedTokenizer, AddedToken
|
32 |
from transformers.utils import try_to_load_from_cache
|
33 |
+
from transformers.tokenization_utils_base import (
|
34 |
+
BatchEncoding,
|
35 |
+
PaddingStrategy,
|
36 |
+
TruncationStrategy,
|
37 |
+
TextInput,
|
38 |
+
TextInputPair,
|
39 |
+
PreTokenizedInput,
|
40 |
+
PreTokenizedInputPair,
|
41 |
+
TensorType,
|
42 |
+
EncodedInput,
|
43 |
+
EncodedInputPair,
|
44 |
+
)
|
45 |
|
46 |
import matplotlib.colors as mcolors
|
47 |
from matplotlib.font_manager import FontProperties
|
|
|
60 |
# as different as possible to minimize the impact
|
61 |
EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
|
62 |
SPECIAL_TOKENS = (
|
63 |
+
ENDOFTEXT,
|
64 |
+
IMSTART,
|
65 |
+
IMEND,
|
66 |
+
) + EXTRAS
|
67 |
|
68 |
LANGUAGES = {
|
69 |
"en": "english",
|
|
|
82 |
contents = f.read()
|
83 |
return {
|
84 |
base64.b64decode(token): int(rank)
|
85 |
+
for token, rank in (
|
86 |
+
line.split() for line in contents.splitlines() if line
|
87 |
+
)
|
88 |
}
|
89 |
|
90 |
|
91 |
def _list_find(
|
92 |
+
input_list: List[Any],
|
93 |
+
candidates: Tuple[Any],
|
94 |
+
start: int = 0,
|
95 |
):
|
96 |
for i in range(start, len(input_list)):
|
97 |
if input_list[i] in candidates:
|
|
|
100 |
|
101 |
|
102 |
def _replace_closed_tag(
|
103 |
+
input_tokens: List[Any],
|
104 |
+
start_tags: Union[Any, Tuple[Any]],
|
105 |
+
end_tags: Union[Any, Tuple[Any]],
|
106 |
+
inclusive_replace_func: Callable,
|
107 |
+
exclusive_replace_func: Callable = lambda x: x,
|
108 |
+
audio_info: Dict = None,
|
109 |
):
|
110 |
if isinstance(start_tags, (str, int)):
|
111 |
start_tags = (start_tags,)
|
|
|
120 |
start = _list_find(input_tokens, start_tags, end)
|
121 |
if start == -1:
|
122 |
break
|
123 |
+
output_tokens.extend(exclusive_replace_func(input_tokens[end:start]))
|
124 |
tag_idx = start_tags.index(input_tokens[start])
|
125 |
end = _list_find(input_tokens, (end_tags[tag_idx],), start)
|
126 |
if end == -1:
|
127 |
raise ValueError("Unclosed audio token")
|
128 |
+
output_tokens.extend(
|
129 |
+
inclusive_replace_func(
|
130 |
+
input_tokens[start : end + 1], audio_info, audio_idx
|
131 |
+
)
|
132 |
+
)
|
133 |
end += 1
|
134 |
audio_idx += 1
|
135 |
output_tokens.extend(exclusive_replace_func(input_tokens[end:]))
|
|
|
142 |
vocab_files_names = VOCAB_FILES_NAMES
|
143 |
|
144 |
def __init__(
|
145 |
+
self,
|
146 |
+
vocab_file,
|
147 |
+
errors="replace",
|
148 |
+
audio_start_tag="<audio>",
|
149 |
+
audio_end_tag="</audio>",
|
150 |
+
**kwargs,
|
151 |
):
|
152 |
super().__init__(**kwargs)
|
153 |
self.audio_start_tag = audio_start_tag
|
|
|
155 |
self.audio_pad_tag = "[[[AUDIO:modality]]]"
|
156 |
|
157 |
self.AUDIO_ST = (
|
158 |
+
"[[[AUDIO:modality]]]",
|
159 |
# Transcription Tag
|
160 |
"<|startoftranscript|>", # Transcription
|
161 |
"<|startofanalysis|>", # Analysis
|
|
|
172 |
"<|notimestamps|>",
|
173 |
"<|sil|>",
|
174 |
"<|timestamps|>",
|
175 |
+
*[
|
176 |
+
f"<|{i * 0.01:.2f}|>" for i in range(3001)
|
177 |
+
], # timestamps 0.00-30.00
|
178 |
# Output Instruction
|
179 |
"<|caption_audiocaps|>", # Audiocaps caption style
|
180 |
"<|caption_clotho|>", # Clotho caption style
|
|
|
192 |
"<|endofword|>",
|
193 |
"<|delim|>", # delimiter of timestamps pair in audio grounding
|
194 |
"<|emotion_recognition|>", # emotion recognition
|
195 |
+
"<|emotion_transcription|>",
|
196 |
"<|music_description|>", # music description
|
197 |
"<|note_analysis|>", # note analysis
|
198 |
"<|pitch|>", # note analysis: pitch
|
199 |
*[f"<|midi_pitch_{i}|>" for i in range(128)], # midi pitch 0-127
|
200 |
"<|velocity|>", # note analysis: velocity
|
201 |
+
*[
|
202 |
+
f"<|midi_velocity_{i}|>" for i in range(128)
|
203 |
+
], # midi velocity 0-127
|
204 |
"<|sonic|>", # note analysis: sonic
|
205 |
"<|instrument|>", # note analysis: instrument
|
206 |
"<|speaker_meta|>", # meta information of speaker
|
|
|
217 |
"<|entities|>", # speech language understanding: entities
|
218 |
"<|speech_edit|>", # speech edit
|
219 |
audio_start_tag,
|
220 |
+
audio_end_tag,
|
221 |
)
|
222 |
|
223 |
self.errors = errors # how to handle errors in decoding
|
224 |
|
225 |
+
self.mergeable_ranks = _load_tiktoken_bpe(
|
226 |
+
vocab_file
|
227 |
+
) # type: dict[bytes, int]
|
228 |
self.special_tokens = {
|
229 |
token: index
|
230 |
for index, token in enumerate(
|
231 |
SPECIAL_TOKENS + self.AUDIO_ST, start=len(self.mergeable_ranks)
|
|
|
232 |
)
|
233 |
}
|
234 |
self.audio_start_id = self.special_tokens[self.audio_start_tag]
|
235 |
self.audio_end_id = self.special_tokens[self.audio_end_tag]
|
236 |
self.audio_pad_id = self.special_tokens[self.audio_pad_tag]
|
237 |
+
print(
|
238 |
+
f"audio_start_id: {self.audio_start_id}, "
|
239 |
+
f"audio_end_id: {self.audio_end_id}, "
|
240 |
+
f"audio_pad_id: {self.audio_pad_id}."
|
241 |
+
)
|
242 |
|
243 |
enc = tiktoken.Encoding(
|
244 |
"Qwen",
|
|
|
247 |
special_tokens=self.special_tokens,
|
248 |
)
|
249 |
assert (
|
250 |
+
len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
|
251 |
), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
|
252 |
|
253 |
self.decoder = {
|
|
|
264 |
def __getstate__(self):
|
265 |
# for pickle lovers
|
266 |
state = self.__dict__.copy()
|
267 |
+
del state["tokenizer"]
|
268 |
return state
|
269 |
|
270 |
def __setstate__(self, state):
|
|
|
285 |
return self.mergeable_ranks
|
286 |
|
287 |
def convert_tokens_to_ids(
|
288 |
+
self, tokens: Union[bytes, str, List[Union[bytes, str]]]
|
289 |
) -> List[int]:
|
290 |
ids = []
|
291 |
if isinstance(tokens, (str, bytes)):
|
|
|
300 |
ids.append(self.mergeable_ranks.get(token))
|
301 |
return ids
|
302 |
|
303 |
+
def _add_tokens(
|
304 |
+
self,
|
305 |
+
new_tokens: Union[List[str], List[AddedToken]],
|
306 |
+
special_tokens: bool = False,
|
307 |
+
) -> int:
|
308 |
if not special_tokens and new_tokens:
|
309 |
+
raise ValueError("Adding regular tokens is not supported")
|
310 |
for token in new_tokens:
|
311 |
+
surface_form = (
|
312 |
+
token.content if isinstance(token, AddedToken) else token
|
313 |
+
)
|
314 |
+
if surface_form not in SPECIAL_TOKENS + self.AUDIO_ST:
|
315 |
+
raise ValueError(
|
316 |
+
"Adding unknown special tokens is not supported"
|
317 |
+
)
|
318 |
return 0
|
319 |
|
320 |
def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
|
|
|
332 |
return (file_path,)
|
333 |
|
334 |
def tokenize(
|
335 |
+
self,
|
336 |
+
text: str,
|
337 |
+
allowed_special: Union[Set, str] = "all",
|
338 |
+
disallowed_special: Union[Collection, str] = (),
|
339 |
+
audio_info: Dict = None,
|
340 |
+
**kwargs,
|
341 |
) -> List[Union[bytes, str]]:
|
342 |
"""
|
343 |
Converts a string in a sequence of tokens.
|
|
|
363 |
|
364 |
# this implementation takes a detour: text -> token id -> token surface forms
|
365 |
for t in self.tokenizer.encode(
|
366 |
+
text,
|
367 |
+
allowed_special=allowed_special,
|
368 |
+
disallowed_special=disallowed_special,
|
369 |
):
|
370 |
tokens.append(self.decoder[t])
|
371 |
|
372 |
def _encode_audiourl(audio_tokens, audio_info, audio_idx):
|
373 |
+
assert (
|
374 |
+
audio_tokens[0] == self.audio_start_tag
|
375 |
+
and audio_tokens[-1] == self.audio_end_tag
|
376 |
+
)
|
377 |
+
audio_token_span = audio_info["audio_span_tokens"][audio_idx]
|
378 |
+
out_audio_tokens = (
|
379 |
+
[self.audio_start_tag]
|
380 |
+
+ [self.audio_pad_tag] * (audio_token_span - 2)
|
381 |
+
+ [self.audio_end_tag]
|
382 |
+
)
|
383 |
return out_audio_tokens
|
384 |
|
385 |
+
return _replace_closed_tag(
|
386 |
+
tokens,
|
387 |
+
self.audio_start_tag,
|
388 |
+
self.audio_end_tag,
|
389 |
+
_encode_audiourl,
|
390 |
+
audio_info=audio_info,
|
391 |
+
)
|
392 |
|
393 |
def _batch_encode_plus(
|
394 |
+
self,
|
395 |
+
batch_text_or_text_pairs: Union[
|
396 |
+
List[TextInput],
|
397 |
+
List[TextInputPair],
|
398 |
+
List[PreTokenizedInput],
|
399 |
+
List[PreTokenizedInputPair],
|
400 |
+
List[EncodedInput],
|
401 |
+
List[EncodedInputPair],
|
402 |
+
],
|
403 |
+
add_special_tokens: bool = True,
|
404 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
405 |
+
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
|
406 |
+
max_length: Optional[int] = None,
|
407 |
+
stride: int = 0,
|
408 |
+
is_split_into_words: bool = False,
|
409 |
+
pad_to_multiple_of: Optional[int] = None,
|
410 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
411 |
+
return_token_type_ids: Optional[bool] = None,
|
412 |
+
return_attention_mask: Optional[bool] = None,
|
413 |
+
return_overflowing_tokens: bool = False,
|
414 |
+
return_special_tokens_mask: bool = False,
|
415 |
+
return_offsets_mapping: bool = False,
|
416 |
+
return_length: bool = False,
|
417 |
+
verbose: bool = True,
|
418 |
+
**kwargs,
|
419 |
) -> BatchEncoding:
|
420 |
|
421 |
def get_input_ids(text):
|
422 |
if isinstance(text, str):
|
423 |
tokens = self.tokenize(text, **kwargs)
|
424 |
return self.convert_tokens_to_ids(tokens)
|
425 |
+
elif (
|
426 |
+
isinstance(text, (list, tuple))
|
427 |
+
and len(text) > 0
|
428 |
+
and isinstance(text[0], str)
|
429 |
+
):
|
430 |
if is_split_into_words:
|
431 |
tokens = list(
|
432 |
+
itertools.chain(
|
433 |
+
*(
|
434 |
+
self.tokenize(
|
435 |
+
t, is_split_into_words=True, **kwargs
|
436 |
+
)
|
437 |
+
for t in text
|
438 |
+
)
|
439 |
+
)
|
440 |
)
|
441 |
return self.convert_tokens_to_ids(tokens)
|
442 |
else:
|
443 |
return self.convert_tokens_to_ids(text)
|
444 |
+
elif (
|
445 |
+
isinstance(text, (list, tuple))
|
446 |
+
and len(text) > 0
|
447 |
+
and isinstance(text[0], int)
|
448 |
+
):
|
449 |
return text
|
450 |
else:
|
451 |
raise ValueError(
|
|
|
462 |
input_ids = []
|
463 |
audio_info = kwargs.pop("audio_info", None)
|
464 |
for pair_id in range(len(batch_text_or_text_pairs)):
|
465 |
+
kwargs["audio_info"] = audio_info[pair_id]
|
466 |
ids_or_pair_ids = batch_text_or_text_pairs[pair_id]
|
467 |
# for ids_or_pair_ids in batch_text_or_text_pairs:
|
468 |
if not isinstance(ids_or_pair_ids, (list, tuple)):
|
469 |
ids, pair_ids = ids_or_pair_ids, None
|
470 |
+
elif is_split_into_words and not isinstance(
|
471 |
+
ids_or_pair_ids[0], (list, tuple)
|
472 |
+
):
|
473 |
ids, pair_ids = ids_or_pair_ids, None
|
474 |
else:
|
475 |
ids, pair_ids = ids_or_pair_ids
|
476 |
|
477 |
first_ids = get_input_ids(ids)
|
478 |
+
second_ids = (
|
479 |
+
get_input_ids(pair_ids) if pair_ids is not None else None
|
480 |
+
)
|
481 |
input_ids.append((first_ids, second_ids))
|
482 |
|
483 |
batch_outputs = self._batch_prepare_for_model(
|
|
|
547 |
raise NotImplementedError
|
548 |
|
549 |
def _decode(
|
550 |
+
self,
|
551 |
+
token_ids: Union[int, List[int]],
|
552 |
+
skip_special_tokens: bool = False,
|
553 |
+
errors: str = None,
|
554 |
+
**kwargs,
|
555 |
) -> str:
|
556 |
if isinstance(token_ids, int):
|
557 |
token_ids = [token_ids]
|
558 |
audio_info = kwargs.pop("audio_info", None)
|
559 |
|
560 |
def _decode_audiourl(audio_token_ids, audio_info, audio_idx):
|
561 |
+
assert (
|
562 |
+
audio_token_ids[0] == self.audio_start_id
|
563 |
+
and audio_token_ids[-1] == self.audio_end_id
|
564 |
+
)
|
565 |
audio_url = audio_info["audio_urls"][audio_idx]
|
566 |
+
return (
|
567 |
+
[self.audio_start_id]
|
568 |
+
+ self.tokenizer.encode(audio_url)
|
569 |
+
+ [self.audio_end_id]
|
570 |
+
)
|
571 |
|
572 |
+
token_ids = _replace_closed_tag(
|
573 |
+
token_ids,
|
574 |
+
self.audio_start_id,
|
575 |
+
self.audio_end_id,
|
576 |
+
_decode_audiourl,
|
577 |
+
audio_info=audio_info,
|
578 |
+
)
|
579 |
|
580 |
if skip_special_tokens:
|
581 |
token_ids = [i for i in token_ids if i < self.eod_id]
|
|
|
584 |
def to_list_format(self, text: str):
|
585 |
text = unicodedata.normalize("NFC", text)
|
586 |
token_ids = self.tokenizer.encode(
|
587 |
+
text, allowed_special=set(self.AUDIO_ST + (ENDOFTEXT,))
|
588 |
+
)
|
589 |
|
590 |
def _encode_audio_info(tokens):
|
591 |
if len(tokens) == 0:
|
592 |
return []
|
593 |
+
if (
|
594 |
+
tokens[0] == self.audio_start_id
|
595 |
+
and tokens[-1] == self.audio_end_id
|
596 |
+
):
|
597 |
+
key = "audio"
|
598 |
else:
|
599 |
+
_tobytes = lambda x: (
|
600 |
+
x.encode("utf-8") if isinstance(x, str) else x
|
601 |
+
)
|
602 |
+
return [
|
603 |
+
{
|
604 |
+
"text": b"".join(
|
605 |
+
map(_tobytes, map(self.decoder.get, tokens))
|
606 |
+
).decode("utf-8")
|
607 |
+
}
|
608 |
+
]
|
609 |
+
_tobytes = lambda x: x.encode("utf-8") if isinstance(x, str) else x
|
610 |
+
val = b"".join(
|
611 |
+
map(_tobytes, map(self.decoder.get, tokens[1:-1]))
|
612 |
+
).decode("utf-8")
|
613 |
return [{key: val}]
|
614 |
|
615 |
return _replace_closed_tag(
|
|
|
621 |
)
|
622 |
|
623 |
def from_list_format(self, list_format: List[Dict]):
|
624 |
+
text = ""
|
625 |
num_audios = 0
|
626 |
for ele in list_format:
|
627 |
+
if "audio" in ele:
|
628 |
num_audios += 1
|
629 |
+
text += f"Audio {num_audios}:"
|
630 |
+
text += self.audio_start_tag + ele["audio"] + self.audio_end_tag
|
631 |
+
text += "\n"
|
632 |
+
elif "text" in ele:
|
633 |
+
text += ele["text"]
|
634 |
+
elif "box" in ele:
|
635 |
+
if "ref" in ele:
|
636 |
+
text += self.ref_start_tag + ele["ref"] + self.ref_end_tag
|
637 |
+
for box in ele["box"]:
|
638 |
+
text += (
|
639 |
+
self.box_start_tag
|
640 |
+
+ "(%d,%d),(%d,%d)" % (box[0], box[1], box[2], box[3])
|
641 |
+
+ self.box_end_tag
|
642 |
+
)
|
643 |
else:
|
644 |
raise ValueError("Unsupport element: " + str(ele))
|
645 |
return text
|
|
|
653 |
if len(audio_urls) > 0:
|
654 |
audios, audio_lens, audio_span_tokens = [], [], []
|
655 |
for audio_path in audio_urls:
|
656 |
+
if audio_path.startswith("http://") or audio_path.startswith(
|
657 |
+
"https://"
|
658 |
+
): # http
|
659 |
data = bytes(requests.get(audio_path, stream=True).content)
|
660 |
audio = load_bytesio_audio(data)
|
661 |
else:
|
662 |
audio = load_audio(audio_path)
|
663 |
+
L = (
|
664 |
+
audio.shape[0] if audio.shape[0] <= 480000 else 480000
|
665 |
+
) # max_length < 30s
|
666 |
mel_len = L // 160
|
667 |
audio = pad_or_trim(audio.flatten())
|
668 |
mel = log_mel_spectrogram(audio)
|
|
|
671 |
audio_len = [audio_len_after_cnn, audio_token_num]
|
672 |
audios.append(mel)
|
673 |
audio_lens.append(audio_len)
|
674 |
+
audio_span_tokens.append(
|
675 |
+
audio_token_num + 2
|
676 |
+
) # add audio bos eos
|
677 |
input_audio_lengths = torch.IntTensor(audio_lens)
|
678 |
input_audios = torch.stack(audios, dim=0)
|
679 |
+
return {
|
680 |
+
"input_audios": input_audios,
|
681 |
+
"input_audio_lengths": input_audio_lengths,
|
682 |
+
"audio_span_tokens": audio_span_tokens,
|
683 |
+
"audio_urls": audio_urls,
|
684 |
+
}
|
685 |
else:
|
686 |
return None
|
|
|
|
|
|
|
|
|
|