File size: 16,341 Bytes
46280bb
769e287
46280bb
51e5ad2
46280bb
769e287
51e5ad2
46280bb
 
 
 
 
51e5ad2
46280bb
 
769e287
 
 
46280bb
 
769e287
 
 
46280bb
 
 
 
 
 
 
 
 
 
 
 
 
769e287
 
 
46280bb
769e287
 
 
46280bb
 
 
 
 
 
 
 
 
 
 
769e287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46280bb
 
 
 
769e287
 
 
46280bb
 
769e287
46280bb
769e287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46280bb
 
 
 
 
 
769e287
46280bb
 
 
 
 
 
 
 
 
 
 
 
 
769e287
46280bb
 
 
 
 
 
 
 
 
 
769e287
 
 
 
 
51e5ad2
 
 
769e287
 
51e5ad2
 
 
46280bb
769e287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46280bb
769e287
 
 
46280bb
769e287
 
 
 
 
 
 
46280bb
51e5ad2
46280bb
51e5ad2
d6e13c5
51e5ad2
d6e13c5
 
46280bb
769e287
 
51e5ad2
 
 
 
d6e13c5
 
51e5ad2
d6e13c5
 
 
 
 
 
 
 
 
51e5ad2
d6e13c5
 
 
 
 
 
51e5ad2
d6e13c5
 
 
 
 
 
 
51e5ad2
d6e13c5
 
 
 
 
 
 
 
 
 
 
51e5ad2
d6e13c5
 
 
 
 
 
 
 
51e5ad2
 
d6e13c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51e5ad2
 
46280bb
 
51e5ad2
 
 
 
 
 
 
 
 
 
 
 
 
46280bb
51e5ad2
 
 
 
 
 
 
46280bb
51e5ad2
46280bb
51e5ad2
46280bb
 
51e5ad2
 
 
 
 
 
 
 
 
 
 
769e287
d6e13c5
 
 
 
 
 
46280bb
51e5ad2
 
 
 
 
 
d6e13c5
51e5ad2
 
 
 
 
d6e13c5
 
51e5ad2
 
 
d6e13c5
 
 
 
 
51e5ad2
 
d6e13c5
46280bb
 
 
 
 
 
 
 
51e5ad2
46280bb
 
51e5ad2
46280bb
51e5ad2
 
 
 
 
 
 
46280bb
 
51e5ad2
46280bb
51e5ad2
d6e13c5
51e5ad2
d6e13c5
51e5ad2
 
 
d6e13c5
 
46280bb
 
51e5ad2
d6e13c5
 
 
 
 
46280bb
 
51e5ad2
 
 
 
 
 
 
 
 
 
 
769e287
51e5ad2
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
import copy
import os
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import torch
from torch.nn import CrossEntropyLoss
from torch.nn.utils.rnn import pad_sequence
import warnings
from torch import Tensor, nn

from transformers import (
    PreTrainedModel,
    PreTrainedTokenizer,
    Blip2VisionModel,
    Blip2QFormerModel,
    Blip2Model,
    Blip2PreTrainedModel,
    Blip2ForConditionalGeneration,
    GenerationConfig,
)
from transformers.models.blip_2.modeling_blip_2 import (
    Blip2ForConditionalGenerationModelOutput,
)
from transformers.utils import logging
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList

from .modeling_chatglm import (
    ChatGLMForConditionalGeneration,
    InvalidScoreLogitsProcessor,
)
from .configuration_blip2chatglm import Blip2ChatGLMConfig


logger = logging.get_logger(__name__)


