FDViT_ti / modeling_fdvit.py
yttdebaba's picture
Upload 4 files
6b9cc7b verified
raw
history blame
19.7 kB
# Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License") and the MIT License (the "License2");
""" PyTorch ViT model."""
from functools import partial
from einops import rearrange
import torch.nn.functional as F
import collections.abc
import math
from typing import Dict, List, Optional, Set, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
ImageClassifierOutput,
MaskedImageModelingOutput,
BaseModelOutputWithNoAttention,
ImageClassifierOutputWithNoAttention,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_fdvit import FDViTConfig
logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "FDViTConfig"
# Base docstring
_CHECKPOINT_FOR_DOC = "amd/fdvit_ti"
_EXPECTED_OUTPUT_SHAPE = [1, 49, 260]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "amd/fdvit_ti"
_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
#from ..deprecated._archive_maps import VIT_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class FDViTDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
super().__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f'drop_prob={round(self.drop_prob,3):0.3f}'
class FDViTEmbeddings(nn.Module):
"""
Construct Patch Embeddings.
"""
def __init__(self, in_channels, out_channels, patch_size, stride, padding):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size,
stride=stride, padding=padding, bias=True)
def forward(self, x):
x = self.conv(x)
return x
class FDViTAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)
def get_lepe(self, x):
B, Head, N, C_p = x.shape
H = W = int(math.sqrt(N))
x = x.transpose(-2,-1).contiguous().view(B, C_p*Head, H, W)
lepe = self.get_v(x)
lepe = lepe.reshape(B, Head, C_p, N).permute(0, 1, 3, 2).contiguous()
x = x.reshape(B, Head, C_p, N).permute(0, 1, 3, 2).contiguous()
return x, lepe
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
v, lepe = self.get_lepe(v)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v) + lepe
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FDViTOutput(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = FDViTAttention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = FDViTDropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = FDViTOutput(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class FDViTLayer(nn.Module):
def __init__(self, base_dim, depth, heads, mlp_ratio,
drop_rate=.0, attn_drop_rate=.0, drop_path_prob=None):
super().__init__()
self.layers = nn.ModuleList([])
embed_dim = base_dim * heads
if drop_path_prob is None:
drop_path_prob = [0.0 for _ in range(depth)]
self.blocks = nn.ModuleList([
Block(
dim=embed_dim,
num_heads=heads,
mlp_ratio=mlp_ratio,
qkv_bias=True,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=drop_path_prob[i],
norm_layer=partial(nn.LayerNorm, eps=1e-6)
)
for i in range(depth)])
def forward(self, x):
h, w = x.shape[2:4]
x = rearrange(x, 'b c h w -> b (h w) c')
for blk in self.blocks:
x = blk(x)
return x
class FDViTPooling(nn.Module):
def __init__(self, in_feature, out_feature, out_size):
super().__init__()
d = torch.linspace(-1, 1, out_size)
meshx, meshy = torch.meshgrid((d, d))
self.grid = torch.stack((meshy, meshx), 2)
self.conv = nn.Conv2d(in_feature, out_feature, kernel_size=3,
padding=1, stride=1)
self.ln = nn.LayerNorm(in_feature)
def forward(self, x):
h = w = int(math.sqrt(x.shape[1]))
x = self.ln(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
grid = self.grid.expand(x.shape[0], -1, -1, -1)
x = F.grid_sample(x, grid.to(x.device).type_as(x),align_corners=True)
x = self.conv(x)
return x
class FDViTEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
image_size, patch_size, stride, base_dims, depth, heads, channels, out_size, mlp_ratio = config.image_size, config.patch_size, config.stride, config.base_dims, config.depth, config.heads, config.channels, config.out_size, config.mlp_ratio
num_classes = config.num_classes if config.num_classes is not None else 1000
in_chans = config.in_chans if config.in_chans is not None else 3
attn_drop_rate = config.attn_drop_rate if config.attn_drop_rate is not None else .0
drop_rate = config.drop_rate if config.drop_rate is not None else .0
drop_path_rate = config.drop_path_rate if config.drop_path_rate is not None else .0
total_block = sum(depth)
padding = 0
block_idx = 0
width = math.floor(
(image_size + 2 * padding - patch_size) / stride + 1)
self.base_dims = base_dims
self.heads = heads
self.num_classes = num_classes
self.patch_size = patch_size
self.pos_embed = nn.Parameter(
torch.randn(1, base_dims[0] * heads[0], width, width),
requires_grad=True
)
self.patch_embed = FDViTEmbeddings(in_chans, base_dims[0] * heads[0],
patch_size, stride, padding)
self.pos_drop = nn.Dropout(p=drop_rate)
self.transformers = nn.ModuleList([])
self.pools = nn.ModuleList([])
self.decoders = nn.ModuleList([])
for stage in range(len(depth)):
drop_path_prob = [drop_path_rate * i / total_block
for i in range(block_idx, block_idx + depth[stage])]
block_idx += depth[stage]
self.transformers.append(
FDViTLayer(base_dims[stage], depth[stage], heads[stage],
mlp_ratio,
drop_rate, attn_drop_rate, drop_path_prob)
)
if stage < len(heads) - 1:
self.pools.append(
FDViTPooling(channels[stage],
channels[stage+1],
out_size[stage+1]
)
)
self.embed_dim = base_dims[-1] * heads[-1]
def forward(self, x, output_hidden_states=False, return_dict=True):
all_hidden_states = () if output_hidden_states else None
x = self.patch_embed(x)
pos_embed = self.pos_embed
x = self.pos_drop(x + pos_embed)
for stage in range(len(self.pools)):
xt = self.transformers[stage](x)
x = self.pools[stage](xt)
if output_hidden_states:
all_hidden_states = all_hidden_states + (xt,)
x = self.transformers[-1](x)
if output_hidden_states:
all_hidden_states = all_hidden_states + (x,)
if not return_dict:
return tuple(v for v in [x, all_hidden_states] if v is not None)
return BaseModelOutputWithNoAttention(last_hidden_state=x, hidden_states=all_hidden_states)
# x = self.norm(x)
# return x.mean(dim=1)
class FDViTPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = FDViTConfig
base_model_prefix = "fdvit"
main_input_name = "pixel_values"
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
FDVIT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`FDViTConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
FDVIT_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`FDViTImageProcessor.__call__`]
for details.
"""
@add_start_docstrings(
"The bare FDViT Model transformer outputting raw hidden-states without any specific head on top.",
FDVIT_START_DOCSTRING,
)
class FDViTModel(FDViTPreTrainedModel):
def __init__(self, config: FDViTConfig):
super().__init__(config)
self.config = config
self.encoder = FDViTEncoder(config)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(FDVIT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithNoAttention,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithNoAttention]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
encoder_outputs = self.encoder(
pixel_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
if not return_dict:
return (sequence_output, None) + encoder_outputs[1:]
return BaseModelOutputWithNoAttention(
last_hidden_state=sequence_output,
hidden_states=encoder_outputs.hidden_states,
)
class FDViTPooler(nn.Module):
def __init__(self, config: FDViTConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
@add_start_docstrings(
"""
FDViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
the [CLS] token) e.g. for ImageNet.
""",
FDVIT_START_DOCSTRING,
)
class FDViTForImageClassification(FDViTPreTrainedModel):
def __init__(self, config: FDViTConfig) -> None:
super().__init__(config)
self.num_labels = config.num_labels
self.fdvit = FDViTModel(config)
# Final norm
self.norm = nn.LayerNorm(config.base_dims[-1] * config.heads[-1], eps=1e-6)
# Classifier head
self.classifier = nn.Linear(config.base_dims[-1] * config.heads[-1], config.num_classes) if config.num_classes > 0 else nn.Identity()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(FDVIT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=ImageClassifierOutputWithNoAttention,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.fdvit(
pixel_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.classifier(self.norm(sequence_output).mean(dim=1))
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)