File size: 11,067 Bytes
a5f8592
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
import torch, os
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

from .siglip_vision_tower import SiglipVisionTower

# from .hr_clip_encoder import HRCLIPVisionTower
# from .eva_vit import EVAVITVisionTower
# from .SAM.modeling_sam import SAMVisionTower
# from .pix2struct_large import Pix2StructLargeVisionTower
import torch.nn.functional as F
from torch.nn.init import trunc_normal_
from copy import deepcopy
import random
import math

class MultiBackboneChannelConcatenationVisionTower(nn.Module):
    def __init__(self,
                 vision_tower,
                 args,
                 grid_size=32,
                 convnext_img_size=1024,
                 normalize_type=None, raw_config=None):
        
        super().__init__()

        self.is_loaded = False
        self.grid_size = grid_size
        self.num_tokens = self.grid_size ** 2
        self.normalize_type = args.normalize_type
        self.moe_version_type = args.moe_version_type
        self.raw_config = raw_config
        print("moe_version_type: ", self.moe_version_type)
        assert self.moe_version_type in [None, 'all_tiling', 'seq_concat', 'feat_concat', 'convnext_512_siglip_448'], f"Unknown self.moe_version_type: {self.moe_version_type}"
        
        vision_tower_name_list = vision_tower.split(";")
        self.input_image_size = 1024
        self.convnext_img_size = convnext_img_size
        self.load_vision_towers(vision_tower_name_list, args)

      
    def load_vision_towers(self, vision_tower_name_list, args):
        self.vision_towers = nn.ModuleList()

        freeze_backbone_list = args.freeze_backbones # note this is a str
        if freeze_backbone_list is not None and len(freeze_backbone_list) > 0:
            print("The frozen backbones: ", freeze_backbone_list)
        else:
            # make it a blank str
            freeze_backbone_list = ""

        for name in vision_tower_name_list:
            
            ## ConvNeXt
            if name == 'convnext-1024':
                convnext_args = deepcopy(args)

                convnext_args.freeze_vision = False
                if 'convnext-1024' in freeze_backbone_list:
                    convnext_args.freeze_vision = True

                from .convnext_encoder import ConvNextVisionTower
                convnext_args.input_image_size = self.convnext_img_size
                convnext_vision_tower = args.vision_tower_convnext_path
                convnext_vision_tower = ConvNextVisionTower(convnext_vision_tower, 
                                                                convnext_args, delay_load=args.delay_load, normalize_type=self.normalize_type)
                convnext_vision_tower.load_model()      
                self.vision_towers.append(convnext_vision_tower)

            ## PaliSigLIP
            elif name == 'palisiglip':
                palisiglip_args = deepcopy(args)
                palisiglip_args.input_image_size = 448

                palisiglip_args.freeze_vision = False
                if 'palisiglip' in freeze_backbone_list:
                    palisiglip_args.freeze_vision = True

                palisiglip_vision_tower = SiglipVisionTower(args.vision_tower_siglip_path, palisiglip_args, delay_load=args.delay_load, raw_config=self.raw_config)     
   
                palisiglip_vision_tower.load_model()
                self.vision_towers.append(palisiglip_vision_tower)

        # Set the image processor
        self.image_processor = None
        self.is_loaded = True

    def load_model(self):
        assert self.is_loaded, "All the vision encoders should be loaded during initialization!"

    def forward(self, x):
        # x is a Tensor if moe_version_type is None or 'all_tiling'
        # else is a tuple(Tensor, Tensor)
        if self.moe_version_type in [None, 'all_tiling']:
            # The default pipeline
            features = []
            image_input_size = x.shape[2]
            assert x.shape[2] == x.shape[3], f"Image should be a square but size ({x.shape[2]} x {x.shape[3]})"
            for vision_tower in self.vision_towers:
        
                if vision_tower.input_image_size != image_input_size:
                    resized_x = F.interpolate(x.float(), 
                                            size=(vision_tower.input_image_size, vision_tower.input_image_size), 
                                            mode='bilinear', 
                                            align_corners=True).to(dtype=x.dtype)
                else:
                    resized_x = x
                
                feature = vision_tower(resized_x)
                
                if len(feature.shape) == 3: # b, n, c
                    b, n, c = feature.shape
                    if n == self.num_tokens:
                        features.append(feature)
                        continue
                    w = h = int(n**0.5)
                    feature = feature.transpose(1,2).reshape(b, c, h, w)
                else:
                    b, c, h, w = feature.shape

                if w != self.grid_size:
                    feature = F.interpolate(feature.float(), size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=True).to(dtype=x.dtype)
                features.append(feature.flatten(2,3).transpose(1,2))
            
            features = torch.cat(features, dim=-1)
        elif self.moe_version_type == 'convnext_512_siglip_448':
            features = {}
            image_input_size = x.shape[2]
            assert x.shape[2] == x.shape[3], f"Image should be a square but size ({x.shape[2]} x {x.shape[3]})"
            for vision_tower in self.vision_towers:
        
                if vision_tower.input_image_size != image_input_size:
                    resized_x = F.interpolate(x.float(), 
                                            size=(vision_tower.input_image_size, vision_tower.input_image_size), 
                                            mode='bilinear', 
                                            align_corners=True).to(dtype=x.dtype)
                else:
                    resized_x = x
                
                feature = vision_tower(resized_x)
                
                # if len(feature.shape) == 3: # b, n, c
                #     b, n, c = feature.shape
                #     if n == self.num_tokens:
                #         features.append(feature)
                #         continue
                #     w = h = int(n**0.5)
                #     feature = feature.transpose(1,2).reshape(b, c, h, w)
                # else:
                #     b, c, h, w = feature.shape
                features[vision_tower.name] = feature

        else:
            assert isinstance(x, dict), "x is expected to be a dict but {}".format(type(x))
            pixel_values = x['pixel_values']
            num_patches = x['num_patches'] # num patch of paddings token in texts

            # calculated the real image patches
            if self.moe_version_type == 'seq_concat':
                image_in_num_patches = [i-1 for i in num_patches]
            else:
                image_in_num_patches = [i for i in num_patches]


            assert sum(image_in_num_patches) == pixel_values.size(0), "sum(image_in_num_patches) ({}) != pixel_values.size(0) ({})".format(sum(image_in_num_patches), pixel_values.size(0))

            # find the thubnail image id
            thumbnail_image_id = torch.cumsum(torch.tensor(image_in_num_patches).to(pixel_values.device), 0) - 1
            image_no_tiling = pixel_values[thumbnail_image_id]

            # By default, we use the 1st vision_tower for x, others for x_nt
            features = []
            for layer_id, vision_tower in enumerate(self.vision_towers):
                if layer_id == 0:
                    x = pixel_values
                else:
                    x = image_no_tiling

                if vision_tower.input_image_size != self.input_image_size:
                    resized_x = F.interpolate(x.float(), 
                                            size=(vision_tower.input_image_size, vision_tower.input_image_size), 
                                            mode='bilinear', 
                                            align_corners=True).to(dtype=x.dtype)
                else:
                    resized_x = x
                
                feature = vision_tower(resized_x)
                if len(feature.shape) == 3: # b, n, c
                    b, n, c = feature.shape
                    if n == self.num_tokens:
                        features.append(feature)
                        continue

                    w = h = int(n**0.5)
                    feature = feature.transpose(1,2).reshape(b, c, h, w)
                else:
                    b, c, h, w = feature.shape

                if w != self.grid_size:
                    feature = F.interpolate(feature.float(), size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=True).to(dtype=x.dtype)
                features.append(feature.flatten(2,3).transpose(1,2))

            clip_embeds = features[0]
            if len(features) <= 1:
                no_tiling_embeds = None
            else:
                no_tiling_embeds = torch.cat(features[1:], dim=-1)

            if self.moe_version_type == 'feat_concat':
                # concat thumbnail images features together
                clip_thumbnail_embeds = clip_embeds[thumbnail_image_id]
                if no_tiling_embeds is not None:
                    no_tiling_embeds = torch.cat([clip_thumbnail_embeds, no_tiling_embeds], dim=-1)
                else:
                    no_tiling_embeds = clip_thumbnail_embeds

                # extra patch featureas
                clip_embeds_mask = ~torch.isin(torch.arange(clip_embeds.shape[0]).to(clip_embeds.device), thumbnail_image_id)
                clip_embeds = clip_embeds[clip_embeds_mask]
            

            features = {
                    'clip_embeds': clip_embeds, 
                    'no_tiling_embeds': no_tiling_embeds,
                    'num_patches': num_patches
                }

        # features is a Tensor if not clip_tiling_only

        return features
        
    @property
    def dummy_feature(self):
        return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)

    @property
    def dtype(self):
        return next(self.clip_vision_tower.parameters()).dtype

    @property
    def device(self):
        return next(self.clip_vision_tower.parameters()).device

    @property
    def config(self):
        assert NotImplementedError
        pass

    @property
    def hidden_size(self):
        if self.moe_version_type == 'convnext_512_siglip_448':
            res = {}
            for vision_tower in self.vision_towers:
                res[vision_tower.name] = vision_tower.hidden_size
            return res
        else:
            return sum([_.hidden_size for _ in self.vision_towers])

    @property
    def num_patches(self):
        return self.num_tokens