# Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved. # # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from contextlib import nullcontext from typing import Dict, Literal, Optional, Tuple, Union import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from transformers.modeling_outputs import ( BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput, ) from transformers.modeling_utils import PreTrainedModel from transformers.utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, logging, ) from transformers.utils.import_utils import is_triton_available, is_torchdynamo_compiling from transformers.models.gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb from transformers.models.modernbert.modular_modernbert import (_pad_modernbert_output, _unpad_modernbert_input, ModernBertEmbeddings, ModernBertMLP, ModernBertUnpaddedRotaryEmbedding, ModernBertEmbeddings) if is_flash_attn_2_available(): from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func from flash_attn.layers.rotary import RotaryEmbedding from flash_attn.ops.triton.rotary import apply_rotary else: RotaryEmbedding = object _CHECKPOINT_FOR_DOC = "answerdotai/ModernBERT-base" _CONFIG_FOR_DOC = "ModernBertConfig" _MAX_SQRT_GRADIENT = 1000.0 logger = logging.get_logger(__name__) class ModernBergConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`ModernBergModel`]. It is used to instantiate an ModernBerg model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the ModernBERT-base. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 50368): Vocabulary size of the ModernBert model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`ModernBertModel`] hidden_size (`int`, *optional*, defaults to 768): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 1152): Dimension of the MLP representations. num_hidden_layers (`int`, *optional*, defaults to 22): Number of hidden layers in the Transformer decoder. num_attention_heads (`int`, *optional*, defaults to 12): Number of attention heads for each attention layer in the Transformer decoder. lru_width (`int`, *optional*, defaults to 128): The dimension of the RG-LRU -- if None, this will be set to `hidden_size`. hidden_activation (`str` or `function`, *optional*, defaults to `"gelu"`): The non-linear activation function (function or string) in the decoder. Will default to `"gelu"` if not specified. max_position_embeddings (`int`, *optional*, defaults to 8192): The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. initializer_cutoff_factor (`float`, *optional*, defaults to 2.0): The cutoff factor for the truncated_normal_initializer for initializing all weight matrices. norm_eps (`float`, *optional*, defaults to 1e-05): The epsilon used by the rms normalization layers. norm_bias (`bool`, *optional*, defaults to `False`): Whether to use bias in the normalization layers. pad_token_id (`int`, *optional*, defaults to 50283): Padding token id. eos_token_id (`int`, *optional*, defaults to 50282): End of stream token id. bos_token_id (`int`, *optional*, defaults to 50281): Beginning of stream token id. cls_token_id (`int`, *optional*, defaults to 50281): Classification token id. sep_token_id (`int`, *optional*, defaults to 50282): Separation token id. global_rope_theta (`float`, *optional*, defaults to 160000.0): The base period of the global RoPE embeddings. attention_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. global_temporal_every_n_layers (`int`, *optional*, defaults to 3): The number of layers between global temporal mixing layers. local_attention (`int`, *optional*, defaults to 128): The window size for local attention. local_rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the local RoPE embeddings. embedding_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the embeddings. mlp_bias (`bool`, *optional*, defaults to `False`): Whether to use bias in the MLP layers. mlp_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the MLP layers. decoder_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the decoder layers. classifier_pooling (`str`, *optional*, defaults to `"cls"`): The pooling method for the classifier. Should be either `"cls"` or `"mean"`. In local attention layers, the CLS token doesn't attend to all tokens on long sequences. classifier_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the classifier. classifier_bias (`bool`, *optional*, defaults to `False`): Whether to use bias in the classifier. classifier_activation (`str`, *optional*, defaults to `"gelu"`): The activation function for the classifier. deterministic_flash_attn (`bool`, *optional*, defaults to `False`): Whether to use deterministic flash attention. If `False`, inference will be faster but not deterministic. sparse_prediction (`bool`, *optional*, defaults to `False`): Whether to use sparse prediction for the masked language model instead of returning the full dense logits. sparse_pred_ignore_index (`int`, *optional*, defaults to -100): The index to ignore for the sparse prediction. reference_compile (`bool`, *optional*): Whether to compile the layers of the model which were compiled during pretraining. If `None`, then parts of the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may be faster in some scenarios. repad_logits_with_grad (`bool`, *optional*, defaults to `False`): When True, ModernBertForMaskedLM keeps track of the logits' gradient when repadding for output. This only applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient. Examples: ```python >>> from transformers import ModernBertModel, ModernBertConfig >>> # Initializing a ModernBert style configuration >>> configuration = ModernBertConfig() >>> # Initializing a model from the modernbert-base style configuration >>> model = ModernBertModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "modernbert" keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, vocab_size=50368, hidden_size=768, intermediate_size=1152, num_hidden_layers=22, num_attention_heads=12, lru_width=1152, conv1d_width=4, hidden_activation="gelu", max_position_embeddings=8192, initializer_range=0.02, initializer_cutoff_factor=2.0, norm_eps=1e-5, norm_bias=False, pad_token_id=50283, eos_token_id=50282, bos_token_id=50281, cls_token_id=50281, sep_token_id=50282, global_rope_theta=160000.0, attention_bias=False, attention_dropout=0.0, global_temporal_every_n_layers=3, local_attention=128, local_rope_theta=10000.0, embedding_dropout=0.0, mlp_bias=False, mlp_dropout=0.0, decoder_bias=True, classifier_pooling: Literal["cls", "mean"] = "cls", classifier_dropout=0.0, classifier_bias=False, classifier_activation="gelu", deterministic_flash_attn=False, sparse_prediction=False, sparse_pred_ignore_index=-100, reference_compile=None, repad_logits_with_grad=False, **kwargs, ): super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, cls_token_id=cls_token_id, sep_token_id=sep_token_id, **kwargs, ) self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.lru_width = lru_width self.conv1d_width = conv1d_width self.initializer_range = initializer_range self.initializer_cutoff_factor = initializer_cutoff_factor self.norm_eps = norm_eps self.norm_bias = norm_bias self.global_rope_theta = global_rope_theta self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.hidden_activation = hidden_activation self.global_temporal_every_n_layers = global_temporal_every_n_layers self.local_attention = local_attention self.local_rope_theta = local_rope_theta self.embedding_dropout = embedding_dropout self.mlp_bias = mlp_bias self.mlp_dropout = mlp_dropout self.decoder_bias = decoder_bias self.classifier_pooling = classifier_pooling self.classifier_dropout = classifier_dropout self.classifier_bias = classifier_bias self.classifier_activation = classifier_activation self.deterministic_flash_attn = deterministic_flash_attn self.sparse_prediction = sparse_prediction self.sparse_pred_ignore_index = sparse_pred_ignore_index self.reference_compile = reference_compile self.repad_logits_with_grad = repad_logits_with_grad if self.classifier_pooling not in ["cls", "mean"]: raise ValueError( f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.' ) class SqrtBoundDerivative(torch.autograd.Function): """Computes a square root with a gradient clipped at `_MAX_SQRT_GRADIENT`.""" @staticmethod def forward(ctx, x: torch.Tensor) -> torch.Tensor: """The forward pass, which is a normal `sqrt`.""" ctx.save_for_backward(x) return torch.sqrt(x) @staticmethod def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: """The backward pass, which clips the `sqrt` gradient.""" (x,) = ctx.saved_tensors clipped_x_times_4 = torch.clip(4.0 * x, min=1 / (_MAX_SQRT_GRADIENT**2)) return grad_output / torch.sqrt(clipped_x_times_4) class GriffinRglru(nn.Module): """A Real-Gated Linear Recurrent Unit (RG-LRU) layer.""" def __init__(self, config: ModernBergConfig): super().__init__() self.num_attention_heads = config.num_attention_heads self.block_width = config.lru_width // self.num_attention_heads self.recurrent_param = nn.Parameter(torch.empty([config.lru_width])) self.input_gate_weight = nn.Parameter( torch.empty([self.num_attention_heads, self.block_width, self.block_width]) ) self.input_gate_bias = nn.Parameter(torch.empty([self.num_attention_heads, self.block_width])) self.recurrent_gate_weight = nn.Parameter( torch.empty([self.num_attention_heads, self.block_width, self.block_width]) ) self.recurrent_gate_bias = nn.Parameter(torch.empty([self.num_attention_heads, self.block_width])) self.recurrent_states = None def forward( self, activations: torch.Tensor, position_ids: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: batch_size, seq_len, lru_width = activations.shape reset = position_ids[:, :, None] == 0 reshape_act = activations.reshape(batch_size * seq_len, self.num_attention_heads, self.block_width) reshape_act = reshape_act.permute(1, 0, 2) res = torch.baddbmm(self.input_gate_bias[:, None, :], reshape_act, self.input_gate_weight) input_gate = torch.sigmoid(res.transpose(0, 1).reshape(batch_size, seq_len, lru_width)) res = torch.baddbmm(self.recurrent_gate_bias[:, None, :], reshape_act, self.recurrent_gate_weight) recurrent_gate = torch.sigmoid(res.transpose(0, 1).reshape(batch_size, seq_len, lru_width)) # Compute the parameter `A` of the recurrence. log_recurrent_gate = -8.0 * recurrent_gate * nn.functional.softplus(self.recurrent_param) recurrent_gate = torch.exp(log_recurrent_gate) a_square = torch.exp(2 * log_recurrent_gate) # Gate the input. gated_inputs = activations * input_gate # Apply gamma normalization to the input. We need to clip the derivatives of # `sqrt` in order to prevent NaNs during training in bfloat16. TODO a bit annoying multiplier = 1 tracing = isinstance(activations, torch.fx.Proxy) or is_torchdynamo_compiling() if not torch.jit.is_tracing() and not tracing: multiplier = SqrtBoundDerivative.apply(1 - a_square) multiplier = reset + ~reset * multiplier normalized_x = gated_inputs * multiplier.type(activations.dtype) hidden_states, recurrent_states = self._rnn_scan( hidden_states=normalized_x, recurrent_gate=recurrent_gate, reset=reset, recurrent_states=self.recurrent_states, ) self.recurrent_states = recurrent_states return hidden_states # TODO refactor def _rnn_scan( self, hidden_states: torch.Tensor, recurrent_gate: torch.Tensor, reset: torch.Tensor, recurrent_states: Union[torch.Tensor, None], acc_dtype: torch.dtype = torch.float32, ) -> Tuple[torch.Tensor, torch.Tensor]: """Runs the recurrence of a linear RNN. Args: hidden_states: The input sequence. recurrent_gate: The diagonal of the recurrence matrix `A`. reset: Indicator of document boundaries, e.g. when to reset the hidden state of the RNN. recurrent_states: The initial hidden state. acc_dtype: The data type for the accumulation. Returns: The output of the linear recurrence. """ # Multiply `a` by the reset. recurrent_gate = recurrent_gate * ~reset if hidden_states.shape[1] == 1: # Using scan in sampling mode. if recurrent_states is None: # same here, when decoding you always have cache return hidden_states, hidden_states[:, 0].type(acc_dtype) else: contextualized_states = recurrent_gate.type(acc_dtype) * recurrent_states[:, None].to( recurrent_gate.device ) contextualized_states += hidden_states.type(acc_dtype) return contextualized_states.type(hidden_states.dtype), contextualized_states[:, -1] else: # Using scan in linear mode. if recurrent_states is None: recurrent_states = torch.zeros(hidden_states[:, 0].shape, dtype=acc_dtype, device=hidden_states.device) contextualized_states = torch.zeros_like(hidden_states) for t in range(hidden_states.shape[1]): recurrent_states = recurrent_gate[:, t].type(acc_dtype) * recurrent_states.to(recurrent_gate.device) recurrent_states = recurrent_states + hidden_states[:, t].type(acc_dtype) contextualized_states[:, t] = recurrent_states.type(hidden_states.dtype) return contextualized_states, recurrent_states class GriffinRecurrentblock(nn.Module): """Griffin and Hawk's recurrent block.""" def __init__(self, config: ModernBergConfig, layer_id: Optional[int] = None): super().__init__() self.lru_width = config.lru_width self.hidden_size = config.hidden_size self.linear_y = nn.Linear(in_features=config.hidden_size, out_features=config.lru_width) self.linear_x = nn.Linear(in_features=config.hidden_size, out_features=config.lru_width) self.linear_out = nn.Linear(in_features=config.lru_width, out_features=config.hidden_size) self.conv1d_width = config.conv1d_width self.conv_1d = nn.Conv1d( config.lru_width, config.lru_width, kernel_size=config.conv1d_width, groups=config.lru_width, padding=config.conv1d_width - 1, ) self.rg_lru = GriffinRglru(config) self.act_fn = ACT2FN[config.hidden_activation] self.conv1d_state = None def forward( self, input_states: torch.Tensor, position_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: _, seq_len, _ = input_states.shape y_branch = self.linear_y(input_states) y_branch = self.act_fn(y_branch) x_branch = self.linear_x(input_states) x_branch = x_branch.transpose(1, 2) x_branch = self.conv_1d(x_branch)[..., :seq_len] x_branch = self.rg_lru(x_branch.transpose(1, 2), position_ids) hidden_states = x_branch * y_branch hidden_states = self.linear_out(hidden_states) return hidden_states def _setup_cache(self, batch, device, dtype): # recurrent_states always computed in full precision self.rg_lru.recurrent_states = torch.zeros((batch, self.lru_width), device=device, dtype=torch.float32) self.conv1d_state = torch.zeros((batch, self.hidden_size, self.conv1d_width - 1), device=device, dtype=dtype) def eager_attention_forward( module: "ModernBergAttention", qkv: torch.Tensor, attention_mask: torch.Tensor, sliding_window_mask: torch.Tensor, position_ids: Optional[torch.LongTensor], local_attention: Tuple[int, int], bs: int, dim: int, output_attentions: Optional[bool] = False, **_kwargs, ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: # qkv: [batch_size, seqlen, 3, nheads, headdim] cos, sin = module.rotary_emb(qkv, position_ids=position_ids) query, key, value = qkv.transpose(3, 1).unbind(dim=2) # query, key, value: [batch_size, heads, seq_len, head_dim] query, key = apply_rotary_pos_emb(query, key, cos, sin) scale = module.head_dim**-0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale if local_attention != (-1, -1): attention_mask = sliding_window_mask attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bs, -1, dim) if output_attentions: return (attn_output, attn_weights) return (attn_output,) def flash_attention_forward( module: "ModernBergAttention", qkv: torch.Tensor, rotary_emb: ModernBertUnpaddedRotaryEmbedding, cu_seqlens: torch.Tensor, max_seqlen: int, local_attention: Tuple[int, int], bs: int, dim: int, target_dtype: torch.dtype = torch.bfloat16, **_kwargs, ) -> Tuple[torch.Tensor]: # (total_seqlen, 3, nheads, headdim) qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) if convert_dtype: # FA2 implementation only supports fp16 and bf16. If FA2 is supported, # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported) orig_dtype = qkv.dtype qkv = qkv.to(target_dtype) attn = flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=module.attention_dropout if module.training else 0.0, deterministic=module.deterministic_flash_attn, window_size=local_attention, ) attn = attn.to(orig_dtype) # type: ignore else: attn = flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=module.attention_dropout if module.training else 0.0, deterministic=module.deterministic_flash_attn, window_size=local_attention, ) return (attn.view(bs, dim),) def sdpa_attention_forward( module: "ModernBergAttention", qkv: torch.Tensor, attention_mask: torch.Tensor, sliding_window_mask: torch.Tensor, position_ids: Optional[torch.LongTensor], local_attention: Tuple[int, int], bs: int, dim: int, **_kwargs, ) -> Tuple[torch.Tensor]: # qkv: [batch_size, seqlen, 3, nheads, headdim] cos, sin = module.rotary_emb(qkv, position_ids=position_ids) query, key, value = qkv.transpose(3, 1).unbind(dim=2) # query, key, value: [batch_size, heads, seq_len, head_dim] query, key = apply_rotary_pos_emb(query, key, cos, sin) if local_attention != (-1, -1): attention_mask = sliding_window_mask attn_output = ( F.scaled_dot_product_attention( query, key, value, dropout_p=module.attention_dropout if module.training else 0.0, attn_mask=attention_mask, ) .transpose(1, 2) .contiguous() ) attn_output = attn_output.view(bs, -1, dim) return (attn_output,) MODERNBERT_ATTENTION_FUNCTION = { "flash_attention_2": flash_attention_forward, "eager": eager_attention_forward, "sdpa": sdpa_attention_forward, } class ModernBergRotaryEmbedding(GemmaRotaryEmbedding): def __init__(self, config: ModernBergConfig, dim: int, base: float, device: Optional[torch.device] = None): # JANK!!! JANK!!! JANK!!! config.rope_theta = base super().__init__(config=config, device=device) inv_freq, self.attention_scaling = self.rope_init_fn(None, device, dim=dim, base=base) class ModernBergAttention(nn.Module): """Performs multi-headed self attention on a batch of unpadded sequences. If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput. If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel, which requires padding and unpadding inputs, adding some overhead. See `forward` method for additional details. """ def __init__(self, config: ModernBergConfig, layer_id: Optional[int] = None): super().__init__() self.config = config self.layer_id = layer_id if config.hidden_size % config.num_attention_heads != 0: raise ValueError( f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})" ) self.attention_dropout = config.attention_dropout self.deterministic_flash_attn = config.deterministic_flash_attn self.num_heads = config.num_attention_heads self.head_dim = config.hidden_size // config.num_attention_heads self.all_head_size = self.head_dim * self.num_heads self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias) assert layer_id % config.global_temporal_every_n_layers != 0, "ModernBerg does not support global self-attention" self.local_attention = (config.local_attention // 2, config.local_attention // 2) rope_theta = config.global_rope_theta max_position_embeddings = config.max_position_embeddings if self.local_attention != (-1, -1): if config.local_rope_theta is not None: rope_theta = config.local_rope_theta max_position_embeddings = config.local_attention if config._attn_implementation == "flash_attention_2": self.rotary_emb = ModernBertUnpaddedRotaryEmbedding( dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta ) else: self.rotary_emb = ModernBergRotaryEmbedding(config=config, dim=self.head_dim, base=rope_theta) self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() self.pruned_heads = set() def forward( self, hidden_states: torch.Tensor, output_attentions: Optional[bool] = False, **kwargs, ) -> torch.Tensor: qkv = self.Wqkv(hidden_states) bs = hidden_states.shape[0] if self.config._attn_implementation == "flash_attention_2": qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) else: qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim) attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation]( self, qkv=qkv, rotary_emb=self.rotary_emb, local_attention=self.local_attention, bs=bs, dim=self.all_head_size, output_attentions=output_attentions, **kwargs, ) hidden_states = attn_outputs[0] hidden_states = self.out_drop(self.Wo(hidden_states)) return hidden_states class ModernBergTemporalLayer(nn.Module): def __init__(self, config: ModernBergConfig, layer_id: Optional[int] = None): super().__init__() self.config = config if layer_id % config.global_temporal_every_n_layers == 0: self.temporal = GriffinRecurrentblock(config=config, layer_id=layer_id) else: self.temporal = ModernBergAttention(config=config, layer_id=layer_id) def forward(self, hidden_states: torch.Tensor, **kwargs): return self.temporal(hidden_states, **kwargs) class ModernBergEncoderLayer(nn.Module): def __init__(self, config: ModernBergConfig, layer_id: Optional[int] = None): super().__init__() self.config = config if layer_id == 0: self.temporal_norm = nn.Identity() else: self.temporal_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) self.temporal = ModernBergTemporalLayer(config=config, layer_id=layer_id) self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) self.mlp = ModernBertMLP(config) @torch.compile(dynamic=True) def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.mlp(self.mlp_norm(hidden_states)) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, output_attentions: Optional[bool] = False, ) -> torch.Tensor: attn_outputs = self.temporal( self.temporal_norm(hidden_states), attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, output_attentions=output_attentions, ) hidden_states = hidden_states + attn_outputs[0] mlp_output = ( self.compiled_mlp(hidden_states) if self.config.reference_compile else self.mlp(self.mlp_norm(hidden_states)) ) hidden_states = hidden_states + mlp_output return hidden_states MODERNBERG_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`ModernBergConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ @add_start_docstrings( "The bare ModernBerg Model outputting raw hidden-states without any specific head on top.", MODERNBERG_START_DOCSTRING, ) class ModernBergPreTrainedModel(PreTrainedModel): config_class = ModernBergConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["ModernBergEmbeddings", "ModernBergEncoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = False def _init_weights(self, module: nn.Module): cutoff_factor = self.config.initializer_cutoff_factor if cutoff_factor is None: cutoff_factor = 3 def init_weight(module: nn.Module, std: float): nn.init.trunc_normal_( module.weight, mean=0.0, std=std, a=-cutoff_factor * std, b=cutoff_factor * std, ) if isinstance(module, nn.Linear): if module.bias is not None: nn.init.zeros_(module.bias) stds = { "in": self.config.initializer_range, "out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers), "embedding": self.config.initializer_range, "final_out": self.config.hidden_size**-0.5, } std = math.sqrt(self.config.initializer_range / self.config.conv1d_width) if isinstance(module, ModernBertEmbeddings): init_weight(module.tok_embeddings, stds["embedding"]) elif isinstance(module, ModernBertMLP): init_weight(module.Wi, stds["in"]) init_weight(module.Wo, stds["out"]) elif isinstance(module, ModernBergAttention): init_weight(module.Wqkv, stds["in"]) init_weight(module.Wo, stds["out"]) elif isinstance(module, GriffinRecurrentblock): torch.nn.init.zeros_(module.linear_x.bias) torch.nn.init.normal_(module.linear_x.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size)) torch.nn.init.zeros_(module.linear_y.bias) torch.nn.init.normal_(module.linear_y.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size)) std = math.sqrt(self.config.initializer_range / self.config.lru_width) torch.nn.init.normal_(module.linear_out.weight, mean=0.0, std=std) torch.nn.init.zeros_(module.linear_out.bias) elif isinstance(module, GriffinRglru): std = math.sqrt( self.config.initializer_range / (self.config.lru_width // self.config.num_attention_heads) ) torch.nn.init.normal_(module.input_gate_weight, mean=0.0, std=std) torch.nn.init.normal_(module.recurrent_gate_weight, mean=0.0, std=std) torch.nn.init.zeros_(module.input_gate_bias) torch.nn.init.zeros_(module.recurrent_gate_bias) module.recurrent_param.data.uniform_(0.9**2 + 1e-8, 0.999**2 + 1e-8) module.recurrent_param.data.log_().mul_(0.5) module.recurrent_param.data.neg_().exp_().sub_(1.0).log_() elif isinstance(module, ModernBergPredictionHead): init_weight(module.dense, stds["out"]) elif isinstance(module, ModernBergForMaskedLM): init_weight(module.decoder, stds["out"]) elif isinstance(module, (ModernBergForSequenceClassification, ModernBergForTokenClassification)): init_weight(module.classifier, stds["final_out"]) elif isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=std) if getattr(module, "bias", None) is not None: torch.nn.init.zeros_(module.bias) @classmethod def _autoset_attn_implementation( cls, config, use_flash_attention_2: bool = False, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None, check_device_map: bool = True, ): # If the user didn't specify anything, try to use flash_attention_2 if available. # Otherwise we fall back to the default SDPA -> Eager from the super() method. # ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't # need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check. if config._attn_implementation_internal is None: config._attn_implementation_internal = "flash_attention_2" try: return cls._check_and_enable_flash_attn_2( config, torch_dtype=torch.float16, device_map=device_map, hard_check_only=False, check_device_map=check_device_map, ) except (ValueError, ImportError): config._attn_implementation_internal = None return super()._autoset_attn_implementation( config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch.float16, device_map=device_map, check_device_map=check_device_map, ) def _maybe_set_compile(self): if self.config.reference_compile is False: return if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1: if self.config.reference_compile: logger.warning_once( "If `accelerate` split the model across devices, `torch.compile` will not work. " "Falling back to non-compiled mode." ) self.config.reference_compile = False if self.device.type == "mps": if self.config.reference_compile: logger.warning_once( "Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. " "Falling back to non-compiled mode." ) self.config.reference_compile = False if self.device.type == "cpu": if self.config.reference_compile: logger.warning_once( "Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. " "Falling back to non-compiled mode." ) self.config.reference_compile = False if self.config.reference_compile is None: self.config.reference_compile = is_triton_available() def resize_token_embeddings(self, *args, **kwargs): model_embeds = super().resize_token_embeddings(*args, **kwargs) if self.config.reference_compile in {True, None}: if self.config.reference_compile: logger.warning_once( "Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode." ) self.config.reference_compile = False return model_embeds MODERNBERG_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. With Flash Attention 2.0, padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers perform global attention, while the rest perform local attention. This mask is used to avoid attending to far-away tokens in the local attention layers when not using Flash Attention. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): Indices of the non-padding tokens in the input sequence. Used for unpadding the output. cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. max_seqlen (`int`, *optional*): Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors. batch_size (`int`, *optional*): Batch size of the input sequences. Used to pad the output tensors. seq_len (`int`, *optional*): Sequence length of the input sequences including padding tokens. Used to pad the output tensors. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @add_start_docstrings( "The bare ModernBerg Model outputting raw hidden-states without any specific head on top.", MODERNBERG_START_DOCSTRING, ) class ModernBergModel(ModernBergPreTrainedModel): def __init__(self, config: ModernBergConfig): super().__init__(config) self.config = config self.embeddings = ModernBertEmbeddings(config) self.layers = nn.ModuleList( [ModernBergEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)] ) self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) self.gradient_checkpointing = False self.post_init() def get_input_embeddings(self): return self.embeddings.tok_embeddings def set_input_embeddings(self, value): self.embeddings.tok_embeddings = value @add_start_docstrings_to_model_forward(MODERNBERG_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC, ) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, batch_size: Optional[int] = None, seq_len: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None self._maybe_set_compile() if input_ids is not None: self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) if batch_size is None and seq_len is None: if inputs_embeds is not None: batch_size, seq_len = inputs_embeds.shape[:2] else: batch_size, seq_len = input_ids.shape[:2] device = input_ids.device if input_ids is not None else inputs_embeds.device if attention_mask is None: attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) repad = False if self.config._attn_implementation == "flash_attention_2": if indices is None and cu_seqlens is None and max_seqlen is None: repad = True if inputs_embeds is None: with torch.no_grad(): input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( inputs=input_ids, attention_mask=attention_mask ) else: inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( inputs=inputs_embeds, attention_mask=attention_mask ) else: if position_ids is None: position_ids = torch.arange(seq_len, device=device).unsqueeze(0) attention_mask, sliding_window_mask = self._update_attention_mask( attention_mask, output_attentions=output_attentions ) hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds) for encoder_layer in self.layers: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, sliding_window_mask, position_ids, cu_seqlens, max_seqlen, output_attentions, ) else: layer_outputs = encoder_layer( hidden_states, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions and len(layer_outputs) > 1: all_self_attentions = all_self_attentions + (layer_outputs[1],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) hidden_states = self.final_norm(hidden_states) if repad: hidden_states = _pad_modernbert_output( inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len ) if all_hidden_states is not None: all_hidden_states = tuple( _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len) for hs in all_hidden_states ) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, ) def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor: if output_attentions: if self.config._attn_implementation == "sdpa": logger.warning_once( "Outputting attentions is only supported with the 'eager' attention implementation, " 'not with "sdpa". Falling back to `attn_implementation="eager"`.' ) self.config._attn_implementation = "eager" elif self.config._attn_implementation != "eager": logger.warning_once( "Outputting attentions is only supported with the eager attention implementation, " f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.' " Setting `output_attentions=False`." ) global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype) # Create position indices rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0) # Calculate distance between positions distance = torch.abs(rows - rows.T) # Create sliding window mask (1 for positions within window, 0 outside) window_mask = ( (distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device) ) # Combine with existing mask sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min) return global_attention_mask, sliding_window_mask class ModernBergPredictionHead(nn.Module): def __init__(self, config: ModernBergConfig): super().__init__() self.config = config self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) self.act = ACT2FN[config.classifier_activation] self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.norm(self.act(self.dense(hidden_states))) @add_start_docstrings( "The ModernBerg Model with a decoder head on top that is used for masked language modeling.", MODERNBERG_START_DOCSTRING, ) class ModernBergForMaskedLM(ModernBergPreTrainedModel): _tied_weights_keys = ["decoder.weight"] def __init__(self, config: ModernBergConfig): super().__init__(config) self.config = config self.model = ModernBergModel(config) self.head = ModernBergPredictionHead(config) self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias) self.sparse_prediction = self.config.sparse_prediction self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index # Initialize weights and apply final processing self.post_init() def get_output_embeddings(self): return self.decoder def set_output_embeddings(self, new_embeddings: nn.Linear): self.decoder = new_embeddings @torch.compile(dynamic=True) def compiled_head(self, output: torch.Tensor) -> torch.Tensor: return self.decoder(self.head(output)) @add_start_docstrings_to_model_forward(MODERNBERG_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC, ) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, batch_size: Optional[int] = None, seq_len: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict self._maybe_set_compile() if self.config._attn_implementation == "flash_attention_2": if indices is None and cu_seqlens is None and max_seqlen is None: if batch_size is None and seq_len is None: if inputs_embeds is not None: batch_size, seq_len = inputs_embeds.shape[:2] else: batch_size, seq_len = input_ids.shape[:2] device = input_ids.device if input_ids is not None else inputs_embeds.device if attention_mask is None: attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) if inputs_embeds is None: with torch.no_grad(): input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels ) else: inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels ) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, batch_size=batch_size, seq_len=seq_len, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = outputs[0] if self.sparse_prediction and labels is not None: # flatten labels and output first labels = labels.view(-1) last_hidden_state = last_hidden_state.view(labels.shape[0], -1) # then filter out the non-masked tokens mask_tokens = labels != self.sparse_pred_ignore_index last_hidden_state = last_hidden_state[mask_tokens] labels = labels[mask_tokens] logits = ( self.compiled_head(last_hidden_state) if self.config.reference_compile else self.decoder(self.head(last_hidden_state)) ) loss = None if labels is not None: loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size) if self.config._attn_implementation == "flash_attention_2": with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad(): logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len) if not return_dict: output = (logits,) return ((loss,) + output) if loss is not None else output return MaskedLMOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @add_start_docstrings( "The ModernBerg Model with a sequence classification head on top that performs pooling.", MODERNBERG_START_DOCSTRING, ) class ModernBergForSequenceClassification(ModernBergPreTrainedModel): def __init__(self, config: ModernBergConfig): super().__init__(config) self.num_labels = config.num_labels self.config = config self.model = ModernBergModel(config) self.head = ModernBergPredictionHead(config) self.drop = torch.nn.Dropout(config.classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() @add_start_docstrings_to_model_forward(MODERNBERG_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC, ) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, batch_size: Optional[int] = None, seq_len: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict self._maybe_set_compile() outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, batch_size=batch_size, seq_len=seq_len, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = outputs[0] if self.config.classifier_pooling == "cls": last_hidden_state = last_hidden_state[:, 0] elif self.config.classifier_pooling == "mean": last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum( dim=1, keepdim=True ) pooled_output = self.head(last_hidden_state) pooled_output = self.drop(pooled_output) logits = self.classifier(pooled_output) loss = None if labels is not None: if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @add_start_docstrings( "The ModernBerg Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.", MODERNBERG_START_DOCSTRING, ) class ModernBergForTokenClassification(ModernBergPreTrainedModel): def __init__(self, config: ModernBergConfig): super().__init__(config) self.num_labels = config.num_labels self.model = ModernBergModel(config) self.head = ModernBergPredictionHead(config) self.drop = torch.nn.Dropout(config.classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() @add_start_docstrings_to_model_forward(MODERNBERG_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC, ) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, batch_size: Optional[int] = None, seq_len: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict self._maybe_set_compile() outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, batch_size=batch_size, seq_len=seq_len, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = outputs[0] last_hidden_state = self.head(last_hidden_state) last_hidden_state = self.drop(last_hidden_state) logits = self.classifier(last_hidden_state) loss = None if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) __all__ = [ "ModernBergConfig", "ModernBergModel", "ModernBergPreTrainedModel", "ModernBergForMaskedLM", "ModernBergForSequenceClassification", "ModernBergForTokenClassification", ]