Spaces:
Sleeping
Sleeping
File size: 5,049 Bytes
3479f48 b999262 1375e49 3479f48 b999262 1375e49 b999262 3479f48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import json
import torch
import torch.nn as nn
from torch.nn import functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# one head of self-attention using scaled-dot product attention
class Head(nn.Module):
def __init__(self, n_embed, head_size, context_size, dropout=0.1):
super().__init__()
self.key = nn.Linear(n_embed, head_size, bias=False)
self.query = nn.Linear(n_embed, head_size, bias=False)
self.value = nn.Linear(n_embed, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(context_size, context_size)))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B,T,C = x.shape
k = self.key(x)
q = self.query(x)
v = self.value(x)
tril = torch.tril(torch.ones(T, T, device=device))
wei = q @ k.transpose(-2, -1) * (C**-0.5)
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
wei = self.dropout(wei)
out = wei @ v
return out
class MultiHeadAttention(nn.Module):
def __init__(self, n_embed, num_heads, context_size, head_size, dropout):
super().__init__()
self.heads = nn.ModuleList([
Head(n_embed, head_size, context_size)
for _ in range(num_heads)
])
self.projection = nn.Linear(n_embed, n_embed)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
out = self.projection(out)
return self.dropout(out)
# simple feed forward layer
class FeedForward(nn.Module):
def __init__(self, n_embeds, dropout):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embeds, 4 * n_embeds),
nn.ReLU(),
# projection layer
nn.Linear(4 * n_embeds, n_embeds),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# Transformer block
class Block(nn.Module):
def __init__(self, n_embeds, n_head, context_size, dropout):
super().__init__()
head_size = n_embeds // n_head
self.sa = MultiHeadAttention(n_embeds, n_head, context_size, head_size, dropout)
self.ffwd = FeedForward(n_embeds, dropout)
self.ln1 = nn.LayerNorm(n_embeds)
self.ln2 = nn.LayerNorm(n_embeds)
def forward(self, x):
x = x + self.sa(self.ln1(x))
x = x + self.ffwd(self.ln2(x))
return x
# simple bigram model
class DecoderTransformer(nn.Module):
def __init__(self, vocab_size, n_embed, context_size, n_layer, n_head, dropout):
super().__init__()
self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
self.position_embedding_table = nn.Embedding(context_size, n_embed)
self.blocks = nn.Sequential(
*[Block(
n_embeds=n_embed,
n_head=n_head,
context_size=context_size,
dropout=dropout
) for _ in range(n_layer)]
)
self.ln_f = nn.LayerNorm(n_embed)
self.lm_head = nn.Linear(n_embed, vocab_size)
def forward(self, idx, targets=None):
B, T = idx.shape
# idx and targets of size (B,T)
token_embeds = self.token_embedding_table(idx) # yields (B, T, C)
pos_embeds = self.position_embedding_table(torch.arange(T, device=device))
x = token_embeds + pos_embeds
x = self.ln_f(self.blocks(x))
logits = self.lm_head(x)
if targets is None:
return logits, None
# reshape elements
B, T, C = logits.shape
logits = logits.view(B*T,C)
targets = targets.view(B*T)
# compute loss (CE)
loss = F.cross_entropy(logits, targets)
return logits, loss
def generate(self, idx, max_new_tokens=50, context_size=None, temperature=1.0):
if context_size is None:
context_size = int(self.position_embedding_table.weight.shape[0])
print(context_size)
for _ in range(max_new_tokens):
idx_cond = idx[:, -context_size:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:] / temperature
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, idx_next], dim=1)
return idx
class Tokenizer:
def __init__(self, vocab):
self.vocab = vocab
self.stoi = {ch: idx for idx, ch in enumerate(vocab)}
self.itos = {idx: ch for idx, ch in enumerate(vocab)}
def encode(self, s):
return [self.stoi[c] for c in s]
def decode(self, i):
return ''.join([self.itos[x] for x in i])
@classmethod
def from_pretrained(cls, path):
with open(path, 'r') as f:
vocab = json.load(f)
return cls(vocab)
def save_pretrained(self, path):
with open(path, 'w') as f:
json.dump(self.vocab, f)
|