designvlm / ixc_layout.py
Yosemat's picture
Upload folder using huggingface_hub
d251895 verified
import copy
import numpy as np
import torch
import torch.utils.checkpoint
from peft.mapping import get_peft_model
from peft.peft_model import PeftModel
from peft.tuners.lora import LoraConfig
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoTokenizer
from transformers.generation.streamers import BaseStreamer
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.utils import (
replace_return_docstrings,
)
from designvlm.crello_dataset import int_wrap
from designvlm.loading import LoraArguments
from yosematvlm.configuration_internlm_xcomposer2 import InternLMXcomposer2Config
from yosematvlm.modeling_internlm_xcomposer2 import (
FROM_TOKEN_2,
InternLMXComposer2ForCausalLM,
)
IM_END_TOKEN = 92542
EOS = 2 # </s>
SWITCH_TOKEN_LENGTH = 6 # after <|im_end|>, there are 5 additional tokens ['\n', '<|im_start|>', 'ass', 'istant', '\n']. You don't want the model to learn these tokens.
MASK_ID = -100
class IXCLayoutConfig(InternLMXcomposer2Config):
def __init__(
self,
vocab_size=103168,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
bias=True,
rope_theta=10000,
rope_scaling=None,
attn_implementation="flash_attention_2",
max_length: int = 16384,
discrete_coordinate_tokens: int | None = None,
**kwargs,
):
self.discrete_coordinate_tokens = discrete_coordinate_tokens
super().__init__(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
hidden_act=hidden_act,
max_position_embeddings=max_position_embeddings,
initializer_range=initializer_range,
rms_norm_eps=rms_norm_eps,
use_cache=use_cache,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
bias=bias,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
attn_implementation=attn_implementation,
**kwargs,
)
self.max_length = max_length
self.use_cache = False
@classmethod
def with_internlm_config(
cls,
config: InternLMXcomposer2Config,
discrete_custom_tokens: int | None = None,
max_length: int = 16384,
):
return cls(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
num_hidden_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
hidden_act=config.hidden_act,
max_position_embeddings=config.max_position_embeddings,
initializer_range=config.initializer_range,
rms_norm_eps=config.rms_norm_eps,
use_cache=config.use_cache,
pad_token_id=config.pad_token_id,
bos_token_id=config.bos_token_id,
eos_token_id=config.eos_token_id,
tie_word_embeddings=config.tie_word_embeddings,
bias=config.bias,
rope_theta=config.rope_theta,
rope_scaling=config.rope_scaling,
attn_implementation=config.attn_implementation,
discrete_coordinate_tokens=discrete_custom_tokens,
max_length=max_length,
)
class IXCLayout(InternLMXComposer2ForCausalLM):
config_class = IXCLayoutConfig
def __init__(self, config: IXCLayoutConfig):
super().__init__(config)
self.vit.vision_tower.vision_model.post_layernorm = torch.nn.Identity()
self.tokenizer = AutoTokenizer.from_pretrained(
"yosematvlm",
padding_side="right",
use_fast=False,
trust_remote_code=True,
) # type: ignore
# Add coordinate tokens
self.coordinate_token_ids: set[int] = set()
if config.discrete_coordinate_tokens is not None:
self.add_coordinate_tokens(config.discrete_coordinate_tokens)
self.config = config
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=InternLMXcomposer2Config
)
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
**kwargs,
) -> CausalLMOutputWithPast:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
"""
infer_mode = "base"
return_dict, output_attentions, output_hidden_states = self.or_config(
return_dict=return_dict,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
if "samples" in kwargs:
# Training
samples = kwargs["samples"]
# encode text
text = samples["text_input"]
# encode image
image = samples["image"][0]
bs = len(samples["text_input"][0])
image_nums = []
temp_image = []
for im in image:
if type(im) is list:
image_nums.append(len(im))
temp_image.extend(im)
else:
image_nums.append(1)
temp_image.append(im)
image = temp_image
assert type(image) is list and len(image_nums) == bs
to_regress_embeds, attention_mask, targets, im_mask = self.interleav_wrap(
image, text, image_nums
)
inputs_embeds = to_regress_embeds[:, : self.max_length] # type: ignore
attention_mask = attention_mask[:, : self.max_length] # type: ignore
targets = targets[:, : self.max_length]
im_mask = im_mask[:, : self.max_length].bool()
labels = targets # type: ignore
elif inputs_embeds is not None or input_ids is not None:
im_mask = kwargs["im_mask"]
if im_mask is None and inputs_embeds is not None:
im_mask = torch.zeros(inputs_embeds.shape[:2]).to(inputs_embeds.device)
im_mask = im_mask.bool()
else:
raise ValueError(
"Either samples, inputs_embeds or input_ids should be provided."
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
im_mask=im_mask,
infer_mode=infer_mode,
)
logits: torch.Tensor = self.output(
outputs.last_hidden_state
).float() # B x L x V
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
ce_loss: torch.Tensor = loss_fct(shift_logits, shift_labels)
assert not ce_loss.isnan().any()
kl_loss = self.coordinate_kl_loss(logits, labels)
assert not kl_loss.isnan().any()
loss = ce_loss + kl_loss
else:
loss = None
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if not return_dict:
output = (logits,) + outputs[1:]
return (ce_loss,) + output if ce_loss is not None else output
return CausalLMOutputWithPast(
loss=loss, # type: ignore
logits=logits, # type: ignore
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def coordinate_kl_loss(
self, logits: torch.Tensor, labels: torch.Tensor, eps: float = 1e-9
) -> torch.Tensor:
"""
For coordinate token, calculate the KL loss between the predicted logits and the target labels.
Instead of one-hot vector, we assume that the target labels are the probability distribution of the target token.
The distribution is a discrete gaussian distribution with mean at the target token and variance of 1.
Args:
logits: B x T x V; The predicted logits of the model.
labels: B x T; The target labels of the model.
"""
label_std_dev = 2.0
coordinate_token_ids = torch.tensor(
list(self.coordinate_token_ids), device=labels.device, dtype=labels.dtype
)
assert len(self.coordinate_token_ids) > 0
# Get the mask of the coordinate tokens
is_label_coordinate = torch.isin(
labels,
coordinate_token_ids,
).type_as(labels) # B x T
# Get the target labels of the coordinate tokens
# Range of indices
indices = torch.arange(
0, self.vocab_size, dtype=labels.dtype, device=logits.device
).repeat(labels.shape[0], labels.shape[1], 1) # B x T x V
# Indices that are not coordinate tokens are set to 0
# To reduce memory consumption, we iterate over the sequence length
is_indice_coordinate_token = torch.stack(
[
torch.isin(indices[:, idx], coordinate_token_ids)
for idx in range(indices.size(1))
],
dim=1,
) # B x T x V
total_mask = is_label_coordinate.unsqueeze(-1) * is_indice_coordinate_token
# Create a Gaussian distribution centered at the label index
gauss_label = torch.exp(
-0.5 * ((indices - labels.unsqueeze(-1)) / label_std_dev) ** 2
) # B x T x V
# Apply the mask so that only the coordinate tokens are considered
gauss_label = gauss_label * total_mask + eps # Add eps for numerical stability
# Normalize the distribution
gauss_label /= gauss_label.sum(dim=-1, keepdim=True)
pointwise_kl = (
nn.functional.kl_div(
logits.log_softmax(dim=-1), gauss_label, reduction="none"
).nan_to_num()
* total_mask
)
kl_loss = pointwise_kl.sum(-1).mean()
assert kl_loss > 0
return kl_loss
def interleav_wrap(
self,
img_list: list[torch.Tensor], # V x 1 x 3 x H x W
text_list_list: list[list[str]], # B x 1 x str
image_nums: list[int], # B
):
temp_embeds = []
temp_im_mask = []
temp_tars = []
# encode_image
if len(img_list) > 0:
img_embeds, img_split = self.vit(
img_list, self.plora_glb_GN, self.plora_sub_GN
)
img_embeds = self.vision_proj(img_embeds)
else:
img_embeds = None
img_split = []
text_list = text_list_list[0]
for idx, text in enumerate(text_list):
image_num = image_nums[idx]
im_id = int(np.sum(image_nums[:idx]))
images = []
for i in range(image_nums[idx]):
st = int(np.sum(img_split[: im_id + i]))
sp = img_split[im_id + i]
temp_img = img_embeds[:, st : st + sp] # type: ignore
images.append(temp_img)
if image_num == 1 and text.find("<ImageHere>") == -1:
text = "<ImageHere>" + text
parts = text.split("<ImageHere>")
wrap_tokens, wrap_embeds, wrap_im_mask = [], [], []
temp_len = 0
need_bos = True
for idx, part in enumerate(parts):
if len(part) > 0:
part_tokens = self.tokenizer(
part,
return_tensors="pt",
padding="longest",
add_special_tokens=need_bos,
).to(self.device) # type: ignore
if need_bos:
need_bos = False
wrap_tokens.append(part_tokens.input_ids)
part_embeds = self.model.tok_embeddings(part_tokens.input_ids)
wrap_embeds.append(part_embeds)
wrap_im_mask.append(
torch.zeros(part_embeds.shape[:2]).to(self.device)
)
temp_len += part_embeds.shape[1]
if idx < image_num:
wrap_embeds.append(images[idx])
wrap_token = (
torch.ones(images[idx].shape[:2], dtype=torch.long).to(
self.device
)
* -100
)
wrap_tokens.append(wrap_token)
wrap_im_mask.append(
torch.ones(images[idx].shape[:2]).to(self.device)
)
temp_len += images[idx].shape[1]
if temp_len > self.max_length:
break
wrap_tokens = torch.cat(wrap_tokens, dim=1)
wrap_embeds = torch.cat(wrap_embeds, dim=1)
wrap_im_mask = torch.cat(wrap_im_mask, dim=1)
wrap_target = self.mask_human_targets(wrap_tokens).to(self.device)
temp_embeds.append(wrap_embeds)
temp_im_mask.append(wrap_im_mask)
temp_tars.append(wrap_target)
temp_max_len = np.max([i.shape[1] for i in temp_embeds])
temp_max_len = min(temp_max_len, self.max_length)
final_input, final_atts, final_tars, final_mask = [], [], [], []
pad = torch.ones([1, 1]) * self.tokenizer.pad_token_id # type: ignore
pad = pad.long().to(self.device)
pad_emb = self.model.tok_embeddings(pad)
for idx in range(len(temp_embeds)):
temp_len = temp_embeds[idx].shape[1]
if temp_len >= temp_max_len:
final_input.append(temp_embeds[idx][:, :temp_max_len])
final_atts.append(
torch.ones(1, temp_max_len).to(wrap_target.dtype).to(self.device)
)
final_tars.append(temp_tars[idx][:, :temp_max_len])
final_mask.append(temp_im_mask[idx][:, :temp_max_len])
else:
final_input.append(
torch.cat(
[
temp_embeds[idx],
pad_emb.repeat(1, temp_max_len - temp_len, 1),
],
dim=1,
)
)
final_atts.append(
torch.cat(
[
torch.ones(1, temp_len),
torch.zeros(1, temp_max_len - temp_len),
],
dim=1,
)
.to(wrap_target.dtype)
.to(self.device)
)
final_tars.append(
torch.cat(
[
temp_tars[idx],
(torch.ones(1, temp_max_len - temp_len) * MASK_ID)
.to(wrap_target.dtype)
.to(self.device),
],
dim=1,
)
)
final_mask.append(
torch.cat(
[
temp_im_mask[idx],
(torch.zeros(1, temp_max_len - temp_len))
.to(wrap_target.dtype)
.to(self.device),
],
dim=1,
)
)
inputs_embeds = torch.cat(final_input, dim=0)
attention_mask = torch.cat(final_atts, dim=0)
targets = torch.cat(final_tars, dim=0)
im_mask = torch.cat(final_mask, dim=0)
return inputs_embeds, attention_mask, targets, im_mask
def mask_human_targets(self, input_ids: torch.Tensor) -> torch.Tensor:
target_batch = []
for bs in range(input_ids.shape[0]):
ids = input_ids[bs]
targets = copy.deepcopy(ids)
end_count = 0
last_eoa = 0
for i, temp_id in enumerate(ids):
# Counterintuitively, IM_END_TOKEN is the token for the end of utterance
# In the whole source code, "[UNUSED_TOKEN_145]" corresponds to IM_END_TOKEN
if temp_id == IM_END_TOKEN:
if end_count % 2 == 0:
targets[last_eoa : i + SWITCH_TOKEN_LENGTH] = MASK_ID
else:
last_eoa = i + 1
end_count += 1
# # eos and following pad
elif temp_id == EOS:
# loss on eos, but not on pad
targets[i + 1 :] = MASK_ID
break
target_batch.append(targets.unsqueeze(0))
target_batch = torch.cat(target_batch, dim=0)
return target_batch
@classmethod
def from_ixc_pretrained(
cls,
ixc_pretrained_model_name_or_path: str,
max_length: int,
img_size: int,
discrete_custom_tokens: int | None = 128,
) -> torch.nn.Module:
r"""
Instantiate a pretrained InternLMXComposer2 model from a pre-trained model configuration.
"""
config: IXCLayoutConfig = AutoConfig.from_pretrained(
ixc_pretrained_model_name_or_path,
trust_remote_code=True,
)
config = IXCLayoutConfig.with_internlm_config(
config,
max_length=max_length,
discrete_custom_tokens=discrete_custom_tokens,
)
model: IXCLayout = super().from_pretrained(
ixc_pretrained_model_name_or_path,
config=config,
torch_dtype=torch.bfloat16,
) # type: ignore
if img_size != 336:
model.vit.resize_pos()
model.vit.requires_grad_(False)
model.vision_proj.requires_grad_(True)
return model
def or_config(
self,
return_dict: bool | None,
output_attentions: bool | None,
output_hidden_states: bool | None,
):
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
return return_dict, output_attentions, output_hidden_states
@torch.no_grad()
def chat(
self,
query: str,
image: list[tuple[str, str]] | torch.Tensor = [],
hd_num: int = 24,
history: list[tuple[str, str]] = [],
streamer: BaseStreamer | None = None,
max_new_tokens: int = 1024,
do_sample: bool = True,
num_beams: int = 1,
temperature: float = 1.0,
top_p: float = 0.8,
repetition_penalty: float = 1.005,
infer_mode: str = "base",
use_meta: bool = False,
meta_instruction: str = "You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).\n"
"- InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
"- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the language chosen by the user such as English and 中文.\n"
"- InternLM-XComposer (浦语·灵笔) is capable of comprehending and articulating responses effectively based on the provided image.",
**kwargs,
):
if not use_meta:
meta_instruction = ""
inputs, im_mask, _ = self.interleav_wrap_chat(
query,
image,
history=history,
meta_instruction=meta_instruction,
hd_num=hd_num,
)
inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
# also add end-of-assistant token in eos token id to avoid unnecessary generation
eos_token_id = [
self.tokenizer.eos_token_id,
self.tokenizer.convert_tokens_to_ids([FROM_TOKEN_2])[0], # type: ignore
]
outputs = self.generate(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
eos_token_id=eos_token_id,
repetition_penalty=repetition_penalty,
im_mask=im_mask,
infer_mode=infer_mode,
**kwargs,
)
outputs = outputs[0].cpu().tolist()
response = self.tokenizer.decode(outputs, skip_special_tokens=True)
response = response.split(FROM_TOKEN_2)[0]
history = history + [(query, response)]
return response, history
def add_coordinate_tokens(self, bins: int):
start = -bins
end = bins * 2
new_tokens = [int_wrap(idx) for idx in range(start, end + 1)]
new_token_ids = self._add_tokens(new_tokens)
self.coordinate_token_ids.update(new_token_ids)
self.config.discrete_coordinate_tokens = end // 2
def _add_tokens(self, new_tokens: list[str]) -> list[int]:
prev_vocab_size = len(self.tokenizer)
vocab = self.tokenizer.get_vocab()
new_tokens = [token for token in new_tokens if token not in vocab]
self.tokenizer.add_tokens(new_tokens) # type: ignore
self.model.resize_token_embeddings(len(self.tokenizer))
self.vocab_size = len(self.tokenizer)
# self.output needs to be resized accordingly but without loosing the weight
new_output = nn.Linear(
self.model.config.hidden_size,
self.vocab_size,
bias=False,
dtype=self.output.weight.dtype,
device=self.output.weight.device,
).to(self.device)
new_output.weight.data[: self.output.weight.shape[0]] = self.output.weight.data
self.output = new_output
return list(range(prev_vocab_size, self.vocab_size))
def setup_ixclayout(
model_name_or_path: str,
max_length: int,
img_size: int,
use_lora: bool,
gradient_checkpointing: bool,
discrete_custom_tokens: int | None = 128,
) -> IXCLayout | PeftModel:
model = IXCLayout.from_ixc_pretrained(
ixc_pretrained_model_name_or_path=model_name_or_path,
max_length=max_length,
img_size=img_size,
discrete_custom_tokens=discrete_custom_tokens,
)
if use_lora:
lora_args = LoraArguments()
lora_config = LoraConfig(
r=lora_args.lora_r,
lora_alpha=lora_args.lora_alpha,
target_modules=lora_args.lora_target_modules,
lora_dropout=lora_args.lora_dropout,
bias=lora_args.lora_bias,
task_type="CAUSAL_LM", # type: ignore
)
model = get_peft_model(model, lora_config) # type: ignore
model.print_trainable_parameters()
if gradient_checkpointing:
model.enable_input_require_grads()
model.vit.vision_tower.gradient_checkpointing_enable(
{"use_reentrant": True}
)
return model # type: ignore