import copy import numpy as np import torch import torch.utils.checkpoint from peft.mapping import get_peft_model from peft.peft_model import PeftModel from peft.tuners.lora import LoraConfig from torch import nn from torch.nn import CrossEntropyLoss from transformers import AutoConfig, AutoTokenizer from transformers.generation.streamers import BaseStreamer from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) from transformers.utils import ( replace_return_docstrings, ) from designvlm.crello_dataset import int_wrap from designvlm.loading import LoraArguments from yosematvlm.configuration_internlm_xcomposer2 import InternLMXcomposer2Config from yosematvlm.modeling_internlm_xcomposer2 import ( FROM_TOKEN_2, InternLMXComposer2ForCausalLM, ) IM_END_TOKEN = 92542 EOS = 2 # SWITCH_TOKEN_LENGTH = 6 # after <|im_end|>, there are 5 additional tokens ['\n', '<|im_start|>', 'ass', 'istant', '\n']. You don't want the model to learn these tokens. MASK_ID = -100 class IXCLayoutConfig(InternLMXcomposer2Config): def __init__( self, vocab_size=103168, hidden_size=4096, intermediate_size=11008, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=None, hidden_act="silu", max_position_embeddings=2048, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, pad_token_id=0, bos_token_id=1, eos_token_id=2, tie_word_embeddings=False, bias=True, rope_theta=10000, rope_scaling=None, attn_implementation="flash_attention_2", max_length: int = 16384, discrete_coordinate_tokens: int | None = None, **kwargs, ): self.discrete_coordinate_tokens = discrete_coordinate_tokens super().__init__( vocab_size=vocab_size, hidden_size=hidden_size, intermediate_size=intermediate_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, num_key_value_heads=num_key_value_heads, hidden_act=hidden_act, max_position_embeddings=max_position_embeddings, initializer_range=initializer_range, rms_norm_eps=rms_norm_eps, use_cache=use_cache, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, bias=bias, rope_theta=rope_theta, rope_scaling=rope_scaling, attn_implementation=attn_implementation, **kwargs, ) self.max_length = max_length self.use_cache = False @classmethod def with_internlm_config( cls, config: InternLMXcomposer2Config, discrete_custom_tokens: int | None = None, max_length: int = 16384, ): return cls( vocab_size=config.vocab_size, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, num_hidden_layers=config.num_hidden_layers, num_attention_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, hidden_act=config.hidden_act, max_position_embeddings=config.max_position_embeddings, initializer_range=config.initializer_range, rms_norm_eps=config.rms_norm_eps, use_cache=config.use_cache, pad_token_id=config.pad_token_id, bos_token_id=config.bos_token_id, eos_token_id=config.eos_token_id, tie_word_embeddings=config.tie_word_embeddings, bias=config.bias, rope_theta=config.rope_theta, rope_scaling=config.rope_scaling, attn_implementation=config.attn_implementation, discrete_coordinate_tokens=discrete_custom_tokens, max_length=max_length, ) class IXCLayout(InternLMXComposer2ForCausalLM): config_class = IXCLayoutConfig def __init__(self, config: IXCLayoutConfig): super().__init__(config) self.vit.vision_tower.vision_model.post_layernorm = torch.nn.Identity() self.tokenizer = AutoTokenizer.from_pretrained( "yosematvlm", padding_side="right", use_fast=False, trust_remote_code=True, ) # type: ignore # Add coordinate tokens self.coordinate_token_ids: set[int] = set() if config.discrete_coordinate_tokens is not None: self.add_coordinate_tokens(config.discrete_coordinate_tokens) self.config = config @replace_return_docstrings( output_type=CausalLMOutputWithPast, config_class=InternLMXcomposer2Config ) def forward( self, input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: list[torch.FloatTensor] | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, **kwargs, ) -> CausalLMOutputWithPast: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: """ infer_mode = "base" return_dict, output_attentions, output_hidden_states = self.or_config( return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) if "samples" in kwargs: # Training samples = kwargs["samples"] # encode text text = samples["text_input"] # encode image image = samples["image"][0] bs = len(samples["text_input"][0]) image_nums = [] temp_image = [] for im in image: if type(im) is list: image_nums.append(len(im)) temp_image.extend(im) else: image_nums.append(1) temp_image.append(im) image = temp_image assert type(image) is list and len(image_nums) == bs to_regress_embeds, attention_mask, targets, im_mask = self.interleav_wrap( image, text, image_nums ) inputs_embeds = to_regress_embeds[:, : self.max_length] # type: ignore attention_mask = attention_mask[:, : self.max_length] # type: ignore targets = targets[:, : self.max_length] im_mask = im_mask[:, : self.max_length].bool() labels = targets # type: ignore elif inputs_embeds is not None or input_ids is not None: im_mask = kwargs["im_mask"] if im_mask is None and inputs_embeds is not None: im_mask = torch.zeros(inputs_embeds.shape[:2]).to(inputs_embeds.device) im_mask = im_mask.bool() else: raise ValueError( "Either samples, inputs_embeds or input_ids should be provided." ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, im_mask=im_mask, infer_mode=infer_mode, ) logits: torch.Tensor = self.output( outputs.last_hidden_state ).float() # B x L x V if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) ce_loss: torch.Tensor = loss_fct(shift_logits, shift_labels) assert not ce_loss.isnan().any() kl_loss = self.coordinate_kl_loss(logits, labels) assert not kl_loss.isnan().any() loss = ce_loss + kl_loss else: loss = None return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if not return_dict: output = (logits,) + outputs[1:] return (ce_loss,) + output if ce_loss is not None else output return CausalLMOutputWithPast( loss=loss, # type: ignore logits=logits, # type: ignore past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def coordinate_kl_loss( self, logits: torch.Tensor, labels: torch.Tensor, eps: float = 1e-9 ) -> torch.Tensor: """ For coordinate token, calculate the KL loss between the predicted logits and the target labels. Instead of one-hot vector, we assume that the target labels are the probability distribution of the target token. The distribution is a discrete gaussian distribution with mean at the target token and variance of 1. Args: logits: B x T x V; The predicted logits of the model. labels: B x T; The target labels of the model. """ label_std_dev = 2.0 coordinate_token_ids = torch.tensor( list(self.coordinate_token_ids), device=labels.device, dtype=labels.dtype ) assert len(self.coordinate_token_ids) > 0 # Get the mask of the coordinate tokens is_label_coordinate = torch.isin( labels, coordinate_token_ids, ).type_as(labels) # B x T # Get the target labels of the coordinate tokens # Range of indices indices = torch.arange( 0, self.vocab_size, dtype=labels.dtype, device=logits.device ).repeat(labels.shape[0], labels.shape[1], 1) # B x T x V # Indices that are not coordinate tokens are set to 0 # To reduce memory consumption, we iterate over the sequence length is_indice_coordinate_token = torch.stack( [ torch.isin(indices[:, idx], coordinate_token_ids) for idx in range(indices.size(1)) ], dim=1, ) # B x T x V total_mask = is_label_coordinate.unsqueeze(-1) * is_indice_coordinate_token # Create a Gaussian distribution centered at the label index gauss_label = torch.exp( -0.5 * ((indices - labels.unsqueeze(-1)) / label_std_dev) ** 2 ) # B x T x V # Apply the mask so that only the coordinate tokens are considered gauss_label = gauss_label * total_mask + eps # Add eps for numerical stability # Normalize the distribution gauss_label /= gauss_label.sum(dim=-1, keepdim=True) pointwise_kl = ( nn.functional.kl_div( logits.log_softmax(dim=-1), gauss_label, reduction="none" ).nan_to_num() * total_mask ) kl_loss = pointwise_kl.sum(-1).mean() assert kl_loss > 0 return kl_loss def interleav_wrap( self, img_list: list[torch.Tensor], # V x 1 x 3 x H x W text_list_list: list[list[str]], # B x 1 x str image_nums: list[int], # B ): temp_embeds = [] temp_im_mask = [] temp_tars = [] # encode_image if len(img_list) > 0: img_embeds, img_split = self.vit( img_list, self.plora_glb_GN, self.plora_sub_GN ) img_embeds = self.vision_proj(img_embeds) else: img_embeds = None img_split = [] text_list = text_list_list[0] for idx, text in enumerate(text_list): image_num = image_nums[idx] im_id = int(np.sum(image_nums[:idx])) images = [] for i in range(image_nums[idx]): st = int(np.sum(img_split[: im_id + i])) sp = img_split[im_id + i] temp_img = img_embeds[:, st : st + sp] # type: ignore images.append(temp_img) if image_num == 1 and text.find("") == -1: text = "" + text parts = text.split("") wrap_tokens, wrap_embeds, wrap_im_mask = [], [], [] temp_len = 0 need_bos = True for idx, part in enumerate(parts): if len(part) > 0: part_tokens = self.tokenizer( part, return_tensors="pt", padding="longest", add_special_tokens=need_bos, ).to(self.device) # type: ignore if need_bos: need_bos = False wrap_tokens.append(part_tokens.input_ids) part_embeds = self.model.tok_embeddings(part_tokens.input_ids) wrap_embeds.append(part_embeds) wrap_im_mask.append( torch.zeros(part_embeds.shape[:2]).to(self.device) ) temp_len += part_embeds.shape[1] if idx < image_num: wrap_embeds.append(images[idx]) wrap_token = ( torch.ones(images[idx].shape[:2], dtype=torch.long).to( self.device ) * -100 ) wrap_tokens.append(wrap_token) wrap_im_mask.append( torch.ones(images[idx].shape[:2]).to(self.device) ) temp_len += images[idx].shape[1] if temp_len > self.max_length: break wrap_tokens = torch.cat(wrap_tokens, dim=1) wrap_embeds = torch.cat(wrap_embeds, dim=1) wrap_im_mask = torch.cat(wrap_im_mask, dim=1) wrap_target = self.mask_human_targets(wrap_tokens).to(self.device) temp_embeds.append(wrap_embeds) temp_im_mask.append(wrap_im_mask) temp_tars.append(wrap_target) temp_max_len = np.max([i.shape[1] for i in temp_embeds]) temp_max_len = min(temp_max_len, self.max_length) final_input, final_atts, final_tars, final_mask = [], [], [], [] pad = torch.ones([1, 1]) * self.tokenizer.pad_token_id # type: ignore pad = pad.long().to(self.device) pad_emb = self.model.tok_embeddings(pad) for idx in range(len(temp_embeds)): temp_len = temp_embeds[idx].shape[1] if temp_len >= temp_max_len: final_input.append(temp_embeds[idx][:, :temp_max_len]) final_atts.append( torch.ones(1, temp_max_len).to(wrap_target.dtype).to(self.device) ) final_tars.append(temp_tars[idx][:, :temp_max_len]) final_mask.append(temp_im_mask[idx][:, :temp_max_len]) else: final_input.append( torch.cat( [ temp_embeds[idx], pad_emb.repeat(1, temp_max_len - temp_len, 1), ], dim=1, ) ) final_atts.append( torch.cat( [ torch.ones(1, temp_len), torch.zeros(1, temp_max_len - temp_len), ], dim=1, ) .to(wrap_target.dtype) .to(self.device) ) final_tars.append( torch.cat( [ temp_tars[idx], (torch.ones(1, temp_max_len - temp_len) * MASK_ID) .to(wrap_target.dtype) .to(self.device), ], dim=1, ) ) final_mask.append( torch.cat( [ temp_im_mask[idx], (torch.zeros(1, temp_max_len - temp_len)) .to(wrap_target.dtype) .to(self.device), ], dim=1, ) ) inputs_embeds = torch.cat(final_input, dim=0) attention_mask = torch.cat(final_atts, dim=0) targets = torch.cat(final_tars, dim=0) im_mask = torch.cat(final_mask, dim=0) return inputs_embeds, attention_mask, targets, im_mask def mask_human_targets(self, input_ids: torch.Tensor) -> torch.Tensor: target_batch = [] for bs in range(input_ids.shape[0]): ids = input_ids[bs] targets = copy.deepcopy(ids) end_count = 0 last_eoa = 0 for i, temp_id in enumerate(ids): # Counterintuitively, IM_END_TOKEN is the token for the end of utterance # In the whole source code, "[UNUSED_TOKEN_145]" corresponds to IM_END_TOKEN if temp_id == IM_END_TOKEN: if end_count % 2 == 0: targets[last_eoa : i + SWITCH_TOKEN_LENGTH] = MASK_ID else: last_eoa = i + 1 end_count += 1 # # eos and following pad elif temp_id == EOS: # loss on eos, but not on pad targets[i + 1 :] = MASK_ID break target_batch.append(targets.unsqueeze(0)) target_batch = torch.cat(target_batch, dim=0) return target_batch @classmethod def from_ixc_pretrained( cls, ixc_pretrained_model_name_or_path: str, max_length: int, img_size: int, discrete_custom_tokens: int | None = 128, ) -> torch.nn.Module: r""" Instantiate a pretrained InternLMXComposer2 model from a pre-trained model configuration. """ config: IXCLayoutConfig = AutoConfig.from_pretrained( ixc_pretrained_model_name_or_path, trust_remote_code=True, ) config = IXCLayoutConfig.with_internlm_config( config, max_length=max_length, discrete_custom_tokens=discrete_custom_tokens, ) model: IXCLayout = super().from_pretrained( ixc_pretrained_model_name_or_path, config=config, torch_dtype=torch.bfloat16, ) # type: ignore if img_size != 336: model.vit.resize_pos() model.vit.requires_grad_(False) model.vision_proj.requires_grad_(True) return model def or_config( self, return_dict: bool | None, output_attentions: bool | None, output_hidden_states: bool | None, ): output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) return return_dict, output_attentions, output_hidden_states @torch.no_grad() def chat( self, query: str, image: list[tuple[str, str]] | torch.Tensor = [], hd_num: int = 24, history: list[tuple[str, str]] = [], streamer: BaseStreamer | None = None, max_new_tokens: int = 1024, do_sample: bool = True, num_beams: int = 1, temperature: float = 1.0, top_p: float = 0.8, repetition_penalty: float = 1.005, infer_mode: str = "base", use_meta: bool = False, meta_instruction: str = "You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).\n" "- InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n" "- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the language chosen by the user such as English and 中文.\n" "- InternLM-XComposer (浦语·灵笔) is capable of comprehending and articulating responses effectively based on the provided image.", **kwargs, ): if not use_meta: meta_instruction = "" inputs, im_mask, _ = self.interleav_wrap_chat( query, image, history=history, meta_instruction=meta_instruction, hd_num=hd_num, ) inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)} # also add end-of-assistant token in eos token id to avoid unnecessary generation eos_token_id = [ self.tokenizer.eos_token_id, self.tokenizer.convert_tokens_to_ids([FROM_TOKEN_2])[0], # type: ignore ] outputs = self.generate( **inputs, streamer=streamer, max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=do_sample, temperature=temperature, top_p=top_p, eos_token_id=eos_token_id, repetition_penalty=repetition_penalty, im_mask=im_mask, infer_mode=infer_mode, **kwargs, ) outputs = outputs[0].cpu().tolist() response = self.tokenizer.decode(outputs, skip_special_tokens=True) response = response.split(FROM_TOKEN_2)[0] history = history + [(query, response)] return response, history def add_coordinate_tokens(self, bins: int): start = -bins end = bins * 2 new_tokens = [int_wrap(idx) for idx in range(start, end + 1)] new_token_ids = self._add_tokens(new_tokens) self.coordinate_token_ids.update(new_token_ids) self.config.discrete_coordinate_tokens = end // 2 def _add_tokens(self, new_tokens: list[str]) -> list[int]: prev_vocab_size = len(self.tokenizer) vocab = self.tokenizer.get_vocab() new_tokens = [token for token in new_tokens if token not in vocab] self.tokenizer.add_tokens(new_tokens) # type: ignore self.model.resize_token_embeddings(len(self.tokenizer)) self.vocab_size = len(self.tokenizer) # self.output needs to be resized accordingly but without loosing the weight new_output = nn.Linear( self.model.config.hidden_size, self.vocab_size, bias=False, dtype=self.output.weight.dtype, device=self.output.weight.device, ).to(self.device) new_output.weight.data[: self.output.weight.shape[0]] = self.output.weight.data self.output = new_output return list(range(prev_vocab_size, self.vocab_size)) def setup_ixclayout( model_name_or_path: str, max_length: int, img_size: int, use_lora: bool, gradient_checkpointing: bool, discrete_custom_tokens: int | None = 128, ) -> IXCLayout | PeftModel: model = IXCLayout.from_ixc_pretrained( ixc_pretrained_model_name_or_path=model_name_or_path, max_length=max_length, img_size=img_size, discrete_custom_tokens=discrete_custom_tokens, ) if use_lora: lora_args = LoraArguments() lora_config = LoraConfig( r=lora_args.lora_r, lora_alpha=lora_args.lora_alpha, target_modules=lora_args.lora_target_modules, lora_dropout=lora_args.lora_dropout, bias=lora_args.lora_bias, task_type="CAUSAL_LM", # type: ignore ) model = get_peft_model(model, lora_config) # type: ignore model.print_trainable_parameters() if gradient_checkpointing: model.enable_input_require_grads() model.vit.vision_tower.gradient_checkpointing_enable( {"use_reentrant": True} ) return model # type: ignore