Spaces:
Sleeping
Sleeping
File size: 7,703 Bytes
285b9d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
# Copyright 2022 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TrainingTask encapsulates the state associated with model step."""
import time
from typing import (Any, Callable, Dict, Iterator, Mapping, Optional, Tuple)
from absl import logging
from clu import metric_writers
from flax import optim
from flax import struct
import jax
import metrics_summary
import numpy as np
@struct.dataclass
class TrainState:
optimizer: optim.Optimizer # Trainable parameters.
state: Any # Other state, e.g. XL cache or memory.
PRNGKeys = Any
Metrics = Dict[str, Any]
MetricsSummary = metrics_summary.MetricsSummary
Dataset = Callable[[], Iterator[Any]]
StepFunction = Callable[[TrainState, Any, Any], Tuple[TrainState, Metrics]]
PrettyPrintInputFunction = Optional[Callable[[Any], str]]
ProcessSummariesFunction = Optional[Callable[[Any, str], Any]]
ExtraSummariesFunction = Optional[Callable[[str, int], Mapping[str, Any]]]
def should_run(step: int, every_steps: int) -> bool:
"""Returns true if a periodic action should be run."""
return (step > 0) and (every_steps > 0) and (step % every_steps == 0)
class TrainingTask:
"""A TrainingTask encapsulates the state associated with a training task.
Examples of tasks include training steps, test or validation runs,
or inference (generation). State includes the input pipeline, and
summary information that is averaged over multiple steps.
"""
def __init__(
self,
*, # Pass arguments by keyword only.
mode: str,
dataset: Dataset,
step_function: StepFunction,
prng_keys: PRNGKeys,
summary: MetricsSummary,
extra_summary: MetricsSummary,
summary_writer: metric_writers.MetricWriter,
summary_prefix: str = "",
# --- Options from TrainingLoop ---
replicate_mode: bool = True,
print_input_every_steps: int = 0,
pretty_print_input_function: PrettyPrintInputFunction = None,
process_summaries_function: ProcessSummariesFunction = None,
extra_summaries_function: Optional[ExtraSummariesFunction] = None):
# Local state.
self.mode = mode
self.dataset = dataset
self.step_function = step_function
self.prng_keys = prng_keys
self.summary = summary
self.extra_summary = extra_summary
self.summary_writer = summary_writer
self.summary_prefix = summary_prefix
# Options carried over from TrainingLoop.
self.replicate_mode = replicate_mode
self.print_input_every_steps = print_input_every_steps
self.pretty_print_input_fn = pretty_print_input_function
self.process_summaries_fn = process_summaries_function
self.extra_summaries_fn = extra_summaries_function
# Local state.
if self.dataset is not None:
self.ds_iterator = self.dataset()
self.epoch = 0
def _get_metrics(self, device_metrics: Metrics) -> Metrics:
"""Read a dictionary of metrics from device."""
if self.replicate_mode:
# x[0] gets the metric from device 0 -- the first replica.
# We assume that merge_replicated_metrics has already combined the
# metrics from multiple devices.
device_metrics = jax.tree_map(lambda x: x[0], device_metrics)
metrics_np = jax.device_get(device_metrics) # Get numpy arrays.
return metrics_np
def get_next_input(self) -> Any:
"""Grab the next input from the data pipeline."""
if self.dataset is None:
logging.warning("No dataset for mode %s", self.mode)
return None
try:
x = next(self.ds_iterator)
except StopIteration:
logging.info("End of epoch %d for mode %s.", self.epoch, self.mode)
self.ds_iterator = self.dataset()
x = next(self.ds_iterator)
self.epoch += 1
return x
def run_step(self, tstate: TrainState, x: Any,
step: int, sub_step: int = 0) -> Tuple[TrainState, Metrics]:
"""Run the model for a single step.
Args:
tstate: The current model state.
x: The input for the model -- from get_next_input.
step: The training step number.
sub_step: For tasks that run multiple iterations within a step.
E.g. A test cycle will call run_step multiple times to cover the test
set. The step counter will not increment, but sub_step will.
Returns:
An updated model state.
"""
start_time = time.perf_counter()
# Split a batch of inputs among local replicas.
if self.replicate_mode:
x = split_batch_dimension(x, jax.local_device_count())
# Pretty-print the input to the summary and log file every so often.
if (sub_step == 0 and self.pretty_print_input_fn is not None and
should_run(step, self.print_input_every_steps)):
x_first = jax.tree_map(lambda x: x[0], x) if self.replicate_mode else x
x_strs = self.pretty_print_input_fn(x_first)
logging.info("[%d] Input (%s) = %s", step, self.mode, x_strs)
self.summary.add_text({"input": x_strs})
# Run the step function on the input.
with jax.profiler.StepTraceAnnotation(self.mode, step_num=step):
(tstate, metrics) = self.step_function(tstate, x, self.prng_keys)
# Read metrics from device.
metrics_np = self._get_metrics(metrics)
end_time = time.perf_counter()
metrics_np["step_time"] = end_time - start_time
if "epoch" not in metrics_np.keys():
metrics_np["epoch"] = self.epoch
# Add metrics to the current summary.
self.summary.add(metrics_np)
return (tstate, metrics_np)
def flush(self, step: int):
"""Flush accumulated metric summaries to disk."""
if self.summary_writer is None:
self.summary.clear() # Clear summary if we can't write it.
return
if self.summary.empty():
return
# Do post-processing of the summaries.
if self.process_summaries_fn is not None:
self.summary = self.process_summaries_fn(self.summary, self.mode) # pylint: disable=not-callable
# Write and clear summary data.
logging.info("Writing summaries for mode %s.", self.mode)
self.summary.write(self.summary_writer, step, prefix=self.summary_prefix)
# Add extra summaries that are not computed by the step function.
if self.extra_summaries_fn is not None:
self.extra_summary.add(self.extra_summaries_fn(self.mode, step))
self.extra_summary.write(self.summary_writer, step, prefix="")
def split_batch_dimension(inputs: Any, num_replicas: int) -> Any:
"""Splits the leading batch dimension.
Given inputs of shape [num_replicas * batch_size, ...], it will reshape
them to [num_replicas, batch_size, ...]. This operation is intended to be
used right before calling pmap, which will eliminate the num_replicas
dimension.
Args:
inputs: Tuple of inputs to split.
num_replicas: Number of replicas.
Returns:
inputs with extra batch dimension.
"""
def split_batch_dim(x):
assert x.ndim > 0
if (x.shape[0] % num_replicas) != 0:
raise ValueError(f"Can't split {x.shape} into {num_replicas} replicas.")
batch_size = x.shape[0] // num_replicas
split_shape = [num_replicas, batch_size] + list(x.shape[1:])
return np.reshape(x, split_shape)
return jax.tree_map(split_batch_dim, inputs)
|