File size: 9,974 Bytes
285b9d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
# Copyright 2022 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Class to handle summarizing of metrics over multiple training steps."""

import abc
from typing import Any, Dict, Mapping, Optional, Tuple, Union
from absl import logging
from clu import metric_writers
import gin
import jax
from jax import numpy as jnp
import numpy as np


Array = Union[jnp.ndarray, np.ndarray]


class Aggregator(abc.ABC):  # Superclass for type checks

  @abc.abstractmethod
  def add(self, value: Any):
    pass

  @abc.abstractmethod
  def is_valid(self) -> bool:
    pass

  @abc.abstractmethod
  def to_value(self):
    pass


class _MeanAggregator(Aggregator):
  """Maintains the mean of incoming values."""
  mean: float = 0.0
  weight: float = 0.0

  def add(self, new_value: Any):
    """Aggregates a new value into the mean."""
    if np.ndim(new_value) == 0:  # is a scalar; works with int, float, Array
      val, weight = new_value, 1.0  # assuming weight 1 by default
    else:
      val, weight = new_value
    if weight < 0.0:
      raise ValueError("Adding value with negative weight.")
    total_weight = self.weight + weight
    if total_weight != 0.0 and weight > 0.0:
      delta = (val - self.mean) * weight / total_weight
      self.mean += delta
      self.weight = total_weight

  def is_valid(self) -> bool:
    return self.weight > 0.0

  def to_value(self):
    assert self.weight > 0.0
    return self.mean


class _SumAggregator(_MeanAggregator):
  # We aggregate sum and mean in the same way as a tuple of the form:
  # (weighted mean, total weights). "sum" can then be computed by
  # multiplying the two values.

  def is_valid(self) -> bool:
    return True

  def to_value(self):
    return self.mean * self.weight


class _LastAggregator(Aggregator):
  """Remembers the last value given."""
  last_value: Optional[float] = None

  def add(self, new_value: Any):
    self.last_value = new_value

  def is_valid(self) -> bool:
    return self.last_value is not None

  def to_value(self):
    assert self.last_value is not None
    return self.last_value


