mingdali commited on
Commit
08c9928
·
verified ·
1 Parent(s): 921331d

Delete tokenization_qwen.py

Browse files
Files changed (1) hide show
  1. tokenization_qwen.py +0 -587
tokenization_qwen.py DELETED
@@ -1,587 +0,0 @@
1
- # Copyright (c) Alibaba Cloud.
2
- #
3
- # This source code is licensed under the license found in the
4
- # LICENSE file in the root directory of this source tree.
5
-
6
- """Tokenization classes for QWen."""
7
-
8
- import base64
9
- import logging
10
- import os
11
- import requests
12
- import unicodedata
13
- from typing import Collection, Dict, List, Set, Tuple, Union, Any, Callable, Optional
14
-
15
- import tiktoken
16
- import numpy as np
17
- from PIL import Image
18
- from PIL import ImageFont
19
- from PIL import ImageDraw
20
- from transformers import PreTrainedTokenizer, AddedToken
21
- from transformers.utils import try_to_load_from_cache
22
-
23
- import matplotlib.colors as mcolors
24
- from matplotlib.font_manager import FontProperties
25
-
26
- logger = logging.getLogger(__name__)
27
-
28
-
29
- VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken", "ttf": "SimSun.ttf"}
30
- FONT_PATH = try_to_load_from_cache("Qwen/Qwen-VL-Chat", "SimSun.ttf")
31
- if FONT_PATH is None:
32
- FONT_PATH = "SimSun.ttf"
33
-
34
- PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
35
- ENDOFTEXT = "<|endoftext|>"
36
- IMSTART = "<|im_start|>"
37
- IMEND = "<|im_end|>"
38
- # as the default behavior is changed to allow special tokens in
39
- # regular texts, the surface forms of special tokens need to be
40
- # as different as possible to minimize the impact
41
- EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
42
- SPECIAL_TOKENS = (
43
- ENDOFTEXT,
44
- IMSTART,
45
- IMEND,
46
- ) + EXTRAS
47
- IMG_TOKEN_SPAN = 256
48
-
49
-
50
- def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
51
- with open(tiktoken_bpe_file, "rb") as f:
52
- contents = f.read()
53
- return {
54
- base64.b64decode(token): int(rank)
55
- for token, rank in (line.split() for line in contents.splitlines() if line)
56
- }
57
-
58
- def _list_find(
59
- input_list: List[Any],
60
- candidates: Tuple[Any],
61
- start: int = 0,
62
- ):
63
- for i in range(start, len(input_list)):
64
- if input_list[i] in candidates:
65
- return i
66
- return -1
67
-
68
- def _replace_closed_tag(
69
- input_tokens: List[Any],
70
- start_tags: Union[Any, Tuple[Any]],
71
- end_tags: Union[Any, Tuple[Any]],
72
- inclusive_replace_func: Callable,
73
- exclusive_replace_func: Callable = lambda x: x,
74
- ):
75
- if isinstance(start_tags, (str, int)):
76
- start_tags = (start_tags,)
77
- if isinstance(end_tags, (str, int)):
78
- end_tags = (end_tags,)
79
- assert len(start_tags) == len(end_tags)
80
-
81
- output_tokens = []
82
- end = 0
83
- while True:
84
- start = _list_find(input_tokens, start_tags, end)
85
- if start == -1:
86
- break
87
- output_tokens.extend(exclusive_replace_func(input_tokens[end : start]))
88
- tag_idx = start_tags.index(input_tokens[start])
89
- end = _list_find(input_tokens, (end_tags[tag_idx],), start)
90
- if end == -1:
91
- raise ValueError("Unclosed image token")
92
- output_tokens.extend(inclusive_replace_func(input_tokens[start : end + 1]))
93
- end += 1
94
- output_tokens.extend(exclusive_replace_func(input_tokens[end : ]))
95
- return output_tokens
96
-
97
- class QWenTokenizer(PreTrainedTokenizer):
98
- """QWen tokenizer."""
99
-
100
- vocab_files_names = VOCAB_FILES_NAMES
101
-
102
- def __init__(
103
- self,
104
- vocab_file,
105
- errors="replace",
106
- image_start_tag='<img>',
107
- image_end_tag='</img>',
108
- image_pad_tag='<imgpad>',
109
- ref_start_tag='<ref>',
110
- ref_end_tag='</ref>',
111
- box_start_tag='<box>',
112
- box_end_tag='</box>',
113
- quad_start_tag='<quad>',
114
- quad_end_tag='</quad>',
115
- **kwargs,
116
- ):
117
- super().__init__(**kwargs)
118
- self.image_start_tag = image_start_tag
119
- self.image_end_tag = image_end_tag
120
- self.image_pad_tag = image_pad_tag
121
- self.ref_start_tag = ref_start_tag
122
- self.ref_end_tag = ref_end_tag
123
- self.box_start_tag = box_start_tag
124
- self.box_end_tag = box_end_tag
125
- self.quad_start_tag = quad_start_tag
126
- self.quad_end_tag = quad_end_tag
127
- self.IMAGE_ST = (
128
- ref_start_tag, ref_end_tag,
129
- box_start_tag, box_end_tag,
130
- quad_start_tag, quad_end_tag,
131
- image_start_tag, image_end_tag,
132
- image_pad_tag
133
- )
134
-
135
- self.errors = errors # how to handle errors in decoding
136
-
137
- self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int]
138
- self.special_tokens = {
139
- token: index
140
- for index, token in enumerate(
141
- SPECIAL_TOKENS + self.IMAGE_ST, start=len(self.mergeable_ranks)
142
- )
143
- }
144
- self.img_start_id = self.special_tokens[self.image_start_tag]
145
- self.img_end_id = self.special_tokens[self.image_end_tag]
146
- self.img_pad_id = self.special_tokens[self.image_pad_tag]
147
- self.ref_start_id = self.special_tokens[self.ref_start_tag]
148
- self.ref_end_id = self.special_tokens[self.ref_end_tag]
149
- self.box_start_id = self.special_tokens[self.box_start_tag]
150
- self.box_end_id = self.special_tokens[self.box_end_tag]
151
- self.quad_start_id = self.special_tokens[self.quad_start_tag]
152
- self.quad_end_id = self.special_tokens[self.quad_end_tag]
153
-
154
- enc = tiktoken.Encoding(
155
- "Qwen",
156
- pat_str=PAT_STR,
157
- mergeable_ranks=self.mergeable_ranks,
158
- special_tokens=self.special_tokens,
159
- )
160
- assert (
161
- len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
162
- ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
163
-
164
- self.decoder = {
165
- v: k for k, v in self.mergeable_ranks.items()
166
- } # type: dict[int, bytes|str]
167
- self.decoder.update({v: k for k, v in self.special_tokens.items()})
168
-
169
- self.tokenizer = enc # type: tiktoken.Encoding
170
-
171
- self.eod_id = self.tokenizer.eot_token
172
- self.im_start_id = self.special_tokens[IMSTART]
173
- self.im_end_id = self.special_tokens[IMEND]
174
-
175
- def __getstate__(self):
176
- # for pickle lovers
177
- state = self.__dict__.copy()
178
- del state['tokenizer']
179
- return state
180
-
181
- def __setstate__(self, state):
182
- # tokenizer is not python native; don't pass it; rebuild it
183
- self.__dict__.update(state)
184
- enc = tiktoken.Encoding(
185
- "Qwen",
186
- pat_str=PAT_STR,
187
- mergeable_ranks=self.mergeable_ranks,
188
- special_tokens=self.special_tokens,
189
- )
190
- self.tokenizer = enc
191
-
192
-
193
- def __len__(self) -> int:
194
- return self.tokenizer.n_vocab
195
-
196
- def get_vocab(self) -> Dict[bytes, int]:
197
- return self.mergeable_ranks
198
-
199
- def convert_tokens_to_ids(
200
- self, tokens: Union[bytes, str, List[Union[bytes, str]]]
201
- ) -> List[int]:
202
- ids = []
203
- if isinstance(tokens, (str, bytes)):
204
- if tokens in self.special_tokens:
205
- return self.special_tokens[tokens]
206
- else:
207
- return self.mergeable_ranks.get(tokens)
208
- for token in tokens:
209
- if token in self.special_tokens:
210
- ids.append(self.special_tokens[token])
211
- else:
212
- ids.append(self.mergeable_ranks.get(token))
213
- return ids
214
-
215
- def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
216
- if not special_tokens and new_tokens:
217
- raise ValueError('Adding regular tokens is not supported')
218
- for token in new_tokens:
219
- surface_form = token.content if isinstance(token, AddedToken) else token
220
- if surface_form not in SPECIAL_TOKENS + self.IMAGE_ST:
221
- raise ValueError('Adding unknown special tokens is not supported')
222
- return 0
223
-
224
- def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
225
- """
226
- Save only the vocabulary of the tokenizer (vocabulary).
227
-
228
- Returns:
229
- `Tuple(str)`: Paths to the files saved.
230
- """
231
- file_path = os.path.join(save_directory, "qwen.tiktoken")
232
- with open(file_path, "w", encoding="utf8") as w:
233
- for k, v in self.mergeable_ranks.items():
234
- line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n"
235
- w.write(line)
236
- return (file_path,)
237
-
238
- def tokenize(
239
- self,
240
- text: str,
241
- allowed_special: Union[Set, str] = "all",
242
- disallowed_special: Union[Collection, str] = (),
243
- **kwargs,
244
- ) -> List[Union[bytes, str]]:
245
- """
246
- Converts a string in a sequence of tokens.
247
-
248
- Args:
249
- text (`str`):
250
- The sequence to be encoded.
251
- allowed_special (`Literal["all"]` or `set`):
252
- The surface forms of the tokens to be encoded as special tokens in regular texts.
253
- Default to "all".
254
- disallowed_special (`Literal["all"]` or `Collection`):
255
- The surface forms of the tokens that should not be in regular texts and trigger errors.
256
- Default to an empty tuple.
257
-
258
- kwargs (additional keyword arguments, *optional*):
259
- Will be passed to the underlying model specific encode method.
260
-
261
- Returns:
262
- `List[bytes|str]`: The list of tokens.
263
- """
264
- tokens = []
265
- text = unicodedata.normalize("NFC", text)
266
-
267
- # this implementation takes a detour: text -> token id -> token surface forms
268
- for t in self.tokenizer.encode(
269
- text, allowed_special=allowed_special, disallowed_special=disallowed_special
270
- ):
271
- tokens.append(self.decoder[t])
272
-
273
- def _encode_imgurl(img_tokens):
274
- assert img_tokens[0] == self.image_start_tag and img_tokens[-1] == self.image_end_tag
275
- img_tokens = img_tokens[1:-1]
276
- img_url = b''.join(img_tokens)
277
- out_img_tokens = list(map(self.decoder.get, img_url))
278
- if len(out_img_tokens) > IMG_TOKEN_SPAN:
279
- raise ValueError("The content in {}..{} is too long".format(
280
- self.image_start_tag, self.image_end_tag))
281
- out_img_tokens.extend([self.image_pad_tag] * (IMG_TOKEN_SPAN - len(out_img_tokens)))
282
- out_img_tokens = [self.image_start_tag] + out_img_tokens + [self.image_end_tag]
283
- return out_img_tokens
284
-
285
- return _replace_closed_tag(tokens, self.image_start_tag, self.image_end_tag, _encode_imgurl)
286
-
287
- def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
288
- """
289
- Converts a sequence of tokens in a single string.
290
- """
291
- text = ""
292
- temp = b""
293
- for t in tokens:
294
- if isinstance(t, str):
295
- if temp:
296
- text += temp.decode("utf-8", errors=self.errors)
297
- temp = b""
298
- text += t
299
- elif isinstance(t, bytes):
300
- temp += t
301
- else:
302
- raise TypeError("token should only be of type types or str")
303
- if temp:
304
- text += temp.decode("utf-8", errors=self.errors)
305
- return text
306
-
307
- @property
308
- def vocab_size(self):
309
- return self.tokenizer.n_vocab
310
-
311
- def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
312
- """Converts an id to a token, special tokens included"""
313
- if index in self.decoder:
314
- return self.decoder[index]
315
- raise ValueError("unknown ids")
316
-
317
- def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
318
- """Converts a token to an id using the vocab, special tokens included"""
319
- if token in self.special_tokens:
320
- return self.special_tokens[token]
321
- if token in self.mergeable_ranks:
322
- return self.mergeable_ranks[token]
323
- raise ValueError("unknown token")
324
-
325
- def _tokenize(self, text: str, **kwargs):
326
- """
327
- Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
328
- vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
329
-
330
- Do NOT take care of added tokens.
331
- """
332
- raise NotImplementedError
333
-
334
- def _decode(
335
- self,
336
- token_ids: Union[int, List[int]],
337
- skip_special_tokens: bool = False,
338
- errors: str = None,
339
- **kwargs,
340
- ) -> str:
341
- if isinstance(token_ids, int):
342
- token_ids = [token_ids]
343
-
344
- def _decode_imgurl(img_token_ids):
345
- assert img_token_ids[0] == self.img_start_id and img_token_ids[-1] == self.img_end_id
346
- img_token_ids = img_token_ids[1:-1]
347
- img_token_ids = img_token_ids[ : img_token_ids.index(self.img_pad_id)]
348
- img_url = bytes(img_token_ids).decode('utf-8')
349
- return [self.img_start_id] + self.tokenizer.encode(img_url) + [self.img_end_id]
350
-
351
- token_ids = _replace_closed_tag(token_ids, self.img_start_id, self.img_end_id, _decode_imgurl)
352
-
353
- if skip_special_tokens:
354
- token_ids = [i for i in token_ids if i < self.eod_id]
355
- return self.tokenizer.decode(token_ids, errors=errors or self.errors)
356
-
357
- def to_list_format(self, text: str):
358
- text = unicodedata.normalize("NFC", text)
359
- token_ids = self.tokenizer.encode(
360
- text, allowed_special=set(self.IMAGE_ST + (ENDOFTEXT,)))
361
-
362
- def _encode_vl_info(tokens):
363
- if len(tokens) == 0:
364
- return []
365
- if tokens[0] == self.img_start_id and tokens[-1] == self.img_end_id:
366
- key = 'image'
367
- elif tokens[0] == self.ref_start_id and tokens[-1] == self.ref_end_id:
368
- key = 'ref'
369
- elif tokens[0] == self.box_start_id and tokens[-1] == self.box_end_id:
370
- key = 'box'
371
- elif tokens[0] == self.quad_start_id and tokens[-1] == self.quad_end_id:
372
- key = 'quad'
373
- else:
374
- _tobytes = lambda x: x.encode('utf-8') if isinstance(x, str) else x
375
- return [{'text': b''.join(map(_tobytes, map(self.decoder.get, tokens))).decode('utf-8')}]
376
- _tobytes = lambda x: x.encode('utf-8') if isinstance(x, str) else x
377
- val = b''.join(map(_tobytes, map(self.decoder.get, tokens[1:-1]))).decode('utf-8')
378
- return [{key: val}]
379
-
380
- return _replace_closed_tag(
381
- token_ids,
382
- (self.img_start_id, self.ref_start_id, self.box_start_id, self.quad_start_id),
383
- (self.img_end_id, self.ref_end_id, self.box_end_id, self.quad_end_id),
384
- _encode_vl_info,
385
- _encode_vl_info,
386
- )
387
-
388
- def from_list_format(self, list_format: List[Dict]):
389
- text = ''
390
- num_images = 0
391
- for ele in list_format:
392
- if 'image' in ele:
393
- num_images += 1
394
- text += f'Picture {num_images}:'
395
- text += self.image_start_tag + ele['image'] + self.image_end_tag
396
- text += '\n'
397
- elif 'text' in ele:
398
- text += ele['text']
399
- elif 'box' in ele:
400
- if 'ref' in ele:
401
- text += self.ref_start_tag + ele['ref'] + self.ref_end_tag
402
- for box in ele['box']:
403
- text += self.box_start_tag + '(%d,%d),(%d,%d)' % (box[0], box[1], box[2], box[3]) + self.box_end_tag
404
- else:
405
- raise ValueError("Unsupport element: " + str(ele))
406
- return text
407
-
408
- def _fetch_latest_picture(self, response, history):
409
- if history is None:
410
- history = []
411
- _history = history + [(response, None)]
412
- for q, r in _history[::-1]:
413
- for ele in self.to_list_format(q)[::-1]:
414
- if 'image' in ele:
415
- return ele['image']
416
- return None
417
-
418
- def _fetch_all_box_with_ref(self, text):
419
- list_format = self.to_list_format(text)
420
- output = []
421
- for i, ele in enumerate(list_format):
422
- if 'box' in ele:
423
- bbox = tuple(map(int, ele['box'].replace('(', '').replace(')', '').split(',')))
424
- assert len(bbox) == 4
425
- output.append({'box': bbox})
426
- if i > 0 and 'ref' in list_format[i-1]:
427
- output[-1]['ref'] = list_format[i-1]['ref'].strip()
428
- return output
429
-
430
- def draw_bbox_on_latest_picture(
431
- self,
432
- response,
433
- history=None,
434
- ) -> Optional[Image.Image]:
435
- image = self._fetch_latest_picture(response, history)
436
- if image is None:
437
- return None
438
- if image.startswith("http://") or image.startswith("https://"):
439
- image = Image.open(requests.get(image, stream=True).raw).convert("RGB")
440
- h, w = image.height, image.width
441
- else:
442
- image = np.asarray(Image.open(image).convert("RGB"))
443
- h, w = image.shape[0], image.shape[1]
444
- visualizer = Visualizer(image)
445
-
446
- boxes = self._fetch_all_box_with_ref(response)
447
- if not boxes:
448
- return None
449
- color = random.choice([_ for _ in mcolors.TABLEAU_COLORS.keys()]) # init color
450
- for box in boxes:
451
- if 'ref' in box: # random new color for new refexps
452
- color = random.choice([_ for _ in mcolors.TABLEAU_COLORS.keys()])
453
- x1, y1, x2, y2 = box['box']
454
- x1, y1, x2, y2 = (int(x1 / 1000 * w), int(y1 / 1000 * h), int(x2 / 1000 * w), int(y2 / 1000 * h))
455
- visualizer.draw_box((x1, y1, x2, y2), alpha=1, edge_color=color)
456
- if 'ref' in box:
457
- visualizer.draw_text(box['ref'], (x1, y1), color=color, horizontal_alignment="left")
458
- return visualizer.output
459
-
460
-
461
- import colorsys
462
- import logging
463
- import math
464
- import numpy as np
465
- import matplotlib as mpl
466
- import matplotlib.colors as mplc
467
- import matplotlib.figure as mplfigure
468
- import torch
469
- from matplotlib.backends.backend_agg import FigureCanvasAgg
470
- from PIL import Image
471
- import random
472
-
473
- logger = logging.getLogger(__name__)
474
-
475
-
476
- class VisImage:
477
- def __init__(self, img, scale=1.0):
478
- self.img = img
479
- self.scale = scale
480
- self.width, self.height = img.shape[1], img.shape[0]
481
- self._setup_figure(img)
482
-
483
- def _setup_figure(self, img):
484
- fig = mplfigure.Figure(frameon=False)
485
- self.dpi = fig.get_dpi()
486
- # add a small 1e-2 to avoid precision lost due to matplotlib's truncation
487
- # (https://github.com/matplotlib/matplotlib/issues/15363)
488
- fig.set_size_inches(
489
- (self.width * self.scale + 1e-2) / self.dpi,
490
- (self.height * self.scale + 1e-2) / self.dpi,
491
- )
492
- self.canvas = FigureCanvasAgg(fig)
493
- # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
494
- ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
495
- ax.axis("off")
496
- self.fig = fig
497
- self.ax = ax
498
- self.reset_image(img)
499
-
500
- def reset_image(self, img):
501
- img = img.astype("uint8")
502
- self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest")
503
-
504
- def save(self, filepath):
505
- self.fig.savefig(filepath)
506
-
507
- def get_image(self):
508
- canvas = self.canvas
509
- s, (width, height) = canvas.print_to_buffer()
510
-
511
- buffer = np.frombuffer(s, dtype="uint8")
512
-
513
- img_rgba = buffer.reshape(height, width, 4)
514
- rgb, alpha = np.split(img_rgba, [3], axis=2)
515
- return rgb.astype("uint8")
516
-
517
-
518
- class Visualizer:
519
- def __init__(self, img_rgb, metadata=None, scale=1.0):
520
- self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
521
- self.font_path = FONT_PATH
522
- self.output = VisImage(self.img, scale=scale)
523
- self.cpu_device = torch.device("cpu")
524
-
525
- # too small texts are useless, therefore clamp to 14
526
- self._default_font_size = max(
527
- np.sqrt(self.output.height * self.output.width) // 30, 15 // scale
528
- )
529
-
530
- def draw_text(
531
- self,
532
- text,
533
- position,
534
- *,
535
- font_size=None,
536
- color="g",
537
- horizontal_alignment="center",
538
- rotation=0,
539
- ):
540
- if not font_size:
541
- font_size = self._default_font_size
542
-
543
- # since the text background is dark, we don't want the text to be dark
544
- color = np.maximum(list(mplc.to_rgb(color)), 0.2)
545
- color[np.argmax(color)] = max(0.8, np.max(color))
546
-
547
- x, y = position
548
- self.output.ax.text(
549
- x,
550
- y,
551
- text,
552
- size=font_size * self.output.scale,
553
- fontproperties=FontProperties(fname=self.font_path),
554
- bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
555
- verticalalignment="top",
556
- horizontalalignment=horizontal_alignment,
557
- color=color,
558
- zorder=10,
559
- rotation=rotation,
560
- )
561
- return self.output
562
-
563
- def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"):
564
-
565
- x0, y0, x1, y1 = box_coord
566
- width = x1 - x0
567
- height = y1 - y0
568
-
569
- linewidth = max(self._default_font_size / 4, 1)
570
-
571
- self.output.ax.add_patch(
572
- mpl.patches.Rectangle(
573
- (x0, y0),
574
- width,
575
- height,
576
- fill=False,
577
- edgecolor=edge_color,
578
- linewidth=linewidth * self.output.scale,
579
- alpha=alpha,
580
- linestyle=line_style,
581
- )
582
- )
583
- return self.output
584
-
585
- def get_output(self):
586
-
587
- return self.output