# -------------------------------------------------------- # Eagle2 # Copyright (c) 2025 NVIDIA # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- import warnings from typing import Any, List, Optional, Tuple, Union import torch.utils.checkpoint import transformers from torch import nn from torch.nn import CrossEntropyLoss from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM, LlamaTokenizer) from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.utils import ModelOutput, logging from peft import LoraConfig, get_peft_model from .configuration_eagle_chat import Eagle2ChatConfig from .conversation import get_conv_template from .modeling_siglip import SiglipVisionModel from .modeling_qwen2 import Qwen2ForCausalLM from .flash_attention import * from .multi_backbone_channel_concatentation_model import MultiBackboneChannelConcatenationVisionModel from .multi_backbone_channel_concatenation_encoder import MultiBackboneChannelConcatenationVisionTower from .configuration_multi_backbone_channel_concatentation_model import MultiBackboneChannelConcatenationVisionModelConfig from .siglip_vision_tower import SiglipVisionTower from .convnext_encoder import ConvNextVisionTower from .convnext import ConvNeXt logger = logging.get_logger(__name__) def version_cmp(v1, v2, op='eq'): import operator from packaging import version op_func = getattr(operator, op) return op_func(version.parse(v1), version.parse(v2)) class Eagle2ChatModel(PreTrainedModel): config_class = Eagle2ChatConfig main_input_name = 'pixel_values' _no_split_modules = ['LlamaDecoderLayer'] def __init__(self, config: Eagle2ChatConfig, vision_model=None, language_model=None): super().__init__(config) assert version_cmp(transformers.__version__, '4.37.2', 'ge') assert version_cmp(transformers.__version__, '4.39.2', 'le') image_size = config.force_image_size or config.vision_config.image_size if hasattr(config.vision_config, 'grid_size'): grid_size = config.vision_config.grid_size self.patch_size = 14 self.num_image_token = int((grid_size * config.downsample_ratio) ** 2) else: patch_size = config.vision_config.patch_size self.patch_size = patch_size self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2)) self.select_layer = config.select_layer self.template = config.template self.downsample_ratio = config.downsample_ratio logger.info(f'num_image_token: {self.num_image_token}') if vision_model is not None: self.vision_model = vision_model else: if config.vision_config.model_type == 'siglip_vision_model': self.vision_model = SiglipVisionModel(config.vision_config) elif config.vision_config.model_type.startswith("MOB"): self.vision_model = MultiBackboneChannelConcatenationVisionModel(config.vision_config, config) if language_model is not None: self.language_model = language_model else: if config.llm_config.architectures[0] == 'LlamaForCausalLM': self.language_model = LlamaForCausalLM(config.llm_config) elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM': self.language_model = Qwen2ForCausalLM(config.llm_config) else: raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.') vit_hidden_size = config.vision_config.hidden_size if vit_hidden_size == 'lazy_calculation': # a hack for Mixture of Backbones vit_hidden_size = self.vision_model.hidden_size print("The lazy calculated hidden_size: {} .. ".format(vit_hidden_size)) llm_hidden_size = config.llm_config.hidden_size self.moe_version_type = getattr(config.vision_config, 'moe_version_type', None) if self.moe_version_type in ['seq_concat', 'feat_concat']: raise NotImplementedError elif self.moe_version_type == 'convnext_512_siglip_448': convnext_hidden_size = vit_hidden_size['convnext'] siglip_hidden_size = vit_hidden_size['siglip'] feature_concat_hidden_size = convnext_hidden_size + siglip_hidden_size * int(1 / self.downsample_ratio) ** 2 self.mlp1 = nn.Sequential( nn.LayerNorm(feature_concat_hidden_size), nn.Linear(feature_concat_hidden_size, llm_hidden_size), nn.GELU(), nn.Linear(llm_hidden_size, llm_hidden_size) ) else: self.mlp1 = nn.Sequential( nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2), nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size), nn.GELU(), nn.Linear(llm_hidden_size, llm_hidden_size) ) self.img_context_token_id = None self.conv_template = get_conv_template(self.template) self.system_message = self.conv_template.system_message def forward( self, pixel_values: torch.FloatTensor, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, image_flags: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, num_patches_list: Optional[List[torch.Tensor]] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict image_flags = image_flags.squeeze(-1) input_embeds = self.language_model.get_input_embeddings()(input_ids) if self.moe_version_type in ['seq_concat', 'feat_concat'] and not isinstance(pixel_values, dict): raise NotImplementedError vit_embeds = self.extract_feature(pixel_values) if not isinstance(image_flags, list): image_flags = image_flags.squeeze(-1) vit_embeds = vit_embeds[image_flags == 1] if isinstance(pixel_values, dict): # for MOE vit_batch_size = sum(pixel_values['num_patches']) else: vit_batch_size = pixel_values.shape[0] B, N, C = input_embeds.shape input_embeds = input_embeds.reshape(B * N, C) if torch.distributed.get_rank() == 0: print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}') input_ids = input_ids.reshape(B * N) selected = (input_ids == self.img_context_token_id) try: input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C) except Exception as e: vit_embeds = vit_embeds.reshape(-1, C) print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, ' f'vit_embeds.shape={vit_embeds.shape}') n_token = selected.sum() input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token] input_embeds = input_embeds.reshape(B, N, C) outputs = self.language_model( inputs_embeds=input_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) logits = outputs.logits loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() # N, W, H, C --> N, W, H * scale, C // scale x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) # N, W, H * scale, C // scale --> N, H * scale, W, C // scale x = x.permute(0, 2, 1, 3).contiguous() # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor))) x = x.permute(0, 2, 1, 3).contiguous() return x def extract_feature(self, pixel_values): """ """ if self.select_layer == -1: vit_embeds = self.vision_model( pixel_values=pixel_values, output_hidden_states=False, return_dict=True).last_hidden_state # torch.Size([B, 1025, 1024]) else: vit_embeds = self.vision_model( pixel_values=pixel_values, output_hidden_states=True, return_dict=True).hidden_states[self.select_layer] if type(self.vision_model) == SiglipVisionModel: pass elif type(self.vision_model) == MultiBackboneChannelConcatenationVisionModel: pass else: vit_embeds = vit_embeds[:, 1:, :] # torch.Size([B, 1024, 1024]) if self.training and self.neftune_alpha is not None: vit_embeds = self.noised_embed(vit_embeds, self.neftune_alpha) if self.moe_version_type in ['feat_concat', 'seq_concat']: raise NotImplementedError elif self.moe_version_type == 'convnext_512_siglip_448': siglip_embeds = vit_embeds['siglip'] convnext_embeds = vit_embeds['convnext'] h = w = int(siglip_embeds.shape[1] ** 0.5) siglip_embeds = siglip_embeds.reshape(siglip_embeds.shape[0], h, w, -1) siglip_embeds = self.pixel_shuffle(siglip_embeds, scale_factor=self.downsample_ratio) siglip_embeds = siglip_embeds.reshape(siglip_embeds.shape[0], -1, siglip_embeds.shape[-1]) vit_embeds = self.mlp1(torch.cat([siglip_embeds, convnext_embeds], dim=-1)) else: h = w = int(vit_embeds.shape[1] ** 0.5) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) # torch.Size([B, 1024, 1024]) -> torch.Size([B, 16, 16, 4096]) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) # torch.Size([B, 16, 16, 4096]) -> torch.Size([B, 256, 4096]) vit_embeds = self.mlp1(vit_embeds)#.to(pixel_values.device) return vit_embeds def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None, history=None, return_history=False, IMG_START_TOKEN='', IMG_END_TOKEN='', IMG_CONTEXT_TOKEN='', verbose=False, image_counts=None): if history is not None or return_history: print('Now multi-turn chat is not supported in batch_chat.') raise NotImplementedError if image_counts is not None: num_patches_list = image_counts print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.') img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) self.img_context_token_id = img_context_token_id if verbose and pixel_values is not None: image_bs = pixel_values.shape[0] print(f'dynamic ViT batch size: {image_bs}') queries = [] for idx, num_patches in enumerate(num_patches_list): question = questions[idx] if pixel_values is not None and '' not in question: question = '\n' + question template = get_conv_template(self.template) template.append_message(template.roles[0], question) template.append_message(template.roles[1], None) query = template.get_prompt() image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN query = query.replace('', image_tokens, 1) queries.append(query) tokenizer.padding_side = 'left' model_inputs = tokenizer(queries, return_tensors='pt', padding=True) input_ids = model_inputs['input_ids'].cuda() attention_mask = model_inputs['attention_mask'].cuda() eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) generation_config['eos_token_id'] = eos_token_id generation_output = self.generate( pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, **generation_config ) responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True) responses = [response.split(template.sep)[0].strip() for response in responses] return responses def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False, num_patches_list=None, IMG_START_TOKEN='', IMG_END_TOKEN='', IMG_CONTEXT_TOKEN='', verbose=False, llm_only=False): if history is None and pixel_values is not None and '' not in question: question = '\n' + question if num_patches_list is None: num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] assert pixel_values is None or len(pixel_values) == sum(num_patches_list) img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) self.img_context_token_id = img_context_token_id template = get_conv_template(self.template) template.system_message = self.system_message eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) history = [] if history is None else history for (old_question, old_answer) in history: template.append_message(template.roles[0], old_question) template.append_message(template.roles[1], old_answer) template.append_message(template.roles[0], question) template.append_message(template.roles[1], None) query = template.get_prompt() if verbose and pixel_values is not None: image_bs = pixel_values.shape[0] print(f'dynamic ViT batch size: {image_bs}') for num_patches in num_patches_list: image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN if llm_only: query = query.replace('', '', 1) else: query = query.replace('', image_tokens, 1) model_inputs = tokenizer(query, return_tensors='pt') input_ids = model_inputs['input_ids'].cuda() attention_mask = model_inputs['attention_mask'].cuda() generation_config['eos_token_id'] = eos_token_id if self.moe_version_type is not None and self.moe_version_type != 'all_tiling' and self.moe_version_type != 'convnext_512_siglip_448': pixel_values = { 'pixel_values': pixel_values, 'num_patches': num_patches_list # num patch of each image. } generation_output = self.generate( pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, **generation_config ) response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] response = response.split(template.sep)[0].strip() history.append((question, response)) if return_history: return response, history else: query_to_print = query.replace(IMG_CONTEXT_TOKEN, '') query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '') if verbose: print(query_to_print, response) return response @torch.no_grad() def generate( self, pixel_values: Optional[torch.FloatTensor] = None, input_ids: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, visual_features: Optional[torch.FloatTensor] = None, generation_config: Optional[GenerationConfig] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **generate_kwargs, ) -> torch.LongTensor: assert self.img_context_token_id is not None if pixel_values is not None: if visual_features is not None: vit_embeds = visual_features else: vit_embeds = self.extract_feature(pixel_values) input_embeds = self.language_model.get_input_embeddings()(input_ids) B, N, C = input_embeds.shape input_embeds = input_embeds.reshape(B * N, C) input_ids = input_ids.reshape(B * N) selected = (input_ids == self.img_context_token_id) assert selected.sum() != 0 input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) input_embeds = input_embeds.reshape(B, N, C) else: input_embeds = self.language_model.get_input_embeddings()(input_ids) outputs = self.language_model.generate( inputs_embeds=input_embeds, attention_mask=attention_mask, generation_config=generation_config, output_hidden_states=output_hidden_states, return_dict=return_dict, use_cache=True, **generate_kwargs, ) return outputs def get_input_embeddings(self): return self.language_model.get_input_embeddings() def get_output_embeddings(self): return self.language_model.get_output_embeddings()