@gin.configurable
class MetricsSummary:
  """Summarizes a set of a metrics over multiple training steps."""

  def __init__(self,
               metric_types: Mapping[str, str],
               upscale_images: bool = True,
               remove_outliers: bool = False):
    """Creates a MetricSummarizer.

    Args:
      metric_types: Map from metrics to the type of summary.  Types are:
         "mean" = Compute the cumulative moving average.
         "sum" =  Compute the sum.
         "last" = No summary, just return the last value.
      upscale_images: Upscale small images for easier viewing.
      remove_outliers: Remove outliers from histograms.
    """
    self.metric_dict = {}  # type: Dict[str, Aggregator]
    self.text_dict = {}
    self.metric_types = metric_types
    self.upscale_images = upscale_images
    self.remove_outliers = remove_outliers
    self.constructor_map = {
        "mean": _MeanAggregator,
        "sum": _SumAggregator,
        "last": _LastAggregator,
    }
    logging.debug("Registered metrics: %r", metric_types)

  def current_metric_dict(self) -> Mapping[str, Aggregator]:
    return self.metric_dict

  def _is_image(self, image: Array) -> bool:
    if image.ndim != 4:
      return False
    # Greyscale or RGB image.
    return image.shape[-1] == 1 or image.shape[-1] == 3

  def _upscale_image(self, image: Array) -> Array:
    """Upscale small images to more pixels, for easier viewing."""
    if not self.upscale_images:
      return image
    assert image.ndim == 4  # (num_images, ysize, xsize, num_channels)
    ys = image.shape[1]
    xs = image.shape[2]
    if xs > 512 or ys > 512:
      return image   # No scaling.
    elif xs > 256 or ys > 256:
      scale = 2
    else:
      scale = 4
    yidx = np.arange(ys * scale) // scale
    xidx = np.arange(xs * scale) // scale
    scaled_image = image[:, yidx, :, :][:, :, xidx, :]
    return scaled_image

  def _remove_outliers(self, v, std_range: float = 4):
    if not self.remove_outliers:
      return v
    v_mean = np.mean(v)
    v_std = np.std(v)
    return np.where(np.abs(v) > (v_std * std_range), v_mean, v)

  @staticmethod
  def merge_replicated_metrics(device_metrics: Mapping[str, Any],
                               metric_types: Mapping[str, str]):
    """Merge metrics across devices by psum over "batch" axis.

    Args:
      device_metrics: dictionary of device metrics.
      metric_types: map from the metric name to { "mean", "sum" }

    Returns:
      A dictionary of metrics.
    """
    logging.info("Merging metrics across devices %r: ",
                 [(k, metric_types[k] if k in metric_types else None)
                  for k in device_metrics.keys()])

    def aggregate_sum(value: Array) -> Array:
      assert not isinstance(value, tuple), (
          "Weighted sums are not supported when aggregating over devices.")
      return jax.lax.psum(value, axis_name="batch")

    def aggregate_mean(value: Array, weight: Array) -> Tuple[Array, Array]:
      weighted_value = value * weight
      weighted_value = jax.lax.psum(weighted_value, axis_name="batch")
      weight = jax.lax.psum(weight, axis_name="batch")
      return weighted_value / (weight + 1.0e-6), weight

    aggregated_metrics = dict(device_metrics)
    for k, value in aggregated_metrics.items():
      if k not in metric_types:
        # If no metric type is given, metric remains untouched.
        continue
      if metric_types[k] == "sum":
        aggregated_metrics[k] = aggregate_sum(value)
      elif metric_types[k] == "mean":
        if not isinstance(aggregated_metrics[k], tuple):
          logging.info("Metric '%s' has no weight; assuming 1.0.", k)
          value = (value, jnp.array(1.0))
        aggregated_metrics[k] = aggregate_mean(*value)
      else:
        raise ValueError("Can only aggregate 'sum' and 'mean' over devices. "
                         f"Got {metric_types[k]}.")
    return aggregated_metrics

  def _new_aggregator(self, key) -> Aggregator:
    if key in self.metric_types:
      return self.constructor_map[self.metric_types[key]]()
    else:
      # TODO(mrabe): The default to last_value is not obvious. Force all metric
      # types to be given explicitly.
      logging.debug("No metric type for accumulator: %s", key)
      return _LastAggregator()

  def add(self, metrics: Mapping[str, Any]):
    """Add metrics from the current training step to the summary.

    Args:
      metrics: Dictionary of metrics.
    """
    for k, new_value in metrics.items():
      if k not in self.metric_dict:
        self.metric_dict[k] = self._new_aggregator(k)
      self.metric_dict[k].add(new_value)

  def add_text(self, text_metrics: Mapping[str, str]):
    """Add text metrics from the current step to the summary."""
    for (k, v) in text_metrics.items():
      self.text_dict[k] = str(v)

  def empty(self):
    """Return true if there are no summaries to write."""
    return not (self.metric_dict or self.text_dict)

  def clear(self):
    """Clear acculumated summaries."""
    self.metric_dict = {}
    self.text_dict = {}

  def write(self, writer: metric_writers.MetricWriter, step: int, prefix: str):
    """Write metrics using summary_writer, and clear all summaries."""
    if self.empty():
      return

    # Special logic for organizing metrics under tensorboard.
    # Tensorboard has top-level groups, but doesn't have subgroups.
    # Scalars are put into separate top-level groups for easier viewing.
    # e.g. all scalars in "train", "test", etc.
    # For images, each set of images should be a different top-level group,
    # otherwise all images will get tossed into a single group under,
    # e.g. "generate".
    if prefix:
      s_prefix = prefix + "/"
      i_prefix = prefix + "_"
    else:
      # Each prefix is stored in a separate subdirectory already.
      s_prefix = ""
      i_prefix = ""

    # Split metrics into different types.
    scalars = {}
    images = {}
    histograms = {}
    text_dict = {}

    # Sort metrics into scalars, images, text, and histograms.
    for k, aggregator in self.metric_dict.items():
      if not isinstance(aggregator, Aggregator):
        raise ValueError("Internal error: metric_dict should contain only "
                         "_Aggregator objects; contained %s" % aggregator)
      if not aggregator.is_valid():
        raise ValueError(f"No valid value for metric {k}.")

      v = aggregator.to_value()

      s_key = s_prefix + k
      i_key = i_prefix + k

      finite_mask = np.isfinite(v)
      if not np.all(finite_mask):
        logging.warning("Item %s contains non-finite elements.", k)
        v = np.where(finite_mask, v, np.zeros_like(v))
      if v is None:
        logging.warning("Invalid value for %s", k)
      elif np.ndim(v) == 0:
        scalars[s_key] = v
      elif self._is_image(v):
        images[i_key] = self._upscale_image(v)
      else:
        histograms[s_key] = self._remove_outliers(v)

    # Handle text data.
    for (k, v) in self.text_dict.items():
      s_key = s_prefix + k
      text_dict[s_key] = v

    # Write metrics.
    if scalars:
      writer.write_scalars(step, scalars)
    if images:
      writer.write_images(step, images)
    if histograms:
      writer.write_histograms(step, histograms)
    if text_dict:
      writer.write_texts(step, text_dict)

    # Clear accumulated summaries.
    self.clear()