File size: 5,691 Bytes
2d3bbc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
from distutils.version import LooseVersion
from types import MethodType
from typing import List, Optional, Tuple, Union
import warnings

import torch
from torch import nn
import torch.nn.functional as F

from timm.models.registry import register_model
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

from .forward_intermediates import forward_intermediates
from .input_conditioner import InputConditioner

_has_torch_sdpa = hasattr(F, 'scaled_dot_product_attention')


class PaliGemmaWrapper(nn.Module):
    def __init__(self, vis_model: nn.Module, embed_dim: int):
        super().__init__()

        self.vis_model = vis_model
        self.embed_dim = embed_dim

    @property
    def patch_size(self):
        return self.vis_model.embeddings.patch_size

    @property
    def blocks(self):
        return self.vis_model.encoder.layers

    @property
    def embed_dim(self):
        return self.vis_model.embeddings.embed_dim

    def forward(self, x: torch.Tensor):
        outputs = self.vis_model(
            x,
            return_dict=False,
            interpolate_pos_encoding=True,
        )

        features = outputs[0].to(torch.float32)

        summary = features.mean(dim=1)

        return summary, features

    def forward_features(self, x: torch.Tensor):
        return self(x)


def _get_paligemma_model(repo: str, embed_dim: int = None, dtype: torch.dtype = torch.bfloat16):
    from transformers import PaliGemmaForConditionalGeneration, __version__ as tx_version

    if LooseVersion(tx_version) > LooseVersion('4.44.2'):
        warnings.warn(f'Your transformers version "{tx_version}" is higher than 4.44.2, and for whatever reason, PaliGemma might be broken.')

    extra_args = dict()

    if dtype is not None:
        extra_args['torch_dtype'] = dtype
        rev = str(dtype).split('.')[-1]
        extra_args['revision'] = rev

    model = PaliGemmaForConditionalGeneration.from_pretrained(repo, **extra_args)

    vis_model = model.vision_tower.vision_model

    vis_model = PaliGemmaWrapper(vis_model, embed_dim)

    return vis_model

@register_model
def paligemma_896_student(**kwargs):
    model = _get_paligemma_model('google/paligemma-3b-pt-896', embed_dim=1152, dtype=None)

    return model


def dv2_sdpa(self, x: torch.Tensor) -> torch.Tensor:
    B, N, C = x.shape
    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

    q, k, v = qkv[0], qkv[1], qkv[2]
    x = F.scaled_dot_product_attention(
        q, k, v,
        is_causal=False,
        dropout_p=self.attn_drop.p if self.training else 0.,
        scale=self.scale,
    )
    x = x.transpose(1, 2).reshape(B, N, C)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x

def _load_dino_v2(dino_v2_model, cache_dir: Optional[str] = None, pretrained=True, **kwargs):
    if cache_dir:
        torch.hub.set_dir(cache_dir)
    model: nn.Module = torch.hub.load(
        'facebookresearch/dinov2',
        dino_v2_model,
        pretrained=pretrained,
        # **kwargs,
    )

    if _has_torch_sdpa:
        for n, m in model.named_modules():
            if n.endswith('.attn'):
                m.forward = MethodType(dv2_sdpa, m)

    return model

class DinoWrapper(nn.Module):
    def __init__(self, dino_model: nn.Module):
        super().__init__()

        self.inner = dino_model
        dino_model.blocks = nn.Sequential(*dino_model.blocks)

    @property
    def embed_dim(self):
        return self.inner.embed_dim

    @property
    def patch_size(self):
        return self.inner.patch_size

    @property
    def num_cls_tokens(self):
        return getattr(self.inner, 'num_tokens', 1)

    @property
    def num_registers(self):
        return getattr(self.inner, 'num_register_tokens', 0)

    @property
    def num_summary_tokens(self):
        return self.num_cls_tokens + self.num_registers

    @property
    def blocks(self):
        return self.inner.blocks

    def forward(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
        parts = self.inner.forward_features(*args, **kwargs)

        cls_token = parts['x_norm_clstoken']
        features = parts['x_norm_patchtokens']

        return cls_token, features

    def forward_features(self, x: torch.Tensor):
        x = self.inner.prepare_tokens_with_masks(x)
        x = self.inner.blocks(x)
        x_norm = self.inner.norm(x)

        return x_norm[:, 0], x_norm[:, self.num_summary_tokens:]

    def patchify(self, x: torch.Tensor) -> torch.Tensor:
        return self.inner.prepare_tokens_with_masks(x)

    def forward_intermediates(self,
        x: torch.Tensor,
        norm: bool = False,
        **kwargs,
    ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
        return forward_intermediates(
            self,
            patch_extractor=self.inner.prepare_tokens_with_masks,
            num_summary_tokens=self.num_summary_tokens,
            num_cls_tokens=self.num_cls_tokens,
            norm=self.inner.norm if norm else lambda y: y,
            x=x,
            **kwargs,
        )


def _dino_student(arch: str, **kwargs):
    from . import dinov2_arch

    factory = getattr(dinov2_arch, arch)
    model = factory()

    model = DinoWrapper(model)

    conditioner = InputConditioner(
        input_scale=1.0,
        norm_mean=IMAGENET_DEFAULT_MEAN,
        norm_std=IMAGENET_DEFAULT_STD,
    )

    model.input_conditioner = conditioner

    return model


@register_model
def dino_v2_l_student(**kwargs):
    return _dino_student('dinov2_vitl14_reg', **kwargs)

@register_model
def dino_v2_g_student(**kwargs):
    return _dino_student('dinov2_vitg14_reg', **kwargs)