Spaces:
Runtime error
Runtime error
feat: log metrics more frequently
Browse files- seq2seq/run_seq2seq_flax.py +16 -3
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -215,6 +215,13 @@ class DataTrainingArguments:
|
|
215 |
overwrite_cache: bool = field(
|
216 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
217 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
|
219 |
def __post_init__(self):
|
220 |
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
@@ -307,12 +314,12 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
|
307 |
|
308 |
train_metrics = get_metrics(train_metrics)
|
309 |
for key, vals in train_metrics.items():
|
310 |
-
tag = f"
|
311 |
for i, val in enumerate(vals):
|
312 |
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
313 |
|
314 |
for metric_name, value in eval_metrics.items():
|
315 |
-
summary_writer.scalar(f"
|
316 |
|
317 |
|
318 |
def create_learning_rate_fn(
|
@@ -718,6 +725,7 @@ def main():
|
|
718 |
|
719 |
train_time = 0
|
720 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
|
|
721 |
for epoch in epochs:
|
722 |
# ======================== Training ================================
|
723 |
train_start = time.time()
|
@@ -730,11 +738,16 @@ def main():
|
|
730 |
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
731 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
732 |
# train
|
733 |
-
for
|
|
|
734 |
batch = next(train_loader)
|
735 |
state, train_metric = p_train_step(state, batch)
|
736 |
train_metrics.append(train_metric)
|
737 |
|
|
|
|
|
|
|
|
|
738 |
train_time += time.time() - train_start
|
739 |
|
740 |
train_metric = unreplicate(train_metric)
|
|
|
215 |
overwrite_cache: bool = field(
|
216 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
217 |
)
|
218 |
+
log_interval: Optional[int] = field(
|
219 |
+
default=500,
|
220 |
+
metadata={
|
221 |
+
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
222 |
+
"value if set."
|
223 |
+
},
|
224 |
+
)
|
225 |
|
226 |
def __post_init__(self):
|
227 |
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
|
|
314 |
|
315 |
train_metrics = get_metrics(train_metrics)
|
316 |
for key, vals in train_metrics.items():
|
317 |
+
tag = f"train_epoch/{key}"
|
318 |
for i, val in enumerate(vals):
|
319 |
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
320 |
|
321 |
for metric_name, value in eval_metrics.items():
|
322 |
+
summary_writer.scalar(f"eval/{metric_name}", value, step)
|
323 |
|
324 |
|
325 |
def create_learning_rate_fn(
|
|
|
725 |
|
726 |
train_time = 0
|
727 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
728 |
+
global_step = 0
|
729 |
for epoch in epochs:
|
730 |
# ======================== Training ================================
|
731 |
train_start = time.time()
|
|
|
738 |
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
739 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
740 |
# train
|
741 |
+
for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
742 |
+
global_step +=1
|
743 |
batch = next(train_loader)
|
744 |
state, train_metric = p_train_step(state, batch)
|
745 |
train_metrics.append(train_metric)
|
746 |
|
747 |
+
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
748 |
+
for k, v in unreplicate(train_metric).items():
|
749 |
+
wandb.log(f{'train/{k}': jax.device_get(v), step=global_step)
|
750 |
+
|
751 |
train_time += time.time() - train_start
|
752 |
|
753 |
train_metric = unreplicate(train_metric)
|