Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from torchvision.transforms import Compose, Resize, InterpolationMode, Normalize | |
import logging | |
logger = logging.getLogger(__name__) | |
class Dinov2Wrapper(nn.Module): | |
""" | |
Dino v2 wrapper using original implementation, hacked with modulation. | |
""" | |
def __init__(self, model_name: str, modulation_dim: int = None, freeze: bool = True): | |
super().__init__() | |
self.modulation_dim = modulation_dim | |
self.model = self._build_dinov2(model_name, modulation_dim=modulation_dim) | |
self.preprocess = Compose([ | |
Resize(self.model.patch_embed.img_size[0], interpolation=InterpolationMode.BICUBIC), | |
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), | |
]) | |
if freeze: | |
if modulation_dim is not None: | |
raise ValueError("Modulated Dinov2 requires training, freezing is not allowed.") | |
self._freeze() | |
def _freeze(self): | |
logger.warning(f"======== Freezing Dinov2Wrapper ========") | |
self.model.eval() | |
for name, param in self.model.named_parameters(): | |
param.requires_grad = False | |
def _build_dinov2(model_name: str, modulation_dim: int = None, pretrained: bool = True): | |
from importlib import import_module | |
dinov2_hub = import_module(".dinov2.hub.backbones", package=__package__) | |
model_fn = getattr(dinov2_hub, model_name) | |
logger.info(f"Modulation dim for Dinov2 is {modulation_dim}.") | |
model = model_fn(modulation_dim=modulation_dim, pretrained=pretrained) | |
return model | |
# @torch.compile | |
def forward(self, image: torch.Tensor, mod: torch.Tensor = None): | |
# image: [N, H, W, C] -- need to be permuted!!! | |
# mod: [N, D] or None | |
assert image.shape[-1] == 3 | |
image = image.permute(0, 3, 1, 2) / 255. | |
image = self.preprocess(image) | |
if self.modulation_dim is None: | |
assert mod is None, "Unexpected modulation input in dinov2 forward." | |
outs = self.model(image, is_training=True) | |
else: | |
assert mod is not None, "Modulation input is required in modulated dinov2 forward." | |
outs = self.model(image, mod=mod, is_training=True) | |
ret = torch.cat([ | |
outs["x_norm_clstoken"].unsqueeze(dim=1), | |
outs["x_norm_patchtokens"], | |
], dim=1) | |
# ret in [B, 1370, 384] | |
return ret | |