Moonlighthxq commited on
Commit
4ec74e3
·
verified ·
1 Parent(s): cdf9d85

Upload 3 files

Browse files
Files changed (3) hide show
  1. gpt2_train.11673347.log +0 -0
  2. model_00019560.bin +3 -0
  3. train_gpt2.py +908 -0
gpt2_train.11673347.log ADDED
The diff for this file is too large to render. See raw diff
 
model_00019560.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:448a505ad5dc65d4f11209225a7ab1a9841cefca96222ff6773b140c542f72dc
3
+ size 248952832
train_gpt2.py ADDED
@@ -0,0 +1,908 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reference code for GPT-2 training and inference.
3
+ Will save the model weights into files, to be read from C as initialization.
4
+
5
+ References:
6
+ 1) the official GPT-2 TensorFlow implementation released by OpenAI:
7
+ https://github.com/openai/gpt-2/blob/master/src/model.py
8
+ 2) huggingface/transformers PyTorch implementation:
9
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
10
+
11
+ Example launches to only benchmark the speed of bfloat16 compiled GPU training:
12
+ 1 GPU:
13
+ python train_gpt2.py --write_tensors=0 --num_iterations=50 --sequence_length=1024 --compile=1 --tensorcores=1 --dtype=bfloat16
14
+ you can also turn on flash-attention by appending --flash=1
15
+ 4 GPU:
16
+ torchrun --standalone --nproc_per_node=4 train_gpt2.py --write_tensors=0 --num_iterations=50 --sequence_length=1024 --compile=1 --tensorcores=1 --dtype=bfloat16
17
+ """
18
+
19
+ import os
20
+ import math
21
+ import glob
22
+ import struct
23
+ import inspect
24
+ from contextlib import nullcontext
25
+ from dataclasses import dataclass
26
+
27
+ import numpy as np
28
+ import torch
29
+ import torch.nn as nn
30
+ from torch.nn import functional as F
31
+ import torch._inductor.config as config
32
+ from torch.nn.parallel import DistributedDataParallel as DDP
33
+ from torch.distributed import init_process_group, destroy_process_group
34
+ from torch.distributed.optim import ZeroRedundancyOptimizer
35
+ import torch.distributed as dist
36
+ # -----------------------------------------------------------------------------
37
+ # PyTorch nn.Module definitions for the GPT-2 model
38
+ import json
39
+
40
+ tiktoken_cache_dir = "/scratch/user/alexzheng/llm.c/tiktoken_cache/"
41
+ os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
42
+
43
+ # validate
44
+ assert os.path.exists(os.path.join(tiktoken_cache_dir, "6d1cbeee0f20b3d9449abfede4726ed8212e3aee"))
45
+ assert os.path.exists(os.path.join(tiktoken_cache_dir, "6c7ea1a7e38e3a7f062df639a5b80947f075ffe6"))
46
+ print("pass tiktoken verification")
47
+
48
+
49
+ class NewGELU(nn.Module):
50
+ """Careful there are a few versions of GeLU, this one is the exact one used by OpenAI"""
51
+ def forward(self, input):
52
+ return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
53
+
54
+ class SwiGLU(nn.Module):
55
+ def __init__(self, input_dim, output_dim):
56
+ super(SwiGLU, self).__init__()
57
+ self.fc1 = nn.Linear(input_dim, output_dim)
58
+ self.fc2 = nn.Linear(input_dim, output_dim)
59
+
60
+ def forward(self, x):
61
+ return self.fc1(x) * torch.sigmoid(self.fc2(x))
62
+
63
+
64
+ class RMSNorm(nn.Module):
65
+ def __init__(self, dim, eps=1e-6):
66
+ super(RMSNorm, self).__init__()
67
+ self.eps = eps
68
+ self.weight = nn.Parameter(torch.ones(dim))
69
+
70
+ def forward(self, x):
71
+ rms = (x ** 2).mean(dim=-1, keepdim=True).sqrt()
72
+ return x / (rms + self.eps) * self.weight
73
+
74
+ # def apply_rope(q, k, seq_len, dim):
75
+ # position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1).to(q.device)
76
+ # div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)).to(q.device)
77
+ #
78
+ # # 生成 RoPE 位置编码
79
+ # pe = torch.zeros(seq_len, dim).to(q.device)
80
+ # pe[:, 0::2] = torch.sin(position * div_term)
81
+ # pe[:, 1::2] = torch.cos(position * div_term)
82
+ #
83
+ # # 在 Query 和 Key 上应用 RoPE
84
+ # pe = pe.unsqueeze(0) # (1, seq_len, dim)
85
+ # q = (q * pe[:, :q.size(1), :]) - (k * pe[:, :k.size(1), :]) # 应用旋转
86
+ # k = (q * pe[:, :q.size(1), :]) + (k * pe[:, :k.size(1), :])
87
+ # return q, k
88
+
89
+ # using a global to toggle flash-attention
90
+ FLASH = 0
91
+
92
+ class CausalSelfAttention(nn.Module):
93
+
94
+ def __init__(self, config):
95
+ super().__init__()
96
+ assert config.n_embd % config.n_head == 0
97
+ # key, query, value projections for all heads, but in a batch
98
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
99
+ # output projection
100
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
101
+ self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1
102
+ # regularization
103
+ self.n_head = config.n_head
104
+ self.n_embd = config.n_embd
105
+ # not really a 'bias', more of a mask, but following the OpenAI/HF naming though
106
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
107
+ .view(1, 1, config.block_size, config.block_size))
108
+
109
+ def forward(self, x):
110
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
111
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
112
+ qkv = self.c_attn(x)
113
+ q, k, v = qkv.split(self.n_embd, dim=2)
114
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
115
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
116
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
117
+ # Apply RoPE
118
+ # q, k = apply_rope(q, k, T, C // self.n_head)
119
+ if FLASH:
120
+ # flashattention
121
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
122
+ else:
123
+ # manual implementation of attention
124
+ # this materializes the large (T,T) matrix for all the queries and keys
125
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
126
+ att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
127
+ att = F.softmax(att, dim=-1)
128
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
129
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
130
+ # output projection
131
+ y = self.c_proj(y)
132
+ return y
133
+
134
+ class MLP(nn.Module):
135
+
136
+ def __init__(self, config):
137
+ super().__init__()
138
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
139
+ self.swiglu = SwiGLU(4 * config.n_embd, 4 * config.n_embd) # Initialize SwiGLU, input and output dimensions are 4 times the embedding dimension
140
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
141
+ self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1
142
+
143
+ def forward(self, x):
144
+ x = self.c_fc(x)
145
+ x = self.swiglu(x)
146
+ x = self.c_proj(x)
147
+ return x
148
+
149
+ class Block(nn.Module):
150
+
151
+ def __init__(self, config):
152
+ super().__init__()
153
+ self.ln_1 = RMSNorm(config.n_embd)
154
+ self.attn = CausalSelfAttention(config)
155
+ self.ln_2 = RMSNorm(config.n_embd)
156
+ self.mlp = MLP(config)
157
+
158
+ def forward(self, x):
159
+ x = x + self.attn(self.ln_1(x))
160
+ x = x + self.mlp(self.ln_2(x))
161
+ return x
162
+
163
+ # -----------------------------------------------------------------------------
164
+ # The main GPT-2 model
165
+
166
+ @dataclass
167
+ class GPTConfig:
168
+ block_size: int = 1024
169
+ vocab_size: int = 50257
170
+ n_layer: int = 12
171
+ n_head: int = 12
172
+ n_embd: int = 768
173
+
174
+ class GPT(nn.Module):
175
+
176
+ def __init__(self, config):
177
+ super().__init__()
178
+ self.config = config
179
+
180
+ self.transformer = nn.ModuleDict(dict(
181
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
182
+ wpe = nn.Embedding(config.block_size, config.n_embd),
183
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
184
+ ln_f = RMSNorm(config.n_embd),
185
+ ))
186
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
187
+ self.lm_head.LLMC_SKIP_INIT = 1 # don't init this one, we will tie weights
188
+ self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
189
+
190
+ # init all weights, use a torch rng object to be very careful
191
+ self.init_rng = torch.Generator()
192
+ self.init_rng.manual_seed(42)
193
+ self.apply(self._init_weights)
194
+
195
+ def _init_weights(self, module):
196
+ if isinstance(module, nn.Linear):
197
+ # apply special scaled init to the residual projections, per GPT-2 paper
198
+ std = 0.02 if not hasattr(module, 'LLMC_RESIDUAL_SCALE_FLAG') else 0.02/math.sqrt(2 * self.config.n_layer)
199
+ # we want to skip initializing lm_head, which shares parameters with wte
200
+ # and wte was already initialized down below during the Embedding init
201
+ if not hasattr(module, 'LLMC_SKIP_INIT'):
202
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std, generator=self.init_rng)
203
+ if module.bias is not None:
204
+ torch.nn.init.zeros_(module.bias)
205
+ elif isinstance(module, nn.Embedding):
206
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02, generator=self.init_rng)
207
+
208
+ def forward(self, idx, targets=None, return_logits=True):
209
+ device = idx.device
210
+ b, t = idx.size()
211
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
212
+ pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
213
+
214
+ # forward the GPT model itself
215
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
216
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
217
+ x = tok_emb + pos_emb
218
+
219
+ for block in self.transformer.h:
220
+ x = block(x)
221
+ x = self.transformer.ln_f(x)
222
+
223
+ if targets is not None:
224
+ # if we are given some desired targets also calculate the loss
225
+ logits = self.lm_head(x)
226
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
227
+ else:
228
+ # inference-time mini-optimization: only forward the lm_head on the very last position
229
+ logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
230
+ loss = None
231
+
232
+ # there are performance reasons why not returning logits is prudent, if not needed
233
+ if not return_logits:
234
+ logits = None
235
+
236
+ return logits, loss
237
+
238
+ @classmethod
239
+ def from_pretrained(cls, model_type):
240
+ """Loads pretrained GPT-2 model weights from huggingface"""
241
+ assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
242
+ from transformers import GPT2LMHeadModel
243
+ print("loading weights from pretrained gpt: %s" % model_type)
244
+
245
+ # n_layer, n_head and n_embd are determined from model_type
246
+ config_args = {
247
+ 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
248
+ 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
249
+ 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
250
+ 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
251
+ }[model_type]
252
+ config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
253
+ config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
254
+ # create a from-scratch initialized minGPT model
255
+ config = GPTConfig(**config_args)
256
+ model = GPT(config)
257
+ sd = model.state_dict()
258
+ sd_keys = sd.keys()
259
+ sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
260
+
261
+ # init a huggingface/transformers model
262
+ model_hf = GPT2LMHeadModel.from_pretrained(model_type)
263
+ sd_hf = model_hf.state_dict()
264
+
265
+ # copy while ensuring all of the parameters are aligned and match in names and shapes
266
+ sd_keys_hf = sd_hf.keys()
267
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
268
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
269
+ transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
270
+ # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
271
+ # this means that we have to transpose these weights when we import them
272
+ assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
273
+ for k in sd_keys_hf:
274
+ if any(k.endswith(w) for w in transposed):
275
+ # special treatment for the Conv1D weights we need to transpose
276
+ assert sd_hf[k].shape[::-1] == sd[k].shape
277
+ with torch.no_grad():
278
+ sd[k].copy_(sd_hf[k].t())
279
+ else:
280
+ # vanilla copy over the other parameters
281
+ assert sd_hf[k].shape == sd[k].shape
282
+ with torch.no_grad():
283
+ sd[k].copy_(sd_hf[k])
284
+
285
+ return model
286
+
287
+ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type, zero_stage):
288
+ # start with all of the candidate parameters
289
+ param_dict = {pn: p for pn, p in self.named_parameters()}
290
+ # filter out those that do not require grad
291
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
292
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
293
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
294
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
295
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
296
+ optim_groups = [
297
+ {'params': decay_params, 'weight_decay': weight_decay},
298
+ {'params': nodecay_params, 'weight_decay': 0.0}
299
+ ]
300
+ num_decay_params = sum(p.numel() for p in decay_params)
301
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
302
+ print0(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
303
+ print0(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
304
+ # Create AdamW optimizer and use the fused version if it is available
305
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
306
+ use_fused = fused_available and device_type == 'cuda'
307
+ print0(f"using fused AdamW: {use_fused}")
308
+ if zero_stage == 1:
309
+ print0("using ZeroRedundancyOptimizer")
310
+ optimizer = ZeroRedundancyOptimizer(**optim_groups[0], optimizer_class=torch.optim.AdamW,
311
+ lr=learning_rate, betas=betas, fused=use_fused)
312
+ optimizer.add_param_group(optim_groups[1])
313
+ else:
314
+ print0("using regular AdamW")
315
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, fused=use_fused)
316
+ return optimizer
317
+
318
+ @torch.no_grad()
319
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
320
+ """
321
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
322
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
323
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
324
+ """
325
+ for _ in range(max_new_tokens):
326
+ # if the sequence context is growing too long we must crop it at block_size
327
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
328
+ # forward the model to get the logits for the index in the sequence
329
+ logits, _ = self(idx_cond)
330
+ # pluck the logits at the final step and scale by desired temperature
331
+ logits = logits[:, -1, :] / temperature
332
+ # optionally crop the logits to only the top k options
333
+ if top_k is not None:
334
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
335
+ logits[logits < v[:, [-1]]] = -float('Inf')
336
+ # apply softmax to convert logits to (normalized) probabilities
337
+ probs = F.softmax(logits, dim=-1)
338
+ # sample from the distribution
339
+ idx_next = torch.multinomial(probs, num_samples=1)
340
+ # append sampled index to the running sequence and continue
341
+ idx = torch.cat((idx, idx_next), dim=1)
342
+
343
+ return idx
344
+
345
+ # -----------------------------------------------------------------------------
346
+ # Our own simple Distributed Data Loader
347
+
348
+ def _peek_data_shard(filename):
349
+ # only reads the header, returns header data
350
+ with open(filename, "rb") as f:
351
+ # first read the header, which is 256 int32 integers (4 bytes each)
352
+ header = np.frombuffer(f.read(256*4), dtype=np.int32)
353
+ if header[0] != 20240520:
354
+ print("ERROR: magic number mismatch in the data .bin file!")
355
+ print("---> HINT: Are you passing in a correct file with --input_bin?")
356
+ print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README")
357
+ print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try")
358
+ exit(1)
359
+ assert header[1] == 1, "unsupported version"
360
+ ntok = header[2] # number of tokens (claimed)
361
+ return ntok # for now just return the number of tokens
362
+
363
+ def _load_data_shard(filename):
364
+ with open(filename, "rb") as f:
365
+ # first read the header, which is 256 int32 integers (4 bytes each)
366
+ header = np.frombuffer(f.read(256*4), dtype=np.int32)
367
+ assert header[0] == 20240520, "magic number mismatch in the data .bin file"
368
+ assert header[1] == 1, "unsupported version"
369
+ ntok = header[2] # number of tokens (claimed)
370
+ # the rest of it are tokens, stored as uint16
371
+ tokens = np.frombuffer(f.read(), dtype=np.uint16)
372
+ assert len(tokens) == ntok, "number of tokens read does not match header?"
373
+ return tokens
374
+
375
+ class DistributedDataLoader:
376
+ def __init__(self, filename_pattern, B, T, process_rank, num_processes):
377
+ self.process_rank = process_rank
378
+ self.num_processes = num_processes
379
+ self.B = B
380
+ self.T = T
381
+
382
+ # glob files that match the pattern
383
+ self.files = sorted(glob.glob(filename_pattern))
384
+ assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"
385
+
386
+ # load and validate all data shards, count number of tokens in total
387
+ ntok_total = 0
388
+ for fname in self.files:
389
+ shard_ntok = _peek_data_shard(fname)
390
+ assert shard_ntok >= num_processes * B * T + 1
391
+ ntok_total += shard_ntok
392
+ self.ntok_total = ntok_total
393
+ print0(f"DataLoader: total number of tokens: {ntok_total:,} across {len(self.files)} files")
394
+
395
+ # kick things off
396
+ self.current_shard = None
397
+ self.reset()
398
+
399
+ def reset(self):
400
+ # we're being a bit clever here: if we already had shard 0 loaded,
401
+ # then don't do the work to reload it, just reset the pointer
402
+ if self.current_shard != 0:
403
+ self.current_shard = 0
404
+ self.tokens = _load_data_shard(self.files[self.current_shard])
405
+ self.current_position = self.process_rank * self.B * self.T
406
+
407
+ def advance(self): # advance to next data shard
408
+ self.current_shard = (self.current_shard + 1) % len(self.files)
409
+ self.current_position = self.process_rank * self.B * self.T
410
+ self.tokens = _load_data_shard(self.files[self.current_shard])
411
+
412
+ def next_batch(self):
413
+ B = self.B
414
+ T = self.T
415
+ buf = self.tokens[self.current_position : self.current_position+B*T+1]
416
+ buf = torch.tensor(buf.astype(np.int32), dtype=torch.long)
417
+ x = (buf[:-1]).view(B, T) # inputs
418
+ y = (buf[1:]).view(B, T) # targets
419
+ # advance the start pointer in current shard
420
+ self.current_position += B * T * self.num_processes
421
+ # if loading the next batch would be out of bounds advance the shard
422
+ if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
423
+ self.advance()
424
+ return x, y
425
+
426
+ # -----------------------------------------------------------------------------
427
+ # Python -> C bridge utilities for saving params/grads/activations to .bin files
428
+
429
+ def write_fp32(tensor, file):
430
+ t = tensor.detach().cpu().to(torch.float32)
431
+ b = t.numpy().tobytes()
432
+ file.write(b)
433
+
434
+ def write_bf16(tensor, file):
435
+ t = tensor.detach().cpu().to(torch.bfloat16)
436
+ # numpy doesn't have bf16 datatype so we have to trick it
437
+ t = t.view(torch.int16) # trick: reinterpret as int16
438
+ b = t.numpy().tobytes()
439
+ file.write(b)
440
+
441
+ def write_tensors(model_tensors, L, file, dtype):
442
+ # writes the GPT-2 model's weights to a binary file
443
+ assert dtype in {"float32", "bfloat16"}
444
+ write_fun = write_fp32 if dtype == "float32" else write_bf16
445
+ write_fun(model_tensors["transformer.wte.weight"], file) # (V, C)
446
+ write_fun(model_tensors["transformer.wpe.weight"], file) # (T, C)
447
+ for i in range(L): # (L, C)
448
+ write_fun(model_tensors[f"transformer.h.{i}.ln_1.weight"], file)
449
+ for i in range(L): # (L, C)
450
+ write_fun(model_tensors[f"transformer.h.{i}.ln_1.bias"], file)
451
+ for i in range(L): # (L, 3C, C)
452
+ write_fun(model_tensors[f"transformer.h.{i}.attn.c_attn.weight"], file)
453
+ for i in range(L): # (L, 3C)
454
+ write_fun(model_tensors[f"transformer.h.{i}.attn.c_attn.bias"], file)
455
+ for i in range(L): # (L, C, C)
456
+ write_fun(model_tensors[f"transformer.h.{i}.attn.c_proj.weight"], file)
457
+ for i in range(L): # (L, C)
458
+ write_fun(model_tensors[f"transformer.h.{i}.attn.c_proj.bias"], file)
459
+ for i in range(L): # (L, C)
460
+ write_fun(model_tensors[f"transformer.h.{i}.ln_2.weight"], file)
461
+ for i in range(L): # (L, C)
462
+ write_fun(model_tensors[f"transformer.h.{i}.ln_2.bias"], file)
463
+ for i in range(L): # (L, 4C, C)
464
+ write_fun(model_tensors[f"transformer.h.{i}.mlp.c_fc.weight"], file)
465
+ for i in range(L): # (L, 4C)
466
+ write_fun(model_tensors[f"transformer.h.{i}.mlp.c_fc.bias"], file)
467
+ for i in range(L): # (L, C, 4C)
468
+ write_fun(model_tensors[f"transformer.h.{i}.mlp.c_proj.weight"], file)
469
+ for i in range(L): # (L, C)
470
+ write_fun(model_tensors[f"transformer.h.{i}.mlp.c_proj.bias"], file)
471
+ write_fun(model_tensors["transformer.ln_f.weight"], file) # (C, )
472
+ write_fun(model_tensors["transformer.ln_f.bias"], file) # (C, )
473
+
474
+ @torch.no_grad()
475
+ def pad_vocab(tensor, multiple=128, value=0):
476
+ """
477
+ The dimension of the vocab size in GPT-2 is 50,257
478
+ which is unfortunately a very unfriendly number for a lot of
479
+ matrix operations on the GPU. So we pad it to the nearest
480
+ friendlier multiple, e.g. 50,304 if multiple=128 when we
481
+ export the weights into C land. This is a NOOP algorithmically
482
+ and is only done to make the tensor operations more efficient.
483
+ """
484
+ assert tensor.ndim == 2
485
+ V, C = tensor.shape
486
+ assert V == 50257, "just being defensive here"
487
+ # calculate padded vocab size by rounding up to nearest multiple
488
+ Vp = ((V + multiple - 1) // multiple) * multiple
489
+ # pad the tensor
490
+ pad_rows = Vp - V
491
+ padded = tensor if pad_rows == 0 else F.pad(tensor, (0, 0, 0, pad_rows), value=value)
492
+ assert padded.shape == (Vp, C)
493
+ return padded
494
+
495
+ def write_model(model, filename, dtype):
496
+ # everything we need to instantiate the model
497
+ # 1) header is: version int, GPTConfig ints, padding to 1024 bytes
498
+ assert dtype in {"float32", "bfloat16"} # float16 todo maybe later
499
+ version = {
500
+ "float32": 3, # 3: all tensors are fp32, padded vocab
501
+ "bfloat16": 5, # 5: all tensors are bf16, padded vocab
502
+ }[dtype]
503
+ header = torch.zeros(256, dtype=torch.int32)
504
+ header[0] = 20240326 # magic
505
+ header[1] = version # checkpoint version
506
+ header[2] = model.config.block_size
507
+ header[3] = model.config.vocab_size
508
+ header[4] = model.config.n_layer
509
+ header[5] = model.config.n_head
510
+ header[6] = model.config.n_embd
511
+ # 2) the parameters follow the header
512
+ params = {name: param.cpu() for name, param in model.named_parameters()}
513
+ # pad the vocab to a multiple of 128 here at export, for efficiency in C
514
+ wte = params["transformer.wte.weight"] # (V, C)
515
+ wte_padded = pad_vocab(wte) # (Vp, C)
516
+ params["transformer.wte.weight"] = wte_padded # (Vp, C)
517
+ print(f"padded vocab size from {wte.size(0)} to {wte_padded.size(0)}")
518
+ header[7] = wte_padded.size(0) # padded vocab size store in header
519
+ # now write to file
520
+ with open(filename, "wb") as file:
521
+ file.write(header.numpy().tobytes()) # header
522
+ write_tensors(params, model.config.n_layer, file, dtype) # params
523
+ print(f"wrote {filename}")
524
+
525
+ def write_state(model, x, y, logits, loss, filename):
526
+ # the state is used for debugging.
527
+ # it contains information about the input, logits, loss, and the parameter gradients
528
+ # this can be used for checking the computation correctness in C
529
+ header = torch.zeros(256, dtype=torch.int32)
530
+ header[0] = 20240327 # magic
531
+ header[1] = 2 # run state version = 2 (1 -> 2 for padded vocab changes)
532
+ header[2] = x.size(0) # batch size of the batch, B
533
+ header[3] = x.size(1) # temporal extent of the batch, T
534
+ grads = {name: param.grad.cpu() for name, param in model.named_parameters()}
535
+ # pad the vocab grads here as well, to mirror write_model
536
+ wte_grad = grads["transformer.wte.weight"] # (V, C)
537
+ wte_grad_padded = pad_vocab(wte_grad, value=0) # (Vp, C) # TODO later maybe pad with nan?
538
+ grads["transformer.wte.weight"] = wte_grad_padded # (Vp, C)
539
+ print(f"padded vocab size in reference grads from {wte_grad.size(0)} to {wte_grad_padded.size(0)}")
540
+ with open(filename, "wb") as file:
541
+ # header
542
+ file.write(header.numpy().tobytes())
543
+ # input x
544
+ file.write(x.cpu().numpy().astype("int32").tobytes()) # (B, T)
545
+ # targets y
546
+ file.write(y.cpu().numpy().astype("int32").tobytes()) # (B, T)
547
+ # logits (result of the model forward pass)
548
+ write_fp32(logits.cpu(), file)
549
+ # loss (single float, result of the cross entropy loss)
550
+ write_fp32(loss.cpu(), file)
551
+ # gradients
552
+ write_tensors(grads, model.config.n_layer, file, "float32")
553
+ print(f"wrote {filename}")
554
+
555
+ def write_tokenizer(enc, filename):
556
+ n = enc.max_token_value + 1
557
+ header = torch.zeros(256, dtype=torch.int32)
558
+ header[0] = 20240328 # magic
559
+ header[1] = 2 # tokenizer version = 2 (1 -> 2: includes EOT token)
560
+ header[2] = n # number of tokens
561
+ header[3] = enc.eot_token # EOT token
562
+ with open(filename, "wb") as file:
563
+ file.write(header.numpy().tobytes())
564
+ for i in range(n):
565
+ b = enc.decode_bytes([i])
566
+ length = len(b)
567
+ assert length < 256, f"Token length exceeds 255: {length}"
568
+ file.write(struct.pack("<B", length)) # Write the length as a 1-byte unsigned integer
569
+ file.write(b) # Write the actual bytes
570
+ print(f"wrote {filename}")
571
+
572
+ # -----------------------------------------------------------------------------
573
+ # int main
574
+
575
+ def print0(*args, **kwargs):
576
+ # modified print that only prints from the master process
577
+ # if this is not a distributed run, it's just a print
578
+ if int(os.environ.get("RANK", 0)) == 0:
579
+ print(*args, **kwargs)
580
+
581
+ if __name__ == "__main__":
582
+ import time
583
+ import argparse
584
+ import tiktoken
585
+ # from transformers import GPT2Tokenizer
586
+ print0(f"Running pytorch {torch.version.__version__}")
587
+
588
+ # default settings will overfit a tiny batch of data
589
+ # and save model weights and debug state to disk on the first iteration
590
+ parser = argparse.ArgumentParser()
591
+ # file system input / output
592
+ parser.add_argument("--input_bin", type=str, default="dev/data/tinyshakespeare/tiny_shakespeare_val.bin", help="input .bin to train on")
593
+ parser.add_argument("--input_val_bin", type=str, default="", help="input .bin to eval validation loss on")
594
+ parser.add_argument("--output_dir", type=str, default="", help="output directory to which to write logs and checkpoints")
595
+ parser.add_argument("--model", type=str, default="gpt2", help="gpt2|gpt2-medium|gpt2-large|gpt2-xl|d12|d24|d36|d48")
596
+ # token layout for each step of the optimization
597
+ parser.add_argument("--batch_size", type=int, default=4, help="batch size, in units of #batch dimensions")
598
+ parser.add_argument("--sequence_length", type=int, default=64, help="sequence length")
599
+ parser.add_argument("--total_batch_size", type=int, default=256, help="total desired batch size, in units of #tokens")
600
+ # workload (number of steps)
601
+ parser.add_argument("--num_iterations", type=int, default=10, help="number of iterations to run")
602
+ parser.add_argument("--inference_only", type=int, default=0, help="only run inference")
603
+ # optimization
604
+ parser.add_argument("--learning_rate", type=float, default=1e-4, help="learning rate warmup iterations")
605
+ parser.add_argument("--warmup_iters", type=int, default=0, help="learning rate warmup iterations")
606
+ parser.add_argument("--learning_rate_decay_frac", type=float, default=1.0, help="learning rate warmup iterations")
607
+ parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay")
608
+ parser.add_argument("--grad_clip", type=float, default=1.0, help="maximum gradient magnitude")
609
+ # evaluation
610
+ parser.add_argument("--val_loss_every", type=int, default=0, help="every how mant steps to evaluate val loss?")
611
+ parser.add_argument("--val_max_steps", type=int, default=20, help="how many batches of val to average?")
612
+ parser.add_argument("--sample_every", type=int, default=0, help="how often to sample from the model?")
613
+ # debugging
614
+ parser.add_argument("--overfit_single_batch", type=int, default=1, help="overfit just one batch of data")
615
+ # numerics
616
+ parser.add_argument("--tensorcores", type=int, default=0, help="use tensorcores")
617
+ # memory management
618
+ parser.add_argument("--device", type=str, default="", help="by default we autodetect, or set it here")
619
+ parser.add_argument("--compile", type=int, default=0, help="torch.compile the model")
620
+ parser.add_argument("--flash", type=int, default=0, help="use flash attention")
621
+ parser.add_argument("--dtype", type=str, default="float32", help="float32|float16|bfloat16")
622
+ parser.add_argument("--zero_stage", type=int, default=0, help="zero redundancy optimizer stage (0/1/2/3)")
623
+ # python -> C bridge
624
+ parser.add_argument("--write_tensors", type=int, default=1, help="write tensors to disk")
625
+ args = parser.parse_args()
626
+
627
+ # args error checking and convenience variables
628
+ B, T = args.batch_size, args.sequence_length
629
+ assert 1 <= T <= 1024
630
+ assert args.dtype in {"float32", "float16", "bfloat16"}
631
+ assert args.model in {"gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl", "d12", "d24", "d36", "d48"}
632
+
633
+ # set up DDP (distributed data parallel). torchrun sets this env variable
634
+ ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
635
+ if ddp:
636
+ # use of DDP atm demands CUDA, we set the device appropriately according to rank
637
+ assert torch.cuda.is_available(), "for now i think we need CUDA for DDP"
638
+ init_process_group(backend='nccl')
639
+ ddp_rank = int(os.environ['RANK'])
640
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
641
+ ddp_world_size = int(os.environ['WORLD_SIZE'])
642
+ device = f'cuda:{ddp_local_rank}'
643
+ torch.cuda.set_device(device)
644
+ master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
645
+ seed_offset = 0 # each process gets the exact same seed
646
+ zero_stage = args.zero_stage
647
+ else:
648
+ ddp_rank = 0
649
+ ddp_local_rank = 0
650
+ zero_stage = 0
651
+ ddp_world_size = 1
652
+ master_process = True
653
+ seed_offset = 0
654
+ # select the device
655
+ if args.device:
656
+ # provided explicitly by the user
657
+ device = args.device
658
+ else:
659
+ # attempt to autodetect the device
660
+ device = "cpu"
661
+ if torch.cuda.is_available():
662
+ device = "cuda"
663
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
664
+ device = "mps"
665
+ print(f"using device: {device}")
666
+ device_type = 'cuda' if 'cuda' in device else 'cpu'
667
+
668
+ # calculate gradient accumulation from the desired total batch size and the current run configuration
669
+ tokens_per_fwdbwd = B * T * ddp_world_size
670
+ assert args.total_batch_size % tokens_per_fwdbwd == 0
671
+ grad_accum_steps = args.total_batch_size // tokens_per_fwdbwd
672
+ print0(f"total desired batch size: {args.total_batch_size}")
673
+ print0(f"=> calculated gradient accumulation steps: {grad_accum_steps}")
674
+
675
+ # set up a context manager following the desired dtype and device
676
+ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
677
+ ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
678
+
679
+ # rng / reproducibility
680
+ torch.manual_seed(42)
681
+ if torch.cuda.is_available():
682
+ torch.cuda.manual_seed(42)
683
+
684
+ # set the torch precision mode to use TensorFloat32 (TF32) for matmuls
685
+ # docs https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html
686
+ if args.tensorcores:
687
+ torch.set_float32_matmul_precision('high')
688
+
689
+ # turn on/off flash attention
690
+ assert args.flash in {0, 1}
691
+ FLASH = args.flash
692
+
693
+ # init (and write) the tokenizer
694
+ enc = tiktoken.get_encoding("gpt2")
695
+ # enc = GPT2Tokenizer.from_pretrained("gpt2", cache_dir="/scratch/user/alexzheng/tokenizer_cache/")
696
+ if master_process and args.write_tensors: # tokenizer is technically not tensors but ok
697
+ write_tokenizer(enc, "gpt2_tokenizer.bin")
698
+
699
+ # init the model, either from scratch or from OpenAI pretrained checkpoint
700
+ if args.model[0] == "d":
701
+ # from scratch (random weights)
702
+ model_config = {
703
+ "d12": GPTConfig(block_size=1024, vocab_size=50257, n_layer=12, n_head=12, n_embd=768),
704
+ "d24": GPTConfig(block_size=1024, vocab_size=50257, n_layer=24, n_head=16, n_embd=1024),
705
+ "d36": GPTConfig(block_size=1024, vocab_size=50257, n_layer=36, n_head=20, n_embd=1280),
706
+ "d48": GPTConfig(block_size=1024, vocab_size=50257, n_layer=48, n_head=25, n_embd=1600),
707
+ }[args.model]
708
+ model = GPT(model_config)
709
+ else:
710
+ # load the GPT-2 model weights
711
+ model = GPT.from_pretrained(args.model)
712
+ model.train()
713
+ model.to(device)
714
+ if args.compile:
715
+ if hasattr(config, "coordinate_descent_tuning"):
716
+ config.coordinate_descent_tuning = True # suggested by @Chillee
717
+ print0("compiling the model...")
718
+ model = torch.compile(model)
719
+
720
+ # -------------------------------------------------------------------------
721
+ # Our own version of a simple DistributedDataLoader
722
+
723
+ # load tokens
724
+ train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size)
725
+ val_loader = None
726
+ if args.input_val_bin:
727
+ val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size)
728
+
729
+ # -------------------------------------------------------------------------
730
+ # PyTorch -> C bridge: save some weights and state for C to load later as reference
731
+
732
+ # do one forward pass to generate ground truth for our C tests
733
+ if master_process and args.write_tensors and (not args.inference_only):
734
+ x, y = train_loader.next_batch()
735
+ x, y = x.to(device), y.to(device)
736
+ logits, loss = model(x, y)
737
+ loss.backward()
738
+ # save model params, in both float32 and bfloat16
739
+ model_to_size = {"gpt2": "124M", "gpt2-medium": "355M", "gpt2-large": "774M", "gpt2-xl": "1558M"}
740
+ model_to_size.update({f"d{d}": f"d{d}" for d in [12, 24, 36, 48]})
741
+ model_size_str = model_to_size[args.model] # e.g. "124M", or "d12"
742
+ write_model(model, f"gpt2_{model_size_str}.bin", dtype="float32")
743
+ write_model(model, f"gpt2_{model_size_str}_bf16.bin", dtype="bfloat16")
744
+ # save x, y, logits, loss, and parameter gradients, for debugging C
745
+ # always store these in fp32 to have an accurate reference (?)
746
+ write_state(model, x, y, logits, loss, f"gpt2_{model_size_str}_debug_state.bin")
747
+ # reset the train_loader for the optimization below
748
+ train_loader.reset()
749
+
750
+ # -------------------------------------------------------------------------
751
+ # main training loop
752
+
753
+ # here we wrap model into DDP container
754
+ if ddp:
755
+ model = DDP(model, device_ids=[ddp_local_rank])
756
+ raw_model = model.module if ddp else model # always contains the "raw" unwrapped model
757
+
758
+ # init the optimizer
759
+ optimizer = raw_model.configure_optimizers(weight_decay=args.weight_decay,
760
+ learning_rate=args.learning_rate, betas=(0.9, 0.95),
761
+ device_type=device, zero_stage=zero_stage)
762
+
763
+ # learning rate decay scheduler (cosine with warmup)
764
+ def get_lr(it):
765
+ min_lr = args.learning_rate * args.learning_rate_decay_frac
766
+ # 1) linear warmup for warmup_iters steps
767
+ if it < args.warmup_iters:
768
+ return args.learning_rate * (it+1) / args.warmup_iters
769
+ # 2) if it > lr_decay_iters, return min learning rate
770
+ if it > args.num_iterations:
771
+ return min_lr
772
+ # 3) in between, use cosine decay down to min learning rate
773
+ decay_ratio = (it - args.warmup_iters) / (args.num_iterations - args.warmup_iters)
774
+ assert 0 <= decay_ratio <= 1
775
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0
776
+ return min_lr + coeff * (args.learning_rate - min_lr)
777
+
778
+ # create the logging directory if it does not exist
779
+ logfile = None
780
+ if args.output_dir:
781
+ os.makedirs(args.output_dir, exist_ok=True)
782
+ logfile = os.path.join(args.output_dir, "main.log")
783
+ # create the log file "main.log" inside it, and wipe it clean
784
+ with open(logfile, "w") as f:
785
+ pass
786
+
787
+ if device == "cuda":
788
+ torch.cuda.reset_peak_memory_stats()
789
+ timings = []
790
+ norm = -1.0 # dummy value to print in inference-only mode
791
+ for step in range(args.num_iterations + 1):
792
+ t0 = time.time()
793
+ last_step = (step == args.num_iterations)
794
+
795
+ # once in a while evaluate the validation dataset
796
+ if (args.val_loss_every > 0 \
797
+ and (step % args.val_loss_every == 0 or last_step)) \
798
+ and (val_loader is not None):
799
+ model.eval()
800
+ val_loader.reset()
801
+ with torch.no_grad():
802
+ val_loss = 0.0
803
+ for _ in range(args.val_max_steps):
804
+ x, y = val_loader.next_batch()
805
+ x, y = x.to(device), y.to(device)
806
+ _, loss = model(x, y, return_logits=False)
807
+ val_loss += loss.item()
808
+ val_loss /= args.val_max_steps
809
+ # log to console and to file
810
+ print0(f"val loss {val_loss}")
811
+ if master_process and logfile is not None:
812
+ with open(logfile, "a") as f:
813
+ f.write("s:%d tel:%f\n" % (step, val_loss))
814
+
815
+ # once in a while perform model inference on the master process
816
+ if (args.sample_every > 0 \
817
+ and (step % args.sample_every == 0 or last_step)) \
818
+ and master_process:
819
+ model.eval()
820
+ # before we end, let's also do one round of inference
821
+ # we'll kick off the generation with "<|endoftext|>", which designates the start of a new sequence
822
+ start_ids = [enc.eot_token]
823
+ xg = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
824
+ max_new_tokens = 32
825
+ temperature = 1.0
826
+ top_k = 40
827
+ yg = raw_model.generate(xg, max_new_tokens, temperature=temperature, top_k=top_k)
828
+ print0('---------------')
829
+ print0(enc.decode(yg[0].tolist()))
830
+ print0('---------------')
831
+
832
+ # bit confusing: we want to make sure to eval and sample on 0th iteration
833
+ # but also after the very last iteration. so we loop for step <= num_iterations
834
+ # instead of just < num_iterations (one extra due to <=), only to do
835
+ # the validation/sampling one last time, and then we break right here as we're done.
836
+ if last_step:
837
+ break
838
+
839
+ # --------------- TRAINING SECTION BEGIN -----------------
840
+ model.train()
841
+ optimizer.zero_grad(set_to_none=True)
842
+ # if we are trying to overfit a single batch, we reset the loader here
843
+ if args.overfit_single_batch:
844
+ train_loader.reset()
845
+ # micro-batch loop where we do gradient accumulation to reach desired total batch size
846
+ lossf = 0.0 # for getting the mean loss (as simple float) over the accumulation steps
847
+ for micro_step in range(grad_accum_steps):
848
+ # fetch a batch
849
+ x, y = train_loader.next_batch()
850
+ x, y = x.to(device), y.to(device)
851
+ if ddp:
852
+ # we want only the last micro-step to sync grads in a DDP model
853
+ # the official way to do this is with model.no_sync(), but that is a
854
+ # context manager that bloats the code, so we just toggle this variable
855
+ model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1)
856
+ # forward pass
857
+ with ctx:
858
+ _, loss = model(x, y, return_logits=False)
859
+ # we have to scale the loss to account for gradient accumulation,
860
+ # because the gradients just add on each successive backward().
861
+ # addition of gradients corresponds to a SUM in the objective, but
862
+ # instead of a SUM we want MEAN, so we scale the loss here
863
+ loss = loss / grad_accum_steps
864
+ lossf += loss.detach() # keep track of the mean loss
865
+ # backward pass
866
+ if not args.inference_only:
867
+ loss.backward()
868
+ if ddp:
869
+ dist.all_reduce(lossf, op=dist.ReduceOp.AVG)
870
+ lossf = lossf.item()
871
+ norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
872
+ # determine and set the learning rate for this iteration
873
+ lr = get_lr(step)
874
+ for param_group in optimizer.param_groups:
875
+ param_group['lr'] = lr
876
+ # step the optimizer
877
+ optimizer.step()
878
+ # --------------- TRAINING SECTION END -------------------
879
+ # everything that follows now is just diagnostics, prints, logging, etc.
880
+
881
+ # wait on the CPU for all device work to end so we get accurate per-iteration timings below
882
+ if device == "mps":
883
+ torch.mps.synchronize()
884
+ elif device == "cuda":
885
+ torch.cuda.synchronize()
886
+ # time and print
887
+ t1 = time.time()
888
+ # the 0th iteration is often an outlier (much slower) => skip logging it
889
+ tokens_per_second = grad_accum_steps * ddp_world_size * B * T / (t1-t0)
890
+ print0(f"step {step+1:4d}/{args.num_iterations} | train loss {lossf:.6f} | norm {norm:.4f} | lr {lr:.2e} | ({(t1-t0)*1000:.2f} ms | {tokens_per_second:.0f} tok/s)")
891
+ # log to logile
892
+ if master_process and logfile is not None:
893
+ with open(logfile, "a") as f:
894
+ f.write("s:%d trl:%f\n" % (step, lossf))
895
+
896
+ # keep track of smooth timings, last 20 iterations
897
+ if step > 0 and step > args.num_iterations - 20:
898
+ timings.append(t1-t0)
899
+
900
+ # print the average of the last 20 timings, to get something smooth-ish
901
+ timings = timings[-20:]
902
+ print0(f"final {len(timings)} iters avg: {np.mean(timings)*1000:.3f}ms")
903
+ print0(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")
904
+
905
+ # -------------------------------------------------------------------------
906
+ # clean up nice
907
+ if ddp:
908
+ destroy_process_group()