0-hero's picture
Add files using upload-large-folder tool
01cd082 verified
raw
history blame
16.8 kB
# gpt2-model-positional-encodings.py
import math
import inspect
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
# Import necessary modules for different positional encodings
import numpy as np
import scipy.special
import scipy.signal
from packaging import version
# Check if scaled_dot_product_attention is available and supports flash attention
use_flash_attn = 'scaled_dot_product_attention' in dir(F) and version.parse(torch.__version__) >= version.parse('2.0.0')
if use_flash_attn:
print("Flash Attention v2 is available and will be used where possible.")
else:
print("Flash Attention v2 is not available. Using standard attention.")
class LayerNorm(nn.Module):
"""LayerNorm with optional bias."""
def __init__(self, ndim, bias):
super().__init__()
self.weight = nn.Parameter(torch.ones(ndim))
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
def forward(self, input):
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
def get_positional_encoding(position, d_model, method, max_len=5000):
"""
Generate positional encodings based on the specified method.
"""
if method == 'default':
return None # Handled by nn.Embedding in the model
elif method == 'learned':
return None # Handled by nn.Embedding in the model
elif method == 'sinusoidal':
pe = torch.zeros(max_len, d_model)
position_enc = position.unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position_enc * div_term)
pe[:, 1::2] = torch.cos(position_enc * div_term)
return pe
elif method == 'exponential':
pe = torch.exp(-position.float() / max_len).unsqueeze(1).repeat(1, d_model)
return pe
elif method == 'polynomial_legendre':
pe = torch.zeros(max_len, d_model)
x = (position / max_len * 2) - 1 # Scale positions to [-1,1]
for i in range(d_model):
pe[:, i] = scipy.special.eval_legendre(i, x)
return pe
elif method == 'polynomial_chebyshev':
pe = torch.zeros(max_len, d_model)
x = (position / max_len * 2) - 1 # Scale positions to [-1,1]
for i in range(d_model):
pe[:, i] = scipy.special.eval_chebyt(i, x)
return pe
elif method == 'gaussian':
pe = torch.zeros(max_len, d_model)
positions = position.float()
means = torch.linspace(0, max_len, d_model)
std = max_len / d_model
for i in range(d_model):
pe[:, i] = torch.exp(- ((positions - means[i]) **2) / (2 * std **2))
return pe
elif method == 'random_fourier':
B = torch.randn(d_model, 1)
x = position.float() / max_len
x = x @ B.T * 2 * math.pi
pe = torch.cat([torch.sin(x), torch.cos(x)], dim=1)
return pe[:, :d_model]
elif method == 'wavelet':
pe = torch.zeros(max_len, d_model)
scales = torch.arange(1, d_model+1)
x = position.float()
for i in range(d_model):
wavelet = scipy.signal.ricker(points=max_len, a=scales[i])
pe[:, i] = torch.from_numpy(wavelet[position])
return pe
elif method == 'bessel':
pe = torch.zeros(max_len, d_model)
x = position.float()
for i in range(d_model):
pe[:, i] = scipy.special.jv(i, x)
return pe
elif method == 'alternative':
pe = torch.zeros(max_len, d_model)
position_enc = position.float()
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.tan(position_enc * div_term)
pe[:, 1::2] = torch.sin(position_enc * div_term + math.pi / 4)
return pe
elif method == 'none':
return torch.zeros(max_len, d_model)
else:
raise ValueError(f"Unknown positional encoding method: {method}")
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
assert config.n_embd % config.n_head == 0
self.n_head = config.n_head
self.n_embd = config.n_embd
self.dropout = config.dropout
self.head_dim = self.n_embd // self.n_head
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.resid_dropout = nn.Dropout(config.dropout)
# Implement attention-level positional encodings
if config.attention_type == 'rope':
self.rotary_dim = self.n_embd // self.n_head
if self.rotary_dim % 2 != 0:
self.rotary_dim -= self.rotary_dim % 2 # Ensure even dimension
inv_freq = 1.0 / (10000 ** (torch.arange(0, self.rotary_dim, 2).float() / self.rotary_dim))
self.register_buffer('inv_freq', inv_freq)
elif config.attention_type == 'alibi':
slopes = self.get_alibi_slopes(self.n_head)
self.register_buffer('alibi_slopes', slopes)
elif config.attention_type == 'relative':
num_rel_dis = 2 * config.block_size - 1
self.relative_positions = nn.Embedding(num_rel_dis, self.n_head)
# else: default attention (nothing extra to define)
def get_alibi_slopes(self, n_heads):
def get_slopes(n):
import math
def get_slopes_power_of_2(n):
start = 2 ** (-2 ** -(math.log2(n) - 3))
ratio = start
return [start * (ratio ** i) for i in range(n)]
if math.log2(n).is_integer():
return torch.Tensor(get_slopes_power_of_2(n))
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
slopes = get_slopes_power_of_2(closest_power_of_2)
extra_slopes = get_slopes(2 * closest_power_of_2)[0::2][:n - closest_power_of_2]
return torch.Tensor(slopes + extra_slopes)
slopes = get_slopes(n_heads)
return slopes.view(n_heads, 1, 1)
def apply_rope(self, x):
# x: (B, n_head, T, head_dim)
seq_len = x.size(-2)
device = x.device
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((freqs.sin(), freqs.cos()), dim=-1) # (T, rotary_dim)
emb = emb[None, None, :, :] # (1, 1, T, rotary_dim)
x1 = x[..., :self.rotary_dim]
x2 = x[..., self.rotary_dim:]
x1_rot = x1 * emb + torch.flip(x1, dims=[-1]) * torch.flip(emb, dims=[-1])
x = torch.cat((x1_rot, x2), dim=-1)
return x
def forward(self, x, layer_past=None):
B, T, C = x.size()
qkv = self.c_attn(x).view(B, T, 3, self.n_head, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, n_head, T, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2] # Each is (B, n_head, T, head_dim)
if self.config.attention_type == 'rope':
q = self.apply_rope(q)
k = self.apply_rope(k)
# Decide whether to use Flash Attention based on training/evaluation mode and tracking flags
if use_flash_attn and self.config.attention_type in ['default', 'rope'] and not (self.config.track_attention_patterns and not self.training):
# Use PyTorch's scaled_dot_product_attention which leverages Flash Attention 2
y = F.scaled_dot_product_attention(
q, k, v, attn_mask=None,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True
)
else:
# Standard attention mechanism
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if self.config.attention_type == 'alibi':
position_ids = torch.arange(T, device=x.device).unsqueeze(0).unsqueeze(0)
alibi = self.alibi_slopes.to(x.device) * position_ids # (n_head, 1, T)
attn_scores = attn_scores + alibi
elif self.config.attention_type == 'relative':
positions = torch.arange(-T+1, T, device=x.device)
rel_pos = self.relative_positions(positions + T -1)
attn_scores = attn_scores + rel_pos
# Apply causal mask
causal_mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
attn_scores = attn_scores.masked_fill(causal_mask == 0, float('-inf'))
attn_weights = F.softmax(attn_scores, dim=-1)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
# Collect attention patterns if required
if self.config.track_attention_patterns and not self.training:
self.attn_weights = attn_weights.detach().cpu()
y = torch.matmul(attn_weights, v)
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.resid_dropout(self.c_proj(y))
return y
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
self.gelu = nn.GELU()
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
x = self.c_fc(x)
x = self.gelu(x)
x = self.c_proj(x)
x = self.dropout(x)
return x
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
self.attn = CausalSelfAttention(config)
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
@dataclass
class GPTConfig:
block_size: int = 1024
vocab_size: int = 50304
n_layer: int = 12
n_head: int = 12
n_embd: int = 768
dropout: float = 0.0
bias: bool = True
embedding_type: str = 'default' # Default uses learned positional embeddings
attention_type: str = 'default' # Default attention without any modifications
track_activations: bool = False
track_attention_patterns: bool = False
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
assert config.vocab_size is not None
assert config.block_size is not None
self.config = config
self.transformer = nn.ModuleDict()
self.transformer['wte'] = nn.Embedding(config.vocab_size, config.n_embd)
if config.embedding_type in ['learned', 'default']:
self.transformer['wpe'] = nn.Embedding(config.block_size, config.n_embd)
self.pos_emb = None
elif config.embedding_type == 'none':
self.transformer['wpe'] = None
self.pos_emb = None
else:
self.transformer['wpe'] = None
position = torch.arange(0, config.block_size)
pe = get_positional_encoding(position, config.n_embd, config.embedding_type, config.block_size)
self.register_buffer('pos_emb', pe)
self.transformer['drop'] = nn.Dropout(config.dropout)
self.transformer['h'] = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
self.transformer['ln_f'] = LayerNorm(config.n_embd, bias=config.bias)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer['wte'].weight = self.lm_head.weight # Weight tying
self.apply(self._init_weights)
for pn, p in self.named_parameters():
if pn.endswith('c_proj.weight'):
nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
# Initialize activations and attention patterns
self.activations = []
self.attention_patterns = []
print("Number of parameters: {:.2f}M".format(self.get_num_params() / 1e6))
def get_num_params(self, non_embedding=True):
n_params = sum(p.numel() for p in self.parameters())
if non_embedding and self.transformer['wpe'] is not None:
n_params -= self.transformer['wpe'].weight.numel()
return n_params
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
device = idx.device
b, t = idx.size()
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
tok_emb = self.transformer['wte'](idx) # token embeddings
if self.config.embedding_type in ['learned', 'default']:
pos_emb = self.transformer['wpe'](pos)
x = tok_emb + pos_emb
elif self.config.embedding_type == 'none':
x = tok_emb
else:
pos_emb = self.pos_emb[:t, :].to(device)
x = tok_emb + pos_emb.unsqueeze(0)
x = self.transformer['drop'](x)
# Reset activations and attention patterns if tracking
if self.config.track_activations and not self.training:
self.activations = []
if self.config.track_attention_patterns and not self.training:
self.attention_patterns = []
for block in self.transformer['h']:
x = block(x)
if self.config.track_activations and not self.training:
self.activations.append(x.detach().cpu())
if self.config.track_attention_patterns and not self.training:
if hasattr(block.attn, 'attn_weights'):
self.attention_patterns.append(block.attn.attn_weights)
x = self.transformer['ln_f'](x)
logits = self.lm_head(x)
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
loss = None
return logits, loss
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
# Start with all candidate parameters
param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
optim_groups = [
{'params': decay_params, 'weight_decay': weight_decay},
{'params': nodecay_params, 'weight_decay': 0.0},
]
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and device_type == 'cuda'
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
print(f"Using fused AdamW: {use_fused}")
return optimizer
def estimate_mfu(self, fwdbwd_per_iter, dt):
"""Estimate model flops utilization (MFU)"""
N = self.get_num_params()
cfg = self.config
L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd // cfg.n_head, cfg.block_size
flops_per_token = 6 * N + 12 * L * H * Q * T
flops_per_fwdbwd = flops_per_token * T
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
flops_achieved = flops_per_iter * (1.0 / dt)
flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
mfu = flops_achieved / flops_promised
return mfu
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
"""Generate sequences of tokens from the model"""
for _ in range(max_new_tokens):
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx