|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch.nn as nn |
|
|
|
from transformers.modeling_outputs import BaseModelOutputWithPooling |
|
from typing import Optional, Tuple, Union |
|
|
|
from .multi_backbone_channel_concatenation_encoder import MultiBackboneChannelConcatenationVisionTower |
|
from .configuration_multi_backbone_channel_concatentation_model import MultiBackboneChannelConcatenationVisionModelConfig |
|
|
|
|
|
class MultiBackboneChannelConcatenationVisionModel(nn.Module): |
|
|
|
""" |
|
A vision model wrapper that concatenates channels from multiple backbones. |
|
|
|
Args: |
|
config (MultiBackboneChannelConcatenationVisionModelConfig): The configuration for the model. |
|
|
|
Attributes: |
|
vision_model (MultiBackboneChannelConcatenationVisionTower): The vision tower that performs the channel concatenation. |
|
|
|
Notes: |
|
**The class is not inherited from the PreTrainedModel in transformers** |
|
|
|
""" |
|
|
|
config_class = MultiBackboneChannelConcatenationVisionModelConfig |
|
main_input_name = "pixel_values" |
|
|
|
def __init__(self, config: MultiBackboneChannelConcatenationVisionModelConfig, raw_config): |
|
super().__init__() |
|
|
|
self.vision_model = MultiBackboneChannelConcatenationVisionTower( |
|
vision_tower=config.vision_tower, |
|
args=config, |
|
grid_size=config.grid_size, |
|
convnext_img_size=config.convnext_img_size, |
|
normalize_type=config.normalize_type, |
|
raw_config=raw_config |
|
) |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
|
return self.vision_model.vision_towers[0].get_input_embeddings() |
|
|
|
def forward( |
|
self, |
|
pixel_values, |
|
return_dict: Optional[bool] = True, |
|
output_hidden_states: Optional[bool] = False, |
|
) -> Union[Tuple, BaseModelOutputWithPooling]: |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
assert return_dict is True, "We only support return_dict" |
|
assert output_hidden_states is False, "We do not support output_hidden_states" |
|
|
|
features = self.vision_model(pixel_values) |
|
|
|
|
|
return BaseModelOutputWithPooling( |
|
last_hidden_state=features, |
|
pooler_output=None, |
|
hidden_states=None, |
|
attentions=None, |
|
) |
|
|
|
@property |
|
def dummy_feature(self): |
|
return self.vision_model.dummy_feature |
|
|
|
@property |
|
def dtype(self): |
|
return self.vision_model.dtype |
|
|
|
@property |
|
def device(self): |
|
return self.vision_model.device |
|
|
|
@property |
|
def config(self): |
|
return self.vision_model.config |
|
|
|
@property |
|
def hidden_size(self): |
|
return self.vision_model.hidden_size |
|
|
|
@property |
|
def num_patches(self): |
|
return self.vision_model.num_patches |
|
|