|
from typing import Optional, Tuple |
|
import torch |
|
from torch import nn |
|
from .configuration_italia import ItaliaConfig |
|
from transformers.models.gpt_neox import modeling_gpt_neox |
|
|
|
|
|
class GPTNeoXLayer(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.use_parallel_residual = config.use_parallel_residual |
|
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.post_attention_dropout = nn.Dropout(config.hidden_dropout) |
|
self.post_mlp_dropout = nn.Dropout(config.hidden_dropout) |
|
self.attention = modeling_gpt_neox.GPT_NEOX_ATTENTION_CLASSES[config._attn_implementation](config) |
|
self.mlp = modeling_gpt_neox.GPTNeoXMLP(config) |
|
|
|
def forward( |
|
self, |
|
hidden_states: Optional[torch.FloatTensor], |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = False, |
|
layer_past: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: Optional[bool] = False, |
|
): |
|
|
|
attention_layer_outputs = self.attention( |
|
self.input_layernorm(hidden_states), |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
layer_past=layer_past, |
|
head_mask=head_mask, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
) |
|
attn_output = attention_layer_outputs[0] |
|
attn_output = self.post_attention_dropout(attn_output) |
|
outputs = attention_layer_outputs[1:] |
|
|
|
|
|
|
|
mlp_output = self.mlp(self.input_layernorm(hidden_states)) |
|
mlp_output = self.post_mlp_dropout(mlp_output) |
|
hidden_states = mlp_output + attn_output + hidden_states |
|
|
|
if use_cache: |
|
outputs = (hidden_states,) + outputs |
|
else: |
|
outputs = (hidden_states,) + outputs[1:] |
|
|
|
return outputs |
|
|
|
modeling_gpt_neox.GPTNeoXLayer = GPTNeoXLayer |
|
|
|
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM, GPTNeoXModel |
|
|
|
class ItaliaForCausalLM(GPTNeoXForCausalLM): |
|
|
|
|
|
config_class = ItaliaConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.gpt_neox = GPTNeoXModel(config) |
|
self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=True) |
|
|
|
|
|
self.post_init() |
|
|