philipp-zettl commited on
Commit
b999262
·
verified ·
1 Parent(s): d39726a

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +131 -0
model.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
6
+
7
+
8
+ # one head of self-attention using scaled-dot product attention
9
+ class Head(nn.Module):
10
+ def __init__(self, n_embed, head_size, context_size, dropout=0.1):
11
+ super().__init__()
12
+
13
+ self.key = nn.Linear(n_embed, head_size, bias=False)
14
+ self.query = nn.Linear(n_embed, head_size, bias=False)
15
+ self.value = nn.Linear(n_embed, head_size, bias=False)
16
+ self.register_buffer('tril', torch.tril(torch.ones(context_size, context_size)))
17
+
18
+ self.dropout = nn.Dropout(dropout)
19
+
20
+ def forward(self, x):
21
+ B,T,C = x.shape
22
+ k = self.key(x)
23
+ q = self.query(x)
24
+ v = self.value(x)
25
+
26
+ tril = torch.tril(torch.ones(T, T, device=device))
27
+ wei = q @ k.transpose(-2, -1) * (C**-0.5)
28
+ wei = wei.masked_fill(tril == 0, float('-inf'))
29
+ wei = F.softmax(wei, dim=-1)
30
+ wei = self.dropout(wei)
31
+ out = wei @ v
32
+ return out
33
+
34
+
35
+ class MultiHeadAttention(nn.Module):
36
+ def __init__(self, n_embed, num_heads, context_size, head_size, dropout):
37
+ super().__init__()
38
+
39
+ self.heads = nn.ModuleList([
40
+ Head(n_embed, head_size, context_size)
41
+ for _ in range(num_heads)
42
+ ])
43
+ self.projection = nn.Linear(n_embed, n_embed)
44
+ self.dropout = nn.Dropout(dropout)
45
+
46
+ def forward(self, x):
47
+ out = torch.cat([h(x) for h in self.heads], dim=-1)
48
+ out = self.projection(out)
49
+ return self.dropout(out)
50
+
51
+
52
+ # simple feed forward layer
53
+ class FeedForward(nn.Module):
54
+ def __init__(self, n_embeds, dropout):
55
+ super().__init__()
56
+ self.net = nn.Sequential(
57
+ nn.Linear(n_embeds, 4 * n_embeds),
58
+ nn.ReLU(),
59
+ # projection layer
60
+ nn.Linear(4 * n_embeds, n_embeds),
61
+ nn.Dropout(dropout)
62
+ )
63
+
64
+ def forward(self, x):
65
+ return self.net(x)
66
+
67
+
68
+ # Transformer block
69
+ class Block(nn.Module):
70
+ def __init__(self, n_embeds, n_head, context_size, dropout):
71
+ super().__init__()
72
+ head_size = n_embeds // n_head
73
+ self.sa = MultiHeadAttention(n_embeds, n_head, context_size, head_size, dropout)
74
+ self.ffwd = FeedForward(n_embeds, dropout)
75
+ self.ln1 = nn.LayerNorm(n_embeds)
76
+ self.ln2 = nn.LayerNorm(n_embeds)
77
+
78
+ def forward(self, x):
79
+ x = x + self.sa(self.ln1(x))
80
+ x = x + self.ffwd(self.ln2(x))
81
+ return x
82
+
83
+
84
+ # simple bigram model
85
+ class DecoderTransformer(nn.Module):
86
+ def __init__(self, vocab_size, n_embed, context_size, n_layer, n_head, dropout):
87
+ super().__init__()
88
+
89
+ self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
90
+ self.position_embedding_table = nn.Embedding(context_size, n_embed)
91
+ self.blocks = nn.Sequential(
92
+ *[Block(
93
+ n_embeds=n_embed,
94
+ n_head=n_head,
95
+ context_size=context_size,
96
+ dropout=dropout
97
+ ) for _ in range(n_layer)]
98
+ )
99
+ self.ln_f = nn.LayerNorm(n_embed)
100
+ self.lm_head = nn.Linear(n_embed, vocab_size)
101
+
102
+ def forward(self, idx, targets=None):
103
+ B, T = idx.shape
104
+ # idx and targets of size (B,T)
105
+ token_embeds = self.token_embedding_table(idx) # yields (B, T, C)
106
+ pos_embeds = self.position_embedding_table(torch.arange(T, device=device))
107
+ x = token_embeds + pos_embeds
108
+ x = self.ln_f(self.blocks(x))
109
+ logits = self.lm_head(x)
110
+
111
+ if targets is None:
112
+ return logits, None
113
+
114
+ # reshape elements
115
+ B, T, C = logits.shape
116
+ logits = logits.view(B*T,C)
117
+ targets = targets.view(B*T)
118
+ # compute loss (CE)
119
+ loss = F.cross_entropy(logits, targets)
120
+ return logits, loss
121
+
122
+ def generate(self, idx, max_new_tokens, context_size):
123
+ for _ in range(max_new_tokens):
124
+ idx_cond = idx[:, -context_size:]
125
+ logits, loss = self(idx_cond)
126
+ logits = logits[:,-1,:]
127
+ probs = F.softmax(logits, dim=-1)
128
+ idx_next = torch.multinomial(probs, num_samples=1)
129
+ idx = torch.cat([idx, idx_next], dim=1)
130
+ return idx
131
+