File size: 9,608 Bytes
a5f8592 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 |
import torch, os
import torch.nn as nn
from timm import create_model
from transformers import CLIPImageProcessor
from .convnext import convnext_xxlarge
from torch.utils.checkpoint import checkpoint
import torch
from torchvision import transforms as T
from PIL import Image
cfg={
"crop_size": 256,
"do_center_crop": True,
"do_normalize": True,
"do_resize": True,
"feature_extractor_type": "CLIPFeatureExtractor",
"image_mean": [
0.48145466,
0.4578275,
0.40821073
],
"image_std": [
0.26862954,
0.26130258,
0.27577711
],
"resample": 3,
"size": 256
}
MEAN_SLIP = [0.5, 0.5, 0.5]
STD_SLIP = [0.5, 0.5, 0.5]
MEAN_CLIP = [0.48145466, 0.4578275, 0.40821073]
STD_CLIP = [0.26862954, 0.26130258, 0.27577711]
a = [s_slip / s_clip for s_slip, s_clip in zip(STD_SLIP, STD_CLIP)]
b = [(m_slip - m_clip) / s_clip for m_slip, m_clip, s_clip in zip(MEAN_SLIP, MEAN_CLIP, STD_CLIP)]
class SlipToClipTransform:
def __init__(self, a, b):
self.a = torch.tensor(a).view(-1, 1, 1)
self.b = torch.tensor(b).view(-1, 1, 1)
def __call__(self, x_slip):
return x_slip * self.a.to(x_slip.device) + self.b.to(x_slip.device)
slip_to_clip = SlipToClipTransform(a, b)
class ConvNextVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False, normalize_type=None):
super().__init__()
self.is_loaded = False
self.freeze_vision=args.freeze_vision
self.input_image_size=args.input_image_size
self.vision_tower_name = vision_tower
self.name = 'convnext'
self.select_layer = args.mm_vision_select_layer
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
self.pre_norm = normalize_type
print('pre_norm: ', self.pre_norm)
self.delay_load = delay_load
self.load_model()
def load_model(self):
if 'xxlarge' in self.vision_tower_name:
if self.delay_load:
self.vision_tower = convnext_xxlarge(pretrained=False)
else:
self.vision_tower = convnext_xxlarge(self.vision_tower_name)
setattr(self.vision_tower, 'hidden_size', 3072)
elif os.path.exists(self.vision_tower_name):
self.vision_tower = torch.load(self.vision_tower_name)
else:
assert False, 'Not implemented'
self.vision_tower = self.vision_tower.to(torch.bfloat16)
if self.freeze_vision:
self.vision_tower.requires_grad_(False)
# if self.vision_tower.grad_checkpointing:
for s in self.vision_tower.stages:
s.grad_checkpointing = True
self.is_loaded = True
def feature_select(self, image_forward_outs):
if self.select_layer>100:
image_features = image_forward_outs[-4:]
else:
image_features = image_forward_outs[-1]
return image_features
def forward_features(self, x):
x = self.vision_tower.stem(x)
image_forward_out=[]
for blk in self.vision_tower.stages:
x = blk(x)
b,c,h,w=x.shape
image_forward_out.append(x.view(b,c,-1).transpose(1,2))
return image_forward_out
def forward(self, images):
if self.freeze_vision:
with torch.no_grad():
image_features = self._forward_images(images)
else:
image_features = self._forward_images(images)
return image_features
def _forward_images(self, images):
if type(images) is list:
image_features = []
for image in images:
if self.pre_norm == 'siglip':
dtype = image.dtype
image = slip_to_clip(image.to(torch.float32)).to(dtype)
image_forward_out = self.forward_features(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
image_feature = self.feature_select(image_forward_out)
image_features.append(image_feature)
else:
if self.pre_norm == 'siglip':
dtype = images.dtype
images = slip_to_clip(images.to(torch.float32)).to(dtype)
image_forward_outs = self.forward_features(images.to(device=self.device, dtype=self.dtype))
image_features = self.feature_select(image_forward_outs)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return next(self.vision_tower.parameters()).dtype
@property
def device(self):
return next(self.vision_tower.parameters()).device
@property
def config(self):
assert NotImplementedError
pass
@property
def num_attention_heads(self):
# as constant
return 16
@property
def num_layers(self):
# as constant
return 4
@property
def hidden_size(self):
return self.vision_tower.hidden_size
@property
def num_patches(self):
return (self.input_image_size // self.patch_embed.patch_size[0]) ** 2
class ConvNextFPNVisionTower(nn.Module):
def __init__(self,
vision_tower,
args,
fpn_target_level=1,
fpn_layer_idx=[1,2,3],
fpn_input_dim=[768,1536,3072],
delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower.replace('-fpn', 'fpn')
self.freeze_vision = getattr(args, "frozen_backbone", True)
# self.input_image_size = getattr(args, "vision_tower_input_size", 1024)
self.input_image_size = 1024 # hardcode
self.select_layer = args.mm_vision_select_layer # no effect
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
self.need_fpn = True
self.fpn_layer_idx = fpn_layer_idx # [1, 2, 3] # x8, x16, x32
self.fpn_input_dim = [768, 1536, 3072]
self.delay_load = delay_load
self.load_model()
def load_model(self):
if self.is_loaded:
return
self.image_processor = CLIPImageProcessor(**cfg)
if 'xxlarge' in self.vision_tower_name:
self.vision_tower = convnext_xxlarge(self.vision_tower_name)
setattr(self.vision_tower, 'hidden_size', self.fpn_input_dim)
# setattr(self.vision_tower, 'hidden_size', 3072)
else:
self.vision_tower = convnext_large_mlp(self.vision_tower_name)
setattr(self.vision_tower, 'hidden_size', 1536)
if self.freeze_vision:
self.vision_tower.requires_grad_(False)
# if self.vision_tower.grad_checkpointing:
for s in self.vision_tower.stages:
s.grad_checkpointing = True
if self.input_image_size is not None:
self.image_processor.size=self.input_image_size
self.image_processor.crop_size={
'height':self.input_image_size,
'width': self.input_image_size
}
self.is_loaded = True
@torch.no_grad()
def forward_features(self, x):
x = self.vision_tower.stem(x)
image_forward_out=[]
for blk in self.vision_tower.stages:
x = blk(x)
image_forward_out.append(x)
return image_forward_out
@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_feature = self.forward_features(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
image_features.append(image_feature)
else:
image_features = self.forward_features(images.to(device=self.device, dtype=self.dtype))
image_features = [image_features[idx] for idx in self.fpn_layer_idx]
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return next(self.vision_tower.parameters()).dtype
@property
def device(self):
return next(self.vision_tower.parameters()).device
@property
def config(self):
assert NotImplementedError
pass
@property
def num_attention_heads(self):
# as constant
return 16
@property
def num_layers(self):
# as constant
return 4
@property
def hidden_size(self):
return self.vision_tower.hidden_size
@property
def num_patches(self):
return (cfg['image_size'] // self.patch_embed.patch_size[0]) ** 2
if __name__ == '__main__':
COMBINED_STD = [s_slip / s_clip for s_slip, s_clip in zip(STD_SigLIP, STD_CLIP)]
COMBINED_MEAN = [(m_slip - m_clip) / s_clip for m_slip, m_clip, s_clip in zip(MEAN_SigLIP, MEAN_CLIP, STD_CLIP)]
# 定义合并的归一化变换
combined_normalize = T.Normalize(mean=COMBINED_MEAN, std=COMBINED_STD)
x = torch.randn(1, 3, 256, 256).cuda()
a = normalize_clip(x).to(torch.bfloat16)
b = normalize_siglip(x).to(torch.bfloat16)
c = denormalize_siglip(b.to(torch.float32))
c2 = normalize_clip(c).to(torch.bfloat16)
c3 = combined_normalize(b)
print((c-x).abs().max())
print((c2-a).abs().max())
print((c3-a).abs().max())
from IPython import embed
embed()
exit() |