C-RADIOv2-G / cls_token.py
gheinrich's picture
Upload model (#1)
2d3bbc7 verified
# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from typing import Optional
import torch
from torch import nn
class ClsToken(nn.Module):
def __init__(self, ndim: int,
num_tokens: int = 1,
enabled: bool = True,
register_multiple: Optional[int] = None,
num_registers: Optional[int] = None,
):
super().__init__()
self.ndim = ndim
self.enabled = enabled
self.num_registers = 0
self.num_tokens = num_tokens
if enabled:
if num_registers:
self.num_registers = num_registers
elif register_multiple:
self.num_registers = register_multiple - (num_tokens % register_multiple)
scale = ndim ** -0.5
self.token = nn.Parameter(torch.randn(num_tokens + self.num_registers, ndim) * scale)
else:
self.token = None
self.num_patches = self.num_tokens + self.num_registers
def disable(self):
self.token = None
self.enabled = False
def forward(self, x: torch.Tensor):
if self.token is None:
return x
token = self.token.unsqueeze(0).expand(x.shape[0], -1, -1)
x = torch.cat([
token,
x,
], dim=1)
return x
def no_weight_decay(self):
return [
'token',
]