MasaakiKotera commited on
Commit
b612eb3
·
verified ·
1 Parent(s): 1b3f51e

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +17 -35
train.py CHANGED
@@ -41,9 +41,6 @@ log_and_write(log_dir, f'training data: {data_dir}')
41
  # -----------------------------------------------------------------------------
42
 
43
 
44
- # various inits, derived attributes, I/O setup
45
- # ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
46
-
47
  ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
48
  if ddp:
49
  init_process_group(backend=backend)
@@ -53,13 +50,10 @@ if ddp:
53
  device = f'cuda:{ddp_local_rank}'
54
  torch.cuda.set_device(device)
55
  master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
56
- seed_offset = ddp_rank # each process gets a different seed
57
- # world_size number of processes will be training simultaneously, so we can scale
58
- # down the desired gradient accumulation iterations per process proportionally
59
  assert gradient_accumulation_steps % ddp_world_size == 0
60
  gradient_accumulation_steps //= ddp_world_size
61
  else:
62
- # if not ddp, we are running on a single gpu, and one process
63
  master_process = True
64
  seed_offset = 0
65
  ddp_world_size = 1
@@ -85,7 +79,6 @@ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torc
85
  ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
86
 
87
  # data loader
88
- # data_dir = os.path.join('data', dataset)
89
  train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
90
  val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
91
  def get_batch(split):
@@ -100,7 +93,6 @@ def get_batch(split):
100
  x, y = x.to(device), y.to(device)
101
  return x, y
102
 
103
- # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
104
  iter_num = 0
105
  best_val_loss = 1e9
106
 
@@ -127,7 +119,6 @@ if init_from == 'scratch':
127
  elif init_from == 'resume':
128
  print(f"Resuming training from {out_dir}")
129
  # resume training from a checkpoint.
130
- # ckpt_path = os.path.join(out_dir, 'ckpt.pt')
131
  checkpoint = torch.load(ckpt_path, map_location=device)
132
  checkpoint_model_args = checkpoint['model_args']
133
  # force these config attributes to be equal otherwise we can't even resume training
@@ -138,8 +129,6 @@ elif init_from == 'resume':
138
  gptconf = GPTConfig(**model_args)
139
  model = GPT(gptconf)
140
  state_dict = checkpoint['model']
141
- # fix the keys of the state dictionary :(
142
- # honestly no idea how checkpoints sometimes get this prefix, have to debug more
143
  unwanted_prefix = '_orig_mod.'
144
  for k,v in list(state_dict.items()):
145
  if k.startswith(unwanted_prefix):
@@ -147,14 +136,6 @@ elif init_from == 'resume':
147
  model.load_state_dict(state_dict)
148
  iter_num = checkpoint['iter_num']
149
  best_val_loss = checkpoint['best_val_loss']
150
- elif init_from.startswith('gpt2'):
151
- print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
152
- # initialize from OpenAI GPT-2 weights
153
- override_args = dict(dropout=dropout)
154
- model = GPT.from_pretrained(init_from, override_args)
155
- # read off the created config params, so we can store them into checkpoint correctly
156
- for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
157
- model_args[k] = getattr(model.config, k)
158
  # crop down the model block size if desired, using model surgery
159
  if block_size < model.config.block_size:
160
  model.crop_block_size(block_size)
@@ -188,7 +169,7 @@ def estimate_loss():
188
  model.eval()
189
  for split in ['train', 'val']:
190
  losses = torch.zeros(eval_iters)
191
- total_loss = 0 # 用于计算perplexity
192
  for k in range(eval_iters):
193
  X, Y = get_batch(split)
194
  with ctx:
@@ -197,7 +178,7 @@ def estimate_loss():
197
  total_loss += loss.item()
198
  avg_loss = losses.mean()
199
  out[split] = avg_loss
200
- perplexities[split] = torch.exp(avg_loss) # 计算perplexity
201
  model.train()
202
  return out, perplexities
203
 
@@ -235,19 +216,20 @@ while True:
235
  log_and_write(log_dir, f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f},train perplexity: {perplexities['train']:.4f}, val perplexity: {perplexities['val']:.4f}")
