mitre_913m / modeling_mitre.py
zhiqu22
update codes
0517e25
raw
history blame
36.7 kB
# coding=utf-8
import math
from typing import List, Optional, Tuple, Union, Dict, Any
import torch
from torch import nn
from .configuration_mitre import MitreConfig
from transformers.utils import logging
from transformers.generation import GenerationMixin
from transformers.modeling_utils import PreTrainedModel
from transformers.activations import ACT2FN
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.integrations.fsdp import is_fsdp_managed_module
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
)
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.beam_search import BeamSearchScorer
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList
logger = logging.get_logger(__name__)
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
"""
mask = input_ids.ne(padding_idx).int()
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
return incremental_indices.long() + padding_idx
# Modified from transformers.models.m2m_100.modeling_m2m_100.M2M100Attention
# and transformers.models.m2m_100.modeling_m2m_100.M2M100SdpaAttention
class MitreSdpaAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
bias: bool = True,
config: Optional[MitreConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Input shape: Batch x Time x Channel
Output objects: attn_output, attn_weights (always be None), past_key_value
"""
"""
1. MitreModel is using MitreSdpaAttention, which is modifed from M2M100SdpaAttention.
Notabley, both of them do not support `output_attentions=True` or `layer_head_mask` not None,
leading to 'attn_weights' always being None in output.
The plan of improving this point has a low priority.
2. We plan to improve this code with Flash Attention v2.
"""
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
past_key_value = (key_states, value_states)
query_states = self._shape(query_states, tgt_len, bsz)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=False,
)
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None, past_key_value
# Modified from transformers.models.m2m_100.modeling_m2m100.M2M100DecoderLayer
class MitreDecoderLayer(nn.Module):
def __init__(self, config: MitreConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = MitreSdpaAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
config=config,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
use_cache: Optional[bool] = True,
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
hidden_states, _, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attention_mask=attention_mask,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if use_cache:
outputs += (present_key_value,)
return outputs
class MitrePreTrainedModel(PreTrainedModel):
config_class = MitreConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["MitreDecoderLayer"]
# we plan to implement codes for falsh attention v2
_supports_flash_attn_2 = False
_supports_sdpa = True
def _init_weights(self, module):
std = self.config.init_std
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class MitreDecoder(MitrePreTrainedModel):
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MitreDecoderLayer`]
Args:
config: MitreConfig
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: MitreConfig):
super().__init__(config)
self.dropout = config.dropout
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = MitreScaledWordEmbedding(
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
)
self.src_embed_positions = MitreSinusoidalPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
self.padding_idx,
)
self.register_embed_positions = MitreSinusoidalPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
self.padding_idx,
)
self.tgt_embed_positions = MitreSinusoidalPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
self.padding_idx,
)
self.layers = nn.ModuleList([MitreDecoderLayer(config) for _ in range(config.decoder_layers)])
if config._attn_implementation != "sdpa":
raise NotImplementedError("Other attention mechanism are not implemented yet.")
# TODO implement flash atten v2 for MITRE
# self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa"
self.layer_norm = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
self._future_mask = torch.empty(0)
# Initialize weights and apply final processing
self.post_init()
def create_registers(self, input_ids):
'''
create registers by duplicating the language tag respective to each sentence.
length(registers) = length(real_tokens) = length(tokens) - length(pads)
'''
register_nums = (~input_ids.eq(self.padding_idx)).sum(dim=1)
max_register_nums = register_nums.max().item()
total_token_nums = input_ids.size(1) + max_register_nums
batch_size = input_ids.size(0)
registers = input_ids[range(batch_size), torch.argmax(input_ids, dim=-1)].unsqueeze(1).repeat(1, max_register_nums)
return registers, register_nums, total_token_nums
def combine_src_and_registers(self, input_ids, registers, register_nums, total_token_nums):
'''
return a expanded_src_tokens for positional embedding.
'''
pads = torch.full_like(registers, self.padding_idx)
expanded_src_tokens = torch.cat((pads, input_ids, registers), dim=1)
indices = torch.arange(total_token_nums).expand(input_ids.size(0), -1).to(input_ids.device)
indices = indices + register_nums.unsqueeze(1)
batch_indices = torch.arange(input_ids.shape[0]).unsqueeze(1).expand(-1, indices.size(1)).contiguous()
return expanded_src_tokens, batch_indices, indices
def fill_with_neg_inf(self, t):
return t.float().fill_(float("-inf")).type_as(t)
def build_future_mask(self, embeds, src_length, register_nums, padding_mask=None, past_key_values_length=0):
b = register_nums.size(0)
ns = src_length - register_nums
if past_key_values_length == 0:
# in training
# 1. create mask by cache
dim = embeds.size(1)
if (
self._future_mask.size(0) == 0
or self._future_mask.size(0) < dim
):
self._future_mask = torch.triu(self.fill_with_neg_inf(torch.zeros([dim, dim])), 1)
if self._future_mask.device == embeds.device:
mask = self._future_mask[:dim, :dim].clone()
else:
mask = self._future_mask[:dim, :dim].to(embeds, copy=True)
# 2. bi-directional attention in source tokens and registers
mask[ :src_length, :src_length] = 0.
# 3. create batch mask
batch_mask = mask.unsqueeze(0).expand(b, -1, -1).clone().contiguous()
# 4. mask source tokens -> registers
# 5. mask target -> source tokens
batch_indices = torch.arange(b).to(batch_mask.device).view(-1, 1, 1).expand(b, dim, dim).contiguous()
row_indices = torch.arange(dim).to(batch_mask.device).view(1, -1, 1).expand(b, dim, dim).contiguous()
col_indices = torch.arange(dim).to(batch_mask.device).view(1, 1, -1).expand(b, dim, dim).contiguous()
source_indices = (row_indices < ns.view(-1, 1, 1)) & (col_indices >= ns.view(-1, 1, 1)) & (col_indices < (ns + register_nums).view(-1, 1, 1)).contiguous()
target_indices = (row_indices >= (ns + register_nums).view(-1, 1, 1)) & (col_indices < ns.view(-1, 1, 1)).contiguous()
# 4
batch_mask[batch_indices[source_indices], row_indices[source_indices], col_indices[source_indices]] = float('-inf')
# 5
batch_mask[batch_indices[target_indices], row_indices[target_indices], col_indices[target_indices]] = float('-inf')
# shape: batch_size, head_num (1 for broadcasting), seq_len, seq_len
batch_mask = batch_mask.unsqueeze(1)
# 6. masking pads
if padding_mask is not None:
if padding_mask.any():
padding_mask = padding_mask.to(batch_mask.device).unsqueeze(1).unsqueeze(2)
batch_mask = batch_mask.masked_fill(padding_mask == 1, float('-inf'))
elif past_key_values_length > 0:
# in generation
mask = torch.zeros(past_key_values_length + 1)
mask = mask.to(embeds, copy=True)
batch_mask = mask.unsqueeze(0).expand(b, -1).clone().contiguous()
batch_indices = torch.arange(b).view(-1, 1).expand(b, past_key_values_length + 1).to(batch_mask.device)
token_indices = torch.arange(past_key_values_length + 1).view(1, -1).expand(b, past_key_values_length + 1).to(batch_mask.device)
target_to_source_mask = token_indices < ns.view(-1, 1)
batch_mask[batch_indices[target_to_source_mask], token_indices[target_to_source_mask]] = float('-inf')
batch_mask = batch_mask.unsqueeze(1)
# ensure contiguous
batch_mask = batch_mask.view(b, 1, batch_mask.shape[-2], batch_mask.shape[-1])
return batch_mask
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
registering_cache: dict = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
decoder_input_shape = decoder_input_ids.size()
decoder_input_ids = decoder_input_ids.view(-1, decoder_input_shape[-1])
padding_mask = None
if past_key_values_length > 0:
register_nums = registering_cache["register_nums"]
src_length = registering_cache["src_length"]
if input_ids is not None and past_key_values_length == 0:
# .view() additionally ensure that the memory is contiguous
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
registers, register_nums, total_token_nums = self.create_registers(input_ids)
expanded_src_tokens, batch_indices, indices = self.combine_src_and_registers(input_ids, registers, register_nums, total_token_nums)
# positional embedding for source tokens and registers
inputs_embeds = self.embed_tokens(expanded_src_tokens)
inputs_embeds_1 = inputs_embeds[:,:total_token_nums,:] + self.src_embed_positions(expanded_src_tokens[:,:total_token_nums])
inputs_embeds_2 = inputs_embeds[:,total_token_nums:,:] + self.register_embed_positions(expanded_src_tokens[:,total_token_nums:])
inputs_embeds = torch.cat((inputs_embeds_1, inputs_embeds_2), dim=1)
inputs_embeds = inputs_embeds[batch_indices, indices]
# padding mask
source_tokens = expanded_src_tokens[batch_indices, indices]
src_length = source_tokens.shape[1]
# replace the inference trigger with langtok
# namely, enc-tgt-dec-tgt strategy
if decoder_input_ids[0][0].item() != source_tokens[0][-1].item():
decoder_input_ids[:, 0] = source_tokens[:, -1]
tokens = torch.cat([source_tokens, decoder_input_ids], dim=1)
padding_mask = tokens.eq(self.padding_idx)
decoder_inputs_embeds = self.embed_tokens(decoder_input_ids)
decoder_inputs_embeds = decoder_inputs_embeds + self.tgt_embed_positions(decoder_input_ids, past_key_values_length, src_length=src_length)
if past_key_values_length == 0:
hidden_states = torch.cat([inputs_embeds, decoder_inputs_embeds], dim=1)
else:
hidden_states = decoder_inputs_embeds
attention_mask = self.build_future_mask(hidden_states, src_length, register_nums, padding_mask, past_key_values_length)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting" " `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
past_key_value=None,
use_cache=use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
past_key_value=past_key_value,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[1],)
if past_key_values_length == 0:
hidden_states = hidden_states[:,src_length:,:]
hidden_states = self.layer_norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
model_output = BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
model_output.registering_cache = {
"register_nums": register_nums,
"src_length": src_length
}
return model_output
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ScaledWordEmbedding
class MitreScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
class MitreSinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
super().__init__()
self.offset = 2
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
if hasattr(self, "weights"):
# in forward put the weights on the correct dtype and device of the param
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
self.register_buffer("weights", emb_weights, persistent=False)
@staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
"""
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
"Attention Is All You Need".
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb.to(torch.get_default_dtype())
@torch.no_grad()
def forward(
self, input_ids: torch.Tensor = None, past_key_values_length: int = 0, src_length: int = 0
):
bsz, seq_len = input_ids.size()
# Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to(
input_ids.device
)
if past_key_values_length > 0 and src_length > 0:
position_ids = torch.where(position_ids == 1, position_ids, position_ids - src_length)
# expand embeddings if needed
max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
if max_pos > self.weights.size(0):
self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach()
class MitreModel(MitrePreTrainedModel):
_tied_weights_keys = ["decoder.embed_tokens.weight"]
def __init__(self, config: MitreConfig):
super().__init__(config)
self.decoder = MitreDecoder(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.decoder.embed_tokens
def get_decoder(self):
return self.decoder
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
registering_cache: dict = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
decoder_outputs = self.decoder(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
registering_cache=registering_cache
)
model_output = Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
)
model_output.registering_cache = decoder_outputs.registering_cache
return model_output
class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
base_model_prefix = "model"
_tied_weights_keys = ["decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: MitreConfig):
super().__init__(config)
self.model = MitreModel(config)
self.lm_head = nn.Linear(config.d_model, self.model.decoder.embed_tokens.num_embeddings, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_decoder(self):
return self.model.get_decoder()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
registering_cache: dict = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
outputs = self.model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
registering_cache=registering_cache,
)
lm_logits = self.lm_head(outputs[0])
if labels is not None:
raise NotImplementedError("Please implement your loss function here.")
model_output = Seq2SeqLMOutput(
loss=None,
logits=lm_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
)
model_output.registering_cache = outputs.registering_cache
return model_output
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
@staticmethod
def _reorder_register_nums(register_nums, beam_idx):
return register_nums.index_select(0, beam_idx.to(register_nums.device))
@staticmethod
def _expand_inputs_for_generation(
input_ids: Optional[torch.LongTensor] = None,
beam_size: int = 1,
) -> torch.LongTensor:
"""
Expands input_ids from [batch_size, len(tokens)] to [batch_size * expand_size, , len(tokens)]
This is simplified from 'transformers.generation.utils.GenerationMixin._expand_inputs_for_generation'
"""
if beam_size == 1:
return input_ids
return input_ids.repeat_interleave(beam_size, dim=0)
def generate(self,
input_ids: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
**kwargs: Dict
):
"""
Inference with beam search.
This code is simplified from 'transformers.generation.utils.GenerationMixin.generate'.
This code follows the style of m2m and nllb.
Therefore, there are two points need improvement.
TODO
1. early_stop in beam search.
Current early_stop is at the beam search level instead of model level. Specficially,
although beamscorer generates eos to the sequence, the sequence is filled by 'pad(1)'.
As a result, the sequence, which has already finished, will be computed by the model
continuously. We plan to remove the finished token as Fairseq's style.
2. build self-attention mask.
Current building happens within the model. Thus, when running beam search, we have to
create a mask whose size is (beam_size * batch_size) from scratch. If we create the mask
outside of the model, we can create the mask by duplicating beam_size times.
Moreover, we can prepare a cache of mask in beam search to avoid create mask many times.
"""
if generation_config != None:
assert type(generation_config) is GenerationConfig
self.generation_config = generation_config
self.generation_config.update(**kwargs)
generation_config = self.generation_config
batch_size = input_ids.shape[0]
beam_size = generation_config.num_beams
device = input_ids.device
max_cache_length = generation_config.max_length
eos_token_id = torch.Tensor([generation_config.eos_token_id])
# initial the target tokens
decoder_input_ids = torch.full(
(batch_size, 1),
self.generation_config.decoder_start_token_id,
dtype=input_ids.dtype,
device=device
)
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=beam_size,
device=device,
length_penalty=self.generation_config.length_penalty,
do_early_stopping=self.generation_config.early_stopping,
num_beam_hyps_to_keep=self.generation_config.num_return_sequences,
max_length=max_cache_length,
)
input_ids = self._expand_inputs_for_generation(input_ids, beam_size)
decoder_input_ids = self._expand_inputs_for_generation(decoder_input_ids, beam_size)
# decoder_input_ids.to(device)
cur_len = decoder_input_ids.shape[1]
this_peer_finished = False
past_key_values = None
registering_cache = None
logits_processor = LogitsProcessorList()
stopping_criteria = StoppingCriteriaList()
beam_scores = torch.zeros((batch_size, beam_size), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * beam_size,))
while not this_peer_finished:
if past_key_values is not None:
decoder_input_ids_for_generation = decoder_input_ids[:, -1:]
else:
decoder_input_ids_for_generation = decoder_input_ids
outputs = self(input_ids, decoder_input_ids_for_generation, past_key_values=past_key_values, use_cache=True, registering_cache=registering_cache)
del input_ids
input_ids = None
past_key_values = outputs.past_key_values
registering_cache = outputs.registering_cache
next_token_logits = outputs.logits[:, -1, :].clone().float()
next_token_logits = next_token_logits.to(device)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
next_token_scores_processed = logits_processor(decoder_input_ids, next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
next_token_scores_processed
)
# reshape for beam search
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, beam_size * vocab_size)
# Beam token selection: pick 1 + eos_token_id.shape[0] next tokens for each beam so we have at least 1
# non eos token per beam.
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
n_tokens_to_keep = max(2, 1 + n_eos_tokens) * beam_size
next_token_scores, next_tokens = torch.topk(
next_token_scores, n_tokens_to_keep, dim=1, largest=True, sorted=True
)
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
next_tokens = next_tokens % vocab_size
beam_outputs = beam_scorer.process(
decoder_input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
decoder_prompt_len=1,
)
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
decoder_input_ids = torch.cat([decoder_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
del outputs
past_key_values = self._reorder_cache(past_key_values, beam_idx)
registering_cache["register_nums"] = self._reorder_register_nums(registering_cache["register_nums"], beam_idx)
cur_len = cur_len + 1
if beam_scorer.is_done:
this_peer_finished = True
sequence_outputs = beam_scorer.finalize(
decoder_input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=generation_config.pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
decoder_prompt_len=1,
)
return sequence_outputs["sequences"]
MitreForConditionalGeneration.register_for_auto_class("AutoModel")