Josephgflowers's picture
Upload LM.py
121854f verified
raw
history blame
6.79 kB
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoTokenizer, LlamaForCausalLM
from transformers.models.llama.modeling_llama import LlamaRMSNorm
# Custom Modules
class AdaptiveRMSNorm(nn.Module):
"""
Adaptive RMSNorm layer where the scaling parameter adapts based on input.
"""
def __init__(self, normalized_shape, adaptive_dim, eps=1e-6):
super(AdaptiveRMSNorm, self).__init__()
self.normalized_shape = normalized_shape
self.eps = eps
# Standard RMSNorm weight parameter
self.weight = nn.Parameter(torch.ones(normalized_shape))
# Adaptive scaling parameter
self.fc_gamma = nn.Linear(adaptive_dim, normalized_shape)
def forward(self, x, adapt_input):
# Compute adaptive scaling factor gamma
gamma = self.fc_gamma(adapt_input).unsqueeze(1) # Shape: [batch_size, 1, hidden_size]
# Compute RMSNorm
norm_x = x / x.norm(dim=-1, keepdim=True).clamp(min=self.eps)
# Apply adaptive scaling
return self.weight * norm_x * gamma
class TokenMixing(nn.Module):
"""
Token Mixing layer that performs depthwise convolution across the sequence dimension.
"""
def __init__(self, hidden_size):
super(TokenMixing, self).__init__()
self.token_mixing = nn.Conv1d(
in_channels=hidden_size,
out_channels=hidden_size,
kernel_size=3,
padding=1,
groups=hidden_size # Depthwise convolution
)
def forward(self, x):
# x shape: [batch_size, seq_length, hidden_size]
x = x.transpose(1, 2) # Shape: [batch_size, hidden_size, seq_length]
x = self.token_mixing(x)
x = x.transpose(1, 2) # Shape back to [batch_size, seq_length, hidden_size]
return x
class SEBlock(nn.Module):
"""
Squeeze-and-Excitation block that adaptively recalibrates channel-wise features.
"""
def __init__(self, hidden_size, reduction=16):
super(SEBlock, self).__init__()
self.fc = nn.Sequential(
nn.Linear(hidden_size, hidden_size // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(hidden_size // reduction, hidden_size, bias=False),
nn.Sigmoid()
)
def forward(self, x):
# x shape: [batch_size, seq_length, hidden_size]
y = x.mean(dim=1) # Global average pooling over sequence length
y = self.fc(y) # Squeeze and Excitation
y = y.unsqueeze(1) # Shape: [batch_size, 1, hidden_size]
return x * y # Scale the original input
# Modified Decoder Layer
class ModifiedLlamaDecoderLayer(nn.Module):
"""
Modified Llama Decoder Layer with AdaptiveRMSNorm, TokenMixing, and SEBlock.
"""
def __init__(self, original_layer, config):
super().__init__()
self.hidden_size = config.hidden_size
self.adaptive_dim = config.hidden_size # Using hidden_size for adapt_input
# Copy the original attention and MLP layers
self.self_attn = original_layer.self_attn
self.mlp = original_layer.mlp
# Replace RMSNorm layers with AdaptiveRMSNorm
self.input_layernorm = AdaptiveRMSNorm(self.hidden_size, self.adaptive_dim, eps=config.rms_norm_eps)
self.post_attention_layernorm = AdaptiveRMSNorm(self.hidden_size, self.adaptive_dim, eps=config.rms_norm_eps)
# Add Token Mixing Layer
self.token_mixing = TokenMixing(self.hidden_size)
# Add SE Block
self.se_block = SEBlock(self.hidden_size, reduction=16)
def forward(
self,
hidden_states,
attention_mask=None,
position_ids=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
**kwargs, # Capture additional arguments
):
# Compute adaptation input
adapt_input = hidden_states.mean(dim=1) # Shape: [batch_size, hidden_size]
residual = hidden_states
# Input layer normalization with adaptive RMSNorm
hidden_states = self.input_layernorm(hidden_states, adapt_input)
# Self-attention
attn_outputs = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
**kwargs, # Pass additional arguments to self_attn
)
attn_output = attn_outputs[0]
if use_cache:
present_key_value = attn_outputs[1]
else:
present_key_value = None
if output_attentions:
attn_weights = attn_outputs[-1]
else:
attn_weights = None
hidden_states = residual + attn_output
# Token Mixing
token_mixed = self.token_mixing(hidden_states)
hidden_states = hidden_states + token_mixed
# Post-attention layer normalization with adaptive RMSNorm
hidden_states = self.post_attention_layernorm(hidden_states, adapt_input)
# MLP
residual = hidden_states
hidden_states = self.mlp(hidden_states)
# SE Block
hidden_states = self.se_block(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if use_cache:
outputs += (present_key_value,)
if output_attentions:
outputs += (attn_weights,)
return outputs
# Load the pre-trained model
# Load the configuration from the pre-trained model
config = AutoConfig.from_pretrained('Josephgflowers/TinyLlama-v1.1-Cinders-World')
# Load the pre-trained model
pretrained_model = LlamaForCausalLM.from_pretrained('Josephgflowers/TinyLlama-v1.1-Cinders-World')
# Replace the decoder layers with modified layers
for i in range(config.num_hidden_layers):
# Original layer
original_layer = pretrained_model.model.layers[i]
# Replace with modified layer
pretrained_model.model.layers[i] = ModifiedLlamaDecoderLayer(original_layer, config)
# The modified model is now ready
modified_model = pretrained_model
# Save the model and tokenizer
output_dir = "./saved_model"
modified_model.save_pretrained(output_dir)
tokenizer = AutoTokenizer.from_pretrained('Josephgflowers/TinyLlama-v1.1-Cinders-World', legacy=False)
tokenizer.save_pretrained(output_dir)
print(f"Model and tokenizer saved to {output_dir}")
# Example Usage
input_text = "Hello, how are you?"
input_ids = tokenizer.encode(input_text, return_tensors='pt')
# Forward pass
outputs = modified_model(input_ids=input_ids)
logits = outputs.logits
print("Logits shape:", logits.shape) # Should be [batch_size, seq_length, vocab_size]