|
|
|
|
|
import copy |
|
import math |
|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
|
|
from transformers.modeling_utils import ModuleUtilsMixin |
|
from transformers.modeling_outputs import ModelOutput, Seq2SeqModelOutput, BaseModelOutput, Seq2SeqLMOutput |
|
from transformers import PreTrainedModel |
|
|
|
try: |
|
from .rms_norm import fast_rms_layernorm |
|
except ImportError: |
|
fast_rms_layernorm = None |
|
|
|
try: |
|
from .cross_entropy_loss import cross_entropy_loss as fast_cross_entropy_loss |
|
except ImportError: |
|
fast_cross_entropy_loss = None |
|
|
|
try: |
|
from .flash_attention_v2_bias import flash_attention_v2_bias |
|
except ImportError: |
|
flash_attention_v2_bias = None |
|
|
|
try: |
|
from flash_attn import flash_attn_kvpacked_func, flash_attn_func |
|
except ImportError: |
|
flash_attn_kvpacked_func, flash_attn_func = None, None |
|
|
|
from .attn_ref import attn_ref |
|
|
|
from .configuration_flash_t5 import FlashT5Config |
|
from .positional_encoding import ALiBiPositionalEncoding, RelativePositionalEncoding, RotaryPositionalEncoding, FIRE |
|
|
|
class FlashT5CrossEntropyLoss(nn.Module): |
|
def __init__(self, z_loss_factor=0.0, label_smoothing=0.0, use_triton_crossentropy=False, inplace_backward=False): |
|
|
|
super().__init__() |
|
|
|
if use_triton_crossentropy and fast_cross_entropy_loss is None: |
|
raise ImportError("fast_cross_entropy_loss is not available") |
|
|
|
self.use_triton_crossentropy = use_triton_crossentropy |
|
self.z_loss_factor = z_loss_factor |
|
self.label_smoothing = label_smoothing |
|
self.inplace_backward = inplace_backward |
|
|
|
self.cross_entropy_loss = nn.CrossEntropyLoss(label_smoothing=label_smoothing) |
|
|
|
def compute_zloss(self, logits: torch.Tensor, z_loss: float): |
|
logits_sum = torch.logsumexp(logits, dim=-1, keepdim=True) |
|
log_z = torch.squeeze(logits_sum, axis=-1) |
|
total_z_loss = z_loss * torch.square(log_z) |
|
return total_z_loss.mean() |
|
|
|
def forward(self, logits, labels): |
|
|
|
if self.use_triton_crossentropy: |
|
return fast_cross_entropy_loss(logits, labels, \ |
|
lse_square_scale=self.z_loss_factor, \ |
|
label_smoothing=self.label_smoothing, \ |
|
inplace_backward=self.inplace_backward \ |
|
)[0].mean() |
|
|
|
|
|
batch, seq_len, d = logits.shape |
|
logits_flatten = logits.float().view(batch*seq_len, d) |
|
labels_flatten = labels.view(-1) |
|
loss = self.cross_entropy_loss(logits_flatten, labels_flatten) |
|
z_loss = 0.0 |
|
if self.z_loss_factor != 0.0: |
|
z_loss = self.compute_zloss(logits_flatten[labels_flatten != -100], |
|
z_loss=self.z_loss_factor) |
|
return loss + z_loss |
|
|
|
class FlashT5LayerNorm(nn.Module): |
|
def __init__(self, hidden_size, eps=1e-6, use_triton_layernorm=False): |
|
""" |
|
Construct a layernorm module in the T5 style. No bias and no subtraction of mean. |
|
""" |
|
super().__init__() |
|
|
|
if use_triton_layernorm and fast_rms_layernorm is None: |
|
raise ImportError("fast_rms_layernorm is not available") |
|
|
|
self.use_triton_layernorm = use_triton_layernorm |
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
self.variance_epsilon = eps |
|
|
|
def forward(self, hidden_states): |
|
|
|
if self.use_triton_layernorm: |
|
return fast_rms_layernorm(hidden_states, self.weight, self.variance_epsilon) |
|
|
|
|
|
|
|
|
|
|
|
|
|
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) |
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
|
|
|
|
|
if self.weight.dtype in [torch.float16, torch.bfloat16]: |
|
hidden_states = hidden_states.to(self.weight.dtype) |
|
|
|
return self.weight * hidden_states |
|
|
|
class FlashT5DenseAct(nn.Module): |
|
def __init__(self, config: FlashT5Config): |
|
super().__init__() |
|
self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) |
|
self.dropout = nn.Dropout(config.dropout_rate) |
|
self.act = torch.nn.GELU(approximate='tanh') if config.use_gelu_act else torch.nn.ReLU() |
|
|
|
def forward(self, hidden_states): |
|
hidden_states = self.wi(hidden_states) |
|
hidden_states = self.act(hidden_states) |
|
hidden_states = self.dropout(hidden_states) |
|
if ( |
|
isinstance(self.wo.weight, torch.Tensor) |
|
and hidden_states.dtype != self.wo.weight.dtype |
|
and self.wo.weight.dtype != torch.int8 |
|
): |
|
hidden_states = hidden_states.to(self.wo.weight.dtype) |
|
|
|
return hidden_states |
|
|
|
class FlashT5DenseGatedAct(nn.Module): |
|
def __init__(self, config: FlashT5Config): |
|
super().__init__() |
|
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) |
|
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) |
|
self.dropout = nn.Dropout(config.dropout_rate) |
|
self.act = torch.nn.GELU(approximate='tanh') if config.use_gelu_act else torch.nn.ReLU() |
|
|
|
self.use_gelu_act = config.use_gelu_act |
|
|
|
def forward(self, hidden_states): |
|
|
|
hidden_act = self.act(self.wi_0(hidden_states)) |
|
hidden_linear = self.wi_1(hidden_states) |
|
hidden_states = hidden_act * hidden_linear |
|
hidden_states = self.dropout(hidden_states) |
|
|
|
return hidden_states |
|
|
|
class FlashT5LayerFF(nn.Module): |
|
def __init__(self, config: FlashT5Config): |
|
super().__init__() |
|
if config.use_glu_mlp: |
|
self.act = FlashT5DenseGatedAct(config) |
|
else: |
|
self.act = FlashT5DenseAct(config) |
|
|
|
self.layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm) |
|
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) |
|
self.dropout = nn.Dropout(config.dropout_rate) |
|
|
|
def forward(self, hidden_states): |
|
forwarded_states = self.layer_norm(hidden_states).type_as(hidden_states) |
|
forwarded_states = self.act(forwarded_states) |
|
forwarded_states = self.wo(forwarded_states) |
|
hidden_states = hidden_states + self.dropout(forwarded_states) |
|
return hidden_states |
|
|
|
|
|
class FlashT5Attention(nn.Module, ModuleUtilsMixin): |
|
def __init__(self, config: FlashT5Config, has_positional_encoding=False, is_causal=False): |
|
super().__init__() |
|
self.is_decoder = config.is_decoder |
|
self.has_positional_encoding = has_positional_encoding |
|
self.is_causal = is_causal |
|
self.relative_attention_num_buckets = config.relative_attention_num_buckets |
|
self.relative_attention_max_distance = config.relative_attention_max_distance |
|
self.d_model = config.d_model |
|
self.key_value_proj_dim = config.d_kv |
|
self.n_heads = config.num_heads |
|
self.p_dropout = config.attention_dropout_rate |
|
self.inner_dim = self.n_heads * self.key_value_proj_dim |
|
self.attention_type = config.attention_type |
|
self.position_encoding_type = config.position_encoding_type |
|
self.max_sequence_length = config.max_sequence_length |
|
self.softmax_scale = config.attention_scale if config.attention_scale is not None else 1.0/math.sqrt(self.n_heads) |
|
self.use_full_bias_size = config.use_full_bias_size |
|
self.use_masking = config.use_masking |
|
|
|
if self.use_masking and not self.use_full_bias_size: |
|
raise ValueError("Masking can only be used with full batch size.") |
|
|
|
if self.attention_type == "triton" and flash_attention_v2_bias is None: |
|
raise ImportError("flash_attention_triton is not available") |
|
elif self.attention_type.startswith("fa2") and flash_attn_func is None: |
|
raise ImportError("Flash Attention 2 is not available") |
|
|
|
if self.attention_type == "fa2_rpe" and self.position_encoding_type != "t5": |
|
raise ValueError("fa2_rpe is not compatible with non-T5 position encoding") |
|
|
|
assert (self.p_dropout == 0.0) or (self.attention_type != "triton"), "Triton attention does not support dropout" |
|
|
|
self.pe_encoding = None |
|
if self.position_encoding_type == "ALiBi" and has_positional_encoding: |
|
|
|
self.pe_encoding = ALiBiPositionalEncoding(self.max_sequence_length, |
|
self.n_heads, |
|
config.alibi_mode, |
|
randomized_position=config.use_randomized_position_encoding) |
|
elif self.position_encoding_type == "t5" and has_positional_encoding: |
|
self.pe_encoding = RelativePositionalEncoding(self.relative_attention_num_buckets, |
|
self.relative_attention_max_distance, |
|
self.n_heads, |
|
self.max_sequence_length, |
|
bidirectional=(not self.is_decoder), |
|
randomized_position=config.use_randomized_position_encoding) |
|
elif self.position_encoding_type == "RoPE": |
|
self.pe_encoding = RotaryPositionalEncoding(int(self.key_value_proj_dim * config.rotary_emb_fraction), |
|
self.max_sequence_length, |
|
config.rotary_base, |
|
config.rotary_interleaved, |
|
config.rotary_scale_base, |
|
randomized_position=config.use_randomized_position_encoding) |
|
elif self.position_encoding_type == "FIRE" and has_positional_encoding: |
|
self.pe_encoding = FIRE(num_heads=self.n_heads, |
|
mlp_width=config.fire_mlp_width, |
|
init_c=0.1, |
|
init_L=self.relative_attention_max_distance) |
|
|
|
self.Wq = nn.Linear(self.d_model, self.inner_dim, bias=False) |
|
self.Wk = nn.Linear(self.d_model, self.inner_dim, bias=False) |
|
self.Wv = nn.Linear(self.d_model, self.inner_dim, bias=False) |
|
self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
mask=None, |
|
key_value_states=None, |
|
position_bias=None, |
|
): |
|
""" |
|
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). |
|
""" |
|
|
|
|
|
batch_size, seq_length = hidden_states.shape[:2] |
|
key_length = seq_length if key_value_states is None else key_value_states.shape[1] |
|
q = self.Wq(hidden_states) |
|
if key_value_states is None: |
|
k = self.Wk(hidden_states) |
|
v = self.Wv(hidden_states) |
|
else: |
|
k = self.Wk(key_value_states) |
|
v = self.Wv(key_value_states) |
|
|
|
q = q.view(batch_size, seq_length, self.n_heads, self.key_value_proj_dim) |
|
k = k.view(batch_size, key_length, self.n_heads, self.key_value_proj_dim) |
|
v = v.view(batch_size, key_length, self.n_heads, self.key_value_proj_dim) |
|
|
|
if position_bias is None and self.pe_encoding is not None and self.attention_type != "fa2_rpe": |
|
q, k, v, position_bias = self.pe_encoding(q, k, v) |
|
|
|
if position_bias is not None and self.use_full_bias_size: |
|
position_bias = position_bias.expand(q.shape[0], q.shape[2], q.shape[1], k.shape[1]) |
|
if self.attention_type == "fa2_bias" or self.attention_type == "triton": |
|
position_bias = position_bias.contiguous() |
|
|
|
if position_bias is not None and mask is not None and self.use_masking: |
|
mask = mask.unsqueeze(1) |
|
if len(mask.shape) == 3: |
|
mask = mask.unsqueeze(3) |
|
position_bias = torch.where(mask, position_bias, torch.finfo(hidden_states.dtype).min) |
|
|
|
if self.attention_type == "fa2_bias": |
|
output = flash_attn_func(q, k, v, dropout_p=self.p_dropout, softmax_scale=self.softmax_scale, \ |
|
attn_bias=position_bias, causal=self.is_causal) |
|
elif self.attention_type == "fa2_rpe": |
|
output = flash_attn_func(q, k, v, dropout_p=self.p_dropout, softmax_scale=self.softmax_scale, \ |
|
rpe_weights=self.pe_encoding.relative_attention_bias.weight.t(), \ |
|
rpe_max_distance=self.relative_attention_max_distance, \ |
|
causal=self.is_causal) |
|
elif self.attention_type == "triton": |
|
q = q.permute(0, 2, 1, 3) |
|
k = k.permute(0, 2, 1, 3) |
|
v = v.permute(0, 2, 1, 3) |
|
output = flash_attention_v2_bias(q, k, v, position_bias, self.is_causal, self.softmax_scale) |
|
output = output.permute(0, 2, 1, 3) |
|
else: |
|
q = q.permute(0, 2, 1, 3) |
|
k = k.permute(0, 2, 1, 3) |
|
v = v.permute(0, 2, 1, 3) |
|
output = attn_ref(q, k, v, position_bias, dropout_p=self.p_dropout, sm_scale=self.softmax_scale, causal=self.is_causal) |
|
output = output.permute(0, 2, 1, 3) |
|
|
|
output = self.o(output.reshape(output.shape[0], output.shape[1], self.inner_dim)) |
|
return (output, position_bias) |
|
|
|
|
|
class FlashT5LayerSelfAttention(nn.Module): |
|
def __init__(self, config, has_positional_encoding=False): |
|
super().__init__() |
|
self.self_attention = FlashT5Attention(config, has_positional_encoding=has_positional_encoding, is_causal=config.is_decoder) |
|
self.layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm) |
|
self.dropout = nn.Dropout(config.dropout_rate) |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
attention_mask=None, |
|
position_bias=None, |
|
): |
|
normed_hidden_states = self.layer_norm(hidden_states).type_as(hidden_states) |
|
attention_output = self.self_attention( |
|
normed_hidden_states, |
|
mask=attention_mask, |
|
position_bias=position_bias, |
|
) |
|
hidden_states = hidden_states + self.dropout(attention_output[0]) |
|
outputs = (hidden_states,) + attention_output[1:] |
|
return outputs |
|
|
|
|
|
class FlashT5LayerCrossAttention(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.cross_attention = FlashT5Attention(config, has_positional_encoding=False) |
|
self.layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm) |
|
self.dropout = nn.Dropout(config.dropout_rate) |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
key_value_states, |
|
attention_mask=None, |
|
position_bias=None, |
|
): |
|
normed_hidden_states = self.layer_norm(hidden_states) |
|
attention_output = self.cross_attention( |
|
normed_hidden_states, |
|
mask=attention_mask, |
|
key_value_states=key_value_states, |
|
position_bias=position_bias, |
|
) |
|
layer_output = hidden_states + self.dropout(attention_output[0]) |
|
outputs = (layer_output,) + attention_output[1:] |
|
return outputs |
|
|
|
|
|
class FlashT5Block(nn.Module): |
|
def __init__(self, config, has_positional_encoding=False): |
|
super().__init__() |
|
self.is_decoder = config.is_decoder |
|
|
|
self.self_attention_layer = FlashT5LayerSelfAttention(config, has_positional_encoding=has_positional_encoding) |
|
|
|
if self.is_decoder: |
|
self.cross_attention_layer = FlashT5LayerCrossAttention(config) |
|
|
|
self.ff_layer = FlashT5LayerFF(config) |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
attention_mask=None, |
|
position_bias=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
encoder_decoder_position_bias=None, |
|
): |
|
self_attention_outputs = self.self_attention_layer( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
position_bias=position_bias, |
|
) |
|
hidden_states = self_attention_outputs[0] |
|
attention_outputs = self_attention_outputs[1:] |
|
|
|
if self.is_decoder and encoder_hidden_states is not None: |
|
cross_attention_outputs = self.cross_attention_layer( |
|
hidden_states, |
|
key_value_states=encoder_hidden_states, |
|
attention_mask=encoder_attention_mask, |
|
position_bias=encoder_decoder_position_bias, |
|
) |
|
hidden_states = cross_attention_outputs[0] |
|
|
|
|
|
attention_outputs = attention_outputs + cross_attention_outputs[1:] |
|
|
|
|
|
hidden_states = self.ff_layer(hidden_states) |
|
|
|
outputs = (hidden_states,) + attention_outputs |
|
return outputs |
|
|
|
|
|
class FlashT5Stack(nn.Module, ModuleUtilsMixin): |
|
def __init__(self, config, embed_tokens): |
|
super().__init__() |
|
assert embed_tokens is not None |
|
|
|
self.config = config |
|
self.embed_tokens = embed_tokens |
|
self.is_decoder = config.is_decoder |
|
|
|
self.block = nn.ModuleList( |
|
[FlashT5Block(config, has_positional_encoding=bool(i == 0)) for i in range(config.num_layers)] |
|
) |
|
|
|
self.final_layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm) |
|
self.dropout = nn.Dropout(config.dropout_rate) |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
|
|
attention_mask=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
inputs_embeds=None, |
|
head_mask=None, |
|
cross_attn_head_mask=None, |
|
past_key_values=None, |
|
use_cache=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
) -> BaseModelOutput: |
|
input_shape = input_ids.size() |
|
batch_size, seq_length = input_shape |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
if torch.is_autocast_enabled() and input_ids.device.type == 'cuda': |
|
inputs_embeds = inputs_embeds.to(torch.get_autocast_gpu_dtype()) |
|
|
|
|
|
if attention_mask is None: |
|
attention_mask = torch.ones(batch_size, seq_length, device=inputs_embeds.device, dtype=torch.bool) |
|
|
|
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: |
|
encoder_seq_length = encoder_hidden_states.shape[1] |
|
encoder_attention_mask = torch.ones( |
|
batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.bool |
|
) |
|
|
|
position_bias = None |
|
encoder_decoder_position_bias = None |
|
|
|
hidden_states = self.dropout(inputs_embeds) |
|
|
|
for _, layer_module in enumerate(self.block): |
|
layer_outputs = layer_module( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
position_bias=position_bias, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
encoder_decoder_position_bias=encoder_decoder_position_bias, |
|
) |
|
|
|
|
|
position_bias = layer_outputs[1] |
|
if self.is_decoder and encoder_hidden_states is not None: |
|
encoder_decoder_position_bias = layer_outputs[2] |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
hidden_states = self.final_layer_norm(hidden_states).type_as(hidden_states) |
|
hidden_states = self.dropout(hidden_states) |
|
|
|
return BaseModelOutput( |
|
last_hidden_state=hidden_states |
|
) |
|
|
|
|
|
|
|
class FlashT5PreTrainedModel(PreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
|
models. |
|
""" |
|
|
|
config_class = FlashT5Config |
|
base_model_prefix = "transformer" |
|
is_parallelizable = False |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["FlashT5Block"] |
|
_keep_in_fp32_modules = [] |
|
|
|
def _init_weights(self, module): |
|
factor = self.config.initializer_factor |
|
if isinstance(module, FlashT5LayerNorm): |
|
module.weight.data.fill_(factor * 1.0) |
|
elif isinstance(module, (FlashT5ForConditionalGeneration)): |
|
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) |
|
if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: |
|
module.lm_head.weight.data.normal_(mean=0.0, std=factor * self.config.d_model ** -0.5) |
|
elif isinstance(module, FlashT5DenseGatedAct): |
|
d_ff, d_model = module.wi_0.weight.data.size() |
|
module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) |
|
module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) |
|
elif isinstance(module, FlashT5LayerFF): |
|
d_ff, d_model = module.wo.weight.data.size() |
|
module.wo.weight.data.normal_(mean=0.0, std=factor * ((d_ff) ** -0.5)) |
|
elif isinstance(module, FlashT5Attention): |
|
d_model = self.config.d_model |
|
key_value_proj_dim = self.config.d_kv |
|
n_heads = self.config.num_heads |
|
module.Wq.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) |
|
module.Wk.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) |
|
module.Wv.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) |
|
module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) |
|
if module.has_positional_encoding: |
|
if hasattr(module.pe_encoding, "relative_attention_bias"): |
|
module.pe_encoding.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) |
|
|
|
def _shift_right(self, input_ids): |
|
decoder_start_token_id = self.config.decoder_start_token_id |
|
pad_token_id = self.config.pad_token_id |
|
|
|
shifted_input_ids = input_ids.new_zeros(input_ids.shape) |
|
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() |
|
shifted_input_ids[..., 0] = decoder_start_token_id |
|
|
|
|
|
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) |
|
|
|
return shifted_input_ids |
|
|
|
|
|
class FlashT5Model(FlashT5PreTrainedModel): |
|
|
|
def __init__(self, config: FlashT5Config): |
|
super().__init__(config) |
|
self.shared = nn.Embedding(config.vocab_size, config.d_model) |
|
|
|
encoder_config = copy.deepcopy(config) |
|
encoder_config.is_decoder = False |
|
encoder_config.use_cache = False |
|
encoder_config.is_encoder_decoder = False |
|
self.encoder = FlashT5Stack(encoder_config, self.shared) |
|
|
|
decoder_config = copy.deepcopy(config) |
|
decoder_config.is_decoder = True |
|
decoder_config.is_encoder_decoder = False |
|
decoder_config.num_layers = config.num_decoder_layers |
|
self.decoder = FlashT5Stack(decoder_config, self.shared) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
self.model_parallel = False |
|
self.device_map = None |
|
|
|
def get_input_embeddings(self): |
|
return self.shared |
|
|
|
def set_input_embeddings(self, new_embeddings): |
|
self.shared = new_embeddings |
|
self.encoder.set_input_embeddings(new_embeddings) |
|
self.decoder.set_input_embeddings(new_embeddings) |
|
|
|
def get_encoder(self): |
|
return self.encoder |
|
|
|
def get_decoder(self): |
|
return self.decoder |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
|
|
attention_mask=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
inputs_embeds=None, |
|
head_mask=None, |
|
cross_attn_head_mask=None, |
|
past_key_values=None, |
|
use_cache=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: |
|
|
|
|
|
if encoder_outputs is None: |
|
encoder_outputs = self.encoder( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
inputs_embeds=inputs_embeds |
|
) |
|
|
|
hidden_states = encoder_outputs[0] |
|
|
|
|
|
decoder_outputs = self.decoder( |
|
input_ids=decoder_input_ids, |
|
attention_mask=decoder_attention_mask, |
|
inputs_embeds=decoder_inputs_embeds, |
|
encoder_hidden_states=hidden_states, |
|
encoder_attention_mask=attention_mask |
|
) |
|
|
|
return Seq2SeqModelOutput( |
|
last_hidden_state=decoder_outputs.last_hidden_state, |
|
decoder_hidden_states=decoder_outputs.hidden_states, |
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state, |
|
encoder_hidden_states=encoder_outputs.hidden_states, |
|
) |
|
|
|
class FlashT5ForConditionalGeneration(FlashT5PreTrainedModel): |
|
|
|
def __init__(self, config: FlashT5Config): |
|
super().__init__(config) |
|
config.is_encoder_decoder = False |
|
assert not config.tie_word_embeddings |
|
|
|
self.config = config |
|
self.model_dim = config.d_model |
|
self.shared = nn.Embedding(config.vocab_size, config.d_model) |
|
|
|
encoder_config = copy.deepcopy(config) |
|
encoder_config.is_decoder = False |
|
self.encoder = FlashT5Stack(encoder_config, self.shared) |
|
|
|
decoder_config = copy.deepcopy(config) |
|
decoder_config.is_decoder = True |
|
decoder_config.num_layers = config.num_decoder_layers |
|
self.decoder = FlashT5Stack(decoder_config, self.shared) |
|
|
|
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
|
self.loss_fct = FlashT5CrossEntropyLoss(z_loss_factor=config.z_loss, |
|
label_smoothing=config.label_smoothing, |
|
use_triton_crossentropy=config.use_triton_crossentropy, |
|
inplace_backward=config.crossentropy_inplace_backward) |
|
|
|
|
|
self.post_init() |
|
|
|
def prepare_inputs_for_generation( |
|
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs |
|
): |
|
|
|
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask} |
|
|
|
return model_inputs |
|
|
|
def get_input_embeddings(self): |
|
return self.shared |
|
|
|
def set_input_embeddings(self, value): |
|
self.shared = value |
|
|
|
def generate( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
max_length = 32, |
|
**kwargs, |
|
) -> torch.LongTensor: |
|
""" |
|
input_ids: B x L_encoder, int64 |
|
attention_mask: B x L_encoder, int64 |
|
1 for tokens to attend to, 0 for tokens to ignore |
|
|
|
Generation: |
|
Starts with 0, ends with 1, padding is 0 |
|
|
|
# For 20 input/outputs, the diff between my implementation and HF is 9.8s vs 11.4s |
|
""" |
|
B, _ = input_ids.size() |
|
labels = torch.zeros(B, 1, dtype=torch.long, device=input_ids.device) |
|
encoder_hidden_states = None |
|
|
|
for _ in range(max_length): |
|
out = self.forward( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
decoder_input_ids=labels, |
|
encoder_hidden_states=encoder_hidden_states, |
|
) |
|
encoder_hidden_states = out.encoder_hidden_states |
|
top_labels = out.logits[:, -1].argmax(-1).unsqueeze(-1) |
|
labels = torch.cat([labels, top_labels], dim=-1) |
|
|
|
if (labels == 1).sum(-1).clamp(min=0, max=1).sum().item() == B: |
|
break |
|
|
|
labels[:, -1] = 1 |
|
|
|
|
|
B, L = labels.size() |
|
mask = torch.arange(L, device=labels.device).unsqueeze(0) <= (labels == 1).long().argmax(-1).unsqueeze(-1) |
|
labels = labels.masked_fill(~mask, 0) |
|
|
|
return labels |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
decoder_input_ids: Optional[torch.LongTensor] = None, |
|
decoder_attention_mask: Optional[torch.BoolTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
) -> Seq2SeqLMOutput: |
|
""" |
|
input_ids: B x L_encoder, int64 |
|
attention_mask: B x L_encoder, int64 |
|
1 for tokens to attend to, 0 for tokens to ignore |
|
labels: B x L_decoder, int64 |
|
""" |
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = self.encoder( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
)[0] |
|
|
|
hidden_states = encoder_hidden_states |
|
|
|
if labels is not None and decoder_input_ids is None: |
|
decoder_input_ids = self._shift_right(labels) |
|
|
|
decoder_outputs = self.decoder( |
|
input_ids=decoder_input_ids, |
|
attention_mask=decoder_attention_mask, |
|
encoder_hidden_states=hidden_states, |
|
encoder_attention_mask=attention_mask, |
|
) |
|
|
|
sequence_output = decoder_outputs[0] |
|
lm_logits = self.lm_head(sequence_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = self.loss_fct(lm_logits, labels) |
|
|
|
return Seq2SeqLMOutput( |
|
loss=loss, |
|
logits=lm_logits, |
|
encoder_hidden_states=encoder_hidden_states |
|
) |
|
|
|
|
|
class FlashT5EncoderModel(FlashT5PreTrainedModel): |
|
def __init__(self, config: FlashT5Config): |
|
super().__init__(config) |
|
self.shared = nn.Embedding(config.vocab_size, config.d_model) |
|
encoder_config = copy.deepcopy(config) |
|
encoder_config.use_cache = False |
|
encoder_config.is_encoder_decoder = False |
|
self.encoder = FlashT5Stack(encoder_config, self.shared) |
|
|
|
self.post_init() |
|
|
|
self.model_parallel = False |
|
self.device_map = None |
|
def get_input_embeddings(self): |
|
return self.shared |
|
def set_input_embeddings(self, new_embeddings): |
|
self.shared = new_embeddings |
|
self.encoder.set_input_embeddings(new_embeddings) |
|
def get_encoder(self): |
|
return self.encoder |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
token_type_ids: Optional[bool] = None, |
|
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: |
|
encoder_outputs = self.encoder( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
inputs_embeds=inputs_embeds |
|
) |
|
return encoder_outputs |