# Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License") and the MIT License (the "License2"); """ PyTorch ViT model.""" from functools import partial from einops import rearrange import torch.nn.functional as F import collections.abc import math from typing import Dict, List, Optional, Set, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedImageModelingOutput, BaseModelOutputWithNoAttention, ImageClassifierOutputWithNoAttention, ) from transformers.modeling_utils import PreTrainedModel from transformers.utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ) from .configuration_fdvit import FDViTConfig logger = logging.get_logger(__name__) # General docstring _CONFIG_FOR_DOC = "FDViTConfig" # Base docstring _CHECKPOINT_FOR_DOC = "amd/fdvit_ti" _EXPECTED_OUTPUT_SHAPE = [1, 49, 260] # Image classification docstring _IMAGE_CLASS_CHECKPOINT = "amd/fdvit_ti" _IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat" #from ..deprecated._archive_maps import VIT_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402 def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0. or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = x.new_empty(shape).bernoulli_(keep_prob) if keep_prob > 0.0 and scale_by_keep: random_tensor.div_(keep_prob) return x * random_tensor class FDViTDropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): super().__init__() self.drop_prob = drop_prob self.scale_by_keep = scale_by_keep def forward(self, x): return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) def extra_repr(self): return f'drop_prob={round(self.drop_prob,3):0.3f}' class FDViTEmbeddings(nn.Module): """ Construct Patch Embeddings. """ def __init__(self, in_channels, out_channels, patch_size, stride, padding): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size, stride=stride, padding=padding, bias=True) def forward(self, x): x = self.conv(x) return x class FDViTAttention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) def get_lepe(self, x): B, Head, N, C_p = x.shape H = W = int(math.sqrt(N)) x = x.transpose(-2,-1).contiguous().view(B, C_p*Head, H, W) lepe = self.get_v(x) lepe = lepe.reshape(B, Head, C_p, N).permute(0, 1, 3, 2).contiguous() x = x.reshape(B, Head, C_p, N).permute(0, 1, 3, 2).contiguous() return x, lepe def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) v, lepe = self.get_lepe(v) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v) + lepe x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class FDViTOutput(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.norm1 = norm_layer(dim) self.attn = FDViTAttention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = FDViTDropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = FDViTOutput(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward(self, x): x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class FDViTLayer(nn.Module): def __init__(self, base_dim, depth, heads, mlp_ratio, drop_rate=.0, attn_drop_rate=.0, drop_path_prob=None): super().__init__() self.layers = nn.ModuleList([]) embed_dim = base_dim * heads if drop_path_prob is None: drop_path_prob = [0.0 for _ in range(depth)] self.blocks = nn.ModuleList([ Block( dim=embed_dim, num_heads=heads, mlp_ratio=mlp_ratio, qkv_bias=True, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_prob[i], norm_layer=partial(nn.LayerNorm, eps=1e-6) ) for i in range(depth)]) def forward(self, x): h, w = x.shape[2:4] x = rearrange(x, 'b c h w -> b (h w) c') for blk in self.blocks: x = blk(x) return x class FDViTPooling(nn.Module): def __init__(self, in_feature, out_feature, out_size): super().__init__() d = torch.linspace(-1, 1, out_size) meshx, meshy = torch.meshgrid((d, d)) self.grid = torch.stack((meshy, meshx), 2) self.conv = nn.Conv2d(in_feature, out_feature, kernel_size=3, padding=1, stride=1) self.ln = nn.LayerNorm(in_feature) def forward(self, x): h = w = int(math.sqrt(x.shape[1])) x = self.ln(x) x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) grid = self.grid.expand(x.shape[0], -1, -1, -1) x = F.grid_sample(x, grid.to(x.device).type_as(x),align_corners=True) x = self.conv(x) return x class FDViTEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config image_size, patch_size, stride, base_dims, depth, heads, channels, out_size, mlp_ratio = config.image_size, config.patch_size, config.stride, config.base_dims, config.depth, config.heads, config.channels, config.out_size, config.mlp_ratio num_classes = config.num_classes if config.num_classes is not None else 1000 in_chans = config.in_chans if config.in_chans is not None else 3 attn_drop_rate = config.attn_drop_rate if config.attn_drop_rate is not None else .0 drop_rate = config.drop_rate if config.drop_rate is not None else .0 drop_path_rate = config.drop_path_rate if config.drop_path_rate is not None else .0 total_block = sum(depth) padding = 0 block_idx = 0 width = math.floor( (image_size + 2 * padding - patch_size) / stride + 1) self.base_dims = base_dims self.heads = heads self.num_classes = num_classes self.patch_size = patch_size self.pos_embed = nn.Parameter( torch.randn(1, base_dims[0] * heads[0], width, width), requires_grad=True ) self.patch_embed = FDViTEmbeddings(in_chans, base_dims[0] * heads[0], patch_size, stride, padding) self.pos_drop = nn.Dropout(p=drop_rate) self.transformers = nn.ModuleList([]) self.pools = nn.ModuleList([]) self.decoders = nn.ModuleList([]) for stage in range(len(depth)): drop_path_prob = [drop_path_rate * i / total_block for i in range(block_idx, block_idx + depth[stage])] block_idx += depth[stage] self.transformers.append( FDViTLayer(base_dims[stage], depth[stage], heads[stage], mlp_ratio, drop_rate, attn_drop_rate, drop_path_prob) ) if stage < len(heads) - 1: self.pools.append( FDViTPooling(channels[stage], channels[stage+1], out_size[stage+1] ) ) self.embed_dim = base_dims[-1] * heads[-1] def forward(self, x, output_hidden_states=False, return_dict=True): all_hidden_states = () if output_hidden_states else None x = self.patch_embed(x) pos_embed = self.pos_embed x = self.pos_drop(x + pos_embed) for stage in range(len(self.pools)): xt = self.transformers[stage](x) x = self.pools[stage](xt) if output_hidden_states: all_hidden_states = all_hidden_states + (xt,) x = self.transformers[-1](x) if output_hidden_states: all_hidden_states = all_hidden_states + (x,) if not return_dict: return tuple(v for v in [x, all_hidden_states] if v is not None) return BaseModelOutputWithNoAttention(last_hidden_state=x, hidden_states=all_hidden_states) # x = self.norm(x) # return x.mean(dim=1) class FDViTPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = FDViTConfig base_model_prefix = "fdvit" main_input_name = "pixel_values" def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) FDVIT_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`FDViTConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ FDVIT_INPUTS_DOCSTRING = r""" Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`FDViTImageProcessor.__call__`] for details. """ @add_start_docstrings( "The bare FDViT Model transformer outputting raw hidden-states without any specific head on top.", FDVIT_START_DOCSTRING, ) class FDViTModel(FDViTPreTrainedModel): def __init__(self, config: FDViTConfig): super().__init__(config) self.config = config self.encoder = FDViTEncoder(config) # Initialize weights and apply final processing self.post_init() @add_start_docstrings_to_model_forward(FDVIT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=BaseModelOutputWithNoAttention, config_class=_CONFIG_FOR_DOC, modality="vision", expected_output=_EXPECTED_OUTPUT_SHAPE, ) def forward( self, pixel_values: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithNoAttention]: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None: raise ValueError("You have to specify pixel_values") encoder_outputs = self.encoder( pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = encoder_outputs[0] if not return_dict: return (sequence_output, None) + encoder_outputs[1:] return BaseModelOutputWithNoAttention( last_hidden_state=sequence_output, hidden_states=encoder_outputs.hidden_states, ) class FDViTPooler(nn.Module): def __init__(self, config: FDViTConfig): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() def forward(self, hidden_states): # We "pool" the model by simply taking the hidden state corresponding # to the first token. pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) return pooled_output @add_start_docstrings( """ FDViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of the [CLS] token) e.g. for ImageNet. """, FDVIT_START_DOCSTRING, ) class FDViTForImageClassification(FDViTPreTrainedModel): def __init__(self, config: FDViTConfig) -> None: super().__init__(config) self.num_labels = config.num_labels self.fdvit = FDViTModel(config) # Final norm self.norm = nn.LayerNorm(config.base_dims[-1] * config.heads[-1], eps=1e-6) # Classifier head self.classifier = nn.Linear(config.base_dims[-1] * config.heads[-1], config.num_classes) if config.num_classes > 0 else nn.Identity() # Initialize weights and apply final processing self.post_init() @add_start_docstrings_to_model_forward(FDVIT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_IMAGE_CLASS_CHECKPOINT, output_type=ImageClassifierOutputWithNoAttention, config_class=_CONFIG_FOR_DOC, expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, ) def forward( self, pixel_values: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, ImageClassifierOutputWithNoAttention]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the image classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.fdvit( pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] logits = self.classifier(self.norm(sequence_output).mean(dim=1)) loss = None if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)