236
  if iter_num % 200 == 0:
237
  print_gpu_memory_usage()
238
- if losses['val'] < best_val_loss or always_save_checkpoint:
239
- best_val_loss = losses['val']
240
- if iter_num > 0:
241
- checkpoint = {
242
- 'model': raw_model.state_dict(),
243
- 'optimizer': optimizer.state_dict(),
244
- 'model_args': model_args,
245
- 'iter_num': iter_num,
246
- 'best_val_loss': best_val_loss,
247
- 'config': config,
248
- }
249
- log_and_write(log_dir, f"saving checkpoint to {out_dir}")
250
- torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{iter_num}.pt'))
 
251
  if iter_num == 0 and eval_only:
252
  break
253
 
 
41
  # -----------------------------------------------------------------------------
42
 
43
 
 
 
 
44
  ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
45
  if ddp:
46
  init_process_group(backend=backend)
 
50
  device = f'cuda:{ddp_local_rank}'
51
  torch.cuda.set_device(device)
52
  master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
53
+ seed_offset = ddp_rank
 
 
54
  assert gradient_accumulation_steps % ddp_world_size == 0
55
  gradient_accumulation_steps //= ddp_world_size
56
  else:
 
57
  master_process = True
58
  seed_offset = 0
59
  ddp_world_size = 1
 
79
  ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
80
 
81
  # data loader
 
82
  train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
83
  val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
84
  def get_batch(split):
 
93
  x, y = x.to(device), y.to(device)
94
  return x, y
95
 
 
96
  iter_num = 0
97
  best_val_loss = 1e9
98
 
 
119
  elif init_from == 'resume':
120
  print(f"Resuming training from {out_dir}")
121
  # resume training from a checkpoint.
 
122
  checkpoint = torch.load(ckpt_path, map_location=device)
123
  checkpoint_model_args = checkpoint['model_args']
124
  # force these config attributes to be equal otherwise we can't even resume training
 
129
  gptconf = GPTConfig(**model_args)
130
  model = GPT(gptconf)
131
  state_dict = checkpoint['model']
 
 
132
  unwanted_prefix = '_orig_mod.'
133
  for k,v in list(state_dict.items()):
134
  if k.startswith(unwanted_prefix):
 
136
  model.load_state_dict(state_dict)
137
  iter_num = checkpoint['iter_num']
138
  best_val_loss = checkpoint['best_val_loss']
 
 
 
 
 
 
 
 
139
  # crop down the model block size if desired, using model surgery
140
  if block_size < model.config.block_size:
141
  model.crop_block_size(block_size)
 
169
  model.eval()
170
  for split in ['train', 'val']:
171
  losses = torch.zeros(eval_iters)
172
+ total_loss = 0
173
  for k in range(eval_iters):
174
  X, Y = get_batch(split)
175
  with ctx:
 
178
  total_loss += loss.item()
179
  avg_loss = losses.mean()
180
  out[split] = avg_loss
181
+ perplexities[split] = torch.exp(avg_loss)
182
  model.train()
183
  return out, perplexities
184
 
 
216
  log_and_write(log_dir, f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f},train perplexity: {perplexities['train']:.4f}, val perplexity: {perplexities['val']:.4f}")
217
  if iter_num % 200 == 0:
218
  print_gpu_memory_usage()
219
+ if always_save_checkpoint:
220
+ if losses['val'] < best_val_loss or always_save_checkpoint:
221
+ best_val_loss = losses['val']
222
+ if iter_num > 0:
223
+ checkpoint = {
224
+ 'model': raw_model.state_dict(),
225
+ 'optimizer': optimizer.state_dict(),
226
+ 'model_args': model_args,
227
+ 'iter_num': iter_num,
228
+ 'best_val_loss': best_val_loss,
229
+ 'config': config,
230
+ }
231
+ log_and_write(log_dir, f"saving checkpoint to {out_dir}")
232
+ torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{iter_num}.pt'))
233
  if iter_num == 0 and eval_only:
234
  break
235