# Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License") and the MIT License (the "License2"); """ FDViT model configuration""" from collections import OrderedDict from typing import Mapping from packaging import version from transformers.configuration_utils import PretrainedConfig from transformers.onnx import OnnxConfig from transformers.utils import logging logger = logging.get_logger(__name__) #from ..deprecated._archive_maps import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP # noqa: F401, E402 class FDViTConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`FDViTModel`]. It is used to instantiate an FDViT model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the FDViT [amd/fdvit_ti](https://huggingface.co/amd/fdvit_ti) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: image_size (`int`, *optional*, defaults to 224): The size (resolution) of each image. patch_size (`int`, *optional*, defaults to 16): The size of the input patch. stride (`int`, *optional*, defaults to 16): The stride of the input patch. base_dims (`list`, *optional*, defaults to `[32, 23, 21, 23, 26]`): The basic dimension of each encoder block. depth (`list`, *optional*, defaults to `[2, 3, 3, 2, 2]`): The depth of each encoder block. heads (`list`, *optional*, defaults to `[2, 4, 6, 8, 10]`): The depth of each encoder block. channels (`list`, *optional*, defaults to `[64, 92, 126, 184, 260]`): The depth of each encoder block. out_size (`list`, *optional*, defaults to `[27, 19, 14, 10, 7]`): The output size of each encoder block. mlp_ratio (`float`, *optional*, defaults to 4.0): The ratio of the number of channels in the output of the MLP to the number of channels in the input. num_classes (`int`, *optional*, defaults to 1000): The number of classes of the dataset. in_chans (`int`, *optional*, defaults to 3): The number of channels in the input image. attn_drop_rate (`float`, *optional*, defaults to 0.0): The attention drop rate for the attention dropout layers. drop_rate (`float`, *optional*, defaults to 0.0): The dropout rate for the dropout layers. drop_path_rate (`float`, *optional*, defaults to 0.1): The droppath rate for the droppath layers. initializer_range (`float`, *optional*, defaults to 0.02): The initializer range for the weights. Example: ```python >>> from transformers import FDViTConfig, FDViTModel >>> # Initializing a FDViT fdvit_ti style configuration >>> configuration = FDViTConfig() >>> # Initializing a model (with random weights) from the fdvit_ti style configuration >>> model = FDViTModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "fdvit" def __init__( self, image_size=224, patch_size=16, stride=8, base_dims=[32, 23, 21, 23, 26], depth=[2, 3, 3, 2, 2], heads=[2, 4, 6, 8, 10], channels=[64,92,126,184,260], out_size=[27, 19, 14, 10, 7], mlp_ratio=4, num_classes=1000, in_chans=3, attn_drop_rate=0.0, drop_rate=0.0, drop_path_rate=0.1, initializer_range=0.02, **kwargs, ): super().__init__(**kwargs) self.image_size = image_size self.patch_size = patch_size self.stride = stride self.base_dims = base_dims self.depth = depth self.heads = heads self.channels = channels self.out_size = out_size self.mlp_ratio = mlp_ratio self.num_classes = num_classes self.in_chans = in_chans self.attn_drop_rate = attn_drop_rate self.drop_rate = drop_rate self.drop_path_rate = drop_path_rate self.initializer_range = initializer_range class FDViTOnnxConfig(OnnxConfig): torch_onnx_minimum_version = version.parse("1.11") @property def inputs(self) -> Mapping[str, Mapping[int, str]]: return OrderedDict( [ ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), ] ) @property def atol_for_validation(self) -> float: return 1e-4