boris commited on
Commit
498559f
·
1 Parent(s): ecafe5e

feat: log metrics more frequently

Browse files
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +16 -3
seq2seq/run_seq2seq_flax.py CHANGED
@@ -215,6 +215,13 @@ class DataTrainingArguments:
215
  overwrite_cache: bool = field(
216
  default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
217
  )
 
 
 
 
 
 
 
218
 
219
  def __post_init__(self):
220
  if self.dataset_name is None and self.train_file is None and self.validation_file is None:
@@ -307,12 +314,12 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
307
 
308
  train_metrics = get_metrics(train_metrics)
309
  for key, vals in train_metrics.items():
310
- tag = f"train_{key}"
311
  for i, val in enumerate(vals):
312
  summary_writer.scalar(tag, val, step - len(vals) + i + 1)
313
 
314
  for metric_name, value in eval_metrics.items():
315
- summary_writer.scalar(f"eval_{metric_name}", value, step)
316
 
317
 
318
  def create_learning_rate_fn(
@@ -718,6 +725,7 @@ def main():
718
 
719
  train_time = 0
720
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
 
721
  for epoch in epochs:
722
  # ======================== Training ================================
723
  train_start = time.time()
@@ -730,11 +738,16 @@ def main():
730
  train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
731
  steps_per_epoch = len(train_dataset) // train_batch_size
732
  # train
733
- for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
 
734
  batch = next(train_loader)
735
  state, train_metric = p_train_step(state, batch)
736
  train_metrics.append(train_metric)
737
 
 
 
 
 
738
  train_time += time.time() - train_start
739
 
740
  train_metric = unreplicate(train_metric)
 
215
  overwrite_cache: bool = field(
216
  default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
217
  )
218
+ log_interval: Optional[int] = field(
219
+ default=500,
220
+ metadata={
221
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
222
+ "value if set."
223
+ },
224
+ )
225
 
226
  def __post_init__(self):
227
  if self.dataset_name is None and self.train_file is None and self.validation_file is None:
 
314
 
315
  train_metrics = get_metrics(train_metrics)
316
  for key, vals in train_metrics.items():
317
+ tag = f"train_epoch/{key}"
318
  for i, val in enumerate(vals):
319
  summary_writer.scalar(tag, val, step - len(vals) + i + 1)
320
 
321
  for metric_name, value in eval_metrics.items():
322
+ summary_writer.scalar(f"eval/{metric_name}", value, step)
323
 
324
 
325
  def create_learning_rate_fn(
 
725
 
726
  train_time = 0
727
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
728
+ global_step = 0
729
  for epoch in epochs:
730
  # ======================== Training ================================
731
  train_start = time.time()
 
738
  train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
739
  steps_per_epoch = len(train_dataset) // train_batch_size
740
  # train
741
+ for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
742
+ global_step +=1
743
  batch = next(train_loader)
744
  state, train_metric = p_train_step(state, batch)
745
  train_metrics.append(train_metric)
746
 
747
+ if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
748
+ for k, v in unreplicate(train_metric).items():
749
+ wandb.log(f{'train/{k}': jax.device_get(v), step=global_step)
750
+
751
  train_time += time.time() - train_start
752
 
753
  train_metric = unreplicate(train_metric)