class Blip2ChatGLMForConditionalGeneration(Blip2ForConditionalGeneration):
    config_class = Blip2ChatGLMConfig

    def __init__(self, config: Blip2ChatGLMConfig):
        Blip2PreTrainedModel.__init__(self, config)
        # NOTE: we only initialize Blip2PreTrainedModel
        # directly call super().__init__() will cause error since ChatGLM cannot be found by AutoModel

        self.vision_model = Blip2VisionModel(config.vision_config)

        self.query_tokens = nn.Parameter(
            torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)
        )
        self.qformer = Blip2QFormerModel(config.qformer_config)

        self.language_projection = nn.Linear(
            config.qformer_config.hidden_size, config.text_config.hidden_size
        )
        self.language_model = ChatGLMForConditionalGeneration(config.text_config)

        # Initialize weights and apply final processing
        # self.post_init()

    def setup_dtype(self, vision_encoder_dtype: str = "fp32", lm_dtype: str = "fp16"):
        if vision_encoder_dtype == "fp32":
            self.vision_model = self.vision_model.float()
        elif vision_encoder_dtype == "fp16":
            self.vision_model = self.vision_model.half()
        else:
            raise NotImplementedError(
                f"Unsupported vision_encoder_dtype: {vision_encoder_dtype}"
            )

        if lm_dtype == "fp32":
            self.language_model = self.language_model.float()
        elif lm_dtype == "fp16":
            self.language_model = self.language_model.half()
        elif lm_dtype == "int4":
            self.language_model = self.language_model.half().quantize(4)
        elif lm_dtype == "int8":
            self.language_model = self.language_model.half().quantize(8)
        else:
            raise NotImplementedError(f"Unsupported lm_dtype: {lm_dtype}")

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        input_ids: torch.FloatTensor,
        image_slot_offset: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        labels: Optional[torch.LongTensor] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]:
        """_summary_

        Args:
            pixel_values (torch.FloatTensor): _description_
            input_ids (torch.FloatTensor): input_ids[:, :num_query_tokens] should be filled with tokenizer.unk_token_id
            image_slot_offset (Optional[torch.LongTensor], optional): if not set, all vtokens are placed as prefix (image_slot_offset = torch.zeros(bsz)). Defaults to None.
            attention_mask (Optional[torch.LongTensor], optional): _description_. Defaults to None.
            output_attentions (Optional[bool], optional): _description_. Defaults to None.
            output_hidden_states (Optional[bool], optional): _description_. Defaults to None.
            labels (Optional[torch.LongTensor], optional): _description_. Defaults to None.
            return_dict (Optional[bool], optional): _description_. Defaults to None.

        Returns:
            Union[Tuple, Blip2ForConditionalGenerationModelOutput]: _description_
        """
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        # step 1: forward the images through the vision encoder,
        # to get image embeddings of shape (batch_size, seq_len, hidden_size)
        vision_outputs = self.vision_model(
            pixel_values=pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        image_embeds = vision_outputs[0]

        # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
        image_attention_mask = torch.ones(
            image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device
        )

        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
        query_outputs = self.qformer(
            query_embeds=query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        query_output = query_outputs[0]

        # step 3: use the language model, conditioned on the query outputs and the prompt
        language_model_inputs = self.language_projection(query_output)
        inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
        if image_slot_offset is None:
            # image as prefix
            # update data to avoid inplace operation of leaf Variable
            inputs_embeds.data[
                :, : self.config.num_query_tokens, :
            ] = language_model_inputs
        else:
            for i, offset in enumerate(image_slot_offset):
                inputs_embeds.data[
                    i, offset : offset + self.config.num_query_tokens, :
                ] = language_model_inputs[i]

        outputs = self.language_model(
            input_ids=input_ids,
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        logits = outputs.logits if return_dict else outputs[0]
        loss = None
        # we compute the loss here since we need to take into account the sequence length of the query embeds
        if labels is not None:
            logits = logits[:, -labels.size(1) :, :]
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous().to(logits.device)

            # Flatten the tokens
            loss_fct = CrossEntropyLoss(reduction="mean")

            loss = loss_fct(
                shift_logits.view(-1, self.config.text_config.vocab_size),
                shift_labels.view(-1),
            )

        if not return_dict:
            output = (logits, vision_outputs, query_outputs, outputs)
            return ((loss,) + output) if loss is not None else output

        return Blip2ForConditionalGenerationModelOutput(
            loss=loss,
            logits=logits,
            vision_outputs=vision_outputs,
            qformer_outputs=query_outputs,
            language_model_outputs=outputs,
        )

    def prepare_inputs_for_chat(
        self,
        tokenizer: PreTrainedTokenizer,
        batch_messages: List[List[Tuple[str, str, List[Tuple[torch.Tensor, int]]]]],
        max_length: int,
        user_role: str = "问",
        bot_role: str = "答",
    ):
        device = self.device
        nvtokens = self.config.num_query_tokens
        # 1. Prepare token ids
        all_images = []
        all_image_slots = []
        all_input_ids = []
        for messages in batch_messages:
            images = []
            image_slots = []
            input_ids = []

            round_roles = [set()]
            for role, qtext, qimgs in messages:
                if role in round_roles[-1]:
                    # a new round (not the first round)
                    input_ids += tokenizer(
                        f"\n[Round {len(round_roles)}]\n{role}:",
                        add_special_tokens=False,
                    ).input_ids
                    round_roles.append({role})
                else:
                    round_roles[-1].add(role)
                    input_ids += tokenizer(
                        # For first role, no new line
                        f"\n{role}:" if len(input_ids) != 0 else f"{role}:", add_special_tokens=False
                    ).input_ids
                cur_index = 0
                for qimg, img_idx in qimgs:
                    if img_idx > cur_index:
                        input_ids += tokenizer(
                            qtext[cur_index:img_idx], add_special_tokens=False
                        ).input_ids
                        cur_index = img_idx
                    # image slot, embedding will be replaced by image embeddings
                    image_slots.append(len(input_ids))
                    input_ids += [tokenizer.unk_token_id] * nvtokens
                    images.append(qimg)
                input_ids += tokenizer(
                    qtext[cur_index:], add_special_tokens=False
                ).input_ids
            if len(round_roles) == 1:
                # only 1 round
                if len(round_roles[0]) == 1 and user_role in round_roles[0]:
                    # only user role
                    input_ids += tokenizer("").input_ids
                else:
                    input_ids += tokenizer(f"\n{bot_role}:").input_ids
            else:
                # add tag for round 0
                input_ids = (
                    tokenizer(f"[Round 0]\n", add_special_tokens=False).input_ids
                    + input_ids
                )
                input_ids += tokenizer(f"\n{bot_role}:").input_ids

            if len(input_ids) >= max_length:
                image_slots_after_truncate = []
                images_after_truncate = []
                truncate_index = len(input_ids) - max_length
                for image_slot, image in zip(image_slots, images):
                    # truncate from left
                    if len(input_ids) - image_slot < max_length:
                        image_slots_after_truncate.append(image_slot)
                        images_after_truncate.append(image)
                    elif len(input_ids) - (image_slot + nvtokens) < max_length:
                        # in-contact image slot is not allowed
                        truncate_index = max(truncate_index, image_slot + nvtokens)
                for i, image_slot in enumerate(image_slots_after_truncate):
                    image_slots_after_truncate[i] = image_slot - truncate_index
                input_ids = input_ids[truncate_index:]
                image_slots = image_slots_after_truncate
                images = images_after_truncate

            # print(tokenizer.convert_ids_to_tokens(input_ids))

            all_images.extend(images)
            all_image_slots.append(image_slots)
            all_input_ids.append(input_ids)

        # 2. Prepare image embeddings
        if len(all_images) != 0:
            vision_outputs = self.vision_model.forward(torch.cat(all_images, dim=0))
            all_image_embeds = vision_outputs[0]
            indices_or_sections = [len(chunk) for chunk in all_image_slots]
            indices_or_sections = np.cumsum(indices_or_sections)
            all_vtokens = []
            # TODO: qformer not batched
            for image_embeds in torch.tensor_split(
                all_image_embeds, tuple(indices_or_sections)
            ):
                image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
                    device
                )

                query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
                query_outputs = self.qformer.forward(
                    query_embeds=query_tokens,
                    encoder_hidden_states=image_embeds,
                    encoder_attention_mask=image_atts,
                )
                query_output = query_outputs[0]

                all_vtokens.append(self.language_projection(query_output))
        else:
            all_vtokens = None

        # 3. Place image embeddings into slots
        input_ids = (
            torch.ones(
                (len(all_input_ids), max(len(ids) for ids in all_input_ids)),
                dtype=torch.long,
            )
            * tokenizer.pad_token_id
        )
        for i, ids in enumerate(all_input_ids):
            # pad left
            input_ids[i][-len(ids) :] = torch.as_tensor(ids, dtype=torch.long)
        input_ids = input_ids.to(device)
        inputs_embeds = self.language_model.transformer.word_embeddings(input_ids)
        if all_vtokens is not None:
            for i, (image_slots, vtokens) in enumerate(
                zip(all_image_slots, all_vtokens)
            ):
                for slot, vimg in zip(image_slots, vtokens):
                    inputs_embeds[i][slot : slot + nvtokens, :] = vimg

        return input_ids, inputs_embeds

    @torch.no_grad()
    def batch_chat(
        self,
        tokenizer: PreTrainedTokenizer,
        batch_messages: List[List[Tuple[str, str, List[Tuple[torch.Tensor, int]]]]],
        max_length: int = 2048,
        num_beams=1,
        do_sample=True,
        top_p=0.7,
        temperature=0.95,
        user_role: str = "问",
        bot_role: str = "答",
        **kwargs,
    ):
        input_ids, inputs_embeds = self.prepare_inputs_for_chat(
            tokenizer=tokenizer,
            batch_messages=batch_messages,
            max_length=max_length,
            user_role=user_role,
            bot_role=bot_role,
        )

        logits_processor = LogitsProcessorList()
        logits_processor.append(InvalidScoreLogitsProcessor())
        gen_kwargs = {
            "max_length": max_length,
            "num_beams": num_beams,
            "do_sample": do_sample,
            "top_p": top_p,
            "temperature": temperature,
            "logits_processor": logits_processor,
            **kwargs,
        }

        outputs = self.language_model.generate(
            input_ids=input_ids, inputs_embeds=inputs_embeds, **gen_kwargs
        )
        responses = []
        for i, output in enumerate(outputs.tolist()):
            output = output[len(input_ids[i]) :]
            response = tokenizer.decode(output)
            responses.append(self.language_model.process_response(response))
        return responses

    @torch.no_grad()
    def stream_chat(
        self,
        tokenizer: PreTrainedTokenizer,
        messages: List[Tuple[str, str, List[Tuple[torch.Tensor, int]]]],
        num_beams=5,
        max_length=512,
        top_p=0.9,
        do_sample=True,
        temperature=1,
        user_role: str = "问",
        bot_role: str = "答",
        **kwargs,
    ):
        input_ids, inputs_embeds = self.prepare_inputs_for_chat(
            tokenizer=tokenizer,
            batch_messages=[messages],
            max_length=max_length,
            user_role=user_role,
            bot_role=bot_role,
        )

        logits_processor = LogitsProcessorList()
        logits_processor.append(InvalidScoreLogitsProcessor())
        gen_kwargs = {
            "max_length": max_length,
            "num_beams": num_beams,
            "do_sample": do_sample,
            "top_p": top_p,
            "temperature": temperature,
            "logits_processor": logits_processor,
            **kwargs,
        }

        for outputs in self.language_model.stream_generate(
            input_ids=input_ids, inputs_embeds=inputs_embeds, **gen_kwargs
        ):
            outputs = outputs.tolist()[0][len(input_ids[0]) :]
            response = tokenizer.decode(outputs)
            response = self.language_model.process_response(response)
            yield response