File size: 6,791 Bytes
0fdc3d9 121854f 0fdc3d9 121854f 0fdc3d9 121854f 0fdc3d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoTokenizer, LlamaForCausalLM
from transformers.models.llama.modeling_llama import LlamaRMSNorm
# Custom Modules
class AdaptiveRMSNorm(nn.Module):
"""
Adaptive RMSNorm layer where the scaling parameter adapts based on input.
"""
def __init__(self, normalized_shape, adaptive_dim, eps=1e-6):
super(AdaptiveRMSNorm, self).__init__()
self.normalized_shape = normalized_shape
self.eps = eps
# Standard RMSNorm weight parameter
self.weight = nn.Parameter(torch.ones(normalized_shape))
# Adaptive scaling parameter
self.fc_gamma = nn.Linear(adaptive_dim, normalized_shape)
def forward(self, x, adapt_input):
# Compute adaptive scaling factor gamma
gamma = self.fc_gamma(adapt_input).unsqueeze(1) # Shape: [batch_size, 1, hidden_size]
# Compute RMSNorm
norm_x = x / x.norm(dim=-1, keepdim=True).clamp(min=self.eps)
# Apply adaptive scaling
return self.weight * norm_x * gamma
class TokenMixing(nn.Module):
"""
Token Mixing layer that performs depthwise convolution across the sequence dimension.
"""
def __init__(self, hidden_size):
super(TokenMixing, self).__init__()
self.token_mixing = nn.Conv1d(
in_channels=hidden_size,
out_channels=hidden_size,
kernel_size=3,
padding=1,
groups=hidden_size # Depthwise convolution
)
def forward(self, x):
# x shape: [batch_size, seq_length, hidden_size]
x = x.transpose(1, 2) # Shape: [batch_size, hidden_size, seq_length]
x = self.token_mixing(x)
x = x.transpose(1, 2) # Shape back to [batch_size, seq_length, hidden_size]
return x
class SEBlock(nn.Module):
"""
Squeeze-and-Excitation block that adaptively recalibrates channel-wise features.
"""
def __init__(self, hidden_size, reduction=16):
super(SEBlock, self).__init__()
self.fc = nn.Sequential(
nn.Linear(hidden_size, hidden_size // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(hidden_size // reduction, hidden_size, bias=False),
nn.Sigmoid()
)
def forward(self, x):
# x shape: [batch_size, seq_length, hidden_size]
y = x.mean(dim=1) # Global average pooling over sequence length
y = self.fc(y) # Squeeze and Excitation
y = y.unsqueeze(1) # Shape: [batch_size, 1, hidden_size]
return x * y # Scale the original input
# Modified Decoder Layer
class ModifiedLlamaDecoderLayer(nn.Module):
"""
Modified Llama Decoder Layer with AdaptiveRMSNorm, TokenMixing, and SEBlock.
"""
def __init__(self, original_layer, config):
super().__init__()
self.hidden_size = config.hidden_size
self.adaptive_dim = config.hidden_size # Using hidden_size for adapt_input
# Copy the original attention and MLP layers
self.self_attn = original_layer.self_attn
self.mlp = original_layer.mlp
# Replace RMSNorm layers with AdaptiveRMSNorm
self.input_layernorm = AdaptiveRMSNorm(self.hidden_size, self.adaptive_dim, eps=config.rms_norm_eps)
self.post_attention_layernorm = AdaptiveRMSNorm(self.hidden_size, self.adaptive_dim, eps=config.rms_norm_eps)
# Add Token Mixing Layer
self.token_mixing = TokenMixing(self.hidden_size)
# Add SE Block
self.se_block = SEBlock(self.hidden_size, reduction=16)
def forward(
self,
hidden_states,
attention_mask=None,
position_ids=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
**kwargs, # Capture additional arguments
):
# Compute adaptation input
adapt_input = hidden_states.mean(dim=1) # Shape: [batch_size, hidden_size]
residual = hidden_states
# Input layer normalization with adaptive RMSNorm
hidden_states = self.input_layernorm(hidden_states, adapt_input)
# Self-attention
attn_outputs = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
**kwargs, # Pass additional arguments to self_attn
)
attn_output = attn_outputs[0]
if use_cache:
present_key_value = attn_outputs[1]
else:
present_key_value = None
if output_attentions:
attn_weights = attn_outputs[-1]
else:
attn_weights = None
hidden_states = residual + attn_output
# Token Mixing
token_mixed = self.token_mixing(hidden_states)
hidden_states = hidden_states + token_mixed
# Post-attention layer normalization with adaptive RMSNorm
hidden_states = self.post_attention_layernorm(hidden_states, adapt_input)
# MLP
residual = hidden_states
hidden_states = self.mlp(hidden_states)
# SE Block
hidden_states = self.se_block(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if use_cache:
outputs += (present_key_value,)
if output_attentions:
outputs += (attn_weights,)
return outputs
# Load the pre-trained model
# Load the configuration from the pre-trained model
config = AutoConfig.from_pretrained('Josephgflowers/TinyLlama-v1.1-Cinders-World')
# Load the pre-trained model
pretrained_model = LlamaForCausalLM.from_pretrained('Josephgflowers/TinyLlama-v1.1-Cinders-World')
# Replace the decoder layers with modified layers
for i in range(config.num_hidden_layers):
# Original layer
original_layer = pretrained_model.model.layers[i]
# Replace with modified layer
pretrained_model.model.layers[i] = ModifiedLlamaDecoderLayer(original_layer, config)
# The modified model is now ready
modified_model = pretrained_model
# Save the model and tokenizer
output_dir = "./saved_model"
modified_model.save_pretrained(output_dir)
tokenizer = AutoTokenizer.from_pretrained('Josephgflowers/TinyLlama-v1.1-Cinders-World', legacy=False)
tokenizer.save_pretrained(output_dir)
print(f"Model and tokenizer saved to {output_dir}")
# Example Usage
input_text = "Hello, how are you?"
input_ids = tokenizer.encode(input_text, return_tensors='pt')
# Forward pass
outputs = modified_model(input_ids=input_ids)
logits = outputs.logits
print("Logits shape:", logits.shape) # Should be [batch_size, seq_length, vocab_size]
|