ModernBERG-base-uninit / modernberg_model.py
Fizzarolli's picture
Update modernberg_model.py
7b1fc1a verified
raw
history blame
64.3 kB
# Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from contextlib import nullcontext
from typing import Dict, Literal, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from transformers.modeling_outputs import (
BaseModelOutput,
MaskedLMOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
logging,
)
from transformers.utils.import_utils import is_triton_available, is_torchdynamo_compiling
from transformers.models.gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb
from transformers.models.modernbert.modular_modernbert import (_pad_modernbert_output, _unpad_modernbert_input, ModernBertEmbeddings, ModernBertMLP, ModernBertUnpaddedRotaryEmbedding, ModernBertEmbeddings)
if is_flash_attn_2_available():
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
from flash_attn.layers.rotary import RotaryEmbedding
from flash_attn.ops.triton.rotary import apply_rotary
else:
RotaryEmbedding = object
_CHECKPOINT_FOR_DOC = "answerdotai/ModernBERT-base"
_CONFIG_FOR_DOC = "ModernBertConfig"
_MAX_SQRT_GRADIENT = 1000.0
logger = logging.get_logger(__name__)
class ModernBergConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`ModernBergModel`]. It is used to instantiate an ModernBerg
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the ModernBERT-base.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 50368):
Vocabulary size of the ModernBert model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`ModernBertModel`]
hidden_size (`int`, *optional*, defaults to 768):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 1152):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 22):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer decoder.
lru_width (`int`, *optional*, defaults to 128):
The dimension of the RG-LRU -- if None, this will be set to `hidden_size`.
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the decoder. Will default to `"gelu"`
if not specified.
max_position_embeddings (`int`, *optional*, defaults to 8192):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
initializer_cutoff_factor (`float`, *optional*, defaults to 2.0):
The cutoff factor for the truncated_normal_initializer for initializing all weight matrices.
norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
norm_bias (`bool`, *optional*, defaults to `False`):
Whether to use bias in the normalization layers.
pad_token_id (`int`, *optional*, defaults to 50283):
Padding token id.
eos_token_id (`int`, *optional*, defaults to 50282):
End of stream token id.
bos_token_id (`int`, *optional*, defaults to 50281):
Beginning of stream token id.
cls_token_id (`int`, *optional*, defaults to 50281):
Classification token id.
sep_token_id (`int`, *optional*, defaults to 50282):
Separation token id.
global_rope_theta (`float`, *optional*, defaults to 160000.0):
The base period of the global RoPE embeddings.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
global_temporal_every_n_layers (`int`, *optional*, defaults to 3):
The number of layers between global temporal mixing layers.
local_attention (`int`, *optional*, defaults to 128):
The window size for local attention.
local_rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the local RoPE embeddings.
embedding_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the embeddings.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use bias in the MLP layers.
mlp_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the MLP layers.
decoder_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in the decoder layers.
classifier_pooling (`str`, *optional*, defaults to `"cls"`):
The pooling method for the classifier. Should be either `"cls"` or `"mean"`. In local attention layers, the
CLS token doesn't attend to all tokens on long sequences.
classifier_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the classifier.
classifier_bias (`bool`, *optional*, defaults to `False`):
Whether to use bias in the classifier.
classifier_activation (`str`, *optional*, defaults to `"gelu"`):
The activation function for the classifier.
deterministic_flash_attn (`bool`, *optional*, defaults to `False`):
Whether to use deterministic flash attention. If `False`, inference will be faster but not deterministic.
sparse_prediction (`bool`, *optional*, defaults to `False`):
Whether to use sparse prediction for the masked language model instead of returning the full dense logits.
sparse_pred_ignore_index (`int`, *optional*, defaults to -100):
The index to ignore for the sparse prediction.
reference_compile (`bool`, *optional*):
Whether to compile the layers of the model which were compiled during pretraining. If `None`, then parts of
the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
be faster in some scenarios.
repad_logits_with_grad (`bool`, *optional*, defaults to `False`):
When True, ModernBertForMaskedLM keeps track of the logits' gradient when repadding for output. This only
applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient.
Examples:
```python
>>> from transformers import ModernBertModel, ModernBertConfig
>>> # Initializing a ModernBert style configuration
>>> configuration = ModernBertConfig()
>>> # Initializing a model from the modernbert-base style configuration
>>> model = ModernBertModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "modernbert"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=50368,
hidden_size=768,
intermediate_size=1152,
num_hidden_layers=22,
num_attention_heads=12,
lru_width=1152,
conv1d_width=4,
hidden_activation="gelu",
max_position_embeddings=8192,
initializer_range=0.02,
initializer_cutoff_factor=2.0,
norm_eps=1e-5,
norm_bias=False,
pad_token_id=50283,
eos_token_id=50282,
bos_token_id=50281,
cls_token_id=50281,
sep_token_id=50282,
global_rope_theta=160000.0,
attention_bias=False,
attention_dropout=0.0,
global_temporal_every_n_layers=3,
local_attention=128,
local_rope_theta=10000.0,
embedding_dropout=0.0,
mlp_bias=False,
mlp_dropout=0.0,
decoder_bias=True,
classifier_pooling: Literal["cls", "mean"] = "cls",
classifier_dropout=0.0,
classifier_bias=False,
classifier_activation="gelu",
deterministic_flash_attn=False,
sparse_prediction=False,
sparse_pred_ignore_index=-100,
reference_compile=None,
repad_logits_with_grad=False,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
cls_token_id=cls_token_id,
sep_token_id=sep_token_id,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.lru_width = lru_width
self.conv1d_width = conv1d_width
self.initializer_range = initializer_range
self.initializer_cutoff_factor = initializer_cutoff_factor
self.norm_eps = norm_eps
self.norm_bias = norm_bias
self.global_rope_theta = global_rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.hidden_activation = hidden_activation
self.global_temporal_every_n_layers = global_temporal_every_n_layers
self.local_attention = local_attention
self.local_rope_theta = local_rope_theta
self.embedding_dropout = embedding_dropout
self.mlp_bias = mlp_bias
self.mlp_dropout = mlp_dropout
self.decoder_bias = decoder_bias
self.classifier_pooling = classifier_pooling
self.classifier_dropout = classifier_dropout
self.classifier_bias = classifier_bias
self.classifier_activation = classifier_activation
self.deterministic_flash_attn = deterministic_flash_attn
self.sparse_prediction = sparse_prediction
self.sparse_pred_ignore_index = sparse_pred_ignore_index
self.reference_compile = reference_compile
self.repad_logits_with_grad = repad_logits_with_grad
if self.classifier_pooling not in ["cls", "mean"]:
raise ValueError(
f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
)
class SqrtBoundDerivative(torch.autograd.Function):
"""Computes a square root with a gradient clipped at `_MAX_SQRT_GRADIENT`."""
@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
"""The forward pass, which is a normal `sqrt`."""
ctx.save_for_backward(x)
return torch.sqrt(x)
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
"""The backward pass, which clips the `sqrt` gradient."""
(x,) = ctx.saved_tensors
clipped_x_times_4 = torch.clip(4.0 * x, min=1 / (_MAX_SQRT_GRADIENT**2))
return grad_output / torch.sqrt(clipped_x_times_4)
class GriffinRglru(nn.Module):
"""A Real-Gated Linear Recurrent Unit (RG-LRU) layer."""
def __init__(self, config: ModernBergConfig):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.block_width = config.lru_width // self.num_attention_heads
self.recurrent_param = nn.Parameter(torch.empty([config.lru_width]))
self.input_gate_weight = nn.Parameter(
torch.empty([self.num_attention_heads, self.block_width, self.block_width])
)
self.input_gate_bias = nn.Parameter(torch.empty([self.num_attention_heads, self.block_width]))
self.recurrent_gate_weight = nn.Parameter(
torch.empty([self.num_attention_heads, self.block_width, self.block_width])
)
self.recurrent_gate_bias = nn.Parameter(torch.empty([self.num_attention_heads, self.block_width]))
self.recurrent_states = None
def forward(
self,
activations: torch.Tensor,
position_ids: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, seq_len, lru_width = activations.shape
reset = position_ids[:, :, None] == 0
reshape_act = activations.reshape(batch_size * seq_len, self.num_attention_heads, self.block_width)
reshape_act = reshape_act.permute(1, 0, 2)
res = torch.baddbmm(self.input_gate_bias[:, None, :], reshape_act, self.input_gate_weight)
input_gate = torch.sigmoid(res.transpose(0, 1).reshape(batch_size, seq_len, lru_width))
res = torch.baddbmm(self.recurrent_gate_bias[:, None, :], reshape_act, self.recurrent_gate_weight)
recurrent_gate = torch.sigmoid(res.transpose(0, 1).reshape(batch_size, seq_len, lru_width))
# Compute the parameter `A` of the recurrence.
log_recurrent_gate = -8.0 * recurrent_gate * nn.functional.softplus(self.recurrent_param)
recurrent_gate = torch.exp(log_recurrent_gate)
a_square = torch.exp(2 * log_recurrent_gate)
# Gate the input.
gated_inputs = activations * input_gate
# Apply gamma normalization to the input. We need to clip the derivatives of
# `sqrt` in order to prevent NaNs during training in bfloat16. TODO a bit annoying
multiplier = 1
tracing = isinstance(activations, torch.fx.Proxy) or is_torchdynamo_compiling()
if not torch.jit.is_tracing() and not tracing:
multiplier = SqrtBoundDerivative.apply(1 - a_square)
multiplier = reset + ~reset * multiplier
normalized_x = gated_inputs * multiplier.type(activations.dtype)
hidden_states, recurrent_states = self._rnn_scan(
hidden_states=normalized_x,
recurrent_gate=recurrent_gate,
reset=reset,
recurrent_states=self.recurrent_states,
)
self.recurrent_states = recurrent_states
return hidden_states
# TODO refactor
def _rnn_scan(
self,
hidden_states: torch.Tensor,
recurrent_gate: torch.Tensor,
reset: torch.Tensor,
recurrent_states: Union[torch.Tensor, None],
acc_dtype: torch.dtype = torch.float32,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Runs the recurrence of a linear RNN.
Args:
hidden_states: The input sequence.
recurrent_gate: The diagonal of the recurrence matrix `A`.
reset: Indicator of document boundaries, e.g. when to reset the hidden state
of the RNN.
recurrent_states: The initial hidden state.
acc_dtype: The data type for the accumulation.
Returns:
The output of the linear recurrence.
"""
# Multiply `a` by the reset.
recurrent_gate = recurrent_gate * ~reset
if hidden_states.shape[1] == 1:
# Using scan in sampling mode.
if recurrent_states is None: # same here, when decoding you always have cache
return hidden_states, hidden_states[:, 0].type(acc_dtype)
else:
contextualized_states = recurrent_gate.type(acc_dtype) * recurrent_states[:, None].to(
recurrent_gate.device
)
contextualized_states += hidden_states.type(acc_dtype)
return contextualized_states.type(hidden_states.dtype), contextualized_states[:, -1]
else:
# Using scan in linear mode.
if recurrent_states is None:
recurrent_states = torch.zeros(hidden_states[:, 0].shape, dtype=acc_dtype, device=hidden_states.device)
contextualized_states = torch.zeros_like(hidden_states)
for t in range(hidden_states.shape[1]):
recurrent_states = recurrent_gate[:, t].type(acc_dtype) * recurrent_states.to(recurrent_gate.device)
recurrent_states = recurrent_states + hidden_states[:, t].type(acc_dtype)
contextualized_states[:, t] = recurrent_states.type(hidden_states.dtype)
return contextualized_states, recurrent_states
class GriffinRecurrentblock(nn.Module):
"""Griffin and Hawk's recurrent block."""
def __init__(self, config: ModernBergConfig, layer_id: Optional[int] = None):
super().__init__()
self.lru_width = config.lru_width
self.hidden_size = config.hidden_size
self.linear_y = nn.Linear(in_features=config.hidden_size, out_features=config.lru_width)
self.linear_x = nn.Linear(in_features=config.hidden_size, out_features=config.lru_width)
self.linear_out = nn.Linear(in_features=config.lru_width, out_features=config.hidden_size)
self.conv1d_width = config.conv1d_width
self.conv_1d = nn.Conv1d(
config.lru_width,
config.lru_width,
kernel_size=config.conv1d_width,
groups=config.lru_width,
padding=config.conv1d_width - 1,
)
self.rg_lru = GriffinRglru(config)
self.act_fn = ACT2FN[config.hidden_activation]
self.conv1d_state = None
def forward(
self,
input_states: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor,
**kwargs
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
_, seq_len, _ = input_states.shape
y_branch = self.linear_y(input_states)
y_branch = self.act_fn(y_branch)
x_branch = self.linear_x(input_states)
x_branch = x_branch.transpose(1, 2)
x_branch = self.conv_1d(x_branch)[..., :seq_len]
x_branch = self.rg_lru(x_branch.transpose(1, 2), position_ids)
hidden_states = x_branch * y_branch
hidden_states = self.linear_out(hidden_states)
return hidden_states
def _setup_cache(self, batch, device, dtype):
# recurrent_states always computed in full precision
self.rg_lru.recurrent_states = torch.zeros((batch, self.lru_width), device=device, dtype=torch.float32)
self.conv1d_state = torch.zeros((batch, self.hidden_size, self.conv1d_width - 1), device=device, dtype=dtype)
def eager_attention_forward(
module: "ModernBergAttention",
qkv: torch.Tensor,
attention_mask: torch.Tensor,
sliding_window_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor],
local_attention: Tuple[int, int],
bs: int,
dim: int,
output_attentions: Optional[bool] = False,
**_kwargs,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
# qkv: [batch_size, seqlen, 3, nheads, headdim]
cos, sin = module.rotary_emb(qkv, position_ids=position_ids)
query, key, value = qkv.transpose(3, 1).unbind(dim=2)
# query, key, value: [batch_size, heads, seq_len, head_dim]
query, key = apply_rotary_pos_emb(query, key, cos, sin)
scale = module.head_dim**-0.5
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale
if local_attention != (-1, -1):
attention_mask = sliding_window_mask
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bs, -1, dim)
if output_attentions:
return (attn_output, attn_weights)
return (attn_output,)
def flash_attention_forward(
module: "ModernBergAttention",
qkv: torch.Tensor,
rotary_emb: ModernBertUnpaddedRotaryEmbedding,
cu_seqlens: torch.Tensor,
max_seqlen: int,
local_attention: Tuple[int, int],
bs: int,
dim: int,
target_dtype: torch.dtype = torch.bfloat16,
**_kwargs,
) -> Tuple[torch.Tensor]:
# (total_seqlen, 3, nheads, headdim)
qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
if convert_dtype:
# FA2 implementation only supports fp16 and bf16. If FA2 is supported,
# bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
orig_dtype = qkv.dtype
qkv = qkv.to(target_dtype)
attn = flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
dropout_p=module.attention_dropout if module.training else 0.0,
deterministic=module.deterministic_flash_attn,
window_size=local_attention,
)
attn = attn.to(orig_dtype) # type: ignore
else:
attn = flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
dropout_p=module.attention_dropout if module.training else 0.0,
deterministic=module.deterministic_flash_attn,
window_size=local_attention,
)
return (attn.view(bs, dim),)
def sdpa_attention_forward(
module: "ModernBergAttention",
qkv: torch.Tensor,
attention_mask: torch.Tensor,
sliding_window_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor],
local_attention: Tuple[int, int],
bs: int,
dim: int,
**_kwargs,
) -> Tuple[torch.Tensor]:
# qkv: [batch_size, seqlen, 3, nheads, headdim]
cos, sin = module.rotary_emb(qkv, position_ids=position_ids)
query, key, value = qkv.transpose(3, 1).unbind(dim=2)
# query, key, value: [batch_size, heads, seq_len, head_dim]
query, key = apply_rotary_pos_emb(query, key, cos, sin)
if local_attention != (-1, -1):
attention_mask = sliding_window_mask
attn_output = (
F.scaled_dot_product_attention(
query,
key,
value,
dropout_p=module.attention_dropout if module.training else 0.0,
attn_mask=attention_mask,
)
.transpose(1, 2)
.contiguous()
)
attn_output = attn_output.view(bs, -1, dim)
return (attn_output,)
MODERNBERT_ATTENTION_FUNCTION = {
"flash_attention_2": flash_attention_forward,
"eager": eager_attention_forward,
"sdpa": sdpa_attention_forward,
}
class ModernBergRotaryEmbedding(GemmaRotaryEmbedding):
def __init__(self, config: ModernBergConfig, dim: int, base: float, device: Optional[torch.device] = None):
# JANK!!! JANK!!! JANK!!!
config.rope_theta = base
super().__init__(config=config, device=device)
inv_freq, self.attention_scaling = self.rope_init_fn(None, device, dim=dim, base=base)
class ModernBergAttention(nn.Module):
"""Performs multi-headed self attention on a batch of unpadded sequences.
If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput.
If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel,
which requires padding and unpadding inputs, adding some overhead.
See `forward` method for additional details.
"""
def __init__(self, config: ModernBergConfig, layer_id: Optional[int] = None):
super().__init__()
self.config = config
self.layer_id = layer_id
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})"
)
self.attention_dropout = config.attention_dropout
self.deterministic_flash_attn = config.deterministic_flash_attn
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.all_head_size = self.head_dim * self.num_heads
self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias)
assert layer_id % config.global_temporal_every_n_layers != 0, "ModernBerg does not support global self-attention"
self.local_attention = (config.local_attention // 2, config.local_attention // 2)
rope_theta = config.global_rope_theta
max_position_embeddings = config.max_position_embeddings
if self.local_attention != (-1, -1):
if config.local_rope_theta is not None:
rope_theta = config.local_rope_theta
max_position_embeddings = config.local_attention
if config._attn_implementation == "flash_attention_2":
self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(
dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
)
else:
self.rotary_emb = ModernBergRotaryEmbedding(config=config, dim=self.head_dim, base=rope_theta)
self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
self.pruned_heads = set()
def forward(
self,
hidden_states: torch.Tensor,
output_attentions: Optional[bool] = False,
**kwargs,
) -> torch.Tensor:
qkv = self.Wqkv(hidden_states)
bs = hidden_states.shape[0]
if self.config._attn_implementation == "flash_attention_2":
qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
else:
qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim)
attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation](
self,
qkv=qkv,
rotary_emb=self.rotary_emb,
local_attention=self.local_attention,
bs=bs,
dim=self.all_head_size,
output_attentions=output_attentions,
**kwargs,
)
hidden_states = attn_outputs[0]
hidden_states = self.out_drop(self.Wo(hidden_states))
return hidden_states
class ModernBergTemporalLayer(nn.Module):
def __init__(self, config: ModernBergConfig, layer_id: Optional[int] = None):
super().__init__()
self.config = config
if layer_id % config.global_temporal_every_n_layers == 0:
self.temporal = GriffinRecurrentblock(config=config, layer_id=layer_id)
else:
self.temporal = ModernBergAttention(config=config, layer_id=layer_id)
def forward(self, hidden_states: torch.Tensor, **kwargs):
return self.temporal(hidden_states, **kwargs)
class ModernBergEncoderLayer(nn.Module):
def __init__(self, config: ModernBergConfig, layer_id: Optional[int] = None):
super().__init__()
self.config = config
if layer_id == 0:
self.temporal_norm = nn.Identity()
else:
self.temporal_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
self.temporal = ModernBergTemporalLayer(config=config, layer_id=layer_id)
self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
self.mlp = ModernBertMLP(config)
@torch.compile(dynamic=True)
def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.mlp(self.mlp_norm(hidden_states))
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
output_attentions: Optional[bool] = False,
) -> torch.Tensor:
attn_outputs = self.temporal(
self.temporal_norm(hidden_states),
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
output_attentions=output_attentions,
)
hidden_states = hidden_states + attn_outputs[0]
mlp_output = (
self.compiled_mlp(hidden_states)
if self.config.reference_compile
else self.mlp(self.mlp_norm(hidden_states))
)
hidden_states = hidden_states + mlp_output
return hidden_states
MODERNBERG_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`ModernBergConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare ModernBerg Model outputting raw hidden-states without any specific head on top.",
MODERNBERG_START_DOCSTRING,
)
class ModernBergPreTrainedModel(PreTrainedModel):
config_class = ModernBergConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["ModernBergEmbeddings", "ModernBergEncoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = False
def _init_weights(self, module: nn.Module):
cutoff_factor = self.config.initializer_cutoff_factor
if cutoff_factor is None:
cutoff_factor = 3
def init_weight(module: nn.Module, std: float):
nn.init.trunc_normal_(
module.weight,
mean=0.0,
std=std,
a=-cutoff_factor * std,
b=cutoff_factor * std,
)
if isinstance(module, nn.Linear):
if module.bias is not None:
nn.init.zeros_(module.bias)
stds = {
"in": self.config.initializer_range,
"out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers),
"embedding": self.config.initializer_range,
"final_out": self.config.hidden_size**-0.5,
}
std = math.sqrt(self.config.initializer_range / self.config.conv1d_width)
if isinstance(module, ModernBertEmbeddings):
init_weight(module.tok_embeddings, stds["embedding"])
elif isinstance(module, ModernBertMLP):
init_weight(module.Wi, stds["in"])
init_weight(module.Wo, stds["out"])
elif isinstance(module, ModernBergAttention):
init_weight(module.Wqkv, stds["in"])
init_weight(module.Wo, stds["out"])
elif isinstance(module, GriffinRecurrentblock):
torch.nn.init.zeros_(module.linear_x.bias)
torch.nn.init.normal_(module.linear_x.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size))
torch.nn.init.zeros_(module.linear_y.bias)
torch.nn.init.normal_(module.linear_y.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size))
std = math.sqrt(self.config.initializer_range / self.config.lru_width)
torch.nn.init.normal_(module.linear_out.weight, mean=0.0, std=std)
torch.nn.init.zeros_(module.linear_out.bias)
elif isinstance(module, GriffinRglru):
std = math.sqrt(
self.config.initializer_range / (self.config.lru_width // self.config.num_attention_heads)
)
torch.nn.init.normal_(module.input_gate_weight, mean=0.0, std=std)
torch.nn.init.normal_(module.recurrent_gate_weight, mean=0.0, std=std)
torch.nn.init.zeros_(module.input_gate_bias)
torch.nn.init.zeros_(module.recurrent_gate_bias)
module.recurrent_param.data.uniform_(0.9**2 + 1e-8, 0.999**2 + 1e-8)
module.recurrent_param.data.log_().mul_(0.5)
module.recurrent_param.data.neg_().exp_().sub_(1.0).log_()
elif isinstance(module, ModernBergPredictionHead):
init_weight(module.dense, stds["out"])
elif isinstance(module, ModernBergForMaskedLM):
init_weight(module.decoder, stds["out"])
elif isinstance(module, (ModernBergForSequenceClassification, ModernBergForTokenClassification)):
init_weight(module.classifier, stds["final_out"])
elif isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
if getattr(module, "bias", None) is not None:
torch.nn.init.zeros_(module.bias)
@classmethod
def _autoset_attn_implementation(
cls,
config,
use_flash_attention_2: bool = False,
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, Dict[str, int]]] = None,
check_device_map: bool = True,
):
# If the user didn't specify anything, try to use flash_attention_2 if available.
# Otherwise we fall back to the default SDPA -> Eager from the super() method.
# ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
# need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
if config._attn_implementation_internal is None:
config._attn_implementation_internal = "flash_attention_2"
try:
return cls._check_and_enable_flash_attn_2(
config,
torch_dtype=torch.float16,
device_map=device_map,
hard_check_only=False,
check_device_map=check_device_map,
)
except (ValueError, ImportError):
config._attn_implementation_internal = None
return super()._autoset_attn_implementation(
config,
use_flash_attention_2=use_flash_attention_2,
torch_dtype=torch.float16,
device_map=device_map,
check_device_map=check_device_map,
)
def _maybe_set_compile(self):
if self.config.reference_compile is False:
return
if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1:
if self.config.reference_compile:
logger.warning_once(
"If `accelerate` split the model across devices, `torch.compile` will not work. "
"Falling back to non-compiled mode."
)
self.config.reference_compile = False
if self.device.type == "mps":
if self.config.reference_compile:
logger.warning_once(
"Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. "
"Falling back to non-compiled mode."
)
self.config.reference_compile = False
if self.device.type == "cpu":
if self.config.reference_compile:
logger.warning_once(
"Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. "
"Falling back to non-compiled mode."
)
self.config.reference_compile = False
if self.config.reference_compile is None:
self.config.reference_compile = is_triton_available()
def resize_token_embeddings(self, *args, **kwargs):
model_embeds = super().resize_token_embeddings(*args, **kwargs)
if self.config.reference_compile in {True, None}:
if self.config.reference_compile:
logger.warning_once(
"Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode."
)
self.config.reference_compile = False
return model_embeds
MODERNBERG_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. With Flash Attention 2.0, padding will be ignored
by default should you provide it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
far-away tokens in the local attention layers when not using Flash Attention.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
max_seqlen (`int`, *optional*):
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
batch_size (`int`, *optional*):
Batch size of the input sequences. Used to pad the output tensors.
seq_len (`int`, *optional*):
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare ModernBerg Model outputting raw hidden-states without any specific head on top.",
MODERNBERG_START_DOCSTRING,
)
class ModernBergModel(ModernBergPreTrainedModel):
def __init__(self, config: ModernBergConfig):
super().__init__(config)
self.config = config
self.embeddings = ModernBertEmbeddings(config)
self.layers = nn.ModuleList(
[ModernBergEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)]
)
self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embeddings.tok_embeddings
def set_input_embeddings(self, value):
self.embeddings.tok_embeddings = value
@add_start_docstrings_to_model_forward(MODERNBERG_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
batch_size: Optional[int] = None,
seq_len: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
self._maybe_set_compile()
if input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
if batch_size is None and seq_len is None:
if inputs_embeds is not None:
batch_size, seq_len = inputs_embeds.shape[:2]
else:
batch_size, seq_len = input_ids.shape[:2]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
repad = False
if self.config._attn_implementation == "flash_attention_2":
if indices is None and cu_seqlens is None and max_seqlen is None:
repad = True
if inputs_embeds is None:
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
inputs=input_ids, attention_mask=attention_mask
)
else:
inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
inputs=inputs_embeds, attention_mask=attention_mask
)
else:
if position_ids is None:
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
attention_mask, sliding_window_mask = self._update_attention_mask(
attention_mask, output_attentions=output_attentions
)
hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
for encoder_layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
sliding_window_mask,
position_ids,
cu_seqlens,
max_seqlen,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions and len(layer_outputs) > 1:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states = self.final_norm(hidden_states)
if repad:
hidden_states = _pad_modernbert_output(
inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len
)
if all_hidden_states is not None:
all_hidden_states = tuple(
_pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
for hs in all_hidden_states
)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor:
if output_attentions:
if self.config._attn_implementation == "sdpa":
logger.warning_once(
"Outputting attentions is only supported with the 'eager' attention implementation, "
'not with "sdpa". Falling back to `attn_implementation="eager"`.'
)
self.config._attn_implementation = "eager"
elif self.config._attn_implementation != "eager":
logger.warning_once(
"Outputting attentions is only supported with the eager attention implementation, "
f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.'
" Setting `output_attentions=False`."
)
global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype)
# Create position indices
rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0)
# Calculate distance between positions
distance = torch.abs(rows - rows.T)
# Create sliding window mask (1 for positions within window, 0 outside)
window_mask = (
(distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device)
)
# Combine with existing mask
sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min)
return global_attention_mask, sliding_window_mask
class ModernBergPredictionHead(nn.Module):
def __init__(self, config: ModernBergConfig):
super().__init__()
self.config = config
self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias)
self.act = ACT2FN[config.classifier_activation]
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.norm(self.act(self.dense(hidden_states)))
@add_start_docstrings(
"The ModernBerg Model with a decoder head on top that is used for masked language modeling.",
MODERNBERG_START_DOCSTRING,
)
class ModernBergForMaskedLM(ModernBergPreTrainedModel):
_tied_weights_keys = ["decoder.weight"]
def __init__(self, config: ModernBergConfig):
super().__init__(config)
self.config = config
self.model = ModernBergModel(config)
self.head = ModernBergPredictionHead(config)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias)
self.sparse_prediction = self.config.sparse_prediction
self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.decoder
def set_output_embeddings(self, new_embeddings: nn.Linear):
self.decoder = new_embeddings
@torch.compile(dynamic=True)
def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
return self.decoder(self.head(output))
@add_start_docstrings_to_model_forward(MODERNBERG_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
batch_size: Optional[int] = None,
seq_len: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
self._maybe_set_compile()
if self.config._attn_implementation == "flash_attention_2":
if indices is None and cu_seqlens is None and max_seqlen is None:
if batch_size is None and seq_len is None:
if inputs_embeds is not None:
batch_size, seq_len = inputs_embeds.shape[:2]
else:
batch_size, seq_len = input_ids.shape[:2]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
if inputs_embeds is None:
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
)
else:
inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
batch_size=batch_size,
seq_len=seq_len,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = outputs[0]
if self.sparse_prediction and labels is not None:
# flatten labels and output first
labels = labels.view(-1)
last_hidden_state = last_hidden_state.view(labels.shape[0], -1)
# then filter out the non-masked tokens
mask_tokens = labels != self.sparse_pred_ignore_index
last_hidden_state = last_hidden_state[mask_tokens]
labels = labels[mask_tokens]
logits = (
self.compiled_head(last_hidden_state)
if self.config.reference_compile
else self.decoder(self.head(last_hidden_state))
)
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size)
if self.config._attn_implementation == "flash_attention_2":
with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return MaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"The ModernBerg Model with a sequence classification head on top that performs pooling.",
MODERNBERG_START_DOCSTRING,
)
class ModernBergForSequenceClassification(ModernBergPreTrainedModel):
def __init__(self, config: ModernBergConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.model = ModernBergModel(config)
self.head = ModernBergPredictionHead(config)
self.drop = torch.nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(MODERNBERG_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
batch_size: Optional[int] = None,
seq_len: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
self._maybe_set_compile()
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
batch_size=batch_size,
seq_len=seq_len,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = outputs[0]
if self.config.classifier_pooling == "cls":
last_hidden_state = last_hidden_state[:, 0]
elif self.config.classifier_pooling == "mean":
last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
dim=1, keepdim=True
)
pooled_output = self.head(last_hidden_state)
pooled_output = self.drop(pooled_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"The ModernBerg Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.",
MODERNBERG_START_DOCSTRING,
)
class ModernBergForTokenClassification(ModernBergPreTrainedModel):
def __init__(self, config: ModernBergConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.model = ModernBergModel(config)
self.head = ModernBergPredictionHead(config)
self.drop = torch.nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(MODERNBERG_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
batch_size: Optional[int] = None,
seq_len: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
self._maybe_set_compile()
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
batch_size=batch_size,
seq_len=seq_len,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = outputs[0]
last_hidden_state = self.head(last_hidden_state)
last_hidden_state = self.drop(last_hidden_state)
logits = self.classifier(last_hidden_state)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = [
"ModernBergConfig",
"ModernBergModel",
"ModernBergPreTrainedModel",
"ModernBergForMaskedLM",
"ModernBergForSequenceClassification",
"ModernBergForTokenClassification",
]