# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/modernbert/modular_modernbert.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_modernbert.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved. # # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Dict, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from transformers.modeling_outputs import BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput from transformers.modeling_utils import PreTrainedModel from transformers.utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, logging, ) import importlib is_triton_available = lambda: importlib.util.find_spec("triton") is not None from .configuration_modernbert import ModernBertConfig if is_flash_attn_2_available(): from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func from flash_attn.layers.rotary import RotaryEmbedding from flash_attn.ops.triton.rotary import apply_rotary else: RotaryEmbedding = object logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "answerdotai/ModernBERT-base" _CONFIG_FOR_DOC = "ModernBertConfig" class ApplyRotaryEmbUnpad(torch.autograd.Function): @staticmethod def forward( ctx, qkv, cos, sin, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, ): # (total_nnz, 3, nheads, headdim) qkv = qkv.contiguous() total_nnz, _three, _nheads, headdim = qkv.shape # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions, # we get the same tensor # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d") qk = qkv[:, :2].view(total_nnz, -1, headdim) apply_rotary( qk, cos, sin, seqlen_offsets=0, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=False, inplace=True, ) ctx.save_for_backward(cos, sin, cu_seqlens) ctx.max_seqlen = max_seqlen return qkv @staticmethod def backward(ctx, do): cos, sin, cu_seqlens = ctx.saved_tensors do = do.contiguous() total_nnz, _three, _nheads, headdim = do.shape # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions, # we get the same tensor dqk = do[:, :2].view(total_nnz, -1, headdim) apply_rotary( dqk, cos, sin, seqlen_offsets=0, cu_seqlens=cu_seqlens, max_seqlen=ctx.max_seqlen, interleaved=False, inplace=True, conjugate=True, ) return do, None, None, None, None, None, None def apply_rotary_unpadded( qkv, cos, sin, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, ): """ Arguments: qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV. cos, sin: (seqlen_rotary, rotary_dim / 2) interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of 1st half and 2nd half (GPT-NeoX style). inplace: if True, apply rotary embedding in-place. seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. Most commonly used in inference when we have KV cache. cu_seqlens: (batch + 1,) or None max_seqlen: int Return: out: (total_nnz, dim) rotary_dim must be <= headdim Apply rotary embedding to the first rotary_dim of x. """ return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen) class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding): """ The rotary position embeddings applied directly to unpadded sequences. """ def __init__( self, dim: int, base: float = 10000.0, max_seqlen: Optional[int] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): """ max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ, the cos_sin_cache wll be recomputed during the forward pass. """ super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=device, interleaved=False) self.max_seqlen = max_seqlen if max_seqlen is not None and device is not None and dtype is not None: self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype) def forward( self, qkv: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: Optional[int] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Apply rotary embedding *inplace* to qkv. qkv: (total_nnz, 3, nheads, headdim) cu_seqlens: (batch + 1,) cumulative sequence lengths max_seqlen: int max seq length in the batch """ if max_seqlen is not None: self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) qkv = apply_rotary_unpadded( qkv, self._cos_cached, self._sin_cached, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, ) return qkv def extra_repr(self) -> str: return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}" class ModernBertEmbeddings(nn.Module): """ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. """ def __init__(self, config: ModernBertConfig): super().__init__() self.config = config self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) self.drop = nn.Dropout(config.embedding_dropout) @torch.compile(dynamic=True) def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: return self.drop(self.norm(self.tok_embeddings(input_ids))) def forward( self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.Tensor] = None ) -> torch.Tensor: if inputs_embeds is not None: hidden_states = self.drop(self.norm(inputs_embeds)) else: hidden_states = ( self.compiled_embeddings(input_ids) if self.config.reference_compile else self.drop(self.norm(self.tok_embeddings(input_ids))) ) return hidden_states class ModernBertMLP(nn.Module): """Applies the GLU at the end of each ModernBERT layer. Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. """ def __init__(self, config: ModernBertConfig): super().__init__() self.config = config self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias) self.act = ACT2FN[config.hidden_activation] self.drop = nn.Dropout(config.mlp_dropout) self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input, gate = self.Wi(hidden_states).chunk(2, dim=-1) return self.Wo(self.drop(self.act(input) * gate)) class ModernBertRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) @torch.no_grad() def forward(self, x, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] self.inv_freq.to(x.device) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def eager_attention_forward( module: "ModernBertAttention", qkv: torch.Tensor, attention_mask: torch.Tensor, sliding_window_mask: torch.Tensor, position_ids: Optional[torch.LongTensor], local_attention: Tuple[int, int], bs: int, dim: int, output_attentions: Optional[bool] = False, **_kwargs, ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: # qkv: [batch_size, seqlen, 3, nheads, headdim] cos, sin = module.rotary_emb(qkv, position_ids=position_ids) query, key, value = qkv.transpose(3, 1).unbind(dim=2) # query, key, value: [batch_size, heads, seq_len, head_dim] query, key = apply_rotary_pos_emb(query, key, cos, sin) scale = module.head_dim**-0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale if local_attention != (-1, -1): attention_mask = sliding_window_mask attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bs, -1, dim) if output_attentions: return (attn_output, attn_weights) return (attn_output,) def flash_attention_forward( module: "ModernBertAttention", qkv: torch.Tensor, rotary_emb: ModernBertUnpaddedRotaryEmbedding, cu_seqlens: torch.Tensor, max_seqlen: int, local_attention: Tuple[int, int], bs: int, dim: int, target_dtype: torch.dtype = torch.bfloat16, **_kwargs, ) -> Tuple[torch.Tensor]: # (total_seqlen, 3, nheads, headdim) qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) if convert_dtype: # FA2 implementation only supports fp16 and bf16. If FA2 is supported, # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported) orig_dtype = qkv.dtype qkv = qkv.to(target_dtype) attn = flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=module.attention_dropout if module.training else 0.0, deterministic=module.deterministic_flash_attn, window_size=local_attention, ) attn = attn.to(orig_dtype) # type: ignore else: attn = flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=module.attention_dropout if module.training else 0.0, deterministic=module.deterministic_flash_attn, window_size=local_attention, ) return (attn.view(bs, dim),) def sdpa_attention_forward( module: "ModernBertAttention", qkv: torch.Tensor, attention_mask: torch.Tensor, sliding_window_mask: torch.Tensor, position_ids: Optional[torch.LongTensor], local_attention: Tuple[int, int], bs: int, dim: int, **_kwargs, ) -> Tuple[torch.Tensor]: # qkv: [batch_size, seqlen, 3, nheads, headdim] cos, sin = module.rotary_emb(qkv, position_ids=position_ids) query, key, value = qkv.transpose(3, 1).unbind(dim=2) # query, key, value: [batch_size, heads, seq_len, head_dim] query, key = apply_rotary_pos_emb(query, key, cos, sin) if local_attention != (-1, -1): attention_mask = sliding_window_mask attn_output = ( F.scaled_dot_product_attention( query, key, value, dropout_p=module.attention_dropout if module.training else 0.0, attn_mask=attention_mask, ) .transpose(1, 2) .contiguous() ) attn_output = attn_output.view(bs, -1, dim) return (attn_output,) MODERNBERT_ATTENTION_FUNCTION = { "flash_attention_2": flash_attention_forward, "eager": eager_attention_forward, "sdpa": sdpa_attention_forward, } class ModernBertAttention(nn.Module): """Performs multi-headed self attention on a batch of unpadded sequences. If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput. If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel, which requires padding and unpadding inputs, adding some overhead. See `forward` method for additional details. """ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): super().__init__() self.config = config self.layer_id = layer_id if config.hidden_size % config.num_attention_heads != 0: raise ValueError( f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})" ) self.attention_dropout = config.attention_dropout self.deterministic_flash_attn = config.deterministic_flash_attn self.num_heads = config.num_attention_heads self.head_dim = config.hidden_size // config.num_attention_heads self.all_head_size = self.head_dim * self.num_heads self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias) if layer_id % config.global_attn_every_n_layers != 0: self.local_attention = (config.local_attention // 2, config.local_attention // 2) else: self.local_attention = (-1, -1) rope_theta = config.global_rope_theta max_position_embeddings = config.max_position_embeddings if self.local_attention != (-1, -1): if config.local_rope_theta is not None: rope_theta = config.local_rope_theta max_position_embeddings = config.local_attention if config._attn_implementation == "flash_attention_2": self.rotary_emb = ModernBertUnpaddedRotaryEmbedding( dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta ) else: self.rotary_emb = ModernBertRotaryEmbedding( dim=self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta ) self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() self.pruned_heads = set() def forward( self, hidden_states: torch.Tensor, output_attentions: Optional[bool] = False, **kwargs, ) -> torch.Tensor: qkv = self.Wqkv(hidden_states) bs = hidden_states.shape[0] if self.config._attn_implementation == "flash_attention_2": qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) else: qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim) attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation]( self, qkv=qkv, rotary_emb=self.rotary_emb, local_attention=self.local_attention, bs=bs, dim=self.all_head_size, output_attentions=output_attentions, **kwargs, ) hidden_states = attn_outputs[0] hidden_states = self.out_drop(self.Wo(hidden_states)) return (hidden_states,) + attn_outputs[1:] # add attentions if outputted class ModernBertEncoderLayer(nn.Module): def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): super().__init__() self.config = config if layer_id == 0: self.attn_norm = nn.Identity() else: self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) self.attn = ModernBertAttention(config=config, layer_id=layer_id) self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) self.mlp = ModernBertMLP(config) @torch.compile(dynamic=True) def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.mlp(self.mlp_norm(hidden_states)) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, output_attentions: Optional[bool] = False, ) -> torch.Tensor: attn_outputs = self.attn( self.attn_norm(hidden_states), attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, output_attentions=output_attentions, ) hidden_states = hidden_states + attn_outputs[0] mlp_output = ( self.compiled_mlp(hidden_states) if self.config.reference_compile else self.mlp(self.mlp_norm(hidden_states)) ) hidden_states = hidden_states + mlp_output return (hidden_states,) + attn_outputs[1:] # add attentions if outputted MODERNBERT_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`ModernBertConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ @add_start_docstrings( "The bare ModernBert Model outputting raw hidden-states without any specific head on top.", MODERNBERT_START_DOCSTRING, ) class ModernBertPreTrainedModel(PreTrainedModel): config_class = ModernBertConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = False def _init_weights(self, module: nn.Module): cutoff_factor = self.config.initializer_cutoff_factor if cutoff_factor is None: cutoff_factor = 3 def init_weight(module: nn.Module, std: float): nn.init.trunc_normal_( module.weight, mean=0.0, std=std, a=-cutoff_factor * std, b=cutoff_factor * std, ) if isinstance(module, nn.Linear): if module.bias is not None: nn.init.zeros_(module.bias) stds = { "in": self.config.initializer_range, "out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers), "embedding": self.config.initializer_range, "final_out": self.config.hidden_size**-0.5, } if isinstance(module, ModernBertEmbeddings): init_weight(module.tok_embeddings, stds["embedding"]) elif isinstance(module, ModernBertMLP): init_weight(module.Wi, stds["in"]) init_weight(module.Wo, stds["out"]) elif isinstance(module, ModernBertAttention): init_weight(module.Wqkv, stds["in"]) init_weight(module.Wo, stds["out"]) elif isinstance(module, ModernBertPredictionHead): init_weight(module.dense, stds["out"]) elif isinstance(module, ModernBertForMaskedLM): init_weight(module.decoder, stds["out"]) elif isinstance(module, (ModernBertForSequenceClassification, ModernBertForTokenClassification)): init_weight(module.classifier, stds["final_out"]) @classmethod def _autoset_attn_implementation( cls, config, use_flash_attention_2: bool = False, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None, check_device_map: bool = True, ): # If the user didn't specify anything, try to use flash_attention_2 if available. # Otherwise we fall back to the default SDPA -> Eager from the super() method. if config._attn_implementation_internal is None: config._attn_implementation_internal = "flash_attention_2" try: return cls._check_and_enable_flash_attn_2( config, torch_dtype=torch_dtype, device_map=device_map, hard_check_only=False, check_device_map=check_device_map, ) except (ValueError, ImportError): config._attn_implementation_internal = None return super()._autoset_attn_implementation( config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map, check_device_map=check_device_map, ) def _maybe_set_compile(self): if self.config.reference_compile is False: return if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1: if self.config.reference_compile: logger.warning_once( "If `accelerate` split the model across devices, `torch.compile` will not work. " "Falling back to non-compiled mode." ) self.config.reference_compile = False if self.device.type == "mps": if self.config.reference_compile: logger.warning_once( "Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. " "Falling back to non-compiled mode." ) self.config.reference_compile = False if self.config.reference_compile is None: self.config.reference_compile = is_triton_available() def resize_token_embeddings(self, *args, **kwargs): model_embeds = super().resize_token_embeddings(*args, **kwargs) if self.config.reference_compile in {True, None}: if self.config.reference_compile: logger.warning_once( "Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode." ) self.config.reference_compile = False return model_embeds def _unpad_modernbert_input( inputs: torch.Tensor, attention_mask: torch.Tensor, position_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]: """ Remove padding from input sequences. Args: inputs: (batch, seqlen, ...) or (batch, seqlen) attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. position_ids: (batch, seqlen), int, position ids labels: (batch, seqlen), int, labels Returns: unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask. indices: (total_nnz) cu_seqlens: (batch + 1), the cumulative sequence lengths max_seqlen_in_batch: int unpadded_position_ids: (total_nnz) or None unpadded_labels: (total_nnz) or None """ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = int(seqlens_in_batch.max().item()) cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) if inputs.dim() == 2: unpadded_inputs = inputs.flatten()[indices] else: batch, seqlen, *rest = inputs.shape shape = batch * seqlen unpadded_inputs = inputs.view(shape, *rest)[indices] unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None unpadded_labels = labels.flatten()[indices] if labels is not None else None return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels def _pad_modernbert_output( inputs: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int, ) -> torch.Tensor: """ Add padding to sequences. Args: inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask. indices: (total_nnz) batch: int, batch size seqlen: int, max sequence length Returns: padded_inputs: (batch, seqlen, ...) or (batch, seqlen) """ if inputs.dim() == 1: output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device) output[indices] = inputs padded_inputs = output.view(batch, seqlen) else: _, *rest = inputs.shape output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device) output[indices] = inputs padded_inputs = output.view(batch, seqlen, *rest) return padded_inputs MODERNBERT_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers perform global attention, while the rest perform local attention. This mask is used to avoid attending to far-away tokens in the local attention layers. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): Indices of the non-padding tokens in the input sequence. Used for unpadding the output. cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. max_seqlen (`int`, *optional*): Maximum sequence length in the batch. Used to pad the output tensors. batch_size (`int`, *optional*): Batch size of the input sequences. Used to pad the output tensors. seq_len (`int`, *optional*): Sequence length of the input sequences. Used to pad the output tensors. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @add_start_docstrings( "The bare ModernBert Model outputting raw hidden-states without any specific head on top.", MODERNBERT_START_DOCSTRING, ) class ModernBertModel(ModernBertPreTrainedModel): def __init__(self, config: ModernBertConfig): super().__init__(config) self.config = config self.embeddings = ModernBertEmbeddings(config) self.layers = nn.ModuleList( [ModernBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)] ) self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) self.gradient_checkpointing = False self.post_init() def get_input_embeddings(self): return self.embeddings.tok_embeddings def set_input_embeddings(self, value): self.embeddings.tok_embeddings = value @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC, ) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, batch_size: Optional[int] = None, seq_len: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None self._maybe_set_compile() if input_ids is not None: self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) if batch_size is None and seq_len is None: if inputs_embeds is not None: batch_size, seq_len = inputs_embeds.shape[:2] else: batch_size, seq_len = input_ids.shape[:2] device = input_ids.device if input_ids is not None else inputs_embeds.device if attention_mask is None: attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) repad = False if self.config._attn_implementation == "flash_attention_2": if indices is None and cu_seqlens is None and max_seqlen is None: repad = True if inputs_embeds is None: with torch.no_grad(): input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( inputs=input_ids, attention_mask=attention_mask ) else: inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( inputs=inputs_embeds, attention_mask=attention_mask ) else: if position_ids is None: position_ids = torch.arange(seq_len, device=device).unsqueeze(0) attention_mask, sliding_window_mask = self._update_attention_mask( attention_mask, output_attentions=output_attentions ) hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds) for encoder_layer in self.layers: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, sliding_window_mask, position_ids, cu_seqlens, max_seqlen, output_attentions, ) else: layer_outputs = encoder_layer( hidden_states, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions and len(layer_outputs) > 1: all_self_attentions = all_self_attentions + (layer_outputs[1],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) hidden_states = self.final_norm(hidden_states) if repad: hidden_states = _pad_modernbert_output( inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len ) if all_hidden_states is not None: all_hidden_states = tuple( _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len) for hs in all_hidden_states ) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, ) def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor: if output_attentions: if self.config._attn_implementation == "sdpa": logger.warning_once( "Outputting attentions is only supported with the 'eager' attention implementation, " 'not with "sdpa". Falling back to `attn_implementation="eager"`.' ) self.config._attn_implementation = "eager" elif self.config._attn_implementation != "eager": logger.warning_once( "Outputting attentions is only supported with the eager attention implementation, " f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.' " Setting `output_attentions=False`." ) global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype) # Create position indices rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0) # Calculate distance between positions distance = torch.abs(rows - rows.T) # Create sliding window mask (1 for positions within window, 0 outside) window_mask = ( (distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device) ) # Combine with existing mask sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min) return global_attention_mask, sliding_window_mask class ModernBertPredictionHead(nn.Module): def __init__(self, config: ModernBertConfig): super().__init__() self.config = config self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) self.act = ACT2FN[config.classifier_activation] self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.norm(self.act(self.dense(hidden_states))) @add_start_docstrings( "The ModernBert Model with a decoder head on top that is used for masked language modeling.", MODERNBERT_START_DOCSTRING, ) class ModernBertForMaskedLM(ModernBertPreTrainedModel): _tied_weights_keys = ["decoder.weight"] def __init__(self, config: ModernBertConfig): super().__init__(config) self.config = config self.model = ModernBertModel(config) self.head = ModernBertPredictionHead(config) self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias) self.sparse_prediction = self.config.sparse_prediction self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index # Initialize weights and apply final processing self.post_init() def get_output_embeddings(self): return self.decoder def set_output_embeddings(self, new_embeddings: nn.Linear): self.decoder = new_embeddings @torch.compile(dynamic=True) def compiled_head(self, output: torch.Tensor) -> torch.Tensor: return self.decoder(self.head(output)) @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC, ) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, batch_size: Optional[int] = None, seq_len: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict self._maybe_set_compile() if self.config._attn_implementation == "flash_attention_2": if indices is None and cu_seqlens is None and max_seqlen is None: if batch_size is None and seq_len is None: if inputs_embeds is not None: batch_size, seq_len = inputs_embeds.shape[:2] else: batch_size, seq_len = input_ids.shape[:2] device = input_ids.device if input_ids is not None else inputs_embeds.device if attention_mask is None: attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) if inputs_embeds is None: with torch.no_grad(): input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels ) else: inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels ) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, batch_size=batch_size, seq_len=seq_len, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = outputs[0] if self.sparse_prediction and labels is not None: # flatten labels and output first labels = labels.view(-1) last_hidden_state = last_hidden_state.view(labels.shape[0], -1) # then filter out the non-masked tokens mask_tokens = labels != self.sparse_pred_ignore_index last_hidden_state = last_hidden_state[mask_tokens] labels = labels[mask_tokens] logits = ( self.compiled_head(last_hidden_state) if self.config.reference_compile else self.decoder(self.head(last_hidden_state)) ) loss = None if labels is not None: loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size) if self.config._attn_implementation == "flash_attention_2": with torch.no_grad(): logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len) if not return_dict: output = (logits,) return ((loss,) + output) if loss is not None else output return MaskedLMOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @add_start_docstrings( "The ModernBert Model with a sequence classification head on top that performs pooling.", MODERNBERT_START_DOCSTRING, ) class ModernBertForSequenceClassification(ModernBertPreTrainedModel): def __init__(self, config: ModernBertConfig): super().__init__(config) self.num_labels = config.num_labels self.config = config self.model = ModernBertModel(config) self.head = ModernBertPredictionHead(config) self.drop = torch.nn.Dropout(config.classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC, ) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, batch_size: Optional[int] = None, seq_len: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict self._maybe_set_compile() outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, batch_size=batch_size, seq_len=seq_len, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = outputs[0] if self.config.classifier_pooling == "cls": last_hidden_state = last_hidden_state[:, 0] elif self.config.classifier_pooling == "mean": last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum( dim=1, keepdim=True ) pooled_output = self.head(last_hidden_state) pooled_output = self.drop(pooled_output) logits = self.classifier(pooled_output) loss = None if labels is not None: if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @add_start_docstrings( "The ModernBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.", MODERNBERT_START_DOCSTRING, ) class ModernBertForTokenClassification(ModernBertPreTrainedModel): def __init__(self, config: ModernBertConfig): super().__init__(config) self.num_labels = config.num_labels self.model = ModernBertModel(config) self.head = ModernBertPredictionHead(config) self.drop = torch.nn.Dropout(config.classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC, ) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, batch_size: Optional[int] = None, seq_len: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict self._maybe_set_compile() outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, batch_size=batch_size, seq_len=seq_len, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = outputs[0] last_hidden_state = self.head(last_hidden_state) last_hidden_state = self.drop(last_hidden_state) logits = self.classifier(last_hidden_state) loss = None if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) __all__ = [ "ModernBertModel", "ModernBertPreTrainedModel", "ModernBertForMaskedLM", "ModernBertForSequenceClassification", "ModernBertForTokenClassification", ]