# -------------------------------------------------------- # Eagle2 # Copyright (c) 2025 NVIDIA # Licensed under The Apache License [see LICENSE for details] # -------------------------------------------------------- import torch, os import torch.nn as nn from torch.utils.checkpoint import checkpoint from .siglip_vision_tower import SiglipVisionTower import torch.nn.functional as F from torch.nn.init import trunc_normal_ from copy import deepcopy import random import math class MultiBackboneChannelConcatenationVisionTower(nn.Module): def __init__(self, vision_tower, args, grid_size=32, convnext_img_size=1024, normalize_type=None, raw_config=None): super().__init__() self.is_loaded = False self.grid_size = grid_size self.num_tokens = self.grid_size ** 2 self.normalize_type = args.normalize_type self.moe_version_type = args.moe_version_type self.raw_config = raw_config print("moe_version_type: ", self.moe_version_type) assert self.moe_version_type in [None, 'all_tiling', 'seq_concat', 'feat_concat', 'convnext_512_siglip_448'], f"Unknown self.moe_version_type: {self.moe_version_type}" vision_tower_name_list = vision_tower.split(";") self.input_image_size = 1024 self.convnext_img_size = convnext_img_size self.load_vision_towers(vision_tower_name_list, args) def load_vision_towers(self, vision_tower_name_list, args): self.vision_towers = nn.ModuleList() freeze_backbone_list = args.freeze_backbones # note this is a str if freeze_backbone_list is not None and len(freeze_backbone_list) > 0: print("The frozen backbones: ", freeze_backbone_list) else: # make it a blank str freeze_backbone_list = "" for name in vision_tower_name_list: ## ConvNeXt if name == 'convnext-1024': convnext_args = deepcopy(args) convnext_args.freeze_vision = False if 'convnext-1024' in freeze_backbone_list: convnext_args.freeze_vision = True from .convnext_encoder import ConvNextVisionTower convnext_args.input_image_size = self.convnext_img_size convnext_vision_tower = args.vision_tower_convnext_path convnext_vision_tower = ConvNextVisionTower(convnext_vision_tower, convnext_args, delay_load=args.delay_load, normalize_type=self.normalize_type) convnext_vision_tower.load_model() self.vision_towers.append(convnext_vision_tower) ## PaliSigLIP elif name == 'palisiglip': palisiglip_args = deepcopy(args) palisiglip_args.input_image_size = 448 palisiglip_args.freeze_vision = False if 'palisiglip' in freeze_backbone_list: palisiglip_args.freeze_vision = True palisiglip_vision_tower = SiglipVisionTower(args.vision_tower_siglip_path, palisiglip_args, delay_load=args.delay_load, raw_config=self.raw_config) palisiglip_vision_tower.load_model() self.vision_towers.append(palisiglip_vision_tower) # Set the image processor self.image_processor = None self.is_loaded = True def load_model(self): assert self.is_loaded, "All the vision encoders should be loaded during initialization!" def forward(self, x): # x is a Tensor if moe_version_type is None or 'all_tiling' # else is a tuple(Tensor, Tensor) if self.moe_version_type in [None, 'all_tiling']: # The default pipeline features = [] image_input_size = x.shape[2] assert x.shape[2] == x.shape[3], f"Image should be a square but size ({x.shape[2]} x {x.shape[3]})" for vision_tower in self.vision_towers: if vision_tower.input_image_size != image_input_size: resized_x = F.interpolate(x.float(), size=(vision_tower.input_image_size, vision_tower.input_image_size), mode='bilinear', align_corners=True).to(dtype=x.dtype) else: resized_x = x feature = vision_tower(resized_x) if len(feature.shape) == 3: # b, n, c b, n, c = feature.shape if n == self.num_tokens: features.append(feature) continue w = h = int(n**0.5) feature = feature.transpose(1,2).reshape(b, c, h, w) else: b, c, h, w = feature.shape if w != self.grid_size: feature = F.interpolate(feature.float(), size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=True).to(dtype=x.dtype) features.append(feature.flatten(2,3).transpose(1,2)) features = torch.cat(features, dim=-1) elif self.moe_version_type == 'convnext_512_siglip_448': features = {} image_input_size = x.shape[2] assert x.shape[2] == x.shape[3], f"Image should be a square but size ({x.shape[2]} x {x.shape[3]})" for vision_tower in self.vision_towers: if vision_tower.input_image_size != image_input_size: resized_x = F.interpolate(x.float(), size=(vision_tower.input_image_size, vision_tower.input_image_size), mode='bilinear', align_corners=True).to(dtype=x.dtype) else: resized_x = x feature = vision_tower(resized_x) # if len(feature.shape) == 3: # b, n, c # b, n, c = feature.shape # if n == self.num_tokens: # features.append(feature) # continue # w = h = int(n**0.5) # feature = feature.transpose(1,2).reshape(b, c, h, w) # else: # b, c, h, w = feature.shape features[vision_tower.name] = feature else: assert isinstance(x, dict), "x is expected to be a dict but {}".format(type(x)) pixel_values = x['pixel_values'] num_patches = x['num_patches'] # num patch of paddings token in texts # calculated the real image patches if self.moe_version_type == 'seq_concat': image_in_num_patches = [i-1 for i in num_patches] else: image_in_num_patches = [i for i in num_patches] assert sum(image_in_num_patches) == pixel_values.size(0), "sum(image_in_num_patches) ({}) != pixel_values.size(0) ({})".format(sum(image_in_num_patches), pixel_values.size(0)) # find the thubnail image id thumbnail_image_id = torch.cumsum(torch.tensor(image_in_num_patches).to(pixel_values.device), 0) - 1 image_no_tiling = pixel_values[thumbnail_image_id] # By default, we use the 1st vision_tower for x, others for x_nt features = [] for layer_id, vision_tower in enumerate(self.vision_towers): if layer_id == 0: x = pixel_values else: x = image_no_tiling if vision_tower.input_image_size != self.input_image_size: resized_x = F.interpolate(x.float(), size=(vision_tower.input_image_size, vision_tower.input_image_size), mode='bilinear', align_corners=True).to(dtype=x.dtype) else: resized_x = x feature = vision_tower(resized_x) if len(feature.shape) == 3: # b, n, c b, n, c = feature.shape if n == self.num_tokens: features.append(feature) continue w = h = int(n**0.5) feature = feature.transpose(1,2).reshape(b, c, h, w) else: b, c, h, w = feature.shape if w != self.grid_size: feature = F.interpolate(feature.float(), size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=True).to(dtype=x.dtype) features.append(feature.flatten(2,3).transpose(1,2)) clip_embeds = features[0] if len(features) <= 1: no_tiling_embeds = None else: no_tiling_embeds = torch.cat(features[1:], dim=-1) if self.moe_version_type == 'feat_concat': # concat thumbnail images features together clip_thumbnail_embeds = clip_embeds[thumbnail_image_id] if no_tiling_embeds is not None: no_tiling_embeds = torch.cat([clip_thumbnail_embeds, no_tiling_embeds], dim=-1) else: no_tiling_embeds = clip_thumbnail_embeds # extra patch featureas clip_embeds_mask = ~torch.isin(torch.arange(clip_embeds.shape[0]).to(clip_embeds.device), thumbnail_image_id) clip_embeds = clip_embeds[clip_embeds_mask] features = { 'clip_embeds': clip_embeds, 'no_tiling_embeds': no_tiling_embeds, 'num_patches': num_patches } # features is a Tensor if not clip_tiling_only return features @property def dummy_feature(self): return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) @property def dtype(self): return next(self.clip_vision_tower.parameters()).dtype @property def device(self): return next(self.clip_vision_tower.parameters()).device @property def config(self): assert NotImplementedError pass @property def hidden_size(self): if self.moe_version_type == 'convnext_512_siglip_448': res = {} for vision_tower in self.vision_towers: res[vision_tower.name] = vision_tower.hidden_size return res else: return sum([_.hidden_size for _ in self.vision_towers]) @property def num_patches(self): return self.num_tokens