File size: 10,400 Bytes
a2dad08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from contextlib import nullcontext\n",
    "from bigram_model import BigramLanguageModel\n",
    "from tokenizer_utils import IntCharTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler\n",
    "ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]\n",
    "ctx = nullcontext() if device == 'cpu' else torch.amp.autocast(device_type=device, dtype=ptdtype)\n",
    "scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from data_utils import *\n",
    "model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embed, block_size=BLOCK_SIZE,\n",
    "                  bias=False, vocab_size=None, dropout=dropout)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([128, 256, 65])\n",
      "tensor(4.3690, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
     ]
    }
   ],
   "source": [
    "from data_utils import *\n",
    "xb, yb = get_random_batch('train')\n",
    "xb = xb.to(device)\n",
    "yb = yb.to(device)\n",
    "\n",
    "m = BigramLanguageModel(vocab_size=65, n_embed=n_embed, block_size=BLOCK_SIZE, num_heads=n_head, n_layers=n_layer).to(device)\n",
    "logits, loss = m(xb, yb)\n",
    "print(logits.shape)\n",
    "print(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def estimate_loss(model):\n",
    "    out = {}\n",
    "    model.eval()\n",
    "    for split in ['train', 'val']:\n",
    "        losses = torch.zeros(eval_iters)\n",
    "        for k in range(eval_iters):\n",
    "            X, Y = get_random_batch(split)\n",
    "            with ctx:\n",
    "                logits, loss = model(X, Y)\n",
    "            losses[k] = loss.item()\n",
    "        out[split] = losses.mean()\n",
    "    model.train()\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "char_tokenizer = load_int_char_tokenizer(load_text())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10.788929 M parameters\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step 0: train loss 4.3685, val loss 4.3640\n",
      "step 500: train loss 1.9681, val loss 2.0837\n",
      "step 1000: train loss 1.5377, val loss 1.7404\n",
      "step 1500: train loss 1.3802, val loss 1.6101\n",
      "step 2000: train loss 1.2855, val loss 1.5551\n",
      "step 2500: train loss 1.2162, val loss 1.5157\n",
      "step 3000: train loss 1.1617, val loss 1.5088\n",
      "step 3500: train loss 1.1061, val loss 1.5088\n",
      "step 4000: train loss 1.0555, val loss 1.5150\n",
      "step 4500: train loss 1.0086, val loss 1.5385\n",
      "step 4999: train loss 0.9583, val loss 1.5524\n"
     ]
    }
   ],
   "source": [
    "print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')\n",
    "\n",
    "# create a PyTorch optimizer\n",
    "optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)\n",
    "\n",
    "for iter in range(max_iters):\n",
    "\n",
    "    # every once in a while evaluate the loss on train and val sets\n",
    "    if iter % eval_interval == 0 or iter == max_iters - 1:\n",
    "        losses = estimate_loss(m)\n",
    "        print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n",
    "\n",
    "    # sample a batch of data\n",
    "    xb, yb = get_random_batch('train')\n",
    "\n",
    "    # evaluate the loss\n",
    "    logits, loss = m(xb, yb)\n",
    "    optimizer.zero_grad(set_to_none=True)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "saving checkpoint to ./nano_gpt_ckpts\n"
     ]
    }
   ],
   "source": [
    "checkpoint = {\n",
    "    'model': m.state_dict(),\n",
    "    'optimizer': optimizer.state_dict(),\n",
    "    'model_args': model_args,\n",
    "    'iter_num': max_iters,\n",
    "    'best_val_loss': losses['val'],\n",
    "\n",
    "}\n",
    "out_dir = \"./nano_gpt_ckpts\"\n",
    "print(f\"saving checkpoint to {out_dir}\")\n",
    "torch.save(checkpoint, os.path.join(out_dir, 'ckpt_5k_iters.pt'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "#m2 = BigramLanguageModel(vocab_size=65, n_embed=n_embed, block_size=BLOCK_SIZE, num_heads=n_head, n_layers=n_layer).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "GLOUCESTER: learn, like a nap. Prisoner will to my intents! with my brother! and this bloody makes off flows,--and haste tear'd your roe!--I should not be the other's.---I'ld do hear that be pupy with thear; sweet Montague,--thou as done not--So that they have nage must know,--never speak so many tears,--traightful ner-light,--with'd yet a ping tymp,--which time to stir; now still hurr'd,---water'd honour,--Pray's Coitlinius: the mountake's nobled daughter.' Sir, it is some thee on Rome is sin:--'proud him 'there;' none honest seen; forsweet must be pointed, hurls thee in men; a proud confines, foot, die, gin night, old Ratchard!--Go, good lord!--will'd you not piece, I dare not.' an't; swear by the dog, belike! mother!--How sir!-Spite! Jupiteous put o's!--God leave your lawful coward!'--for I'll dry down, you in death;'--near'---for very 'ven a day.---fa, by; 'twas his mother's disposed;--'I shall make no son,--hard him hear me,--do. Madam, or smother'd wife: and that you may part this denies.'--'--thrieks for Richmond dancerts, in free people's anointed,--O, hold: Curs, on a fiathful doom: every nurse, is I long now, never large.' quoth let return him; for an't plead the fie, his maids; he will not quarrel; 'twas this, but take within, as he learn, as and heat, it see; a gized evassages of season, imagish: yet, a very no other consulance, good den.--To fair cousin, stay! come, sir; and hath been, let it breather ring.' God; I am, trusper, I say: provided, pardone! a never lady; come in God. I'll fight with Montagues come. Why, 'twas bring you to be, if the pass off, and here, it dare, man cryield. Frow, your head A called with Gaunt; the cause. O, prettiest his pale thing, rust, and good. Thou adventure be more, Juliet, perishease: I'll take the queen, and his love.--give me note to de,--dyes help, Edward, and after Romeo!--Whence labour cann'd Warwick! was? whither? why hours! fairs! after was? stay come! your run? a happy kind!--O day, go be--hours, wrong!--ta w\n"
     ]
    }
   ],
   "source": [
    "# generate from the model\n",
    "context = torch.zeros((1, 1), dtype=torch.long, device=device)\n",
    "#print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))\n",
    "print(char_tokenizer.decode(m.generate(context, max_new_tokens=2000)[0].tolist()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m3 = BigramLanguageModel(vocab_size=65, n_embed=n_embed, block_size=BLOCK_SIZE, num_heads=n_head, n_layers=n_layer)\n",
    "ckpt = torch.load(os.path.join(\"./nano_gpt_ckpts\", \"ckpt_5k_iters.pt\"))\n",
    "m3.load_state_dict(ckpt['model'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "But Dohor, aged by! At Antigonus. You see his court! For death; a talm every hand, here shall!--So,--O, I, title now point!--Who, this I sem blind--that tark;--come boy?---O pray, peace! May, two here, do not---that I troth:----to villain leave, where was the Gallent--if I look the house,--bold Jour---whether may I go,--Mine son,---as I amiled me pized,--or so fled; 'tis a famouse,--there littenants,--If an either lawful hant ther is gone.' Sicilence, if it wer done! I have twize its sourness. P\n"
     ]
    }
   ],
   "source": [
    "context = torch.zeros((1, 1), dtype=torch.long)\n",
    "print(char_tokenizer.decode(m3.generate(context, max_new_tokens=500)[0].tolist()))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}