File size: 25,854 Bytes
d251895 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 |
import copy
import numpy as np
import torch
import torch.utils.checkpoint
from peft.mapping import get_peft_model
from peft.peft_model import PeftModel
from peft.tuners.lora import LoraConfig
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoTokenizer
from transformers.generation.streamers import BaseStreamer
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.utils import (
replace_return_docstrings,
)
from designvlm.crello_dataset import int_wrap
from designvlm.loading import LoraArguments
from yosematvlm.configuration_internlm_xcomposer2 import InternLMXcomposer2Config
from yosematvlm.modeling_internlm_xcomposer2 import (
FROM_TOKEN_2,
InternLMXComposer2ForCausalLM,
)
IM_END_TOKEN = 92542
EOS = 2 # </s>
SWITCH_TOKEN_LENGTH = 6 # after <|im_end|>, there are 5 additional tokens ['\n', '<|im_start|>', 'ass', 'istant', '\n']. You don't want the model to learn these tokens.
MASK_ID = -100
class IXCLayoutConfig(InternLMXcomposer2Config):
def __init__(
self,
vocab_size=103168,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
bias=True,
rope_theta=10000,
rope_scaling=None,
attn_implementation="flash_attention_2",
max_length: int = 16384,
discrete_coordinate_tokens: int | None = None,
**kwargs,
):
self.discrete_coordinate_tokens = discrete_coordinate_tokens
super().__init__(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
hidden_act=hidden_act,
max_position_embeddings=max_position_embeddings,
initializer_range=initializer_range,
rms_norm_eps=rms_norm_eps,
use_cache=use_cache,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
bias=bias,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
attn_implementation=attn_implementation,
**kwargs,
)
self.max_length = max_length
self.use_cache = False
@classmethod
def with_internlm_config(
cls,
config: InternLMXcomposer2Config,
discrete_custom_tokens: int | None = None,
max_length: int = 16384,
):
return cls(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
num_hidden_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
hidden_act=config.hidden_act,
max_position_embeddings=config.max_position_embeddings,
initializer_range=config.initializer_range,
rms_norm_eps=config.rms_norm_eps,
use_cache=config.use_cache,
pad_token_id=config.pad_token_id,
bos_token_id=config.bos_token_id,
eos_token_id=config.eos_token_id,
tie_word_embeddings=config.tie_word_embeddings,
bias=config.bias,
rope_theta=config.rope_theta,
rope_scaling=config.rope_scaling,
attn_implementation=config.attn_implementation,
discrete_coordinate_tokens=discrete_custom_tokens,
max_length=max_length,
)
class IXCLayout(InternLMXComposer2ForCausalLM):
config_class = IXCLayoutConfig
def __init__(self, config: IXCLayoutConfig):
super().__init__(config)
self.vit.vision_tower.vision_model.post_layernorm = torch.nn.Identity()
self.tokenizer = AutoTokenizer.from_pretrained(
"yosematvlm",
padding_side="right",
use_fast=False,
trust_remote_code=True,
) # type: ignore
# Add coordinate tokens
self.coordinate_token_ids: set[int] = set()
if config.discrete_coordinate_tokens is not None:
self.add_coordinate_tokens(config.discrete_coordinate_tokens)
self.config = config
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=InternLMXcomposer2Config
)
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
**kwargs,
) -> CausalLMOutputWithPast:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
"""
infer_mode = "base"
return_dict, output_attentions, output_hidden_states = self.or_config(
return_dict=return_dict,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
if "samples" in kwargs:
# Training
samples = kwargs["samples"]
# encode text
text = samples["text_input"]
# encode image
image = samples["image"][0]
bs = len(samples["text_input"][0])
image_nums = []
temp_image = []
for im in image:
if type(im) is list:
image_nums.append(len(im))
temp_image.extend(im)
else:
image_nums.append(1)
temp_image.append(im)
image = temp_image
assert type(image) is list and len(image_nums) == bs
to_regress_embeds, attention_mask, targets, im_mask = self.interleav_wrap(
image, text, image_nums
)
inputs_embeds = to_regress_embeds[:, : self.max_length] # type: ignore
attention_mask = attention_mask[:, : self.max_length] # type: ignore
targets = targets[:, : self.max_length]
im_mask = im_mask[:, : self.max_length].bool()
labels = targets # type: ignore
elif inputs_embeds is not None or input_ids is not None:
im_mask = kwargs["im_mask"]
if im_mask is None and inputs_embeds is not None:
im_mask = torch.zeros(inputs_embeds.shape[:2]).to(inputs_embeds.device)
im_mask = im_mask.bool()
else:
raise ValueError(
"Either samples, inputs_embeds or input_ids should be provided."
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
im_mask=im_mask,
infer_mode=infer_mode,
)
logits: torch.Tensor = self.output(
outputs.last_hidden_state
).float() # B x L x V
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
ce_loss: torch.Tensor = loss_fct(shift_logits, shift_labels)
assert not ce_loss.isnan().any()
kl_loss = self.coordinate_kl_loss(logits, labels)
assert not kl_loss.isnan().any()
loss = ce_loss + kl_loss
else:
loss = None
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if not return_dict:
output = (logits,) + outputs[1:]
return (ce_loss,) + output if ce_loss is not None else output
return CausalLMOutputWithPast(
loss=loss, # type: ignore
logits=logits, # type: ignore
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def coordinate_kl_loss(
self, logits: torch.Tensor, labels: torch.Tensor, eps: float = 1e-9
) -> torch.Tensor:
"""
For coordinate token, calculate the KL loss between the predicted logits and the target labels.
Instead of one-hot vector, we assume that the target labels are the probability distribution of the target token.
The distribution is a discrete gaussian distribution with mean at the target token and variance of 1.
Args:
logits: B x T x V; The predicted logits of the model.
labels: B x T; The target labels of the model.
"""
label_std_dev = 2.0
coordinate_token_ids = torch.tensor(
list(self.coordinate_token_ids), device=labels.device, dtype=labels.dtype
)
assert len(self.coordinate_token_ids) > 0
# Get the mask of the coordinate tokens
is_label_coordinate = torch.isin(
labels,
coordinate_token_ids,
).type_as(labels) # B x T
# Get the target labels of the coordinate tokens
# Range of indices
indices = torch.arange(
0, self.vocab_size, dtype=labels.dtype, device=logits.device
).repeat(labels.shape[0], labels.shape[1], 1) # B x T x V
# Indices that are not coordinate tokens are set to 0
# To reduce memory consumption, we iterate over the sequence length
is_indice_coordinate_token = torch.stack(
[
torch.isin(indices[:, idx], coordinate_token_ids)
for idx in range(indices.size(1))
],
dim=1,
) # B x T x V
total_mask = is_label_coordinate.unsqueeze(-1) * is_indice_coordinate_token
# Create a Gaussian distribution centered at the label index
gauss_label = torch.exp(
-0.5 * ((indices - labels.unsqueeze(-1)) / label_std_dev) ** 2
) # B x T x V
# Apply the mask so that only the coordinate tokens are considered
gauss_label = gauss_label * total_mask + eps # Add eps for numerical stability
# Normalize the distribution
gauss_label /= gauss_label.sum(dim=-1, keepdim=True)
pointwise_kl = (
nn.functional.kl_div(
logits.log_softmax(dim=-1), gauss_label, reduction="none"
).nan_to_num()
* total_mask
)
kl_loss = pointwise_kl.sum(-1).mean()
assert kl_loss > 0
return kl_loss
def interleav_wrap(
self,
img_list: list[torch.Tensor], # V x 1 x 3 x H x W
text_list_list: list[list[str]], # B x 1 x str
image_nums: list[int], # B
):
temp_embeds = []
temp_im_mask = []
temp_tars = []
# encode_image
if len(img_list) > 0:
img_embeds, img_split = self.vit(
img_list, self.plora_glb_GN, self.plora_sub_GN
)
img_embeds = self.vision_proj(img_embeds)
else:
img_embeds = None
img_split = []
text_list = text_list_list[0]
for idx, text in enumerate(text_list):
image_num = image_nums[idx]
im_id = int(np.sum(image_nums[:idx]))
images = []
for i in range(image_nums[idx]):
st = int(np.sum(img_split[: im_id + i]))
sp = img_split[im_id + i]
temp_img = img_embeds[:, st : st + sp] # type: ignore
images.append(temp_img)
if image_num == 1 and text.find("<ImageHere>") == -1:
text = "<ImageHere>" + text
parts = text.split("<ImageHere>")
wrap_tokens, wrap_embeds, wrap_im_mask = [], [], []
temp_len = 0
need_bos = True
for idx, part in enumerate(parts):
if len(part) > 0:
part_tokens = self.tokenizer(
part,
return_tensors="pt",
padding="longest",
add_special_tokens=need_bos,
).to(self.device) # type: ignore
if need_bos:
need_bos = False
wrap_tokens.append(part_tokens.input_ids)
part_embeds = self.model.tok_embeddings(part_tokens.input_ids)
wrap_embeds.append(part_embeds)
wrap_im_mask.append(
torch.zeros(part_embeds.shape[:2]).to(self.device)
)
temp_len += part_embeds.shape[1]
if idx < image_num:
wrap_embeds.append(images[idx])
wrap_token = (
torch.ones(images[idx].shape[:2], dtype=torch.long).to(
self.device
)
* -100
)
wrap_tokens.append(wrap_token)
wrap_im_mask.append(
torch.ones(images[idx].shape[:2]).to(self.device)
)
temp_len += images[idx].shape[1]
if temp_len > self.max_length:
break
wrap_tokens = torch.cat(wrap_tokens, dim=1)
wrap_embeds = torch.cat(wrap_embeds, dim=1)
wrap_im_mask = torch.cat(wrap_im_mask, dim=1)
wrap_target = self.mask_human_targets(wrap_tokens).to(self.device)
temp_embeds.append(wrap_embeds)
temp_im_mask.append(wrap_im_mask)
temp_tars.append(wrap_target)
temp_max_len = np.max([i.shape[1] for i in temp_embeds])
temp_max_len = min(temp_max_len, self.max_length)
final_input, final_atts, final_tars, final_mask = [], [], [], []
pad = torch.ones([1, 1]) * self.tokenizer.pad_token_id # type: ignore
pad = pad.long().to(self.device)
pad_emb = self.model.tok_embeddings(pad)
for idx in range(len(temp_embeds)):
temp_len = temp_embeds[idx].shape[1]
if temp_len >= temp_max_len:
final_input.append(temp_embeds[idx][:, :temp_max_len])
final_atts.append(
torch.ones(1, temp_max_len).to(wrap_target.dtype).to(self.device)
)
final_tars.append(temp_tars[idx][:, :temp_max_len])
final_mask.append(temp_im_mask[idx][:, :temp_max_len])
else:
final_input.append(
torch.cat(
[
temp_embeds[idx],
pad_emb.repeat(1, temp_max_len - temp_len, 1),
],
dim=1,
)
)
final_atts.append(
torch.cat(
[
torch.ones(1, temp_len),
torch.zeros(1, temp_max_len - temp_len),
],
dim=1,
)
.to(wrap_target.dtype)
.to(self.device)
)
final_tars.append(
torch.cat(
[
temp_tars[idx],
(torch.ones(1, temp_max_len - temp_len) * MASK_ID)
.to(wrap_target.dtype)
.to(self.device),
],
dim=1,
)
)
final_mask.append(
torch.cat(
[
temp_im_mask[idx],
(torch.zeros(1, temp_max_len - temp_len))
.to(wrap_target.dtype)
.to(self.device),
],
dim=1,
)
)
inputs_embeds = torch.cat(final_input, dim=0)
attention_mask = torch.cat(final_atts, dim=0)
targets = torch.cat(final_tars, dim=0)
im_mask = torch.cat(final_mask, dim=0)
return inputs_embeds, attention_mask, targets, im_mask
def mask_human_targets(self, input_ids: torch.Tensor) -> torch.Tensor:
target_batch = []
for bs in range(input_ids.shape[0]):
ids = input_ids[bs]
targets = copy.deepcopy(ids)
end_count = 0
last_eoa = 0
for i, temp_id in enumerate(ids):
# Counterintuitively, IM_END_TOKEN is the token for the end of utterance
# In the whole source code, "[UNUSED_TOKEN_145]" corresponds to IM_END_TOKEN
if temp_id == IM_END_TOKEN:
if end_count % 2 == 0:
targets[last_eoa : i + SWITCH_TOKEN_LENGTH] = MASK_ID
else:
last_eoa = i + 1
end_count += 1
# # eos and following pad
elif temp_id == EOS:
# loss on eos, but not on pad
targets[i + 1 :] = MASK_ID
break
target_batch.append(targets.unsqueeze(0))
target_batch = torch.cat(target_batch, dim=0)
return target_batch
@classmethod
def from_ixc_pretrained(
cls,
ixc_pretrained_model_name_or_path: str,
max_length: int,
img_size: int,
discrete_custom_tokens: int | None = 128,
) -> torch.nn.Module:
r"""
Instantiate a pretrained InternLMXComposer2 model from a pre-trained model configuration.
"""
config: IXCLayoutConfig = AutoConfig.from_pretrained(
ixc_pretrained_model_name_or_path,
trust_remote_code=True,
)
config = IXCLayoutConfig.with_internlm_config(
config,
max_length=max_length,
discrete_custom_tokens=discrete_custom_tokens,
)
model: IXCLayout = super().from_pretrained(
ixc_pretrained_model_name_or_path,
config=config,
torch_dtype=torch.bfloat16,
) # type: ignore
if img_size != 336:
model.vit.resize_pos()
model.vit.requires_grad_(False)
model.vision_proj.requires_grad_(True)
return model
def or_config(
self,
return_dict: bool | None,
output_attentions: bool | None,
output_hidden_states: bool | None,
):
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
return return_dict, output_attentions, output_hidden_states
@torch.no_grad()
def chat(
self,
query: str,
image: list[tuple[str, str]] | torch.Tensor = [],
hd_num: int = 24,
history: list[tuple[str, str]] = [],
streamer: BaseStreamer | None = None,
max_new_tokens: int = 1024,
do_sample: bool = True,
num_beams: int = 1,
temperature: float = 1.0,
top_p: float = 0.8,
repetition_penalty: float = 1.005,
infer_mode: str = "base",
use_meta: bool = False,
meta_instruction: str = "You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).\n"
"- InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
"- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the language chosen by the user such as English and 中文.\n"
"- InternLM-XComposer (浦语·灵笔) is capable of comprehending and articulating responses effectively based on the provided image.",
**kwargs,
):
if not use_meta:
meta_instruction = ""
inputs, im_mask, _ = self.interleav_wrap_chat(
query,
image,
history=history,
meta_instruction=meta_instruction,
hd_num=hd_num,
)
inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
# also add end-of-assistant token in eos token id to avoid unnecessary generation
eos_token_id = [
self.tokenizer.eos_token_id,
self.tokenizer.convert_tokens_to_ids([FROM_TOKEN_2])[0], # type: ignore
]
outputs = self.generate(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
eos_token_id=eos_token_id,
repetition_penalty=repetition_penalty,
im_mask=im_mask,
infer_mode=infer_mode,
**kwargs,
)
outputs = outputs[0].cpu().tolist()
response = self.tokenizer.decode(outputs, skip_special_tokens=True)
response = response.split(FROM_TOKEN_2)[0]
history = history + [(query, response)]
return response, history
def add_coordinate_tokens(self, bins: int):
start = -bins
end = bins * 2
new_tokens = [int_wrap(idx) for idx in range(start, end + 1)]
new_token_ids = self._add_tokens(new_tokens)
self.coordinate_token_ids.update(new_token_ids)
self.config.discrete_coordinate_tokens = end // 2
def _add_tokens(self, new_tokens: list[str]) -> list[int]:
prev_vocab_size = len(self.tokenizer)
vocab = self.tokenizer.get_vocab()
new_tokens = [token for token in new_tokens if token not in vocab]
self.tokenizer.add_tokens(new_tokens) # type: ignore
self.model.resize_token_embeddings(len(self.tokenizer))
self.vocab_size = len(self.tokenizer)
# self.output needs to be resized accordingly but without loosing the weight
new_output = nn.Linear(
self.model.config.hidden_size,
self.vocab_size,
bias=False,
dtype=self.output.weight.dtype,
device=self.output.weight.device,
).to(self.device)
new_output.weight.data[: self.output.weight.shape[0]] = self.output.weight.data
self.output = new_output
return list(range(prev_vocab_size, self.vocab_size))
def setup_ixclayout(
model_name_or_path: str,
max_length: int,
img_size: int,
use_lora: bool,
gradient_checkpointing: bool,
discrete_custom_tokens: int | None = 128,
) -> IXCLayout | PeftModel:
model = IXCLayout.from_ixc_pretrained(
ixc_pretrained_model_name_or_path=model_name_or_path,
max_length=max_length,
img_size=img_size,
discrete_custom_tokens=discrete_custom_tokens,
)
if use_lora:
lora_args = LoraArguments()
lora_config = LoraConfig(
r=lora_args.lora_r,
lora_alpha=lora_args.lora_alpha,
target_modules=lora_args.lora_target_modules,
lora_dropout=lora_args.lora_dropout,
bias=lora_args.lora_bias,
task_type="CAUSAL_LM", # type: ignore
)
model = get_peft_model(model, lora_config) # type: ignore
model.print_trainable_parameters()
if gradient_checkpointing:
model.enable_input_require_grads()
model.vit.vision_tower.gradient_checkpointing_enable(
{"use_reentrant": True}
)
return model # type: ignore
|