|
|
|
|
|
import math |
|
from typing import List, Optional, Tuple, Union, Dict, Any |
|
|
|
import torch |
|
from torch import nn |
|
from .configuration_mitre import MitreConfig |
|
from transformers.utils import logging |
|
|
|
from transformers.generation import GenerationMixin |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.activations import ACT2FN |
|
from transformers.modeling_outputs import ( |
|
BaseModelOutputWithPastAndCrossAttentions, |
|
Seq2SeqLMOutput, |
|
Seq2SeqModelOutput, |
|
) |
|
from transformers.generation.configuration_utils import GenerationConfig |
|
from transformers.generation.beam_search import BeamSearchScorer |
|
from transformers.generation.logits_process import LogitsProcessorList |
|
from transformers.generation.stopping_criteria import StoppingCriteriaList |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): |
|
""" |
|
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols |
|
are ignored. This is modified from fairseq's `utils.make_positions`. |
|
""" |
|
mask = input_ids.ne(padding_idx).int() |
|
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask |
|
return incremental_indices.long() + padding_idx |
|
|
|
|
|
|
|
|
|
class MitreSdpaAttention(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
embed_dim: int, |
|
num_heads: int, |
|
dropout: float = 0.0, |
|
bias: bool = True, |
|
config: Optional[MitreConfig] = None, |
|
): |
|
super().__init__() |
|
self.embed_dim = embed_dim |
|
self.num_heads = num_heads |
|
self.dropout = dropout |
|
self.head_dim = embed_dim // num_heads |
|
self.config = config |
|
|
|
if (self.head_dim * num_heads) != self.embed_dim: |
|
raise ValueError( |
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" |
|
f" and `num_heads`: {num_heads})." |
|
) |
|
self.scaling = self.head_dim**-0.5 |
|
|
|
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
""" |
|
1. MitreModel uses MitreSdpaAttention, which is modified from M2M100SdpaAttention. |
|
Notably, neither of them supports 'output_attentions=True' or 'layer_head_mask is not None', |
|
meaning that attn_weights are not included in the output. |
|
Improving this feature is currently a low priority, and we leave this functionality for users to customize. |
|
2.We plan to enhance this code with Flash Attention v2 in the future. |
|
""" |
|
bsz, tgt_len, _ = hidden_states.size() |
|
|
|
|
|
query_states = self.q_proj(hidden_states) |
|
if past_key_value is not None: |
|
|
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz) |
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz) |
|
key_states = torch.cat([past_key_value[0], key_states], dim=2) |
|
value_states = torch.cat([past_key_value[1], value_states], dim=2) |
|
else: |
|
|
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz) |
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz) |
|
|
|
past_key_value = (key_states, value_states) |
|
|
|
query_states = self._shape(query_states, tgt_len, bsz) |
|
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention( |
|
query_states, |
|
key_states, |
|
value_states, |
|
attn_mask=attention_mask, |
|
dropout_p=self.dropout if self.training else 0.0, |
|
is_causal=False, |
|
) |
|
|
|
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): |
|
raise ValueError( |
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" |
|
f" {attn_output.size()}" |
|
) |
|
|
|
attn_output = attn_output.transpose(1, 2) |
|
|
|
|
|
|
|
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) |
|
|
|
attn_output = self.out_proj(attn_output) |
|
|
|
return attn_output, None, past_key_value |
|
|
|
|
|
|
|
class MitreDecoderLayer(nn.Module): |
|
def __init__(self, config: MitreConfig): |
|
super().__init__() |
|
self.embed_dim = config.d_model |
|
|
|
self.self_attn = MitreSdpaAttention( |
|
embed_dim=self.embed_dim, |
|
num_heads=config.decoder_attention_heads, |
|
dropout=config.attention_dropout, |
|
config=config, |
|
) |
|
self.dropout = config.dropout |
|
self.activation_fn = ACT2FN[config.activation_function] |
|
self.activation_dropout = config.activation_dropout |
|
|
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) |
|
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) |
|
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) |
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
use_cache: Optional[bool] = True, |
|
) -> torch.Tensor: |
|
""" |
|
Args: |
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
|
attention_mask (`torch.FloatTensor`): attention mask of size |
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. |
|
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states |
|
""" |
|
residual = hidden_states |
|
hidden_states = self.self_attn_layer_norm(hidden_states) |
|
|
|
|
|
|
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None |
|
|
|
hidden_states, _, present_key_value = self.self_attn( |
|
hidden_states=hidden_states, |
|
past_key_value=self_attn_past_key_value, |
|
attention_mask=attention_mask, |
|
) |
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
hidden_states = residual + hidden_states |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = self.final_layer_norm(hidden_states) |
|
hidden_states = self.activation_fn(self.fc1(hidden_states)) |
|
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) |
|
hidden_states = self.fc2(hidden_states) |
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
hidden_states = residual + hidden_states |
|
|
|
outputs = (hidden_states,) |
|
|
|
if use_cache: |
|
outputs += (present_key_value,) |
|
|
|
return outputs |
|
|
|
|
|
class MitrePreTrainedModel(PreTrainedModel): |
|
config_class = MitreConfig |
|
base_model_prefix = "model" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["MitreDecoderLayer"] |
|
|
|
_supports_flash_attn_2 = False |
|
_supports_sdpa = True |
|
|
|
def _init_weights(self, module): |
|
std = self.config.init_std |
|
if isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
|
|
class MitreDecoder(MitrePreTrainedModel): |
|
""" |
|
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MitreDecoderLayer`] |
|
|
|
Args: |
|
config: MitreConfig |
|
embed_tokens (nn.Embedding): output embedding |
|
""" |
|
|
|
def __init__(self, config: MitreConfig): |
|
super().__init__(config) |
|
self.dropout = config.dropout |
|
self.padding_idx = config.pad_token_id |
|
self.max_target_positions = config.max_position_embeddings |
|
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 |
|
|
|
self.embed_tokens = MitreScaledWordEmbedding( |
|
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale |
|
) |
|
|
|
self.src_embed_positions = MitreSinusoidalPositionalEmbedding( |
|
config.max_position_embeddings, |
|
config.d_model, |
|
self.padding_idx, |
|
) |
|
self.register_embed_positions = MitreSinusoidalPositionalEmbedding( |
|
config.max_position_embeddings, |
|
config.d_model, |
|
self.padding_idx, |
|
) |
|
self.tgt_embed_positions = MitreSinusoidalPositionalEmbedding( |
|
config.max_position_embeddings, |
|
config.d_model, |
|
self.padding_idx, |
|
) |
|
self.layers = nn.ModuleList([MitreDecoderLayer(config) for _ in range(config.decoder_layers)]) |
|
if config._attn_implementation != "sdpa": |
|
raise NotImplementedError("Other attention mechanism are not implemented yet.") |
|
|
|
|
|
|
|
self._use_sdpa = config._attn_implementation == "sdpa" |
|
self.layer_norm = nn.LayerNorm(config.d_model) |
|
|
|
self.gradient_checkpointing = False |
|
self._future_mask = torch.empty(0) |
|
|
|
self.post_init() |
|
|
|
def create_registers(self, input_ids): |
|
''' |
|
create registers by duplicating the language tag respective to each sentence. |
|
length(registers) = length(real_tokens) = length(tokens) - length(pads) |
|
''' |
|
register_nums = (~input_ids.eq(self.padding_idx)).sum(dim=1) |
|
max_register_nums = register_nums.max().item() |
|
total_token_nums = input_ids.size(1) + max_register_nums |
|
batch_size = input_ids.size(0) |
|
registers = input_ids[range(batch_size), torch.argmax(input_ids, dim=-1)].unsqueeze(1).repeat(1, max_register_nums) |
|
return registers, register_nums, total_token_nums |
|
|
|
def get_token_indices(self, input_ids, total_token_nums, register_nums): |
|
''' |
|
return a token_indices for selecting source tokens from expanded_src_tokens |
|
''' |
|
token_indices = torch.arange(total_token_nums).expand(input_ids.size(0), -1).to(input_ids.device) |
|
token_indices = token_indices + register_nums.unsqueeze(1) |
|
return token_indices |
|
|
|
def get_batch_indices(self, input_ids, token_indices): |
|
''' |
|
return a batch_indices for selecting source tokens from expanded_src_tokens |
|
''' |
|
batch_indices = torch.arange(input_ids.shape[0]).unsqueeze(1).expand(-1, token_indices.size(1)).contiguous() |
|
return batch_indices |
|
|
|
def combine_src_and_registers(self, input_ids, registers): |
|
''' |
|
return a expanded_src_tokens for positional embedding. |
|
''' |
|
pads = torch.full_like(registers, self.padding_idx) |
|
expanded_src_tokens = torch.cat((pads, input_ids, registers), dim=1) |
|
return expanded_src_tokens |
|
|
|
def source_tokens_embedding_with_positions(self, expanded_src_tokens, total_token_nums, batch_indices, indices): |
|
''' |
|
return the embeds of source tokens |
|
''' |
|
inputs_embeds = self.embed_tokens(expanded_src_tokens) |
|
inputs_embeds_1 = inputs_embeds[:,:total_token_nums,:] + self.src_embed_positions(expanded_src_tokens[:,:total_token_nums]) |
|
inputs_embeds_2 = inputs_embeds[:,total_token_nums:,:] + self.register_embed_positions(expanded_src_tokens[:,total_token_nums:]) |
|
inputs_embeds = torch.cat((inputs_embeds_1, inputs_embeds_2), dim=1) |
|
inputs_embeds = inputs_embeds[batch_indices, indices] |
|
|
|
return inputs_embeds |
|
|
|
def fill_with_neg_inf(self, t): |
|
return t.float().fill_(float("-inf")).type_as(t) |
|
|
|
def check_contiguous(self, t: torch.Tensor): |
|
return t if t.is_contiguous() else t.contiguous() |
|
|
|
def build_future_mask(self, embeds, src_length, register_nums, past_key_values_length=0): |
|
b = register_nums.size(0) |
|
ns = src_length - register_nums |
|
if past_key_values_length == 0: |
|
|
|
|
|
dim = embeds.size(1) |
|
if ( |
|
self._future_mask.size(0) == 0 |
|
or self._future_mask.size(0) < dim |
|
): |
|
self._future_mask = torch.triu(self.fill_with_neg_inf(torch.zeros([dim, dim])), 1) |
|
if self._future_mask.device == embeds.device: |
|
mask = self._future_mask[:dim, :dim].clone() |
|
else: |
|
mask = self._future_mask[:dim, :dim].to(embeds, copy=True) |
|
|
|
|
|
mask[ :src_length, :src_length] = 0. |
|
|
|
|
|
batch_mask = mask.unsqueeze(0).expand(b, -1, -1).clone().contiguous() |
|
|
|
|
|
|
|
batch_indices = torch.arange(b).to(batch_mask.device).view(-1, 1, 1).expand(b, dim, dim).contiguous() |
|
row_indices = torch.arange(dim).to(batch_mask.device).view(1, -1, 1).expand(b, dim, dim).contiguous() |
|
col_indices = torch.arange(dim).to(batch_mask.device).view(1, 1, -1).expand(b, dim, dim).contiguous() |
|
source_indices = (row_indices < ns.view(-1, 1, 1)) & (col_indices >= ns.view(-1, 1, 1)) & (col_indices < (ns + register_nums).view(-1, 1, 1)).contiguous() |
|
target_indices = (row_indices >= (ns + register_nums).view(-1, 1, 1)) & (col_indices < ns.view(-1, 1, 1)).contiguous() |
|
|
|
batch_mask[batch_indices[source_indices], row_indices[source_indices], col_indices[source_indices]] = float('-inf') |
|
|
|
batch_mask[batch_indices[target_indices], row_indices[target_indices], col_indices[target_indices]] = float('-inf') |
|
|
|
batch_mask = batch_mask.unsqueeze(1) |
|
|
|
elif past_key_values_length > 0: |
|
|
|
|
|
|
|
mask = torch.zeros(past_key_values_length + 1) |
|
mask = mask.to(embeds, copy=True) |
|
batch_mask = mask.unsqueeze(0).expand(b, -1).clone().contiguous() |
|
|
|
batch_indices = torch.arange(b).view(-1, 1).expand(b, past_key_values_length + 1).to(batch_mask.device) |
|
token_indices = torch.arange(past_key_values_length + 1).view(1, -1).expand(b, past_key_values_length + 1).to(batch_mask.device) |
|
target_to_source_mask = token_indices < ns.view(-1, 1) |
|
|
|
batch_mask[batch_indices[target_to_source_mask], token_indices[target_to_source_mask]] = float('-inf') |
|
batch_mask = batch_mask.unsqueeze(1) |
|
|
|
batch_mask = batch_mask.view(b, 1, batch_mask.shape[-2], batch_mask.shape[-1]) |
|
return batch_mask |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
decoder_input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
use_cache: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
registering_cache: dict = None, |
|
): |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
|
|
|
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 |
|
|
|
if past_key_values_length > 0: |
|
register_nums = registering_cache["register_nums"] |
|
src_length = registering_cache["src_length"] |
|
|
|
if input_ids is not None and past_key_values_length == 0: |
|
|
|
input_ids = self.check_contiguous(input_ids) |
|
decoder_input_ids = self.check_contiguous(decoder_input_ids) |
|
|
|
if attention_mask is None: |
|
|
|
registers, register_nums, total_token_nums = self.create_registers(input_ids) |
|
|
|
expanded_src_tokens = self.combine_src_and_registers(input_ids, registers) |
|
token_indices = self.get_token_indices(input_ids, total_token_nums, register_nums) |
|
batch_indices = self.get_batch_indices(input_ids, token_indices) |
|
|
|
source_tokens = expanded_src_tokens[batch_indices, token_indices] |
|
|
|
else: |
|
|
|
|
|
if registering_cache is None or \ |
|
not all(key in registering_cache for key in \ |
|
("register_nums", "total_token_nums", "expanded_src_tokens",\ |
|
"batch_indices", "token_indices", "source_tokens")): |
|
raise ValueError( |
|
"If you generate registers by external codes, \ |
|
you must provide 'register_nums', 'total_token_nums', \ |
|
'expanded_src_tokens', 'batch_indices', 'token_indices' \ |
|
and 'source_tokens' in 'registering_cache' in the training." |
|
) |
|
register_nums, total_token_nums = registering_cache["register_nums"], registering_cache["total_token_nums"] |
|
expanded_src_tokens = registering_cache["expanded_src_tokens"] |
|
batch_indices, token_indices = registering_cache["batch_indices"], registering_cache["token_indices"] |
|
source_tokens = registering_cache["source_tokens"] |
|
|
|
|
|
expanded_src_tokens = self.check_contiguous(expanded_src_tokens) |
|
source_tokens = self.check_contiguous(source_tokens) |
|
|
|
|
|
inputs_embeds = self.source_tokens_embedding_with_positions(expanded_src_tokens, total_token_nums, batch_indices, token_indices) |
|
|
|
|
|
|
|
if decoder_input_ids[0][0].item() != source_tokens[0][-1].item(): |
|
decoder_input_ids[:, 0] = source_tokens[:, -1] |
|
|
|
tokens = torch.cat([source_tokens, decoder_input_ids], dim=1) |
|
src_length = source_tokens.shape[1] |
|
|
|
decoder_inputs_embeds = self.embed_tokens(decoder_input_ids) |
|
decoder_inputs_embeds = decoder_inputs_embeds + self.tgt_embed_positions(decoder_input_ids, past_key_values_length, src_length=src_length) |
|
|
|
if past_key_values_length == 0: |
|
hidden_states = torch.cat([inputs_embeds, decoder_inputs_embeds], dim=1) |
|
else: |
|
hidden_states = decoder_inputs_embeds |
|
|
|
|
|
hidden_states = self.check_contiguous(hidden_states) |
|
|
|
|
|
|
|
if attention_mask is None: |
|
attention_mask = self.build_future_mask(hidden_states, src_length, register_nums, past_key_values_length) |
|
else: |
|
bsz, src_len = hidden_states.shape[0], hidden_states.shape[1] |
|
tgt_len = hidden_states.shape[1] if past_key_values_length == 0 else past_key_values_length + 1 |
|
if attention_mask.size() != (bsz, 1, src_len, tgt_len): |
|
raise ValueError( |
|
f"Attention mask should be of size {(bsz, 1, src_len, tgt_len)}, but is {attention_mask.size()}" |
|
) |
|
|
|
|
|
attention_mask = self.check_contiguous(attention_mask) |
|
|
|
|
|
|
|
max_register_num = None |
|
|
|
if past_key_values_length == 0: |
|
|
|
max_register_num = register_nums.max().item() if use_cache else None |
|
|
|
padding_mask = tokens.eq(self.padding_idx) |
|
if padding_mask.any(): |
|
padding_mask = padding_mask.unsqueeze(1).unsqueeze(2) |
|
attention_mask = attention_mask.masked_fill(padding_mask == 1, float('-inf')) |
|
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
|
|
if self.gradient_checkpointing and self.training: |
|
if use_cache: |
|
logger.warning_once( |
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting" " `use_cache=False`..." |
|
) |
|
use_cache = False |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
next_decoder_cache = () if use_cache else None |
|
|
|
for idx, decoder_layer in enumerate(self.layers): |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
past_key_value = past_key_values[idx] if past_key_values is not None else None |
|
|
|
if self.gradient_checkpointing and self.training: |
|
layer_outputs = self._gradient_checkpointing_func( |
|
decoder_layer.__call__, |
|
hidden_states, |
|
attention_mask, |
|
past_key_value=None, |
|
use_cache=use_cache, |
|
) |
|
else: |
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
past_key_value=past_key_value, |
|
use_cache=use_cache, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if use_cache: |
|
if past_key_values_length > 0: |
|
next_decoder_cache += (layer_outputs[1],) |
|
else: |
|
cache_key, cache_value = layer_outputs[1] |
|
clipped_rep = ( |
|
cache_key[:, :, src_length - max_register_num:, :], |
|
cache_value[:, :, src_length - max_register_num:, :] |
|
) |
|
next_decoder_cache += (clipped_rep,) |
|
|
|
if past_key_values_length == 0: |
|
hidden_states = hidden_states[:,src_length:,:] |
|
|
|
hidden_states = self.layer_norm(hidden_states) |
|
|
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
next_cache = next_decoder_cache if use_cache else None |
|
|
|
model_output = BaseModelOutputWithPastAndCrossAttentions( |
|
last_hidden_state=hidden_states, |
|
past_key_values=next_cache, |
|
hidden_states=all_hidden_states, |
|
) |
|
|
|
|
|
|
|
if use_cache: |
|
model_output.registering_cache = { |
|
"register_nums": register_nums, |
|
"src_length": src_length if past_key_values_length > 0 else max_register_num, |
|
"attention_mask": attention_mask if past_key_values_length > 0 else None |
|
} |
|
else: |
|
model_output.registering_cache = None |
|
|
|
return model_output |
|
|
|
|
|
|
|
class MitreScaledWordEmbedding(nn.Embedding): |
|
""" |
|
This module overrides nn.Embeddings' forward by multiplying with embeddings scale. |
|
""" |
|
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): |
|
super().__init__(num_embeddings, embedding_dim, padding_idx) |
|
self.embed_scale = embed_scale |
|
|
|
def forward(self, input_ids: torch.Tensor): |
|
return super().forward(input_ids) * self.embed_scale |
|
|
|
|
|
class MitreSinusoidalPositionalEmbedding(nn.Module): |
|
"""This module produces sinusoidal positional embeddings of any length.""" |
|
|
|
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): |
|
super().__init__() |
|
self.offset = 2 |
|
self.embedding_dim = embedding_dim |
|
self.padding_idx = padding_idx |
|
self.make_weights(num_positions + self.offset, embedding_dim, padding_idx) |
|
|
|
def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): |
|
emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx) |
|
if hasattr(self, "weights"): |
|
|
|
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) |
|
|
|
self.register_buffer("weights", emb_weights, persistent=False) |
|
|
|
@staticmethod |
|
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): |
|
""" |
|
Build sinusoidal embeddings. |
|
|
|
This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of |
|
"Attention Is All You Need". |
|
""" |
|
half_dim = embedding_dim // 2 |
|
emb = math.log(10000) / (half_dim - 1) |
|
emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) |
|
emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0) |
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) |
|
if embedding_dim % 2 == 1: |
|
|
|
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) |
|
if padding_idx is not None: |
|
emb[padding_idx, :] = 0 |
|
|
|
return emb.to(torch.get_default_dtype()) |
|
|
|
@torch.no_grad() |
|
def forward( |
|
self, input_ids: torch.Tensor = None, past_key_values_length: int = 0, src_length: int = 0 |
|
): |
|
bsz, seq_len = input_ids.size() |
|
|
|
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to( |
|
input_ids.device |
|
) |
|
|
|
if past_key_values_length > 0 and src_length > 0: |
|
position_ids = torch.where(position_ids == 1, position_ids, position_ids - src_length) |
|
|
|
|
|
max_pos = self.padding_idx + 1 + seq_len + past_key_values_length |
|
|
|
if max_pos > self.weights.size(0): |
|
self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) |
|
|
|
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach() |
|
|
|
class MitreModel(MitrePreTrainedModel): |
|
_tied_weights_keys = ["decoder.embed_tokens.weight"] |
|
|
|
def __init__(self, config: MitreConfig): |
|
super().__init__(config) |
|
|
|
self.decoder = MitreDecoder(config) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.decoder.embed_tokens |
|
|
|
def get_decoder(self): |
|
return self.decoder |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
decoder_input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
registering_cache: dict = None, |
|
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
|
decoder_outputs = self.decoder( |
|
input_ids=input_ids, |
|
decoder_input_ids=decoder_input_ids, |
|
attention_mask=attention_mask, |
|
past_key_values=past_key_values, |
|
use_cache=use_cache, |
|
output_hidden_states=output_hidden_states, |
|
registering_cache=registering_cache |
|
) |
|
|
|
model_output = Seq2SeqModelOutput( |
|
last_hidden_state=decoder_outputs.last_hidden_state, |
|
past_key_values=decoder_outputs.past_key_values, |
|
decoder_hidden_states=decoder_outputs.hidden_states, |
|
decoder_attentions=decoder_outputs.attentions, |
|
) |
|
model_output.registering_cache = decoder_outputs.registering_cache |
|
return model_output |
|
|
|
class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin): |
|
base_model_prefix = "model" |
|
_tied_weights_keys = ["decoder.embed_tokens.weight", "lm_head.weight"] |
|
|
|
def __init__(self, config: MitreConfig): |
|
super().__init__(config) |
|
self.model = MitreModel(config) |
|
self.lm_head = nn.Linear(config.d_model, self.model.decoder.embed_tokens.num_embeddings, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_decoder(self): |
|
return self.model.get_decoder() |
|
|
|
def get_output_embeddings(self): |
|
return self.lm_head |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.lm_head = new_embeddings |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
decoder_input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
registering_cache: dict = None, |
|
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: |
|
|
|
outputs = self.model( |
|
input_ids=input_ids, |
|
decoder_input_ids=decoder_input_ids, |
|
attention_mask=attention_mask, |
|
past_key_values=past_key_values, |
|
use_cache=use_cache, |
|
output_hidden_states=output_hidden_states, |
|
registering_cache=registering_cache, |
|
) |
|
|
|
lm_logits = self.lm_head(outputs[0]) |
|
|
|
if labels is not None: |
|
raise NotImplementedError("Please implement your loss function here.") |
|
|
|
model_output = Seq2SeqLMOutput( |
|
loss=None, |
|
logits=lm_logits, |
|
past_key_values=outputs.past_key_values, |
|
decoder_hidden_states=outputs.decoder_hidden_states, |
|
decoder_attentions=outputs.decoder_attentions, |
|
) |
|
model_output.registering_cache = outputs.registering_cache |
|
return model_output |
|
|
|
@staticmethod |
|
def _reorder_cache(past_key_values, beam_idx): |
|
reordered_past = () |
|
for layer_past in past_key_values: |
|
reordered_past += ( |
|
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), |
|
) |
|
return reordered_past |
|
|
|
@staticmethod |
|
def _reorder_register_cache(t, beam_idx): |
|
""" a costumized reorder method """ |
|
return t.index_select(dim=0, index=beam_idx.to(t.device)) |
|
|
|
@staticmethod |
|
def _expand_inputs_for_generation( |
|
input_ids: Optional[torch.LongTensor] = None, |
|
beam_size: int = 1, |
|
) -> torch.LongTensor: |
|
""" |
|
Expands input_ids from [batch_size, len(tokens)] to [batch_size * expand_size, , len(tokens)] |
|
This is simplified from 'transformers.generation.utils.GenerationMixin._expand_inputs_for_generation' |
|
""" |
|
if beam_size == 1: |
|
return input_ids |
|
|
|
return input_ids.repeat_interleave(beam_size, dim=0) |
|
|
|
def generate(self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
generation_config: Optional[GenerationConfig] = None, |
|
**kwargs: Dict |
|
): |
|
""" |
|
Inference with beam search. |
|
This code is an improved version of transformers.generation.utils.GenerationMixin.generate. |
|
There are two main improvements: |
|
1. 'soft early_stop' in beam search. |
|
a) problem in the vanilla version. |
|
In multilingual translation models such as NLLB and M2M, the vanilla early stop in BeamSearchScorer |
|
(the official implementation by HuggingFace) marks ended sequences with pad(1). However, these ended |
|
sequences are still fed into the model, leading to significant memory waste. |
|
b) our improvement. |
|
We implemented a "soft early stop" to address this issue. Instead of modifying BeamSearchScorer |
|
(to maintain code flexibility), we remove ended sequences from the input. Since this changes the |
|
shape of the output hidden states, we insert placeholders to maintain compatibility with |
|
BeamSearchScorer's state shapes. |
|
Based on our tests, this improvement reduces memory usage by half. |
|
2. mask reusing. |
|
a) problem: |
|
Registers require attention masks at each step. |
|
A sequence may consist of four parts: padding, source tokens, registers, and target tokens. |
|
During training, we mask all tokens before registers for target token generation. During generation, |
|
we cannot allow target tokens to "see" padding tokens, requiring masks at every step. |
|
This leads to computational inefficiency. |
|
b) our improvement. |
|
First, we turncate the source tokens and their representations to reduce cost. |
|
Second, for source tokens acting as placeholders, we modified the mask generation logic compared to |
|
our Fairseq implementation. |
|
Third, to avoid regenerating masks at each step, we cache the mask in 'registering_cache', where cached |
|
mask is managed like the key-value cache in beam search. Then, At every step, we add a column of zeros |
|
to maintain alignment. |
|
""" |
|
if generation_config != None: |
|
assert type(generation_config) is GenerationConfig |
|
self.generation_config = generation_config |
|
self.generation_config.update(**kwargs) |
|
|
|
generation_config = self.generation_config |
|
|
|
batch_size = input_ids.shape[0] |
|
beam_size = generation_config.num_beams |
|
device = input_ids.device |
|
max_cache_length = generation_config.max_length |
|
eos_token_id = torch.Tensor([generation_config.eos_token_id]) |
|
|
|
|
|
decoder_input_ids = torch.full( |
|
(batch_size, 1), |
|
self.generation_config.decoder_start_token_id, |
|
dtype=input_ids.dtype, |
|
device=device |
|
) |
|
|
|
beam_scorer = BeamSearchScorer( |
|
batch_size=batch_size, |
|
num_beams=beam_size, |
|
device=device, |
|
length_penalty=self.generation_config.length_penalty, |
|
do_early_stopping=self.generation_config.early_stopping, |
|
num_beam_hyps_to_keep=self.generation_config.num_return_sequences, |
|
max_length=max_cache_length, |
|
) |
|
|
|
input_ids = self._expand_inputs_for_generation(input_ids, beam_size) |
|
decoder_input_ids = self._expand_inputs_for_generation(decoder_input_ids, beam_size) |
|
cur_len = decoder_input_ids.shape[1] |
|
|
|
this_peer_finished = False |
|
past_key_values = None |
|
registering_cache= None |
|
attention_mask = None |
|
|
|
|
|
done_mask = None |
|
|
|
|
|
|
|
logits_processor = LogitsProcessorList() |
|
stopping_criteria = StoppingCriteriaList() |
|
|
|
beam_scores = torch.zeros((batch_size, beam_size), dtype=torch.float, device=input_ids.device) |
|
beam_scores[:, 1:] = -1e9 |
|
beam_scores = beam_scores.view((batch_size * beam_size,)) |
|
while not this_peer_finished: |
|
|
|
if past_key_values is not None: |
|
decoder_input_ids_for_generation = decoder_input_ids[:, -1:] |
|
attention_mask = registering_cache["attention_mask"] |
|
|
|
|
|
|
|
if attention_mask is not None: |
|
attention_mask = torch.cat((attention_mask, attention_mask[..., -1:]), dim=-1) |
|
else: |
|
decoder_input_ids_for_generation = decoder_input_ids |
|
|
|
outputs = self( |
|
input_ids, |
|
decoder_input_ids_for_generation, |
|
attention_mask=attention_mask, |
|
past_key_values=past_key_values, |
|
use_cache=True, |
|
registering_cache=registering_cache |
|
) |
|
del input_ids |
|
input_ids = None |
|
|
|
past_key_values = outputs.past_key_values |
|
registering_cache = outputs.registering_cache |
|
next_token_logits = outputs.logits[:, -1, :].clone().float() |
|
del outputs |
|
|
|
next_token_logits = next_token_logits.to(device) |
|
next_token_scores = nn.functional.log_softmax( |
|
next_token_logits, dim=-1 |
|
) |
|
|
|
next_token_scores_processed = logits_processor(decoder_input_ids, next_token_scores) |
|
|
|
|
|
|
|
if done_mask is not None: |
|
if done_mask.any(): |
|
|
|
restored_tensor = torch.zeros( |
|
(batch_size * beam_size, next_token_scores_processed.shape[1]), |
|
dtype=next_token_scores_processed.dtype, |
|
device=next_token_scores_processed.device |
|
) |
|
restored_tensor[~done_mask] = next_token_scores_processed |
|
next_token_scores_processed = restored_tensor |
|
|
|
restored_tokens = torch.full( |
|
(batch_size * beam_size, decoder_input_ids.shape[1]), |
|
self.generation_config.pad_token_id, |
|
dtype=decoder_input_ids.dtype, |
|
device=device |
|
) |
|
restored_tokens[~done_mask] = decoder_input_ids |
|
decoder_input_ids = restored_tokens |
|
|
|
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( |
|
next_token_scores_processed |
|
) |
|
|
|
|
|
vocab_size = next_token_scores.shape[-1] |
|
next_token_scores = next_token_scores.view(batch_size, beam_size * vocab_size) |
|
|
|
|
|
|
|
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0 |
|
n_tokens_to_keep = max(2, 1 + n_eos_tokens) * beam_size |
|
next_token_scores, next_tokens = torch.topk( |
|
next_token_scores, n_tokens_to_keep, dim=1, largest=True, sorted=True |
|
) |
|
|
|
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") |
|
next_tokens = next_tokens % vocab_size |
|
|
|
beam_outputs = beam_scorer.process( |
|
decoder_input_ids, |
|
next_token_scores, |
|
next_tokens, |
|
next_indices, |
|
pad_token_id=generation_config.pad_token_id, |
|
eos_token_id=generation_config.eos_token_id, |
|
decoder_prompt_len=1, |
|
) |
|
beam_scores = beam_outputs["next_beam_scores"] |
|
beam_next_tokens = beam_outputs["next_beam_tokens"] |
|
beam_idx = beam_outputs["next_beam_indices"] |
|
|
|
|
|
|
|
if done_mask is not None: |
|
last_done_mask = done_mask |
|
|
|
|
|
|
|
done_mask = beam_scorer._done.clone().view(-1) |
|
done_mask = self._expand_inputs_for_generation(done_mask, beam_size) |
|
beam_idx = beam_idx[~done_mask] |
|
|
|
decoder_input_ids = torch.cat([decoder_input_ids[beam_idx, :], beam_next_tokens[~done_mask].unsqueeze(-1)], dim=-1) |
|
|
|
|
|
|
|
if decoder_input_ids_for_generation.shape[0] < beam_next_tokens.shape[0]: |
|
|
|
|
|
|
|
if (~done_mask).sum() < decoder_input_ids_for_generation.shape[0]: |
|
count_mask = last_done_mask |
|
else: |
|
count_mask = done_mask |
|
|
|
|
|
|
|
|
|
|
|
prefix_sum = torch.cat([ |
|
torch.zeros_like(count_mask[:1], dtype=torch.long), |
|
torch.cumsum(count_mask.long(), dim=0) |
|
], dim=0) |
|
reorder_idx = beam_idx - prefix_sum[beam_idx] |
|
not_done = ~done_mask[beam_idx] |
|
reorder_idx = reorder_idx[not_done] |
|
else: |
|
reorder_idx = beam_idx |
|
|
|
past_key_values = self._reorder_cache(past_key_values, reorder_idx) |
|
registering_cache["register_nums"] = self._reorder_register_cache(registering_cache["register_nums"], reorder_idx) |
|
if registering_cache["attention_mask"] is not None: |
|
registering_cache["attention_mask"] = self._reorder_register_cache(registering_cache["attention_mask"], reorder_idx) |
|
|
|
cur_len = cur_len + 1 |
|
|
|
if beam_scorer.is_done: |
|
this_peer_finished = True |
|
|
|
sequence_outputs = beam_scorer.finalize( |
|
decoder_input_ids, |
|
beam_scores, |
|
next_tokens, |
|
next_indices, |
|
pad_token_id=generation_config.pad_token_id, |
|
eos_token_id=eos_token_id, |
|
max_length=stopping_criteria.max_length, |
|
decoder_prompt_len=1, |
|
) |
|
|
|
return sequence_outputs["sequences"] |
|
|
|
|
|
MitreForConditionalGeneration.register_for_auto_class("AutoModel") |