Eagle2-1B / multi_backbone_channel_concatentation_model.py
Zhiding's picture
update
7b7790d
# --------------------------------------------------------
# Eagle2
# Copyright (c) 2025 NVIDIA
# Licensed under The Apache License [see LICENSE for details]
# --------------------------------------------------------
import torch.nn as nn
from transformers.modeling_outputs import BaseModelOutputWithPooling
from typing import Optional, Tuple, Union
from .multi_backbone_channel_concatenation_encoder import MultiBackboneChannelConcatenationVisionTower
from .configuration_multi_backbone_channel_concatentation_model import MultiBackboneChannelConcatenationVisionModelConfig
class MultiBackboneChannelConcatenationVisionModel(nn.Module):
"""
A vision model wrapper that concatenates channels from multiple backbones.
Args:
config (MultiBackboneChannelConcatenationVisionModelConfig): The configuration for the model.
Attributes:
vision_model (MultiBackboneChannelConcatenationVisionTower): The vision tower that performs the channel concatenation.
Notes:
**The class is not inherited from the PreTrainedModel in transformers**
"""
config_class = MultiBackboneChannelConcatenationVisionModelConfig
main_input_name = "pixel_values"
def __init__(self, config: MultiBackboneChannelConcatenationVisionModelConfig, raw_config):
super().__init__()
self.vision_model = MultiBackboneChannelConcatenationVisionTower(
vision_tower=config.vision_tower,
args=config,
grid_size=config.grid_size,
convnext_img_size=config.convnext_img_size,
normalize_type=config.normalize_type,
raw_config=raw_config
)
def get_input_embeddings(self):
# You might need to adjust this depending on how you want to handle input embeddings
return self.vision_model.vision_towers[0].get_input_embeddings()
def forward(
self,
pixel_values,
return_dict: Optional[bool] = True,
output_hidden_states: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPooling]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
assert return_dict is True, "We only support return_dict"
assert output_hidden_states is False, "We do not support output_hidden_states"
features = self.vision_model(pixel_values)
# We only supports features as model outputs
return BaseModelOutputWithPooling(
last_hidden_state=features,
pooler_output=None,
hidden_states=None,
attentions=None,
)
@property
def dummy_feature(self):
return self.vision_model.dummy_feature
@property
def dtype(self):
return self.vision_model.dtype
@property
def device(self):
return self.vision_model.device
@property
def config(self):
return self.vision_model.config
@property
def hidden_size(self):
return self.vision_model.hidden_size
@property
def num_patches(self):
return self.vision_model.num_patches