Spaces:
Sleeping
Sleeping
File size: 9,974 Bytes
285b9d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 |
# Copyright 2022 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class to handle summarizing of metrics over multiple training steps."""
import abc
from typing import Any, Dict, Mapping, Optional, Tuple, Union
from absl import logging
from clu import metric_writers
import gin
import jax
from jax import numpy as jnp
import numpy as np
Array = Union[jnp.ndarray, np.ndarray]
class Aggregator(abc.ABC): # Superclass for type checks
@abc.abstractmethod
def add(self, value: Any):
pass
@abc.abstractmethod
def is_valid(self) -> bool:
pass
@abc.abstractmethod
def to_value(self):
pass
class _MeanAggregator(Aggregator):
"""Maintains the mean of incoming values."""
mean: float = 0.0
weight: float = 0.0
def add(self, new_value: Any):
"""Aggregates a new value into the mean."""
if np.ndim(new_value) == 0: # is a scalar; works with int, float, Array
val, weight = new_value, 1.0 # assuming weight 1 by default
else:
val, weight = new_value
if weight < 0.0:
raise ValueError("Adding value with negative weight.")
total_weight = self.weight + weight
if total_weight != 0.0 and weight > 0.0:
delta = (val - self.mean) * weight / total_weight
self.mean += delta
self.weight = total_weight
def is_valid(self) -> bool:
return self.weight > 0.0
def to_value(self):
assert self.weight > 0.0
return self.mean
class _SumAggregator(_MeanAggregator):
# We aggregate sum and mean in the same way as a tuple of the form:
# (weighted mean, total weights). "sum" can then be computed by
# multiplying the two values.
def is_valid(self) -> bool:
return True
def to_value(self):
return self.mean * self.weight
class _LastAggregator(Aggregator):
"""Remembers the last value given."""
last_value: Optional[float] = None
def add(self, new_value: Any):
self.last_value = new_value
def is_valid(self) -> bool:
return self.last_value is not None
def to_value(self):
assert self.last_value is not None
return self.last_value
@gin.configurable
class MetricsSummary:
"""Summarizes a set of a metrics over multiple training steps."""
def __init__(self,
metric_types: Mapping[str, str],
upscale_images: bool = True,
remove_outliers: bool = False):
"""Creates a MetricSummarizer.
Args:
metric_types: Map from metrics to the type of summary. Types are:
"mean" = Compute the cumulative moving average.
"sum" = Compute the sum.
"last" = No summary, just return the last value.
upscale_images: Upscale small images for easier viewing.
remove_outliers: Remove outliers from histograms.
"""
self.metric_dict = {} # type: Dict[str, Aggregator]
self.text_dict = {}
self.metric_types = metric_types
self.upscale_images = upscale_images
self.remove_outliers = remove_outliers
self.constructor_map = {
"mean": _MeanAggregator,
"sum": _SumAggregator,
"last": _LastAggregator,
}
logging.debug("Registered metrics: %r", metric_types)
def current_metric_dict(self) -> Mapping[str, Aggregator]:
return self.metric_dict
def _is_image(self, image: Array) -> bool:
if image.ndim != 4:
return False
# Greyscale or RGB image.
return image.shape[-1] == 1 or image.shape[-1] == 3
def _upscale_image(self, image: Array) -> Array:
"""Upscale small images to more pixels, for easier viewing."""
if not self.upscale_images:
return image
assert image.ndim == 4 # (num_images, ysize, xsize, num_channels)
ys = image.shape[1]
xs = image.shape[2]
if xs > 512 or ys > 512:
return image # No scaling.
elif xs > 256 or ys > 256:
scale = 2
else:
scale = 4
yidx = np.arange(ys * scale) // scale
xidx = np.arange(xs * scale) // scale
scaled_image = image[:, yidx, :, :][:, :, xidx, :]
return scaled_image
def _remove_outliers(self, v, std_range: float = 4):
if not self.remove_outliers:
return v
v_mean = np.mean(v)
v_std = np.std(v)
return np.where(np.abs(v) > (v_std * std_range), v_mean, v)
@staticmethod
def merge_replicated_metrics(device_metrics: Mapping[str, Any],
metric_types: Mapping[str, str]):
"""Merge metrics across devices by psum over "batch" axis.
Args:
device_metrics: dictionary of device metrics.
metric_types: map from the metric name to { "mean", "sum" }
Returns:
A dictionary of metrics.
"""
logging.info("Merging metrics across devices %r: ",
[(k, metric_types[k] if k in metric_types else None)
for k in device_metrics.keys()])
def aggregate_sum(value: Array) -> Array:
assert not isinstance(value, tuple), (
"Weighted sums are not supported when aggregating over devices.")
return jax.lax.psum(value, axis_name="batch")
def aggregate_mean(value: Array, weight: Array) -> Tuple[Array, Array]:
weighted_value = value * weight
weighted_value = jax.lax.psum(weighted_value, axis_name="batch")
weight = jax.lax.psum(weight, axis_name="batch")
return weighted_value / (weight + 1.0e-6), weight
aggregated_metrics = dict(device_metrics)
for k, value in aggregated_metrics.items():
if k not in metric_types:
# If no metric type is given, metric remains untouched.
continue
if metric_types[k] == "sum":
aggregated_metrics[k] = aggregate_sum(value)
elif metric_types[k] == "mean":
if not isinstance(aggregated_metrics[k], tuple):
logging.info("Metric '%s' has no weight; assuming 1.0.", k)
value = (value, jnp.array(1.0))
aggregated_metrics[k] = aggregate_mean(*value)
else:
raise ValueError("Can only aggregate 'sum' and 'mean' over devices. "
f"Got {metric_types[k]}.")
return aggregated_metrics
def _new_aggregator(self, key) -> Aggregator:
if key in self.metric_types:
return self.constructor_map[self.metric_types[key]]()
else:
# TODO(mrabe): The default to last_value is not obvious. Force all metric
# types to be given explicitly.
logging.debug("No metric type for accumulator: %s", key)
return _LastAggregator()
def add(self, metrics: Mapping[str, Any]):
"""Add metrics from the current training step to the summary.
Args:
metrics: Dictionary of metrics.
"""
for k, new_value in metrics.items():
if k not in self.metric_dict:
self.metric_dict[k] = self._new_aggregator(k)
self.metric_dict[k].add(new_value)
def add_text(self, text_metrics: Mapping[str, str]):
"""Add text metrics from the current step to the summary."""
for (k, v) in text_metrics.items():
self.text_dict[k] = str(v)
def empty(self):
"""Return true if there are no summaries to write."""
return not (self.metric_dict or self.text_dict)
def clear(self):
"""Clear acculumated summaries."""
self.metric_dict = {}
self.text_dict = {}
def write(self, writer: metric_writers.MetricWriter, step: int, prefix: str):
"""Write metrics using summary_writer, and clear all summaries."""
if self.empty():
return
# Special logic for organizing metrics under tensorboard.
# Tensorboard has top-level groups, but doesn't have subgroups.
# Scalars are put into separate top-level groups for easier viewing.
# e.g. all scalars in "train", "test", etc.
# For images, each set of images should be a different top-level group,
# otherwise all images will get tossed into a single group under,
# e.g. "generate".
if prefix:
s_prefix = prefix + "/"
i_prefix = prefix + "_"
else:
# Each prefix is stored in a separate subdirectory already.
s_prefix = ""
i_prefix = ""
# Split metrics into different types.
scalars = {}
images = {}
histograms = {}
text_dict = {}
# Sort metrics into scalars, images, text, and histograms.
for k, aggregator in self.metric_dict.items():
if not isinstance(aggregator, Aggregator):
raise ValueError("Internal error: metric_dict should contain only "
"_Aggregator objects; contained %s" % aggregator)
if not aggregator.is_valid():
raise ValueError(f"No valid value for metric {k}.")
v = aggregator.to_value()
s_key = s_prefix + k
i_key = i_prefix + k
finite_mask = np.isfinite(v)
if not np.all(finite_mask):
logging.warning("Item %s contains non-finite elements.", k)
v = np.where(finite_mask, v, np.zeros_like(v))
if v is None:
logging.warning("Invalid value for %s", k)
elif np.ndim(v) == 0:
scalars[s_key] = v
elif self._is_image(v):
images[i_key] = self._upscale_image(v)
else:
histograms[s_key] = self._remove_outliers(v)
# Handle text data.
for (k, v) in self.text_dict.items():
s_key = s_prefix + k
text_dict[s_key] = v
# Write metrics.
if scalars:
writer.write_scalars(step, scalars)
if images:
writer.write_images(step, images)
if histograms:
writer.write_histograms(step, histograms)
if text_dict:
writer.write_texts(step, text_dict)
# Clear accumulated summaries.
self.clear()
|