Zhiding commited on
Commit
a5f8592
·
1 Parent(s): 471b572
README.md ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ pipeline_tag: image-text-to-text
4
+ library_name: transformers
5
+ base_model:
6
+ - google/paligemma-3b-mix-448
7
+ - Qwen/Qwen2.5-0.5B-Instruct
8
+ base_model_relation: merge
9
+ language:
10
+ - multilingual
11
+ tags:
12
+ - eagle
13
+ - VLM
14
+ ---
15
+
16
+
17
+ # Eagle-2
18
+
19
+ [\[📂 GitHub\]](https://github.com/NVlabs/EAGLE) [\[📜 Eagle 2\]](TODO)
20
+ [\[🗨️ Chat Demo\]](http://eagle-vlm.xyz/) [\[🤗 HF Demo\]](TODO)
21
+ ## Introduction
22
+
23
+ We are thrilled to release our latest Eagle2 series Vision-Language Model. Open-source Vision-Language Models (VLMs) have made significant strides in narrowing the gap with proprietary models. However, critical details about data strategies and implementation are often missing, limiting reproducibility and innovation. In this project, we focus on VLM post-training from a data-centric perspective, sharing insights into building effective data strategies from scratch. By combining these strategies with robust training recipes and model design, we introduce Eagle 2, a family of performant VLMs. Our work aims to empower the open-source community to develop competitive VLMs with transparent processes.
24
+
25
+
26
+
27
+ In this repo, we are open-sourcing Eagle2-9B, which strikes the perfect balance between performance and inference speed.
28
+
29
+
30
+
31
+
32
+
33
+
34
+
35
+
36
+
37
+ ## Model Zoo
38
+ We provide the following models:
39
+
40
+ | model name | LLM | Vision | Max Length| HF Link|
41
+ | ----------- | ------- |---------|-|-|
42
+ | Eagle2-1B | [Qwen2.5-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct) | Siglip | 16K| [🤗 link](https://huggingface.co/NVIDIA/Eagle2-1B)|
43
+ | Eagle2-2B | [Qwen2.5-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct) | Siglip | 16K| [🤗 link](https://huggingface.co/NVIDIA/Eagle2-2B)|
44
+ | Eagle2-9B | [Qwen2.5-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct) | Siglip+ConvNext | 16K| [🤗 link](https://huggingface.co/NVIDIA/Eagle2-9B)|
45
+ | Eagle2-32B | [Qwen2.5-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-32B-Instruct) | Siglip+ConvNext | 16K| [🤗 link](https://huggingface.co/NVIDIA/Eagle2-32B)|
46
+
47
+ ## Benchmark Results
48
+ | Benchmark | LLaVa-One-Vision-0.5B | InternVL2-1B | InternVL2.5-1B |Qwen2-VL-2B| Eagle2-1B|
49
+ | :--------------------------: | :------------------: | :----------------: | :----------: |:----------: |:----------: |
50
+ | DocVQA<sub>test</sub> | 70.0 | 81.7 | 84.8 |90.1|81.8|
51
+ | ChartQA<sub>test</sub> | 61.4 | 72.9 | 75.9 |73.0|77.0|
52
+ | InfoVQA<sub>test</sub> | 41.8 | 50.9 | 56.0 |65.5|54.8|
53
+ | TextVQA<sub>val</sub> | - | 70.0 | 72.0 |79.7|76.6|
54
+ | OCRBench | 565 | 754 | 785 |809|767|
55
+ | MME<sub>sum</sub> | 1438.0 | 1794.4 | 1950.5 | 1872.0| 1790.2|
56
+ | RealWorldQA | 55.6 | 50.3 | 57.5 |62.6|55.4|
57
+ | AI2D<sub>test</sub> | 57.1 | 64.1 | 69.3 | 74.7 |70.9|
58
+ | MMMU<sub>val</sub> | 31.4 | 36.7 | 40.9 |41.1|38.8|
59
+ | MMVet<sub>GPT-4-Turbo</sub> | 32.2 | 32.7 | 48.8 | 49.5|40.9| HallBench<sub>avg</sub> | 27.9 | 34.0 | 39.0 |**41.7**|35.3
60
+ | MathVista<sub>testmini</sub> | 3.8 | 37.7 | 43.2 |43.0|45.3|
61
+ | MMstar | 37.7 | 45.7 | 50.1|48.0|48.5|
62
+
63
+
64
+
65
+ ## Quick Start
66
+
67
+
68
+
69
+ We provide a [demo inference script](./demo.py) to help you quickly start using the model. We support different input types:
70
+ - pure text input
71
+ - single image input
72
+ - multiple image input
73
+ - video input
74
+
75
+ ### 0. Install the dependencies
76
+
77
+ ```bash
78
+ pip install transformers==4.37.2
79
+ pip install flash-attn
80
+ ```
81
+ **Note**: Latest version of transformers if not compatible with the model.
82
+
83
+ ### 1. Prepare the Model worker
84
+
85
+ <details>
86
+ <summary>Click to expand</summary>
87
+
88
+ ```python
89
+
90
+ """
91
+ A model worker executes the model.
92
+ Copied and modified from https://github.com/OpenGVLab/InternVL/blob/main/streamlit_demo/model_worker.py
93
+ """
94
+ # Importing torch before transformers can cause `segmentation fault`
95
+ from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer, AutoConfig
96
+
97
+ import argparse
98
+ import base64
99
+ import json
100
+ import os
101
+ import decord
102
+ import threading
103
+ import time
104
+ from io import BytesIO
105
+ from threading import Thread
106
+ import math
107
+ import requests
108
+ import torch
109
+ import torchvision.transforms as T
110
+ from PIL import Image
111
+ from torchvision.transforms.functional import InterpolationMode
112
+ import numpy as np
113
+
114
+
115
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
116
+ IMAGENET_STD = (0.229, 0.224, 0.225)
117
+
118
+ SIGLIP_MEAN = (0.5, 0.5, 0.5)
119
+ SIGLIP_STD = (0.5, 0.5, 0.5)
120
+
121
+
122
+ def get_seq_frames(total_num_frames, desired_num_frames=-1, stride=-1):
123
+ """
124
+ Calculate the indices of frames to extract from a video.
125
+
126
+ Parameters:
127
+ total_num_frames (int): Total number of frames in the video.
128
+ desired_num_frames (int): Desired number of frames to extract.
129
+
130
+ Returns:
131
+ list: List of indices of frames to extract.
132
+ """
133
+
134
+ assert desired_num_frames > 0 or stride > 0 and not (desired_num_frames > 0 and stride > 0)
135
+
136
+ if stride > 0:
137
+ return list(range(0, total_num_frames, stride))
138
+
139
+ # Calculate the size of each segment from which a frame will be extracted
140
+ seg_size = float(total_num_frames - 1) / desired_num_frames
141
+
142
+ seq = []
143
+ for i in range(desired_num_frames):
144
+ # Calculate the start and end indices of each segment
145
+ start = int(np.round(seg_size * i))
146
+ end = int(np.round(seg_size * (i + 1)))
147
+
148
+ # Append the middle index of the segment to the list
149
+ seq.append((start + end) // 2)
150
+
151
+ return seq
152
+
153
+ def build_video_prompt(meta_list, num_frames, time_position=False):
154
+ # if time_position is True, the frame_timestamp is used.
155
+ # 1. pass time_position, 2. use env TIME_POSITION
156
+ time_position = os.environ.get("TIME_POSITION", time_position)
157
+ prefix = f"This is a video:\n"
158
+ for i in range(num_frames):
159
+ if time_position:
160
+ frame_txt = f"Frame {i+1} sampled at {meta_list[i]:.2f} seconds: <image>\n"
161
+ else:
162
+ frame_txt = f"Frame {i+1}: <image>\n"
163
+ prefix += frame_txt
164
+ return prefix
165
+
166
+ def load_video(video_path, num_frames=64, frame_cache_root=None):
167
+ if isinstance(video_path, str):
168
+ video = decord.VideoReader(video_path)
169
+ elif isinstance(video_path, dict):
170
+ assert False, 'we not support vidoe: "video_path" as input'
171
+ fps = video.get_avg_fps()
172
+ sampled_frames = get_seq_frames(len(video), num_frames)
173
+ samepld_timestamps = [i / fps for i in sampled_frames]
174
+ frames = video.get_batch(sampled_frames).asnumpy()
175
+ images = [Image.fromarray(frame) for frame in frames]
176
+
177
+ return images, build_video_prompt(samepld_timestamps, len(images), time_position=True)
178
+
179
+ def load_image(image):
180
+ if isinstance(image, str) and os.path.exists(image):
181
+ return Image.open(image)
182
+ elif isinstance(image, dict):
183
+ if 'disk_path' in image:
184
+ return Image.open(image['disk_path'])
185
+ elif 'base64' in image:
186
+ return Image.open(BytesIO(base64.b64decode(image['base64'])))
187
+ elif 'url' in image:
188
+ response = requests.get(image['url'])
189
+ return Image.open(BytesIO(response.content))
190
+ elif 'bytes' in image:
191
+ return Image.open(BytesIO(image['bytes']))
192
+ else:
193
+ raise ValueError(f'Invalid image: {image}')
194
+ else:
195
+ raise ValueError(f'Invalid image: {image}')
196
+
197
+ def build_transform(input_size, norm_type='imagenet'):
198
+ if norm_type == 'imagenet':
199
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
200
+ elif norm_type == 'siglip':
201
+ MEAN, STD = SIGLIP_MEAN, SIGLIP_STD
202
+
203
+ transform = T.Compose([
204
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
205
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
206
+ T.ToTensor(),
207
+ T.Normalize(mean=MEAN, std=STD)
208
+ ])
209
+ return transform
210
+
211
+
212
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
213
+ """
214
+ previous version mainly foucs on ratio.
215
+ We also consider area ratio here.
216
+ """
217
+ best_factor = float('-inf')
218
+ best_ratio = (1, 1)
219
+ area = width * height
220
+ for ratio in target_ratios:
221
+ target_aspect_ratio = ratio[0] / ratio[1]
222
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
223
+ area_ratio = (ratio[0]*ratio[1]*image_size*image_size)/ area
224
+ """
225
+ new area > 60% of original image area is enough.
226
+ """
227
+ factor_based_on_area_n_ratio = min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6)* \
228
+ min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio)
229
+
230
+ if factor_based_on_area_n_ratio > best_factor:
231
+ best_factor = factor_based_on_area_n_ratio
232
+ best_ratio = ratio
233
+
234
+ return best_ratio
235
+
236
+
237
+ def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
238
+ orig_width, orig_height = image.size
239
+ aspect_ratio = orig_width / orig_height
240
+
241
+ # calculate the existing image aspect ratio
242
+ target_ratios = set(
243
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
244
+ i * j <= max_num and i * j >= min_num)
245
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
246
+
247
+ # find the closest aspect ratio to the target
248
+ target_aspect_ratio = find_closest_aspect_ratio(
249
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
250
+
251
+ # calculate the target width and height
252
+ target_width = image_size * target_aspect_ratio[0]
253
+ target_height = image_size * target_aspect_ratio[1]
254
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
255
+
256
+ # resize the image
257
+ resized_img = image.resize((target_width, target_height))
258
+ processed_images = []
259
+ for i in range(blocks):
260
+ box = (
261
+ (i % (target_width // image_size)) * image_size,
262
+ (i // (target_width // image_size)) * image_size,
263
+ ((i % (target_width // image_size)) + 1) * image_size,
264
+ ((i // (target_width // image_size)) + 1) * image_size
265
+ )
266
+ # split the image
267
+ split_img = resized_img.crop(box)
268
+ processed_images.append(split_img)
269
+ assert len(processed_images) == blocks
270
+ if use_thumbnail and len(processed_images) != 1:
271
+ thumbnail_img = image.resize((image_size, image_size))
272
+ processed_images.append(thumbnail_img)
273
+ return processed_images
274
+
275
+ def split_model(model_path, device):
276
+
277
+ device_map = {}
278
+ world_size = torch.cuda.device_count()
279
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
280
+ num_layers = config.llm_config.num_hidden_layers
281
+
282
+ print('world_size', world_size)
283
+ num_layers_per_gpu_ = math.floor(num_layers / (world_size - 1))
284
+ num_layers_per_gpu = [num_layers_per_gpu_] * world_size
285
+ num_layers_per_gpu[device] = num_layers - num_layers_per_gpu_ * (world_size-1)
286
+ print(num_layers_per_gpu)
287
+ layer_cnt = 0
288
+ for i, num_layer in enumerate(num_layers_per_gpu):
289
+ for j in range(num_layer):
290
+ device_map[f'language_model.model.layers.{layer_cnt}'] = i
291
+ layer_cnt += 1
292
+ device_map['vision_model'] = device
293
+ device_map['mlp1'] = device
294
+ device_map['language_model.model.tok_embeddings'] = device
295
+ device_map['language_model.model.embed_tokens'] = device
296
+ device_map['language_model.output'] = device
297
+ device_map['language_model.model.norm'] = device
298
+ device_map['language_model.lm_head'] = device
299
+ device_map['language_model.model.rotary_emb'] = device
300
+ device_map[f'language_model.model.layers.{num_layers - 1}'] = device
301
+ return device_map
302
+
303
+ class ModelWorker:
304
+ def __init__(self, model_path, model_name,
305
+ load_8bit, device):
306
+
307
+ if model_path.endswith('/'):
308
+ model_path = model_path[:-1]
309
+ if model_name is None:
310
+ model_paths = model_path.split('/')
311
+ if model_paths[-1].startswith('checkpoint-'):
312
+ self.model_name = model_paths[-2] + '_' + model_paths[-1]
313
+ else:
314
+ self.model_name = model_paths[-1]
315
+ else:
316
+ self.model_name = model_name
317
+
318
+ print(f'Loading the model {self.model_name}')
319
+
320
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
321
+ tokens_to_keep = ['<box>', '</box>', '<ref>', '</ref>']
322
+ tokenizer.additional_special_tokens = [item for item in tokenizer.additional_special_tokens if item not in tokens_to_keep]
323
+ self.tokenizer = tokenizer
324
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
325
+ model_type = config.vision_config.model_type
326
+ self.device = torch.cuda.current_device()
327
+ if model_type == 'siglip_vision_model':
328
+ self.norm_type = 'siglip'
329
+ elif model_type == 'MOB':
330
+ self.norm_type = 'siglip'
331
+ else:
332
+ self.norm_type = 'imagenet'
333
+
334
+ if any(x in model_path.lower() for x in ['34b']):
335
+ device_map = split_model(model_path, self.device)
336
+ else:
337
+ device_map = None
338
+
339
+ if device_map is not None:
340
+ self.model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16,
341
+ low_cpu_mem_usage=True,
342
+ device_map=device_map,
343
+ trust_remote_code=True,
344
+ load_in_8bit=load_8bit).eval()
345
+ else:
346
+ self.model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16,
347
+ trust_remote_code=True,
348
+ load_in_8bit=load_8bit).eval()
349
+
350
+ if not load_8bit and device_map is None:
351
+ self.model = self.model.to(device)
352
+ self.load_8bit = load_8bit
353
+
354
+ self.model_path = model_path
355
+ self.image_size = self.model.config.force_image_size
356
+ self.context_len = tokenizer.model_max_length
357
+ self.per_tile_len = 256
358
+
359
+ def reload_model(self):
360
+ del self.model
361
+ torch.cuda.empty_cache()
362
+ if self.device == 'auto':
363
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
364
+ # This can make distributed deployment work properly
365
+ self.model = AutoModel.from_pretrained(
366
+ self.model_path,
367
+ load_in_8bit=self.load_8bit,
368
+ torch_dtype=torch.bfloat16,
369
+ device_map=self.device_map,
370
+ trust_remote_code=True).eval()
371
+ else:
372
+ self.model = AutoModel.from_pretrained(
373
+ self.model_path,
374
+ load_in_8bit=self.load_8bit,
375
+ torch_dtype=torch.bfloat16,
376
+ trust_remote_code=True).eval()
377
+ if not self.load_8bit and not self.device == 'auto':
378
+ self.model = self.model.cuda()
379
+
380
+ @torch.inference_mode()
381
+ def generate(self, params):
382
+ system_message = params['prompt'][0]['content']
383
+ send_messages = params['prompt'][1:]
384
+ max_input_tiles = params['max_input_tiles']
385
+ temperature = params['temperature']
386
+ top_p = params['top_p']
387
+ max_new_tokens = params['max_new_tokens']
388
+ repetition_penalty = params['repetition_penalty']
389
+ video_frame_num = params.get('video_frame_num', 64)
390
+ do_sample = True if temperature > 0.0 else False
391
+
392
+ global_image_cnt = 0
393
+ history, pil_images, max_input_tile_list = [], [], []
394
+ for message in send_messages:
395
+ if message['role'] == 'user':
396
+ prefix = ''
397
+ if 'image' in message:
398
+ for image_data in message['image']:
399
+ pil_images.append(load_image(image_data))
400
+ prefix = prefix + f'<image {global_image_cnt + 1}><image>\n'
401
+ global_image_cnt += 1
402
+ max_input_tile_list.append(max_input_tiles)
403
+ if 'video' in message:
404
+ for video_data in message['video']:
405
+ video_frames, tmp_prefix = load_video(video_data, num_frames=video_frame_num)
406
+ pil_images.extend(video_frames)
407
+ prefix = prefix + tmp_prefix
408
+ global_image_cnt += len(video_frames)
409
+ max_input_tile_list.extend([1] * len(video_frames))
410
+ content = prefix + message['content']
411
+ history.append([content, ])
412
+ else:
413
+ history[-1].append(message['content'])
414
+ question, history = history[-1][0], history[:-1]
415
+
416
+ if global_image_cnt == 1:
417
+ question = question.replace('<image 1><image>\n', '<image>\n')
418
+ history = [[item[0].replace('<image 1><image>\n', '<image>\n'), item[1]] for item in history]
419
+
420
+
421
+ try:
422
+ assert len(max_input_tile_list) == len(pil_images), 'The number of max_input_tile_list and pil_images should be the same.'
423
+ except Exception as e:
424
+ from IPython import embed; embed()
425
+ exit()
426
+ print(f'Error: {e}')
427
+ print(f'max_input_tile_list: {max_input_tile_list}, pil_images: {pil_images}')
428
+ # raise e
429
+
430
+ old_system_message = self.model.system_message
431
+ self.model.system_message = system_message
432
+
433
+ transform = build_transform(input_size=self.image_size, norm_type=self.norm_type)
434
+ if len(pil_images) > 0:
435
+ max_input_tiles_limited_by_contect = params['max_input_tiles']
436
+ while True:
437
+ image_tiles = []
438
+ for current_max_input_tiles, pil_image in zip(max_input_tile_list, pil_images):
439
+ if self.model.config.dynamic_image_size:
440
+ tiles = dynamic_preprocess(
441
+ pil_image, image_size=self.image_size, max_num=min(current_max_input_tiles, max_input_tiles_limited_by_contect),
442
+ use_thumbnail=self.model.config.use_thumbnail)
443
+ else:
444
+ tiles = [pil_image]
445
+ image_tiles += tiles
446
+ if (len(image_tiles) * self.per_tile_len < self.context_len):
447
+ break
448
+ else:
449
+ max_input_tiles_limited_by_contect -= 2
450
+
451
+ if max_input_tiles_limited_by_contect < 1:
452
+ break
453
+
454
+ pixel_values = [transform(item) for item in image_tiles]
455
+ pixel_values = torch.stack(pixel_values).to(self.model.device, dtype=torch.bfloat16)
456
+ print(f'Split images to {pixel_values.shape}')
457
+ else:
458
+ pixel_values = None
459
+
460
+ generation_config = dict(
461
+ num_beams=1,
462
+ max_new_tokens=max_new_tokens,
463
+ do_sample=do_sample,
464
+ temperature=temperature,
465
+ repetition_penalty=repetition_penalty,
466
+ max_length=self.context_len,
467
+ top_p=top_p,
468
+ )
469
+
470
+ response = self.model.chat(
471
+ tokenizer=self.tokenizer,
472
+ pixel_values=pixel_values,
473
+ question=question,
474
+ history=history,
475
+ return_history=False,
476
+ generation_config=generation_config,
477
+ )
478
+ self.model.system_message = old_system_message
479
+ return {'text': response, 'error_code': 0}
480
+
481
+
482
+
483
+
484
+
485
+ if __name__ == '__main__':
486
+ parser = argparse.ArgumentParser()
487
+ parser.add_argument('--model-path', type=str, default='NVIDIA/Eagle-2-1B')
488
+ parser.add_argument('--model-name', type=str, default='Eagle-2-1B')
489
+ parser.add_argument('--device', type=str, default='cuda')
490
+ parser.add_argument('--load-8bit', action='store_true')
491
+ args = parser.parse_args()
492
+ print(f'args: {args}')
493
+
494
+ worker = ModelWorker(
495
+ args.model_path,
496
+ args.model_name,
497
+ args.load_8bit,
498
+ args.device)
499
+ ```
500
+ </details>
501
+
502
+
503
+ ### 2. Prepare the Prompt
504
+
505
+ - Single image input
506
+ ```python
507
+ prompt = [
508
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
509
+ {'role': 'user', 'content': 'Describe this image in details.',
510
+ 'image':[
511
+ {'url': 'https://www.nvidia.com/content/dam/en-zz/Solutions/about-nvidia/logo-and-brand/[email protected]'}
512
+ ],
513
+ }
514
+ ]
515
+ ```
516
+
517
+ - Multiple image input
518
+ ```python
519
+ prompt = [
520
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
521
+ {'role': 'user', 'content': 'Describe these two images in details.',
522
+ 'image':[
523
+ {'url': 'https://www.nvidia.com/content/dam/en-zz/Solutions/about-nvidia/logo-and-brand/[email protected]'},
524
+ {'url': 'https://www.nvidia.com/content/dam/en-zz/Solutions/about-nvidia/logo-and-brand/[email protected]'}
525
+ ],
526
+ }
527
+ ]
528
+ ```
529
+
530
+ - Video input
531
+ ```python
532
+ prompt = [
533
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
534
+ {'role': 'user', 'content': 'Describe this video in details.',
535
+ 'video':[
536
+ 'path/to/your/video.mp4'
537
+ ],
538
+ }
539
+ ]
540
+ ```
541
+
542
+ ### 3. Generate the response
543
+ ```python
544
+ params = {
545
+ 'prompt': prompt,
546
+ 'max_input_tiles': 24,
547
+ 'temperature': 0.7,
548
+ 'top_p': 1.0,
549
+ 'max_new_tokens': 4096,
550
+ 'repetition_penalty': 1.0,
551
+ }
552
+ worker.generate(params)
553
+ ```
554
+
555
+ ## TODO
556
+ - [ ] Support vLLM Inference
557
+ - [ ] Provide AWQ Quantization Weights
558
+ - [ ] Provide fine-tuning scripts
559
+
560
+
561
+ ## License/Terms of Use
562
+ - The code is released under the Apache 2.0 license as found in the [LICENSE](https://huggingface.co/NVEagle/Eagle-X5-13B-Chat/blob/main/LICENSE) file.
563
+ - The pretrained model weights are released under the [Creative Commons Attribution: Non-Commercial 4.0 International](https://spdx.org/licenses/CC-BY-NC-4.0) <br>
564
+ - The service is a research preview intended for non-commercial use only, and is subject to the following licenses and terms:
565
+ - Model License of Qwen2.5-7B-Instruct: [Apache-2.0](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/LICENSE)
566
+ - Model License of PaliGemma: [Gemma license](https://ai.google.dev/gemma/terms)
567
+
568
+
569
+
570
+ ## Citation
571
+
572
+ ## Ethical Considerations
573
+ NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse.
574
+
575
+ Please report security vulnerabilities or NVIDIA AI Concerns [here](https://www.nvidia.com/en-us/support/submit-security-vulnerability/).
576
+
added_tokens.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</box>": 151673,
3
+ "</img>": 151666,
4
+ "</quad>": 151669,
5
+ "</ref>": 151671,
6
+ "</tool_call>": 151658,
7
+ "<IMG_CONTEXT>": 151667,
8
+ "<box>": 151672,
9
+ "<img>": 151665,
10
+ "<quad>": 151668,
11
+ "<ref>": 151670,
12
+ "<tool_call>": 151657,
13
+ "<|box_end|>": 151649,
14
+ "<|box_start|>": 151648,
15
+ "<|endoftext|>": 151643,
16
+ "<|file_sep|>": 151664,
17
+ "<|fim_middle|>": 151660,
18
+ "<|fim_pad|>": 151662,
19
+ "<|fim_prefix|>": 151659,
20
+ "<|fim_suffix|>": 151661,
21
+ "<|im_end|>": 151645,
22
+ "<|im_start|>": 151644,
23
+ "<|image_pad|>": 151655,
24
+ "<|object_ref_end|>": 151647,
25
+ "<|object_ref_start|>": 151646,
26
+ "<|quad_end|>": 151651,
27
+ "<|quad_start|>": 151650,
28
+ "<|repo_name|>": 151663,
29
+ "<|video_pad|>": 151656,
30
+ "<|vision_end|>": 151653,
31
+ "<|vision_pad|>": 151654,
32
+ "<|vision_start|>": 151652
33
+ }
config.json ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": null,
3
+ "_name_or_path": "",
4
+ "architectures": [
5
+ "Eagle2ChatModel"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_eagle_chat.Eagle2ChatConfig",
9
+ "AutoModel": "modeling_eagle_chat.Eagle2ChatModel",
10
+ "AutoModelForCausalLM": "modeling_eagle_chat.Eagle2ChatModel"
11
+ },
12
+ "downsample_ratio": 0.5,
13
+ "dynamic_image_size": true,
14
+ "efficient_loss": true,
15
+ "force_image_size": 448,
16
+ "keep_aspect_ratio": false,
17
+ "llm_config": {
18
+ "_name_or_path": "./pretrained/Qwen2_5-0_5B-Instruct",
19
+ "add_cross_attention": false,
20
+ "architectures": [
21
+ "Qwen2ForCausalLM"
22
+ ],
23
+ "attention_dropout": 0.0,
24
+ "attn_implementation": "flash_attention_2",
25
+ "auto_map": {
26
+ "AutoConfig": "configuration_qwen2.Qwen2Config",
27
+ "AutoModel": "modeling_qwen2.Qwen2Model",
28
+ "AutoModelForCausalLM": "modeling_qwen2.Qwen2ForCausalLM"
29
+ },
30
+ "bad_words_ids": null,
31
+ "begin_suppress_tokens": null,
32
+ "bos_token_id": 151643,
33
+ "chunk_size_feed_forward": 0,
34
+ "cross_attention_hidden_size": null,
35
+ "decoder_start_token_id": null,
36
+ "diversity_penalty": 0.0,
37
+ "do_sample": false,
38
+ "early_stopping": false,
39
+ "encoder_no_repeat_ngram_size": 0,
40
+ "eos_token_id": 151645,
41
+ "exponential_decay_length_penalty": null,
42
+ "finetuning_task": null,
43
+ "forced_bos_token_id": null,
44
+ "forced_eos_token_id": null,
45
+ "hidden_act": "silu",
46
+ "hidden_size": 896,
47
+ "id2label": {
48
+ "0": "LABEL_0",
49
+ "1": "LABEL_1"
50
+ },
51
+ "initializer_range": 0.02,
52
+ "intermediate_size": 4864,
53
+ "is_decoder": false,
54
+ "is_encoder_decoder": false,
55
+ "label2id": {
56
+ "LABEL_0": 0,
57
+ "LABEL_1": 1
58
+ },
59
+ "length_penalty": 1.0,
60
+ "max_length": 20,
61
+ "max_position_embeddings": 32768,
62
+ "max_window_layers": 21,
63
+ "min_length": 0,
64
+ "model_type": "qwen2",
65
+ "no_repeat_ngram_size": 0,
66
+ "num_attention_heads": 14,
67
+ "num_beam_groups": 1,
68
+ "num_beams": 1,
69
+ "num_hidden_layers": 24,
70
+ "num_key_value_heads": 2,
71
+ "num_return_sequences": 1,
72
+ "output_attentions": false,
73
+ "output_hidden_states": false,
74
+ "output_scores": false,
75
+ "pad_token_id": null,
76
+ "prefix": null,
77
+ "problem_type": null,
78
+ "pruned_heads": {},
79
+ "remove_invalid_values": false,
80
+ "repetition_penalty": 1.0,
81
+ "return_dict": true,
82
+ "return_dict_in_generate": false,
83
+ "rms_norm_eps": 1e-06,
84
+ "rope_theta": 1000000.0,
85
+ "sep_token_id": null,
86
+ "sliding_window": 32768,
87
+ "suppress_tokens": null,
88
+ "task_specific_params": null,
89
+ "temperature": 1.0,
90
+ "tf_legacy_loss": false,
91
+ "tie_encoder_decoder": false,
92
+ "tie_word_embeddings": true,
93
+ "tokenizer_class": null,
94
+ "top_k": 50,
95
+ "top_p": 1.0,
96
+ "torch_dtype": "bfloat16",
97
+ "torchscript": false,
98
+ "transformers_version": "4.37.2",
99
+ "typical_p": 1.0,
100
+ "use_bfloat16": false,
101
+ "use_cache": false,
102
+ "use_sliding_window": false,
103
+ "vocab_size": 151674
104
+ },
105
+ "loss_version": "v4",
106
+ "max_dynamic_patch": 12,
107
+ "min_dynamic_patch": 1,
108
+ "mlp_checkpoint": true,
109
+ "model_type": "eagle_chat",
110
+ "pad2square": false,
111
+ "pre_feature_reduction": false,
112
+ "ps_version": "v2",
113
+ "select_layer": -1,
114
+ "template": "qwen2-chat",
115
+ "torch_dtype": "bfloat16",
116
+ "transformers_version": null,
117
+ "use_backbone_lora": 0,
118
+ "use_llm_lora": 0,
119
+ "use_thumbnail": true,
120
+ "vision_config": {
121
+ "_name_or_path": "",
122
+ "add_cross_attention": false,
123
+ "architectures": [
124
+ "SiglipVisionModel"
125
+ ],
126
+ "attention_dropout": 0.0,
127
+ "auto_map": {
128
+ "AutoConfig": "configuration_siglip.SiglipVisionConfig",
129
+ "AutoModel": "modeling_siglip.SiglipVisionModel"
130
+ },
131
+ "bad_words_ids": null,
132
+ "begin_suppress_tokens": null,
133
+ "bos_token_id": null,
134
+ "chunk_size_feed_forward": 0,
135
+ "cross_attention_hidden_size": null,
136
+ "decoder_start_token_id": null,
137
+ "diversity_penalty": 0.0,
138
+ "do_sample": false,
139
+ "drop_path_rate": 0.1,
140
+ "early_stopping": false,
141
+ "encoder_no_repeat_ngram_size": 0,
142
+ "eos_token_id": null,
143
+ "exponential_decay_length_penalty": null,
144
+ "finetuning_task": null,
145
+ "forced_bos_token_id": null,
146
+ "forced_eos_token_id": null,
147
+ "hidden_act": "gelu_pytorch_tanh",
148
+ "hidden_size": 1152,
149
+ "id2label": {
150
+ "0": "LABEL_0",
151
+ "1": "LABEL_1"
152
+ },
153
+ "image_size": 448,
154
+ "intermediate_size": 4304,
155
+ "is_decoder": false,
156
+ "is_encoder_decoder": false,
157
+ "label2id": {
158
+ "LABEL_0": 0,
159
+ "LABEL_1": 1
160
+ },
161
+ "layer_norm_eps": 1e-06,
162
+ "length_penalty": 1.0,
163
+ "max_length": 20,
164
+ "min_length": 0,
165
+ "model_type": "siglip_vision_model",
166
+ "no_repeat_ngram_size": 0,
167
+ "num_attention_heads": 16,
168
+ "num_beam_groups": 1,
169
+ "num_beams": 1,
170
+ "num_channels": 3,
171
+ "num_hidden_layers": 27,
172
+ "num_image_tokens": 1024,
173
+ "num_return_sequences": 1,
174
+ "output_attentions": false,
175
+ "output_hidden_states": false,
176
+ "output_scores": false,
177
+ "pad_token_id": null,
178
+ "patch_size": 14,
179
+ "prefix": null,
180
+ "problem_type": null,
181
+ "projection_dim": 2048,
182
+ "projector_hidden_act": "gelu_fast",
183
+ "pruned_heads": {},
184
+ "remove_invalid_values": false,
185
+ "repetition_penalty": 1.0,
186
+ "return_dict": true,
187
+ "return_dict_in_generate": false,
188
+ "sep_token_id": null,
189
+ "suppress_tokens": null,
190
+ "task_specific_params": null,
191
+ "temperature": 1.0,
192
+ "tf_legacy_loss": false,
193
+ "tie_encoder_decoder": false,
194
+ "tie_word_embeddings": true,
195
+ "tokenizer_class": null,
196
+ "top_k": 50,
197
+ "top_p": 1.0,
198
+ "torch_dtype": "bfloat16",
199
+ "torchscript": false,
200
+ "transformers_version": "4.37.2",
201
+ "typical_p": 1.0,
202
+ "use_bfloat16": false,
203
+ "vision_use_head": false
204
+ }
205
+ }
configuration_eagle_chat.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Eagle2
3
+ # Copyright (c) 2025 NVIDIA
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import copy
8
+
9
+ from transformers import AutoConfig, LlamaConfig
10
+ from transformers.configuration_utils import PretrainedConfig
11
+ from transformers.utils import logging
12
+ from .configuration_siglip import SiglipVisionConfig
13
+ from .configuration_qwen2 import Qwen2Config
14
+ from .configuration_multi_backbone_channel_concatentation_model import MultiBackboneChannelConcatenationVisionModelConfig
15
+ logger = logging.get_logger(__name__)
16
+
17
+
18
+ class Eagle2ChatConfig(PretrainedConfig):
19
+ model_type = 'eagle_chat'
20
+ is_composition = True
21
+
22
+ def __init__(
23
+ self,
24
+ vision_config=None,
25
+ llm_config=None,
26
+ use_backbone_lora=0,
27
+ use_llm_lora=0,
28
+ select_layer=-1,
29
+ force_image_size=None,
30
+ downsample_ratio=0.5,
31
+ template=None,
32
+ dynamic_image_size=False,
33
+ use_thumbnail=False,
34
+ min_dynamic_patch=1,
35
+ max_dynamic_patch=6,
36
+ mlp_checkpoint=True,
37
+ pre_feature_reduction=False,
38
+ keep_aspect_ratio=False,
39
+ **kwargs):
40
+ super().__init__(**kwargs)
41
+
42
+ if vision_config is None:
43
+ vision_config = {}
44
+ logger.info('vision_config is None. Initializing Vision Encoders with default values.')
45
+
46
+ if llm_config is None:
47
+ llm_config = {}
48
+ logger.info('llm_config is None. Initializing the LLM config with default values')
49
+
50
+ if vision_config['model_type'] == 'siglip_vision_model':
51
+ self.vision_config = SiglipVisionConfig(**vision_config)
52
+ elif vision_config['model_type'].startswith("MOB"):
53
+ self.vision_config = MultiBackboneChannelConcatenationVisionModelConfig(**vision_config)
54
+ else:
55
+ raise ValueError('Unsupported model_type: {}'.format(vision_config['model_type']))
56
+
57
+ if llm_config['architectures'][0] == 'LlamaForCausalLM':
58
+ self.llm_config = LlamaConfig(**llm_config)
59
+ elif llm_config['architectures'][0] == 'Qwen2ForCausalLM':
60
+ self.llm_config = Qwen2Config(**llm_config)
61
+ else:
62
+ raise ValueError('Unsupported architecture: {}'.format(llm_config['architectures'][0]))
63
+ self.use_backbone_lora = use_backbone_lora
64
+ self.use_llm_lora = use_llm_lora
65
+ self.select_layer = select_layer
66
+ self.force_image_size = force_image_size
67
+ self.downsample_ratio = downsample_ratio
68
+ self.template = template
69
+ self.dynamic_image_size = dynamic_image_size
70
+ self.use_thumbnail = use_thumbnail
71
+ self.min_dynamic_patch = min_dynamic_patch
72
+ self.max_dynamic_patch = max_dynamic_patch
73
+ self.mlp_checkpoint = mlp_checkpoint
74
+ self.pre_feature_reduction = pre_feature_reduction
75
+ self.keep_aspect_ratio = keep_aspect_ratio
76
+ logger.info(f'keep_aspect_ratio: {self.keep_aspect_ratio}')
77
+ logger.info(f'vision_select_layer: {self.select_layer}')
78
+ logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}')
79
+ logger.info(f'max_dynamic_patch: {self.max_dynamic_patch}')
80
+
81
+ def to_dict(self):
82
+ """
83
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
84
+
85
+ Returns:
86
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
87
+ """
88
+ output = copy.deepcopy(self.__dict__)
89
+ output['vision_config'] = self.vision_config.to_dict()
90
+ output['llm_config'] = self.llm_config.to_dict()
91
+ output['model_type'] = self.__class__.model_type
92
+ output['use_backbone_lora'] = self.use_backbone_lora
93
+ output['use_llm_lora'] = self.use_llm_lora
94
+ output['select_layer'] = self.select_layer
95
+ output['force_image_size'] = self.force_image_size
96
+ output['downsample_ratio'] = self.downsample_ratio
97
+ output['template'] = self.template
98
+ output['dynamic_image_size'] = self.dynamic_image_size
99
+ output['use_thumbnail'] = self.use_thumbnail
100
+ output['min_dynamic_patch'] = self.min_dynamic_patch
101
+ output['max_dynamic_patch'] = self.max_dynamic_patch
102
+ output['keep_aspect_ratio'] = self.keep_aspect_ratio
103
+
104
+ return output
configuration_multi_backbone_channel_concatentation_model.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Eagle2
3
+ # Copyright (c) 2025 NVIDIA
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import os
8
+ from typing import Union
9
+
10
+ from transformers.configuration_utils import PretrainedConfig
11
+ from transformers.utils import logging
12
+ from .configuration_siglip import SiglipVisionConfig
13
+ logger = logging.get_logger(__name__)
14
+
15
+
16
+ class MultiBackboneChannelConcatenationVisionModelConfig(PretrainedConfig):
17
+ r"""
18
+ This is the configuration class to store the configuration of a [`MultiBackboneChannelConcatenationVisionModelConfig`]. It is used to
19
+ instantiate a vision encoder according to the specified arguments, defining the model architecture.
20
+
21
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
22
+ documentation from [`PretrainedConfig`] for more information.
23
+
24
+ Args:
25
+ vision_path (str): Path to the vision model or its configuration.
26
+ mm_vision_select_layer (int, optional): The layer to select from the vision model
27
+ for multi-modal processing. Defaults to -2.
28
+ grid_size (int, optional): The size of the grid for vision processing. Defaults to 32.
29
+ **kwargs: Additional keyword arguments to be passed to the parent PretrainedConfig.
30
+
31
+ """
32
+
33
+ model_type = 'MOB'
34
+
35
+ def __init__(
36
+ self,
37
+ vision_path,
38
+ mm_vision_select_layer=-2,
39
+ grid_size=32,
40
+ input_image_size=1024,
41
+ hidden_size='lazy_calculation',
42
+ image_size=1024,
43
+ freeze_backbones=None,
44
+ moe_version_type=None,
45
+ delay_load=False,
46
+ convnext_img_size=1024,
47
+ vision_tower_siglip_path=None,
48
+ vision_tower_convnext_path='convnext_xxlarge.clip_laion2b_soup',
49
+ normalize_type='siglip',
50
+ **kwargs,
51
+ ):
52
+ super().__init__(**kwargs)
53
+
54
+ self.normalize_type = normalize_type
55
+ self.vision_path = vision_path
56
+ self.mm_vision_select_layer = mm_vision_select_layer
57
+ self.grid_size = grid_size
58
+ self.input_image_size = input_image_size
59
+ self.image_size = image_size
60
+ self.hidden_size = hidden_size
61
+ self.freeze_backbones = freeze_backbones
62
+ self.moe_version_type = moe_version_type
63
+ self.delay_load = delay_load
64
+ self.convnext_img_size = convnext_img_size
65
+ # other args. to make it compatable with eagle-next
66
+ self.vision_tower_siglip_path = vision_tower_siglip_path
67
+ self.vision_tower_convnext_path = vision_tower_convnext_path
68
+ self.vision_tower = self.vision_path[4:] # remove `MOB:` prefix
69
+
70
+ # asserts
71
+ assert image_size == input_image_size, f"input_image_size ({input_image_size}) != image_size ({image_size})"
72
+
73
+ @classmethod
74
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
75
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
76
+
77
+ if 'vision_config' in config_dict:
78
+ config_dict = config_dict['vision_config']
79
+
80
+ if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
81
+ logger.warning(
82
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
83
+ f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
84
+ )
85
+
86
+ return cls.from_dict(config_dict, **kwargs)
configuration_qwen2.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Qwen2 model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+ QWEN2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
24
+ "Qwen/Qwen2-7B-beta": "https://huggingface.co/Qwen/Qwen2-7B-beta/resolve/main/config.json",
25
+ }
26
+
27
+
28
+ class Qwen2Config(PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
31
+ Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
32
+ with the defaults will yield a similar configuration to that of
33
+ Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 151936):
41
+ Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
42
+ `inputs_ids` passed when calling [`Qwen2Model`]
43
+ hidden_size (`int`, *optional*, defaults to 4096):
44
+ Dimension of the hidden representations.
45
+ intermediate_size (`int`, *optional*, defaults to 22016):
46
+ Dimension of the MLP representations.
47
+ num_hidden_layers (`int`, *optional*, defaults to 32):
48
+ Number of hidden layers in the Transformer encoder.
49
+ num_attention_heads (`int`, *optional*, defaults to 32):
50
+ Number of attention heads for each attention layer in the Transformer encoder.
51
+ num_key_value_heads (`int`, *optional*, defaults to 32):
52
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
53
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
54
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
55
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
56
+ by meanpooling all the original heads within that group. For more details checkout [this
57
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
58
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
59
+ The non-linear activation function (function or string) in the decoder.
60
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
61
+ The maximum sequence length that this model might ever be used with.
62
+ initializer_range (`float`, *optional*, defaults to 0.02):
63
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
64
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
65
+ The epsilon used by the rms normalization layers.
66
+ use_cache (`bool`, *optional*, defaults to `True`):
67
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
68
+ relevant if `config.is_decoder=True`.
69
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
70
+ Whether the model's input and output word embeddings should be tied.
71
+ rope_theta (`float`, *optional*, defaults to 10000.0):
72
+ The base period of the RoPE embeddings.
73
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
74
+ Whether to use sliding window attention.
75
+ sliding_window (`int`, *optional*, defaults to 4096):
76
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
77
+ max_window_layers (`int`, *optional*, defaults to 28):
78
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
79
+ attention_dropout (`float`, *optional*, defaults to 0.0):
80
+ The dropout ratio for the attention probabilities.
81
+
82
+ ```python
83
+ >>> from transformers import Qwen2Model, Qwen2Config
84
+
85
+ >>> # Initializing a Qwen2 style configuration
86
+ >>> configuration = Qwen2Config()
87
+
88
+ >>> # Initializing a model from the Qwen2-7B style configuration
89
+ >>> model = Qwen2Model(configuration)
90
+
91
+ >>> # Accessing the model configuration
92
+ >>> configuration = model.config
93
+ ```"""
94
+
95
+ model_type = "qwen2"
96
+ keys_to_ignore_at_inference = ["past_key_values"]
97
+
98
+ def __init__(
99
+ self,
100
+ vocab_size=151936,
101
+ hidden_size=4096,
102
+ intermediate_size=22016,
103
+ num_hidden_layers=32,
104
+ num_attention_heads=32,
105
+ num_key_value_heads=32,
106
+ hidden_act="silu",
107
+ max_position_embeddings=32768,
108
+ initializer_range=0.02,
109
+ rms_norm_eps=1e-6,
110
+ use_cache=True,
111
+ tie_word_embeddings=False,
112
+ rope_theta=10000.0,
113
+ use_sliding_window=False,
114
+ sliding_window=4096,
115
+ max_window_layers=28,
116
+ attention_dropout=0.0,
117
+ attn_implementation='flash_attention_2',
118
+ **kwargs,
119
+ ):
120
+ self.vocab_size = vocab_size
121
+ self.max_position_embeddings = max_position_embeddings
122
+ self.hidden_size = hidden_size
123
+ self.intermediate_size = intermediate_size
124
+ self.num_hidden_layers = num_hidden_layers
125
+ self.num_attention_heads = num_attention_heads
126
+ self.use_sliding_window = use_sliding_window
127
+ self.sliding_window = sliding_window
128
+ self.max_window_layers = max_window_layers
129
+
130
+ self.attn_implementation = attn_implementation
131
+ if self.attn_implementation is None:
132
+ self.attn_implementation = "flash_attention_2"
133
+
134
+ # for backward compatibility
135
+ if num_key_value_heads is None:
136
+ num_key_value_heads = num_attention_heads
137
+
138
+ self.num_key_value_heads = num_key_value_heads
139
+ self.hidden_act = hidden_act
140
+ self.initializer_range = initializer_range
141
+ self.rms_norm_eps = rms_norm_eps
142
+ self.use_cache = use_cache
143
+ self.rope_theta = rope_theta
144
+ self.attention_dropout = attention_dropout
145
+
146
+ super().__init__(
147
+ tie_word_embeddings=tie_word_embeddings,
148
+ **kwargs,
149
+ )
configuration_siglip.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Siglip model configuration"""
16
+
17
+ import os
18
+ from typing import Union
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27
+ "google/siglip-base-patch16-224": "https://huggingface.co/google/siglip-base-patch16-224/resolve/main/config.json",
28
+ }
29
+
30
+
31
+ class SiglipTextConfig(PretrainedConfig):
32
+ r"""
33
+ This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a
34
+ Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a
35
+ configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip
36
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
37
+
38
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
39
+ documentation from [`PretrainedConfig`] for more information.
40
+
41
+ Args:
42
+ vocab_size (`int`, *optional*, defaults to 32000):
43
+ Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by
44
+ the `inputs_ids` passed when calling [`SiglipModel`].
45
+ hidden_size (`int`, *optional*, defaults to 768):
46
+ Dimensionality of the encoder layers and the pooler layer.
47
+ intermediate_size (`int`, *optional*, defaults to 3072):
48
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
49
+ num_hidden_layers (`int`, *optional*, defaults to 12):
50
+ Number of hidden layers in the Transformer encoder.
51
+ num_attention_heads (`int`, *optional*, defaults to 12):
52
+ Number of attention heads for each attention layer in the Transformer encoder.
53
+ max_position_embeddings (`int`, *optional*, defaults to 64):
54
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
55
+ just in case (e.g., 512 or 1024 or 2048).
56
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
57
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
58
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
59
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
60
+ The epsilon used by the layer normalization layers.
61
+ attention_dropout (`float`, *optional*, defaults to 0.0):
62
+ The dropout ratio for the attention probabilities.
63
+ pad_token_id (`int`, *optional*, defaults to 1):
64
+ The id of the padding token in the vocabulary.
65
+ bos_token_id (`int`, *optional*, defaults to 49406):
66
+ The id of the beginning-of-sequence token in the vocabulary.
67
+ eos_token_id (`int`, *optional*, defaults to 49407):
68
+ The id of the end-of-sequence token in the vocabulary.
69
+
70
+ Example:
71
+
72
+ ```python
73
+ >>> from transformers import SiglipTextConfig, SiglipTextModel
74
+
75
+ >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration
76
+ >>> configuration = SiglipTextConfig()
77
+
78
+ >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration
79
+ >>> model = SiglipTextModel(configuration)
80
+
81
+ >>> # Accessing the model configuration
82
+ >>> configuration = model.config
83
+ ```"""
84
+
85
+ model_type = "siglip_text_model"
86
+
87
+ def __init__(
88
+ self,
89
+ vocab_size=32000,
90
+ hidden_size=768,
91
+ intermediate_size=3072,
92
+ num_hidden_layers=12,
93
+ num_attention_heads=12,
94
+ max_position_embeddings=64,
95
+ hidden_act="gelu_pytorch_tanh",
96
+ layer_norm_eps=1e-6,
97
+ attention_dropout=0.0,
98
+ # This differs from `CLIPTokenizer`'s default and from openai/siglip
99
+ # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
100
+ pad_token_id=1,
101
+ bos_token_id=49406,
102
+ eos_token_id=49407,
103
+ **kwargs,
104
+ ):
105
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
106
+
107
+ self.vocab_size = vocab_size
108
+ self.hidden_size = hidden_size
109
+ self.intermediate_size = intermediate_size
110
+ self.num_hidden_layers = num_hidden_layers
111
+ self.num_attention_heads = num_attention_heads
112
+ self.max_position_embeddings = max_position_embeddings
113
+ self.layer_norm_eps = layer_norm_eps
114
+ self.hidden_act = hidden_act
115
+ self.attention_dropout = attention_dropout
116
+
117
+ @classmethod
118
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
119
+ cls._set_token_in_kwargs(kwargs)
120
+
121
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
122
+
123
+ # get the text config dict if we are loading from SiglipConfig
124
+ if config_dict.get("model_type") == "siglip":
125
+ config_dict = config_dict["text_config"]
126
+
127
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
128
+ logger.warning(
129
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
130
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
131
+ )
132
+
133
+ return cls.from_dict(config_dict, **kwargs)
134
+
135
+
136
+ class SiglipVisionConfig(PretrainedConfig):
137
+ r"""
138
+ This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
139
+ Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
140
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
141
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
142
+
143
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
144
+ documentation from [`PretrainedConfig`] for more information.
145
+
146
+ Args:
147
+ hidden_size (`int`, *optional*, defaults to 768):
148
+ Dimensionality of the encoder layers and the pooler layer.
149
+ intermediate_size (`int`, *optional*, defaults to 3072):
150
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
151
+ num_hidden_layers (`int`, *optional*, defaults to 12):
152
+ Number of hidden layers in the Transformer encoder.
153
+ num_attention_heads (`int`, *optional*, defaults to 12):
154
+ Number of attention heads for each attention layer in the Transformer encoder.
155
+ num_channels (`int`, *optional*, defaults to 3):
156
+ Number of channels in the input images.
157
+ image_size (`int`, *optional*, defaults to 224):
158
+ The size (resolution) of each image.
159
+ patch_size (`int`, *optional*, defaults to 16):
160
+ The size (resolution) of each patch.
161
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
162
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
163
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
164
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
165
+ The epsilon used by the layer normalization layers.
166
+ attention_dropout (`float`, *optional*, defaults to 0.0):
167
+ The dropout ratio for the attention probabilities.
168
+
169
+ Example:
170
+
171
+ ```python
172
+ >>> from transformers import SiglipVisionConfig, SiglipVisionModel
173
+
174
+ >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
175
+ >>> configuration = SiglipVisionConfig()
176
+
177
+ >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
178
+ >>> model = SiglipVisionModel(configuration)
179
+
180
+ >>> # Accessing the model configuration
181
+ >>> configuration = model.config
182
+ ```"""
183
+
184
+ model_type = "siglip_vision_model"
185
+
186
+ def __init__(
187
+ self,
188
+ hidden_size=768,
189
+ intermediate_size=3072,
190
+ num_hidden_layers=12,
191
+ num_attention_heads=12,
192
+ num_channels=3,
193
+ image_size=224,
194
+ patch_size=16,
195
+ hidden_act="gelu_pytorch_tanh",
196
+ layer_norm_eps=1e-6,
197
+ attention_dropout=0.0,
198
+ **kwargs,
199
+ ):
200
+ super().__init__(**kwargs)
201
+
202
+ self.hidden_size = hidden_size
203
+ self.intermediate_size = intermediate_size
204
+ self.num_hidden_layers = num_hidden_layers
205
+ self.num_attention_heads = num_attention_heads
206
+ self.num_channels = num_channels
207
+ self.patch_size = patch_size
208
+ self.image_size = image_size
209
+ self.attention_dropout = attention_dropout
210
+ self.layer_norm_eps = layer_norm_eps
211
+ self.hidden_act = hidden_act
212
+
213
+ @classmethod
214
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
215
+ cls._set_token_in_kwargs(kwargs)
216
+
217
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
218
+
219
+ # get the vision config dict if we are loading from SiglipConfig
220
+ if config_dict.get("model_type") == "siglip":
221
+ config_dict = config_dict["vision_config"]
222
+
223
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
224
+ logger.warning(
225
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
226
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
227
+ )
228
+
229
+ return cls.from_dict(config_dict, **kwargs)
230
+
231
+
232
+ class SiglipConfig(PretrainedConfig):
233
+ r"""
234
+ [`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to
235
+ instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs.
236
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip
237
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
238
+
239
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
240
+ documentation from [`PretrainedConfig`] for more information.
241
+
242
+ Args:
243
+ text_config (`dict`, *optional*):
244
+ Dictionary of configuration options used to initialize [`SiglipTextConfig`].
245
+ vision_config (`dict`, *optional*):
246
+ Dictionary of configuration options used to initialize [`SiglipVisionConfig`].
247
+ kwargs (*optional*):
248
+ Dictionary of keyword arguments.
249
+
250
+ Example:
251
+
252
+ ```python
253
+ >>> from transformers import SiglipConfig, SiglipModel
254
+
255
+ >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration
256
+ >>> configuration = SiglipConfig()
257
+
258
+ >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration
259
+ >>> model = SiglipModel(configuration)
260
+
261
+ >>> # Accessing the model configuration
262
+ >>> configuration = model.config
263
+
264
+ >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig
265
+ >>> from transformers import SiglipTextConfig, SiglipVisionConfig
266
+
267
+ >>> # Initializing a SiglipText and SiglipVision configuration
268
+ >>> config_text = SiglipTextConfig()
269
+ >>> config_vision = SiglipVisionConfig()
270
+
271
+ >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision)
272
+ ```"""
273
+
274
+ model_type = "siglip"
275
+
276
+ def __init__(self, text_config=None, vision_config=None, **kwargs):
277
+ super().__init__(**kwargs)
278
+
279
+ if text_config is None:
280
+ text_config = {}
281
+ logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.")
282
+
283
+ if vision_config is None:
284
+ vision_config = {}
285
+ logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.")
286
+
287
+ self.text_config = SiglipTextConfig(**text_config)
288
+ self.vision_config = SiglipVisionConfig(**vision_config)
289
+
290
+ self.initializer_factor = 1.0
291
+
292
+ @classmethod
293
+ def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs):
294
+ r"""
295
+ Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision
296
+ model configuration.
297
+
298
+ Returns:
299
+ [`SiglipConfig`]: An instance of a configuration object
300
+ """
301
+
302
+ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
conversation.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Conversation prompt templates.
3
+
4
+ We kindly request that you import fastchat instead of copying this file if you wish to use it.
5
+ If you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates.
6
+ """
7
+
8
+ import dataclasses
9
+ from enum import IntEnum, auto
10
+ from typing import Any, Dict, List, Tuple, Union
11
+
12
+
13
+ class SeparatorStyle(IntEnum):
14
+ """Separator styles."""
15
+
16
+ ADD_COLON_SINGLE = auto()
17
+ ADD_COLON_TWO = auto()
18
+ ADD_COLON_SPACE_SINGLE = auto()
19
+ NO_COLON_SINGLE = auto()
20
+ NO_COLON_TWO = auto()
21
+ ADD_NEW_LINE_SINGLE = auto()
22
+ LLAMA2 = auto()
23
+ CHATGLM = auto()
24
+ CHATML = auto()
25
+ CHATINTERN = auto()
26
+ DOLLY = auto()
27
+ RWKV = auto()
28
+ PHOENIX = auto()
29
+ ROBIN = auto()
30
+ FALCON_CHAT = auto()
31
+ CHATGLM3 = auto()
32
+ INTERNVL_ZH = auto()
33
+ MPT = auto()
34
+ LLAMA3 = auto()
35
+
36
+
37
+ @dataclasses.dataclass
38
+ class Conversation:
39
+ """A class that manages prompt templates and keeps all conversation history."""
40
+
41
+ # The name of this template
42
+ name: str
43
+ # The template of the system prompt
44
+ system_template: str = '{system_message}'
45
+ # The system message
46
+ system_message: str = ''
47
+ # The names of two roles
48
+ roles: Tuple[str] = ('USER', 'ASSISTANT')
49
+ # All messages. Each item is (role, message).
50
+ messages: List[List[str]] = ()
51
+ # The number of few shot examples
52
+ offset: int = 0
53
+ # The separator style and configurations
54
+ sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
55
+ sep: str = '\n'
56
+ sep2: str = None
57
+ # Stop criteria (the default one is EOS token)
58
+ stop_str: Union[str, List[str]] = None
59
+ # Stops generation if meeting any token in this list
60
+ stop_token_ids: List[int] = None
61
+
62
+ def get_prompt(self) -> str:
63
+ """Get the prompt for generation."""
64
+ system_prompt = self.system_template.format(system_message=self.system_message)
65
+ if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
66
+ ret = system_prompt + self.sep
67
+ for role, message in self.messages:
68
+ if message:
69
+ ret += role + ': ' + message + self.sep
70
+ else:
71
+ ret += role + ':'
72
+ return ret
73
+ elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
74
+ seps = [self.sep, self.sep2]
75
+ ret = system_prompt + seps[0]
76
+ for i, (role, message) in enumerate(self.messages):
77
+ if message:
78
+ ret += role + ': ' + message + seps[i % 2]
79
+ else:
80
+ ret += role + ':'
81
+ return ret
82
+ elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
83
+ ret = system_prompt + self.sep
84
+ for role, message in self.messages:
85
+ if message:
86
+ ret += role + ': ' + message + self.sep
87
+ else:
88
+ ret += role + ': ' # must be end with a space
89
+ return ret
90
+ elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
91
+ ret = '' if system_prompt == '' else system_prompt + self.sep
92
+ for role, message in self.messages:
93
+ if message:
94
+ ret += role + '\n' + message + self.sep
95
+ else:
96
+ ret += role + '\n'
97
+ return ret
98
+ elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
99
+ ret = system_prompt
100
+ for role, message in self.messages:
101
+ if message:
102
+ ret += role + message + self.sep
103
+ else:
104
+ ret += role
105
+ return ret
106
+ elif self.sep_style == SeparatorStyle.NO_COLON_TWO:
107
+ seps = [self.sep, self.sep2]
108
+ ret = system_prompt
109
+ for i, (role, message) in enumerate(self.messages):
110
+ if message:
111
+ ret += role + message + seps[i % 2]
112
+ else:
113
+ ret += role
114
+ return ret
115
+ elif self.sep_style == SeparatorStyle.RWKV:
116
+ ret = system_prompt
117
+ for i, (role, message) in enumerate(self.messages):
118
+ if message:
119
+ ret += (
120
+ role
121
+ + ': '
122
+ + message.replace('\r\n', '\n').replace('\n\n', '\n')
123
+ )
124
+ ret += '\n\n'
125
+ else:
126
+ ret += role + ':'
127
+ return ret
128
+ elif self.sep_style == SeparatorStyle.LLAMA2:
129
+ seps = [self.sep, self.sep2]
130
+ if self.system_message:
131
+ ret = system_prompt
132
+ else:
133
+ ret = '[INST] '
134
+ for i, (role, message) in enumerate(self.messages):
135
+ tag = self.roles[i % 2]
136
+ if message:
137
+ if i == 0:
138
+ ret += message + ' '
139
+ else:
140
+ ret += tag + ' ' + message + seps[i % 2]
141
+ else:
142
+ ret += tag
143
+ return ret
144
+ elif self.sep_style == SeparatorStyle.CHATGLM:
145
+ # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
146
+ # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
147
+ round_add_n = 1 if self.name == 'chatglm2' else 0
148
+ if system_prompt:
149
+ ret = system_prompt + self.sep
150
+ else:
151
+ ret = ''
152
+
153
+ for i, (role, message) in enumerate(self.messages):
154
+ if i % 2 == 0:
155
+ ret += f'[Round {i//2 + round_add_n}]{self.sep}'
156
+
157
+ if message:
158
+ ret += f'{role}:{message}{self.sep}'
159
+ else:
160
+ ret += f'{role}:'
161
+ return ret
162
+ elif self.sep_style == SeparatorStyle.CHATML:
163
+ ret = '' if system_prompt == '' else system_prompt + self.sep + '\n'
164
+ for role, message in self.messages:
165
+ if message:
166
+ ret += role + '\n' + message + self.sep + '\n'
167
+ else:
168
+ ret += role + '\n'
169
+ return ret
170
+ elif self.sep_style == SeparatorStyle.CHATGLM3:
171
+ ret = ''
172
+ if self.system_message:
173
+ ret += system_prompt
174
+ for role, message in self.messages:
175
+ if message:
176
+ ret += role + '\n' + ' ' + message
177
+ else:
178
+ ret += role
179
+ return ret
180
+ elif self.sep_style == SeparatorStyle.CHATINTERN:
181
+ # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
182
+ seps = [self.sep, self.sep2]
183
+ ret = system_prompt
184
+ for i, (role, message) in enumerate(self.messages):
185
+ # if i % 2 == 0:
186
+ # ret += "<s>"
187
+ if message:
188
+ ret += role + ':' + message + seps[i % 2] + '\n'
189
+ else:
190
+ ret += role + ':'
191
+ return ret
192
+ elif self.sep_style == SeparatorStyle.DOLLY:
193
+ seps = [self.sep, self.sep2]
194
+ ret = system_prompt
195
+ for i, (role, message) in enumerate(self.messages):
196
+ if message:
197
+ ret += role + ':\n' + message + seps[i % 2]
198
+ if i % 2 == 1:
199
+ ret += '\n\n'
200
+ else:
201
+ ret += role + ':\n'
202
+ return ret
203
+ elif self.sep_style == SeparatorStyle.PHOENIX:
204
+ ret = system_prompt
205
+ for role, message in self.messages:
206
+ if message:
207
+ ret += role + ': ' + '<s>' + message + '</s>'
208
+ else:
209
+ ret += role + ': ' + '<s>'
210
+ return ret
211
+ elif self.sep_style == SeparatorStyle.ROBIN:
212
+ ret = system_prompt + self.sep
213
+ for role, message in self.messages:
214
+ if message:
215
+ ret += role + ':\n' + message + self.sep
216
+ else:
217
+ ret += role + ':\n'
218
+ return ret
219
+ elif self.sep_style == SeparatorStyle.FALCON_CHAT:
220
+ ret = ''
221
+ if self.system_message:
222
+ ret += system_prompt + self.sep
223
+ for role, message in self.messages:
224
+ if message:
225
+ ret += role + ': ' + message + self.sep
226
+ else:
227
+ ret += role + ':'
228
+
229
+ return ret
230
+ elif self.sep_style == SeparatorStyle.INTERNVL_ZH:
231
+ seps = [self.sep, self.sep2]
232
+ ret = self.system_message + seps[0]
233
+ for i, (role, message) in enumerate(self.messages):
234
+ if message:
235
+ ret += role + ': ' + message + seps[i % 2]
236
+ else:
237
+ ret += role + ':'
238
+ return ret
239
+ elif self.sep_style == SeparatorStyle.MPT:
240
+ ret = system_prompt + self.sep
241
+ for role, message in self.messages:
242
+ if message:
243
+ if type(message) is tuple:
244
+ message, _, _ = message
245
+ ret += role + message + self.sep
246
+ else:
247
+ ret += role
248
+ return ret
249
+ elif self.sep_style == SeparatorStyle.LLAMA3:
250
+ ret = system_prompt + self.sep
251
+ for role, message in self.messages:
252
+ if message:
253
+ if type(message) is tuple:
254
+ message, _, _ = message
255
+ ret += role + message + self.sep
256
+ else:
257
+ ret += role
258
+ return ret
259
+ else:
260
+ raise ValueError(f'Invalid style: {self.sep_style}')
261
+
262
+ def set_system_message(self, system_message: str):
263
+ """Set the system message."""
264
+ self.system_message = system_message
265
+
266
+ def append_message(self, role: str, message: str):
267
+ """Append a new message."""
268
+ self.messages.append([role, message])
269
+
270
+ def update_last_message(self, message: str):
271
+ """Update the last output.
272
+
273
+ The last message is typically set to be None when constructing the prompt,
274
+ so we need to update it in-place after getting the response from a model.
275
+ """
276
+ self.messages[-1][1] = message
277
+
278
+ def to_gradio_chatbot(self):
279
+ """Convert the conversation to gradio chatbot format."""
280
+ ret = []
281
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
282
+ if i % 2 == 0:
283
+ ret.append([msg, None])
284
+ else:
285
+ ret[-1][-1] = msg
286
+ return ret
287
+
288
+ def to_openai_api_messages(self):
289
+ """Convert the conversation to OpenAI chat completion format."""
290
+ ret = [{'role': 'system', 'content': self.system_message}]
291
+
292
+ for i, (_, msg) in enumerate(self.messages[self.offset :]):
293
+ if i % 2 == 0:
294
+ ret.append({'role': 'user', 'content': msg})
295
+ else:
296
+ if msg is not None:
297
+ ret.append({'role': 'assistant', 'content': msg})
298
+ return ret
299
+
300
+ def copy(self):
301
+ return Conversation(
302
+ name=self.name,
303
+ system_template=self.system_template,
304
+ system_message=self.system_message,
305
+ roles=self.roles,
306
+ messages=[[x, y] for x, y in self.messages],
307
+ offset=self.offset,
308
+ sep_style=self.sep_style,
309
+ sep=self.sep,
310
+ sep2=self.sep2,
311
+ stop_str=self.stop_str,
312
+ stop_token_ids=self.stop_token_ids,
313
+ )
314
+
315
+ def dict(self):
316
+ return {
317
+ 'template_name': self.name,
318
+ 'system_message': self.system_message,
319
+ 'roles': self.roles,
320
+ 'messages': self.messages,
321
+ 'offset': self.offset,
322
+ }
323
+
324
+
325
+ # A global registry for all conversation templates
326
+ conv_templates: Dict[str, Conversation] = {}
327
+
328
+
329
+ def register_conv_template(template: Conversation, override: bool = False):
330
+ """Register a new conversation template."""
331
+ if not override:
332
+ assert (
333
+ template.name not in conv_templates
334
+ ), f'{template.name} has been registered.'
335
+
336
+ conv_templates[template.name] = template
337
+
338
+
339
+ def get_conv_template(name: str) -> Conversation:
340
+ """Get a conversation template."""
341
+ return conv_templates[name].copy()
342
+
343
+
344
+ # Note that for inference, using the Hermes-2 and internlm2-chat templates is equivalent.
345
+ register_conv_template(
346
+ Conversation(
347
+ name='Hermes-2',
348
+ system_template='<|im_start|>system\n{system_message}',
349
+ # note: The new system prompt was not used here to avoid changes in benchmark performance.
350
+ # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室及多家合作单位联合开发的多模态大语言模型。人工智能实验室致力于原始技术创新,开源开放,共享共创,推动科技进步和产业发展。',
351
+ system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
352
+ roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
353
+ sep_style=SeparatorStyle.MPT,
354
+ sep='<|im_end|>',
355
+ stop_token_ids=[
356
+ 2,
357
+ 6,
358
+ 7,
359
+ 8,
360
+ ],
361
+ stop_str='<|endoftext|>',
362
+ )
363
+ )
364
+
365
+
366
+ register_conv_template(
367
+ Conversation(
368
+ name='internlm2-chat',
369
+ system_template='<|im_start|>system\n{system_message}',
370
+ # note: The new system prompt was not used here to avoid changes in benchmark performance.
371
+ # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室及多家合作单位联合开发的多模态大语言模型。人工智能实验室致力于原始技术创新,开源开放,共享共创,推动科技进步和产业发展。',
372
+ system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
373
+ roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
374
+ sep_style=SeparatorStyle.MPT,
375
+ sep='<|im_end|>',
376
+ stop_token_ids=[
377
+ 2,
378
+ 92543,
379
+ 92542
380
+ ]
381
+ )
382
+ )
383
+
384
+
385
+ register_conv_template(
386
+ Conversation(
387
+ name='phi3-chat',
388
+ system_template='<|system|>\n{system_message}',
389
+ # note: The new system prompt was not used here to avoid changes in benchmark performance.
390
+ # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室及��家合作单位联合开发的多模态大语言模型。人工智能实验室致力于原始技术创新,开源开放,共享共创,推动科技进步和产业发展。',
391
+ system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
392
+ roles=('<|user|>\n', '<|assistant|>\n'),
393
+ sep_style=SeparatorStyle.MPT,
394
+ sep='<|end|>',
395
+ stop_token_ids=[
396
+ 2,
397
+ 32000,
398
+ 32007
399
+ ]
400
+ )
401
+ )
402
+ register_conv_template(
403
+ Conversation(
404
+ name='llama3-chat',
405
+ system_template='<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}',
406
+ system_message='You are an AI assistant whose name is Eagle-Next.',
407
+ roles=('<|start_header_id|>user<|end_header_id|>\n\n', '<|start_header_id|>assistant<|end_header_id|>\n\n'),
408
+ sep_style=SeparatorStyle.LLAMA3,
409
+ sep='<|eot_id|>',
410
+ stop_token_ids=[
411
+ 128259,
412
+ 128001
413
+ ]
414
+ )
415
+ )
416
+
417
+ # Qwen-chat default template
418
+ # source: https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/qwen_generation_utils.py#L130
419
+ register_conv_template(
420
+ Conversation(
421
+ name='qwen2-chat',
422
+ system_template='<|im_start|>system\n{system_message}',
423
+ system_message='You are a helpful assistant.',
424
+ roles=('<|im_start|>user', '<|im_start|>assistant'),
425
+ sep_style=SeparatorStyle.CHATML,
426
+ sep='<|im_end|>',
427
+ stop_token_ids=[
428
+ 151643,
429
+ 151644,
430
+ 151645,
431
+ ], # "<|endoftext|>", "<|im_start|>", "<|im_end|>"
432
+ stop_str='<|endoftext|>',
433
+ )
434
+ )
convnext.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ ConvNeXt
2
+
3
+ Papers:
4
+ * `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
5
+ @Article{liu2022convnet,
6
+ author = {Zhuang Liu and Hanzi Mao and Chao-Yuan Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
7
+ title = {A ConvNet for the 2020s},
8
+ journal = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
9
+ year = {2022},
10
+ }
11
+
12
+ * `ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808
13
+ @article{Woo2023ConvNeXtV2,
14
+ title={ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders},
15
+ author={Sanghyun Woo, Shoubhik Debnath, Ronghang Hu, Xinlei Chen, Zhuang Liu, In So Kweon and Saining Xie},
16
+ year={2023},
17
+ journal={arXiv preprint arXiv:2301.00808},
18
+ }
19
+
20
+ Original code and weights from:
21
+ * https://github.com/facebookresearch/ConvNeXt, original copyright below
22
+ * https://github.com/facebookresearch/ConvNeXt-V2, original copyright below
23
+
24
+ Model defs atto, femto, pico, nano and _ols / _hnf variants are timm originals.
25
+
26
+ Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman
27
+ """
28
+ # ConvNeXt
29
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
30
+ # All rights reserved.
31
+ # This source code is licensed under the MIT license
32
+
33
+ # ConvNeXt-V2
34
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
35
+ # All rights reserved.
36
+ # This source code is licensed under the license found in the
37
+ # LICENSE file in the root directory of this source tree (Attribution-NonCommercial 4.0 International (CC BY-NC 4.0))
38
+ # No code was used directly from ConvNeXt-V2, however the weights are CC BY-NC 4.0 so beware if using commercially.
39
+
40
+ from collections import OrderedDict
41
+ from functools import partial
42
+ from typing import Callable, Optional, Tuple, Union
43
+
44
+ import torch
45
+ import torch.nn as nn
46
+
47
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
48
+ from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
49
+ LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
50
+ from timm.layers import NormMlpClassifierHead, ClassifierHead
51
+ from timm.models._builder import build_model_with_cfg
52
+ from timm.models._manipulate import named_apply, checkpoint_seq
53
+ from timm.models._registry import generate_default_cfgs, register_model, register_model_deprecations
54
+
55
+ __all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
56
+
57
+
58
+ class Downsample(nn.Module):
59
+
60
+ def __init__(self, in_chs, out_chs, stride=1, dilation=1):
61
+ super().__init__()
62
+ avg_stride = stride if dilation == 1 else 1
63
+ if stride > 1 or dilation > 1:
64
+ avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
65
+ self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
66
+ else:
67
+ self.pool = nn.Identity()
68
+
69
+ if in_chs != out_chs:
70
+ self.conv = create_conv2d(in_chs, out_chs, 1, stride=1)
71
+ else:
72
+ self.conv = nn.Identity()
73
+
74
+ def forward(self, x):
75
+ x = self.pool(x)
76
+ x = self.conv(x)
77
+ return x
78
+
79
+
80
+ class ConvNeXtBlock(nn.Module):
81
+ """ ConvNeXt Block
82
+ There are two equivalent implementations:
83
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
84
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
85
+
86
+ Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
87
+ choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
88
+ is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ in_chs: int,
94
+ out_chs: Optional[int] = None,
95
+ kernel_size: int = 7,
96
+ stride: int = 1,
97
+ dilation: Union[int, Tuple[int, int]] = (1, 1),
98
+ mlp_ratio: float = 4,
99
+ conv_mlp: bool = False,
100
+ conv_bias: bool = True,
101
+ use_grn: bool = False,
102
+ ls_init_value: Optional[float] = 1e-6,
103
+ act_layer: Union[str, Callable] = 'gelu',
104
+ norm_layer: Optional[Callable] = None,
105
+ drop_path: float = 0.,
106
+ ):
107
+ """
108
+
109
+ Args:
110
+ in_chs: Block input channels.
111
+ out_chs: Block output channels (same as in_chs if None).
112
+ kernel_size: Depthwise convolution kernel size.
113
+ stride: Stride of depthwise convolution.
114
+ dilation: Tuple specifying input and output dilation of block.
115
+ mlp_ratio: MLP expansion ratio.
116
+ conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
117
+ conv_bias: Apply bias for all convolution (linear) layers.
118
+ use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
119
+ ls_init_value: Layer-scale init values, layer-scale applied if not None.
120
+ act_layer: Activation layer.
121
+ norm_layer: Normalization layer (defaults to LN if not specified).
122
+ drop_path: Stochastic depth probability.
123
+ """
124
+ super().__init__()
125
+ out_chs = out_chs or in_chs
126
+ dilation = to_ntuple(2)(dilation)
127
+ act_layer = get_act_layer(act_layer)
128
+ if not norm_layer:
129
+ norm_layer = LayerNorm2d if conv_mlp else LayerNorm
130
+ mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp)
131
+ self.use_conv_mlp = conv_mlp
132
+ self.conv_dw = create_conv2d(
133
+ in_chs,
134
+ out_chs,
135
+ kernel_size=kernel_size,
136
+ stride=stride,
137
+ dilation=dilation[0],
138
+ depthwise=True,
139
+ bias=conv_bias,
140
+ )
141
+ self.norm = norm_layer(out_chs)
142
+ self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
143
+ self.weight = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None
144
+ if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
145
+ self.shortcut = Downsample(in_chs, out_chs, stride=stride, dilation=dilation[0])
146
+ else:
147
+ self.shortcut = nn.Identity()
148
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
149
+
150
+ def forward(self, x):
151
+ shortcut = x
152
+ x = self.conv_dw(x)
153
+ if self.use_conv_mlp:
154
+ x = self.norm(x)
155
+ x = self.mlp(x)
156
+ else:
157
+ x = x.permute(0, 2, 3, 1)
158
+ x = self.norm(x)
159
+ x = self.mlp(x)
160
+ x = x.permute(0, 3, 1, 2)
161
+ if self.weight is not None:
162
+ x = x.mul(self.weight.reshape(1, -1, 1, 1))
163
+
164
+ x = self.drop_path(x) + self.shortcut(shortcut)
165
+ return x
166
+
167
+
168
+ class ConvNeXtStage(nn.Module):
169
+
170
+ def __init__(
171
+ self,
172
+ in_chs,
173
+ out_chs,
174
+ kernel_size=7,
175
+ stride=2,
176
+ depth=2,
177
+ dilation=(1, 1),
178
+ drop_path_rates=None,
179
+ ls_init_value=1.0,
180
+ conv_mlp=False,
181
+ conv_bias=True,
182
+ use_grn=False,
183
+ act_layer='gelu',
184
+ norm_layer=None,
185
+ norm_layer_cl=None
186
+ ):
187
+ super().__init__()
188
+ self.grad_checkpointing = False
189
+
190
+ if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]:
191
+ ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1
192
+ pad = 'same' if dilation[1] > 1 else 0 # same padding needed if dilation used
193
+ self.downsample = nn.Sequential(
194
+ norm_layer(in_chs),
195
+ create_conv2d(
196
+ in_chs,
197
+ out_chs,
198
+ kernel_size=ds_ks,
199
+ stride=stride,
200
+ dilation=dilation[0],
201
+ padding=pad,
202
+ bias=conv_bias,
203
+ ),
204
+ )
205
+ in_chs = out_chs
206
+ else:
207
+ self.downsample = nn.Identity()
208
+
209
+ drop_path_rates = drop_path_rates or [0.] * depth
210
+ stage_blocks = []
211
+ for i in range(depth):
212
+ stage_blocks.append(ConvNeXtBlock(
213
+ in_chs=in_chs,
214
+ out_chs=out_chs,
215
+ kernel_size=kernel_size,
216
+ dilation=dilation[1],
217
+ drop_path=drop_path_rates[i],
218
+ ls_init_value=ls_init_value,
219
+ conv_mlp=conv_mlp,
220
+ conv_bias=conv_bias,
221
+ use_grn=use_grn,
222
+ act_layer=act_layer,
223
+ norm_layer=norm_layer if conv_mlp else norm_layer_cl,
224
+ ))
225
+ in_chs = out_chs
226
+ self.blocks = nn.Sequential(*stage_blocks)
227
+
228
+ def forward(self, x):
229
+ x = self.downsample(x)
230
+ if self.grad_checkpointing and not torch.jit.is_scripting():
231
+ x = checkpoint_seq(self.blocks, x)
232
+ else:
233
+ x = self.blocks(x)
234
+ return x
235
+
236
+
237
+ class ConvNeXt(nn.Module):
238
+ r""" ConvNeXt
239
+ A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
240
+ """
241
+
242
+ def __init__(
243
+ self,
244
+ in_chans: int = 3,
245
+ num_classes: int = 1000,
246
+ global_pool: str = 'avg',
247
+ output_stride: int = 32,
248
+ depths: Tuple[int, ...] = (3, 3, 9, 3),
249
+ dims: Tuple[int, ...] = (96, 192, 384, 768),
250
+ kernel_sizes: Union[int, Tuple[int, ...]] = 7,
251
+ ls_init_value: Optional[float] = 1e-6,
252
+ stem_type: str = 'patch',
253
+ patch_size: int = 4,
254
+ head_init_scale: float = 1.,
255
+ head_norm_first: bool = False,
256
+ head_hidden_size: Optional[int] = None,
257
+ conv_mlp: bool = False,
258
+ conv_bias: bool = True,
259
+ use_grn: bool = False,
260
+ act_layer: Union[str, Callable] = 'gelu',
261
+ norm_layer: Optional[Union[str, Callable]] = None,
262
+ norm_eps: Optional[float] = None,
263
+ drop_rate: float = 0.,
264
+ drop_path_rate: float = 0.,
265
+ ):
266
+ """
267
+ Args:
268
+ in_chans: Number of input image channels.
269
+ num_classes: Number of classes for classification head.
270
+ global_pool: Global pooling type.
271
+ output_stride: Output stride of network, one of (8, 16, 32).
272
+ depths: Number of blocks at each stage.
273
+ dims: Feature dimension at each stage.
274
+ kernel_sizes: Depthwise convolution kernel-sizes for each stage.
275
+ ls_init_value: Init value for Layer Scale, disabled if None.
276
+ stem_type: Type of stem.
277
+ patch_size: Stem patch size for patch stem.
278
+ head_init_scale: Init scaling value for classifier weights and biases.
279
+ head_norm_first: Apply normalization before global pool + head.
280
+ head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False.
281
+ conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last.
282
+ conv_bias: Use bias layers w/ all convolutions.
283
+ use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP.
284
+ act_layer: Activation layer type.
285
+ norm_layer: Normalization layer type.
286
+ drop_rate: Head pre-classifier dropout rate.
287
+ drop_path_rate: Stochastic depth drop rate.
288
+ """
289
+ super().__init__()
290
+ assert output_stride in (8, 16, 32)
291
+ kernel_sizes = to_ntuple(4)(kernel_sizes)
292
+ if norm_layer is None:
293
+ norm_layer = LayerNorm2d
294
+ norm_layer_cl = norm_layer if conv_mlp else LayerNorm
295
+ if norm_eps is not None:
296
+ norm_layer = partial(norm_layer, eps=norm_eps)
297
+ norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
298
+ else:
299
+ assert conv_mlp,\
300
+ 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
301
+ norm_layer_cl = norm_layer
302
+ if norm_eps is not None:
303
+ norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
304
+
305
+ self.num_classes = num_classes
306
+ self.drop_rate = drop_rate
307
+ self.feature_info = []
308
+
309
+ assert stem_type in ('patch', 'overlap', 'overlap_tiered')
310
+ if stem_type == 'patch':
311
+ # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
312
+ self.stem = nn.Sequential(
313
+ nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias),
314
+ norm_layer(dims[0]),
315
+ )
316
+ stem_stride = patch_size
317
+ else:
318
+ mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0]
319
+ self.stem = nn.Sequential(
320
+ nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias),
321
+ nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias),
322
+ norm_layer(dims[0]),
323
+ )
324
+ stem_stride = 4
325
+
326
+ self.stages = nn.Sequential()
327
+ dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
328
+ stages = []
329
+ prev_chs = dims[0]
330
+ curr_stride = stem_stride
331
+ dilation = 1
332
+ # 4 feature resolution stages, each consisting of multiple residual blocks
333
+ for i in range(4):
334
+ stride = 2 if curr_stride == 2 or i > 0 else 1
335
+ if curr_stride >= output_stride and stride > 1:
336
+ dilation *= stride
337
+ stride = 1
338
+ curr_stride *= stride
339
+ first_dilation = 1 if dilation in (1, 2) else 2
340
+ out_chs = dims[i]
341
+ stages.append(ConvNeXtStage(
342
+ prev_chs,
343
+ out_chs,
344
+ kernel_size=kernel_sizes[i],
345
+ stride=stride,
346
+ dilation=(first_dilation, dilation),
347
+ depth=depths[i],
348
+ drop_path_rates=dp_rates[i],
349
+ ls_init_value=ls_init_value,
350
+ conv_mlp=conv_mlp,
351
+ conv_bias=conv_bias,
352
+ use_grn=use_grn,
353
+ act_layer=act_layer,
354
+ norm_layer=norm_layer,
355
+ norm_layer_cl=norm_layer_cl,
356
+ ))
357
+ prev_chs = out_chs
358
+ # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
359
+ self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
360
+ self.stages = nn.Sequential(*stages)
361
+ self.num_features = prev_chs
362
+
363
+ # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
364
+ # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
365
+ if head_norm_first:
366
+ assert not head_hidden_size
367
+ self.norm_pre = norm_layer(self.num_features)
368
+ self.head = ClassifierHead(
369
+ self.num_features,
370
+ num_classes,
371
+ pool_type=global_pool,
372
+ drop_rate=self.drop_rate,
373
+ )
374
+ else:
375
+ self.norm_pre = nn.Identity()
376
+ self.head = NormMlpClassifierHead(
377
+ self.num_features,
378
+ num_classes,
379
+ hidden_size=head_hidden_size,
380
+ pool_type=global_pool,
381
+ drop_rate=self.drop_rate,
382
+ norm_layer=norm_layer,
383
+ act_layer='gelu',
384
+ )
385
+ named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
386
+
387
+ @torch.jit.ignore
388
+ def group_matcher(self, coarse=False):
389
+ return dict(
390
+ stem=r'^stem',
391
+ blocks=r'^stages\.(\d+)' if coarse else [
392
+ (r'^stages\.(\d+)\.downsample', (0,)), # blocks
393
+ (r'^stages\.(\d+)\.blocks\.(\d+)', None),
394
+ (r'^norm_pre', (99999,))
395
+ ]
396
+ )
397
+
398
+ @torch.jit.ignore
399
+ def set_grad_checkpointing(self, enable=True):
400
+ for s in self.stages:
401
+ s.grad_checkpointing = enable
402
+
403
+ @torch.jit.ignore
404
+ def get_classifier(self):
405
+ return self.head.fc
406
+
407
+ def reset_classifier(self, num_classes=0, global_pool=None):
408
+ self.head.reset(num_classes, global_pool)
409
+
410
+ def forward_features(self, x):
411
+ x = self.stem(x)
412
+ x = self.stages(x)
413
+ x = self.norm_pre(x)
414
+ return x
415
+
416
+ def forward_head(self, x, pre_logits: bool = False):
417
+ return self.head(x, pre_logits=True) if pre_logits else self.head(x)
418
+
419
+ def forward(self, x):
420
+ x = self.forward_features(x)
421
+ x = self.forward_head(x)
422
+ return x
423
+
424
+
425
+ def _init_weights(module, name=None, head_init_scale=1.0):
426
+ if isinstance(module, nn.Conv2d):
427
+ trunc_normal_(module.weight, std=.02)
428
+ if module.bias is not None:
429
+ nn.init.zeros_(module.bias)
430
+ elif isinstance(module, nn.Linear):
431
+ trunc_normal_(module.weight, std=.02)
432
+ nn.init.zeros_(module.bias)
433
+ if name and 'head.' in name:
434
+ module.weight.data.mul_(head_init_scale)
435
+ module.bias.data.mul_(head_init_scale)
436
+
437
+
438
+ def checkpoint_filter_fn(state_dict, model):
439
+ """ Remap FB checkpoints -> timm """
440
+ if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict:
441
+ out_dict={}
442
+ out_dict = {k.replace('gamma', 'weight'): v for k, v in state_dict.items()}
443
+ return out_dict # non-FB checkpoint
444
+ if 'model' in state_dict:
445
+ state_dict = state_dict['model']
446
+
447
+ out_dict = {}
448
+ if 'visual.trunk.stem.0.weight' in state_dict:
449
+ out_dict = {k.replace('visual.trunk.', '').replace('gamma', 'weight'): v for k, v in state_dict.items() if
450
+ k.startswith('visual.trunk.')}
451
+
452
+ if 'visual.head.proj.weight' in state_dict:
453
+ out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight']
454
+ out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
455
+ elif 'visual.head.mlp.fc1.weight' in state_dict:
456
+ out_dict['head.pre_logits.fc.weight'] = state_dict['visual.head.mlp.fc1.weight']
457
+ out_dict['head.pre_logits.fc.bias'] = state_dict['visual.head.mlp.fc1.bias']
458
+ out_dict['head.fc.weight'] = state_dict['visual.head.mlp.fc2.weight']
459
+ out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.mlp.fc2.weight'].shape[0])
460
+ return out_dict
461
+
462
+ import re
463
+ for k, v in state_dict.items():
464
+ k = k.replace('downsample_layers.0.', 'stem.')
465
+ k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
466
+ k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k)
467
+ k = k.replace('dwconv', 'conv_dw')
468
+ k = k.replace('pwconv', 'mlp.fc')
469
+ if 'grn' in k:
470
+ k = k.replace('grn.beta', 'mlp.grn.bias')
471
+ k = k.replace('grn.gamma', 'mlp.grn.weight')
472
+ v = v.reshape(v.shape[-1])
473
+ k = k.replace('head.', 'head.fc.')
474
+ if k.startswith('norm.'):
475
+ k = k.replace('norm', 'head.norm')
476
+ if v.ndim == 2 and 'head' not in k:
477
+ model_shape = model.state_dict()[k].shape
478
+ v = v.reshape(model_shape)
479
+ k=k.replace('gamma','weight')
480
+ out_dict[k] = v
481
+
482
+ return out_dict
483
+
484
+
485
+ def _create_convnext(variant, pretrained=False, **kwargs):
486
+ if kwargs.get('pretrained_cfg', '') == 'fcmae':
487
+ # NOTE fcmae pretrained weights have no classifier or final norm-layer (`head.norm`)
488
+ # This is workaround loading with num_classes=0 w/o removing norm-layer.
489
+ kwargs.setdefault('pretrained_strict', False)
490
+
491
+ model = build_model_with_cfg(
492
+ ConvNeXt, variant, pretrained,
493
+ pretrained_filter_fn=checkpoint_filter_fn,
494
+ feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
495
+ **kwargs)
496
+ return model
497
+
498
+
499
+ def _cfg(url='', **kwargs):
500
+ return {
501
+ 'url': url,
502
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
503
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
504
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
505
+ 'first_conv': 'stem.0', 'classifier': 'head.fc',
506
+ **kwargs
507
+ }
508
+
509
+
510
+ def _cfgv2(url='', **kwargs):
511
+ return {
512
+ 'url': url,
513
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
514
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
515
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
516
+ 'first_conv': 'stem.0', 'classifier': 'head.fc',
517
+ 'license': 'cc-by-nc-4.0', 'paper_ids': 'arXiv:2301.00808',
518
+ 'paper_name': 'ConvNeXt-V2: Co-designing and Scaling ConvNets with Masked Autoencoders',
519
+ 'origin_url': 'https://github.com/facebookresearch/ConvNeXt-V2',
520
+ **kwargs
521
+ }
522
+
523
+
524
+ default_cfgs = generate_default_cfgs({
525
+ 'convnext_xxlarge.clip_laion2b_soup_ft_in1k': _cfg(
526
+ hf_hub_id='timm/',
527
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
528
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
529
+
530
+ 'convnext_xxlarge.clip_laion2b_soup_ft_in12k': _cfg(
531
+ hf_hub_id='timm/',
532
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
533
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
534
+ 'convnext_xxlarge.clip_laion2b_soup': _cfg(
535
+ hf_hub_id='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup',
536
+ hf_hub_filename='open_clip_pytorch_model.bin',
537
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
538
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
539
+ 'convnext_xxlarge.clip_laion2b_rewind': _cfg(
540
+ hf_hub_id='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind',
541
+ hf_hub_filename='open_clip_pytorch_model.bin',
542
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
543
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
544
+ })
545
+
546
+
547
+
548
+ @register_model
549
+ def convnext_xxlarge(pretrained=False, **kwargs) -> ConvNeXt:
550
+ model_args = dict(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], norm_eps=kwargs.pop('norm_eps', 1e-5))
551
+ model = _create_convnext('convnext_xxlarge', pretrained=pretrained, **dict(model_args, **kwargs))
552
+ return model
553
+
554
+
555
+
556
+ # register_model_deprecations(__name__, {
557
+ # 'convnext_tiny_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k',
558
+ # 'convnext_small_in22ft1k': 'convnext_small.fb_in22k_ft_in1k',
559
+ # 'convnext_base_in22ft1k': 'convnext_base.fb_in22k_ft_in1k',
560
+ # 'convnext_large_in22ft1k': 'convnext_large.fb_in22k_ft_in1k',
561
+ # 'convnext_xlarge_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k',
562
+ # 'convnext_tiny_384_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k_384',
563
+ # 'convnext_small_384_in22ft1k': 'convnext_small.fb_in22k_ft_in1k_384',
564
+ # 'convnext_base_384_in22ft1k': 'convnext_base.fb_in22k_ft_in1k_384',
565
+ # 'convnext_large_384_in22ft1k': 'convnext_large.fb_in22k_ft_in1k_384',
566
+ # 'convnext_xlarge_384_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k_384',
567
+ # 'convnext_tiny_in22k': 'convnext_tiny.fb_in22k',
568
+ # 'convnext_small_in22k': 'convnext_small.fb_in22k',
569
+ # 'convnext_base_in22k': 'convnext_base.fb_in22k',
570
+ # 'convnext_large_in22k': 'convnext_large.fb_in22k',
571
+ # 'convnext_xlarge_in22k': 'convnext_xlarge.fb_in22k',
572
+ # })
convnext_encoder.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os
2
+ import torch.nn as nn
3
+ from timm import create_model
4
+ from transformers import CLIPImageProcessor
5
+ from .convnext import convnext_xxlarge
6
+ from torch.utils.checkpoint import checkpoint
7
+ import torch
8
+ from torchvision import transforms as T
9
+ from PIL import Image
10
+
11
+
12
+
13
+ cfg={
14
+ "crop_size": 256,
15
+ "do_center_crop": True,
16
+ "do_normalize": True,
17
+ "do_resize": True,
18
+ "feature_extractor_type": "CLIPFeatureExtractor",
19
+ "image_mean": [
20
+ 0.48145466,
21
+ 0.4578275,
22
+ 0.40821073
23
+ ],
24
+ "image_std": [
25
+ 0.26862954,
26
+ 0.26130258,
27
+ 0.27577711
28
+ ],
29
+ "resample": 3,
30
+ "size": 256
31
+ }
32
+
33
+
34
+
35
+ MEAN_SLIP = [0.5, 0.5, 0.5]
36
+ STD_SLIP = [0.5, 0.5, 0.5]
37
+
38
+ MEAN_CLIP = [0.48145466, 0.4578275, 0.40821073]
39
+ STD_CLIP = [0.26862954, 0.26130258, 0.27577711]
40
+
41
+
42
+ a = [s_slip / s_clip for s_slip, s_clip in zip(STD_SLIP, STD_CLIP)]
43
+ b = [(m_slip - m_clip) / s_clip for m_slip, m_clip, s_clip in zip(MEAN_SLIP, MEAN_CLIP, STD_CLIP)]
44
+
45
+
46
+ class SlipToClipTransform:
47
+ def __init__(self, a, b):
48
+ self.a = torch.tensor(a).view(-1, 1, 1)
49
+ self.b = torch.tensor(b).view(-1, 1, 1)
50
+
51
+ def __call__(self, x_slip):
52
+ return x_slip * self.a.to(x_slip.device) + self.b.to(x_slip.device)
53
+ slip_to_clip = SlipToClipTransform(a, b)
54
+
55
+ class ConvNextVisionTower(nn.Module):
56
+ def __init__(self, vision_tower, args, delay_load=False, normalize_type=None):
57
+ super().__init__()
58
+
59
+ self.is_loaded = False
60
+ self.freeze_vision=args.freeze_vision
61
+ self.input_image_size=args.input_image_size
62
+ self.vision_tower_name = vision_tower
63
+ self.name = 'convnext'
64
+ self.select_layer = args.mm_vision_select_layer
65
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
66
+ self.pre_norm = normalize_type
67
+
68
+ print('pre_norm: ', self.pre_norm)
69
+ self.delay_load = delay_load
70
+ self.load_model()
71
+
72
+ def load_model(self):
73
+ if 'xxlarge' in self.vision_tower_name:
74
+ if self.delay_load:
75
+ self.vision_tower = convnext_xxlarge(pretrained=False)
76
+ else:
77
+ self.vision_tower = convnext_xxlarge(self.vision_tower_name)
78
+ setattr(self.vision_tower, 'hidden_size', 3072)
79
+ elif os.path.exists(self.vision_tower_name):
80
+ self.vision_tower = torch.load(self.vision_tower_name)
81
+ else:
82
+ assert False, 'Not implemented'
83
+
84
+
85
+ self.vision_tower = self.vision_tower.to(torch.bfloat16)
86
+
87
+ if self.freeze_vision:
88
+ self.vision_tower.requires_grad_(False)
89
+
90
+ # if self.vision_tower.grad_checkpointing:
91
+ for s in self.vision_tower.stages:
92
+ s.grad_checkpointing = True
93
+
94
+ self.is_loaded = True
95
+
96
+ def feature_select(self, image_forward_outs):
97
+
98
+ if self.select_layer>100:
99
+ image_features = image_forward_outs[-4:]
100
+ else:
101
+ image_features = image_forward_outs[-1]
102
+ return image_features
103
+
104
+ def forward_features(self, x):
105
+ x = self.vision_tower.stem(x)
106
+ image_forward_out=[]
107
+ for blk in self.vision_tower.stages:
108
+ x = blk(x)
109
+ b,c,h,w=x.shape
110
+ image_forward_out.append(x.view(b,c,-1).transpose(1,2))
111
+ return image_forward_out
112
+
113
+ def forward(self, images):
114
+ if self.freeze_vision:
115
+ with torch.no_grad():
116
+ image_features = self._forward_images(images)
117
+ else:
118
+ image_features = self._forward_images(images)
119
+
120
+ return image_features
121
+
122
+ def _forward_images(self, images):
123
+
124
+ if type(images) is list:
125
+ image_features = []
126
+ for image in images:
127
+ if self.pre_norm == 'siglip':
128
+ dtype = image.dtype
129
+ image = slip_to_clip(image.to(torch.float32)).to(dtype)
130
+ image_forward_out = self.forward_features(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
131
+ image_feature = self.feature_select(image_forward_out)
132
+ image_features.append(image_feature)
133
+ else:
134
+ if self.pre_norm == 'siglip':
135
+ dtype = images.dtype
136
+ images = slip_to_clip(images.to(torch.float32)).to(dtype)
137
+ image_forward_outs = self.forward_features(images.to(device=self.device, dtype=self.dtype))
138
+ image_features = self.feature_select(image_forward_outs)
139
+
140
+ return image_features
141
+
142
+ @property
143
+ def dummy_feature(self):
144
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
145
+
146
+ @property
147
+ def dtype(self):
148
+ return next(self.vision_tower.parameters()).dtype
149
+
150
+ @property
151
+ def device(self):
152
+ return next(self.vision_tower.parameters()).device
153
+
154
+ @property
155
+ def config(self):
156
+ assert NotImplementedError
157
+ pass
158
+
159
+ @property
160
+ def num_attention_heads(self):
161
+ # as constant
162
+ return 16
163
+ @property
164
+ def num_layers(self):
165
+ # as constant
166
+ return 4
167
+ @property
168
+ def hidden_size(self):
169
+ return self.vision_tower.hidden_size
170
+
171
+ @property
172
+ def num_patches(self):
173
+ return (self.input_image_size // self.patch_embed.patch_size[0]) ** 2
174
+
175
+
176
+ class ConvNextFPNVisionTower(nn.Module):
177
+ def __init__(self,
178
+ vision_tower,
179
+ args,
180
+ fpn_target_level=1,
181
+ fpn_layer_idx=[1,2,3],
182
+ fpn_input_dim=[768,1536,3072],
183
+ delay_load=False):
184
+
185
+ super().__init__()
186
+
187
+ self.is_loaded = False
188
+ self.vision_tower_name = vision_tower.replace('-fpn', 'fpn')
189
+ self.freeze_vision = getattr(args, "frozen_backbone", True)
190
+ # self.input_image_size = getattr(args, "vision_tower_input_size", 1024)
191
+ self.input_image_size = 1024 # hardcode
192
+ self.select_layer = args.mm_vision_select_layer # no effect
193
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
194
+
195
+ self.need_fpn = True
196
+ self.fpn_layer_idx = fpn_layer_idx # [1, 2, 3] # x8, x16, x32
197
+ self.fpn_input_dim = [768, 1536, 3072]
198
+ self.delay_load = delay_load
199
+ self.load_model()
200
+
201
+ def load_model(self):
202
+ if self.is_loaded:
203
+ return
204
+
205
+ self.image_processor = CLIPImageProcessor(**cfg)
206
+ if 'xxlarge' in self.vision_tower_name:
207
+ self.vision_tower = convnext_xxlarge(self.vision_tower_name)
208
+ setattr(self.vision_tower, 'hidden_size', self.fpn_input_dim)
209
+ # setattr(self.vision_tower, 'hidden_size', 3072)
210
+ else:
211
+ self.vision_tower = convnext_large_mlp(self.vision_tower_name)
212
+ setattr(self.vision_tower, 'hidden_size', 1536)
213
+ if self.freeze_vision:
214
+ self.vision_tower.requires_grad_(False)
215
+
216
+ # if self.vision_tower.grad_checkpointing:
217
+ for s in self.vision_tower.stages:
218
+ s.grad_checkpointing = True
219
+
220
+ if self.input_image_size is not None:
221
+ self.image_processor.size=self.input_image_size
222
+ self.image_processor.crop_size={
223
+ 'height':self.input_image_size,
224
+ 'width': self.input_image_size
225
+ }
226
+
227
+ self.is_loaded = True
228
+
229
+ @torch.no_grad()
230
+ def forward_features(self, x):
231
+ x = self.vision_tower.stem(x)
232
+ image_forward_out=[]
233
+ for blk in self.vision_tower.stages:
234
+ x = blk(x)
235
+ image_forward_out.append(x)
236
+ return image_forward_out
237
+
238
+ @torch.no_grad()
239
+ def forward(self, images):
240
+ if type(images) is list:
241
+ image_features = []
242
+ for image in images:
243
+ image_feature = self.forward_features(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
244
+ image_features.append(image_feature)
245
+ else:
246
+ image_features = self.forward_features(images.to(device=self.device, dtype=self.dtype))
247
+ image_features = [image_features[idx] for idx in self.fpn_layer_idx]
248
+
249
+ return image_features
250
+
251
+ @property
252
+ def dummy_feature(self):
253
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
254
+
255
+ @property
256
+ def dtype(self):
257
+ return next(self.vision_tower.parameters()).dtype
258
+
259
+ @property
260
+ def device(self):
261
+ return next(self.vision_tower.parameters()).device
262
+
263
+ @property
264
+ def config(self):
265
+ assert NotImplementedError
266
+ pass
267
+
268
+ @property
269
+ def num_attention_heads(self):
270
+ # as constant
271
+ return 16
272
+ @property
273
+ def num_layers(self):
274
+ # as constant
275
+ return 4
276
+ @property
277
+ def hidden_size(self):
278
+ return self.vision_tower.hidden_size
279
+
280
+ @property
281
+ def num_patches(self):
282
+ return (cfg['image_size'] // self.patch_embed.patch_size[0]) ** 2
283
+
284
+ if __name__ == '__main__':
285
+ COMBINED_STD = [s_slip / s_clip for s_slip, s_clip in zip(STD_SigLIP, STD_CLIP)]
286
+ COMBINED_MEAN = [(m_slip - m_clip) / s_clip for m_slip, m_clip, s_clip in zip(MEAN_SigLIP, MEAN_CLIP, STD_CLIP)]
287
+
288
+ # 定义合并的归一化变换
289
+ combined_normalize = T.Normalize(mean=COMBINED_MEAN, std=COMBINED_STD)
290
+ x = torch.randn(1, 3, 256, 256).cuda()
291
+ a = normalize_clip(x).to(torch.bfloat16)
292
+ b = normalize_siglip(x).to(torch.bfloat16)
293
+ c = denormalize_siglip(b.to(torch.float32))
294
+ c2 = normalize_clip(c).to(torch.bfloat16)
295
+ c3 = combined_normalize(b)
296
+ print((c-x).abs().max())
297
+ print((c2-a).abs().max())
298
+ print((c3-a).abs().max())
299
+ from IPython import embed
300
+ embed()
301
+ exit()
demo.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ A model worker executes the model.
4
+ """
5
+ from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer, AutoConfig
6
+ import argparse
7
+ import base64
8
+ import json
9
+ import os
10
+ import decord
11
+ import threading
12
+ import time
13
+ from io import BytesIO
14
+ from threading import Thread
15
+ import math
16
+ import requests
17
+ import torch
18
+ import torchvision.transforms as T
19
+ from PIL import Image
20
+ from torchvision.transforms.functional import InterpolationMode
21
+
22
+ import numpy as np
23
+
24
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
25
+ IMAGENET_STD = (0.229, 0.224, 0.225)
26
+
27
+ SIGLIP_MEAN = (0.5, 0.5, 0.5)
28
+ SIGLIP_STD = (0.5, 0.5, 0.5)
29
+
30
+
31
+ def get_seq_frames(total_num_frames, desired_num_frames=-1, stride=-1):
32
+ """
33
+ Calculate the indices of frames to extract from a video.
34
+
35
+ Parameters:
36
+ total_num_frames (int): Total number of frames in the video.
37
+ desired_num_frames (int): Desired number of frames to extract.
38
+
39
+ Returns:
40
+ list: List of indices of frames to extract.
41
+ """
42
+
43
+ assert desired_num_frames > 0 or stride > 0 and not (desired_num_frames > 0 and stride > 0)
44
+
45
+ if stride > 0:
46
+ return list(range(0, total_num_frames, stride))
47
+
48
+ # Calculate the size of each segment from which a frame will be extracted
49
+ seg_size = float(total_num_frames - 1) / desired_num_frames
50
+
51
+ seq = []
52
+ for i in range(desired_num_frames):
53
+ # Calculate the start and end indices of each segment
54
+ start = int(np.round(seg_size * i))
55
+ end = int(np.round(seg_size * (i + 1)))
56
+
57
+ # Append the middle index of the segment to the list
58
+ seq.append((start + end) // 2)
59
+
60
+ return seq
61
+
62
+ def build_video_prompt(meta_list, num_frames, time_position=False):
63
+ # if time_position is True, the frame_timestamp is used.
64
+ # 1. pass time_position, 2. use env TIME_POSITION
65
+ time_position = os.environ.get("TIME_POSITION", time_position)
66
+ prefix = f"This is a video:\n"
67
+ for i in range(num_frames):
68
+ if time_position:
69
+ frame_txt = f"Frame {i+1} sampled at {meta_list[i]:.2f} seconds: <image>\n"
70
+ else:
71
+ frame_txt = f"Frame {i+1}: <image>\n"
72
+ prefix += frame_txt
73
+ return prefix
74
+
75
+ def load_video(video_path, num_frames=64, frame_cache_root=None):
76
+ if isinstance(video_path, str):
77
+ video = decord.VideoReader(video_path)
78
+ elif isinstance(video_path, dict):
79
+ assert False, 'we not support vidoe: "video_path" as input'
80
+ fps = video.get_avg_fps()
81
+ sampled_frames = get_seq_frames(len(video), num_frames)
82
+ samepld_timestamps = [i / fps for i in sampled_frames]
83
+ frames = video.get_batch(sampled_frames).asnumpy()
84
+ images = [Image.fromarray(frame) for frame in frames]
85
+
86
+ return images, build_video_prompt(samepld_timestamps, len(images), time_position=True)
87
+
88
+ def load_image(image):
89
+ if isinstance(image, str) and os.path.exists(image):
90
+ return Image.open(image)
91
+ elif isinstance(image, dict):
92
+ if 'disk_path' in image:
93
+ return Image.open(image['disk_path'])
94
+ elif 'base64' in image:
95
+ return Image.open(BytesIO(base64.b64decode(image['base64'])))
96
+ elif 'url' in image:
97
+ response = requests.get(image['url'])
98
+ return Image.open(BytesIO(response.content))
99
+ elif 'bytes' in image:
100
+ return Image.open(BytesIO(image['bytes']))
101
+ else:
102
+ raise ValueError(f'Invalid image: {image}')
103
+ else:
104
+ raise ValueError(f'Invalid image: {image}')
105
+
106
+ def build_transform(input_size, norm_type='imagenet'):
107
+ if norm_type == 'imagenet':
108
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
109
+ elif norm_type == 'siglip':
110
+ MEAN, STD = SIGLIP_MEAN, SIGLIP_STD
111
+
112
+ transform = T.Compose([
113
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
114
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
115
+ T.ToTensor(),
116
+ T.Normalize(mean=MEAN, std=STD)
117
+ ])
118
+ return transform
119
+
120
+
121
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
122
+ """
123
+ previous version mainly foucs on ratio.
124
+ We also consider area ratio here.
125
+ """
126
+ best_factor = float('-inf')
127
+ best_ratio = (1, 1)
128
+ area = width * height
129
+ for ratio in target_ratios:
130
+ target_aspect_ratio = ratio[0] / ratio[1]
131
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
132
+ area_ratio = (ratio[0]*ratio[1]*image_size*image_size)/ area
133
+ """
134
+ new area > 60% of original image area is enough.
135
+ """
136
+ factor_based_on_area_n_ratio = min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6)* \
137
+ min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio)
138
+
139
+ if factor_based_on_area_n_ratio > best_factor:
140
+ best_factor = factor_based_on_area_n_ratio
141
+ best_ratio = ratio
142
+
143
+ return best_ratio
144
+
145
+
146
+ def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
147
+ orig_width, orig_height = image.size
148
+ aspect_ratio = orig_width / orig_height
149
+
150
+ # calculate the existing image aspect ratio
151
+ target_ratios = set(
152
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
153
+ i * j <= max_num and i * j >= min_num)
154
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
155
+
156
+ # find the closest aspect ratio to the target
157
+ target_aspect_ratio = find_closest_aspect_ratio(
158
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
159
+
160
+ # calculate the target width and height
161
+ target_width = image_size * target_aspect_ratio[0]
162
+ target_height = image_size * target_aspect_ratio[1]
163
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
164
+
165
+ # resize the image
166
+ resized_img = image.resize((target_width, target_height))
167
+ processed_images = []
168
+ for i in range(blocks):
169
+ box = (
170
+ (i % (target_width // image_size)) * image_size,
171
+ (i // (target_width // image_size)) * image_size,
172
+ ((i % (target_width // image_size)) + 1) * image_size,
173
+ ((i // (target_width // image_size)) + 1) * image_size
174
+ )
175
+ # split the image
176
+ split_img = resized_img.crop(box)
177
+ processed_images.append(split_img)
178
+ assert len(processed_images) == blocks
179
+ if use_thumbnail and len(processed_images) != 1:
180
+ thumbnail_img = image.resize((image_size, image_size))
181
+ processed_images.append(thumbnail_img)
182
+ return processed_images
183
+
184
+ def split_model(model_path, device):
185
+
186
+ device_map = {}
187
+ world_size = torch.cuda.device_count()
188
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
189
+ num_layers = config.llm_config.num_hidden_layers
190
+
191
+ num_layers_per_gpu_ = math.floor(num_layers / (world_size - 1))
192
+ num_layers_per_gpu = [num_layers_per_gpu_] * world_size
193
+ num_layers_per_gpu[device] = num_layers - num_layers_per_gpu_ * (world_size-1)
194
+ layer_cnt = 0
195
+ for i, num_layer in enumerate(num_layers_per_gpu):
196
+ for j in range(num_layer):
197
+ device_map[f'language_model.model.layers.{layer_cnt}'] = i
198
+ layer_cnt += 1
199
+ device_map['vision_model'] = device
200
+ device_map['mlp1'] = device
201
+ device_map['language_model.model.tok_embeddings'] = device
202
+ device_map['language_model.model.embed_tokens'] = device
203
+ device_map['language_model.output'] = device
204
+ device_map['language_model.model.norm'] = device
205
+ device_map['language_model.lm_head'] = device
206
+ device_map['language_model.model.rotary_emb'] = device
207
+ device_map[f'language_model.model.layers.{num_layers - 1}'] = device
208
+ return device_map
209
+
210
+ class ModelWorker:
211
+ def __init__(self, model_path, model_name,
212
+ load_8bit, device):
213
+
214
+ if model_path.endswith('/'):
215
+ model_path = model_path[:-1]
216
+ if model_name is None:
217
+ model_paths = model_path.split('/')
218
+ if model_paths[-1].startswith('checkpoint-'):
219
+ self.model_name = model_paths[-2] + '_' + model_paths[-1]
220
+ else:
221
+ self.model_name = model_paths[-1]
222
+ else:
223
+ self.model_name = model_name
224
+
225
+ print(f'Loading the model {self.model_name}')
226
+
227
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
228
+ tokens_to_keep = ['<box>', '</box>', '<ref>', '</ref>']
229
+ tokenizer.additional_special_tokens = [item for item in tokenizer.additional_special_tokens if item not in tokens_to_keep]
230
+ self.tokenizer = tokenizer
231
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
232
+ model_type = config.vision_config.model_type
233
+ self.device = torch.cuda.current_device()
234
+ if model_type == 'siglip_vision_model':
235
+ self.norm_type = 'siglip'
236
+ elif model_type == 'MOB':
237
+ self.norm_type = 'siglip'
238
+ else:
239
+ self.norm_type = 'imagenet'
240
+
241
+ if any(x in model_path.lower() for x in ['34b']):
242
+ device_map = split_model(model_path, self.device)
243
+ else:
244
+ device_map = None
245
+
246
+ if device_map is not None:
247
+ self.model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16,
248
+ low_cpu_mem_usage=True,
249
+ device_map=device_map,
250
+ trust_remote_code=True,
251
+ load_in_8bit=load_8bit).eval()
252
+ else:
253
+ self.model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16,
254
+ trust_remote_code=True,
255
+ load_in_8bit=load_8bit).eval()
256
+ if not load_8bit and device_map is None:
257
+ self.model = self.model.to(device)
258
+ self.load_8bit = load_8bit
259
+
260
+ self.model_path = model_path
261
+ self.image_size = self.model.config.force_image_size
262
+ self.context_len = tokenizer.model_max_length
263
+ self.per_tile_len = 256
264
+
265
+ def reload_model(self):
266
+ del self.model
267
+ torch.cuda.empty_cache()
268
+ if self.device == 'auto':
269
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
270
+ # This can make distributed deployment work properly
271
+ self.model = AutoModel.from_pretrained(
272
+ self.model_path,
273
+ load_in_8bit=self.load_8bit,
274
+ torch_dtype=torch.bfloat16,
275
+ device_map=self.device_map,
276
+ trust_remote_code=True).eval()
277
+ else:
278
+ self.model = AutoModel.from_pretrained(
279
+ self.model_path,
280
+ load_in_8bit=self.load_8bit,
281
+ torch_dtype=torch.bfloat16,
282
+ trust_remote_code=True).eval()
283
+ if not self.load_8bit and not self.device == 'auto':
284
+ self.model = self.model.cuda()
285
+
286
+ @torch.inference_mode()
287
+ def generate(self, params):
288
+ system_message = params['prompt'][0]['content']
289
+ send_messages = params['prompt'][1:]
290
+ max_input_tiles = params['max_input_tiles']
291
+ temperature = params['temperature']
292
+ top_p = params['top_p']
293
+ max_new_tokens = params['max_new_tokens']
294
+ repetition_penalty = params['repetition_penalty']
295
+ video_frame_num = params.get('video_frame_num', 64)
296
+ do_sample = True if temperature > 0.0 else False
297
+
298
+ global_image_cnt = 0
299
+ history, pil_images, max_input_tile_list = [], [], []
300
+ for message in send_messages:
301
+ if message['role'] == 'user':
302
+ prefix = ''
303
+ if 'image' in message:
304
+ for image_data in message['image']:
305
+ pil_images.append(load_image(image_data))
306
+ prefix = prefix + f'<image {global_image_cnt + 1}><image>\n'
307
+ global_image_cnt += 1
308
+ max_input_tile_list.append(max_input_tiles)
309
+ if 'video' in message:
310
+ for video_data in message['video']:
311
+ video_frames, tmp_prefix = load_video(video_data, num_frames=video_frame_num)
312
+ pil_images.extend(video_frames)
313
+ prefix = prefix + tmp_prefix
314
+ global_image_cnt += len(video_frames)
315
+ max_input_tile_list.extend([1] * len(video_frames))
316
+ content = prefix + message['content']
317
+ history.append([content, ])
318
+ else:
319
+ history[-1].append(message['content'])
320
+ question, history = history[-1][0], history[:-1]
321
+
322
+ if global_image_cnt == 1:
323
+ question = question.replace('<image 1><image>\n', '<image>\n')
324
+ history = [[item[0].replace('<image 1><image>\n', '<image>\n'), item[1]] for item in history]
325
+
326
+
327
+ try:
328
+ assert len(max_input_tile_list) == len(pil_images), 'The number of max_input_tile_list and pil_images should be the same.'
329
+ except Exception as e:
330
+ from IPython import embed; embed()
331
+ exit()
332
+ print(f'Error: {e}')
333
+ print(f'max_input_tile_list: {max_input_tile_list}, pil_images: {pil_images}')
334
+ # raise e
335
+
336
+ old_system_message = self.model.system_message
337
+ self.model.system_message = system_message
338
+
339
+ transform = build_transform(input_size=self.image_size, norm_type=self.norm_type)
340
+ if len(pil_images) > 0:
341
+ max_input_tiles_limited_by_contect = params['max_input_tiles']
342
+ while True:
343
+ image_tiles = []
344
+ for current_max_input_tiles, pil_image in zip(max_input_tile_list, pil_images):
345
+ if self.model.config.dynamic_image_size:
346
+ tiles = dynamic_preprocess(
347
+ pil_image, image_size=self.image_size, max_num=min(current_max_input_tiles, max_input_tiles_limited_by_contect),
348
+ use_thumbnail=self.model.config.use_thumbnail)
349
+ else:
350
+ tiles = [pil_image]
351
+ image_tiles += tiles
352
+ if (len(image_tiles) * self.per_tile_len < self.context_len):
353
+ break
354
+ else:
355
+ max_input_tiles_limited_by_contect -= 2
356
+
357
+ if max_input_tiles_limited_by_contect < 1:
358
+ break
359
+
360
+ pixel_values = [transform(item) for item in image_tiles]
361
+ pixel_values = torch.stack(pixel_values).to(self.model.device, dtype=torch.bfloat16)
362
+
363
+ else:
364
+ pixel_values = None
365
+
366
+ generation_config = dict(
367
+ num_beams=1,
368
+ max_new_tokens=max_new_tokens,
369
+ do_sample=do_sample,
370
+ temperature=temperature,
371
+ repetition_penalty=repetition_penalty,
372
+ max_length=self.context_len,
373
+ top_p=top_p,
374
+ )
375
+
376
+ response = self.model.chat(
377
+ tokenizer=self.tokenizer,
378
+ pixel_values=pixel_values,
379
+ question=question,
380
+ history=history,
381
+ return_history=False,
382
+ generation_config=generation_config,
383
+ )
384
+ self.model.system_message = old_system_message
385
+ return {'text': response, 'error_code': 0}
386
+
387
+
388
+
389
+
390
+
391
+ if __name__ == '__main__':
392
+ parser = argparse.ArgumentParser()
393
+ parser.add_argument('--model-path', type=str, default='/home/zhidingy/workspace/eagle-next/internvl_chat/work_dirs/release/Eagle2-1B')
394
+ parser.add_argument('--model-name', type=str, default='Eagle2-1B')
395
+ parser.add_argument('--device', type=str, default='cuda')
396
+ parser.add_argument('--load-8bit', action='store_true')
397
+ args = parser.parse_args()
398
+ print(f'args: {args}')
399
+
400
+ worker = ModelWorker(
401
+ args.model_path,
402
+ args.model_name,
403
+ args.load_8bit,
404
+ args.device)
405
+ prompt = [
406
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
407
+ {'role': 'user', 'content': 'Describe this image in details.',
408
+ 'image':[
409
+ {'url': 'https://www.nvidia.com/content/dam/en-zz/Solutions/about-nvidia/logo-and-brand/[email protected]'}
410
+ ]
411
+ }
412
+ ]
413
+ params = {
414
+ 'prompt': prompt,
415
+ 'max_input_tiles': 24,
416
+ 'temperature': 0.7,
417
+ 'top_p': 1.0,
418
+ 'max_new_tokens': 4096,
419
+ 'repetition_penalty': 1.0,
420
+ }
421
+ print(worker.generate(params))
flash_attention.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/Dao-AILab/flash-attention/blob/v0.2.8/flash_attn/flash_attention.py
2
+ import torch
3
+ import torch.nn as nn
4
+ from einops import rearrange
5
+
6
+ try: # v1
7
+ from flash_attn.flash_attn_interface import \
8
+ flash_attn_unpadded_qkvpacked_func
9
+ except: # v2
10
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
11
+
12
+ from flash_attn.bert_padding import pad_input, unpad_input
13
+
14
+
15
+ class FlashAttention(nn.Module):
16
+ """Implement the scaled dot product attention with softmax.
17
+ Arguments
18
+ ---------
19
+ softmax_scale: The temperature to use for the softmax attention.
20
+ (default: 1/sqrt(d_keys) where d_keys is computed at
21
+ runtime)
22
+ attention_dropout: The dropout rate to apply to the attention
23
+ (default: 0.0)
24
+ """
25
+
26
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
27
+ super().__init__()
28
+ self.softmax_scale = softmax_scale
29
+ self.dropout_p = attention_dropout
30
+
31
+ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
32
+ max_s=None, need_weights=False):
33
+ """Implements the multihead softmax attention.
34
+ Arguments
35
+ ---------
36
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
37
+ if unpadded: (nnz, 3, h, d)
38
+ key_padding_mask: a bool tensor of shape (B, S)
39
+ """
40
+ assert not need_weights
41
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
42
+ assert qkv.is_cuda
43
+
44
+ if cu_seqlens is None:
45
+ batch_size = qkv.shape[0]
46
+ seqlen = qkv.shape[1]
47
+ if key_padding_mask is None:
48
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
49
+ max_s = seqlen
50
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
51
+ device=qkv.device)
52
+ output = flash_attn_unpadded_qkvpacked_func(
53
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
54
+ softmax_scale=self.softmax_scale, causal=causal
55
+ )
56
+ output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
57
+ else:
58
+ nheads = qkv.shape[-2]
59
+ x = rearrange(qkv, 'b s three h d -> b s (three h d)')
60
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
61
+ x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
62
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
63
+ x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
64
+ softmax_scale=self.softmax_scale, causal=causal
65
+ )
66
+ output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
67
+ indices, batch_size, seqlen),
68
+ 'b s (h d) -> b s h d', h=nheads)
69
+ else:
70
+ assert max_s is not None
71
+ output = flash_attn_unpadded_qkvpacked_func(
72
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
73
+ softmax_scale=self.softmax_scale, causal=causal
74
+ )
75
+
76
+ return output, None
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.37.2"
4
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a67b1a076346d99d202427f39c173b22cf79fd6c666dd4461d3cade62dfa38ed
3
+ size 2126511128
modeling_eagle_chat.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Eagle2
3
+ # Copyright (c) 2025 NVIDIA
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import warnings
8
+ from typing import Any, List, Optional, Tuple, Union
9
+
10
+ import torch.utils.checkpoint
11
+ import transformers
12
+ from torch import nn
13
+ from torch.nn import CrossEntropyLoss
14
+ from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
15
+ LlamaTokenizer)
16
+ from transformers.modeling_outputs import CausalLMOutputWithPast
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.utils import ModelOutput, logging
19
+ from peft import LoraConfig, get_peft_model
20
+ from .configuration_eagle_chat import Eagle2ChatConfig
21
+ from .conversation import get_conv_template
22
+ from .modeling_siglip import SiglipVisionModel
23
+ from .modeling_qwen2 import Qwen2ForCausalLM
24
+ from .flash_attention import *
25
+ from .multi_backbone_channel_concatentation_model import MultiBackboneChannelConcatenationVisionModel
26
+ from .multi_backbone_channel_concatenation_encoder import MultiBackboneChannelConcatenationVisionTower
27
+ from .configuration_multi_backbone_channel_concatentation_model import MultiBackboneChannelConcatenationVisionModelConfig
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ def version_cmp(v1, v2, op='eq'):
33
+ import operator
34
+
35
+ from packaging import version
36
+ op_func = getattr(operator, op)
37
+ return op_func(version.parse(v1), version.parse(v2))
38
+
39
+
40
+ class Eagle2ChatModel(PreTrainedModel):
41
+ config_class = Eagle2ChatConfig
42
+ main_input_name = 'pixel_values'
43
+ _no_split_modules = ['LlamaDecoderLayer']
44
+
45
+ def __init__(self, config: Eagle2ChatConfig, vision_model=None, language_model=None):
46
+ super().__init__(config)
47
+
48
+ assert version_cmp(transformers.__version__, '4.37.2', 'ge')
49
+ assert version_cmp(transformers.__version__, '4.39.2', 'le')
50
+ image_size = config.force_image_size or config.vision_config.image_size
51
+ if hasattr(config.vision_config, 'grid_size'):
52
+ grid_size = config.vision_config.grid_size
53
+ self.patch_size = 14
54
+ self.num_image_token = int((grid_size * config.downsample_ratio) ** 2)
55
+ else:
56
+ patch_size = config.vision_config.patch_size
57
+ self.patch_size = patch_size
58
+ self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
59
+
60
+ self.select_layer = config.select_layer
61
+ self.template = config.template
62
+
63
+ self.downsample_ratio = config.downsample_ratio
64
+
65
+ logger.info(f'num_image_token: {self.num_image_token}')
66
+ if vision_model is not None:
67
+ self.vision_model = vision_model
68
+ else:
69
+ if config.vision_config.model_type == 'siglip_vision_model':
70
+ self.vision_model = SiglipVisionModel(config.vision_config)
71
+ elif config.vision_config.model_type.startswith("MOB"):
72
+ self.vision_model = MultiBackboneChannelConcatenationVisionModel(config.vision_config, config)
73
+
74
+ if language_model is not None:
75
+ self.language_model = language_model
76
+ else:
77
+ if config.llm_config.architectures[0] == 'LlamaForCausalLM':
78
+ self.language_model = LlamaForCausalLM(config.llm_config)
79
+ elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM':
80
+ self.language_model = Qwen2ForCausalLM(config.llm_config)
81
+ else:
82
+ raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
83
+
84
+ vit_hidden_size = config.vision_config.hidden_size
85
+ if vit_hidden_size == 'lazy_calculation':
86
+ # a hack for Mixture of Backbones
87
+ vit_hidden_size = self.vision_model.hidden_size
88
+ print("The lazy calculated hidden_size: {} .. ".format(vit_hidden_size))
89
+ llm_hidden_size = config.llm_config.hidden_size
90
+ self.moe_version_type = getattr(config.vision_config, 'moe_version_type', None)
91
+
92
+ if self.moe_version_type in ['seq_concat', 'feat_concat']:
93
+ raise NotImplementedError
94
+ elif self.moe_version_type == 'convnext_512_siglip_448':
95
+ convnext_hidden_size = vit_hidden_size['convnext']
96
+ siglip_hidden_size = vit_hidden_size['siglip']
97
+ feature_concat_hidden_size = convnext_hidden_size + siglip_hidden_size * int(1 / self.downsample_ratio) ** 2
98
+ self.mlp1 = nn.Sequential(
99
+ nn.LayerNorm(feature_concat_hidden_size),
100
+ nn.Linear(feature_concat_hidden_size, llm_hidden_size),
101
+ nn.GELU(),
102
+ nn.Linear(llm_hidden_size, llm_hidden_size)
103
+ )
104
+ else:
105
+ self.mlp1 = nn.Sequential(
106
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
107
+ nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
108
+ nn.GELU(),
109
+ nn.Linear(llm_hidden_size, llm_hidden_size)
110
+ )
111
+ self.img_context_token_id = None
112
+ self.conv_template = get_conv_template(self.template)
113
+ self.system_message = self.conv_template.system_message
114
+
115
+ def forward(
116
+ self,
117
+ pixel_values: torch.FloatTensor,
118
+ input_ids: torch.LongTensor = None,
119
+ attention_mask: Optional[torch.Tensor] = None,
120
+ position_ids: Optional[torch.LongTensor] = None,
121
+ image_flags: Optional[torch.LongTensor] = None,
122
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
123
+ labels: Optional[torch.LongTensor] = None,
124
+ use_cache: Optional[bool] = None,
125
+ output_attentions: Optional[bool] = None,
126
+ output_hidden_states: Optional[bool] = None,
127
+ return_dict: Optional[bool] = None,
128
+ num_patches_list: Optional[List[torch.Tensor]] = None,
129
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
130
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
131
+
132
+ image_flags = image_flags.squeeze(-1)
133
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
134
+
135
+
136
+ if self.moe_version_type in ['seq_concat', 'feat_concat'] and not isinstance(pixel_values, dict):
137
+ raise NotImplementedError
138
+ vit_embeds = self.extract_feature(pixel_values)
139
+
140
+ if not isinstance(image_flags, list):
141
+ image_flags = image_flags.squeeze(-1)
142
+ vit_embeds = vit_embeds[image_flags == 1]
143
+ if isinstance(pixel_values, dict):
144
+ # for MOE
145
+ vit_batch_size = sum(pixel_values['num_patches'])
146
+ else:
147
+ vit_batch_size = pixel_values.shape[0]
148
+
149
+ B, N, C = input_embeds.shape
150
+ input_embeds = input_embeds.reshape(B * N, C)
151
+
152
+ if torch.distributed.get_rank() == 0:
153
+ print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
154
+
155
+ input_ids = input_ids.reshape(B * N)
156
+ selected = (input_ids == self.img_context_token_id)
157
+ try:
158
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
159
+ except Exception as e:
160
+ vit_embeds = vit_embeds.reshape(-1, C)
161
+ print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
162
+ f'vit_embeds.shape={vit_embeds.shape}')
163
+ n_token = selected.sum()
164
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
165
+
166
+ input_embeds = input_embeds.reshape(B, N, C)
167
+
168
+ outputs = self.language_model(
169
+ inputs_embeds=input_embeds,
170
+ attention_mask=attention_mask,
171
+ position_ids=position_ids,
172
+ past_key_values=past_key_values,
173
+ use_cache=use_cache,
174
+ output_attentions=output_attentions,
175
+ output_hidden_states=output_hidden_states,
176
+ return_dict=return_dict,
177
+ )
178
+ logits = outputs.logits
179
+
180
+ loss = None
181
+ if labels is not None:
182
+ # Shift so that tokens < n predict n
183
+ shift_logits = logits[..., :-1, :].contiguous()
184
+ shift_labels = labels[..., 1:].contiguous()
185
+ # Flatten the tokens
186
+ loss_fct = CrossEntropyLoss()
187
+ shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
188
+ shift_labels = shift_labels.view(-1)
189
+ # Enable model parallelism
190
+ shift_labels = shift_labels.to(shift_logits.device)
191
+ loss = loss_fct(shift_logits, shift_labels)
192
+
193
+ if not return_dict:
194
+ output = (logits,) + outputs[1:]
195
+ return (loss,) + output if loss is not None else output
196
+
197
+ return CausalLMOutputWithPast(
198
+ loss=loss,
199
+ logits=logits,
200
+ past_key_values=outputs.past_key_values,
201
+ hidden_states=outputs.hidden_states,
202
+ attentions=outputs.attentions,
203
+ )
204
+
205
+ def pixel_shuffle(self, x, scale_factor=0.5):
206
+ n, w, h, c = x.size()
207
+ # N, W, H, C --> N, W, H * scale, C // scale
208
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
209
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
210
+ x = x.permute(0, 2, 1, 3).contiguous()
211
+ # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
212
+ x = x.view(n, int(h * scale_factor), int(w * scale_factor),
213
+ int(c / (scale_factor * scale_factor)))
214
+ x = x.permute(0, 2, 1, 3).contiguous()
215
+ return x
216
+
217
+ def extract_feature(self, pixel_values):
218
+
219
+ """
220
+ """
221
+
222
+ if self.select_layer == -1:
223
+ vit_embeds = self.vision_model(
224
+ pixel_values=pixel_values,
225
+ output_hidden_states=False,
226
+ return_dict=True).last_hidden_state # torch.Size([B, 1025, 1024])
227
+
228
+ else:
229
+ vit_embeds = self.vision_model(
230
+ pixel_values=pixel_values,
231
+ output_hidden_states=True,
232
+ return_dict=True).hidden_states[self.select_layer]
233
+ if type(self.vision_model) == SiglipVisionModel:
234
+ pass
235
+ elif type(self.vision_model) == MultiBackboneChannelConcatenationVisionModel:
236
+ pass
237
+ else:
238
+ vit_embeds = vit_embeds[:, 1:, :] # torch.Size([B, 1024, 1024])
239
+
240
+ if self.training and self.neftune_alpha is not None:
241
+ vit_embeds = self.noised_embed(vit_embeds, self.neftune_alpha)
242
+
243
+ if self.moe_version_type in ['feat_concat', 'seq_concat']:
244
+ raise NotImplementedError
245
+ elif self.moe_version_type == 'convnext_512_siglip_448':
246
+ siglip_embeds = vit_embeds['siglip']
247
+ convnext_embeds = vit_embeds['convnext']
248
+ h = w = int(siglip_embeds.shape[1] ** 0.5)
249
+ siglip_embeds = siglip_embeds.reshape(siglip_embeds.shape[0], h, w, -1)
250
+ siglip_embeds = self.pixel_shuffle(siglip_embeds, scale_factor=self.downsample_ratio)
251
+ siglip_embeds = siglip_embeds.reshape(siglip_embeds.shape[0], -1, siglip_embeds.shape[-1])
252
+ vit_embeds = self.mlp1(torch.cat([siglip_embeds, convnext_embeds], dim=-1))
253
+ else:
254
+ h = w = int(vit_embeds.shape[1] ** 0.5)
255
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
256
+
257
+ vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) # torch.Size([B, 1024, 1024]) -> torch.Size([B, 16, 16, 4096])
258
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) # torch.Size([B, 16, 16, 4096]) -> torch.Size([B, 256, 4096])
259
+ vit_embeds = self.mlp1(vit_embeds)#.to(pixel_values.device)
260
+
261
+ return vit_embeds
262
+
263
+ def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
264
+ history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
265
+ IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
266
+ if history is not None or return_history:
267
+ print('Now multi-turn chat is not supported in batch_chat.')
268
+ raise NotImplementedError
269
+
270
+ if image_counts is not None:
271
+ num_patches_list = image_counts
272
+ print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')
273
+
274
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
275
+ self.img_context_token_id = img_context_token_id
276
+
277
+ if verbose and pixel_values is not None:
278
+ image_bs = pixel_values.shape[0]
279
+ print(f'dynamic ViT batch size: {image_bs}')
280
+
281
+ queries = []
282
+ for idx, num_patches in enumerate(num_patches_list):
283
+ question = questions[idx]
284
+ if pixel_values is not None and '<image>' not in question:
285
+ question = '<image>\n' + question
286
+ template = get_conv_template(self.template)
287
+ template.append_message(template.roles[0], question)
288
+ template.append_message(template.roles[1], None)
289
+ query = template.get_prompt()
290
+
291
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
292
+ query = query.replace('<image>', image_tokens, 1)
293
+ queries.append(query)
294
+
295
+ tokenizer.padding_side = 'left'
296
+ model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
297
+ input_ids = model_inputs['input_ids'].cuda()
298
+ attention_mask = model_inputs['attention_mask'].cuda()
299
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
300
+ generation_config['eos_token_id'] = eos_token_id
301
+ generation_output = self.generate(
302
+ pixel_values=pixel_values,
303
+ input_ids=input_ids,
304
+ attention_mask=attention_mask,
305
+ **generation_config
306
+ )
307
+ responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
308
+ responses = [response.split(template.sep)[0].strip() for response in responses]
309
+ return responses
310
+
311
+ def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
312
+ num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
313
+ verbose=False, llm_only=False):
314
+
315
+ if history is None and pixel_values is not None and '<image>' not in question:
316
+ question = '<image>\n' + question
317
+
318
+ if num_patches_list is None:
319
+ num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
320
+ assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
321
+
322
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
323
+ self.img_context_token_id = img_context_token_id
324
+
325
+ template = get_conv_template(self.template)
326
+ template.system_message = self.system_message
327
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
328
+
329
+ history = [] if history is None else history
330
+ for (old_question, old_answer) in history:
331
+ template.append_message(template.roles[0], old_question)
332
+ template.append_message(template.roles[1], old_answer)
333
+ template.append_message(template.roles[0], question)
334
+ template.append_message(template.roles[1], None)
335
+ query = template.get_prompt()
336
+
337
+ if verbose and pixel_values is not None:
338
+ image_bs = pixel_values.shape[0]
339
+ print(f'dynamic ViT batch size: {image_bs}')
340
+
341
+ for num_patches in num_patches_list:
342
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
343
+ if llm_only:
344
+ query = query.replace('<image>', '', 1)
345
+ else:
346
+ query = query.replace('<image>', image_tokens, 1)
347
+
348
+ model_inputs = tokenizer(query, return_tensors='pt')
349
+ input_ids = model_inputs['input_ids'].cuda()
350
+ attention_mask = model_inputs['attention_mask'].cuda()
351
+ generation_config['eos_token_id'] = eos_token_id
352
+ if self.moe_version_type is not None and self.moe_version_type != 'all_tiling' and self.moe_version_type != 'convnext_512_siglip_448':
353
+ pixel_values = {
354
+ 'pixel_values': pixel_values,
355
+ 'num_patches': num_patches_list # num patch of each image.
356
+ }
357
+ generation_output = self.generate(
358
+ pixel_values=pixel_values,
359
+ input_ids=input_ids,
360
+ attention_mask=attention_mask,
361
+ **generation_config
362
+ )
363
+ response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
364
+ response = response.split(template.sep)[0].strip()
365
+ history.append((question, response))
366
+ if return_history:
367
+ return response, history
368
+ else:
369
+ query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
370
+ query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
371
+ if verbose:
372
+ print(query_to_print, response)
373
+ return response
374
+
375
+ @torch.no_grad()
376
+ def generate(
377
+ self,
378
+ pixel_values: Optional[torch.FloatTensor] = None,
379
+ input_ids: Optional[torch.FloatTensor] = None,
380
+ attention_mask: Optional[torch.LongTensor] = None,
381
+ visual_features: Optional[torch.FloatTensor] = None,
382
+ generation_config: Optional[GenerationConfig] = None,
383
+ output_hidden_states: Optional[bool] = None,
384
+ return_dict: Optional[bool] = None,
385
+ **generate_kwargs,
386
+ ) -> torch.LongTensor:
387
+
388
+ assert self.img_context_token_id is not None
389
+ if pixel_values is not None:
390
+ if visual_features is not None:
391
+ vit_embeds = visual_features
392
+ else:
393
+ vit_embeds = self.extract_feature(pixel_values)
394
+
395
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
396
+ B, N, C = input_embeds.shape
397
+ input_embeds = input_embeds.reshape(B * N, C)
398
+
399
+ input_ids = input_ids.reshape(B * N)
400
+ selected = (input_ids == self.img_context_token_id)
401
+ assert selected.sum() != 0
402
+ input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
403
+
404
+ input_embeds = input_embeds.reshape(B, N, C)
405
+ else:
406
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
407
+
408
+ outputs = self.language_model.generate(
409
+ inputs_embeds=input_embeds,
410
+ attention_mask=attention_mask,
411
+ generation_config=generation_config,
412
+ output_hidden_states=output_hidden_states,
413
+ return_dict=return_dict,
414
+ use_cache=True,
415
+ **generate_kwargs,
416
+ )
417
+
418
+ return outputs
419
+
420
+ def get_input_embeddings(self):
421
+ return self.language_model.get_input_embeddings()
422
+
423
+ def get_output_embeddings(self):
424
+ return self.language_model.get_output_embeddings()
modeling_qwen2.py ADDED
@@ -0,0 +1,1744 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch Qwen2 model."""
21
+ import inspect
22
+ import math
23
+ import warnings
24
+ from typing import List, Optional, Tuple, Union
25
+
26
+ import torch
27
+ import torch.nn.functional as F
28
+ import torch.utils.checkpoint
29
+ from torch import nn
30
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
+
32
+ from transformers.activations import ACT2FN
33
+ from transformers.cache_utils import Cache, DynamicCache
34
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
35
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
36
+ from transformers.modeling_utils import PreTrainedModel
37
+ from transformers.utils import (
38
+ add_start_docstrings,
39
+ add_start_docstrings_to_model_forward,
40
+ is_flash_attn_2_available,
41
+ is_flash_attn_greater_or_equal_2_10,
42
+ logging,
43
+ replace_return_docstrings,
44
+ )
45
+ from .configuration_qwen2 import Qwen2Config
46
+
47
+
48
+ if is_flash_attn_2_available():
49
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
50
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
51
+
52
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
53
+
54
+
55
+ logger = logging.get_logger(__name__)
56
+
57
+
58
+ _CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
59
+ _CONFIG_FOR_DOC = "Qwen2Config"
60
+
61
+ QWEN2_PRETRAINED_MODEL_ARCHIVE_LIST = [
62
+ "Qwen/Qwen2-7B-beta",
63
+ # See all Qwen2 models at https://huggingface.co/models?filter=qwen2
64
+ ]
65
+
66
+
67
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
68
+ def _get_unpad_data(attention_mask):
69
+ seqlens_in_batch = (attention_mask>0).sum(dim=-1, dtype=torch.int32)
70
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
71
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
72
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
73
+ return (
74
+ indices,
75
+ cu_seqlens,
76
+ max_seqlen_in_batch,
77
+ )
78
+
79
+ def _get_unpad_data_packing(attention_mask, sub_sample_lengths):
80
+ seqlens_in_batch = []
81
+ for i, per_sub_sample_lengths in enumerate(sub_sample_lengths):
82
+ if (attention_mask[i]==0).sum() == per_sub_sample_lengths[-1]:
83
+ per_sub_sample_lengths = per_sub_sample_lengths[:-1]
84
+ seqlens_in_batch.extend(per_sub_sample_lengths)
85
+ seqlens_in_batch = torch.tensor(seqlens_in_batch, device=attention_mask.device, dtype=torch.int32)
86
+
87
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
88
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
89
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
90
+ return (
91
+ indices,
92
+ cu_seqlens,
93
+ max_seqlen_in_batch,
94
+ )
95
+
96
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2
97
+ class Qwen2RMSNorm(nn.Module):
98
+ def __init__(self, hidden_size, eps=1e-6):
99
+ """
100
+ Qwen2RMSNorm is equivalent to T5LayerNorm
101
+ """
102
+ super().__init__()
103
+ self.weight = nn.Parameter(torch.ones(hidden_size))
104
+ self.variance_epsilon = eps
105
+
106
+ def forward(self, hidden_states):
107
+ input_dtype = hidden_states.dtype
108
+ hidden_states = hidden_states.to(torch.float32)
109
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
110
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
111
+ return self.weight * hidden_states.to(input_dtype)
112
+
113
+
114
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2
115
+ class Qwen2RotaryEmbedding(nn.Module):
116
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
117
+ super().__init__()
118
+
119
+ self.dim = dim
120
+ self.max_position_embeddings = max_position_embeddings
121
+ self.base = base
122
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
123
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
124
+
125
+ # Build here to make `torch.jit.trace` work.
126
+ self._set_cos_sin_cache(
127
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
128
+ )
129
+
130
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
131
+ self.max_seq_len_cached = seq_len
132
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
133
+
134
+ freqs = torch.outer(t, self.inv_freq)
135
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
136
+ emb = torch.cat((freqs, freqs), dim=-1)
137
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
138
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
139
+
140
+ def forward(self, x, seq_len=None):
141
+ # x: [bs, num_attention_heads, seq_len, head_size]
142
+ if seq_len > self.max_seq_len_cached:
143
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
144
+
145
+ return (
146
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
147
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
148
+ )
149
+
150
+
151
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
152
+ def rotate_half(x):
153
+ """Rotates half the hidden dims of the input."""
154
+ x1 = x[..., : x.shape[-1] // 2]
155
+ x2 = x[..., x.shape[-1] // 2 :]
156
+ return torch.cat((-x2, x1), dim=-1)
157
+
158
+
159
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
160
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
161
+ """Applies Rotary Position Embedding to the query and key tensors.
162
+
163
+ Args:
164
+ q (`torch.Tensor`): The query tensor.
165
+ k (`torch.Tensor`): The key tensor.
166
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
167
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
168
+ position_ids (`torch.Tensor`):
169
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
170
+ used to pass offsetted position ids when working with a KV-cache.
171
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
172
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
173
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
174
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
175
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
176
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
177
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
178
+ Returns:
179
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
180
+ """
181
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
182
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
183
+ q_embed = (q * cos) + (rotate_half(q) * sin)
184
+ k_embed = (k * cos) + (rotate_half(k) * sin)
185
+ return q_embed, k_embed
186
+
187
+
188
+ # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
189
+ class Qwen2MLP(nn.Module):
190
+ def __init__(self, config):
191
+ super().__init__()
192
+ self.config = config
193
+ self.hidden_size = config.hidden_size
194
+ self.intermediate_size = config.intermediate_size
195
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
196
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
197
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
198
+ self.act_fn = ACT2FN[config.hidden_act]
199
+
200
+ def forward(self, x):
201
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
202
+
203
+
204
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
205
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
206
+ """
207
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
208
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
209
+ """
210
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
211
+ if n_rep == 1:
212
+ return hidden_states
213
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
214
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
215
+
216
+
217
+ class Qwen2Attention(nn.Module):
218
+ """
219
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
220
+ and "Generating Long Sequences with Sparse Transformers".
221
+ """
222
+
223
+ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
224
+ super().__init__()
225
+ self.config = config
226
+ self.layer_idx = layer_idx
227
+ if layer_idx is None:
228
+ logger.warning_once(
229
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
230
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
231
+ "when creating this class."
232
+ )
233
+
234
+ self.hidden_size = config.hidden_size
235
+ self.num_heads = config.num_attention_heads
236
+ self.head_dim = self.hidden_size // self.num_heads
237
+ self.num_key_value_heads = config.num_key_value_heads
238
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
239
+ self.max_position_embeddings = config.max_position_embeddings
240
+ self.rope_theta = config.rope_theta
241
+ self.is_causal = True
242
+ self.attention_dropout = config.attention_dropout
243
+
244
+ if (self.head_dim * self.num_heads) != self.hidden_size:
245
+ raise ValueError(
246
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
247
+ f" and `num_heads`: {self.num_heads})."
248
+ )
249
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
250
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
251
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
252
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
253
+
254
+ self.rotary_emb = Qwen2RotaryEmbedding(
255
+ self.head_dim,
256
+ max_position_embeddings=self.max_position_embeddings,
257
+ base=self.rope_theta,
258
+ )
259
+
260
+ def forward(
261
+ self,
262
+ hidden_states: torch.Tensor,
263
+ attention_mask: Optional[torch.Tensor] = None,
264
+ position_ids: Optional[torch.LongTensor] = None,
265
+ past_key_value: Optional[Cache] = None,
266
+ output_attentions: bool = False,
267
+ use_cache: bool = False,
268
+ **kwargs,
269
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
270
+ if "padding_mask" in kwargs:
271
+ warnings.warn(
272
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
273
+ )
274
+ bsz, q_len, _ = hidden_states.size()
275
+
276
+ query_states = self.q_proj(hidden_states)
277
+ key_states = self.k_proj(hidden_states)
278
+ value_states = self.v_proj(hidden_states)
279
+
280
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
281
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
282
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
283
+
284
+ kv_seq_len = key_states.shape[-2]
285
+ if past_key_value is not None:
286
+ if self.layer_idx is None:
287
+ raise ValueError(
288
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
289
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
290
+ "with a layer index."
291
+ )
292
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
293
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
294
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
295
+
296
+ if past_key_value is not None:
297
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
298
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
299
+
300
+ # repeat k/v heads if n_kv_heads < n_heads
301
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
302
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
303
+
304
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
305
+
306
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
307
+ raise ValueError(
308
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
309
+ f" {attn_weights.size()}"
310
+ )
311
+
312
+ if attention_mask is not None:
313
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
314
+ raise ValueError(
315
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
316
+ )
317
+
318
+ attn_weights = attn_weights + attention_mask
319
+
320
+ # upcast attention to fp32
321
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
322
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
323
+ attn_output = torch.matmul(attn_weights, value_states)
324
+
325
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
326
+ raise ValueError(
327
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
328
+ f" {attn_output.size()}"
329
+ )
330
+
331
+ attn_output = attn_output.transpose(1, 2).contiguous()
332
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
333
+
334
+ attn_output = self.o_proj(attn_output)
335
+
336
+ if not output_attentions:
337
+ attn_weights = None
338
+
339
+ return attn_output, attn_weights, past_key_value
340
+
341
+
342
+ class Qwen2FlashAttention2(Qwen2Attention):
343
+ """
344
+ Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
345
+ as the weights of the module stays untouched. The only required change would be on the forward pass
346
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
347
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
348
+ config.max_window_layers layers.
349
+ """
350
+
351
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
352
+ def __init__(self, *args, **kwargs):
353
+ super().__init__(*args, **kwargs)
354
+
355
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
356
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
357
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
358
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
359
+
360
+ def forward(
361
+ self,
362
+ hidden_states: torch.Tensor,
363
+ attention_mask: Optional[torch.Tensor] = None,
364
+ position_ids: Optional[torch.LongTensor] = None,
365
+ past_key_value: Optional[Cache] = None,
366
+ output_attentions: bool = False,
367
+ use_cache: bool = False,
368
+ **kwargs,
369
+ ):
370
+ if "padding_mask" in kwargs:
371
+ warnings.warn(
372
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
373
+ )
374
+
375
+ # overwrite attention_mask with padding_mask
376
+ attention_mask = kwargs.pop("padding_mask")
377
+ bsz, q_len, _ = hidden_states.size()
378
+
379
+ query_states = self.q_proj(hidden_states)
380
+ key_states = self.k_proj(hidden_states)
381
+ value_states = self.v_proj(hidden_states)
382
+
383
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
384
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
385
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
386
+
387
+ kv_seq_len = key_states.shape[-2]
388
+ if past_key_value is not None:
389
+ if self.layer_idx is None:
390
+ raise ValueError(
391
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
392
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
393
+ "with a layer index."
394
+ )
395
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
396
+
397
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
398
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
399
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
400
+
401
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
402
+
403
+ use_sliding_windows = (
404
+ _flash_supports_window_size
405
+ and getattr(self.config, "sliding_window", None) is not None
406
+ and kv_seq_len > self.config.sliding_window
407
+ and self.config.use_sliding_window
408
+ )
409
+
410
+ if not _flash_supports_window_size:
411
+ logger.warning_once(
412
+ "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
413
+ " make sure to upgrade flash-attn library."
414
+ )
415
+
416
+ if past_key_value is not None:
417
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
418
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
419
+ if (
420
+ getattr(self.config, "sliding_window", None) is not None
421
+ and kv_seq_len > self.config.sliding_window
422
+ and cache_has_contents
423
+ ):
424
+ slicing_tokens = 1 - self.config.sliding_window
425
+
426
+ past_key = past_key_value[self.layer_idx][0]
427
+ past_value = past_key_value[self.layer_idx][1]
428
+
429
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
430
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
431
+
432
+ if past_key.shape[-2] != self.config.sliding_window - 1:
433
+ raise ValueError(
434
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
435
+ f" {past_key.shape}"
436
+ )
437
+
438
+ if attention_mask is not None:
439
+ attention_mask = attention_mask[:, slicing_tokens:]
440
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
441
+
442
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
443
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
444
+
445
+ # repeat k/v heads if n_kv_heads < n_heads
446
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
447
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
448
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
449
+
450
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
451
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
452
+ # cast them back in float16 just to be sure everything works as expected.
453
+ input_dtype = query_states.dtype
454
+ if input_dtype == torch.float32:
455
+ if torch.is_autocast_enabled():
456
+ target_dtype = torch.get_autocast_gpu_dtype()
457
+ # Handle the case where the model is quantized
458
+ elif hasattr(self.config, "_pre_quantization_dtype"):
459
+ target_dtype = self.config._pre_quantization_dtype
460
+ else:
461
+ target_dtype = self.q_proj.weight.dtype
462
+
463
+ logger.warning_once(
464
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
465
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
466
+ f" {target_dtype}."
467
+ )
468
+
469
+ query_states = query_states.to(target_dtype)
470
+ key_states = key_states.to(target_dtype)
471
+ value_states = value_states.to(target_dtype)
472
+
473
+ # Reashape to the expected shape for Flash Attention
474
+ query_states = query_states.transpose(1, 2)
475
+ key_states = key_states.transpose(1, 2)
476
+ value_states = value_states.transpose(1, 2)
477
+
478
+ attn_output = self._flash_attention_forward(
479
+ query_states,
480
+ key_states,
481
+ value_states,
482
+ attention_mask,
483
+ q_len,
484
+ dropout=dropout_rate,
485
+ use_sliding_windows=use_sliding_windows,
486
+ )
487
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
488
+ attn_output = self.o_proj(attn_output)
489
+
490
+ if not output_attentions:
491
+ attn_weights = None
492
+
493
+ return attn_output, attn_weights, past_key_value
494
+
495
+ def _flash_attention_forward(
496
+ self,
497
+ query_states,
498
+ key_states,
499
+ value_states,
500
+ attention_mask,
501
+ query_length,
502
+ dropout=0.0,
503
+ softmax_scale=None,
504
+ use_sliding_windows=False,
505
+ ):
506
+ """
507
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
508
+ first unpad the input, then computes the attention scores and pad the final attention scores.
509
+
510
+ Args:
511
+ query_states (`torch.Tensor`):
512
+ Input query states to be passed to Flash Attention API
513
+ key_states (`torch.Tensor`):
514
+ Input key states to be passed to Flash Attention API
515
+ value_states (`torch.Tensor`):
516
+ Input value states to be passed to Flash Attention API
517
+ attention_mask (`torch.Tensor`):
518
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
519
+ position of padding tokens and 1 for the position of non-padding tokens.
520
+ dropout (`int`, *optional*):
521
+ Attention dropout
522
+ softmax_scale (`float`, *optional*):
523
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
524
+ use_sliding_windows (`bool`, *optional*):
525
+ Whether to activate sliding window attention.
526
+ """
527
+ if not self._flash_attn_uses_top_left_mask:
528
+ causal = self.is_causal
529
+ else:
530
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
531
+ causal = self.is_causal and query_length != 1
532
+
533
+ # Decide whether to use SWA or not by layer index.
534
+ if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
535
+ use_sliding_windows = False
536
+
537
+ # Contains at least one padding token in the sequence
538
+ if attention_mask is not None:
539
+ batch_size = query_states.shape[0]
540
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
541
+ query_states, key_states, value_states, attention_mask, query_length
542
+ )
543
+
544
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
545
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
546
+
547
+ if not use_sliding_windows:
548
+ attn_output_unpad = flash_attn_varlen_func(
549
+ query_states,
550
+ key_states,
551
+ value_states,
552
+ cu_seqlens_q=cu_seqlens_q,
553
+ cu_seqlens_k=cu_seqlens_k,
554
+ max_seqlen_q=max_seqlen_in_batch_q,
555
+ max_seqlen_k=max_seqlen_in_batch_k,
556
+ dropout_p=dropout,
557
+ softmax_scale=softmax_scale,
558
+ causal=causal,
559
+ )
560
+ else:
561
+ attn_output_unpad = flash_attn_varlen_func(
562
+ query_states,
563
+ key_states,
564
+ value_states,
565
+ cu_seqlens_q=cu_seqlens_q,
566
+ cu_seqlens_k=cu_seqlens_k,
567
+ max_seqlen_q=max_seqlen_in_batch_q,
568
+ max_seqlen_k=max_seqlen_in_batch_k,
569
+ dropout_p=dropout,
570
+ softmax_scale=softmax_scale,
571
+ causal=causal,
572
+ window_size=(self.config.sliding_window, self.config.sliding_window),
573
+ )
574
+
575
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
576
+ else:
577
+ if not use_sliding_windows:
578
+ attn_output = flash_attn_func(
579
+ query_states,
580
+ key_states,
581
+ value_states,
582
+ dropout,
583
+ softmax_scale=softmax_scale,
584
+ causal=causal,
585
+ )
586
+ else:
587
+ attn_output = flash_attn_func(
588
+ query_states,
589
+ key_states,
590
+ value_states,
591
+ dropout,
592
+ softmax_scale=softmax_scale,
593
+ causal=causal,
594
+ window_size=(self.config.sliding_window, self.config.sliding_window),
595
+ )
596
+
597
+ return attn_output
598
+
599
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input
600
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
601
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
602
+
603
+ # On the first iteration we need to properly re-create the padding mask
604
+ # by slicing it on the proper place
605
+ if kv_seq_len != attention_mask.shape[-1]:
606
+ attention_mask_num_tokens = attention_mask.shape[-1]
607
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
608
+
609
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
610
+
611
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
612
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
613
+
614
+ if query_length == kv_seq_len:
615
+ query_layer = index_first_axis(
616
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
617
+ )
618
+ cu_seqlens_q = cu_seqlens_k
619
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
620
+ indices_q = indices_k
621
+ elif query_length == 1:
622
+ max_seqlen_in_batch_q = 1
623
+ cu_seqlens_q = torch.arange(
624
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
625
+ ) # There is a memcpy here, that is very bad.
626
+ indices_q = cu_seqlens_q[:-1]
627
+ query_layer = query_layer.squeeze(1)
628
+ else:
629
+ # The -q_len: slice assumes left padding.
630
+ attention_mask = attention_mask[:, -query_length:]
631
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
632
+
633
+ return (
634
+ query_layer,
635
+ key_layer,
636
+ value_layer,
637
+ indices_q,
638
+ (cu_seqlens_q, cu_seqlens_k),
639
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
640
+ )
641
+ class Qwen2FlashAttention2_packing(Qwen2Attention):
642
+ """
643
+ Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
644
+ as the weights of the module stays untouched. The only required change would be on the forward pass
645
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
646
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
647
+ config.max_window_layers layers.
648
+ """
649
+
650
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
651
+ def __init__(self, *args, **kwargs):
652
+ super().__init__(*args, **kwargs)
653
+
654
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
655
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
656
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
657
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
658
+
659
+ def forward(
660
+ self,
661
+ hidden_states: torch.Tensor,
662
+ attention_mask: Optional[torch.Tensor] = None,
663
+ position_ids: Optional[torch.LongTensor] = None,
664
+ past_key_value: Optional[Cache] = None,
665
+ output_attentions: bool = False,
666
+ use_cache: bool = False,
667
+ sub_sample_lengths = None,
668
+ **kwargs,
669
+ ):
670
+ if "padding_mask" in kwargs:
671
+ warnings.warn(
672
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
673
+ )
674
+
675
+ # overwrite attention_mask with padding_mask
676
+ attention_mask = kwargs.pop("padding_mask")
677
+ bsz, q_len, _ = hidden_states.size()
678
+
679
+ query_states = self.q_proj(hidden_states)
680
+ key_states = self.k_proj(hidden_states)
681
+ value_states = self.v_proj(hidden_states)
682
+
683
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
684
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
685
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
686
+
687
+ kv_seq_len = key_states.shape[-2]
688
+ if past_key_value is not None:
689
+ if self.layer_idx is None:
690
+ raise ValueError(
691
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
692
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
693
+ "with a layer index."
694
+ )
695
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
696
+
697
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
698
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
699
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
700
+
701
+ if sub_sample_lengths is not None:
702
+ packing_position_ids = []
703
+ for b in range(bsz):
704
+ each_sum_sample_lengths = sub_sample_lengths[b]
705
+ packing_position_ids.append(torch.cat([torch.arange(each) for each in each_sum_sample_lengths]))
706
+ packing_position_ids = torch.stack(packing_position_ids)
707
+ packing_position_ids.to(query_states.device)
708
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, packing_position_ids)
709
+ else:
710
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
711
+
712
+ use_sliding_windows = (
713
+ _flash_supports_window_size
714
+ and getattr(self.config, "sliding_window", None) is not None
715
+ and kv_seq_len > self.config.sliding_window
716
+ and self.config.use_sliding_window
717
+ )
718
+
719
+ if not _flash_supports_window_size:
720
+ logger.warning_once(
721
+ "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
722
+ " make sure to upgrade flash-attn library."
723
+ )
724
+
725
+ if past_key_value is not None:
726
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
727
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
728
+ if (
729
+ getattr(self.config, "sliding_window", None) is not None
730
+ and kv_seq_len > self.config.sliding_window
731
+ and cache_has_contents
732
+ ):
733
+ slicing_tokens = 1 - self.config.sliding_window
734
+
735
+ past_key = past_key_value[self.layer_idx][0]
736
+ past_value = past_key_value[self.layer_idx][1]
737
+
738
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
739
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
740
+
741
+ if past_key.shape[-2] != self.config.sliding_window - 1:
742
+ raise ValueError(
743
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
744
+ f" {past_key.shape}"
745
+ )
746
+
747
+ if attention_mask is not None:
748
+ attention_mask = attention_mask[:, slicing_tokens:]
749
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
750
+
751
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
752
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
753
+
754
+ # repeat k/v heads if n_kv_heads < n_heads
755
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
756
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
757
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
758
+
759
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
760
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
761
+ # cast them back in float16 just to be sure everything works as expected.
762
+ input_dtype = query_states.dtype
763
+ if input_dtype == torch.float32:
764
+ if torch.is_autocast_enabled():
765
+ target_dtype = torch.get_autocast_gpu_dtype()
766
+ # Handle the case where the model is quantized
767
+ elif hasattr(self.config, "_pre_quantization_dtype"):
768
+ target_dtype = self.config._pre_quantization_dtype
769
+ else:
770
+ target_dtype = self.q_proj.weight.dtype
771
+
772
+ logger.warning_once(
773
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
774
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
775
+ f" {target_dtype}."
776
+ )
777
+
778
+ query_states = query_states.to(target_dtype)
779
+ key_states = key_states.to(target_dtype)
780
+ value_states = value_states.to(target_dtype)
781
+
782
+ # Reashape to the expected shape for Flash Attention
783
+ query_states = query_states.transpose(1, 2)
784
+ key_states = key_states.transpose(1, 2)
785
+ value_states = value_states.transpose(1, 2)
786
+
787
+ attn_output = self._flash_attention_forward(
788
+ query_states,
789
+ key_states,
790
+ value_states,
791
+ attention_mask,
792
+ q_len,
793
+ dropout=dropout_rate,
794
+ use_sliding_windows=use_sliding_windows,
795
+ sub_sample_lengths=sub_sample_lengths
796
+ )
797
+
798
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
799
+ attn_output = self.o_proj(attn_output)
800
+
801
+ if not output_attentions:
802
+ attn_weights = None
803
+
804
+ return attn_output, attn_weights, past_key_value
805
+
806
+ def _flash_attention_forward(
807
+ self,
808
+ query_states,
809
+ key_states,
810
+ value_states,
811
+ attention_mask,
812
+ query_length,
813
+ dropout=0.0,
814
+ softmax_scale=None,
815
+ use_sliding_windows=False,
816
+ sub_sample_lengths=None,
817
+ ):
818
+ """
819
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
820
+ first unpad the input, then computes the attention scores and pad the final attention scores.
821
+
822
+ Args:
823
+ query_states (`torch.Tensor`):
824
+ Input query states to be passed to Flash Attention API
825
+ key_states (`torch.Tensor`):
826
+ Input key states to be passed to Flash Attention API
827
+ value_states (`torch.Tensor`):
828
+ Input value states to be passed to Flash Attention API
829
+ attention_mask (`torch.Tensor`):
830
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
831
+ position of padding tokens and 1 for the position of non-padding tokens.
832
+ dropout (`int`, *optional*):
833
+ Attention dropout
834
+ softmax_scale (`float`, *optional*):
835
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
836
+ use_sliding_windows (`bool`, *optional*):
837
+ Whether to activate sliding window attention.
838
+ """
839
+ if not self._flash_attn_uses_top_left_mask:
840
+ causal = self.is_causal
841
+ else:
842
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
843
+ causal = self.is_causal and query_length != 1
844
+
845
+ # Decide whether to use SWA or not by layer index.
846
+ if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
847
+ use_sliding_windows = False
848
+
849
+ # Contains at least one padding token in the sequence
850
+
851
+ if attention_mask is not None:
852
+ batch_size = query_states.shape[0]
853
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input_packing(
854
+ query_states, key_states, value_states, attention_mask, query_length, sub_sample_lengths
855
+ )
856
+
857
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
858
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
859
+
860
+ if not use_sliding_windows:
861
+ attn_output_unpad = flash_attn_varlen_func(
862
+ query_states,
863
+ key_states,
864
+ value_states,
865
+ cu_seqlens_q=cu_seqlens_q,
866
+ cu_seqlens_k=cu_seqlens_k,
867
+ max_seqlen_q=max_seqlen_in_batch_q,
868
+ max_seqlen_k=max_seqlen_in_batch_k,
869
+ dropout_p=dropout,
870
+ softmax_scale=softmax_scale,
871
+ causal=causal,
872
+ )
873
+ else:
874
+ attn_output_unpad = flash_attn_varlen_func(
875
+ query_states,
876
+ key_states,
877
+ value_states,
878
+ cu_seqlens_q=cu_seqlens_q,
879
+ cu_seqlens_k=cu_seqlens_k,
880
+ max_seqlen_q=max_seqlen_in_batch_q,
881
+ max_seqlen_k=max_seqlen_in_batch_k,
882
+ dropout_p=dropout,
883
+ softmax_scale=softmax_scale,
884
+ causal=causal,
885
+ window_size=(self.config.sliding_window, self.config.sliding_window),
886
+ )
887
+
888
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
889
+ else:
890
+ if not use_sliding_windows:
891
+ attn_output = flash_attn_func(
892
+ query_states,
893
+ key_states,
894
+ value_states,
895
+ dropout,
896
+ softmax_scale=softmax_scale,
897
+ causal=causal,
898
+ )
899
+ else:
900
+ attn_output = flash_attn_func(
901
+ query_states,
902
+ key_states,
903
+ value_states,
904
+ dropout,
905
+ softmax_scale=softmax_scale,
906
+ causal=causal,
907
+ window_size=(self.config.sliding_window, self.config.sliding_window),
908
+ )
909
+
910
+ return attn_output
911
+
912
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input
913
+ def _unpad_input_packing(self, query_layer, key_layer, value_layer, attention_mask, query_length, sub_sample_lengths):
914
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
915
+
916
+ # On the first iteration we need to properly re-create the padding mask
917
+ # by slicing it on the proper place
918
+ if kv_seq_len != attention_mask.shape[-1]:
919
+ attention_mask_num_tokens = attention_mask.shape[-1]
920
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
921
+
922
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data_packing(attention_mask, sub_sample_lengths)
923
+
924
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
925
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
926
+
927
+ if query_length == kv_seq_len:
928
+ query_layer = index_first_axis(
929
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
930
+ )
931
+ cu_seqlens_q = cu_seqlens_k
932
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
933
+ indices_q = indices_k
934
+ elif query_length == 1:
935
+ max_seqlen_in_batch_q = 1
936
+ cu_seqlens_q = torch.arange(
937
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
938
+ ) # There is a memcpy here, that is very bad.
939
+ indices_q = cu_seqlens_q[:-1]
940
+ query_layer = query_layer.squeeze(1)
941
+ else:
942
+ # The -q_len: slice assumes left padding.
943
+ attention_mask = attention_mask[:, -query_length:]
944
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
945
+
946
+ return (
947
+ query_layer,
948
+ key_layer,
949
+ value_layer,
950
+ indices_q,
951
+ (cu_seqlens_q, cu_seqlens_k),
952
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
953
+ )
954
+
955
+
956
+ # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Qwen2
957
+ class Qwen2SdpaAttention(Qwen2Attention):
958
+ """
959
+ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
960
+ `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
961
+ SDPA API.
962
+ """
963
+
964
+ # Adapted from Qwen2Attention.forward
965
+ def forward(
966
+ self,
967
+ hidden_states: torch.Tensor,
968
+ attention_mask: Optional[torch.Tensor] = None,
969
+ position_ids: Optional[torch.LongTensor] = None,
970
+ past_key_value: Optional[Cache] = None,
971
+ output_attentions: bool = False,
972
+ use_cache: bool = False,
973
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
974
+ if output_attentions:
975
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
976
+ logger.warning_once(
977
+ "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
978
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
979
+ )
980
+ return super().forward(
981
+ hidden_states=hidden_states,
982
+ attention_mask=attention_mask,
983
+ position_ids=position_ids,
984
+ past_key_value=past_key_value,
985
+ output_attentions=output_attentions,
986
+ use_cache=use_cache,
987
+ )
988
+
989
+ bsz, q_len, _ = hidden_states.size()
990
+
991
+ query_states = self.q_proj(hidden_states)
992
+ key_states = self.k_proj(hidden_states)
993
+ value_states = self.v_proj(hidden_states)
994
+
995
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
996
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
997
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
998
+
999
+ kv_seq_len = key_states.shape[-2]
1000
+ if past_key_value is not None:
1001
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
1002
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
1003
+
1004
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
1005
+
1006
+ if past_key_value is not None:
1007
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
1008
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
1009
+
1010
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
1011
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
1012
+
1013
+ if attention_mask is not None:
1014
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
1015
+ raise ValueError(
1016
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
1017
+ )
1018
+
1019
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
1020
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
1021
+ if query_states.device.type == "cuda" and attention_mask is not None:
1022
+ query_states = query_states.contiguous()
1023
+ key_states = key_states.contiguous()
1024
+ value_states = value_states.contiguous()
1025
+
1026
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
1027
+ query_states,
1028
+ key_states,
1029
+ value_states,
1030
+ attn_mask=attention_mask,
1031
+ dropout_p=self.attention_dropout if self.training else 0.0,
1032
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
1033
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
1034
+ )
1035
+
1036
+ attn_output = attn_output.transpose(1, 2).contiguous()
1037
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
1038
+
1039
+ attn_output = self.o_proj(attn_output)
1040
+
1041
+ return attn_output, None, past_key_value
1042
+
1043
+
1044
+ QWEN2_ATTENTION_CLASSES = {
1045
+ "eager": Qwen2Attention,
1046
+ "flash_attention_2": Qwen2FlashAttention2,
1047
+ "sdpa": Qwen2SdpaAttention,
1048
+ 'flash_attention_2_packing':Qwen2FlashAttention2_packing
1049
+ }
1050
+
1051
+
1052
+ class Qwen2DecoderLayer(nn.Module):
1053
+ def __init__(self, config: Qwen2Config, layer_idx: int):
1054
+ super().__init__()
1055
+ self.hidden_size = config.hidden_size
1056
+
1057
+ if config.use_sliding_window and config.attn_implementation != "flash_attention_2":
1058
+ logger.warning_once(
1059
+ f"Sliding Window Attention is enabled but not implemented for `{config.attn_implementation}`; "
1060
+ "unexpected results may be encountered."
1061
+ )
1062
+
1063
+ self.self_attn = QWEN2_ATTENTION_CLASSES[config.attn_implementation](config, layer_idx)
1064
+
1065
+ self.mlp = Qwen2MLP(config)
1066
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1067
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1068
+
1069
+ def forward(
1070
+ self,
1071
+ hidden_states: torch.Tensor,
1072
+ attention_mask: Optional[torch.Tensor] = None,
1073
+ position_ids: Optional[torch.LongTensor] = None,
1074
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1075
+ sub_sample_lengths=None,
1076
+ output_attentions: Optional[bool] = False,
1077
+ use_cache: Optional[bool] = False,
1078
+ **kwargs,
1079
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1080
+ if "padding_mask" in kwargs:
1081
+ warnings.warn(
1082
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. "
1083
+ "Please make sure use `attention_mask` instead.`"
1084
+ )
1085
+ """
1086
+ Args:
1087
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1088
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
1089
+ `(batch, sequence_length)` where padding elements are indicated by 0.
1090
+ output_attentions (`bool`, *optional*):
1091
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1092
+ returned tensors for more detail.
1093
+ use_cache (`bool`, *optional*):
1094
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1095
+ (see `past_key_values`).
1096
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1097
+ """
1098
+
1099
+ residual = hidden_states
1100
+
1101
+ hidden_states = self.input_layernorm(hidden_states)
1102
+
1103
+ # Self Attention
1104
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
1105
+ hidden_states=hidden_states,
1106
+ attention_mask=attention_mask,
1107
+ position_ids=position_ids,
1108
+ past_key_value=past_key_value,
1109
+ output_attentions=output_attentions,
1110
+ use_cache=use_cache,
1111
+ sub_sample_lengths=sub_sample_lengths,
1112
+ )
1113
+ hidden_states = residual + hidden_states
1114
+
1115
+ # Fully Connected
1116
+ residual = hidden_states
1117
+ hidden_states = self.post_attention_layernorm(hidden_states)
1118
+ hidden_states = self.mlp(hidden_states)
1119
+ hidden_states = residual + hidden_states
1120
+
1121
+ outputs = (hidden_states,)
1122
+
1123
+ if output_attentions:
1124
+ outputs += (self_attn_weights,)
1125
+
1126
+ if use_cache:
1127
+ outputs += (present_key_value,)
1128
+
1129
+ return outputs
1130
+
1131
+
1132
+ QWEN2_START_DOCSTRING = r"""
1133
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1134
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1135
+ etc.)
1136
+
1137
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1138
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1139
+ and behavior.
1140
+
1141
+ Parameters:
1142
+ config ([`Qwen2Config`]):
1143
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1144
+ load the weights associated with the model, only the configuration. Check out the
1145
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1146
+ """
1147
+
1148
+
1149
+ @add_start_docstrings(
1150
+ "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
1151
+ QWEN2_START_DOCSTRING,
1152
+ )
1153
+ class Qwen2PreTrainedModel(PreTrainedModel):
1154
+ config_class = Qwen2Config
1155
+ base_model_prefix = "model"
1156
+ supports_gradient_checkpointing = True
1157
+ _no_split_modules = ["Qwen2DecoderLayer"]
1158
+ _skip_keys_device_placement = "past_key_values"
1159
+ _supports_flash_attn_2 = True
1160
+ _supports_sdpa = True
1161
+ _supports_cache_class = True
1162
+
1163
+ def _init_weights(self, module):
1164
+ std = self.config.initializer_range
1165
+ if isinstance(module, nn.Linear):
1166
+ module.weight.data.normal_(mean=0.0, std=std)
1167
+ if module.bias is not None:
1168
+ module.bias.data.zero_()
1169
+ elif isinstance(module, nn.Embedding):
1170
+ module.weight.data.normal_(mean=0.0, std=std)
1171
+ if module.padding_idx is not None:
1172
+ module.weight.data[module.padding_idx].zero_()
1173
+
1174
+
1175
+ QWEN2_INPUTS_DOCSTRING = r"""
1176
+ Args:
1177
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1178
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1179
+ it.
1180
+
1181
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1182
+ [`PreTrainedTokenizer.__call__`] for details.
1183
+
1184
+ [What are input IDs?](../glossary#input-ids)
1185
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1186
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1187
+
1188
+ - 1 for tokens that are **not masked**,
1189
+ - 0 for tokens that are **masked**.
1190
+
1191
+ [What are attention masks?](../glossary#attention-mask)
1192
+
1193
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1194
+ [`PreTrainedTokenizer.__call__`] for details.
1195
+
1196
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
1197
+ `past_key_values`).
1198
+
1199
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1200
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1201
+ information on the default strategy.
1202
+
1203
+ - 1 indicates the head is **not masked**,
1204
+ - 0 indicates the head is **masked**.
1205
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1206
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1207
+ config.n_positions - 1]`.
1208
+
1209
+ [What are position IDs?](../glossary#position-ids)
1210
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1211
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1212
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1213
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1214
+
1215
+ Two formats are allowed:
1216
+ - a [`~cache_utils.Cache`] instance;
1217
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1218
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1219
+ cache format.
1220
+
1221
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1222
+ legacy cache format will be returned.
1223
+
1224
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1225
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1226
+ of shape `(batch_size, sequence_length)`.
1227
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1228
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1229
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1230
+ model's internal embedding lookup matrix.
1231
+ use_cache (`bool`, *optional*):
1232
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1233
+ `past_key_values`).
1234
+ output_attentions (`bool`, *optional*):
1235
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1236
+ tensors for more detail.
1237
+ output_hidden_states (`bool`, *optional*):
1238
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1239
+ more detail.
1240
+ return_dict (`bool`, *optional*):
1241
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1242
+ """
1243
+
1244
+
1245
+ @add_start_docstrings(
1246
+ "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
1247
+ QWEN2_START_DOCSTRING,
1248
+ )
1249
+ class Qwen2Model(Qwen2PreTrainedModel):
1250
+ """
1251
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
1252
+
1253
+ Args:
1254
+ config: Qwen2Config
1255
+ """
1256
+
1257
+ def __init__(self, config: Qwen2Config):
1258
+ super().__init__(config)
1259
+ self.padding_idx = config.pad_token_id
1260
+ self.vocab_size = config.vocab_size
1261
+
1262
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1263
+ self.layers = nn.ModuleList(
1264
+ [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1265
+ )
1266
+ self.attn_implementation = config.attn_implementation
1267
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1268
+
1269
+ self.gradient_checkpointing = False
1270
+ # Initialize weights and apply final processing
1271
+ self.post_init()
1272
+
1273
+ def get_input_embeddings(self):
1274
+ return self.embed_tokens
1275
+
1276
+ def set_input_embeddings(self, value):
1277
+ self.embed_tokens = value
1278
+
1279
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1280
+ def forward(
1281
+ self,
1282
+ input_ids: torch.LongTensor = None,
1283
+ attention_mask: Optional[torch.Tensor] = None,
1284
+ position_ids: Optional[torch.LongTensor] = None,
1285
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1286
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1287
+ use_cache: Optional[bool] = None,
1288
+ output_attentions: Optional[bool] = None,
1289
+ output_hidden_states: Optional[bool] = None,
1290
+ return_dict: Optional[bool] = None,
1291
+ sub_sample_lengths=None,
1292
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1293
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1294
+ output_hidden_states = (
1295
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1296
+ )
1297
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1298
+
1299
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1300
+
1301
+ # retrieve input_ids and inputs_embeds
1302
+ if input_ids is not None and inputs_embeds is not None:
1303
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1304
+ elif input_ids is not None:
1305
+ batch_size, seq_length = input_ids.shape
1306
+ elif inputs_embeds is not None:
1307
+ batch_size, seq_length, _ = inputs_embeds.shape
1308
+ else:
1309
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1310
+
1311
+ if self.gradient_checkpointing and self.training:
1312
+ if use_cache:
1313
+ logger.warning_once(
1314
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1315
+ )
1316
+ use_cache = False
1317
+
1318
+ past_key_values_length = 0
1319
+
1320
+ if use_cache:
1321
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1322
+ if use_legacy_cache:
1323
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1324
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
1325
+
1326
+ if position_ids is None:
1327
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1328
+ position_ids = torch.arange(
1329
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1330
+ )
1331
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1332
+ else:
1333
+ position_ids = position_ids.view(-1, seq_length).long()
1334
+
1335
+ if inputs_embeds is None:
1336
+ inputs_embeds = self.embed_tokens(input_ids)
1337
+
1338
+ if attention_mask is not None and self.attn_implementation == "flash_attention_2" and use_cache:
1339
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1340
+ if is_padding_right:
1341
+ raise ValueError(
1342
+ "You are attempting to perform batched generation with padding_side='right'"
1343
+ " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
1344
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1345
+ )
1346
+
1347
+ if self.attn_implementation == "flash_attention_2" or self.config.attn_implementation =='flash_attention_2_packing':
1348
+ # 2d mask is passed through the layers
1349
+ if attention_mask is not None:
1350
+ if attention_mask.dtype == torch.long:
1351
+ pass
1352
+ # attention_mask = attention_mask
1353
+ else:
1354
+ attention_mask = attention_mask if (0 in attention_mask) else None
1355
+
1356
+ elif self.attn_implementation == "sdpa" and not output_attentions:
1357
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
1358
+ # the manual implementation that requires a 4D causal mask in all cases.
1359
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1360
+ attention_mask,
1361
+ (batch_size, seq_length),
1362
+ inputs_embeds,
1363
+ past_key_values_length,
1364
+ )
1365
+ else:
1366
+ # 4d mask is passed through the layers
1367
+ attention_mask = _prepare_4d_causal_attention_mask(
1368
+ attention_mask,
1369
+ (batch_size, seq_length),
1370
+ inputs_embeds,
1371
+ past_key_values_length,
1372
+ sliding_window=self.config.sliding_window,
1373
+ )
1374
+
1375
+ hidden_states = inputs_embeds
1376
+
1377
+ # decoder layers
1378
+ all_hidden_states = () if output_hidden_states else None
1379
+ all_self_attns = () if output_attentions else None
1380
+ next_decoder_cache = None
1381
+
1382
+ for decoder_layer in self.layers:
1383
+ if output_hidden_states:
1384
+ all_hidden_states += (hidden_states,)
1385
+ if self.gradient_checkpointing and self.training:
1386
+ layer_outputs = self._gradient_checkpointing_func(
1387
+ decoder_layer.__call__,
1388
+ hidden_states,
1389
+ attention_mask,
1390
+ position_ids,
1391
+ past_key_values,
1392
+ sub_sample_lengths,
1393
+ output_attentions,
1394
+ use_cache,
1395
+ )
1396
+ else:
1397
+ layer_outputs = decoder_layer(
1398
+ hidden_states,
1399
+ attention_mask=attention_mask,
1400
+ position_ids=position_ids,
1401
+ past_key_value=past_key_values,
1402
+ sub_sample_lengths=sub_sample_lengths,
1403
+ output_attentions=output_attentions,
1404
+ use_cache=use_cache,
1405
+ )
1406
+
1407
+ hidden_states = layer_outputs[0]
1408
+
1409
+ if use_cache:
1410
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1411
+
1412
+ if output_attentions:
1413
+ all_self_attns += (layer_outputs[1],)
1414
+
1415
+ hidden_states = self.norm(hidden_states)
1416
+
1417
+ # add hidden states from the last decoder layer
1418
+ if output_hidden_states:
1419
+ all_hidden_states += (hidden_states,)
1420
+
1421
+ next_cache = None
1422
+ if use_cache:
1423
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1424
+
1425
+ if not return_dict:
1426
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1427
+ return BaseModelOutputWithPast(
1428
+ last_hidden_state=hidden_states,
1429
+ past_key_values=next_cache,
1430
+ hidden_states=all_hidden_states,
1431
+ attentions=all_self_attns,
1432
+ )
1433
+
1434
+
1435
+ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1436
+ _tied_weights_keys = ["lm_head.weight"]
1437
+
1438
+ def __init__(self, config):
1439
+ super().__init__(config)
1440
+ self.model = Qwen2Model(config)
1441
+ self.vocab_size = config.vocab_size
1442
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1443
+
1444
+ # Initialize weights and apply final processing
1445
+ self.post_init()
1446
+ self.support_packing = True
1447
+
1448
+ def get_input_embeddings(self):
1449
+ return self.model.embed_tokens
1450
+
1451
+ def set_input_embeddings(self, value):
1452
+ self.model.embed_tokens = value
1453
+
1454
+ def get_output_embeddings(self):
1455
+ return self.lm_head
1456
+
1457
+ def set_output_embeddings(self, new_embeddings):
1458
+ self.lm_head = new_embeddings
1459
+
1460
+ def set_decoder(self, decoder):
1461
+ self.model = decoder
1462
+
1463
+ def get_decoder(self):
1464
+ return self.model
1465
+
1466
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1467
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1468
+ def forward(
1469
+ self,
1470
+ input_ids: torch.LongTensor = None,
1471
+ attention_mask: Optional[torch.Tensor] = None,
1472
+ position_ids: Optional[torch.LongTensor] = None,
1473
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1474
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1475
+ labels: Optional[torch.LongTensor] = None,
1476
+ use_cache: Optional[bool] = None,
1477
+ output_attentions: Optional[bool] = None,
1478
+ output_hidden_states: Optional[bool] = None,
1479
+ return_dict: Optional[bool] = None,
1480
+ sub_sample_lengths=None,
1481
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1482
+ r"""
1483
+ Args:
1484
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1485
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1486
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1487
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1488
+
1489
+ Returns:
1490
+
1491
+ Example:
1492
+
1493
+ ```python
1494
+ >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
1495
+
1496
+ >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1497
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1498
+
1499
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1500
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1501
+
1502
+ >>> # Generate
1503
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1504
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1505
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1506
+ ```"""
1507
+
1508
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1509
+ output_hidden_states = (
1510
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1511
+ )
1512
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1513
+
1514
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1515
+ outputs = self.model(
1516
+ input_ids=input_ids,
1517
+ attention_mask=attention_mask,
1518
+ position_ids=position_ids,
1519
+ past_key_values=past_key_values,
1520
+ inputs_embeds=inputs_embeds,
1521
+ use_cache=use_cache,
1522
+ output_attentions=output_attentions,
1523
+ output_hidden_states=output_hidden_states,
1524
+ return_dict=return_dict,
1525
+ sub_sample_lengths=sub_sample_lengths
1526
+ )
1527
+
1528
+ hidden_states = outputs[0]
1529
+ logits = self.lm_head(hidden_states)
1530
+ logits = logits.float()
1531
+
1532
+ loss = None
1533
+ if labels is not None:
1534
+ # Shift so that tokens < n predict n
1535
+ shift_logits = logits[..., :-1, :].contiguous()
1536
+ shift_labels = labels[..., 1:].contiguous()
1537
+ # Flatten the tokens
1538
+ loss_fct = CrossEntropyLoss()
1539
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1540
+ shift_labels = shift_labels.view(-1)
1541
+ # Enable model parallelism
1542
+ shift_labels = shift_labels.to(shift_logits.device)
1543
+ loss = loss_fct(shift_logits, shift_labels)
1544
+
1545
+ if not return_dict:
1546
+ output = (logits,) + outputs[1:]
1547
+ return (loss,) + output if loss is not None else output
1548
+
1549
+ return CausalLMOutputWithPast(
1550
+ loss=loss,
1551
+ logits=logits,
1552
+ past_key_values=outputs.past_key_values,
1553
+ hidden_states=outputs.hidden_states,
1554
+ attentions=outputs.attentions,
1555
+ )
1556
+
1557
+ def prepare_inputs_for_generation(
1558
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1559
+ ):
1560
+ # Omit tokens covered by past_key_values
1561
+ if past_key_values is not None:
1562
+ if isinstance(past_key_values, Cache):
1563
+ cache_length = past_key_values.get_seq_length()
1564
+ past_length = past_key_values.seen_tokens
1565
+ max_cache_length = past_key_values.get_max_length()
1566
+ else:
1567
+ cache_length = past_length = past_key_values[0][0].shape[2]
1568
+ max_cache_length = None
1569
+
1570
+ # Keep only the unprocessed tokens:
1571
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1572
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1573
+ # input)
1574
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1575
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1576
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1577
+ # input_ids based on the past_length.
1578
+ elif past_length < input_ids.shape[1]:
1579
+ input_ids = input_ids[:, past_length:]
1580
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1581
+
1582
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1583
+ if (
1584
+ max_cache_length is not None
1585
+ and attention_mask is not None
1586
+ and cache_length + input_ids.shape[1] > max_cache_length
1587
+ ):
1588
+ attention_mask = attention_mask[:, -max_cache_length:]
1589
+
1590
+ position_ids = kwargs.get("position_ids", None)
1591
+ if attention_mask is not None and position_ids is None:
1592
+ # create position_ids on the fly for batch generation
1593
+ position_ids = attention_mask.long().cumsum(-1) - 1
1594
+ position_ids.masked_fill_(attention_mask == 0, 1)
1595
+ if past_key_values:
1596
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1597
+
1598
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1599
+ if inputs_embeds is not None and past_key_values is None:
1600
+ model_inputs = {"inputs_embeds": inputs_embeds}
1601
+ else:
1602
+ model_inputs = {"input_ids": input_ids}
1603
+
1604
+ model_inputs.update(
1605
+ {
1606
+ "position_ids": position_ids,
1607
+ "past_key_values": past_key_values,
1608
+ "use_cache": kwargs.get("use_cache"),
1609
+ "attention_mask": attention_mask,
1610
+ }
1611
+ )
1612
+ return model_inputs
1613
+
1614
+ @staticmethod
1615
+ def _reorder_cache(past_key_values, beam_idx):
1616
+ reordered_past = ()
1617
+ for layer_past in past_key_values:
1618
+ reordered_past += (
1619
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1620
+ )
1621
+ return reordered_past
1622
+
1623
+
1624
+ @add_start_docstrings(
1625
+ """
1626
+ The Qwen2 Model transformer with a sequence classification head on top (linear layer).
1627
+
1628
+ [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1629
+ (e.g. GPT-2) do.
1630
+
1631
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1632
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1633
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1634
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1635
+ each row of the batch).
1636
+ """,
1637
+ QWEN2_START_DOCSTRING,
1638
+ )
1639
+ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
1640
+ def __init__(self, config):
1641
+ super().__init__(config)
1642
+ self.num_labels = config.num_labels
1643
+ self.model = Qwen2Model(config)
1644
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1645
+
1646
+ # Initialize weights and apply final processing
1647
+ self.post_init()
1648
+
1649
+ def get_input_embeddings(self):
1650
+ return self.model.embed_tokens
1651
+
1652
+ def set_input_embeddings(self, value):
1653
+ self.model.embed_tokens = value
1654
+
1655
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1656
+ def forward(
1657
+ self,
1658
+ input_ids: torch.LongTensor = None,
1659
+ attention_mask: Optional[torch.Tensor] = None,
1660
+ position_ids: Optional[torch.LongTensor] = None,
1661
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1662
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1663
+ labels: Optional[torch.LongTensor] = None,
1664
+ use_cache: Optional[bool] = None,
1665
+ output_attentions: Optional[bool] = None,
1666
+ output_hidden_states: Optional[bool] = None,
1667
+ return_dict: Optional[bool] = None,
1668
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1669
+ r"""
1670
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1671
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1672
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1673
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1674
+ """
1675
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1676
+
1677
+ transformer_outputs = self.model(
1678
+ input_ids,
1679
+ attention_mask=attention_mask,
1680
+ position_ids=position_ids,
1681
+ past_key_values=past_key_values,
1682
+ inputs_embeds=inputs_embeds,
1683
+ use_cache=use_cache,
1684
+ output_attentions=output_attentions,
1685
+ output_hidden_states=output_hidden_states,
1686
+ return_dict=return_dict,
1687
+ )
1688
+ hidden_states = transformer_outputs[0]
1689
+ logits = self.score(hidden_states)
1690
+
1691
+ if input_ids is not None:
1692
+ batch_size = input_ids.shape[0]
1693
+ else:
1694
+ batch_size = inputs_embeds.shape[0]
1695
+
1696
+ if self.config.pad_token_id is None and batch_size != 1:
1697
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1698
+ if self.config.pad_token_id is None:
1699
+ sequence_lengths = -1
1700
+ else:
1701
+ if input_ids is not None:
1702
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1703
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1704
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1705
+ sequence_lengths = sequence_lengths.to(logits.device)
1706
+ else:
1707
+ sequence_lengths = -1
1708
+
1709
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1710
+
1711
+ loss = None
1712
+ if labels is not None:
1713
+ labels = labels.to(logits.device)
1714
+ if self.config.problem_type is None:
1715
+ if self.num_labels == 1:
1716
+ self.config.problem_type = "regression"
1717
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1718
+ self.config.problem_type = "single_label_classification"
1719
+ else:
1720
+ self.config.problem_type = "multi_label_classification"
1721
+
1722
+ if self.config.problem_type == "regression":
1723
+ loss_fct = MSELoss()
1724
+ if self.num_labels == 1:
1725
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1726
+ else:
1727
+ loss = loss_fct(pooled_logits, labels)
1728
+ elif self.config.problem_type == "single_label_classification":
1729
+ loss_fct = CrossEntropyLoss()
1730
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1731
+ elif self.config.problem_type == "multi_label_classification":
1732
+ loss_fct = BCEWithLogitsLoss()
1733
+ loss = loss_fct(pooled_logits, labels)
1734
+ if not return_dict:
1735
+ output = (pooled_logits,) + transformer_outputs[1:]
1736
+ return ((loss,) + output) if loss is not None else output
1737
+
1738
+ return SequenceClassifierOutputWithPast(
1739
+ loss=loss,
1740
+ logits=pooled_logits,
1741
+ past_key_values=transformer_outputs.past_key_values,
1742
+ hidden_states=transformer_outputs.hidden_states,
1743
+ attentions=transformer_outputs.attentions,
1744
+ )
modeling_siglip.py ADDED
@@ -0,0 +1,1229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch Siglip model."""
16
+
17
+
18
+ import math
19
+ import warnings
20
+ from dataclasses import dataclass
21
+ from typing import Any, Optional, Tuple, Union
22
+ from einops import rearrange
23
+ import numpy as np
24
+ import torch
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.nn.init import _calculate_fan_in_and_fan_out
28
+
29
+ from transformers.activations import ACT2FN
30
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
31
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
32
+ from transformers.modeling_utils import PreTrainedModel
33
+ from transformers.utils import (
34
+ ModelOutput,
35
+ add_start_docstrings,
36
+ add_start_docstrings_to_model_forward,
37
+ logging,
38
+ replace_return_docstrings,
39
+ )
40
+ from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
41
+
42
+ try:
43
+ from .flash_attention import FlashAttention
44
+ has_flash_attn = True
45
+ except:
46
+ print('FlashAttention is not installed.')
47
+ has_flash_attn = False
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+ _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
52
+
53
+ SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
54
+ "google/siglip-base-patch16-224",
55
+ # See all SigLIP models at https://huggingface.co/models?filter=siglip
56
+ ]
57
+
58
+
59
+ def _trunc_normal_(tensor, mean, std, a, b):
60
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
61
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
62
+ def norm_cdf(x):
63
+ # Computes standard normal cumulative distribution function
64
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
65
+
66
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
67
+ warnings.warn(
68
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
69
+ "The distribution of values may be incorrect.",
70
+ stacklevel=2,
71
+ )
72
+
73
+ # Values are generated by using a truncated uniform distribution and
74
+ # then using the inverse CDF for the normal distribution.
75
+ # Get upper and lower cdf values
76
+ l = norm_cdf((a - mean) / std)
77
+ u = norm_cdf((b - mean) / std)
78
+
79
+ # Uniformly fill tensor with values from [l, u], then translate to
80
+ # [2l-1, 2u-1].
81
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
82
+
83
+ # Use inverse cdf transform for normal distribution to get truncated
84
+ # standard normal
85
+ tensor.erfinv_()
86
+
87
+ # Transform to proper mean, std
88
+ tensor.mul_(std * math.sqrt(2.0))
89
+ tensor.add_(mean)
90
+
91
+ # Clamp to ensure it's in the proper range
92
+ tensor.clamp_(min=a, max=b)
93
+
94
+
95
+ def trunc_normal_tf_(
96
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
97
+ ) -> torch.Tensor:
98
+ """Fills the input Tensor with values drawn from a truncated
99
+ normal distribution. The values are effectively drawn from the
100
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
101
+ with values outside :math:`[a, b]` redrawn until they are within
102
+ the bounds. The method used for generating the random values works
103
+ best when :math:`a \\leq \text{mean} \\leq b`.
104
+
105
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
106
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
107
+ and the result is subsquently scaled and shifted by the mean and std args.
108
+
109
+ Args:
110
+ tensor: an n-dimensional `torch.Tensor`
111
+ mean: the mean of the normal distribution
112
+ std: the standard deviation of the normal distribution
113
+ a: the minimum cutoff value
114
+ b: the maximum cutoff value
115
+ """
116
+ with torch.no_grad():
117
+ _trunc_normal_(tensor, 0, 1.0, a, b)
118
+ tensor.mul_(std).add_(mean)
119
+
120
+
121
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
122
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
123
+ if mode == "fan_in":
124
+ denom = fan_in
125
+ elif mode == "fan_out":
126
+ denom = fan_out
127
+ elif mode == "fan_avg":
128
+ denom = (fan_in + fan_out) / 2
129
+
130
+ variance = scale / denom
131
+
132
+ if distribution == "truncated_normal":
133
+ # constant is stddev of standard normal truncated to (-2, 2)
134
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
135
+ elif distribution == "normal":
136
+ with torch.no_grad():
137
+ tensor.normal_(std=math.sqrt(variance))
138
+ elif distribution == "uniform":
139
+ bound = math.sqrt(3 * variance)
140
+ with torch.no_grad():
141
+ tensor.uniform_(-bound, bound)
142
+ else:
143
+ raise ValueError(f"invalid distribution {distribution}")
144
+
145
+
146
+ def lecun_normal_(tensor):
147
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
148
+
149
+
150
+ def default_flax_embed_init(tensor):
151
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
152
+
153
+
154
+ @dataclass
155
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
156
+ class SiglipVisionModelOutput(ModelOutput):
157
+ """
158
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
159
+
160
+ Args:
161
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
162
+ The image embeddings obtained by applying the projection layer to the pooler_output.
163
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
164
+ Sequence of hidden-states at the output of the last layer of the model.
165
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
166
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
167
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
168
+
169
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
170
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
171
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
172
+ sequence_length)`.
173
+
174
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
175
+ heads.
176
+ """
177
+
178
+ image_embeds: Optional[torch.FloatTensor] = None
179
+ last_hidden_state: torch.FloatTensor = None
180
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
181
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
182
+
183
+
184
+ @dataclass
185
+ # Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip
186
+ class SiglipTextModelOutput(ModelOutput):
187
+ """
188
+ Base class for text model's outputs that also contains a pooling of the last hidden states.
189
+
190
+ Args:
191
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
192
+ The text embeddings obtained by applying the projection layer to the pooler_output.
193
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
194
+ Sequence of hidden-states at the output of the last layer of the model.
195
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
196
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
197
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
198
+
199
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
200
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
201
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
202
+ sequence_length)`.
203
+
204
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
205
+ heads.
206
+ """
207
+
208
+ text_embeds: Optional[torch.FloatTensor] = None
209
+ last_hidden_state: torch.FloatTensor = None
210
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
211
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
212
+
213
+
214
+ @dataclass
215
+ # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
216
+ class SiglipOutput(ModelOutput):
217
+ """
218
+ Args:
219
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
220
+ Contrastive loss for image-text similarity.
221
+ logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
222
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
223
+ similarity scores.
224
+ logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
225
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
226
+ similarity scores.
227
+ text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
228
+ The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
229
+ image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
230
+ The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
231
+ text_model_output(`BaseModelOutputWithPooling`):
232
+ The output of the [`SiglipTextModel`].
233
+ vision_model_output(`BaseModelOutputWithPooling`):
234
+ The output of the [`SiglipVisionModel`].
235
+ """
236
+
237
+ loss: Optional[torch.FloatTensor] = None
238
+ logits_per_image: torch.FloatTensor = None
239
+ logits_per_text: torch.FloatTensor = None
240
+ text_embeds: torch.FloatTensor = None
241
+ image_embeds: torch.FloatTensor = None
242
+ text_model_output: BaseModelOutputWithPooling = None
243
+ vision_model_output: BaseModelOutputWithPooling = None
244
+
245
+ def to_tuple(self) -> Tuple[Any]:
246
+ return tuple(
247
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
248
+ for k in self.keys()
249
+ )
250
+
251
+
252
+ class SiglipVisionEmbeddings(nn.Module):
253
+ def __init__(self, config: SiglipVisionConfig):
254
+ super().__init__()
255
+ self.config = config
256
+ self.embed_dim = config.hidden_size
257
+ self.image_size = config.image_size
258
+ self.patch_size = config.patch_size
259
+
260
+ self.patch_embedding = nn.Conv2d(
261
+ in_channels=config.num_channels,
262
+ out_channels=self.embed_dim,
263
+ kernel_size=self.patch_size,
264
+ stride=self.patch_size,
265
+ padding="valid",
266
+ )
267
+
268
+ self.num_patches = (self.image_size // self.patch_size) ** 2
269
+ self.num_positions = self.num_patches
270
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
271
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
272
+
273
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
274
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
275
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
276
+
277
+ embeddings = embeddings + self.position_embedding(self.position_ids)
278
+ return embeddings
279
+
280
+
281
+ # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip
282
+ class SiglipTextEmbeddings(nn.Module):
283
+ def __init__(self, config: SiglipTextConfig):
284
+ super().__init__()
285
+ embed_dim = config.hidden_size
286
+
287
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
288
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
289
+
290
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
291
+ self.register_buffer(
292
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
293
+ )
294
+
295
+ def forward(
296
+ self,
297
+ input_ids: Optional[torch.LongTensor] = None,
298
+ position_ids: Optional[torch.LongTensor] = None,
299
+ inputs_embeds: Optional[torch.FloatTensor] = None,
300
+ ) -> torch.Tensor:
301
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
302
+
303
+ if position_ids is None:
304
+ position_ids = self.position_ids[:, :seq_length]
305
+
306
+ if inputs_embeds is None:
307
+ inputs_embeds = self.token_embedding(input_ids)
308
+
309
+ position_embeddings = self.position_embedding(position_ids)
310
+ embeddings = inputs_embeds + position_embeddings
311
+
312
+ return embeddings
313
+
314
+
315
+ class SiglipAttention(nn.Module):
316
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
317
+
318
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
319
+ def __init__(self, config):
320
+ super().__init__()
321
+ self.config = config
322
+ self.embed_dim = config.hidden_size
323
+ self.num_heads = config.num_attention_heads
324
+ self.head_dim = self.embed_dim // self.num_heads
325
+ if self.head_dim * self.num_heads != self.embed_dim:
326
+ raise ValueError(
327
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
328
+ f" {self.num_heads})."
329
+ )
330
+ self.scale = self.head_dim**-0.5
331
+ self.dropout = config.attention_dropout
332
+
333
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
334
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
335
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
336
+ # self.use_flash_attn = config.use_flash_attn and has_flash_attn
337
+ self.use_flash_attn = True if has_flash_attn else False
338
+ if self.use_flash_attn:
339
+ self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
340
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
341
+
342
+ def _flash_attn(self,
343
+ hidden_states: torch.Tensor,
344
+ attention_mask: Optional[torch.Tensor] = None,
345
+ output_attentions: Optional[bool] = False,
346
+ key_padding_mask=None,
347
+ need_weights=False
348
+ ):
349
+
350
+ batch_size, q_len, _ = hidden_states.size()
351
+
352
+ query_states = self.q_proj(hidden_states)
353
+ key_states = self.k_proj(hidden_states)
354
+ value_states = self.v_proj(hidden_states)
355
+
356
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
357
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
358
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim)
359
+
360
+ qkv = torch.stack([query_states, key_states, value_states], dim=2)
361
+ context, attn_weights = self.inner_attn(
362
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
363
+ )
364
+ attn_output = self.out_proj(rearrange(context, 'b s h d -> b s (h d)'))
365
+
366
+ return attn_output, attn_weights
367
+
368
+ def forward(
369
+ self,
370
+ hidden_states: torch.Tensor,
371
+ attention_mask: Optional[torch.Tensor] = None,
372
+ output_attentions: Optional[bool] = False,
373
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
374
+ """Input shape: Batch x Time x Channel"""
375
+ if self.use_flash_attn:
376
+ return self._flash_attn(hidden_states)
377
+ batch_size, q_len, _ = hidden_states.size()
378
+
379
+ query_states = self.q_proj(hidden_states)
380
+ key_states = self.k_proj(hidden_states)
381
+ value_states = self.v_proj(hidden_states)
382
+
383
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
384
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
385
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
386
+
387
+ k_v_seq_len = key_states.shape[-2]
388
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
389
+
390
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
391
+ raise ValueError(
392
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
393
+ f" {attn_weights.size()}"
394
+ )
395
+
396
+ if attention_mask is not None:
397
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
398
+ raise ValueError(
399
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
400
+ )
401
+ attn_weights = attn_weights + attention_mask
402
+
403
+ # upcast attention to fp32
404
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
405
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
406
+ attn_output = torch.matmul(attn_weights, value_states)
407
+
408
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
409
+ raise ValueError(
410
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
411
+ f" {attn_output.size()}"
412
+ )
413
+
414
+ attn_output = attn_output.transpose(1, 2).contiguous()
415
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
416
+
417
+ attn_output = self.out_proj(attn_output)
418
+
419
+ return attn_output, attn_weights
420
+
421
+
422
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
423
+ class SiglipMLP(nn.Module):
424
+ def __init__(self, config):
425
+ super().__init__()
426
+ self.config = config
427
+ self.activation_fn = ACT2FN[config.hidden_act]
428
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
429
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
430
+
431
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
432
+ hidden_states = self.fc1(hidden_states)
433
+ hidden_states = self.activation_fn(hidden_states)
434
+ hidden_states = self.fc2(hidden_states)
435
+ return hidden_states
436
+
437
+
438
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
439
+ class SiglipEncoderLayer(nn.Module):
440
+ def __init__(self, config: SiglipConfig):
441
+ super().__init__()
442
+ self.embed_dim = config.hidden_size
443
+ self.self_attn = SiglipAttention(config)
444
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
445
+ self.mlp = SiglipMLP(config)
446
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
447
+
448
+ # Ignore copy
449
+ def forward(
450
+ self,
451
+ hidden_states: torch.Tensor,
452
+ attention_mask: torch.Tensor,
453
+ output_attentions: Optional[bool] = False,
454
+ ) -> Tuple[torch.FloatTensor]:
455
+ """
456
+ Args:
457
+ hidden_states (`torch.FloatTensor`):
458
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
459
+ attention_mask (`torch.FloatTensor`):
460
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
461
+ output_attentions (`bool`, *optional*, defaults to `False`):
462
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
463
+ returned tensors for more detail.
464
+ """
465
+ residual = hidden_states
466
+
467
+ hidden_states = self.layer_norm1(hidden_states)
468
+ hidden_states, attn_weights = self.self_attn(
469
+ hidden_states=hidden_states,
470
+ attention_mask=attention_mask,
471
+ output_attentions=output_attentions,
472
+ )
473
+ hidden_states = residual + hidden_states
474
+
475
+ residual = hidden_states
476
+ hidden_states = self.layer_norm2(hidden_states)
477
+ hidden_states = self.mlp(hidden_states)
478
+ hidden_states = residual + hidden_states
479
+
480
+ outputs = (hidden_states,)
481
+
482
+ if output_attentions:
483
+ outputs += (attn_weights,)
484
+
485
+ return outputs
486
+
487
+
488
+ class SiglipPreTrainedModel(PreTrainedModel):
489
+ """
490
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
491
+ models.
492
+ """
493
+
494
+ config_class = SiglipConfig
495
+ base_model_prefix = "siglip"
496
+ supports_gradient_checkpointing = True
497
+
498
+ def _init_weights(self, module):
499
+ """Initialize the weights"""
500
+ if isinstance(module, SiglipVisionEmbeddings):
501
+ width = (
502
+ self.config.vision_config.hidden_size
503
+ if isinstance(self.config, SiglipConfig)
504
+ else self.config.hidden_size
505
+ )
506
+ nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
507
+ elif isinstance(module, nn.Embedding):
508
+ default_flax_embed_init(module.weight)
509
+ elif isinstance(module, SiglipAttention):
510
+ nn.init.xavier_uniform_(module.q_proj.weight)
511
+ nn.init.xavier_uniform_(module.k_proj.weight)
512
+ nn.init.xavier_uniform_(module.v_proj.weight)
513
+ nn.init.xavier_uniform_(module.out_proj.weight)
514
+ nn.init.zeros_(module.q_proj.bias)
515
+ nn.init.zeros_(module.k_proj.bias)
516
+ nn.init.zeros_(module.v_proj.bias)
517
+ nn.init.zeros_(module.out_proj.bias)
518
+ elif isinstance(module, SiglipMLP):
519
+ nn.init.xavier_uniform_(module.fc1.weight)
520
+ nn.init.xavier_uniform_(module.fc2.weight)
521
+ nn.init.normal_(module.fc1.bias, std=1e-6)
522
+ nn.init.normal_(module.fc2.bias, std=1e-6)
523
+ elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
524
+ nn.init.xavier_uniform_(module.probe.data)
525
+ nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
526
+ nn.init.zeros_(module.attention.in_proj_bias.data)
527
+ elif isinstance(module, SiglipModel):
528
+ logit_scale_init = torch.log(torch.tensor(1.0))
529
+ module.logit_scale.data.fill_(logit_scale_init)
530
+ module.logit_bias.data.zero_()
531
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
532
+ lecun_normal_(module.weight)
533
+ if module.bias is not None:
534
+ nn.init.zeros_(module.bias)
535
+ elif isinstance(module, nn.LayerNorm):
536
+ module.bias.data.zero_()
537
+ module.weight.data.fill_(1.0)
538
+
539
+
540
+ SIGLIP_START_DOCSTRING = r"""
541
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
542
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
543
+ etc.)
544
+
545
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
546
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
547
+ and behavior.
548
+
549
+ Parameters:
550
+ config ([`SiglipConfig`]): Model configuration class with all the parameters of the model.
551
+ Initializing with a config file does not load the weights associated with the model, only the
552
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
553
+ """
554
+
555
+ SIGLIP_TEXT_INPUTS_DOCSTRING = r"""
556
+ Args:
557
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
558
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
559
+ it.
560
+
561
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
562
+ [`PreTrainedTokenizer.__call__`] for details.
563
+
564
+ [What are input IDs?](../glossary#input-ids)
565
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
566
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
567
+
568
+ - 1 for tokens that are **not masked**,
569
+ - 0 for tokens that are **masked**.
570
+
571
+ [What are attention masks?](../glossary#attention-mask)
572
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
573
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
574
+ config.max_position_embeddings - 1]`.
575
+
576
+ [What are position IDs?](../glossary#position-ids)
577
+ output_attentions (`bool`, *optional*):
578
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
579
+ tensors for more detail.
580
+ output_hidden_states (`bool`, *optional*):
581
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
582
+ more detail.
583
+ return_dict (`bool`, *optional*):
584
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
585
+ """
586
+
587
+ SIGLIP_VISION_INPUTS_DOCSTRING = r"""
588
+ Args:
589
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
590
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
591
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
592
+ output_attentions (`bool`, *optional*):
593
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
594
+ tensors for more detail.
595
+ output_hidden_states (`bool`, *optional*):
596
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
597
+ more detail.
598
+ return_dict (`bool`, *optional*):
599
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
600
+ """
601
+
602
+ SIGLIP_INPUTS_DOCSTRING = r"""
603
+ Args:
604
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
605
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
606
+ it.
607
+
608
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
609
+ [`PreTrainedTokenizer.__call__`] for details.
610
+
611
+ [What are input IDs?](../glossary#input-ids)
612
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
613
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
614
+
615
+ - 1 for tokens that are **not masked**,
616
+ - 0 for tokens that are **masked**.
617
+
618
+ [What are attention masks?](../glossary#attention-mask)
619
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
620
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
621
+ config.max_position_embeddings - 1]`.
622
+
623
+ [What are position IDs?](../glossary#position-ids)
624
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
625
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
626
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
627
+ return_loss (`bool`, *optional*):
628
+ Whether or not to return the contrastive loss.
629
+ output_attentions (`bool`, *optional*):
630
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
631
+ tensors for more detail.
632
+ output_hidden_states (`bool`, *optional*):
633
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
634
+ more detail.
635
+ return_dict (`bool`, *optional*):
636
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
637
+ """
638
+
639
+
640
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
641
+ class SiglipEncoder(nn.Module):
642
+ """
643
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
644
+ [`SiglipEncoderLayer`].
645
+
646
+ Args:
647
+ config: SiglipConfig
648
+ """
649
+
650
+ def __init__(self, config: SiglipConfig):
651
+ super().__init__()
652
+ self.config = config
653
+ self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
654
+ self.gradient_checkpointing = False
655
+
656
+ # Ignore copy
657
+ def forward(
658
+ self,
659
+ inputs_embeds,
660
+ attention_mask: Optional[torch.Tensor] = None,
661
+ output_attentions: Optional[bool] = None,
662
+ output_hidden_states: Optional[bool] = None,
663
+ return_dict: Optional[bool] = None,
664
+ ) -> Union[Tuple, BaseModelOutput]:
665
+ r"""
666
+ Args:
667
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
668
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
669
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
670
+ than the model's internal embedding lookup matrix.
671
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
672
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
673
+
674
+ - 1 for tokens that are **not masked**,
675
+ - 0 for tokens that are **masked**.
676
+
677
+ [What are attention masks?](../glossary#attention-mask)
678
+ output_attentions (`bool`, *optional*):
679
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
680
+ returned tensors for more detail.
681
+ output_hidden_states (`bool`, *optional*):
682
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
683
+ for more detail.
684
+ return_dict (`bool`, *optional*):
685
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
686
+ """
687
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
688
+ output_hidden_states = (
689
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
690
+ )
691
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
692
+
693
+ encoder_states = () if output_hidden_states else None
694
+ all_attentions = () if output_attentions else None
695
+
696
+ hidden_states = inputs_embeds
697
+ for encoder_layer in self.layers:
698
+ if output_hidden_states:
699
+ encoder_states = encoder_states + (hidden_states,)
700
+ if self.gradient_checkpointing and self.training:
701
+ layer_outputs = self._gradient_checkpointing_func(
702
+ encoder_layer.__call__,
703
+ hidden_states,
704
+ attention_mask,
705
+ output_attentions,
706
+ )
707
+ else:
708
+ layer_outputs = encoder_layer(
709
+ hidden_states,
710
+ attention_mask,
711
+ output_attentions=output_attentions,
712
+ )
713
+
714
+ hidden_states = layer_outputs[0]
715
+
716
+ if output_attentions:
717
+ all_attentions = all_attentions + (layer_outputs[1],)
718
+
719
+ if output_hidden_states:
720
+ encoder_states = encoder_states + (hidden_states,)
721
+
722
+ if not return_dict:
723
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
724
+ return BaseModelOutput(
725
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
726
+ )
727
+
728
+
729
+ class SiglipTextTransformer(nn.Module):
730
+ def __init__(self, config: SiglipTextConfig):
731
+ super().__init__()
732
+ self.config = config
733
+ embed_dim = config.hidden_size
734
+ self.embeddings = SiglipTextEmbeddings(config)
735
+ self.encoder = SiglipEncoder(config)
736
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
737
+
738
+ self.head = nn.Linear(embed_dim, embed_dim)
739
+
740
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
741
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
742
+ def forward(
743
+ self,
744
+ input_ids: Optional[torch.Tensor] = None,
745
+ attention_mask: Optional[torch.Tensor] = None,
746
+ position_ids: Optional[torch.Tensor] = None,
747
+ output_attentions: Optional[bool] = None,
748
+ output_hidden_states: Optional[bool] = None,
749
+ return_dict: Optional[bool] = None,
750
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
751
+ r"""
752
+ Returns:
753
+
754
+ """
755
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
756
+ output_hidden_states = (
757
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
758
+ )
759
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
760
+
761
+ if input_ids is None:
762
+ raise ValueError("You have to specify input_ids")
763
+
764
+ input_shape = input_ids.size()
765
+ input_ids = input_ids.view(-1, input_shape[-1])
766
+
767
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
768
+
769
+ # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
770
+ # expand attention_mask
771
+ if attention_mask is not None:
772
+ # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
773
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
774
+
775
+ encoder_outputs = self.encoder(
776
+ inputs_embeds=hidden_states,
777
+ attention_mask=attention_mask,
778
+ output_attentions=output_attentions,
779
+ output_hidden_states=output_hidden_states,
780
+ return_dict=return_dict,
781
+ )
782
+
783
+ last_hidden_state = encoder_outputs[0]
784
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
785
+
786
+ # Assuming "sticky" EOS tokenization, last token is always EOS.
787
+ pooled_output = last_hidden_state[:, -1, :]
788
+ pooled_output = self.head(pooled_output)
789
+
790
+ if not return_dict:
791
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
792
+
793
+ return BaseModelOutputWithPooling(
794
+ last_hidden_state=last_hidden_state,
795
+ pooler_output=pooled_output,
796
+ hidden_states=encoder_outputs.hidden_states,
797
+ attentions=encoder_outputs.attentions,
798
+ )
799
+
800
+
801
+ @add_start_docstrings(
802
+ """The text model from SigLIP without any head or projection on top.""",
803
+ SIGLIP_START_DOCSTRING,
804
+ )
805
+ class SiglipTextModel(SiglipPreTrainedModel):
806
+ config_class = SiglipTextConfig
807
+
808
+ _no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"]
809
+
810
+ def __init__(self, config: SiglipTextConfig):
811
+ super().__init__(config)
812
+ self.text_model = SiglipTextTransformer(config)
813
+ # Initialize weights and apply final processing
814
+ self.post_init()
815
+
816
+ def get_input_embeddings(self) -> nn.Module:
817
+ return self.text_model.embeddings.token_embedding
818
+
819
+ def set_input_embeddings(self, value):
820
+ self.text_model.embeddings.token_embedding = value
821
+
822
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
823
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
824
+ def forward(
825
+ self,
826
+ input_ids: Optional[torch.Tensor] = None,
827
+ attention_mask: Optional[torch.Tensor] = None,
828
+ position_ids: Optional[torch.Tensor] = None,
829
+ output_attentions: Optional[bool] = None,
830
+ output_hidden_states: Optional[bool] = None,
831
+ return_dict: Optional[bool] = None,
832
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
833
+ r"""
834
+ Returns:
835
+
836
+ Examples:
837
+
838
+ ```python
839
+ >>> from transformers import AutoTokenizer, SiglipTextModel
840
+
841
+ >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
842
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
843
+
844
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
845
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
846
+
847
+ >>> outputs = model(**inputs)
848
+ >>> last_hidden_state = outputs.last_hidden_state
849
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
850
+ ```"""
851
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
852
+
853
+ return self.text_model(
854
+ input_ids=input_ids,
855
+ attention_mask=attention_mask,
856
+ position_ids=position_ids,
857
+ output_attentions=output_attentions,
858
+ output_hidden_states=output_hidden_states,
859
+ return_dict=return_dict,
860
+ )
861
+
862
+
863
+ class SiglipVisionTransformer(nn.Module):
864
+ def __init__(self, config: SiglipVisionConfig):
865
+ super().__init__()
866
+ self.config = config
867
+ embed_dim = config.hidden_size
868
+
869
+ self.embeddings = SiglipVisionEmbeddings(config)
870
+ self.encoder = SiglipEncoder(config)
871
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
872
+ self.head = SiglipMultiheadAttentionPoolingHead(config)
873
+
874
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
875
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
876
+ def forward(
877
+ self,
878
+ pixel_values,
879
+ output_attentions: Optional[bool] = None,
880
+ output_hidden_states: Optional[bool] = None,
881
+ return_dict: Optional[bool] = None,
882
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
883
+ r"""
884
+ Returns:
885
+
886
+ """
887
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
888
+ output_hidden_states = (
889
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
890
+ )
891
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
892
+
893
+ hidden_states = self.embeddings(pixel_values)
894
+
895
+ encoder_outputs = self.encoder(
896
+ inputs_embeds=hidden_states,
897
+ output_attentions=output_attentions,
898
+ output_hidden_states=output_hidden_states,
899
+ return_dict=return_dict,
900
+ )
901
+
902
+ last_hidden_state = encoder_outputs[0]
903
+ last_hidden_state = self.post_layernorm(last_hidden_state)
904
+
905
+ pooled_output = self.head(last_hidden_state)
906
+
907
+ if not return_dict:
908
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
909
+
910
+ return BaseModelOutputWithPooling(
911
+ last_hidden_state=last_hidden_state,
912
+ pooler_output=pooled_output,
913
+ hidden_states=encoder_outputs.hidden_states,
914
+ attentions=encoder_outputs.attentions,
915
+ )
916
+
917
+
918
+ class SiglipMultiheadAttentionPoolingHead(nn.Module):
919
+ """Multihead Attention Pooling."""
920
+
921
+ def __init__(self, config: SiglipVisionConfig):
922
+ super().__init__()
923
+
924
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
925
+ self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
926
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
927
+ self.mlp = SiglipMLP(config)
928
+
929
+ def forward(self, hidden_state):
930
+ batch_size = hidden_state.shape[0]
931
+ probe = self.probe.repeat(batch_size, 1, 1)
932
+
933
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
934
+
935
+ residual = hidden_state
936
+ hidden_state = self.layernorm(hidden_state)
937
+ hidden_state = residual + self.mlp(hidden_state)
938
+
939
+ return hidden_state[:, 0]
940
+
941
+
942
+ @add_start_docstrings(
943
+ """The vision model from SigLIP without any head or projection on top.""",
944
+ SIGLIP_START_DOCSTRING,
945
+ )
946
+ class SiglipVisionModel(SiglipPreTrainedModel):
947
+ config_class = SiglipVisionConfig
948
+ main_input_name = "pixel_values"
949
+ _no_split_modules = [
950
+ "SiglipEncoderLayer",
951
+ "SiglipVisionEmbeddings",
952
+ "SiglipMultiheadAttentionPoolingHead",
953
+ ]
954
+
955
+ def __init__(self, config: SiglipVisionConfig):
956
+ super().__init__(config)
957
+
958
+ self.vision_model = SiglipVisionTransformer(config)
959
+
960
+ # Initialize weights and apply final processing
961
+ self.post_init()
962
+
963
+ def get_input_embeddings(self) -> nn.Module:
964
+ return self.vision_model.embeddings.patch_embedding
965
+
966
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
967
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
968
+ def forward(
969
+ self,
970
+ pixel_values,
971
+ output_attentions: Optional[bool] = None,
972
+ output_hidden_states: Optional[bool] = None,
973
+ return_dict: Optional[bool] = None,
974
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
975
+ r"""
976
+ Returns:
977
+
978
+ Examples:
979
+
980
+ ```python
981
+ >>> from PIL import Image
982
+ >>> import requests
983
+ >>> from transformers import AutoProcessor, SiglipVisionModel
984
+
985
+ >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
986
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
987
+
988
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
989
+ >>> image = Image.open(requests.get(url, stream=True).raw)
990
+
991
+ >>> inputs = processor(images=image, return_tensors="pt")
992
+
993
+ >>> outputs = model(**inputs)
994
+ >>> last_hidden_state = outputs.last_hidden_state
995
+ >>> pooled_output = outputs.pooler_output # pooled features
996
+ ```"""
997
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
998
+
999
+ return self.vision_model(
1000
+ pixel_values=pixel_values,
1001
+ output_attentions=output_attentions,
1002
+ output_hidden_states=output_hidden_states,
1003
+ return_dict=return_dict,
1004
+ )
1005
+
1006
+
1007
+ @add_start_docstrings(SIGLIP_START_DOCSTRING)
1008
+ class SiglipModel(SiglipPreTrainedModel):
1009
+ config_class = SiglipConfig
1010
+
1011
+ def __init__(self, config: SiglipConfig):
1012
+ super().__init__(config)
1013
+
1014
+ if not isinstance(config.text_config, SiglipTextConfig):
1015
+ raise ValueError(
1016
+ "config.text_config is expected to be of type SiglipTextConfig but is of type"
1017
+ f" {type(config.text_config)}."
1018
+ )
1019
+
1020
+ if not isinstance(config.vision_config, SiglipVisionConfig):
1021
+ raise ValueError(
1022
+ "config.vision_config is expected to be of type SiglipVisionConfig but is of type"
1023
+ f" {type(config.vision_config)}."
1024
+ )
1025
+
1026
+ text_config = config.text_config
1027
+ vision_config = config.vision_config
1028
+
1029
+ self.text_model = SiglipTextTransformer(text_config)
1030
+ self.vision_model = SiglipVisionTransformer(vision_config)
1031
+
1032
+ self.logit_scale = nn.Parameter(torch.randn(1))
1033
+ self.logit_bias = nn.Parameter(torch.randn(1))
1034
+
1035
+ # Initialize weights and apply final processing
1036
+ self.post_init()
1037
+
1038
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
1039
+ def get_text_features(
1040
+ self,
1041
+ input_ids: Optional[torch.Tensor] = None,
1042
+ attention_mask: Optional[torch.Tensor] = None,
1043
+ position_ids: Optional[torch.Tensor] = None,
1044
+ output_attentions: Optional[bool] = None,
1045
+ output_hidden_states: Optional[bool] = None,
1046
+ return_dict: Optional[bool] = None,
1047
+ ) -> torch.FloatTensor:
1048
+ r"""
1049
+ Returns:
1050
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1051
+ applying the projection layer to the pooled output of [`SiglipTextModel`].
1052
+
1053
+ Examples:
1054
+
1055
+ ```python
1056
+ >>> from transformers import AutoTokenizer, AutoModel
1057
+ >>> import torch
1058
+
1059
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1060
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
1061
+
1062
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
1063
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
1064
+ >>> with torch.no_grad():
1065
+ transformers. text_features = model.get_text_features(**inputs)
1066
+ ```"""
1067
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1068
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1069
+ output_hidden_states = (
1070
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1071
+ )
1072
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1073
+
1074
+ text_outputs = self.text_model(
1075
+ input_ids=input_ids,
1076
+ attention_mask=attention_mask,
1077
+ position_ids=position_ids,
1078
+ output_attentions=output_attentions,
1079
+ output_hidden_states=output_hidden_states,
1080
+ return_dict=return_dict,
1081
+ )
1082
+
1083
+ pooled_output = text_outputs[1]
1084
+
1085
+ return pooled_output
1086
+
1087
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1088
+ def get_image_features(
1089
+ self,
1090
+ pixel_values: Optional[torch.FloatTensor] = None,
1091
+ output_attentions: Optional[bool] = None,
1092
+ output_hidden_states: Optional[bool] = None,
1093
+ return_dict: Optional[bool] = None,
1094
+ ) -> torch.FloatTensor:
1095
+ r"""
1096
+ Returns:
1097
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1098
+ applying the projection layer to the pooled output of [`SiglipVisionModel`].
1099
+
1100
+ Examples:
1101
+
1102
+ ```python
1103
+ >>> from PIL import Image
1104
+ >>> import requests
1105
+ >>> from transformers import AutoProcessor, AutoModel
1106
+ >>> import torch
1107
+
1108
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1109
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1110
+
1111
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1112
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1113
+
1114
+ >>> inputs = processor(images=image, return_tensors="pt")
1115
+
1116
+ >>> with torch.no_grad():
1117
+ transformers. image_features = model.get_image_features(**inputs)
1118
+ ```"""
1119
+ # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.
1120
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1121
+ output_hidden_states = (
1122
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1123
+ )
1124
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1125
+
1126
+ vision_outputs = self.vision_model(
1127
+ pixel_values=pixel_values,
1128
+ output_attentions=output_attentions,
1129
+ output_hidden_states=output_hidden_states,
1130
+ return_dict=return_dict,
1131
+ )
1132
+
1133
+ pooled_output = vision_outputs[1]
1134
+
1135
+ return pooled_output
1136
+
1137
+ @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
1138
+ @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig)
1139
+ def forward(
1140
+ self,
1141
+ input_ids: Optional[torch.LongTensor] = None,
1142
+ pixel_values: Optional[torch.FloatTensor] = None,
1143
+ attention_mask: Optional[torch.Tensor] = None,
1144
+ position_ids: Optional[torch.LongTensor] = None,
1145
+ return_loss: Optional[bool] = None,
1146
+ output_attentions: Optional[bool] = None,
1147
+ output_hidden_states: Optional[bool] = None,
1148
+ return_dict: Optional[bool] = None,
1149
+ ) -> Union[Tuple, SiglipOutput]:
1150
+ r"""
1151
+ Returns:
1152
+
1153
+ Examples:
1154
+
1155
+ ```python
1156
+ >>> from PIL import Image
1157
+ >>> import requests
1158
+ >>> from transformers import AutoProcessor, AutoModel
1159
+ >>> import torch
1160
+
1161
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1162
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1163
+
1164
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1165
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1166
+
1167
+ >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
1168
+ >>> # important: we pass `padding=max_length` since the model was trained with this
1169
+ >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
1170
+
1171
+ >>> with torch.no_grad():
1172
+ transformers. outputs = model(**inputs)
1173
+
1174
+ >>> logits_per_image = outputs.logits_per_image
1175
+ >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
1176
+ >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
1177
+ 31.9% that image 0 is 'a photo of 2 cats'
1178
+ ```"""
1179
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1180
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1181
+ output_hidden_states = (
1182
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1183
+ )
1184
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1185
+
1186
+ vision_outputs = self.vision_model(
1187
+ pixel_values=pixel_values,
1188
+ output_attentions=output_attentions,
1189
+ output_hidden_states=output_hidden_states,
1190
+ return_dict=return_dict,
1191
+ )
1192
+
1193
+ text_outputs = self.text_model(
1194
+ input_ids=input_ids,
1195
+ attention_mask=attention_mask,
1196
+ position_ids=position_ids,
1197
+ output_attentions=output_attentions,
1198
+ output_hidden_states=output_hidden_states,
1199
+ return_dict=return_dict,
1200
+ )
1201
+
1202
+ image_embeds = vision_outputs[1]
1203
+ text_embeds = text_outputs[1]
1204
+
1205
+ # normalized features
1206
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
1207
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1208
+
1209
+ # cosine similarity as logits
1210
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale.exp() + self.logit_bias
1211
+ logits_per_image = logits_per_text.t()
1212
+
1213
+ loss = None
1214
+ if return_loss:
1215
+ raise NotImplementedError("SigLIP loss to be implemented")
1216
+
1217
+ if not return_dict:
1218
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
1219
+ return ((loss,) + output) if loss is not None else output
1220
+
1221
+ return SiglipOutput(
1222
+ loss=loss,
1223
+ logits_per_image=logits_per_image,
1224
+ logits_per_text=logits_per_text,
1225
+ text_embeds=text_embeds,
1226
+ image_embeds=image_embeds,
1227
+ text_model_output=text_outputs,
1228
+ vision_model_output=vision_outputs,
1229
+ )
multi_backbone_channel_concatenation_encoder.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+
5
+ from .siglip_vision_tower import SiglipVisionTower
6
+
7
+ # from .hr_clip_encoder import HRCLIPVisionTower
8
+ # from .eva_vit import EVAVITVisionTower
9
+ # from .SAM.modeling_sam import SAMVisionTower
10
+ # from .pix2struct_large import Pix2StructLargeVisionTower
11
+ import torch.nn.functional as F
12
+ from torch.nn.init import trunc_normal_
13
+ from copy import deepcopy
14
+ import random
15
+ import math
16
+
17
+ class MultiBackboneChannelConcatenationVisionTower(nn.Module):
18
+ def __init__(self,
19
+ vision_tower,
20
+ args,
21
+ grid_size=32,
22
+ convnext_img_size=1024,
23
+ normalize_type=None, raw_config=None):
24
+
25
+ super().__init__()
26
+
27
+ self.is_loaded = False
28
+ self.grid_size = grid_size
29
+ self.num_tokens = self.grid_size ** 2
30
+ self.normalize_type = args.normalize_type
31
+ self.moe_version_type = args.moe_version_type
32
+ self.raw_config = raw_config
33
+ print("moe_version_type: ", self.moe_version_type)
34
+ assert self.moe_version_type in [None, 'all_tiling', 'seq_concat', 'feat_concat', 'convnext_512_siglip_448'], f"Unknown self.moe_version_type: {self.moe_version_type}"
35
+
36
+ vision_tower_name_list = vision_tower.split(";")
37
+ self.input_image_size = 1024
38
+ self.convnext_img_size = convnext_img_size
39
+ self.load_vision_towers(vision_tower_name_list, args)
40
+
41
+
42
+ def load_vision_towers(self, vision_tower_name_list, args):
43
+ self.vision_towers = nn.ModuleList()
44
+
45
+ freeze_backbone_list = args.freeze_backbones # note this is a str
46
+ if freeze_backbone_list is not None and len(freeze_backbone_list) > 0:
47
+ print("The frozen backbones: ", freeze_backbone_list)
48
+ else:
49
+ # make it a blank str
50
+ freeze_backbone_list = ""
51
+
52
+ for name in vision_tower_name_list:
53
+
54
+ ## ConvNeXt
55
+ if name == 'convnext-1024':
56
+ convnext_args = deepcopy(args)
57
+
58
+ convnext_args.freeze_vision = False
59
+ if 'convnext-1024' in freeze_backbone_list:
60
+ convnext_args.freeze_vision = True
61
+
62
+ from .convnext_encoder import ConvNextVisionTower
63
+ convnext_args.input_image_size = self.convnext_img_size
64
+ convnext_vision_tower = args.vision_tower_convnext_path
65
+ convnext_vision_tower = ConvNextVisionTower(convnext_vision_tower,
66
+ convnext_args, delay_load=args.delay_load, normalize_type=self.normalize_type)
67
+ convnext_vision_tower.load_model()
68
+ self.vision_towers.append(convnext_vision_tower)
69
+
70
+ ## PaliSigLIP
71
+ elif name == 'palisiglip':
72
+ palisiglip_args = deepcopy(args)
73
+ palisiglip_args.input_image_size = 448
74
+
75
+ palisiglip_args.freeze_vision = False
76
+ if 'palisiglip' in freeze_backbone_list:
77
+ palisiglip_args.freeze_vision = True
78
+
79
+ palisiglip_vision_tower = SiglipVisionTower(args.vision_tower_siglip_path, palisiglip_args, delay_load=args.delay_load, raw_config=self.raw_config)
80
+
81
+ palisiglip_vision_tower.load_model()
82
+ self.vision_towers.append(palisiglip_vision_tower)
83
+
84
+ # Set the image processor
85
+ self.image_processor = None
86
+ self.is_loaded = True
87
+
88
+ def load_model(self):
89
+ assert self.is_loaded, "All the vision encoders should be loaded during initialization!"
90
+
91
+ def forward(self, x):
92
+ # x is a Tensor if moe_version_type is None or 'all_tiling'
93
+ # else is a tuple(Tensor, Tensor)
94
+ if self.moe_version_type in [None, 'all_tiling']:
95
+ # The default pipeline
96
+ features = []
97
+ image_input_size = x.shape[2]
98
+ assert x.shape[2] == x.shape[3], f"Image should be a square but size ({x.shape[2]} x {x.shape[3]})"
99
+ for vision_tower in self.vision_towers:
100
+
101
+ if vision_tower.input_image_size != image_input_size:
102
+ resized_x = F.interpolate(x.float(),
103
+ size=(vision_tower.input_image_size, vision_tower.input_image_size),
104
+ mode='bilinear',
105
+ align_corners=True).to(dtype=x.dtype)
106
+ else:
107
+ resized_x = x
108
+
109
+ feature = vision_tower(resized_x)
110
+
111
+ if len(feature.shape) == 3: # b, n, c
112
+ b, n, c = feature.shape
113
+ if n == self.num_tokens:
114
+ features.append(feature)
115
+ continue
116
+ w = h = int(n**0.5)
117
+ feature = feature.transpose(1,2).reshape(b, c, h, w)
118
+ else:
119
+ b, c, h, w = feature.shape
120
+
121
+ if w != self.grid_size:
122
+ feature = F.interpolate(feature.float(), size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=True).to(dtype=x.dtype)
123
+ features.append(feature.flatten(2,3).transpose(1,2))
124
+
125
+ features = torch.cat(features, dim=-1)
126
+ elif self.moe_version_type == 'convnext_512_siglip_448':
127
+ features = {}
128
+ image_input_size = x.shape[2]
129
+ assert x.shape[2] == x.shape[3], f"Image should be a square but size ({x.shape[2]} x {x.shape[3]})"
130
+ for vision_tower in self.vision_towers:
131
+
132
+ if vision_tower.input_image_size != image_input_size:
133
+ resized_x = F.interpolate(x.float(),
134
+ size=(vision_tower.input_image_size, vision_tower.input_image_size),
135
+ mode='bilinear',
136
+ align_corners=True).to(dtype=x.dtype)
137
+ else:
138
+ resized_x = x
139
+
140
+ feature = vision_tower(resized_x)
141
+
142
+ # if len(feature.shape) == 3: # b, n, c
143
+ # b, n, c = feature.shape
144
+ # if n == self.num_tokens:
145
+ # features.append(feature)
146
+ # continue
147
+ # w = h = int(n**0.5)
148
+ # feature = feature.transpose(1,2).reshape(b, c, h, w)
149
+ # else:
150
+ # b, c, h, w = feature.shape
151
+ features[vision_tower.name] = feature
152
+
153
+ else:
154
+ assert isinstance(x, dict), "x is expected to be a dict but {}".format(type(x))
155
+ pixel_values = x['pixel_values']
156
+ num_patches = x['num_patches'] # num patch of paddings token in texts
157
+
158
+ # calculated the real image patches
159
+ if self.moe_version_type == 'seq_concat':
160
+ image_in_num_patches = [i-1 for i in num_patches]
161
+ else:
162
+ image_in_num_patches = [i for i in num_patches]
163
+
164
+
165
+ assert sum(image_in_num_patches) == pixel_values.size(0), "sum(image_in_num_patches) ({}) != pixel_values.size(0) ({})".format(sum(image_in_num_patches), pixel_values.size(0))
166
+
167
+ # find the thubnail image id
168
+ thumbnail_image_id = torch.cumsum(torch.tensor(image_in_num_patches).to(pixel_values.device), 0) - 1
169
+ image_no_tiling = pixel_values[thumbnail_image_id]
170
+
171
+ # By default, we use the 1st vision_tower for x, others for x_nt
172
+ features = []
173
+ for layer_id, vision_tower in enumerate(self.vision_towers):
174
+ if layer_id == 0:
175
+ x = pixel_values
176
+ else:
177
+ x = image_no_tiling
178
+
179
+ if vision_tower.input_image_size != self.input_image_size:
180
+ resized_x = F.interpolate(x.float(),
181
+ size=(vision_tower.input_image_size, vision_tower.input_image_size),
182
+ mode='bilinear',
183
+ align_corners=True).to(dtype=x.dtype)
184
+ else:
185
+ resized_x = x
186
+
187
+ feature = vision_tower(resized_x)
188
+ if len(feature.shape) == 3: # b, n, c
189
+ b, n, c = feature.shape
190
+ if n == self.num_tokens:
191
+ features.append(feature)
192
+ continue
193
+
194
+ w = h = int(n**0.5)
195
+ feature = feature.transpose(1,2).reshape(b, c, h, w)
196
+ else:
197
+ b, c, h, w = feature.shape
198
+
199
+ if w != self.grid_size:
200
+ feature = F.interpolate(feature.float(), size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=True).to(dtype=x.dtype)
201
+ features.append(feature.flatten(2,3).transpose(1,2))
202
+
203
+ clip_embeds = features[0]
204
+ if len(features) <= 1:
205
+ no_tiling_embeds = None
206
+ else:
207
+ no_tiling_embeds = torch.cat(features[1:], dim=-1)
208
+
209
+ if self.moe_version_type == 'feat_concat':
210
+ # concat thumbnail images features together
211
+ clip_thumbnail_embeds = clip_embeds[thumbnail_image_id]
212
+ if no_tiling_embeds is not None:
213
+ no_tiling_embeds = torch.cat([clip_thumbnail_embeds, no_tiling_embeds], dim=-1)
214
+ else:
215
+ no_tiling_embeds = clip_thumbnail_embeds
216
+
217
+ # extra patch featureas
218
+ clip_embeds_mask = ~torch.isin(torch.arange(clip_embeds.shape[0]).to(clip_embeds.device), thumbnail_image_id)
219
+ clip_embeds = clip_embeds[clip_embeds_mask]
220
+
221
+
222
+ features = {
223
+ 'clip_embeds': clip_embeds,
224
+ 'no_tiling_embeds': no_tiling_embeds,
225
+ 'num_patches': num_patches
226
+ }
227
+
228
+ # features is a Tensor if not clip_tiling_only
229
+
230
+ return features
231
+
232
+ @property
233
+ def dummy_feature(self):
234
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
235
+
236
+ @property
237
+ def dtype(self):
238
+ return next(self.clip_vision_tower.parameters()).dtype
239
+
240
+ @property
241
+ def device(self):
242
+ return next(self.clip_vision_tower.parameters()).device
243
+
244
+ @property
245
+ def config(self):
246
+ assert NotImplementedError
247
+ pass
248
+
249
+ @property
250
+ def hidden_size(self):
251
+ if self.moe_version_type == 'convnext_512_siglip_448':
252
+ res = {}
253
+ for vision_tower in self.vision_towers:
254
+ res[vision_tower.name] = vision_tower.hidden_size
255
+ return res
256
+ else:
257
+ return sum([_.hidden_size for _ in self.vision_towers])
258
+
259
+ @property
260
+ def num_patches(self):
261
+ return self.num_tokens
262
+
263
+
264
+
multi_backbone_channel_concatentation_model.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
4
+ from typing import Optional, Tuple, Union
5
+
6
+ from .multi_backbone_channel_concatenation_encoder import MultiBackboneChannelConcatenationVisionTower
7
+ from .configuration_multi_backbone_channel_concatentation_model import MultiBackboneChannelConcatenationVisionModelConfig
8
+
9
+
10
+ class MultiBackboneChannelConcatenationVisionModel(nn.Module):
11
+
12
+ """
13
+ A vision model wrapper that concatenates channels from multiple backbones.
14
+
15
+ Args:
16
+ config (MultiBackboneChannelConcatenationVisionModelConfig): The configuration for the model.
17
+
18
+ Attributes:
19
+ vision_model (MultiBackboneChannelConcatenationVisionTower): The vision tower that performs the channel concatenation.
20
+
21
+ Notes:
22
+ **The class is not inherited from the PreTrainedModel in transformers**
23
+
24
+ """
25
+
26
+ config_class = MultiBackboneChannelConcatenationVisionModelConfig
27
+ main_input_name = "pixel_values"
28
+
29
+ def __init__(self, config: MultiBackboneChannelConcatenationVisionModelConfig, raw_config):
30
+ super().__init__()
31
+
32
+ self.vision_model = MultiBackboneChannelConcatenationVisionTower(
33
+ vision_tower=config.vision_tower,
34
+ args=config,
35
+ grid_size=config.grid_size,
36
+ convnext_img_size=config.convnext_img_size,
37
+ normalize_type=config.normalize_type,
38
+ raw_config=raw_config
39
+ )
40
+
41
+
42
+ def get_input_embeddings(self):
43
+ # You might need to adjust this depending on how you want to handle input embeddings
44
+ return self.vision_model.vision_towers[0].get_input_embeddings()
45
+
46
+ def forward(
47
+ self,
48
+ pixel_values,
49
+ return_dict: Optional[bool] = True,
50
+ output_hidden_states: Optional[bool] = False,
51
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
52
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
53
+
54
+ assert return_dict is True, "We only support return_dict"
55
+ assert output_hidden_states is False, "We do not support output_hidden_states"
56
+
57
+ features = self.vision_model(pixel_values)
58
+
59
+ # We only supports features as model outputs
60
+ return BaseModelOutputWithPooling(
61
+ last_hidden_state=features,
62
+ pooler_output=None,
63
+ hidden_states=None,
64
+ attentions=None,
65
+ )
66
+
67
+ @property
68
+ def dummy_feature(self):
69
+ return self.vision_model.dummy_feature
70
+
71
+ @property
72
+ def dtype(self):
73
+ return self.vision_model.dtype
74
+
75
+ @property
76
+ def device(self):
77
+ return self.vision_model.device
78
+
79
+ @property
80
+ def config(self):
81
+ return self.vision_model.config
82
+
83
+ @property
84
+ def hidden_size(self):
85
+ return self.vision_model.hidden_size
86
+
87
+ @property
88
+ def num_patches(self):
89
+ return self.vision_model.num_patches
siglip_vision_tower.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+
5
+ from .modeling_siglip import SiglipVisionModel
6
+ from .configuration_siglip import SiglipVisionConfig
7
+
8
+ import math
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from typing import List, Optional
12
+ import os
13
+
14
+ class SiglipVisionTower(nn.Module):
15
+ # We use the same wrapper as the default clip encoder.
16
+ # See `clip_encoder.py` in the same folder
17
+ def __init__(self, vision_tower, args, delay_load=False, raw_config=None):
18
+ super().__init__()
19
+
20
+ self.is_loaded = False
21
+ self.freeze_vision=args.freeze_vision
22
+ self.input_image_size=args.input_image_size
23
+ self.vision_tower_name = vision_tower
24
+ self.select_layer = args.mm_vision_select_layer
25
+ self.name = 'siglip'
26
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
27
+ self.delay_load = delay_load
28
+ self.raw_config = raw_config
29
+ if not delay_load:
30
+ self.load_model()
31
+ else:
32
+ if os.path.isfile(self.vision_tower_name):
33
+ self.cfg_only = SiglipVisionConfig.from_pretrained(self.vision_tower_name, local_files_only=True)
34
+ else:
35
+ self.cfg_only = SiglipVisionConfig(**self.raw_config.vision_config.siglip_vision_config)
36
+
37
+
38
+ def load_model(self):
39
+ if self.is_loaded:
40
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
41
+ return
42
+
43
+ # self.image_processor = SiglipImageProcessor(size=1024)
44
+ # self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name, local_files_only=True, torch_dtype=torch.bfloat16)
45
+ if self.delay_load:
46
+ # cfg = SiglipVisionConfig.from_pretrained(self.vision_tower_name, local_files_only=True)
47
+ self.vision_tower = SiglipVisionModel(self.cfg_only)
48
+ else:
49
+ self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name, local_files_only=True)
50
+
51
+ if self.freeze_vision:
52
+ self.vision_tower.requires_grad_(False)
53
+
54
+ self.vision_tower.vision_model.encoder.gradient_checkpointing = True
55
+ self.is_loaded = True
56
+
57
+ def forward(self, images):
58
+ return self.vision_tower(
59
+ pixel_values=images,
60
+ output_hidden_states=False,
61
+ return_dict=True).last_hidden_state
62
+
63
+
64
+ @property
65
+ def dummy_feature(self):
66
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
67
+
68
+ @property
69
+ def dtype(self):
70
+ return self.vision_tower.dtype
71
+
72
+ @property
73
+ def device(self):
74
+ return self.vision_tower.device
75
+
76
+ @property
77
+ def config(self):
78
+ if self.is_loaded:
79
+ return self.vision_tower.config
80
+ else:
81
+ return self.cfg_only
82
+
83
+ @property
84
+ def hidden_size(self):
85
+ return self.config.hidden_size
86
+
87
+ @property
88
+ def num_patches_per_side(self):
89
+ return self.config.image_size // self.config.patch_size
90
+
91
+ @property
92
+ def num_patches(self):
93
+ return (self.config.image_size // self.config.patch_size) ** 2
special_tokens_map.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>",
16
+ "<img>",
17
+ "</img>",
18
+ "<IMG_CONTEXT>",
19
+ "<quad>",
20
+ "</quad>",
21
+ "<ref>",
22
+ "</ref>",
23
+ "<box>",
24
+ "</box>"
25
+ ],
26
+ "eos_token": {
27
+ "content": "<|im_end|>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ },
33
+ "pad_token": {
34
+ "content": "<|endoftext|>",
35
+ "lstrip": false,
36
+ "normalized": false,
37
+ "rstrip": false,
38
+ "single_word": false
39
+ }
40
+ }
tokenization_qwen2.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for Qwen2."""
16
+
17
+ import json
18
+ import os
19
+ import unicodedata
20
+ from functools import lru_cache
21
+ from typing import Optional, Tuple
22
+
23
+ import regex as re
24
+
25
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
26
+ from transformers.utils import logging
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+ VOCAB_FILES_NAMES = {
32
+ "vocab_file": "vocab.json",
33
+ "merges_file": "merges.txt",
34
+ }
35
+
36
+ PRETRAINED_VOCAB_FILES_MAP = {
37
+ "vocab_file": {"qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/vocab.json"},
38
+ "merges_file": {"qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/merges.txt"},
39
+ }
40
+
41
+ MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
42
+
43
+ PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
44
+
45
+
46
+ @lru_cache()
47
+ # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
48
+ def bytes_to_unicode():
49
+ """
50
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
51
+ characters the bpe code barfs on.
52
+
53
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
54
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
55
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
56
+ tables between utf-8 bytes and unicode strings.
57
+ """
58
+ bs = (
59
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
60
+ )
61
+ cs = bs[:]
62
+ n = 0
63
+ for b in range(2**8):
64
+ if b not in bs:
65
+ bs.append(b)
66
+ cs.append(2**8 + n)
67
+ n += 1
68
+ cs = [chr(n) for n in cs]
69
+ return dict(zip(bs, cs))
70
+
71
+
72
+ # Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
73
+ def get_pairs(word):
74
+ """
75
+ Return set of symbol pairs in a word.
76
+
77
+ Word is represented as tuple of symbols (symbols being variable-length strings).
78
+ """
79
+ pairs = set()
80
+ prev_char = word[0]
81
+ for char in word[1:]:
82
+ pairs.add((prev_char, char))
83
+ prev_char = char
84
+ return pairs
85
+
86
+
87
+ class Qwen2Tokenizer(PreTrainedTokenizer):
88
+ """
89
+ Construct a Qwen2 tokenizer. Based on byte-level Byte-Pair-Encoding.
90
+
91
+ Same with GPT2Tokenzier, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
92
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
93
+
94
+ ```python
95
+ >>> from transformers import Qwen2Tokenizer
96
+
97
+ >>> tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen-tokenizer")
98
+ >>> tokenizer("Hello world")["input_ids"]
99
+ [9707, 1879]
100
+
101
+ >>> tokenizer(" Hello world")["input_ids"]
102
+ [21927, 1879]
103
+ ```
104
+ This is expected.
105
+
106
+ You should not use GPT2Tokenizer instead, because of the different pretokenization rules.
107
+
108
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
109
+ this superclass for more information regarding those methods.
110
+
111
+ Args:
112
+ vocab_file (`str`):
113
+ Path to the vocabulary file.
114
+ merges_file (`str`):
115
+ Path to the merges file.
116
+ errors (`str`, *optional*, defaults to `"replace"`):
117
+ Paradigm to follow when decoding bytes to UTF-8. See
118
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
119
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
120
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
121
+ token instead.
122
+ bos_token (`str`, *optional*):
123
+ The beginning of sequence token. Not applicable for this tokenizer.
124
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
125
+ The end of sequence token.
126
+ pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
127
+ The token used for padding, for example when batching sequences of different lengths.
128
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
129
+ Whether or not the model should cleanup the spaces that were added when splitting the input text during the
130
+ tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces.
131
+ split_special_tokens (`bool`, *optional*, defaults to `False`):
132
+ Whether or not the special tokens should be split during the tokenization process. The default behavior is
133
+ to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") =
134
+ ['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<',
135
+ '|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment.
136
+ """
137
+
138
+ vocab_files_names = VOCAB_FILES_NAMES
139
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
140
+ max_model_input_sizes = MAX_MODEL_INPUT_SIZES
141
+ model_input_names = ["input_ids", "attention_mask"]
142
+
143
+ def __init__(
144
+ self,
145
+ vocab_file,
146
+ merges_file,
147
+ errors="replace",
148
+ unk_token="<|endoftext|>",
149
+ bos_token=None,
150
+ eos_token="<|endoftext|>",
151
+ pad_token="<|endoftext|>",
152
+ clean_up_tokenization_spaces=False,
153
+ split_special_tokens=False,
154
+ **kwargs,
155
+ ):
156
+ # Qwen vocab does not contain control tokens; added tokens need to be special
157
+ bos_token = (
158
+ AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
159
+ if isinstance(bos_token, str)
160
+ else bos_token
161
+ )
162
+ eos_token = (
163
+ AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
164
+ if isinstance(eos_token, str)
165
+ else eos_token
166
+ )
167
+ unk_token = (
168
+ AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
169
+ if isinstance(unk_token, str)
170
+ else unk_token
171
+ )
172
+ pad_token = (
173
+ AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
174
+ if isinstance(pad_token, str)
175
+ else pad_token
176
+ )
177
+
178
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
179
+ self.encoder = json.load(vocab_handle)
180
+ self.decoder = {v: k for k, v in self.encoder.items()}
181
+ self.errors = errors # how to handle errors in decoding
182
+ self.byte_encoder = bytes_to_unicode()
183
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
184
+ bpe_merges = []
185
+ with open(merges_file, encoding="utf-8") as merges_handle:
186
+ for line in merges_handle:
187
+ line = line.strip()
188
+ if not line or line.startswith("#"):
189
+ continue
190
+ bpe_merges.append(tuple(line.split()))
191
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
192
+ # NOTE: the cache can grow without bound and will get really large for long running processes
193
+ # (esp. for texts of language that do not use space between word, e.g. Chinese); technically
194
+ # not a memory leak but appears as one.
195
+ # GPT2Tokenizer has the same problem, so let's be consistent.
196
+ self.cache = {}
197
+
198
+ self.pat = re.compile(PRETOKENIZE_REGEX)
199
+
200
+ if kwargs.get("add_prefix_space", False):
201
+ logger.warning_once(
202
+ f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
203
+ )
204
+
205
+ super().__init__(
206
+ errors=errors,
207
+ bos_token=bos_token,
208
+ eos_token=eos_token,
209
+ pad_token=pad_token,
210
+ unk_token=unk_token,
211
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
212
+ split_special_tokens=split_special_tokens,
213
+ **kwargs,
214
+ )
215
+
216
+ @property
217
+ def vocab_size(self) -> int:
218
+ return len(self.encoder)
219
+
220
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab
221
+ def get_vocab(self):
222
+ return dict(self.encoder, **self.added_tokens_encoder)
223
+
224
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe
225
+ def bpe(self, token):
226
+ if token in self.cache:
227
+ return self.cache[token]
228
+ word = tuple(token)
229
+ pairs = get_pairs(word)
230
+
231
+ if not pairs:
232
+ return token
233
+
234
+ while True:
235
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
236
+ if bigram not in self.bpe_ranks:
237
+ break
238
+ first, second = bigram
239
+ new_word = []
240
+ i = 0
241
+ while i < len(word):
242
+ try:
243
+ j = word.index(first, i)
244
+ except ValueError:
245
+ new_word.extend(word[i:])
246
+ break
247
+ else:
248
+ new_word.extend(word[i:j])
249
+ i = j
250
+
251
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
252
+ new_word.append(first + second)
253
+ i += 2
254
+ else:
255
+ new_word.append(word[i])
256
+ i += 1
257
+ new_word = tuple(new_word)
258
+ word = new_word
259
+ if len(word) == 1:
260
+ break
261
+ else:
262
+ pairs = get_pairs(word)
263
+ word = " ".join(word)
264
+ self.cache[token] = word
265
+ return word
266
+
267
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize
268
+ def _tokenize(self, text):
269
+ """Tokenize a string."""
270
+ bpe_tokens = []
271
+ for token in re.findall(self.pat, text):
272
+ token = "".join(
273
+ self.byte_encoder[b] for b in token.encode("utf-8")
274
+ ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
275
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
276
+ return bpe_tokens
277
+
278
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id
279
+ def _convert_token_to_id(self, token):
280
+ """Converts a token (str) in an id using the vocab."""
281
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
282
+
283
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token
284
+ def _convert_id_to_token(self, index):
285
+ """Converts an index (integer) in a token (str) using the vocab."""
286
+ return self.decoder.get(index)
287
+
288
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string
289
+ def convert_tokens_to_string(self, tokens):
290
+ """Converts a sequence of tokens (string) in a single string."""
291
+ text = "".join(tokens)
292
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
293
+ return text
294
+
295
+ def decode(
296
+ self,
297
+ token_ids,
298
+ skip_special_tokens: bool = False,
299
+ clean_up_tokenization_spaces: Optional[bool] = False,
300
+ spaces_between_special_tokens: bool = False,
301
+ **kwargs,
302
+ ) -> str:
303
+ # `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers
304
+ # and cannot be configured elsewhere, but it should default to False for Qwen2Tokenizer
305
+ return super().decode(
306
+ token_ids,
307
+ skip_special_tokens=skip_special_tokens,
308
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
309
+ spaces_between_special_tokens=spaces_between_special_tokens,
310
+ **kwargs,
311
+ )
312
+
313
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary
314
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
315
+ if not os.path.isdir(save_directory):
316
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
317
+ return
318
+ vocab_file = os.path.join(
319
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
320
+ )
321
+ merge_file = os.path.join(
322
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
323
+ )
324
+
325
+ with open(vocab_file, "w", encoding="utf-8") as f:
326
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
327
+
328
+ index = 0
329
+ with open(merge_file, "w", encoding="utf-8") as writer:
330
+ writer.write("#version: 0.2\n")
331
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
332
+ if index != token_index:
333
+ logger.warning(
334
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
335
+ " Please check that the tokenizer is not corrupted!"
336
+ )
337
+ index = token_index
338
+ writer.write(" ".join(bpe_tokens) + "\n")
339
+ index += 1
340
+
341
+ return vocab_file, merge_file
342
+
343
+ def prepare_for_tokenization(self, text, **kwargs):
344
+ text = unicodedata.normalize("NFC", text)
345
+ return (text, kwargs)
tokenization_qwen2_fast.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for Qwen2."""
16
+
17
+ from typing import Optional, Tuple
18
+
19
+ from transformers.tokenization_utils import AddedToken
20
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
21
+ from transformers.utils import logging
22
+ from .tokenization_qwen2 import Qwen2Tokenizer
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ VOCAB_FILES_NAMES = {
28
+ "vocab_file": "vocab.json",
29
+ "merges_file": "merges.txt",
30
+ "tokenizer_file": "tokenizer.json",
31
+ }
32
+
33
+ PRETRAINED_VOCAB_FILES_MAP = {
34
+ "vocab_file": {"qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/vocab.json"},
35
+ "merges_file": {"qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/merges.txt"},
36
+ "tokenizer_file": {
37
+ "qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/tokenizer.json"
38
+ },
39
+ }
40
+
41
+ MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
42
+
43
+
44
+ class Qwen2TokenizerFast(PreTrainedTokenizerFast):
45
+ """
46
+ Construct a "fast" Qwen2 tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
47
+ Byte-Pair-Encoding.
48
+
49
+ Same with GPT2Tokenzier, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
50
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
51
+
52
+ ```python
53
+ >>> from transformers import Qwen2TokenizerFast
54
+
55
+ >>> tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen-tokenizer")
56
+ >>> tokenizer("Hello world")["input_ids"]
57
+ [9707, 1879]
58
+
59
+ >>> tokenizer(" Hello world")["input_ids"]
60
+ [21927, 1879]
61
+ ```
62
+ This is expected.
63
+
64
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
65
+ refer to this superclass for more information regarding those methods.
66
+
67
+ Args:
68
+ vocab_file (`str`, *optional*):
69
+ Path to the vocabulary file.
70
+ merges_file (`str`, *optional*):
71
+ Path to the merges file.
72
+ tokenizer_file (`str`, *optional*):
73
+ Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
74
+ contains everything needed to load the tokenizer.
75
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
76
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
77
+ token instead. Not applicable to this tokenizer.
78
+ bos_token (`str`, *optional*):
79
+ The beginning of sequence token. Not applicable for this tokenizer.
80
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
81
+ The end of sequence token.
82
+ pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
83
+ The token used for padding, for example when batching sequences of different lengths.
84
+ """
85
+
86
+ vocab_files_names = VOCAB_FILES_NAMES
87
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
88
+ max_model_input_sizes = MAX_MODEL_INPUT_SIZES
89
+ model_input_names = ["input_ids", "attention_mask"]
90
+ slow_tokenizer_class = Qwen2Tokenizer
91
+
92
+ def __init__(
93
+ self,
94
+ vocab_file=None,
95
+ merges_file=None,
96
+ tokenizer_file=None,
97
+ unk_token="<|endoftext|>",
98
+ bos_token=None,
99
+ eos_token="<|endoftext|>",
100
+ pad_token="<|endoftext|>",
101
+ **kwargs,
102
+ ):
103
+ # We need to at least pass vocab_file and merges_file to base class
104
+ # in case a slow tokenizer needs to be initialized; other can be
105
+ # configured through files.
106
+ # following GPT2TokenizerFast, also adding unk_token, bos_token, and eos_token
107
+
108
+ bos_token = (
109
+ AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
110
+ if isinstance(bos_token, str)
111
+ else bos_token
112
+ )
113
+ eos_token = (
114
+ AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
115
+ if isinstance(eos_token, str)
116
+ else eos_token
117
+ )
118
+ unk_token = (
119
+ AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
120
+ if isinstance(unk_token, str)
121
+ else unk_token
122
+ )
123
+ pad_token = (
124
+ AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
125
+ if isinstance(pad_token, str)
126
+ else pad_token
127
+ )
128
+
129
+ super().__init__(
130
+ vocab_file,
131
+ merges_file,
132
+ tokenizer_file=tokenizer_file,
133
+ unk_token=unk_token,
134
+ bos_token=bos_token,
135
+ eos_token=eos_token,
136
+ pad_token=pad_token,
137
+ **kwargs,
138
+ )
139
+
140
+ # Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast.save_vocabulary
141
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
142
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
143
+ return tuple(files)
tokenizer_config.json ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "151643": {
7
+ "content": "<|endoftext|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "151644": {
15
+ "content": "<|im_start|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "151645": {
23
+ "content": "<|im_end|>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ },
30
+ "151646": {
31
+ "content": "<|object_ref_start|>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": true
37
+ },
38
+ "151647": {
39
+ "content": "<|object_ref_end|>",
40
+ "lstrip": false,
41
+ "normalized": false,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": true
45
+ },
46
+ "151648": {
47
+ "content": "<|box_start|>",
48
+ "lstrip": false,
49
+ "normalized": false,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": true
53
+ },
54
+ "151649": {
55
+ "content": "<|box_end|>",
56
+ "lstrip": false,
57
+ "normalized": false,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": true
61
+ },
62
+ "151650": {
63
+ "content": "<|quad_start|>",
64
+ "lstrip": false,
65
+ "normalized": false,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": true
69
+ },
70
+ "151651": {
71
+ "content": "<|quad_end|>",
72
+ "lstrip": false,
73
+ "normalized": false,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": true
77
+ },
78
+ "151652": {
79
+ "content": "<|vision_start|>",
80
+ "lstrip": false,
81
+ "normalized": false,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": true
85
+ },
86
+ "151653": {
87
+ "content": "<|vision_end|>",
88
+ "lstrip": false,
89
+ "normalized": false,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": true
93
+ },
94
+ "151654": {
95
+ "content": "<|vision_pad|>",
96
+ "lstrip": false,
97
+ "normalized": false,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": true
101
+ },
102
+ "151655": {
103
+ "content": "<|image_pad|>",
104
+ "lstrip": false,
105
+ "normalized": false,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": true
109
+ },
110
+ "151656": {
111
+ "content": "<|video_pad|>",
112
+ "lstrip": false,
113
+ "normalized": false,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": true
117
+ },
118
+ "151657": {
119
+ "content": "<tool_call>",
120
+ "lstrip": false,
121
+ "normalized": false,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": false
125
+ },
126
+ "151658": {
127
+ "content": "</tool_call>",
128
+ "lstrip": false,
129
+ "normalized": false,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": false
133
+ },
134
+ "151659": {
135
+ "content": "<|fim_prefix|>",
136
+ "lstrip": false,
137
+ "normalized": false,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": false
141
+ },
142
+ "151660": {
143
+ "content": "<|fim_middle|>",
144
+ "lstrip": false,
145
+ "normalized": false,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": false
149
+ },
150
+ "151661": {
151
+ "content": "<|fim_suffix|>",
152
+ "lstrip": false,
153
+ "normalized": false,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": false
157
+ },
158
+ "151662": {
159
+ "content": "<|fim_pad|>",
160
+ "lstrip": false,
161
+ "normalized": false,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": false
165
+ },
166
+ "151663": {
167
+ "content": "<|repo_name|>",
168
+ "lstrip": false,
169
+ "normalized": false,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": false
173
+ },
174
+ "151664": {
175
+ "content": "<|file_sep|>",
176
+ "lstrip": false,
177
+ "normalized": false,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "151665": {
183
+ "content": "<img>",
184
+ "lstrip": false,
185
+ "normalized": false,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": true
189
+ },
190
+ "151666": {
191
+ "content": "</img>",
192
+ "lstrip": false,
193
+ "normalized": false,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": true
197
+ },
198
+ "151667": {
199
+ "content": "<IMG_CONTEXT>",
200
+ "lstrip": false,
201
+ "normalized": false,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": true
205
+ },
206
+ "151668": {
207
+ "content": "<quad>",
208
+ "lstrip": false,
209
+ "normalized": false,
210
+ "rstrip": false,
211
+ "single_word": false,
212
+ "special": true
213
+ },
214
+ "151669": {
215
+ "content": "</quad>",
216
+ "lstrip": false,
217
+ "normalized": false,
218
+ "rstrip": false,
219
+ "single_word": false,
220
+ "special": true
221
+ },
222
+ "151670": {
223
+ "content": "<ref>",
224
+ "lstrip": false,
225
+ "normalized": false,
226
+ "rstrip": false,
227
+ "single_word": false,
228
+ "special": true
229
+ },
230
+ "151671": {
231
+ "content": "</ref>",
232
+ "lstrip": false,
233
+ "normalized": false,
234
+ "rstrip": false,
235
+ "single_word": false,
236
+ "special": true
237
+ },
238
+ "151672": {
239
+ "content": "<box>",
240
+ "lstrip": false,
241
+ "normalized": false,
242
+ "rstrip": false,
243
+ "single_word": false,
244
+ "special": true
245
+ },
246
+ "151673": {
247
+ "content": "</box>",
248
+ "lstrip": false,
249
+ "normalized": false,
250
+ "rstrip": false,
251
+ "single_word": false,
252
+ "special": true
253
+ }
254
+ },
255
+ "additional_special_tokens": [
256
+ "<|im_start|>",
257
+ "<|im_end|>",
258
+ "<|object_ref_start|>",
259
+ "<|object_ref_end|>",
260
+ "<|box_start|>",
261
+ "<|box_end|>",
262
+ "<|quad_start|>",
263
+ "<|quad_end|>",
264
+ "<|vision_start|>",
265
+ "<|vision_end|>",
266
+ "<|vision_pad|>",
267
+ "<|image_pad|>",
268
+ "<|video_pad|>",
269
+ "<img>",
270
+ "</img>",
271
+ "<IMG_CONTEXT>",
272
+ "<quad>",
273
+ "</quad>",
274
+ "<ref>",
275
+ "</ref>",
276
+ "<box>",
277
+ "</box>"
278
+ ],
279
+ "bos_token": null,
280
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
281
+ "clean_up_tokenization_spaces": false,
282
+ "eos_token": "<|im_end|>",
283
+ "errors": "replace",
284
+ "model_max_length": 16384,
285
+ "pad_token": "<|endoftext|>",
286
+ "split_special_tokens": false,
287
+ "tokenizer_class": "Qwen2Tokenizer",
288
+ "unk_token": null
289
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff