Image-Text-to-Text
Safetensors
llava
StarCycle commited on
Commit
5d6d601
1 Parent(s): 7e38877

initial commit

Browse files
llava_internlm2_chat_7b_dinov2_e1_gpu8_finetune.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig,
10
+ AutoImageProcessor, Dinov2Model,
11
+ CLIPImageProcessor, CLIPVisionModel)
12
+
13
+ from xtuner.dataset import LLaVADataset
14
+ from xtuner.dataset.collate_fns import default_collate_fn
15
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
16
+ from xtuner.dataset.samplers import LengthGroupedSampler
17
+ from xtuner.engine import DatasetInfoHook, EvaluateChatHook
18
+ from xtuner.model import LLaVAModel
19
+ from xtuner.utils import PROMPT_TEMPLATE
20
+
21
+ #######################################################################
22
+ # PART 1 Settings #
23
+ #######################################################################
24
+ # Model
25
+ llm_name_or_path = 'internlm/internlm2-chat-7b'
26
+ visual_encoder_name_or_path = 'facebook/dinov2-large'
27
+ # Specify the pretrained pth
28
+ pretrained_pth = './work_dirs/llava_internlm2_chat_7b_dinov2_e1_gpu8_pretrain_copy/epoch_1.pth' # noqa: E501
29
+
30
+ # Data
31
+ data_root = './'
32
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
33
+ image_folder = data_root + 'llava_images/'
34
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
35
+ max_length = int(2048 - (336 / 14)**2)
36
+
37
+ # Scheduler & Optimizer
38
+ batch_size = 16 # per_device
39
+ accumulative_counts = 1
40
+ dataloader_num_workers = 0
41
+ max_epochs = 1
42
+ optim_type = AdamW
43
+ lr = 2e-4
44
+ betas = (0.9, 0.999)
45
+ weight_decay = 0
46
+ max_norm = 1 # grad clip
47
+ warmup_ratio = 0.03
48
+
49
+ # Evaluate the generation performance during the training
50
+ evaluation_freq = 500
51
+ SYSTEM = ''
52
+ evaluation_images = 'https://llava-vl.github.io/static/images/view.jpg'
53
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture']
54
+
55
+ #######################################################################
56
+ # PART 2 Model & Tokenizer & Image Processor #
57
+ #######################################################################
58
+ tokenizer = dict(
59
+ type=AutoTokenizer.from_pretrained,
60
+ pretrained_model_name_or_path=llm_name_or_path,
61
+ trust_remote_code=True,
62
+ padding_side='right')
63
+
64
+ image_processor = dict(
65
+ type=AutoImageProcessor.from_pretrained,
66
+ pretrained_model_name_or_path=visual_encoder_name_or_path,
67
+ trust_remote_code=True)
68
+
69
+ model = dict(
70
+ type=LLaVAModel,
71
+ freeze_llm=True,
72
+ freeze_visual_encoder=True,
73
+ pretrained_pth=pretrained_pth,
74
+ llm=dict(
75
+ type=AutoModelForCausalLM.from_pretrained,
76
+ pretrained_model_name_or_path=llm_name_or_path,
77
+ trust_remote_code=True,
78
+ torch_dtype=torch.float16,
79
+ quantization_config=dict(
80
+ type=BitsAndBytesConfig,
81
+ load_in_4bit=True,
82
+ load_in_8bit=False,
83
+ llm_int8_threshold=6.0,
84
+ llm_int8_has_fp16_weight=False,
85
+ bnb_4bit_compute_dtype=torch.float16,
86
+ bnb_4bit_use_double_quant=True,
87
+ bnb_4bit_quant_type='nf4')),
88
+ llm_lora=dict(
89
+ type=LoraConfig,
90
+ r=512,
91
+ lora_alpha=256,
92
+ lora_dropout=0.05,
93
+ bias='none',
94
+ task_type='CAUSAL_LM'),
95
+ visual_encoder=dict(
96
+ type=Dinov2Model.from_pretrained,
97
+ pretrained_model_name_or_path=visual_encoder_name_or_path),
98
+ visual_encoder_lora=dict(
99
+ type=LoraConfig, r=64, lora_alpha=16, lora_dropout=0.05, bias='none'))
100
+
101
+ #######################################################################
102
+ # PART 3 Dataset & Dataloader #
103
+ #######################################################################
104
+ llava_dataset = dict(
105
+ type=LLaVADataset,
106
+ data_path=data_path,
107
+ image_folder=image_folder,
108
+ tokenizer=tokenizer,
109
+ image_processor=image_processor,
110
+ dataset_map_fn=llava_map_fn,
111
+ template_map_fn=dict(
112
+ type=template_map_fn_factory, template=prompt_template),
113
+ max_length=max_length,
114
+ pad_image_to_square=True)
115
+
116
+ train_dataloader = dict(
117
+ batch_size=batch_size,
118
+ num_workers=dataloader_num_workers,
119
+ dataset=llava_dataset,
120
+ sampler=dict(
121
+ type=LengthGroupedSampler,
122
+ length_property='modality_length',
123
+ per_device_batch_size=batch_size * accumulative_counts),
124
+ collate_fn=dict(type=default_collate_fn))
125
+
126
+ #######################################################################
127
+ # PART 4 Scheduler & Optimizer #
128
+ #######################################################################
129
+ # optimizer
130
+ optim_wrapper = dict(
131
+ type=AmpOptimWrapper,
132
+ optimizer=dict(
133
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
134
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
135
+ accumulative_counts=accumulative_counts,
136
+ loss_scale='dynamic',
137
+ dtype='float16')
138
+
139
+ # learning policy
140
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
141
+ param_scheduler = [
142
+ dict(
143
+ type=LinearLR,
144
+ start_factor=1e-5,
145
+ by_epoch=True,
146
+ begin=0,
147
+ end=warmup_ratio * max_epochs,
148
+ convert_to_iter_based=True),
149
+ dict(
150
+ type=CosineAnnealingLR,
151
+ eta_min=0.0,
152
+ by_epoch=True,
153
+ begin=warmup_ratio * max_epochs,
154
+ T_max=max_epochs,
155
+ convert_to_iter_based=True)
156
+ ]
157
+
158
+ # train, val, test setting
159
+ train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1)
160
+
161
+ #######################################################################
162
+ # PART 5 Runtime #
163
+ #######################################################################
164
+ # Log the dialogue periodically during the training process, optional
165
+ custom_hooks = [
166
+ dict(type=DatasetInfoHook, tokenizer=tokenizer),
167
+ dict(
168
+ type=EvaluateChatHook,
169
+ tokenizer=tokenizer,
170
+ image_processor=image_processor,
171
+ every_n_iters=evaluation_freq,
172
+ evaluation_inputs=evaluation_inputs,
173
+ evaluation_images=evaluation_images,
174
+ system=SYSTEM,
175
+ prompt_template=prompt_template)
176
+ ]
177
+
178
+ # configure default hooks
179
+ default_hooks = dict(
180
+ # record the time of every iteration.
181
+ timer=dict(type=IterTimerHook),
182
+ # print log every 100 iterations.
183
+ logger=dict(type=LoggerHook, interval=10),
184
+ # enable the parameter scheduler.
185
+ param_scheduler=dict(type=ParamSchedulerHook),
186
+ # save checkpoint per epoch.
187
+ checkpoint=dict(type=CheckpointHook, interval=1),
188
+ # set sampler seed in distributed evrionment.
189
+ sampler_seed=dict(type=DistSamplerSeedHook),
190
+ )
191
+
192
+ # configure environment
193
+ env_cfg = dict(
194
+ # whether to enable cudnn benchmark
195
+ cudnn_benchmark=False,
196
+ # set multi process parameters
197
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
198
+ # set distributed parameters
199
+ dist_cfg=dict(backend='nccl'),
200
+ )
201
+
202
+ # set visualizer
203
+ from mmengine.visualization import Visualizer, TensorboardVisBackend
204
+ visualizer = dict(
205
+ type=Visualizer,
206
+ vis_backends=[dict(type=TensorboardVisBackend)]
207
+ )
208
+
209
+ # set log level
210
+ log_level = 'INFO'
211
+
212
+ # load from which checkpoint
213
+ load_from = None
214
+
215
+ # whether to resume training from the loaded checkpoint
216
+ resume = False
217
+
218
+ # Defaults to use random seed and disable `deterministic`
219
+ randomness = dict(seed=None, deterministic=False)
llava_internlm2_chat_7b_dinov2_e1_gpu8_pretrain.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.dataset import DefaultSampler
4
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
5
+ LoggerHook, ParamSchedulerHook)
6
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, AutoImageProcessor,
10
+ Dinov2Model)
11
+
12
+ from xtuner.dataset import LLaVADataset
13
+ from xtuner.dataset.collate_fns import default_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook
16
+ from xtuner.engine.runner import TrainLoop
17
+ from xtuner.model import LLaVAModel
18
+ from xtuner.utils import PROMPT_TEMPLATE
19
+
20
+ #######################################################################
21
+ # PART 1 Settings #
22
+ #######################################################################
23
+ # Model
24
+ llm_name_or_path = 'internlm/internlm2-chat-7b'
25
+ visual_encoder_name_or_path = 'facebook/dinov2-large'
26
+
27
+ # Data
28
+ data_root = './data/llava_data/'
29
+ data_path = data_root + 'LLaVA-Pretrain/blip_laion_cc_sbu_558k.json'
30
+ image_folder = data_root + 'LLaVA-Pretrain/images'
31
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
32
+ max_length = int(2048 - (336 / 14)**2)
33
+
34
+ # Scheduler & Optimizer
35
+ batch_size = 32 # per_device
36
+ accumulative_counts = 1
37
+ dataloader_num_workers = 0
38
+ max_epochs = 1
39
+ optim_type = AdamW
40
+ lr = 1e-3
41
+ betas = (0.9, 0.999)
42
+ weight_decay = 0
43
+ max_norm = 1 # grad clip
44
+ warmup_ratio = 0.03
45
+
46
+ # Save
47
+ save_steps = 500
48
+ save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
49
+
50
+ # Evaluate the generation performance during the training
51
+ evaluation_freq = 500
52
+ SYSTEM = ''
53
+ evaluation_images = 'https://llava-vl.github.io/static/images/view.jpg'
54
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture']
55
+
56
+ #######################################################################
57
+ # PART 2 Model & Tokenizer & Image Processor #
58
+ #######################################################################
59
+ tokenizer = dict(
60
+ type=AutoTokenizer.from_pretrained,
61
+ pretrained_model_name_or_path=llm_name_or_path,
62
+ trust_remote_code=True,
63
+ padding_side='right')
64
+
65
+ image_processor = dict(
66
+ type=AutoImageProcessor.from_pretrained,
67
+ pretrained_model_name_or_path=visual_encoder_name_or_path,
68
+ trust_remote_code=True)
69
+
70
+ model = dict(
71
+ type=LLaVAModel,
72
+ freeze_llm=True,
73
+ freeze_visual_encoder=True,
74
+ llm=dict(
75
+ type=AutoModelForCausalLM.from_pretrained,
76
+ pretrained_model_name_or_path=llm_name_or_path,
77
+ trust_remote_code=True,
78
+ torch_dtype=torch.float16,
79
+ quantization_config=dict(
80
+ type=BitsAndBytesConfig,
81
+ load_in_4bit=True,
82
+ load_in_8bit=False,
83
+ llm_int8_threshold=6.0,
84
+ llm_int8_has_fp16_weight=False,
85
+ bnb_4bit_compute_dtype=torch.float16,
86
+ bnb_4bit_use_double_quant=True,
87
+ bnb_4bit_quant_type='nf4')),
88
+ visual_encoder=dict(
89
+ type=Dinov2Model.from_pretrained,
90
+ pretrained_model_name_or_path=visual_encoder_name_or_path))
91
+
92
+ #######################################################################
93
+ # PART 3 Dataset & Dataloader #
94
+ #######################################################################
95
+ llava_dataset = dict(
96
+ type=LLaVADataset,
97
+ data_path=data_path,
98
+ image_folder=image_folder,
99
+ tokenizer=tokenizer,
100
+ image_processor=image_processor,
101
+ dataset_map_fn=llava_map_fn,
102
+ template_map_fn=dict(
103
+ type=template_map_fn_factory, template=prompt_template),
104
+ max_length=max_length,
105
+ pad_image_to_square=False)
106
+
107
+ train_dataloader = dict(
108
+ batch_size=batch_size,
109
+ num_workers=dataloader_num_workers,
110
+ dataset=llava_dataset,
111
+ sampler=dict(type=DefaultSampler, shuffle=True),
112
+ collate_fn=dict(type=default_collate_fn))
113
+
114
+ #######################################################################
115
+ # PART 4 Scheduler & Optimizer #
116
+ #######################################################################
117
+ # optimizer
118
+ optim_wrapper = dict(
119
+ type=AmpOptimWrapper,
120
+ optimizer=dict(
121
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
122
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
123
+ accumulative_counts=accumulative_counts,
124
+ loss_scale='dynamic',
125
+ dtype='float16')
126
+
127
+ # learning policy
128
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
129
+ param_scheduler = [
130
+ dict(
131
+ type=LinearLR,
132
+ start_factor=1e-5,
133
+ by_epoch=True,
134
+ begin=0,
135
+ end=warmup_ratio * max_epochs,
136
+ convert_to_iter_based=True),
137
+ dict(
138
+ type=CosineAnnealingLR,
139
+ eta_min=0.0,
140
+ by_epoch=True,
141
+ begin=warmup_ratio * max_epochs,
142
+ end=max_epochs,
143
+ convert_to_iter_based=True)
144
+ ]
145
+
146
+ # train, val, test setting
147
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
148
+
149
+ #######################################################################
150
+ # PART 5 Runtime #
151
+ #######################################################################
152
+ # Log the dialogue periodically during the training process, optional
153
+ custom_hooks = [
154
+ dict(type=DatasetInfoHook, tokenizer=tokenizer),
155
+ dict(
156
+ type=EvaluateChatHook,
157
+ tokenizer=tokenizer,
158
+ image_processor=image_processor,
159
+ every_n_iters=evaluation_freq,
160
+ evaluation_inputs=evaluation_inputs,
161
+ evaluation_images=evaluation_images,
162
+ system=SYSTEM,
163
+ prompt_template=prompt_template)
164
+ ]
165
+
166
+ # configure default hooks
167
+ default_hooks = dict(
168
+ # record the time of every iteration.
169
+ timer=dict(type=IterTimerHook),
170
+ # print log every 10 iterations.
171
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
172
+ # enable the parameter scheduler.
173
+ param_scheduler=dict(type=ParamSchedulerHook),
174
+ # save checkpoint per `save_steps`.
175
+ checkpoint=dict(
176
+ type=CheckpointHook,
177
+ by_epoch=False,
178
+ interval=save_steps,
179
+ max_keep_ckpts=save_total_limit),
180
+ # set sampler seed in distributed evrionment.
181
+ sampler_seed=dict(type=DistSamplerSeedHook),
182
+ )
183
+
184
+ # configure environment
185
+ env_cfg = dict(
186
+ # whether to enable cudnn benchmark
187
+ cudnn_benchmark=False,
188
+ # set multi process parameters
189
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
190
+ # set distributed parameters
191
+ dist_cfg=dict(backend='nccl'),
192
+ )
193
+
194
+ # set visualizer
195
+ from mmengine.visualization import Visualizer, TensorboardVisBackend
196
+ visualizer = dict(
197
+ type=Visualizer,
198
+ vis_backends=[dict(type=TensorboardVisBackend)]
199
+ )
200
+
201
+ # set log level
202
+ log_level = 'INFO'
203
+
204
+ # load from which checkpoint
205
+ load_from = None
206
+
207
+ # whether to resume training from the loaded checkpoint
208
+ resume = False
209
+
210
+ # Defaults to use random seed and disable `deterministic`
211
+ randomness = dict(seed=None, deterministic=False)
212
+
213
+ # set log processor
214
+ log_processor = dict(by_epoch=False)
lora_and_projector/llm_adapter/.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ adapter_model.safetensors filter=lfs diff=lfs merge=lfs -text
lora_and_projector/llm_adapter/README.md ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: peft
3
+ base_model: ../internlm2-chat-7b/
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+
201
+
202
+ ### Framework versions
203
+
204
+ - PEFT 0.7.1
lora_and_projector/llm_adapter/adapter_config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "../internlm2-chat-7b/",
5
+ "bias": "none",
6
+ "fan_in_fan_out": false,
7
+ "inference_mode": true,
8
+ "init_lora_weights": true,
9
+ "layers_pattern": null,
10
+ "layers_to_transform": null,
11
+ "loftq_config": {},
12
+ "lora_alpha": 256,
13
+ "lora_dropout": 0.05,
14
+ "megatron_config": null,
15
+ "megatron_core": "megatron.core",
16
+ "modules_to_save": null,
17
+ "peft_type": "LORA",
18
+ "r": 512,
19
+ "rank_pattern": {},
20
+ "revision": null,
21
+ "target_modules": [
22
+ "output",
23
+ "w3",
24
+ "w2",
25
+ "w1",
26
+ "wo",
27
+ "wqkv"
28
+ ],
29
+ "task_type": "CAUSAL_LM"
30
+ }
lora_and_projector/llm_adapter/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c8f4f8c4b7d1e163de56982a1e5d97755837ab52846c3e3da5dee107d6827f1
3
+ size 2514922648
lora_and_projector/projector/.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ model.safetensors filter=lfs diff=lfs merge=lfs -text
lora_and_projector/projector/config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ProjectorModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_projector.ProjectorConfig",
7
+ "AutoModel": "modeling_projector.ProjectorModel"
8
+ },
9
+ "bias": true,
10
+ "depth": 2,
11
+ "hidden_act": "gelu",
12
+ "llm_hidden_size": 4096,
13
+ "model_type": "projector",
14
+ "torch_dtype": "float32",
15
+ "transformers_version": "4.37.1",
16
+ "visual_hidden_size": 1024
17
+ }
lora_and_projector/projector/configuration_projector.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class ProjectorConfig(PretrainedConfig):
6
+ model_type = 'projector'
7
+ _auto_class = 'AutoConfig'
8
+
9
+ def __init__(
10
+ self,
11
+ visual_hidden_size=4096,
12
+ llm_hidden_size=4096,
13
+ depth=2,
14
+ hidden_act='gelu',
15
+ bias=True,
16
+ **kwargs,
17
+ ):
18
+ self.visual_hidden_size = visual_hidden_size
19
+ self.llm_hidden_size = llm_hidden_size
20
+ self.depth = depth
21
+ self.hidden_act = hidden_act
22
+ self.bias = bias
23
+ super().__init__(**kwargs)
lora_and_projector/projector/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4dca1653bb4b6d9024d8c383caf196304a84ab8d115022e320ec4f7a9f46b6be
3
+ size 83919216
lora_and_projector/projector/modeling_projector.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import PreTrainedModel
5
+ from transformers.activations import ACT2FN
6
+
7
+ from .configuration_projector import ProjectorConfig
8
+
9
+
10
+ class ProjectorModel(PreTrainedModel):
11
+ _auto_class = 'AutoModel'
12
+ config_class = ProjectorConfig
13
+ base_model_prefix = 'model'
14
+ supports_gradient_checkpointing = True
15
+
16
+ def __init__(self, config: ProjectorConfig) -> None:
17
+ super().__init__(config)
18
+ self.gradient_checkpointing = False
19
+
20
+ modules = [
21
+ nn.Linear(
22
+ config.visual_hidden_size,
23
+ config.llm_hidden_size,
24
+ bias=config.bias)
25
+ ]
26
+ for _ in range(1, config.depth):
27
+ modules.append(ACT2FN[config.hidden_act])
28
+ modules.append(
29
+ nn.Linear(
30
+ config.llm_hidden_size,
31
+ config.llm_hidden_size,
32
+ bias=config.bias))
33
+ self.model = nn.Sequential(*modules)
34
+
35
+ def enable_input_require_grads(self):
36
+
37
+ def make_inputs_require_grad(module, input, output):
38
+ output.requires_grad_(True)
39
+
40
+ self.model.register_forward_hook(make_inputs_require_grad)
41
+
42
+ def _set_gradient_checkpointing(self, module, value=False):
43
+ if isinstance(module, ProjectorModel):
44
+ module.gradient_checkpointing = value
45
+
46
+ def forward(self, x):
47
+ if self.gradient_checkpointing and self.training:
48
+ layer_outputs = torch.utils.checkpoint.checkpoint(self.model, x)
49
+ else:
50
+ layer_outputs = self.model(x)
51
+ return layer_outputs
lora_and_projector/visual_encoder_adapter/.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ adapter_model.safetensors filter=lfs diff=lfs merge=lfs -text
lora_and_projector/visual_encoder_adapter/README.md ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: peft
3
+ base_model: ../dinov2-large/
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+
201
+
202
+ ### Framework versions
203
+
204
+ - PEFT 0.7.1
lora_and_projector/visual_encoder_adapter/adapter_config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": {
4
+ "base_model_class": "Dinov2Model",
5
+ "parent_library": "transformers.models.dinov2.modeling_dinov2"
6
+ },
7
+ "base_model_name_or_path": "../dinov2-large/",
8
+ "bias": "none",
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": true,
12
+ "layers_pattern": null,
13
+ "layers_to_transform": null,
14
+ "loftq_config": {},
15
+ "lora_alpha": 16,
16
+ "lora_dropout": 0.05,
17
+ "megatron_config": null,
18
+ "megatron_core": "megatron.core",
19
+ "modules_to_save": null,
20
+ "peft_type": "LORA",
21
+ "r": 64,
22
+ "rank_pattern": {},
23
+ "revision": null,
24
+ "target_modules": [
25
+ "fc2",
26
+ "fc1",
27
+ "dense",
28
+ "key",
29
+ "query",
30
+ "value"
31
+ ],
32
+ "task_type": null
33
+ }
lora_and_projector/visual_encoder_adapter/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59453737c9e0d1fc0354772a1949c5deb9bcf104c8e991f570f811b823666c14
3
+ size 113285920
lora_and_projector/xtuner_config.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig,
10
+ AutoImageProcessor, Dinov2Model,
11
+ CLIPImageProcessor, CLIPVisionModel)
12
+
13
+ from xtuner.dataset import LLaVADataset
14
+ from xtuner.dataset.collate_fns import default_collate_fn
15
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
16
+ from xtuner.dataset.samplers import LengthGroupedSampler
17
+ from xtuner.engine import DatasetInfoHook, EvaluateChatHook
18
+ from xtuner.model import LLaVAModel
19
+ from xtuner.utils import PROMPT_TEMPLATE
20
+
21
+ #######################################################################
22
+ # PART 1 Settings #
23
+ #######################################################################
24
+ # Model
25
+ llm_name_or_path = '../internlm2-chat-7b/'
26
+ visual_encoder_name_or_path = '../dinov2-large/'
27
+ # Specify the pretrained pth
28
+ pretrained_pth = './work_dirs/llava_internlm2_chat_7b_clip_vit_large_p14_336_e1_gpu8_pretrain_copy/epoch_1.pth' # noqa: E501
29
+
30
+ # Data
31
+ data_root = './'
32
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
33
+ image_folder = data_root + 'llava_images/'
34
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
35
+ max_length = int(2048 - (336 / 14)**2)
36
+
37
+ # Scheduler & Optimizer
38
+ batch_size = 16 # per_device
39
+ accumulative_counts = 4
40
+ dataloader_num_workers = 4
41
+ max_epochs = 1
42
+ optim_type = AdamW
43
+ lr = 2e-4
44
+ betas = (0.9, 0.999)
45
+ weight_decay = 0
46
+ max_norm = 1 # grad clip
47
+ warmup_ratio = 0.03
48
+
49
+ # Evaluate the generation performance during the training
50
+ evaluation_freq = 500
51
+ SYSTEM = ''
52
+ evaluation_images = 'https://llava-vl.github.io/static/images/view.jpg'
53
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture']
54
+
55
+ #######################################################################
56
+ # PART 2 Model & Tokenizer & Image Processor #
57
+ #######################################################################
58
+ tokenizer = dict(
59
+ type=AutoTokenizer.from_pretrained,
60
+ pretrained_model_name_or_path=llm_name_or_path,
61
+ trust_remote_code=True,
62
+ padding_side='right')
63
+
64
+ image_processor = dict(
65
+ type=AutoImageProcessor.from_pretrained,
66
+ pretrained_model_name_or_path=visual_encoder_name_or_path,
67
+ trust_remote_code=True)
68
+
69
+ model = dict(
70
+ type=LLaVAModel,
71
+ freeze_llm=True,
72
+ freeze_visual_encoder=True,
73
+ pretrained_pth=pretrained_pth,
74
+ llm=dict(
75
+ type=AutoModelForCausalLM.from_pretrained,
76
+ pretrained_model_name_or_path=llm_name_or_path,
77
+ trust_remote_code=True,
78
+ torch_dtype=torch.float16,
79
+ quantization_config=dict(
80
+ type=BitsAndBytesConfig,
81
+ load_in_4bit=True,
82
+ load_in_8bit=False,
83
+ llm_int8_threshold=6.0,
84
+ llm_int8_has_fp16_weight=False,
85
+ bnb_4bit_compute_dtype=torch.float16,
86
+ bnb_4bit_use_double_quant=True,
87
+ bnb_4bit_quant_type='nf4')),
88
+ llm_lora=dict(
89
+ type=LoraConfig,
90
+ r=512,
91
+ lora_alpha=256,
92
+ lora_dropout=0.05,
93
+ bias='none',
94
+ task_type='CAUSAL_LM'),
95
+ visual_encoder=dict(
96
+ type=Dinov2Model.from_pretrained,
97
+ pretrained_model_name_or_path=visual_encoder_name_or_path),
98
+ visual_encoder_lora=dict(
99
+ type=LoraConfig, r=64, lora_alpha=16, lora_dropout=0.05, bias='none'))
100
+
101
+ #######################################################################
102
+ # PART 3 Dataset & Dataloader #
103
+ #######################################################################
104
+ llava_dataset = dict(
105
+ type=LLaVADataset,
106
+ data_path=data_path,
107
+ image_folder=image_folder,
108
+ tokenizer=tokenizer,
109
+ image_processor=image_processor,
110
+ dataset_map_fn=llava_map_fn,
111
+ template_map_fn=dict(
112
+ type=template_map_fn_factory, template=prompt_template),
113
+ max_length=max_length,
114
+ pad_image_to_square=True)
115
+
116
+ train_dataloader = dict(
117
+ batch_size=batch_size,
118
+ num_workers=dataloader_num_workers,
119
+ dataset=llava_dataset,
120
+ sampler=dict(
121
+ type=LengthGroupedSampler,
122
+ length_property='modality_length',
123
+ per_device_batch_size=batch_size * accumulative_counts),
124
+ collate_fn=dict(type=default_collate_fn))
125
+
126
+ #######################################################################
127
+ # PART 4 Scheduler & Optimizer #
128
+ #######################################################################
129
+ # optimizer
130
+ optim_wrapper = dict(
131
+ type=AmpOptimWrapper,
132
+ optimizer=dict(
133
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
134
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
135
+ accumulative_counts=accumulative_counts,
136
+ loss_scale='dynamic',
137
+ dtype='float16')
138
+
139
+ # learning policy
140
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
141
+ param_scheduler = [
142
+ dict(
143
+ type=LinearLR,
144
+ start_factor=1e-5,
145
+ by_epoch=True,
146
+ begin=0,
147
+ end=warmup_ratio * max_epochs,
148
+ convert_to_iter_based=True),
149
+ dict(
150
+ type=CosineAnnealingLR,
151
+ eta_min=0.0,
152
+ by_epoch=True,
153
+ begin=warmup_ratio * max_epochs,
154
+ T_max=max_epochs,
155
+ convert_to_iter_based=True)
156
+ ]
157
+
158
+ # train, val, test setting
159
+ train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1)
160
+
161
+ #######################################################################
162
+ # PART 5 Runtime #
163
+ #######################################################################
164
+ # Log the dialogue periodically during the training process, optional
165
+ custom_hooks = [
166
+ dict(type=DatasetInfoHook, tokenizer=tokenizer),
167
+ dict(
168
+ type=EvaluateChatHook,
169
+ tokenizer=tokenizer,
170
+ image_processor=image_processor,
171
+ every_n_iters=evaluation_freq,
172
+ evaluation_inputs=evaluation_inputs,
173
+ evaluation_images=evaluation_images,
174
+ system=SYSTEM,
175
+ prompt_template=prompt_template)
176
+ ]
177
+
178
+ # configure default hooks
179
+ default_hooks = dict(
180
+ # record the time of every iteration.
181
+ timer=dict(type=IterTimerHook),
182
+ # print log every 100 iterations.
183
+ logger=dict(type=LoggerHook, interval=10),
184
+ # enable the parameter scheduler.
185
+ param_scheduler=dict(type=ParamSchedulerHook),
186
+ # save checkpoint per epoch.
187
+ checkpoint=dict(type=CheckpointHook, interval=1),
188
+ # set sampler seed in distributed evrionment.
189
+ sampler_seed=dict(type=DistSamplerSeedHook),
190
+ )
191
+
192
+ # configure environment
193
+ env_cfg = dict(
194
+ # whether to enable cudnn benchmark
195
+ cudnn_benchmark=False,
196
+ # set multi process parameters
197
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
198
+ # set distributed parameters
199
+ dist_cfg=dict(backend='nccl'),
200
+ )
201
+
202
+ # set visualizer
203
+ from mmengine.visualization import Visualizer, TensorboardVisBackend
204
+ visualizer = dict(
205
+ type=Visualizer,
206
+ vis_backends=[dict(type=TensorboardVisBackend)]
207
+ )
208
+
209
+ # set log level
210
+ log_level = 'INFO'
211
+
212
+ # load from which checkpoint
213
+ load_from = None
214
+
215
+ # whether to resume training from the loaded checkpoint
216
+ resume = False
217
+
218
+ # Defaults to use random seed and disable `deterministic`
219
+ randomness = dict(seed=None, deterministic=False)
modified_xtuner_code/xtuner/tools/chat.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import os
4
+ import os.path as osp
5
+ import re
6
+ import sys
7
+
8
+ import torch
9
+ from huggingface_hub import snapshot_download
10
+ from peft import PeftModel
11
+ from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
12
+ BitsAndBytesConfig, AutoImageProcessor,
13
+ Dinov2Model, GenerationConfig)
14
+
15
+ from xtuner.dataset.utils import expand2square, load_image
16
+ from xtuner.model.utils import prepare_inputs_labels_for_multimodal
17
+ from xtuner.tools.utils import get_stop_criteria, get_streamer
18
+ from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
19
+ PROMPT_TEMPLATE, SYSTEM_TEMPLATE)
20
+
21
+ TORCH_DTYPE_MAP = dict(
22
+ fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
23
+
24
+
25
+ def remove_prefix(state_dict, prefix):
26
+ new_state_dict = {}
27
+ for key, value in state_dict.items():
28
+ if key.startswith(prefix):
29
+ new_key = key[len(prefix):]
30
+ new_state_dict[new_key] = value
31
+ else:
32
+ new_state_dict[key] = value
33
+ return new_state_dict
34
+
35
+
36
+ def parse_args():
37
+ parser = argparse.ArgumentParser(description='Chat with a HF model')
38
+ parser.add_argument(
39
+ 'model_name_or_path', help='Hugging Face model name or path')
40
+ adapter_group = parser.add_mutually_exclusive_group()
41
+ adapter_group.add_argument(
42
+ '--adapter', default=None, help='adapter name or path')
43
+ adapter_group.add_argument(
44
+ '--llava', default=None, help='llava name or path')
45
+ parser.add_argument(
46
+ '--visual-encoder', default=None, help='visual encoder name or path')
47
+ parser.add_argument(
48
+ '--visual-select-layer', default=-2, help='visual select layer')
49
+ parser.add_argument('--image', default=None, help='image')
50
+ parser.add_argument(
51
+ '--torch-dtype',
52
+ default='fp16',
53
+ choices=TORCH_DTYPE_MAP.keys(),
54
+ help='Override the default `torch.dtype` and load the model under '
55
+ 'a specific `dtype`.')
56
+ parser.add_argument(
57
+ '--prompt-template',
58
+ choices=PROMPT_TEMPLATE.keys(),
59
+ default=None,
60
+ help='Specify a prompt template')
61
+ system_group = parser.add_mutually_exclusive_group()
62
+ system_group.add_argument(
63
+ '--system', default=None, help='Specify the system text')
64
+ system_group.add_argument(
65
+ '--system-template',
66
+ choices=SYSTEM_TEMPLATE.keys(),
67
+ default=None,
68
+ help='Specify a system template')
69
+ parser.add_argument(
70
+ '--bits',
71
+ type=int,
72
+ choices=[4, 8, None],
73
+ default=None,
74
+ help='LLM bits')
75
+ parser.add_argument(
76
+ '--bot-name', type=str, default='BOT', help='Name for Bot')
77
+ parser.add_argument(
78
+ '--with-plugins',
79
+ nargs='+',
80
+ choices=['calculate', 'solve', 'search'],
81
+ help='Specify plugins to use')
82
+ parser.add_argument(
83
+ '--no-streamer', action='store_true', help='Whether to with streamer')
84
+ parser.add_argument(
85
+ '--lagent', action='store_true', help='Whether to use lagent')
86
+ parser.add_argument(
87
+ '--stop-words', nargs='+', type=str, default=[], help='Stop words')
88
+ parser.add_argument(
89
+ '--offload-folder',
90
+ default=None,
91
+ help='The folder in which to offload the model weights (or where the '
92
+ 'model weights are already offloaded).')
93
+ parser.add_argument(
94
+ '--max-new-tokens',
95
+ type=int,
96
+ default=2048,
97
+ help='Maximum number of new tokens allowed in generated text')
98
+ parser.add_argument(
99
+ '--temperature',
100
+ type=float,
101
+ default=0.1,
102
+ help='The value used to modulate the next token probabilities.')
103
+ parser.add_argument(
104
+ '--top-k',
105
+ type=int,
106
+ default=40,
107
+ help='The number of highest probability vocabulary tokens to '
108
+ 'keep for top-k-filtering.')
109
+ parser.add_argument(
110
+ '--top-p',
111
+ type=float,
112
+ default=0.75,
113
+ help='If set to float < 1, only the smallest set of most probable '
114
+ 'tokens with probabilities that add up to top_p or higher are '
115
+ 'kept for generation.')
116
+ parser.add_argument(
117
+ '--repetition-penalty',
118
+ type=float,
119
+ default=1.0,
120
+ help='The parameter for repetition penalty. 1.0 means no penalty.')
121
+ parser.add_argument(
122
+ '--seed',
123
+ type=int,
124
+ default=0,
125
+ help='Random seed for reproducible text generation')
126
+ args = parser.parse_args()
127
+ return args
128
+
129
+
130
+ def get_input():
131
+ """Helper function for getting input from users."""
132
+ sentinel = '' # ends when this string is seen
133
+ result = None
134
+ while result is None:
135
+ print(('\ndouble enter to end input (EXIT: exit chat, '
136
+ 'RESET: reset history) >>> '),
137
+ end='')
138
+ try:
139
+ result = '\n'.join(iter(input, sentinel))
140
+ except UnicodeDecodeError:
141
+ print('Invalid characters detected. Please enter again.')
142
+ return result
143
+
144
+
145
+ def main():
146
+ args = parse_args()
147
+ torch.manual_seed(args.seed)
148
+
149
+ # build llm
150
+ quantization_config = None
151
+ load_in_8bit = False
152
+ if args.bits == 4:
153
+ quantization_config = BitsAndBytesConfig(
154
+ load_in_4bit=True,
155
+ load_in_8bit=False,
156
+ llm_int8_threshold=6.0,
157
+ llm_int8_has_fp16_weight=False,
158
+ bnb_4bit_compute_dtype=torch.float16,
159
+ bnb_4bit_use_double_quant=True,
160
+ bnb_4bit_quant_type='nf4')
161
+ elif args.bits == 8:
162
+ load_in_8bit = True
163
+ model_kwargs = {
164
+ 'quantization_config': quantization_config,
165
+ 'load_in_8bit': load_in_8bit,
166
+ 'device_map': 'auto',
167
+ 'offload_folder': args.offload_folder,
168
+ 'trust_remote_code': True,
169
+ 'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype]
170
+ }
171
+ if args.lagent:
172
+ from lagent.actions import ActionExecutor, GoogleSearch
173
+ from lagent.agents import (CALL_PROTOCOL_CN, FORCE_STOP_PROMPT_CN,
174
+ ReAct, ReActProtocol)
175
+ from lagent.llms import HFTransformerCasualLM
176
+
177
+ try:
178
+ SERPER_API_KEY = os.environ['SERPER_API_KEY']
179
+ except Exception:
180
+ print('Please obtain the `SERPER_API_KEY` from https://serper.dev '
181
+ 'and set it using `export SERPER_API_KEY=xxx`.')
182
+ sys.exit(1)
183
+
184
+ model_kwargs.pop('trust_remote_code')
185
+ llm = HFTransformerCasualLM(
186
+ args.model_name_or_path, model_kwargs=model_kwargs)
187
+ if args.adapter is not None:
188
+ print(f'Loading adapter from {args.adapter}...')
189
+ llm.model = PeftModel.from_pretrained(
190
+ llm.model,
191
+ args.adapter,
192
+ offload_folder=args.offload_folder,
193
+ trust_remote_code=True)
194
+ search_tool = GoogleSearch(api_key=SERPER_API_KEY)
195
+ chatbot = ReAct(
196
+ llm=llm,
197
+ action_executor=ActionExecutor(actions=[search_tool]),
198
+ protocol=ReActProtocol(
199
+ call_protocol=CALL_PROTOCOL_CN,
200
+ force_stop=FORCE_STOP_PROMPT_CN))
201
+ while True:
202
+ text = get_input()
203
+ while text.strip() == 'RESET':
204
+ print('Log: History responses have been removed!')
205
+ chatbot._session_history = []
206
+ inputs = ''
207
+ text = get_input()
208
+ if text.strip() == 'EXIT':
209
+ print('Log: Exit!')
210
+ exit(0)
211
+ response = chatbot.chat(text)
212
+ print(response.response)
213
+ else:
214
+ if args.with_plugins is None:
215
+ inner_thoughts_open = False
216
+ calculate_open = False
217
+ solve_open = False
218
+ search_open = False
219
+ else:
220
+ assert args.prompt_template == args.system_template == 'moss_sft'
221
+ from plugins import plugins_api
222
+ inner_thoughts_open = True
223
+ calculate_open = 'calculate' in args.with_plugins
224
+ solve_open = 'solve' in args.with_plugins
225
+ search_open = 'search' in args.with_plugins
226
+ # pre-import for api and model preparation
227
+ if calculate_open:
228
+ from plugins import calculate # noqa: F401
229
+ if solve_open:
230
+ from plugins import solve # noqa: F401
231
+ if search_open:
232
+ from plugins import search # noqa: F401
233
+ # build llm
234
+ llm = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,
235
+ **model_kwargs)
236
+ tokenizer = AutoTokenizer.from_pretrained(
237
+ args.model_name_or_path,
238
+ trust_remote_code=True,
239
+ encode_special_tokens=True)
240
+ print(f'Load LLM from {args.model_name_or_path}')
241
+ if args.adapter is not None:
242
+ llm = PeftModel.from_pretrained(
243
+ llm,
244
+ args.adapter,
245
+ offload_folder=args.offload_folder,
246
+ trust_remote_code=True)
247
+ print(f'Load adapter from {args.adapter}')
248
+ if args.llava is not None:
249
+ llava_path = snapshot_download(
250
+ repo_id=args.llava) if not osp.isdir(
251
+ args.llava) else args.llava
252
+
253
+ # build visual_encoder
254
+ if 'visual_encoder' in os.listdir(llava_path):
255
+ assert args.visual_encoder is None, (
256
+ "Please don't specify the `--visual-encoder` since passed "
257
+ '`--llava` contains a visual encoder!')
258
+ visual_encoder_path = osp.join(llava_path, 'visual_encoder')
259
+ else:
260
+ assert args.visual_encoder is not None, (
261
+ 'Please specify the `--visual-encoder`!')
262
+ visual_encoder_path = args.visual_encoder
263
+ visual_encoder = Dinov2Model.from_pretrained(
264
+ visual_encoder_path,
265
+ torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype])
266
+ image_processor = AutoImageProcessor.from_pretrained(
267
+ visual_encoder_path)
268
+ print(f'Load visual_encoder from {visual_encoder_path}')
269
+
270
+ # load adapter
271
+ if 'llm_adapter' in os.listdir(llava_path):
272
+ adapter_path = osp.join(llava_path, 'llm_adapter')
273
+ llm = PeftModel.from_pretrained(
274
+ llm,
275
+ adapter_path,
276
+ offload_folder=args.offload_folder,
277
+ trust_remote_code=True)
278
+ print(f'Load LLM adapter from {args.llava}')
279
+ if 'visual_encoder_adapter' in os.listdir(llava_path):
280
+ adapter_path = osp.join(llava_path, 'visual_encoder_adapter')
281
+ visual_encoder = PeftModel.from_pretrained(
282
+ visual_encoder,
283
+ adapter_path,
284
+ offload_folder=args.offload_folder)
285
+ print(f'Load visual_encoder adapter from {args.llava}')
286
+
287
+ # build projector
288
+ projector_path = osp.join(llava_path, 'projector')
289
+ projector = AutoModel.from_pretrained(
290
+ projector_path,
291
+ torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype],
292
+ trust_remote_code=True)
293
+ print(f'Load projector from {args.llava}')
294
+
295
+ projector.cuda()
296
+ projector.eval()
297
+ visual_encoder.cuda()
298
+ visual_encoder.eval()
299
+
300
+ llm.eval()
301
+
302
+ if args.image is not None:
303
+ image = load_image(args.image)
304
+ image = expand2square(
305
+ image, tuple(int(x * 255) for x in image_processor.image_mean))
306
+ image = image_processor.preprocess(
307
+ image, return_tensors='pt')['pixel_values'][0]
308
+ image = image.cuda().unsqueeze(0)
309
+ visual_outputs = visual_encoder(image, output_hidden_states=True)
310
+ pixel_values = projector(
311
+ visual_outputs.hidden_states[args.visual_select_layer][:, 1:])
312
+
313
+ stop_words = args.stop_words
314
+ sep = ''
315
+ if args.prompt_template:
316
+ template = PROMPT_TEMPLATE[args.prompt_template]
317
+ stop_words += template.get('STOP_WORDS', [])
318
+ sep = template.get('SEP', '')
319
+ stop_criteria = get_stop_criteria(
320
+ tokenizer=tokenizer, stop_words=stop_words)
321
+
322
+ if args.no_streamer:
323
+ Streamer = None
324
+ else:
325
+ Streamer = get_streamer(llm)
326
+
327
+ gen_config = GenerationConfig(
328
+ max_new_tokens=args.max_new_tokens,
329
+ do_sample=args.temperature > 0,
330
+ temperature=args.temperature,
331
+ top_p=args.top_p,
332
+ top_k=args.top_k,
333
+ repetition_penalty=args.repetition_penalty,
334
+ eos_token_id=tokenizer.eos_token_id,
335
+ pad_token_id=tokenizer.pad_token_id
336
+ if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
337
+ )
338
+
339
+ n_turn = 0
340
+ inputs = ''
341
+ while True:
342
+ text = get_input()
343
+ while text.strip() == 'RESET':
344
+ print('Log: History responses have been removed!')
345
+ n_turn = 0
346
+ inputs = ''
347
+ text = get_input()
348
+ if text.strip() == 'EXIT':
349
+ print('Log: Exit!')
350
+ exit(0)
351
+
352
+ if args.image is not None and n_turn == 0:
353
+ text = DEFAULT_IMAGE_TOKEN + '\n' + text
354
+
355
+ if args.prompt_template:
356
+ prompt_text = ''
357
+ template = PROMPT_TEMPLATE[args.prompt_template]
358
+ if 'SYSTEM' in template and n_turn == 0:
359
+ system_text = None
360
+ if args.system_template is not None:
361
+ system_text = SYSTEM_TEMPLATE[
362
+ args.system_template].format(
363
+ round=n_turn + 1, bot_name=args.bot_name)
364
+ elif args.system is not None:
365
+ system_text = args.system
366
+ if system_text is not None:
367
+ prompt_text += template['SYSTEM'].format(
368
+ system=system_text,
369
+ round=n_turn + 1,
370
+ bot_name=args.bot_name)
371
+ prompt_text += template['INSTRUCTION'].format(
372
+ input=text, round=n_turn + 1, bot_name=args.bot_name)
373
+ if args.prompt_template == args.system_template == 'moss_sft':
374
+ if not inner_thoughts_open:
375
+ prompt_text.replace('- Inner thoughts: enabled.',
376
+ '- Inner thoughts: disabled.')
377
+ if not calculate_open:
378
+ prompt_text.replace(('- Calculator: enabled. API: '
379
+ 'Calculate(expression)'),
380
+ '- Calculator: disabled.')
381
+ if not solve_open:
382
+ prompt_text.replace(
383
+ '- Equation solver: enabled. API: Solve(equation)',
384
+ '- Equation solver: disabled.')
385
+ if not search_open:
386
+ prompt_text.replace(
387
+ '- Web search: enabled. API: Search(query)',
388
+ '- Web search: disabled.')
389
+ else:
390
+ prompt_text = text
391
+ inputs += prompt_text
392
+ if args.image is None:
393
+ if n_turn == 0:
394
+ ids = tokenizer.encode(inputs, return_tensors='pt')
395
+ else:
396
+ ids = tokenizer.encode(
397
+ inputs, return_tensors='pt', add_special_tokens=False)
398
+ streamer = Streamer(
399
+ tokenizer) if Streamer is not None else None
400
+ if args.with_plugins is not None:
401
+ generate_output = llm.generate(
402
+ inputs=ids.cuda(),
403
+ generation_config=gen_config,
404
+ streamer=streamer,
405
+ stopping_criteria=stop_criteria).cpu()
406
+ generate_output_text = tokenizer.decode(
407
+ generate_output[0][len(ids[0]):])
408
+ if streamer is None:
409
+ end = '' if generate_output_text[-1] == '\n' else '\n'
410
+ print(generate_output_text, end=end)
411
+ pattern = r'<\|Commands\|>:(.*?)<eoc>'
412
+ command_text = ', '.join(
413
+ re.findall(pattern, generate_output_text))
414
+ extent_text = plugins_api(
415
+ command_text,
416
+ calculate_open=calculate_open,
417
+ solve_open=solve_open,
418
+ search_open=search_open)
419
+ end = '' if extent_text[-1] == '\n' else '\n'
420
+ print(extent_text, end=end)
421
+ extent_text_ids = tokenizer.encode(
422
+ extent_text,
423
+ return_tensors='pt',
424
+ add_special_tokens=False)
425
+ new_ids = torch.cat((generate_output, extent_text_ids),
426
+ dim=1)
427
+ new_streamer = Streamer(
428
+ tokenizer) if Streamer is not None else None
429
+ generate_output = llm.generate(
430
+ inputs=new_ids.cuda(),
431
+ generation_config=gen_config,
432
+ streamer=new_streamer,
433
+ stopping_criteria=stop_criteria)
434
+ if streamer is None:
435
+ output_text = tokenizer.decode(
436
+ generate_output[0][len(new_ids[0]):])
437
+ end = '' if output_text[-1] == '\n' else '\n'
438
+ print(output_text, end=end)
439
+ else:
440
+ generate_output = llm.generate(
441
+ inputs=ids.cuda(),
442
+ generation_config=gen_config,
443
+ streamer=streamer,
444
+ stopping_criteria=stop_criteria)
445
+ if streamer is None:
446
+ output_text = tokenizer.decode(
447
+ generate_output[0][len(ids[0]):])
448
+ end = '' if output_text[-1] == '\n' else '\n'
449
+ print(output_text, end=end)
450
+ inputs = tokenizer.decode(generate_output[0])
451
+ else:
452
+ chunk_encode = []
453
+ for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
454
+ if idx == 0 and n_turn == 0:
455
+ cur_encode = tokenizer.encode(chunk)
456
+ else:
457
+ cur_encode = tokenizer.encode(
458
+ chunk, add_special_tokens=False)
459
+ chunk_encode.append(cur_encode)
460
+ assert len(chunk_encode) == 2
461
+ ids = []
462
+ for idx, cur_chunk_encode in enumerate(chunk_encode):
463
+ ids.extend(cur_chunk_encode)
464
+ if idx != len(chunk_encode) - 1:
465
+ ids.append(IMAGE_TOKEN_INDEX)
466
+ ids = torch.tensor(ids).cuda().unsqueeze(0)
467
+ mm_inputs = prepare_inputs_labels_for_multimodal(
468
+ llm=llm, input_ids=ids, pixel_values=pixel_values)
469
+
470
+ streamer = Streamer(
471
+ tokenizer) if Streamer is not None else None
472
+ generate_output = llm.generate(
473
+ **mm_inputs,
474
+ generation_config=gen_config,
475
+ streamer=streamer,
476
+ bos_token_id=tokenizer.bos_token_id,
477
+ stopping_criteria=stop_criteria)
478
+ if streamer is None:
479
+ output_text = tokenizer.decode(generate_output[0])
480
+ end = '' if output_text[-1] == '\n' else '\n'
481
+ print(output_text, end=end)
482
+ inputs += tokenizer.decode(generate_output[0])
483
+ n_turn += 1
484
+ inputs += sep
485
+ if len(generate_output[0]) >= args.max_new_tokens:
486
+ print(
487
+ 'Remove the memory of history responses, since '
488
+ f'it exceeds the length limitation {args.max_new_tokens}.')
489
+ n_turn = 0
490
+ inputs = ''
491
+
492
+
493
+ if __name__ == '__main__':
494
+ main()
modified_xtuner_code/xtuner/tools/mmbench.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import json
4
+ import math
5
+ import os
6
+ import os.path as osp
7
+ import re
8
+ import string
9
+ import time
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import torch
14
+ import tqdm
15
+ from huggingface_hub import snapshot_download
16
+ from mmengine import mkdir_or_exist
17
+ from mmengine.dist import (collect_results, get_dist_info, get_rank, init_dist,
18
+ master_only)
19
+ from mmengine.utils.dl_utils import set_multi_processing
20
+ from peft import PeftModel
21
+ from rich.console import Console
22
+ from rich.table import Table
23
+ from torch.utils.data import Dataset
24
+ from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
25
+ BitsAndBytesConfig, AutoImageProcessor,
26
+ Dinov2Model, GenerationConfig)
27
+
28
+ from xtuner.dataset.utils import decode_base64_to_image, expand2square
29
+ from xtuner.model.utils import LoadWoInit, prepare_inputs_labels_for_multimodal
30
+ from xtuner.tools.utils import get_stop_criteria, is_cn_string
31
+ from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
32
+ PROMPT_TEMPLATE)
33
+
34
+ TORCH_DTYPE_MAP = dict(
35
+ fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
36
+
37
+
38
+ def parse_args():
39
+ parser = argparse.ArgumentParser(description='MMBench')
40
+ parser.add_argument(
41
+ 'model_name_or_path', help='Hugging Face model name or path')
42
+ parser.add_argument('--data-path', default=None, help='data path')
43
+ parser.add_argument('--work-dir', help='the dir to save results')
44
+ parser.add_argument('--llava', default=None, help='llava name or path')
45
+ parser.add_argument(
46
+ '--visual-encoder', default=None, help='visual encoder name or path')
47
+ parser.add_argument(
48
+ '--visual-select-layer', default=-2, help='visual select layer')
49
+ parser.add_argument(
50
+ '--prompt-template',
51
+ choices=PROMPT_TEMPLATE.keys(),
52
+ default=None,
53
+ help='Specify a prompt template')
54
+ parser.add_argument(
55
+ '--stop-words', nargs='+', type=str, default=[], help='Stop words')
56
+ parser.add_argument(
57
+ '--torch-dtype',
58
+ default='fp16',
59
+ choices=TORCH_DTYPE_MAP.keys(),
60
+ help='Override the default `torch.dtype` and load the model under '
61
+ 'a specific `dtype`.')
62
+ parser.add_argument(
63
+ '--bits',
64
+ type=int,
65
+ choices=[4, 8, None],
66
+ default=None,
67
+ help='LLM bits')
68
+ parser.add_argument(
69
+ '--bot-name', type=str, default='BOT', help='Name for Bot')
70
+ parser.add_argument(
71
+ '--offload-folder',
72
+ default=None,
73
+ help='The folder in which to offload the model weights (or where the '
74
+ 'model weights are already offloaded).')
75
+ parser.add_argument(
76
+ '--max-new-tokens',
77
+ type=int,
78
+ default=100,
79
+ help='Maximum number of new tokens allowed in generated text')
80
+ parser.add_argument(
81
+ '--seed',
82
+ type=int,
83
+ default=0,
84
+ help='Random seed for reproducible text generation')
85
+ parser.add_argument(
86
+ '--launcher',
87
+ choices=['none', 'pytorch', 'slurm', 'mpi'],
88
+ default='none',
89
+ help='job launcher')
90
+ args = parser.parse_args()
91
+ return args
92
+
93
+
94
+ @master_only
95
+ def master_print(msg):
96
+ print(msg)
97
+
98
+
99
+ class MMBenchDataset(Dataset):
100
+ ABBRS = {
101
+ 'coarse_perception': 'CP',
102
+ 'finegrained_perception (instance-level)': 'FP-S',
103
+ 'finegrained_perception (cross-instance)': 'FP-C',
104
+ 'logic_reasoning': 'LR',
105
+ 'relation_reasoning': 'RR',
106
+ 'attribute_reasoning': 'AR',
107
+ 'sketch_reasoning': 'Sketch Reasoning',
108
+ 'scenery_building': 'Scenery & Building',
109
+ 'food_clothes': 'Food & Clothes',
110
+ 'historical_figure': 'Historical Figure',
111
+ 'traditional_show': 'Traditional Show',
112
+ 'calligraphy_painting': 'Calligraphy Painting',
113
+ 'cultural_relic': 'Cultural Relic'
114
+ }
115
+
116
+ def __init__(self, data_file):
117
+ self.data_file = data_file
118
+ self.df = pd.read_csv(data_file, sep='\t')
119
+ self.split = 'dev' if 'answer' in self.df.iloc[0].keys() else 'test'
120
+ self.has_l2_category = 'l2-category' in self.df.columns.to_list()
121
+
122
+ def get_image(self, image):
123
+ while len(image) < 16:
124
+ image = self.df[self.df['index'] == int(image)]['image'].values
125
+ assert len(image) == 1
126
+ image = image[0]
127
+ image = decode_base64_to_image(image)
128
+ return image
129
+
130
+ def __len__(self):
131
+ return len(self.df)
132
+
133
+ def __getitem__(self, idx):
134
+ index = self.df.iloc[idx]['index']
135
+ image = self.df.iloc[idx]['image']
136
+ image = self.get_image(image)
137
+ question = self.df.iloc[idx]['question']
138
+ answer = self.df.iloc[idx]['answer'] if 'answer' in self.df.iloc[
139
+ 0].keys() else None
140
+ category = self.df.iloc[idx]['category']
141
+
142
+ options = {
143
+ cand: self.load_from_df(idx, cand)
144
+ for cand in string.ascii_uppercase
145
+ if self.load_from_df(idx, cand) is not None
146
+ }
147
+ options_prompt = ''
148
+ for key, item in options.items():
149
+ options_prompt += f'{key}. {item}\n'
150
+
151
+ hint = self.load_from_df(idx, 'hint')
152
+ data = {
153
+ 'img': image,
154
+ 'question': question,
155
+ 'answer': answer,
156
+ 'options': options_prompt,
157
+ 'category': category,
158
+ 'options_dict': options,
159
+ 'index': index,
160
+ 'context': hint,
161
+ }
162
+ if self.has_l2_category:
163
+ data.update({'l2-category': self.df.iloc[idx]['l2-category']})
164
+ return data
165
+
166
+ def load_from_df(self, idx, key):
167
+ if key in self.df.iloc[idx] and not pd.isna(self.df.iloc[idx][key]):
168
+ return self.df.iloc[idx][key]
169
+ else:
170
+ return None
171
+
172
+ @master_only
173
+ def eval_result(self, result_df, show=True):
174
+
175
+ def calc_acc(df, group='category'):
176
+ assert group in ['overall', 'category', 'l2-category']
177
+ if group == 'overall':
178
+ res = {'Average': np.mean(df['hit'])}
179
+ else:
180
+ res = {}
181
+ abilities = list(set(df[group]))
182
+ abilities.sort()
183
+ for ab in abilities:
184
+ sub_df = df[df[group] == ab]
185
+ ab = self.ABBRS[ab] if ab in self.ABBRS else ab
186
+ res[ab] = np.mean(sub_df['hit'])
187
+ return res
188
+
189
+ def eval_sub_data(sub_data, answer_map):
190
+ lt = len(sub_data)
191
+ for i in range(lt):
192
+ item = sub_data.iloc[i]
193
+ match = re.search(r'([A-D]+)', item['prediction'])
194
+ pred = match.group(1) if match else ''
195
+ gt = answer_map[item['index']]
196
+ if gt != pred:
197
+ return 0
198
+ return 1
199
+
200
+ def show_result(ret_json):
201
+ show_dict = ret_json.copy()
202
+ table = Table(title=f' MMBench ({self.data_file}) ')
203
+ console = Console()
204
+ table.add_column('Category', justify='left')
205
+ table.add_column('Accuracy (%)', justify='right')
206
+ average = show_dict.pop('Average') * 100
207
+ table.add_row('Average', f'{average:.1f}')
208
+ table.add_section()
209
+ for cat_name, cat_acc in show_dict.items():
210
+ table.add_row(cat_name, f'{cat_acc * 100:.1f}')
211
+ with console.capture() as capture:
212
+ console.print(table, end='')
213
+ print('\n' + capture.get())
214
+ print('Note: Please be cautious if you use the results in papers, '
215
+ "since we don't use ChatGPT as a helper for choice "
216
+ 'extraction')
217
+
218
+ data = result_df.sort_values(by='index')
219
+ data['prediction'] = [str(x) for x in data['prediction']]
220
+ for k in data.keys():
221
+ data[k.lower() if k not in 'ABCD' else k] = data.pop(k)
222
+
223
+ data_main = data[data['index'] < int(1e6)]
224
+ cate_map = {
225
+ i: c
226
+ for i, c in zip(self.df['index'], self.df['category'])
227
+ }
228
+ if self.has_l2_category:
229
+ l2_cate_map = {
230
+ i: c
231
+ for i, c in zip(self.df['index'], self.df['l2-category'])
232
+ }
233
+ answer_map = {
234
+ i: c
235
+ for i, c in zip(self.df['index'], self.df['answer'])
236
+ }
237
+
238
+ lt = len(data_main)
239
+ hit, tot = 0, 0
240
+ result = {}
241
+ for i in range(lt):
242
+ item_main = data_main.iloc[i]
243
+ idx = item_main['index']
244
+ assert idx not in result
245
+ sub_data = data[data['index'] % int(1e6) == idx]
246
+ ret = eval_sub_data(sub_data, answer_map)
247
+ result[idx] = ret
248
+ hit += ret
249
+ tot += 1
250
+
251
+ indices = data_main['index']
252
+ data_main = data_main.copy()
253
+ data_main['hit'] = [result[i] for i in indices]
254
+ main_idx = data_main['index']
255
+ data_main['category'] = [cate_map[i] for i in main_idx]
256
+
257
+ ret_json = calc_acc(data_main, 'overall')
258
+
259
+ if self.has_l2_category:
260
+ data_main['l2-category'] = [l2_cate_map[i] for i in main_idx]
261
+ l2 = calc_acc(data_main, 'l2-category')
262
+ ret_json.update(l2)
263
+ else:
264
+ leaf = calc_acc(data_main, 'category')
265
+ ret_json.update(leaf)
266
+ if show:
267
+ show_result(ret_json)
268
+ return ret_json
269
+
270
+
271
+ def main():
272
+ args = parse_args()
273
+
274
+ torch.manual_seed(args.seed)
275
+
276
+ if args.launcher != 'none':
277
+ set_multi_processing(distributed=True)
278
+ init_dist(args.launcher)
279
+
280
+ rank, world_size = get_dist_info()
281
+ torch.cuda.set_device(rank)
282
+ else:
283
+ rank = 0
284
+ world_size = 1
285
+
286
+ # build llm
287
+ quantization_config = None
288
+ load_in_8bit = False
289
+ if args.bits == 4:
290
+ quantization_config = BitsAndBytesConfig(
291
+ load_in_4bit=True,
292
+ load_in_8bit=False,
293
+ llm_int8_threshold=6.0,
294
+ llm_int8_has_fp16_weight=False,
295
+ bnb_4bit_compute_dtype=torch.float16,
296
+ bnb_4bit_use_double_quant=True,
297
+ bnb_4bit_quant_type='nf4')
298
+ elif args.bits == 8:
299
+ load_in_8bit = True
300
+ model_kwargs = {
301
+ 'quantization_config': quantization_config,
302
+ 'load_in_8bit': load_in_8bit,
303
+ 'device_map': rank if world_size > 1 else 'auto',
304
+ 'offload_folder': args.offload_folder,
305
+ 'trust_remote_code': True,
306
+ 'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype]
307
+ }
308
+
309
+ # build llm
310
+ with LoadWoInit():
311
+ llm = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,
312
+ **model_kwargs)
313
+ tokenizer = AutoTokenizer.from_pretrained(
314
+ args.model_name_or_path,
315
+ trust_remote_code=True,
316
+ encode_special_tokens=True)
317
+ master_print(f'Load LLM from {args.model_name_or_path}')
318
+
319
+ llava_path = snapshot_download(
320
+ repo_id=args.llava) if not osp.isdir(args.llava) else args.llava
321
+
322
+ # build visual_encoder
323
+ if 'visual_encoder' in os.listdir(llava_path):
324
+ assert args.visual_encoder is None, (
325
+ "Please don't specify the `--visual-encoder` since passed "
326
+ '`--llava` contains a visual encoder!')
327
+ visual_encoder_path = osp.join(llava_path, 'visual_encoder')
328
+ else:
329
+ assert args.visual_encoder is not None, (
330
+ 'Please specify the `--visual-encoder`!')
331
+ visual_encoder_path = args.visual_encoder
332
+ with LoadWoInit():
333
+ visual_encoder = Dinov2Model.from_pretrained(
334
+ visual_encoder_path, torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype])
335
+ image_processor = AutoImageProcessor.from_pretrained(
336
+ visual_encoder_path)
337
+ master_print(f'Load visual_encoder from {visual_encoder_path}')
338
+
339
+ # load adapter
340
+ if 'llm_adapter' in os.listdir(llava_path):
341
+ adapter_path = osp.join(llava_path, 'llm_adapter')
342
+
343
+ with LoadWoInit():
344
+ llm = PeftModel.from_pretrained(
345
+ llm, adapter_path, offload_folder=args.offload_folder)
346
+
347
+ master_print(f'Load LLM adapter from {args.llava}')
348
+
349
+ if 'visual_encoder_adapter' in os.listdir(llava_path):
350
+ adapter_path = osp.join(llava_path, 'visual_encoder_adapter')
351
+ visual_encoder = PeftModel.from_pretrained(
352
+ visual_encoder, adapter_path, offload_folder=args.offload_folder)
353
+ master_print(f'Load visual_encoder adapter from {args.llava}')
354
+
355
+ # build projector
356
+ projector_path = osp.join(llava_path, 'projector')
357
+ with LoadWoInit():
358
+ projector = AutoModel.from_pretrained(
359
+ projector_path, torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype])
360
+ master_print(f'Load projector from {args.llava}')
361
+
362
+ projector.cuda()
363
+ projector.eval()
364
+
365
+ visual_encoder.cuda()
366
+ visual_encoder.eval()
367
+
368
+ llm.eval()
369
+
370
+ stop_words = args.stop_words
371
+ if args.prompt_template:
372
+ template = PROMPT_TEMPLATE[args.prompt_template]
373
+ stop_words += template.get('STOP_WORDS', [])
374
+ stop_criteria = get_stop_criteria(
375
+ tokenizer=tokenizer, stop_words=stop_words)
376
+
377
+ gen_config = GenerationConfig(
378
+ max_new_tokens=args.max_new_tokens,
379
+ do_sample=False,
380
+ eos_token_id=tokenizer.eos_token_id,
381
+ pad_token_id=tokenizer.pad_token_id
382
+ if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
383
+ )
384
+
385
+ # work_dir
386
+ if args.work_dir is not None:
387
+ # update configs according to CLI args if args.work_dir is not None
388
+ save_dir = args.work_dir
389
+ else:
390
+ # use config filename as default work_dir
391
+ save_dir = osp.join('./work_dirs',
392
+ osp.splitext(osp.basename(args.data_path))[0])
393
+ timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
394
+ save_dir = osp.join(save_dir, timestamp)
395
+
396
+ if rank == 0:
397
+ mkdir_or_exist(osp.abspath(save_dir))
398
+ print('=======================================================')
399
+ print(f'Dataset path: {osp.abspath(args.data_path)}\n'
400
+ f'Results will be saved to {osp.abspath(save_dir)}')
401
+ print('=======================================================')
402
+
403
+ args_path = osp.join(save_dir, 'args.json')
404
+ with open(args_path, 'w') as f:
405
+ json.dump(args.__dict__, f, indent=2)
406
+
407
+ results_xlsx_path = osp.join(save_dir, 'mmbench_result.xlsx')
408
+ results_json_path = osp.join(save_dir, 'mmbench_result.json')
409
+
410
+ dataset = MMBenchDataset(args.data_path)
411
+
412
+ results = []
413
+ n_samples = len(dataset)
414
+ per_rank_samples = math.ceil(n_samples / world_size)
415
+
416
+ per_rank_ids = range(per_rank_samples * rank,
417
+ min(n_samples, per_rank_samples * (rank + 1)))
418
+ for i in tqdm.tqdm(per_rank_ids, desc=f'Rank {rank}'):
419
+ data_sample = dataset[i]
420
+ if data_sample['context'] is not None:
421
+ text = data_sample['context'] + '\n' + data_sample[
422
+ 'question'] + '\n' + data_sample['options']
423
+ else:
424
+ text = data_sample['question'] + '\n' + data_sample['options']
425
+
426
+ text = DEFAULT_IMAGE_TOKEN + '\n' + text
427
+
428
+ if is_cn_string(text):
429
+ text = text + '请直接回答选项字母。'
430
+ else:
431
+ text = text + ("Answer with the option's letter from the "
432
+ 'given choices directly.')
433
+
434
+ if args.prompt_template:
435
+ prompt_text = ''
436
+ template = PROMPT_TEMPLATE[args.prompt_template]
437
+ prompt_text += template['INSTRUCTION'].format(
438
+ input=text, round=1, bot_name=args.bot_name)
439
+ else:
440
+ prompt_text = text
441
+ inputs = prompt_text
442
+
443
+ image = data_sample['img'].convert('RGB')
444
+ image = expand2square(
445
+ image, tuple(int(x * 255) for x in image_processor.image_mean))
446
+ image = image_processor.preprocess(
447
+ image, return_tensors='pt')['pixel_values'][0]
448
+ image = image.cuda().unsqueeze(0)
449
+ visual_outputs = visual_encoder(image, output_hidden_states=True)
450
+ pixel_values = projector(
451
+ visual_outputs.hidden_states[args.visual_select_layer][:, 1:])
452
+
453
+ chunk_encode = []
454
+ for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
455
+ if idx == 0:
456
+ cur_encode = tokenizer.encode(chunk)
457
+ else:
458
+ cur_encode = tokenizer.encode(chunk, add_special_tokens=False)
459
+ chunk_encode.append(cur_encode)
460
+ assert len(chunk_encode) == 2
461
+ ids = []
462
+ for idx, cur_chunk_encode in enumerate(chunk_encode):
463
+ ids.extend(cur_chunk_encode)
464
+ if idx != len(chunk_encode) - 1:
465
+ ids.append(IMAGE_TOKEN_INDEX)
466
+ ids = torch.tensor(ids).cuda().unsqueeze(0)
467
+ mm_inputs = prepare_inputs_labels_for_multimodal(
468
+ llm=llm, input_ids=ids, pixel_values=pixel_values)
469
+
470
+ generate_output = llm.generate(
471
+ **mm_inputs,
472
+ generation_config=gen_config,
473
+ streamer=None,
474
+ bos_token_id=tokenizer.bos_token_id,
475
+ stopping_criteria=stop_criteria)
476
+
477
+ predict = tokenizer.decode(
478
+ generate_output[0], skip_special_tokens=True).strip()
479
+ cur_result = {}
480
+ cur_result['question'] = data_sample.get('question')
481
+ cur_result.update(data_sample.get('options_dict'))
482
+ cur_result['prediction'] = predict
483
+ if data_sample.get('category') is not None:
484
+ cur_result['category'] = data_sample.get('category')
485
+ if data_sample.get('l2-category') is not None:
486
+ cur_result['l2-category'] = data_sample.get('l2-category')
487
+ cur_result['index'] = data_sample.get('index')
488
+ cur_result['split'] = data_sample.get('split')
489
+ cur_result['answer'] = data_sample.get('answer')
490
+ results.append(cur_result)
491
+
492
+ results = collect_results(results, n_samples)
493
+
494
+ if get_rank() == 0:
495
+
496
+ results_df = pd.DataFrame(results)
497
+ with pd.ExcelWriter(results_xlsx_path, engine='openpyxl') as writer:
498
+ results_df.to_excel(writer, index=False)
499
+
500
+ if dataset.split == 'dev':
501
+ results_dict = dataset.eval_result(results_df, show=True)
502
+ with open(results_json_path, 'w') as f:
503
+ json.dump(results_dict, f, indent=2)
504
+ else:
505
+ print('All done!')
506
+
507
+
508
+ if __name__ == '__main__':
509
+
510
+ main()