|
|
|
|
|
|
|
""" PyTorch ViT model."""
|
|
|
|
|
|
from functools import partial
|
|
from einops import rearrange
|
|
import torch.nn.functional as F
|
|
|
|
import collections.abc
|
|
import math
|
|
from typing import Dict, List, Optional, Set, Tuple, Union
|
|
|
|
import torch
|
|
import torch.utils.checkpoint
|
|
from torch import nn
|
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
|
from transformers.activations import ACT2FN
|
|
from transformers.modeling_outputs import (
|
|
BaseModelOutput,
|
|
BaseModelOutputWithPooling,
|
|
ImageClassifierOutput,
|
|
MaskedImageModelingOutput,
|
|
BaseModelOutputWithNoAttention,
|
|
ImageClassifierOutputWithNoAttention,
|
|
)
|
|
from transformers.modeling_utils import PreTrainedModel
|
|
from transformers.utils import (
|
|
add_code_sample_docstrings,
|
|
add_start_docstrings,
|
|
add_start_docstrings_to_model_forward,
|
|
logging,
|
|
replace_return_docstrings,
|
|
)
|
|
from .configuration_fdvit import FDViTConfig
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
_CONFIG_FOR_DOC = "FDViTConfig"
|
|
|
|
|
|
_CHECKPOINT_FOR_DOC = "amd/fdvit_ti"
|
|
_EXPECTED_OUTPUT_SHAPE = [1, 49, 260]
|
|
|
|
|
|
_IMAGE_CLASS_CHECKPOINT = "amd/fdvit_ti"
|
|
_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
|
|
|
|
|
|
|
|
|
|
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
|
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
|
|
|
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
|
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
|
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
|
'survival rate' as the argument.
|
|
|
|
"""
|
|
if drop_prob == 0. or not training:
|
|
return x
|
|
keep_prob = 1 - drop_prob
|
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
|
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
|
if keep_prob > 0.0 and scale_by_keep:
|
|
random_tensor.div_(keep_prob)
|
|
return x * random_tensor
|
|
|
|
|
|
class FDViTDropPath(nn.Module):
|
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
|
"""
|
|
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
|
super().__init__()
|
|
self.drop_prob = drop_prob
|
|
self.scale_by_keep = scale_by_keep
|
|
|
|
def forward(self, x):
|
|
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
|
|
|
def extra_repr(self):
|
|
return f'drop_prob={round(self.drop_prob,3):0.3f}'
|
|
|
|
|
|
class FDViTEmbeddings(nn.Module):
|
|
"""
|
|
Construct Patch Embeddings.
|
|
"""
|
|
|
|
def __init__(self, in_channels, out_channels, patch_size, stride, padding):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size,
|
|
stride=stride, padding=padding, bias=True)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
|
|
class FDViTAttention(nn.Module):
|
|
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
head_dim = dim // num_heads
|
|
|
|
self.scale = qk_scale or head_dim ** -0.5
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
self.proj = nn.Linear(dim, dim)
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)
|
|
|
|
def get_lepe(self, x):
|
|
B, Head, N, C_p = x.shape
|
|
H = W = int(math.sqrt(N))
|
|
x = x.transpose(-2,-1).contiguous().view(B, C_p*Head, H, W)
|
|
|
|
lepe = self.get_v(x)
|
|
lepe = lepe.reshape(B, Head, C_p, N).permute(0, 1, 3, 2).contiguous()
|
|
x = x.reshape(B, Head, C_p, N).permute(0, 1, 3, 2).contiguous()
|
|
return x, lepe
|
|
|
|
def forward(self, x):
|
|
B, N, C = x.shape
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
q, k, v = qkv[0], qkv[1], qkv[2]
|
|
v, lepe = self.get_lepe(v)
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
attn = attn.softmax(dim=-1)
|
|
attn = self.attn_drop(attn)
|
|
|
|
x = (attn @ v) + lepe
|
|
x = x.transpose(1, 2).reshape(B, N, C)
|
|
x = self.proj(x)
|
|
x = self.proj_drop(x)
|
|
return x
|
|
|
|
|
|
class FDViTOutput(nn.Module):
|
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
|
super().__init__()
|
|
out_features = out_features or in_features
|
|
hidden_features = hidden_features or in_features
|
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
|
self.act = act_layer()
|
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
|
self.drop = nn.Dropout(drop)
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.act(x)
|
|
x = self.drop(x)
|
|
x = self.fc2(x)
|
|
x = self.drop(x)
|
|
return x
|
|
|
|
|
|
class Block(nn.Module):
|
|
|
|
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
|
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
|
super().__init__()
|
|
self.norm1 = norm_layer(dim)
|
|
self.attn = FDViTAttention(
|
|
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
|
|
|
self.drop_path = FDViTDropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
self.norm2 = norm_layer(dim)
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
self.mlp = FDViTOutput(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
|
|
|
def forward(self, x):
|
|
x = x + self.drop_path(self.attn(self.norm1(x)))
|
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
return x
|
|
|
|
|
|
class FDViTLayer(nn.Module):
|
|
def __init__(self, base_dim, depth, heads, mlp_ratio,
|
|
drop_rate=.0, attn_drop_rate=.0, drop_path_prob=None):
|
|
super().__init__()
|
|
self.layers = nn.ModuleList([])
|
|
embed_dim = base_dim * heads
|
|
|
|
if drop_path_prob is None:
|
|
drop_path_prob = [0.0 for _ in range(depth)]
|
|
|
|
self.blocks = nn.ModuleList([
|
|
Block(
|
|
dim=embed_dim,
|
|
num_heads=heads,
|
|
mlp_ratio=mlp_ratio,
|
|
qkv_bias=True,
|
|
drop=drop_rate,
|
|
attn_drop=attn_drop_rate,
|
|
drop_path=drop_path_prob[i],
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6)
|
|
)
|
|
for i in range(depth)])
|
|
|
|
def forward(self, x):
|
|
h, w = x.shape[2:4]
|
|
x = rearrange(x, 'b c h w -> b (h w) c')
|
|
for blk in self.blocks:
|
|
x = blk(x)
|
|
|
|
return x
|
|
|
|
|
|
class FDViTPooling(nn.Module):
|
|
def __init__(self, in_feature, out_feature, out_size):
|
|
super().__init__()
|
|
|
|
d = torch.linspace(-1, 1, out_size)
|
|
meshx, meshy = torch.meshgrid((d, d))
|
|
self.grid = torch.stack((meshy, meshx), 2)
|
|
|
|
self.conv = nn.Conv2d(in_feature, out_feature, kernel_size=3,
|
|
padding=1, stride=1)
|
|
self.ln = nn.LayerNorm(in_feature)
|
|
|
|
def forward(self, x):
|
|
h = w = int(math.sqrt(x.shape[1]))
|
|
x = self.ln(x)
|
|
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
|
|
|
grid = self.grid.expand(x.shape[0], -1, -1, -1)
|
|
x = F.grid_sample(x, grid.to(x.device).type_as(x),align_corners=True)
|
|
x = self.conv(x)
|
|
|
|
|
|
return x
|
|
|
|
class FDViTEncoder(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
|
|
self.config = config
|
|
image_size, patch_size, stride, base_dims, depth, heads, channels, out_size, mlp_ratio = config.image_size, config.patch_size, config.stride, config.base_dims, config.depth, config.heads, config.channels, config.out_size, config.mlp_ratio
|
|
num_classes = config.num_classes if config.num_classes is not None else 1000
|
|
in_chans = config.in_chans if config.in_chans is not None else 3
|
|
attn_drop_rate = config.attn_drop_rate if config.attn_drop_rate is not None else .0
|
|
drop_rate = config.drop_rate if config.drop_rate is not None else .0
|
|
drop_path_rate = config.drop_path_rate if config.drop_path_rate is not None else .0
|
|
|
|
|
|
total_block = sum(depth)
|
|
padding = 0
|
|
block_idx = 0
|
|
|
|
width = math.floor(
|
|
(image_size + 2 * padding - patch_size) / stride + 1)
|
|
|
|
self.base_dims = base_dims
|
|
self.heads = heads
|
|
self.num_classes = num_classes
|
|
|
|
self.patch_size = patch_size
|
|
self.pos_embed = nn.Parameter(
|
|
torch.randn(1, base_dims[0] * heads[0], width, width),
|
|
requires_grad=True
|
|
)
|
|
self.patch_embed = FDViTEmbeddings(in_chans, base_dims[0] * heads[0],
|
|
patch_size, stride, padding)
|
|
|
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
|
|
|
self.transformers = nn.ModuleList([])
|
|
self.pools = nn.ModuleList([])
|
|
self.decoders = nn.ModuleList([])
|
|
|
|
for stage in range(len(depth)):
|
|
drop_path_prob = [drop_path_rate * i / total_block
|
|
for i in range(block_idx, block_idx + depth[stage])]
|
|
block_idx += depth[stage]
|
|
|
|
self.transformers.append(
|
|
FDViTLayer(base_dims[stage], depth[stage], heads[stage],
|
|
mlp_ratio,
|
|
drop_rate, attn_drop_rate, drop_path_prob)
|
|
)
|
|
if stage < len(heads) - 1:
|
|
self.pools.append(
|
|
FDViTPooling(channels[stage],
|
|
channels[stage+1],
|
|
out_size[stage+1]
|
|
)
|
|
)
|
|
|
|
self.embed_dim = base_dims[-1] * heads[-1]
|
|
|
|
|
|
def forward(self, x, output_hidden_states=False, return_dict=True):
|
|
all_hidden_states = () if output_hidden_states else None
|
|
|
|
x = self.patch_embed(x)
|
|
|
|
pos_embed = self.pos_embed
|
|
x = self.pos_drop(x + pos_embed)
|
|
|
|
for stage in range(len(self.pools)):
|
|
xt = self.transformers[stage](x)
|
|
x = self.pools[stage](xt)
|
|
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (xt,)
|
|
|
|
x = self.transformers[-1](x)
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (x,)
|
|
|
|
if not return_dict:
|
|
return tuple(v for v in [x, all_hidden_states] if v is not None)
|
|
|
|
return BaseModelOutputWithNoAttention(last_hidden_state=x, hidden_states=all_hidden_states)
|
|
|
|
|
|
|
|
|
|
class FDViTPreTrainedModel(PreTrainedModel):
|
|
"""
|
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
|
models.
|
|
"""
|
|
|
|
config_class = FDViTConfig
|
|
base_model_prefix = "fdvit"
|
|
main_input_name = "pixel_values"
|
|
|
|
def _init_weights(self, module):
|
|
"""Initialize the weights"""
|
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
|
if module.bias is not None:
|
|
module.bias.data.zero_()
|
|
elif isinstance(module, nn.LayerNorm):
|
|
module.bias.data.zero_()
|
|
module.weight.data.fill_(1.0)
|
|
|
|
|
|
FDVIT_START_DOCSTRING = r"""
|
|
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
|
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
|
behavior.
|
|
|
|
Parameters:
|
|
config ([`FDViTConfig`]): Model configuration class with all the parameters of the model.
|
|
Initializing with a config file does not load the weights associated with the model, only the
|
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
"""
|
|
|
|
FDVIT_INPUTS_DOCSTRING = r"""
|
|
Args:
|
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
|
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`FDViTImageProcessor.__call__`]
|
|
for details.
|
|
"""
|
|
|
|
|
|
@add_start_docstrings(
|
|
"The bare FDViT Model transformer outputting raw hidden-states without any specific head on top.",
|
|
FDVIT_START_DOCSTRING,
|
|
)
|
|
class FDViTModel(FDViTPreTrainedModel):
|
|
def __init__(self, config: FDViTConfig):
|
|
super().__init__(config)
|
|
self.config = config
|
|
|
|
self.encoder = FDViTEncoder(config)
|
|
|
|
|
|
self.post_init()
|
|
|
|
@add_start_docstrings_to_model_forward(FDVIT_INPUTS_DOCSTRING)
|
|
@add_code_sample_docstrings(
|
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
|
output_type=BaseModelOutputWithNoAttention,
|
|
config_class=_CONFIG_FOR_DOC,
|
|
modality="vision",
|
|
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
|
)
|
|
def forward(
|
|
self,
|
|
pixel_values: Optional[torch.Tensor] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple, BaseModelOutputWithNoAttention]:
|
|
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
if pixel_values is None:
|
|
raise ValueError("You have to specify pixel_values")
|
|
|
|
encoder_outputs = self.encoder(
|
|
pixel_values,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
sequence_output = encoder_outputs[0]
|
|
|
|
if not return_dict:
|
|
return (sequence_output, None) + encoder_outputs[1:]
|
|
|
|
return BaseModelOutputWithNoAttention(
|
|
last_hidden_state=sequence_output,
|
|
hidden_states=encoder_outputs.hidden_states,
|
|
)
|
|
|
|
|
|
class FDViTPooler(nn.Module):
|
|
def __init__(self, config: FDViTConfig):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.activation = nn.Tanh()
|
|
|
|
def forward(self, hidden_states):
|
|
|
|
|
|
pooled_output = self.dense(first_token_tensor)
|
|
pooled_output = self.activation(pooled_output)
|
|
return pooled_output
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""
|
|
FDViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
|
|
the [CLS] token) e.g. for ImageNet.
|
|
""",
|
|
FDVIT_START_DOCSTRING,
|
|
)
|
|
class FDViTForImageClassification(FDViTPreTrainedModel):
|
|
def __init__(self, config: FDViTConfig) -> None:
|
|
super().__init__(config)
|
|
|
|
self.num_labels = config.num_labels
|
|
self.fdvit = FDViTModel(config)
|
|
|
|
|
|
self.norm = nn.LayerNorm(config.base_dims[-1] * config.heads[-1], eps=1e-6)
|
|
|
|
self.classifier = nn.Linear(config.base_dims[-1] * config.heads[-1], config.num_classes) if config.num_classes > 0 else nn.Identity()
|
|
|
|
|
|
self.post_init()
|
|
|
|
@add_start_docstrings_to_model_forward(FDVIT_INPUTS_DOCSTRING)
|
|
@add_code_sample_docstrings(
|
|
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
|
output_type=ImageClassifierOutputWithNoAttention,
|
|
config_class=_CONFIG_FOR_DOC,
|
|
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
|
)
|
|
def forward(
|
|
self,
|
|
pixel_values: Optional[torch.Tensor] = None,
|
|
labels: Optional[torch.Tensor] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
outputs = self.fdvit(
|
|
pixel_values,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
sequence_output = outputs[0]
|
|
|
|
logits = self.classifier(self.norm(sequence_output).mean(dim=1))
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
|
|
labels = labels.to(logits.device)
|
|
if self.config.problem_type is None:
|
|
if self.num_labels == 1:
|
|
self.config.problem_type = "regression"
|
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
|
self.config.problem_type = "single_label_classification"
|
|
else:
|
|
self.config.problem_type = "multi_label_classification"
|
|
|
|
if self.config.problem_type == "regression":
|
|
loss_fct = MSELoss()
|
|
if self.num_labels == 1:
|
|
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
|
else:
|
|
loss = loss_fct(logits, labels)
|
|
elif self.config.problem_type == "single_label_classification":
|
|
loss_fct = CrossEntropyLoss()
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
elif self.config.problem_type == "multi_label_classification":
|
|
loss_fct = BCEWithLogitsLoss()
|
|
loss = loss_fct(logits, labels)
|
|
|
|
if not return_dict:
|
|
output = (logits,) + outputs[1:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
|
|
|