Josephgflowers commited on
Commit
0fdc3d9
·
verified ·
1 Parent(s): 1eab743

Upload LM.py

Browse files
Files changed (1) hide show
  1. LM.py +201 -0
LM.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoConfig, AutoTokenizer, LlamaForCausalLM
4
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
5
+
6
+ # Custom Modules
7
+
8
+ class AdaptiveRMSNorm(nn.Module):
9
+ """
10
+ Adaptive RMSNorm layer where the scaling parameter adapts based on input.
11
+ """
12
+ def __init__(self, normalized_shape, adaptive_dim, eps=1e-6):
13
+ super(AdaptiveRMSNorm, self).__init__()
14
+ self.normalized_shape = normalized_shape
15
+ self.eps = eps
16
+
17
+ # Standard RMSNorm weight parameter
18
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
19
+
20
+ # Adaptive scaling parameter
21
+ self.fc_gamma = nn.Linear(adaptive_dim, normalized_shape)
22
+
23
+ def forward(self, x, adapt_input):
24
+ # Compute adaptive scaling factor gamma
25
+ gamma = self.fc_gamma(adapt_input).unsqueeze(1) # Shape: [batch_size, 1, hidden_size]
26
+
27
+ # Compute RMSNorm
28
+ norm_x = x / x.norm(dim=-1, keepdim=True).clamp(min=self.eps)
29
+
30
+ # Apply adaptive scaling
31
+ return self.weight * norm_x * gamma
32
+
33
+ class TokenMixing(nn.Module):
34
+ """
35
+ Token Mixing layer that performs depthwise convolution across the sequence dimension.
36
+ """
37
+ def __init__(self, hidden_size):
38
+ super(TokenMixing, self).__init__()
39
+ self.token_mixing = nn.Conv1d(
40
+ in_channels=hidden_size,
41
+ out_channels=hidden_size,
42
+ kernel_size=3,
43
+ padding=1,
44
+ groups=hidden_size # Depthwise convolution
45
+ )
46
+
47
+ def forward(self, x):
48
+ # x shape: [batch_size, seq_length, hidden_size]
49
+ x = x.transpose(1, 2) # Shape: [batch_size, hidden_size, seq_length]
50
+ x = self.token_mixing(x)
51
+ x = x.transpose(1, 2) # Shape back to [batch_size, seq_length, hidden_size]
52
+ return x
53
+
54
+ class SEBlock(nn.Module):
55
+ """
56
+ Squeeze-and-Excitation block that adaptively recalibrates channel-wise features.
57
+ """
58
+ def __init__(self, hidden_size, reduction=16):
59
+ super(SEBlock, self).__init__()
60
+ self.fc = nn.Sequential(
61
+ nn.Linear(hidden_size, hidden_size // reduction, bias=False),
62
+ nn.ReLU(inplace=True),
63
+ nn.Linear(hidden_size // reduction, hidden_size, bias=False),
64
+ nn.Sigmoid()
65
+ )
66
+
67
+ def forward(self, x):
68
+ # x shape: [batch_size, seq_length, hidden_size]
69
+ y = x.mean(dim=1) # Global average pooling over sequence length
70
+ y = self.fc(y) # Squeeze and Excitation
71
+ y = y.unsqueeze(1) # Shape: [batch_size, 1, hidden_size]
72
+ return x * y # Scale the original input
73
+
74
+ # Modified Decoder Layer
75
+
76
+ class ModifiedLlamaDecoderLayer(nn.Module):
77
+ """
78
+ Modified Llama Decoder Layer with AdaptiveRMSNorm, TokenMixing, and SEBlock.
79
+ """
80
+ def __init__(self, original_layer, config):
81
+ super().__init__()
82
+ self.hidden_size = config.hidden_size
83
+ self.adaptive_dim = config.hidden_size # Using hidden_size for adapt_input
84
+
85
+ # Copy the original attention and MLP layers
86
+ self.self_attn = original_layer.self_attn
87
+ self.mlp = original_layer.mlp
88
+
89
+ # Replace RMSNorm layers with AdaptiveRMSNorm
90
+ self.input_layernorm = AdaptiveRMSNorm(self.hidden_size, self.adaptive_dim, eps=config.rms_norm_eps)
91
+ self.post_attention_layernorm = AdaptiveRMSNorm(self.hidden_size, self.adaptive_dim, eps=config.rms_norm_eps)
92
+
93
+ # Add Token Mixing Layer
94
+ self.token_mixing = TokenMixing(self.hidden_size)
95
+
96
+ # Add SE Block
97
+ self.se_block = SEBlock(self.hidden_size, reduction=16)
98
+
99
+ def forward(
100
+ self,
101
+ hidden_states,
102
+ attention_mask=None,
103
+ position_ids=None,
104
+ past_key_value=None,
105
+ use_cache=False,
106
+ output_attentions=False,
107
+ **kwargs, # Capture additional arguments
108
+ ):
109
+ # Compute adaptation input
110
+ adapt_input = hidden_states.mean(dim=1) # Shape: [batch_size, hidden_size]
111
+
112
+ residual = hidden_states
113
+
114
+ # Input layer normalization with adaptive RMSNorm
115
+ hidden_states = self.input_layernorm(hidden_states, adapt_input)
116
+
117
+ # Self-attention
118
+ attn_outputs = self.self_attn(
119
+ hidden_states=hidden_states,
120
+ attention_mask=attention_mask,
121
+ position_ids=position_ids,
122
+ past_key_value=past_key_value,
123
+ output_attentions=output_attentions,
124
+ use_cache=use_cache,
125
+ **kwargs, # Pass additional arguments to self_attn
126
+ )
127
+ attn_output = attn_outputs[0]
128
+ if use_cache:
129
+ present_key_value = attn_outputs[1]
130
+ else:
131
+ present_key_value = None
132
+ if output_attentions:
133
+ attn_weights = attn_outputs[-1]
134
+ else:
135
+ attn_weights = None
136
+
137
+ hidden_states = residual + attn_output
138
+
139
+ # Token Mixing
140
+ token_mixed = self.token_mixing(hidden_states)
141
+ hidden_states = hidden_states + token_mixed
142
+
143
+ # Post-attention layer normalization with adaptive RMSNorm
144
+ hidden_states = self.post_attention_layernorm(hidden_states, adapt_input)
145
+
146
+ # MLP
147
+ residual = hidden_states
148
+ hidden_states = self.mlp(hidden_states)
149
+
150
+ # SE Block
151
+ hidden_states = self.se_block(hidden_states)
152
+
153
+ hidden_states = residual + hidden_states
154
+
155
+ outputs = (hidden_states,)
156
+
157
+ if use_cache:
158
+ outputs += (present_key_value,)
159
+
160
+ if output_attentions:
161
+ outputs += (attn_weights,)
162
+
163
+ return outputs
164
+
165
+ # Load the pre-trained model
166
+
167
+ # Load the configuration from the pre-trained model
168
+ config = AutoConfig.from_pretrained('/home/joe/Music/220-agent')
169
+
170
+ # Load the pre-trained model
171
+ pretrained_model = LlamaForCausalLM.from_pretrained('/home/joe/Music/220-agent')
172
+
173
+ # Replace the decoder layers with modified layers
174
+ for i in range(config.num_hidden_layers):
175
+ # Original layer
176
+ original_layer = pretrained_model.model.layers[i]
177
+ # Replace with modified layer
178
+ pretrained_model.model.layers[i] = ModifiedLlamaDecoderLayer(original_layer, config)
179
+
180
+ # The modified model is now ready
181
+ modified_model = pretrained_model
182
+
183
+ # Save the model and tokenizer
184
+ output_dir = "./saved_model"
185
+ modified_model.save_pretrained(output_dir)
186
+ tokenizer = AutoTokenizer.from_pretrained('/home/joe/Music/220-agent', legacy=False)
187
+ tokenizer.save_pretrained(output_dir)
188
+
189
+ print(f"Model and tokenizer saved to {output_dir}")
190
+
191
+ # Example Usage
192
+
193
+ input_text = "Hello, how are you?"
194
+ input_ids = tokenizer.encode(input_text, return_tensors='pt')
195
+
196
+ # Forward pass
197
+ outputs = modified_model(input_ids=input_ids)
198
+ logits = outputs.logits
199
+
200
+ print("Logits shape:", logits.shape) # Should be [batch_size, seq_length, vocab_size]
201
+