File size: 5,691 Bytes
2d3bbc7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
from distutils.version import LooseVersion
from types import MethodType
from typing import List, Optional, Tuple, Union
import warnings
import torch
from torch import nn
import torch.nn.functional as F
from timm.models.registry import register_model
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .forward_intermediates import forward_intermediates
from .input_conditioner import InputConditioner
_has_torch_sdpa = hasattr(F, 'scaled_dot_product_attention')
class PaliGemmaWrapper(nn.Module):
def __init__(self, vis_model: nn.Module, embed_dim: int):
super().__init__()
self.vis_model = vis_model
self.embed_dim = embed_dim
@property
def patch_size(self):
return self.vis_model.embeddings.patch_size
@property
def blocks(self):
return self.vis_model.encoder.layers
@property
def embed_dim(self):
return self.vis_model.embeddings.embed_dim
def forward(self, x: torch.Tensor):
outputs = self.vis_model(
x,
return_dict=False,
interpolate_pos_encoding=True,
)
features = outputs[0].to(torch.float32)
summary = features.mean(dim=1)
return summary, features
def forward_features(self, x: torch.Tensor):
return self(x)
def _get_paligemma_model(repo: str, embed_dim: int = None, dtype: torch.dtype = torch.bfloat16):
from transformers import PaliGemmaForConditionalGeneration, __version__ as tx_version
if LooseVersion(tx_version) > LooseVersion('4.44.2'):
warnings.warn(f'Your transformers version "{tx_version}" is higher than 4.44.2, and for whatever reason, PaliGemma might be broken.')
extra_args = dict()
if dtype is not None:
extra_args['torch_dtype'] = dtype
rev = str(dtype).split('.')[-1]
extra_args['revision'] = rev
model = PaliGemmaForConditionalGeneration.from_pretrained(repo, **extra_args)
vis_model = model.vision_tower.vision_model
vis_model = PaliGemmaWrapper(vis_model, embed_dim)
return vis_model
@register_model
def paligemma_896_student(**kwargs):
model = _get_paligemma_model('google/paligemma-3b-pt-896', embed_dim=1152, dtype=None)
return model
def dv2_sdpa(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
x = F.scaled_dot_product_attention(
q, k, v,
is_causal=False,
dropout_p=self.attn_drop.p if self.training else 0.,
scale=self.scale,
)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def _load_dino_v2(dino_v2_model, cache_dir: Optional[str] = None, pretrained=True, **kwargs):
if cache_dir:
torch.hub.set_dir(cache_dir)
model: nn.Module = torch.hub.load(
'facebookresearch/dinov2',
dino_v2_model,
pretrained=pretrained,
# **kwargs,
)
if _has_torch_sdpa:
for n, m in model.named_modules():
if n.endswith('.attn'):
m.forward = MethodType(dv2_sdpa, m)
return model
class DinoWrapper(nn.Module):
def __init__(self, dino_model: nn.Module):
super().__init__()
self.inner = dino_model
dino_model.blocks = nn.Sequential(*dino_model.blocks)
@property
def embed_dim(self):
return self.inner.embed_dim
@property
def patch_size(self):
return self.inner.patch_size
@property
def num_cls_tokens(self):
return getattr(self.inner, 'num_tokens', 1)
@property
def num_registers(self):
return getattr(self.inner, 'num_register_tokens', 0)
@property
def num_summary_tokens(self):
return self.num_cls_tokens + self.num_registers
@property
def blocks(self):
return self.inner.blocks
def forward(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
parts = self.inner.forward_features(*args, **kwargs)
cls_token = parts['x_norm_clstoken']
features = parts['x_norm_patchtokens']
return cls_token, features
def forward_features(self, x: torch.Tensor):
x = self.inner.prepare_tokens_with_masks(x)
x = self.inner.blocks(x)
x_norm = self.inner.norm(x)
return x_norm[:, 0], x_norm[:, self.num_summary_tokens:]
def patchify(self, x: torch.Tensor) -> torch.Tensor:
return self.inner.prepare_tokens_with_masks(x)
def forward_intermediates(self,
x: torch.Tensor,
norm: bool = False,
**kwargs,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
return forward_intermediates(
self,
patch_extractor=self.inner.prepare_tokens_with_masks,
num_summary_tokens=self.num_summary_tokens,
num_cls_tokens=self.num_cls_tokens,
norm=self.inner.norm if norm else lambda y: y,
x=x,
**kwargs,
)
def _dino_student(arch: str, **kwargs):
from . import dinov2_arch
factory = getattr(dinov2_arch, arch)
model = factory()
model = DinoWrapper(model)
conditioner = InputConditioner(
input_scale=1.0,
norm_mean=IMAGENET_DEFAULT_MEAN,
norm_std=IMAGENET_DEFAULT_STD,
)
model.input_conditioner = conditioner
return model
@register_model
def dino_v2_l_student(**kwargs):
return _dino_student('dinov2_vitl14_reg', **kwargs)
@register_model
def dino_v2_g_student(**kwargs):
return _dino_student('dinov2_vitg14_reg', **kwargs)
|