File size: 7,703 Bytes
285b9d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
# Copyright 2022 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""TrainingTask encapsulates the state associated with model step."""

import time
from typing import (Any, Callable, Dict, Iterator, Mapping, Optional, Tuple)

from absl import logging
from clu import metric_writers
from flax import optim
from flax import struct
import jax
import  metrics_summary
import numpy as np


@struct.dataclass
class TrainState:
  optimizer: optim.Optimizer   # Trainable parameters.
  state: Any                   # Other state, e.g. XL cache or memory.


PRNGKeys = Any
Metrics = Dict[str, Any]
MetricsSummary = metrics_summary.MetricsSummary

Dataset = Callable[[], Iterator[Any]]
StepFunction = Callable[[TrainState, Any, Any], Tuple[TrainState, Metrics]]
PrettyPrintInputFunction = Optional[Callable[[Any], str]]
ProcessSummariesFunction = Optional[Callable[[Any, str], Any]]
ExtraSummariesFunction = Optional[Callable[[str, int], Mapping[str, Any]]]


def should_run(step: int, every_steps: int) -> bool:
  """Returns true if a periodic action should be run."""
  return (step > 0) and (every_steps > 0) and (step % every_steps == 0)


class TrainingTask:
  """A TrainingTask encapsulates the state associated with a training task.

  Examples of tasks include training steps, test or validation runs,
  or inference (generation).  State includes the input pipeline, and
  summary information that is averaged over multiple steps.
  """

  def __init__(
      self,
      *,  # Pass arguments by keyword only.
      mode: str,
      dataset: Dataset,
      step_function: StepFunction,
      prng_keys: PRNGKeys,
      summary: MetricsSummary,
      extra_summary: MetricsSummary,
      summary_writer: metric_writers.MetricWriter,
      summary_prefix: str = "",
      # --- Options from TrainingLoop ---
      replicate_mode: bool = True,
      print_input_every_steps: int = 0,
      pretty_print_input_function: PrettyPrintInputFunction = None,
      process_summaries_function: ProcessSummariesFunction = None,
      extra_summaries_function: Optional[ExtraSummariesFunction] = None):
    # Local state.
    self.mode = mode
    self.dataset = dataset
    self.step_function = step_function
    self.prng_keys = prng_keys
    self.summary = summary
    self.extra_summary = extra_summary
    self.summary_writer = summary_writer
    self.summary_prefix = summary_prefix

    # Options carried over from TrainingLoop.
    self.replicate_mode = replicate_mode
    self.print_input_every_steps = print_input_every_steps
    self.pretty_print_input_fn = pretty_print_input_function
    self.process_summaries_fn = process_summaries_function
    self.extra_summaries_fn = extra_summaries_function

    # Local state.
    if self.dataset is not None:
      self.ds_iterator = self.dataset()
    self.epoch = 0

  def _get_metrics(self, device_metrics: Metrics) -> Metrics:
    """Read a dictionary of metrics from device."""
    if self.replicate_mode:
      # x[0] gets the metric from device 0 -- the first replica.
      # We assume that merge_replicated_metrics has already combined the
      # metrics from multiple devices.
      device_metrics = jax.tree_map(lambda x: x[0], device_metrics)
    metrics_np = jax.device_get(device_metrics)  # Get numpy arrays.
    return metrics_np

  def get_next_input(self) -> Any:
    """Grab the next input from the data pipeline."""
    if self.dataset is None:
      logging.warning("No dataset for mode %s", self.mode)
      return None

    try:
      x = next(self.ds_iterator)
    except StopIteration:
      logging.info("End of epoch %d for mode %s.", self.epoch, self.mode)
      self.ds_iterator = self.dataset()
      x = next(self.ds_iterator)
      self.epoch += 1
    return x

  def run_step(self, tstate: TrainState, x: Any,
               step: int, sub_step: int = 0) -> Tuple[TrainState, Metrics]:
    """Run the model for a single step.

    Args:
      tstate: The current model state.
      x: The input for the model -- from get_next_input.
      step: The training step number.
      sub_step: For tasks that run multiple iterations within a step.
        E.g. A test cycle will call run_step multiple times to cover the test
        set.  The step counter will not increment, but sub_step will.

    Returns:
      An updated model state.
    """

    start_time = time.perf_counter()

    # Split a batch of inputs among local replicas.
    if self.replicate_mode:
      x = split_batch_dimension(x, jax.local_device_count())

    # Pretty-print the input to the summary and log file every so often.
    if (sub_step == 0 and self.pretty_print_input_fn is not None and
        should_run(step, self.print_input_every_steps)):
      x_first = jax.tree_map(lambda x: x[0], x) if self.replicate_mode else x
      x_strs = self.pretty_print_input_fn(x_first)
      logging.info("[%d] Input (%s) = %s", step, self.mode, x_strs)
      self.summary.add_text({"input": x_strs})

    # Run the step function on the input.
    with jax.profiler.StepTraceAnnotation(self.mode, step_num=step):
      (tstate, metrics) = self.step_function(tstate, x, self.prng_keys)

    # Read metrics from device.
    metrics_np = self._get_metrics(metrics)
    end_time = time.perf_counter()
    metrics_np["step_time"] = end_time - start_time
    if "epoch" not in metrics_np.keys():
      metrics_np["epoch"] = self.epoch

    # Add metrics to the current summary.
    self.summary.add(metrics_np)
    return (tstate, metrics_np)

  def flush(self, step: int):
    """Flush accumulated metric summaries to disk."""

    if self.summary_writer is None:
      self.summary.clear()  # Clear summary if we can't write it.
      return

    if self.summary.empty():
      return

    # Do post-processing of the summaries.
    if self.process_summaries_fn is not None:
      self.summary = self.process_summaries_fn(self.summary, self.mode)  # pylint: disable=not-callable

    # Write and clear summary data.
    logging.info("Writing summaries for mode %s.", self.mode)
    self.summary.write(self.summary_writer, step, prefix=self.summary_prefix)

    # Add extra summaries that are not computed by the step function.
    if self.extra_summaries_fn is not None:
      self.extra_summary.add(self.extra_summaries_fn(self.mode, step))
      self.extra_summary.write(self.summary_writer, step, prefix="")


def split_batch_dimension(inputs: Any, num_replicas: int) -> Any:
  """Splits the leading batch dimension.

  Given inputs of shape [num_replicas * batch_size, ...], it will reshape
  them to [num_replicas, batch_size, ...].  This operation is intended to be
  used right before calling pmap, which will eliminate the num_replicas
  dimension.

  Args:
    inputs: Tuple of inputs to split.
    num_replicas: Number of replicas.

  Returns:
    inputs with extra batch dimension.
  """

  def split_batch_dim(x):
    assert x.ndim > 0
    if (x.shape[0] % num_replicas) != 0:
      raise ValueError(f"Can't split {x.shape} into {num_replicas} replicas.")
    batch_size = x.shape[0] // num_replicas
    split_shape = [num_replicas, batch_size] + list(x.shape[1:])
    return np.reshape(x, split_shape)

  return jax.tree_map(split_batch_dim, inputs)