feng2022 commited on
Commit
5ea8c68
·
1 Parent(s): 228e7db
dnnlib/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ from .util import EasyDict, make_cache_dir_path
dnnlib/tflib/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
4
+ #
5
+ # This work is made available under the Nvidia Source Code License-NC.
6
+ # To view a copy of this license, visit
7
+ # https://nvlabs.github.io/stylegan2/license.html
8
+
9
+ from . import autosummary
10
+ from . import network
11
+ from . import optimizer
12
+ from . import tfutil
13
+ from . import custom_ops
14
+
15
+ from .tfutil import *
16
+ from .network import Network
17
+
18
+ from .optimizer import Optimizer
19
+
20
+ from .custom_ops import get_plugin
dnnlib/tflib/autosummary.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
4
+ #
5
+ # This work is made available under the Nvidia Source Code License-NC.
6
+ # To view a copy of this license, visit
7
+ # https://nvlabs.github.io/stylegan2/license.html
8
+
9
+ """Helper for adding automatically tracked values to Tensorboard.
10
+
11
+ Autosummary creates an identity op that internally keeps track of the input
12
+ values and automatically shows up in TensorBoard. The reported value
13
+ represents an average over input components. The average is accumulated
14
+ constantly over time and flushed when save_summaries() is called.
15
+
16
+ Notes:
17
+ - The output tensor must be used as an input for something else in the
18
+ graph. Otherwise, the autosummary op will not get executed, and the average
19
+ value will not get accumulated.
20
+ - It is perfectly fine to include autosummaries with the same name in
21
+ several places throughout the graph, even if they are executed concurrently.
22
+ - It is ok to also pass in a python scalar or numpy array. In this case, it
23
+ is added to the average immediately.
24
+ """
25
+
26
+ from collections import OrderedDict
27
+ import numpy as np
28
+ import tensorflow as tf
29
+ from tensorboard import summary as summary_lib
30
+ from tensorboard.plugins.custom_scalar import layout_pb2
31
+
32
+ from . import tfutil
33
+ from .tfutil import TfExpression
34
+ from .tfutil import TfExpressionEx
35
+
36
+ # Enable "Custom scalars" tab in TensorBoard for advanced formatting.
37
+ # Disabled by default to reduce tfevents file size.
38
+ enable_custom_scalars = False
39
+
40
+ _dtype = tf.float64
41
+ _vars = OrderedDict() # name => [var, ...]
42
+ _immediate = OrderedDict() # name => update_op, update_value
43
+ _finalized = False
44
+ _merge_op = None
45
+
46
+
47
+ def _create_var(name: str, value_expr: TfExpression) -> TfExpression:
48
+ """Internal helper for creating autosummary accumulators."""
49
+ assert not _finalized
50
+ name_id = name.replace("/", "_")
51
+ v = tf.cast(value_expr, _dtype)
52
+
53
+ if v.shape.is_fully_defined():
54
+ size = np.prod(v.shape.as_list())
55
+ size_expr = tf.constant(size, dtype=_dtype)
56
+ else:
57
+ size = None
58
+ size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype))
59
+
60
+ if size == 1:
61
+ if v.shape.ndims != 0:
62
+ v = tf.reshape(v, [])
63
+ v = [size_expr, v, tf.square(v)]
64
+ else:
65
+ v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))]
66
+ v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype))
67
+
68
+ with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None):
69
+ var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)]
70
+ update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v))
71
+
72
+ if name in _vars:
73
+ _vars[name].append(var)
74
+ else:
75
+ _vars[name] = [var]
76
+ return update_op
77
+
78
+
79
+ def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None, condition: TfExpressionEx = True) -> TfExpressionEx:
80
+ """Create a new autosummary.
81
+
82
+ Args:
83
+ name: Name to use in TensorBoard
84
+ value: TensorFlow expression or python value to track
85
+ passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node.
86
+
87
+ Example use of the passthru mechanism:
88
+
89
+ n = autosummary('l2loss', loss, passthru=n)
90
+
91
+ This is a shorthand for the following code:
92
+
93
+ with tf.control_dependencies([autosummary('l2loss', loss)]):
94
+ n = tf.identity(n)
95
+ """
96
+ tfutil.assert_tf_initialized()
97
+ name_id = name.replace("/", "_")
98
+
99
+ if tfutil.is_tf_expression(value):
100
+ with tf.name_scope("summary_" + name_id), tf.device(value.device):
101
+ condition = tf.convert_to_tensor(condition, name='condition')
102
+ update_op = tf.cond(condition, lambda: tf.group(_create_var(name, value)), tf.no_op)
103
+ with tf.control_dependencies([update_op]):
104
+ return tf.identity(value if passthru is None else passthru)
105
+
106
+ else: # python scalar or numpy array
107
+ assert not tfutil.is_tf_expression(passthru)
108
+ assert not tfutil.is_tf_expression(condition)
109
+ if condition:
110
+ if name not in _immediate:
111
+ with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None):
112
+ update_value = tf.placeholder(_dtype)
113
+ update_op = _create_var(name, update_value)
114
+ _immediate[name] = update_op, update_value
115
+ update_op, update_value = _immediate[name]
116
+ tfutil.run(update_op, {update_value: value})
117
+ return value if passthru is None else passthru
118
+
119
+
120
+ def finalize_autosummaries() -> None:
121
+ """Create the necessary ops to include autosummaries in TensorBoard report.
122
+ Note: This should be done only once per graph.
123
+ """
124
+ global _finalized
125
+ tfutil.assert_tf_initialized()
126
+
127
+ if _finalized:
128
+ return None
129
+
130
+ _finalized = True
131
+ tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list])
132
+
133
+ # Create summary ops.
134
+ with tf.device(None), tf.control_dependencies(None):
135
+ for name, vars_list in _vars.items():
136
+ name_id = name.replace("/", "_")
137
+ with tfutil.absolute_name_scope("Autosummary/" + name_id):
138
+ moments = tf.add_n(vars_list)
139
+ moments /= moments[0]
140
+ with tf.control_dependencies([moments]): # read before resetting
141
+ reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list]
142
+ with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting
143
+ mean = moments[1]
144
+ std = tf.sqrt(moments[2] - tf.square(moments[1]))
145
+ tf.summary.scalar(name, mean)
146
+ if enable_custom_scalars:
147
+ tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std)
148
+ tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std)
149
+
150
+ # Setup layout for custom scalars.
151
+ layout = None
152
+ if enable_custom_scalars:
153
+ cat_dict = OrderedDict()
154
+ for series_name in sorted(_vars.keys()):
155
+ p = series_name.split("/")
156
+ cat = p[0] if len(p) >= 2 else ""
157
+ chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1]
158
+ if cat not in cat_dict:
159
+ cat_dict[cat] = OrderedDict()
160
+ if chart not in cat_dict[cat]:
161
+ cat_dict[cat][chart] = []
162
+ cat_dict[cat][chart].append(series_name)
163
+ categories = []
164
+ for cat_name, chart_dict in cat_dict.items():
165
+ charts = []
166
+ for chart_name, series_names in chart_dict.items():
167
+ series = []
168
+ for series_name in series_names:
169
+ series.append(layout_pb2.MarginChartContent.Series(
170
+ value=series_name,
171
+ lower="xCustomScalars/" + series_name + "/margin_lo",
172
+ upper="xCustomScalars/" + series_name + "/margin_hi"))
173
+ margin = layout_pb2.MarginChartContent(series=series)
174
+ charts.append(layout_pb2.Chart(title=chart_name, margin=margin))
175
+ categories.append(layout_pb2.Category(title=cat_name, chart=charts))
176
+ layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories))
177
+ return layout
178
+
179
+ def save_summaries(file_writer, global_step=None):
180
+ """Call FileWriter.add_summary() with all summaries in the default graph,
181
+ automatically finalizing and merging them on the first call.
182
+ """
183
+ global _merge_op
184
+ tfutil.assert_tf_initialized()
185
+
186
+ if _merge_op is None:
187
+ layout = finalize_autosummaries()
188
+ if layout is not None:
189
+ file_writer.add_summary(layout)
190
+ with tf.device(None), tf.control_dependencies(None):
191
+ _merge_op = tf.summary.merge_all()
192
+
193
+ file_writer.add_summary(_merge_op.eval(), global_step)
dnnlib/tflib/custom_ops.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
4
+ #
5
+ # This work is made available under the Nvidia Source Code License-NC.
6
+ # To view a copy of this license, visit
7
+ # https://nvlabs.github.io/stylegan2/license.html
8
+
9
+ """TensorFlow custom ops builder.
10
+ """
11
+
12
+ import os
13
+ import re
14
+ import uuid
15
+ import hashlib
16
+ import tempfile
17
+ import shutil
18
+ import tensorflow as tf
19
+ from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module
20
+
21
+ #----------------------------------------------------------------------------
22
+ # Global options.
23
+
24
+ cuda_cache_path = os.path.join(os.path.dirname(__file__), '_cudacache')
25
+ cuda_cache_version_tag = 'v1'
26
+ do_not_hash_included_headers = False # Speed up compilation by assuming that headers included by the CUDA code never change. Unsafe!
27
+ verbose = True # Print status messages to stdout.
28
+
29
+ compiler_bindir_search_path = [
30
+ 'C:/Program Files (x86)/Microsoft Visual Studio/2017/Community/VC/Tools/MSVC/14.14.26428/bin/Hostx64/x64',
31
+ 'C:/Program Files (x86)/Microsoft Visual Studio/2019/Community/VC/Tools/MSVC/14.23.28105/bin/Hostx64/x64',
32
+ 'C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin',
33
+ ]
34
+
35
+ #----------------------------------------------------------------------------
36
+ # Internal helper funcs.
37
+
38
+ def _find_compiler_bindir():
39
+ for compiler_path in compiler_bindir_search_path:
40
+ if os.path.isdir(compiler_path):
41
+ return compiler_path
42
+ return None
43
+
44
+ def _get_compute_cap(device):
45
+ caps_str = device.physical_device_desc
46
+ m = re.search('compute capability: (\\d+).(\\d+)', caps_str)
47
+ major = m.group(1)
48
+ minor = m.group(2)
49
+ return (major, minor)
50
+
51
+ def _get_cuda_gpu_arch_string():
52
+ gpus = [x for x in device_lib.list_local_devices() if x.device_type == 'GPU']
53
+ if len(gpus) == 0:
54
+ raise RuntimeError('No GPU devices found')
55
+ (major, minor) = _get_compute_cap(gpus[0])
56
+ return 'sm_%s%s' % (major, minor)
57
+
58
+ def _run_cmd(cmd):
59
+ with os.popen(cmd) as pipe:
60
+ output = pipe.read()
61
+ status = pipe.close()
62
+ if status is not None:
63
+ raise RuntimeError('NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s' % (cmd, output))
64
+
65
+ def _prepare_nvcc_cli(opts):
66
+ cmd = 'nvcc ' + opts.strip()
67
+ cmd += ' --disable-warnings'
68
+ cmd += ' --include-path "%s"' % tf.sysconfig.get_include()
69
+ cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'protobuf_archive', 'src')
70
+ cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'com_google_absl')
71
+ cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'eigen_archive')
72
+
73
+ compiler_bindir = _find_compiler_bindir()
74
+ if compiler_bindir is None:
75
+ # Require that _find_compiler_bindir succeeds on Windows. Allow
76
+ # nvcc to use whatever is the default on Linux.
77
+ if os.name == 'nt':
78
+ raise RuntimeError('Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".' % __file__)
79
+ else:
80
+ cmd += ' --compiler-bindir "%s"' % compiler_bindir
81
+ cmd += ' 2>&1'
82
+ return cmd
83
+
84
+ #----------------------------------------------------------------------------
85
+ # Main entry point.
86
+
87
+ _plugin_cache = dict()
88
+
89
+ def get_plugin(cuda_file):
90
+ cuda_file_base = os.path.basename(cuda_file)
91
+ cuda_file_name, cuda_file_ext = os.path.splitext(cuda_file_base)
92
+
93
+ # Already in cache?
94
+ if cuda_file in _plugin_cache:
95
+ return _plugin_cache[cuda_file]
96
+
97
+ # Setup plugin.
98
+ if verbose:
99
+ print('Setting up TensorFlow plugin "%s": ' % cuda_file_base, end='', flush=True)
100
+ try:
101
+ # Hash CUDA source.
102
+ md5 = hashlib.md5()
103
+ with open(cuda_file, 'rb') as f:
104
+ md5.update(f.read())
105
+ md5.update(b'\n')
106
+
107
+ # Hash headers included by the CUDA code by running it through the preprocessor.
108
+ if not do_not_hash_included_headers:
109
+ if verbose:
110
+ print('Preprocessing... ', end='', flush=True)
111
+ with tempfile.TemporaryDirectory() as tmp_dir:
112
+ tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + cuda_file_ext)
113
+ _run_cmd(_prepare_nvcc_cli('"%s" --preprocess -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir)))
114
+ with open(tmp_file, 'rb') as f:
115
+ bad_file_str = ('"' + cuda_file.replace('\\', '/') + '"').encode('utf-8') # __FILE__ in error check macros
116
+ good_file_str = ('"' + cuda_file_base + '"').encode('utf-8')
117
+ for ln in f:
118
+ if not ln.startswith(b'# ') and not ln.startswith(b'#line '): # ignore line number pragmas
119
+ ln = ln.replace(bad_file_str, good_file_str)
120
+ md5.update(ln)
121
+ md5.update(b'\n')
122
+
123
+ # Select compiler options.
124
+ compile_opts = ''
125
+ if os.name == 'nt':
126
+ compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.lib')
127
+ elif os.name == 'posix':
128
+ compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.so')
129
+ compile_opts += ' --compiler-options \'-fPIC -D_GLIBCXX_USE_CXX11_ABI=0\''
130
+ else:
131
+ assert False # not Windows or Linux, w00t?
132
+ compile_opts += ' --gpu-architecture=%s' % _get_cuda_gpu_arch_string()
133
+ compile_opts += ' --use_fast_math'
134
+ nvcc_cmd = _prepare_nvcc_cli(compile_opts)
135
+
136
+ # Hash build configuration.
137
+ md5.update(('nvcc_cmd: ' + nvcc_cmd).encode('utf-8') + b'\n')
138
+ md5.update(('tf.VERSION: ' + tf.VERSION).encode('utf-8') + b'\n')
139
+ md5.update(('cuda_cache_version_tag: ' + cuda_cache_version_tag).encode('utf-8') + b'\n')
140
+
141
+ # Compile if not already compiled.
142
+ bin_file_ext = '.dll' if os.name == 'nt' else '.so'
143
+ bin_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + md5.hexdigest() + bin_file_ext)
144
+ if not os.path.isfile(bin_file):
145
+ if verbose:
146
+ print('Compiling... ', end='', flush=True)
147
+ with tempfile.TemporaryDirectory() as tmp_dir:
148
+ tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + bin_file_ext)
149
+ _run_cmd(nvcc_cmd + ' "%s" --shared -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir))
150
+ os.makedirs(cuda_cache_path, exist_ok=True)
151
+ intermediate_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + uuid.uuid4().hex + '_tmp' + bin_file_ext)
152
+ shutil.copyfile(tmp_file, intermediate_file)
153
+ os.rename(intermediate_file, bin_file) # atomic
154
+
155
+ # Load.
156
+ if verbose:
157
+ print('Loading... ', end='', flush=True)
158
+ plugin = tf.load_op_library(bin_file)
159
+
160
+ # Add to cache.
161
+ _plugin_cache[cuda_file] = plugin
162
+ if verbose:
163
+ print('Done.', flush=True)
164
+ return plugin
165
+
166
+ except:
167
+ if verbose:
168
+ print('Failed!', flush=True)
169
+ raise
170
+
171
+ #----------------------------------------------------------------------------
dnnlib/tflib/network.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
4
+ #
5
+ # This work is made available under the Nvidia Source Code License-NC.
6
+ # To view a copy of this license, visit
7
+ # https://nvlabs.github.io/stylegan2/license.html
8
+
9
+ """Helper for managing networks."""
10
+
11
+ import types
12
+ import inspect
13
+ import re
14
+ import uuid
15
+ import sys
16
+ import numpy as np
17
+ import tensorflow as tf
18
+
19
+ from collections import OrderedDict
20
+ from typing import Any, List, Tuple, Union
21
+
22
+ from . import tfutil
23
+ from .. import util
24
+
25
+ from .tfutil import TfExpression, TfExpressionEx
26
+
27
+ _import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import.
28
+ _import_module_src = dict() # Source code for temporary modules created during pickle import.
29
+
30
+
31
+ def import_handler(handler_func):
32
+ """Function decorator for declaring custom import handlers."""
33
+ _import_handlers.append(handler_func)
34
+ return handler_func
35
+
36
+
37
+ class Network:
38
+ """Generic network abstraction.
39
+
40
+ Acts as a convenience wrapper for a parameterized network construction
41
+ function, providing several utility methods and convenient access to
42
+ the inputs/outputs/weights.
43
+
44
+ Network objects can be safely pickled and unpickled for long-term
45
+ archival purposes. The pickling works reliably as long as the underlying
46
+ network construction function is defined in a standalone Python module
47
+ that has no side effects or application-specific imports.
48
+
49
+ Args:
50
+ name: Network name. Used to select TensorFlow name and variable scopes.
51
+ func_name: Fully qualified name of the underlying network construction function, or a top-level function object.
52
+ static_kwargs: Keyword arguments to be passed in to the network construction function.
53
+
54
+ Attributes:
55
+ name: User-specified name, defaults to build func name if None.
56
+ scope: Unique TensorFlow scope containing template graph and variables, derived from the user-specified name.
57
+ static_kwargs: Arguments passed to the user-supplied build func.
58
+ components: Container for sub-networks. Passed to the build func, and retained between calls.
59
+ num_inputs: Number of input tensors.
60
+ num_outputs: Number of output tensors.
61
+ input_shapes: Input tensor shapes (NC or NCHW), including minibatch dimension.
62
+ output_shapes: Output tensor shapes (NC or NCHW), including minibatch dimension.
63
+ input_shape: Short-hand for input_shapes[0].
64
+ output_shape: Short-hand for output_shapes[0].
65
+ input_templates: Input placeholders in the template graph.
66
+ output_templates: Output tensors in the template graph.
67
+ input_names: Name string for each input.
68
+ output_names: Name string for each output.
69
+ own_vars: Variables defined by this network (local_name => var), excluding sub-networks.
70
+ vars: All variables (local_name => var).
71
+ trainables: All trainable variables (local_name => var).
72
+ var_global_to_local: Mapping from variable global names to local names.
73
+ """
74
+
75
+ def __init__(self, name: str = None, func_name: Any = None, **static_kwargs):
76
+ tfutil.assert_tf_initialized()
77
+ assert isinstance(name, str) or name is None
78
+ assert func_name is not None
79
+ assert isinstance(func_name, str) or util.is_top_level_function(func_name)
80
+ assert util.is_pickleable(static_kwargs)
81
+
82
+ self._init_fields()
83
+ self.name = name
84
+ self.static_kwargs = util.EasyDict(static_kwargs)
85
+
86
+ # Locate the user-specified network build function.
87
+ if util.is_top_level_function(func_name):
88
+ func_name = util.get_top_level_function_name(func_name)
89
+ module, self._build_func_name = util.get_module_from_obj_name(func_name)
90
+ self._build_func = util.get_obj_from_module(module, self._build_func_name)
91
+ assert callable(self._build_func)
92
+
93
+ # Dig up source code for the module containing the build function.
94
+ self._build_module_src = _import_module_src.get(module, None)
95
+ if self._build_module_src is None:
96
+ self._build_module_src = inspect.getsource(module)
97
+
98
+ # Init TensorFlow graph.
99
+ self._init_graph()
100
+ self.reset_own_vars()
101
+
102
+ def _init_fields(self) -> None:
103
+ self.name = None
104
+ self.scope = None
105
+ self.static_kwargs = util.EasyDict()
106
+ self.components = util.EasyDict()
107
+ self.num_inputs = 0
108
+ self.num_outputs = 0
109
+ self.input_shapes = [[]]
110
+ self.output_shapes = [[]]
111
+ self.input_shape = []
112
+ self.output_shape = []
113
+ self.input_templates = []
114
+ self.output_templates = []
115
+ self.input_names = []
116
+ self.output_names = []
117
+ self.own_vars = OrderedDict()
118
+ self.vars = OrderedDict()
119
+ self.trainables = OrderedDict()
120
+ self.var_global_to_local = OrderedDict()
121
+
122
+ self._build_func = None # User-supplied build function that constructs the network.
123
+ self._build_func_name = None # Name of the build function.
124
+ self._build_module_src = None # Full source code of the module containing the build function.
125
+ self._run_cache = dict() # Cached graph data for Network.run().
126
+
127
+ def _init_graph(self) -> None:
128
+ # Collect inputs.
129
+ self.input_names = []
130
+
131
+ for param in inspect.signature(self._build_func).parameters.values():
132
+ if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty:
133
+ self.input_names.append(param.name)
134
+
135
+ self.num_inputs = len(self.input_names)
136
+ assert self.num_inputs >= 1
137
+
138
+ # Choose name and scope.
139
+ if self.name is None:
140
+ self.name = self._build_func_name
141
+ assert re.match("^[A-Za-z0-9_.\\-]*$", self.name)
142
+ with tf.name_scope(None):
143
+ self.scope = tf.get_default_graph().unique_name(self.name, mark_as_used=True)
144
+
145
+ # Finalize build func kwargs.
146
+ build_kwargs = dict(self.static_kwargs)
147
+ build_kwargs["is_template_graph"] = True
148
+ build_kwargs["components"] = self.components
149
+
150
+ # Build template graph.
151
+ with tfutil.absolute_variable_scope(self.scope, reuse=False), tfutil.absolute_name_scope(self.scope): # ignore surrounding scopes
152
+ assert tf.get_variable_scope().name == self.scope
153
+ assert tf.get_default_graph().get_name_scope() == self.scope
154
+ with tf.control_dependencies(None): # ignore surrounding control dependencies
155
+ self.input_templates = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
156
+ out_expr = self._build_func(*self.input_templates, **build_kwargs)
157
+
158
+ # Collect outputs.
159
+ assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
160
+ self.output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
161
+ self.num_outputs = len(self.output_templates)
162
+ assert self.num_outputs >= 1
163
+ assert all(tfutil.is_tf_expression(t) for t in self.output_templates)
164
+
165
+ # Perform sanity checks.
166
+ if any(t.shape.ndims is None for t in self.input_templates):
167
+ raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.")
168
+ if any(t.shape.ndims is None for t in self.output_templates):
169
+ raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.")
170
+ if any(not isinstance(comp, Network) for comp in self.components.values()):
171
+ raise ValueError("Components of a Network must be Networks themselves.")
172
+ if len(self.components) != len(set(comp.name for comp in self.components.values())):
173
+ raise ValueError("Components of a Network must have unique names.")
174
+
175
+ # List inputs and outputs.
176
+ self.input_shapes = [t.shape.as_list() for t in self.input_templates]
177
+ self.output_shapes = [t.shape.as_list() for t in self.output_templates]
178
+ self.input_shape = self.input_shapes[0]
179
+ self.output_shape = self.output_shapes[0]
180
+ self.output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates]
181
+
182
+ # List variables.
183
+ self.own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/"))
184
+ self.vars = OrderedDict(self.own_vars)
185
+ self.vars.update((comp.name + "/" + name, var) for comp in self.components.values() for name, var in comp.vars.items())
186
+ self.trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable)
187
+ self.var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items())
188
+
189
+ def reset_own_vars(self) -> None:
190
+ """Re-initialize all variables of this network, excluding sub-networks."""
191
+ tfutil.run([var.initializer for var in self.own_vars.values()])
192
+
193
+ def reset_vars(self) -> None:
194
+ """Re-initialize all variables of this network, including sub-networks."""
195
+ tfutil.run([var.initializer for var in self.vars.values()])
196
+
197
+ def reset_trainables(self) -> None:
198
+ """Re-initialize all trainable variables of this network, including sub-networks."""
199
+ tfutil.run([var.initializer for var in self.trainables.values()])
200
+
201
+ def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]:
202
+ """Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s)."""
203
+ assert len(in_expr) == self.num_inputs
204
+ assert not all(expr is None for expr in in_expr)
205
+
206
+ # Finalize build func kwargs.
207
+ build_kwargs = dict(self.static_kwargs)
208
+ build_kwargs.update(dynamic_kwargs)
209
+ build_kwargs["is_template_graph"] = False
210
+ build_kwargs["components"] = self.components
211
+
212
+ # Build TensorFlow graph to evaluate the network.
213
+ with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name):
214
+ assert tf.get_variable_scope().name == self.scope
215
+ valid_inputs = [expr for expr in in_expr if expr is not None]
216
+ final_inputs = []
217
+ for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes):
218
+ if expr is not None:
219
+ expr = tf.identity(expr, name=name)
220
+ else:
221
+ expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name)
222
+ final_inputs.append(expr)
223
+ out_expr = self._build_func(*final_inputs, **build_kwargs)
224
+
225
+ # Propagate input shapes back to the user-specified expressions.
226
+ for expr, final in zip(in_expr, final_inputs):
227
+ if isinstance(expr, tf.Tensor):
228
+ expr.set_shape(final.shape)
229
+
230
+ # Express outputs in the desired format.
231
+ assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
232
+ if return_as_list:
233
+ out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
234
+ return out_expr
235
+
236
+ def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str:
237
+ """Get the local name of a given variable, without any surrounding name scopes."""
238
+ assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str)
239
+ global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name
240
+ return self.var_global_to_local[global_name]
241
+
242
+ def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression:
243
+ """Find variable by local or global name."""
244
+ assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str)
245
+ return self.vars[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name
246
+
247
+ def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray:
248
+ """Get the value of a given variable as NumPy array.
249
+ Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible."""
250
+ return self.find_var(var_or_local_name).eval()
251
+
252
+ def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None:
253
+ """Set the value of a given variable based on the given NumPy array.
254
+ Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible."""
255
+ tfutil.set_vars({self.find_var(var_or_local_name): new_value})
256
+
257
+ def __getstate__(self) -> dict:
258
+ """Pickle export."""
259
+ state = dict()
260
+ state["version"] = 4
261
+ state["name"] = self.name
262
+ state["static_kwargs"] = dict(self.static_kwargs)
263
+ state["components"] = dict(self.components)
264
+ state["build_module_src"] = self._build_module_src
265
+ state["build_func_name"] = self._build_func_name
266
+ state["variables"] = list(zip(self.own_vars.keys(), tfutil.run(list(self.own_vars.values()))))
267
+ return state
268
+
269
+ def __setstate__(self, state: dict) -> None:
270
+ """Pickle import."""
271
+ # pylint: disable=attribute-defined-outside-init
272
+ tfutil.assert_tf_initialized()
273
+ self._init_fields()
274
+
275
+ # Execute custom import handlers.
276
+ for handler in _import_handlers:
277
+ state = handler(state)
278
+
279
+ # Set basic fields.
280
+ assert state["version"] in [2, 3, 4]
281
+ self.name = state["name"]
282
+ self.static_kwargs = util.EasyDict(state["static_kwargs"])
283
+ self.components = util.EasyDict(state.get("components", {}))
284
+ self._build_module_src = state["build_module_src"]
285
+ self._build_func_name = state["build_func_name"]
286
+
287
+ # Create temporary module from the imported source code.
288
+ module_name = "_tflib_network_import_" + uuid.uuid4().hex
289
+ module = types.ModuleType(module_name)
290
+ sys.modules[module_name] = module
291
+ _import_module_src[module] = self._build_module_src
292
+ exec(self._build_module_src, module.__dict__) # pylint: disable=exec-used
293
+
294
+ # Locate network build function in the temporary module.
295
+ self._build_func = util.get_obj_from_module(module, self._build_func_name)
296
+ assert callable(self._build_func)
297
+
298
+ # Init TensorFlow graph.
299
+ self._init_graph()
300
+ self.reset_own_vars()
301
+ tfutil.set_vars({self.find_var(name): value for name, value in state["variables"]})
302
+
303
+ def clone(self, name: str = None, **new_static_kwargs) -> "Network":
304
+ """Create a clone of this network with its own copy of the variables."""
305
+ # pylint: disable=protected-access
306
+ net = object.__new__(Network)
307
+ net._init_fields()
308
+ net.name = name if name is not None else self.name
309
+ net.static_kwargs = util.EasyDict(self.static_kwargs)
310
+ net.static_kwargs.update(new_static_kwargs)
311
+ net._build_module_src = self._build_module_src
312
+ net._build_func_name = self._build_func_name
313
+ net._build_func = self._build_func
314
+ net._init_graph()
315
+ net.copy_vars_from(self)
316
+ return net
317
+
318
+ def copy_own_vars_from(self, src_net: "Network") -> None:
319
+ """Copy the values of all variables from the given network, excluding sub-networks."""
320
+ names = [name for name in self.own_vars.keys() if name in src_net.own_vars]
321
+ tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
322
+
323
+ def copy_vars_from(self, src_net: "Network") -> None:
324
+ """Copy the values of all variables from the given network, including sub-networks."""
325
+ names = [name for name in self.vars.keys() if name in src_net.vars]
326
+ tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
327
+
328
+ def copy_trainables_from(self, src_net: "Network") -> None:
329
+ """Copy the values of all trainable variables from the given network, including sub-networks."""
330
+ names = [name for name in self.trainables.keys() if name in src_net.trainables]
331
+ tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
332
+
333
+ def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network":
334
+ """Create new network with the given parameters, and copy all variables from this network."""
335
+ if new_name is None:
336
+ new_name = self.name
337
+ static_kwargs = dict(self.static_kwargs)
338
+ static_kwargs.update(new_static_kwargs)
339
+ net = Network(name=new_name, func_name=new_func_name, **static_kwargs)
340
+ net.copy_vars_from(self)
341
+ return net
342
+
343
+ def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation:
344
+ """Construct a TensorFlow op that updates the variables of this network
345
+ to be slightly closer to those of the given network."""
346
+ with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"):
347
+ ops = []
348
+ for name, var in self.vars.items():
349
+ if name in src_net.vars:
350
+ cur_beta = beta if name in self.trainables else beta_nontrainable
351
+ new_value = tfutil.lerp(src_net.vars[name], var, cur_beta)
352
+ ops.append(var.assign(new_value))
353
+ return tf.group(*ops)
354
+
355
+ def run(self,
356
+ *in_arrays: Tuple[Union[np.ndarray, None], ...],
357
+ input_transform: dict = None,
358
+ output_transform: dict = None,
359
+ return_as_list: bool = False,
360
+ print_progress: bool = False,
361
+ minibatch_size: int = None,
362
+ num_gpus: int = 1,
363
+ assume_frozen: bool = False,
364
+ **dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]:
365
+ """Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s).
366
+
367
+ Args:
368
+ input_transform: A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network.
369
+ The dict must contain a 'func' field that points to a top-level function. The function is called with the input
370
+ TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
371
+ output_transform: A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network.
372
+ The dict must contain a 'func' field that points to a top-level function. The function is called with the output
373
+ TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
374
+ return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs.
375
+ print_progress: Print progress to the console? Useful for very large input arrays.
376
+ minibatch_size: Maximum minibatch size to use, None = disable batching.
377
+ num_gpus: Number of GPUs to use.
378
+ assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls.
379
+ dynamic_kwargs: Additional keyword arguments to be passed into the network build function.
380
+ """
381
+ assert len(in_arrays) == self.num_inputs
382
+ assert not all(arr is None for arr in in_arrays)
383
+ assert input_transform is None or util.is_top_level_function(input_transform["func"])
384
+ assert output_transform is None or util.is_top_level_function(output_transform["func"])
385
+ output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs)
386
+ num_items = in_arrays[0].shape[0]
387
+ if minibatch_size is None:
388
+ minibatch_size = num_items
389
+
390
+ # Construct unique hash key from all arguments that affect the TensorFlow graph.
391
+ key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs)
392
+ def unwind_key(obj):
393
+ if isinstance(obj, dict):
394
+ return [(key, unwind_key(value)) for key, value in sorted(obj.items())]
395
+ if callable(obj):
396
+ return util.get_top_level_function_name(obj)
397
+ return obj
398
+ key = repr(unwind_key(key))
399
+
400
+ # Build graph.
401
+ if key not in self._run_cache:
402
+ with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None):
403
+ with tf.device("/cpu:0"):
404
+ in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
405
+ in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
406
+
407
+ out_split = []
408
+ for gpu in range(num_gpus):
409
+ with tf.device("/gpu:%d" % gpu):
410
+ net_gpu = self.clone() if assume_frozen else self
411
+ in_gpu = in_split[gpu]
412
+
413
+ if input_transform is not None:
414
+ in_kwargs = dict(input_transform)
415
+ in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs)
416
+ in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu)
417
+
418
+ assert len(in_gpu) == self.num_inputs
419
+ out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs)
420
+
421
+ if output_transform is not None:
422
+ out_kwargs = dict(output_transform)
423
+ out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs)
424
+ out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu)
425
+
426
+ assert len(out_gpu) == self.num_outputs
427
+ out_split.append(out_gpu)
428
+
429
+ with tf.device("/cpu:0"):
430
+ out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)]
431
+ self._run_cache[key] = in_expr, out_expr
432
+
433
+ # Run minibatches.
434
+ in_expr, out_expr = self._run_cache[key]
435
+ out_arrays = [np.empty([num_items] + expr.shape.as_list()[1:], expr.dtype.name) for expr in out_expr]
436
+
437
+ for mb_begin in range(0, num_items, minibatch_size):
438
+ if print_progress:
439
+ print("\r%d / %d" % (mb_begin, num_items), end="")
440
+
441
+ mb_end = min(mb_begin + minibatch_size, num_items)
442
+ mb_num = mb_end - mb_begin
443
+ mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)]
444
+ mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in)))
445
+
446
+ for dst, src in zip(out_arrays, mb_out):
447
+ dst[mb_begin: mb_end] = src
448
+
449
+ # Done.
450
+ if print_progress:
451
+ print("\r%d / %d" % (num_items, num_items))
452
+
453
+ if not return_as_list:
454
+ out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays)
455
+ return out_arrays
456
+
457
+ def list_ops(self) -> List[TfExpression]:
458
+ include_prefix = self.scope + "/"
459
+ exclude_prefix = include_prefix + "_"
460
+ ops = tf.get_default_graph().get_operations()
461
+ ops = [op for op in ops if op.name.startswith(include_prefix)]
462
+ ops = [op for op in ops if not op.name.startswith(exclude_prefix)]
463
+ return ops
464
+
465
+ def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]:
466
+ """Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to
467
+ individual layers of the network. Mainly intended to be used for reporting."""
468
+ layers = []
469
+
470
+ def recurse(scope, parent_ops, parent_vars, level):
471
+ # Ignore specific patterns.
472
+ if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]):
473
+ return
474
+
475
+ # Filter ops and vars by scope.
476
+ global_prefix = scope + "/"
477
+ local_prefix = global_prefix[len(self.scope) + 1:]
478
+ cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]]
479
+ cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]]
480
+ if not cur_ops and not cur_vars:
481
+ return
482
+
483
+ # Filter out all ops related to variables.
484
+ for var in [op for op in cur_ops if op.type.startswith("Variable")]:
485
+ var_prefix = var.name + "/"
486
+ cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)]
487
+
488
+ # Scope does not contain ops as immediate children => recurse deeper.
489
+ contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type not in ["Identity", "Cast", "Transpose"] for op in cur_ops)
490
+ if (level == 0 or not contains_direct_ops) and (len(cur_ops) + len(cur_vars)) > 1:
491
+ visited = set()
492
+ for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]:
493
+ token = rel_name.split("/")[0]
494
+ if token not in visited:
495
+ recurse(global_prefix + token, cur_ops, cur_vars, level + 1)
496
+ visited.add(token)
497
+ return
498
+
499
+ # Report layer.
500
+ layer_name = scope[len(self.scope) + 1:]
501
+ layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1]
502
+ layer_trainables = [var for _name, var in cur_vars if var.trainable]
503
+ layers.append((layer_name, layer_output, layer_trainables))
504
+
505
+ recurse(self.scope, self.list_ops(), list(self.vars.items()), 0)
506
+ return layers
507
+
508
+ def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None:
509
+ """Print a summary table of the network structure."""
510
+ rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]]
511
+ rows += [["---"] * 4]
512
+ total_params = 0
513
+
514
+ for layer_name, layer_output, layer_trainables in self.list_layers():
515
+ num_params = sum(int(np.prod(var.shape.as_list())) for var in layer_trainables)
516
+ weights = [var for var in layer_trainables if var.name.endswith("/weight:0")]
517
+ weights.sort(key=lambda x: len(x.name))
518
+ if len(weights) == 0 and len(layer_trainables) == 1:
519
+ weights = layer_trainables
520
+ total_params += num_params
521
+
522
+ if not hide_layers_with_no_params or num_params != 0:
523
+ num_params_str = str(num_params) if num_params > 0 else "-"
524
+ output_shape_str = str(layer_output.shape)
525
+ weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-"
526
+ rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]]
527
+
528
+ rows += [["---"] * 4]
529
+ rows += [["Total", str(total_params), "", ""]]
530
+
531
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
532
+ print()
533
+ for row in rows:
534
+ print(" ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths)))
535
+ print()
536
+
537
+ def setup_weight_histograms(self, title: str = None) -> None:
538
+ """Construct summary ops to include histograms of all trainable parameters in TensorBoard."""
539
+ if title is None:
540
+ title = self.name
541
+
542
+ with tf.name_scope(None), tf.device(None), tf.control_dependencies(None):
543
+ for local_name, var in self.trainables.items():
544
+ if "/" in local_name:
545
+ p = local_name.split("/")
546
+ name = title + "_" + p[-1] + "/" + "_".join(p[:-1])
547
+ else:
548
+ name = title + "_toplevel/" + local_name
549
+
550
+ tf.summary.histogram(name, var)
551
+
552
+ #----------------------------------------------------------------------------
553
+ # Backwards-compatible emulation of legacy output transformation in Network.run().
554
+
555
+ _print_legacy_warning = True
556
+
557
+ def _handle_legacy_output_transforms(output_transform, dynamic_kwargs):
558
+ global _print_legacy_warning
559
+ legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"]
560
+ if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs):
561
+ return output_transform, dynamic_kwargs
562
+
563
+ if _print_legacy_warning:
564
+ _print_legacy_warning = False
565
+ print()
566
+ print("WARNING: Old-style output transformations in Network.run() are deprecated.")
567
+ print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'")
568
+ print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.")
569
+ print()
570
+ assert output_transform is None
571
+
572
+ new_kwargs = dict(dynamic_kwargs)
573
+ new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs}
574
+ new_transform["func"] = _legacy_output_transform_func
575
+ return new_transform, new_kwargs
576
+
577
+ def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None):
578
+ if out_mul != 1.0:
579
+ expr = [x * out_mul for x in expr]
580
+
581
+ if out_add != 0.0:
582
+ expr = [x + out_add for x in expr]
583
+
584
+ if out_shrink > 1:
585
+ ksize = [1, 1, out_shrink, out_shrink]
586
+ expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr]
587
+
588
+ if out_dtype is not None:
589
+ if tf.as_dtype(out_dtype).is_integer:
590
+ expr = [tf.round(x) for x in expr]
591
+ expr = [tf.saturate_cast(x, out_dtype) for x in expr]
592
+ return expr
dnnlib/tflib/ops/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
4
+ #
5
+ # This work is made available under the Nvidia Source Code License-NC.
6
+ # To view a copy of this license, visit
7
+ # https://nvlabs.github.io/stylegan2/license.html
8
+
9
+ # empty
dnnlib/tflib/ops/fused_bias_act.cu ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
4
+ //
5
+ // This work is made available under the Nvidia Source Code License-NC.
6
+ // To view a copy of this license, visit
7
+ // https://nvlabs.github.io/stylegan2/license.html
8
+
9
+ #define EIGEN_USE_GPU
10
+ #define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__
11
+ #include "tensorflow/core/framework/op.h"
12
+ #include "tensorflow/core/framework/op_kernel.h"
13
+ #include "tensorflow/core/framework/shape_inference.h"
14
+ #include <stdio.h>
15
+
16
+ using namespace tensorflow;
17
+ using namespace tensorflow::shape_inference;
18
+
19
+ #define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false)
20
+
21
+ //------------------------------------------------------------------------
22
+ // CUDA kernel.
23
+
24
+ template <class T>
25
+ struct FusedBiasActKernelParams
26
+ {
27
+ const T* x; // [sizeX]
28
+ const T* b; // [sizeB] or NULL
29
+ const T* ref; // [sizeX] or NULL
30
+ T* y; // [sizeX]
31
+
32
+ int grad;
33
+ int axis;
34
+ int act;
35
+ float alpha;
36
+ float gain;
37
+
38
+ int sizeX;
39
+ int sizeB;
40
+ int stepB;
41
+ int loopX;
42
+ };
43
+
44
+ template <class T>
45
+ static __global__ void FusedBiasActKernel(const FusedBiasActKernelParams<T> p)
46
+ {
47
+ const float expRange = 80.0f;
48
+ const float halfExpRange = 40.0f;
49
+ const float seluScale = 1.0507009873554804934193349852946f;
50
+ const float seluAlpha = 1.6732632423543772848170429916717f;
51
+
52
+ // Loop over elements.
53
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
54
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
55
+ {
56
+ // Load and apply bias.
57
+ float x = (float)p.x[xi];
58
+ if (p.b)
59
+ x += (float)p.b[(xi / p.stepB) % p.sizeB];
60
+ float ref = (p.ref) ? (float)p.ref[xi] : 0.0f;
61
+ if (p.gain != 0.0f & p.act != 9)
62
+ ref /= p.gain;
63
+
64
+ // Evaluate activation func.
65
+ float y;
66
+ switch (p.act * 10 + p.grad)
67
+ {
68
+ // linear
69
+ default:
70
+ case 10: y = x; break;
71
+ case 11: y = x; break;
72
+ case 12: y = 0.0f; break;
73
+
74
+ // relu
75
+ case 20: y = (x > 0.0f) ? x : 0.0f; break;
76
+ case 21: y = (ref > 0.0f) ? x : 0.0f; break;
77
+ case 22: y = 0.0f; break;
78
+
79
+ // lrelu
80
+ case 30: y = (x > 0.0f) ? x : x * p.alpha; break;
81
+ case 31: y = (ref > 0.0f) ? x : x * p.alpha; break;
82
+ case 32: y = 0.0f; break;
83
+
84
+ // tanh
85
+ case 40: { float c = expf(x); float d = 1.0f / c; y = (x < -expRange) ? -1.0f : (x > expRange) ? 1.0f : (c - d) / (c + d); } break;
86
+ case 41: y = x * (1.0f - ref * ref); break;
87
+ case 42: y = x * (1.0f - ref * ref) * (-2.0f * ref); break;
88
+
89
+ // sigmoid
90
+ case 50: y = (x < -expRange) ? 0.0f : 1.0f / (expf(-x) + 1.0f); break;
91
+ case 51: y = x * ref * (1.0f - ref); break;
92
+ case 52: y = x * ref * (1.0f - ref) * (1.0f - 2.0f * ref); break;
93
+
94
+ // elu
95
+ case 60: y = (x >= 0.0f) ? x : expf(x) - 1.0f; break;
96
+ case 61: y = (ref >= 0.0f) ? x : x * (ref + 1.0f); break;
97
+ case 62: y = (ref >= 0.0f) ? 0.0f : x * (ref + 1.0f); break;
98
+
99
+ // selu
100
+ case 70: y = (x >= 0.0f) ? seluScale * x : (seluScale * seluAlpha) * (expf(x) - 1.0f); break;
101
+ case 71: y = (ref >= 0.0f) ? x * seluScale : x * (ref + seluScale * seluAlpha); break;
102
+ case 72: y = (ref >= 0.0f) ? 0.0f : x * (ref + seluScale * seluAlpha); break;
103
+
104
+ // softplus
105
+ case 80: y = (x > expRange) ? x : logf(expf(x) + 1.0f); break;
106
+ case 81: y = x * (1.0f - expf(-ref)); break;
107
+ case 82: { float c = expf(-ref); y = x * c * (1.0f - c); } break;
108
+
109
+ // swish
110
+ case 90: y = (x < -expRange) ? 0.0f : x / (expf(-x) + 1.0f); break;
111
+ case 91: { float c = expf(ref); float d = c + 1.0f; y = (ref > halfExpRange) ? x : x * c * (ref + d) / (d * d); } break;
112
+ case 92: { float c = expf(ref); float d = c + 1.0f; y = (ref > halfExpRange) ? 0.0f : x * c * (ref * (2.0f - d) + 2.0f * d) / (d * d * d); } break;
113
+ }
114
+
115
+ // Apply gain and store.
116
+ p.y[xi] = (T)(y * p.gain);
117
+ }
118
+ }
119
+
120
+ //------------------------------------------------------------------------
121
+ // TensorFlow op.
122
+
123
+ template <class T>
124
+ struct FusedBiasActOp : public OpKernel
125
+ {
126
+ FusedBiasActKernelParams<T> m_attribs;
127
+
128
+ FusedBiasActOp(OpKernelConstruction* ctx) : OpKernel(ctx)
129
+ {
130
+ memset(&m_attribs, 0, sizeof(m_attribs));
131
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("grad", &m_attribs.grad));
132
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &m_attribs.axis));
133
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("act", &m_attribs.act));
134
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &m_attribs.alpha));
135
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("gain", &m_attribs.gain));
136
+ OP_REQUIRES(ctx, m_attribs.grad >= 0, errors::InvalidArgument("grad must be non-negative"));
137
+ OP_REQUIRES(ctx, m_attribs.axis >= 0, errors::InvalidArgument("axis must be non-negative"));
138
+ OP_REQUIRES(ctx, m_attribs.act >= 0, errors::InvalidArgument("act must be non-negative"));
139
+ }
140
+
141
+ void Compute(OpKernelContext* ctx)
142
+ {
143
+ FusedBiasActKernelParams<T> p = m_attribs;
144
+ cudaStream_t stream = ctx->eigen_device<Eigen::GpuDevice>().stream();
145
+
146
+ const Tensor& x = ctx->input(0); // [...]
147
+ const Tensor& b = ctx->input(1); // [sizeB] or [0]
148
+ const Tensor& ref = ctx->input(2); // x.shape or [0]
149
+ p.x = x.flat<T>().data();
150
+ p.b = (b.NumElements()) ? b.flat<T>().data() : NULL;
151
+ p.ref = (ref.NumElements()) ? ref.flat<T>().data() : NULL;
152
+ OP_REQUIRES(ctx, b.NumElements() == 0 || m_attribs.axis < x.dims(), errors::InvalidArgument("axis out of bounds"));
153
+ OP_REQUIRES(ctx, b.dims() == 1, errors::InvalidArgument("b must have rank 1"));
154
+ OP_REQUIRES(ctx, b.NumElements() == 0 || b.NumElements() == x.dim_size(m_attribs.axis), errors::InvalidArgument("b has wrong number of elements"));
155
+ OP_REQUIRES(ctx, ref.NumElements() == ((p.grad == 0) ? 0 : x.NumElements()), errors::InvalidArgument("ref has wrong number of elements"));
156
+ OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("x is too large"));
157
+
158
+ p.sizeX = (int)x.NumElements();
159
+ p.sizeB = (int)b.NumElements();
160
+ p.stepB = 1;
161
+ for (int i = m_attribs.axis + 1; i < x.dims(); i++)
162
+ p.stepB *= (int)x.dim_size(i);
163
+
164
+ Tensor* y = NULL; // x.shape
165
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, x.shape(), &y));
166
+ p.y = y->flat<T>().data();
167
+
168
+ p.loopX = 4;
169
+ int blockSize = 4 * 32;
170
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
171
+ void* args[] = {&p};
172
+ OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel((void*)FusedBiasActKernel<T>, gridSize, blockSize, args, 0, stream));
173
+ }
174
+ };
175
+
176
+ REGISTER_OP("FusedBiasAct")
177
+ .Input ("x: T")
178
+ .Input ("b: T")
179
+ .Input ("ref: T")
180
+ .Output ("y: T")
181
+ .Attr ("T: {float, half}")
182
+ .Attr ("grad: int = 0")
183
+ .Attr ("axis: int = 1")
184
+ .Attr ("act: int = 0")
185
+ .Attr ("alpha: float = 0.0")
186
+ .Attr ("gain: float = 1.0");
187
+ REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint<float>("T"), FusedBiasActOp<float>);
188
+ REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), FusedBiasActOp<Eigen::half>);
189
+
190
+ //------------------------------------------------------------------------
dnnlib/tflib/ops/fused_bias_act.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
4
+ #
5
+ # This work is made available under the Nvidia Source Code License-NC.
6
+ # To view a copy of this license, visit
7
+ # https://nvlabs.github.io/stylegan2/license.html
8
+
9
+ """Custom TensorFlow ops for efficient bias and activation."""
10
+
11
+ import os
12
+ import numpy as np
13
+ import tensorflow as tf
14
+ from .. import custom_ops
15
+ from ...util import EasyDict
16
+
17
+ def _get_plugin():
18
+ return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu')
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ activation_funcs = {
23
+ 'linear': EasyDict(func=lambda x, **_: x, def_alpha=None, def_gain=1.0, cuda_idx=1, ref='y', zero_2nd_grad=True),
24
+ 'relu': EasyDict(func=lambda x, **_: tf.nn.relu(x), def_alpha=None, def_gain=np.sqrt(2), cuda_idx=2, ref='y', zero_2nd_grad=True),
25
+ 'lrelu': EasyDict(func=lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', zero_2nd_grad=True),
26
+ 'tanh': EasyDict(func=lambda x, **_: tf.nn.tanh(x), def_alpha=None, def_gain=1.0, cuda_idx=4, ref='y', zero_2nd_grad=False),
27
+ 'sigmoid': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x), def_alpha=None, def_gain=1.0, cuda_idx=5, ref='y', zero_2nd_grad=False),
28
+ 'elu': EasyDict(func=lambda x, **_: tf.nn.elu(x), def_alpha=None, def_gain=1.0, cuda_idx=6, ref='y', zero_2nd_grad=False),
29
+ 'selu': EasyDict(func=lambda x, **_: tf.nn.selu(x), def_alpha=None, def_gain=1.0, cuda_idx=7, ref='y', zero_2nd_grad=False),
30
+ 'softplus': EasyDict(func=lambda x, **_: tf.nn.softplus(x), def_alpha=None, def_gain=1.0, cuda_idx=8, ref='y', zero_2nd_grad=False),
31
+ 'swish': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x) * x, def_alpha=None, def_gain=np.sqrt(2), cuda_idx=9, ref='x', zero_2nd_grad=False),
32
+ }
33
+
34
+ #----------------------------------------------------------------------------
35
+
36
+ def fused_bias_act(x, b=None, axis=1, act='linear', alpha=None, gain=None, impl='cuda'):
37
+ r"""Fused bias and activation function.
38
+
39
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
40
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
41
+ the fused op is considerably more efficient than performing the same calculation
42
+ using standard TensorFlow ops. It supports first and second order gradients,
43
+ but not third order gradients.
44
+
45
+ Args:
46
+ x: Input activation tensor. Can have any shape, but if `b` is defined, the
47
+ dimension corresponding to `axis`, as well as the rank, must be known.
48
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
49
+ as `x`. The shape must be known, and it must match the dimension of `x`
50
+ corresponding to `axis`.
51
+ axis: The dimension in `x` corresponding to the elements of `b`.
52
+ The value of `axis` is ignored if `b` is not specified.
53
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
54
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
55
+ See `activation_funcs` for a full list. `None` is not allowed.
56
+ alpha: Shape parameter for the activation function, or `None` to use the default.
57
+ gain: Scaling factor for the output tensor, or `None` to use default.
58
+ See `activation_funcs` for the default scaling of each activation function.
59
+ If unsure, consider specifying `1.0`.
60
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
61
+
62
+ Returns:
63
+ Tensor of the same shape and datatype as `x`.
64
+ """
65
+
66
+ impl_dict = {
67
+ 'ref': _fused_bias_act_ref,
68
+ 'cuda': _fused_bias_act_cuda,
69
+ }
70
+ return impl_dict[impl](x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain)
71
+
72
+ #----------------------------------------------------------------------------
73
+
74
+ def _fused_bias_act_ref(x, b, axis, act, alpha, gain):
75
+ """Slow reference implementation of `fused_bias_act()` using standard TensorFlow ops."""
76
+
77
+ # Validate arguments.
78
+ x = tf.convert_to_tensor(x)
79
+ b = tf.convert_to_tensor(b) if b is not None else tf.constant([], dtype=x.dtype)
80
+ act_spec = activation_funcs[act]
81
+ assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis])
82
+ assert b.shape[0] == 0 or 0 <= axis < x.shape.rank
83
+ if alpha is None:
84
+ alpha = act_spec.def_alpha
85
+ if gain is None:
86
+ gain = act_spec.def_gain
87
+
88
+ # Add bias.
89
+ if b.shape[0] != 0:
90
+ x += tf.reshape(b, [-1 if i == axis else 1 for i in range(x.shape.rank)])
91
+
92
+ # Evaluate activation function.
93
+ x = act_spec.func(x, alpha=alpha)
94
+
95
+ # Scale by gain.
96
+ if gain != 1:
97
+ x *= gain
98
+ return x
99
+
100
+ #----------------------------------------------------------------------------
101
+
102
+ def _fused_bias_act_cuda(x, b, axis, act, alpha, gain):
103
+ """Fast CUDA implementation of `fused_bias_act()` using custom ops."""
104
+
105
+ # Validate arguments.
106
+ x = tf.convert_to_tensor(x)
107
+ empty_tensor = tf.constant([], dtype=x.dtype)
108
+ b = tf.convert_to_tensor(b) if b is not None else empty_tensor
109
+ act_spec = activation_funcs[act]
110
+ assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis])
111
+ assert b.shape[0] == 0 or 0 <= axis < x.shape.rank
112
+ if alpha is None:
113
+ alpha = act_spec.def_alpha
114
+ if gain is None:
115
+ gain = act_spec.def_gain
116
+
117
+ # Special cases.
118
+ if act == 'linear' and b is None and gain == 1.0:
119
+ return x
120
+ if act_spec.cuda_idx is None:
121
+ return _fused_bias_act_ref(x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain)
122
+
123
+ # CUDA kernel.
124
+ cuda_kernel = _get_plugin().fused_bias_act
125
+ cuda_kwargs = dict(axis=axis, act=act_spec.cuda_idx, alpha=alpha, gain=gain)
126
+
127
+ # Forward pass: y = func(x, b).
128
+ def func_y(x, b):
129
+ y = cuda_kernel(x=x, b=b, ref=empty_tensor, grad=0, **cuda_kwargs)
130
+ y.set_shape(x.shape)
131
+ return y
132
+
133
+ # Backward pass: dx, db = grad(dy, x, y)
134
+ def grad_dx(dy, x, y):
135
+ ref = {'x': x, 'y': y}[act_spec.ref]
136
+ dx = cuda_kernel(x=dy, b=empty_tensor, ref=ref, grad=1, **cuda_kwargs)
137
+ dx.set_shape(x.shape)
138
+ return dx
139
+ def grad_db(dx):
140
+ if b.shape[0] == 0:
141
+ return empty_tensor
142
+ db = dx
143
+ if axis < x.shape.rank - 1:
144
+ db = tf.reduce_sum(db, list(range(axis + 1, x.shape.rank)))
145
+ if axis > 0:
146
+ db = tf.reduce_sum(db, list(range(axis)))
147
+ db.set_shape(b.shape)
148
+ return db
149
+
150
+ # Second order gradients: d_dy, d_x = grad2(d_dx, d_db, x, y)
151
+ def grad2_d_dy(d_dx, d_db, x, y):
152
+ ref = {'x': x, 'y': y}[act_spec.ref]
153
+ d_dy = cuda_kernel(x=d_dx, b=d_db, ref=ref, grad=1, **cuda_kwargs)
154
+ d_dy.set_shape(x.shape)
155
+ return d_dy
156
+ def grad2_d_x(d_dx, d_db, x, y):
157
+ ref = {'x': x, 'y': y}[act_spec.ref]
158
+ d_x = cuda_kernel(x=d_dx, b=d_db, ref=ref, grad=2, **cuda_kwargs)
159
+ d_x.set_shape(x.shape)
160
+ return d_x
161
+
162
+ # Fast version for piecewise-linear activation funcs.
163
+ @tf.custom_gradient
164
+ def func_zero_2nd_grad(x, b):
165
+ y = func_y(x, b)
166
+ @tf.custom_gradient
167
+ def grad(dy):
168
+ dx = grad_dx(dy, x, y)
169
+ db = grad_db(dx)
170
+ def grad2(d_dx, d_db):
171
+ d_dy = grad2_d_dy(d_dx, d_db, x, y)
172
+ return d_dy
173
+ return (dx, db), grad2
174
+ return y, grad
175
+
176
+ # Slow version for general activation funcs.
177
+ @tf.custom_gradient
178
+ def func_nonzero_2nd_grad(x, b):
179
+ y = func_y(x, b)
180
+ def grad_wrap(dy):
181
+ @tf.custom_gradient
182
+ def grad_impl(dy, x):
183
+ dx = grad_dx(dy, x, y)
184
+ db = grad_db(dx)
185
+ def grad2(d_dx, d_db):
186
+ d_dy = grad2_d_dy(d_dx, d_db, x, y)
187
+ d_x = grad2_d_x(d_dx, d_db, x, y)
188
+ return d_dy, d_x
189
+ return (dx, db), grad2
190
+ return grad_impl(dy, x)
191
+ return y, grad_wrap
192
+
193
+ # Which version to use?
194
+ if act_spec.zero_2nd_grad:
195
+ return func_zero_2nd_grad(x, b)
196
+ return func_nonzero_2nd_grad(x, b)
197
+
198
+ #----------------------------------------------------------------------------
dnnlib/tflib/ops/upfirdn_2d.cu ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
4
+ //
5
+ // This work is made available under the Nvidia Source Code License-NC.
6
+ // To view a copy of this license, visit
7
+ // https://nvlabs.github.io/stylegan2/license.html
8
+
9
+ #define EIGEN_USE_GPU
10
+ #define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__
11
+ #include "tensorflow/core/framework/op.h"
12
+ #include "tensorflow/core/framework/op_kernel.h"
13
+ #include "tensorflow/core/framework/shape_inference.h"
14
+ #include <stdio.h>
15
+
16
+ using namespace tensorflow;
17
+ using namespace tensorflow::shape_inference;
18
+
19
+ //------------------------------------------------------------------------
20
+ // Helpers.
21
+
22
+ #define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false)
23
+
24
+ static __host__ __device__ __forceinline__ int floorDiv(int a, int b)
25
+ {
26
+ int c = a / b;
27
+ if (c * b > a)
28
+ c--;
29
+ return c;
30
+ }
31
+
32
+ //------------------------------------------------------------------------
33
+ // CUDA kernel params.
34
+
35
+ template <class T>
36
+ struct UpFirDn2DKernelParams
37
+ {
38
+ const T* x; // [majorDim, inH, inW, minorDim]
39
+ const T* k; // [kernelH, kernelW]
40
+ T* y; // [majorDim, outH, outW, minorDim]
41
+
42
+ int upx;
43
+ int upy;
44
+ int downx;
45
+ int downy;
46
+ int padx0;
47
+ int padx1;
48
+ int pady0;
49
+ int pady1;
50
+
51
+ int majorDim;
52
+ int inH;
53
+ int inW;
54
+ int minorDim;
55
+ int kernelH;
56
+ int kernelW;
57
+ int outH;
58
+ int outW;
59
+ int loopMajor;
60
+ int loopX;
61
+ };
62
+
63
+ //------------------------------------------------------------------------
64
+ // General CUDA implementation for large filter kernels.
65
+
66
+ template <class T>
67
+ static __global__ void UpFirDn2DKernel_large(const UpFirDn2DKernelParams<T> p)
68
+ {
69
+ // Calculate thread index.
70
+ int minorIdx = blockIdx.x * blockDim.x + threadIdx.x;
71
+ int outY = minorIdx / p.minorDim;
72
+ minorIdx -= outY * p.minorDim;
73
+ int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
74
+ int majorIdxBase = blockIdx.z * p.loopMajor;
75
+ if (outXBase >= p.outW || outY >= p.outH || majorIdxBase >= p.majorDim)
76
+ return;
77
+
78
+ // Setup Y receptive field.
79
+ int midY = outY * p.downy + p.upy - 1 - p.pady0;
80
+ int inY = min(max(floorDiv(midY, p.upy), 0), p.inH);
81
+ int h = min(max(floorDiv(midY + p.kernelH, p.upy), 0), p.inH) - inY;
82
+ int kernelY = midY + p.kernelH - (inY + 1) * p.upy;
83
+
84
+ // Loop over majorDim and outX.
85
+ for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor && majorIdx < p.majorDim; loopMajor++, majorIdx++)
86
+ for (int loopX = 0, outX = outXBase; loopX < p.loopX && outX < p.outW; loopX++, outX += blockDim.y)
87
+ {
88
+ // Setup X receptive field.
89
+ int midX = outX * p.downx + p.upx - 1 - p.padx0;
90
+ int inX = min(max(floorDiv(midX, p.upx), 0), p.inW);
91
+ int w = min(max(floorDiv(midX + p.kernelW, p.upx), 0), p.inW) - inX;
92
+ int kernelX = midX + p.kernelW - (inX + 1) * p.upx;
93
+
94
+ // Initialize pointers.
95
+ const T* xp = &p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx];
96
+ const T* kp = &p.k[kernelY * p.kernelW + kernelX];
97
+ int xpx = p.minorDim;
98
+ int kpx = -p.upx;
99
+ int xpy = p.inW * p.minorDim;
100
+ int kpy = -p.upy * p.kernelW;
101
+
102
+ // Inner loop.
103
+ float v = 0.0f;
104
+ for (int y = 0; y < h; y++)
105
+ {
106
+ for (int x = 0; x < w; x++)
107
+ {
108
+ v += (float)(*xp) * (float)(*kp);
109
+ xp += xpx;
110
+ kp += kpx;
111
+ }
112
+ xp += xpy - w * xpx;
113
+ kp += kpy - w * kpx;
114
+ }
115
+
116
+ // Store result.
117
+ p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v;
118
+ }
119
+ }
120
+
121
+ //------------------------------------------------------------------------
122
+ // Specialized CUDA implementation for small filter kernels.
123
+
124
+ template <class T, int upx, int upy, int downx, int downy, int kernelW, int kernelH, int tileOutW, int tileOutH>
125
+ static __global__ void UpFirDn2DKernel_small(const UpFirDn2DKernelParams<T> p)
126
+ {
127
+ //assert(kernelW % upx == 0);
128
+ //assert(kernelH % upy == 0);
129
+ const int tileInW = ((tileOutW - 1) * downx + kernelW - 1) / upx + 1;
130
+ const int tileInH = ((tileOutH - 1) * downy + kernelH - 1) / upy + 1;
131
+ __shared__ volatile float sk[kernelH][kernelW];
132
+ __shared__ volatile float sx[tileInH][tileInW];
133
+
134
+ // Calculate tile index.
135
+ int minorIdx = blockIdx.x;
136
+ int tileOutY = minorIdx / p.minorDim;
137
+ minorIdx -= tileOutY * p.minorDim;
138
+ tileOutY *= tileOutH;
139
+ int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
140
+ int majorIdxBase = blockIdx.z * p.loopMajor;
141
+ if (tileOutXBase >= p.outW | tileOutY >= p.outH | majorIdxBase >= p.majorDim)
142
+ return;
143
+
144
+ // Load filter kernel (flipped).
145
+ for (int tapIdx = threadIdx.x; tapIdx < kernelH * kernelW; tapIdx += blockDim.x)
146
+ {
147
+ int ky = tapIdx / kernelW;
148
+ int kx = tapIdx - ky * kernelW;
149
+ float v = 0.0f;
150
+ if (kx < p.kernelW & ky < p.kernelH)
151
+ v = (float)p.k[(p.kernelH - 1 - ky) * p.kernelW + (p.kernelW - 1 - kx)];
152
+ sk[ky][kx] = v;
153
+ }
154
+
155
+ // Loop over majorDim and outX.
156
+ for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor & majorIdx < p.majorDim; loopMajor++, majorIdx++)
157
+ for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outW; loopX++, tileOutX += tileOutW)
158
+ {
159
+ // Load input pixels.
160
+ int tileMidX = tileOutX * downx + upx - 1 - p.padx0;
161
+ int tileMidY = tileOutY * downy + upy - 1 - p.pady0;
162
+ int tileInX = floorDiv(tileMidX, upx);
163
+ int tileInY = floorDiv(tileMidY, upy);
164
+ __syncthreads();
165
+ for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW; inIdx += blockDim.x)
166
+ {
167
+ int relInY = inIdx / tileInW;
168
+ int relInX = inIdx - relInY * tileInW;
169
+ int inX = relInX + tileInX;
170
+ int inY = relInY + tileInY;
171
+ float v = 0.0f;
172
+ if (inX >= 0 & inY >= 0 & inX < p.inW & inY < p.inH)
173
+ v = (float)p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx];
174
+ sx[relInY][relInX] = v;
175
+ }
176
+
177
+ // Loop over output pixels.
178
+ __syncthreads();
179
+ for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW; outIdx += blockDim.x)
180
+ {
181
+ int relOutY = outIdx / tileOutW;
182
+ int relOutX = outIdx - relOutY * tileOutW;
183
+ int outX = relOutX + tileOutX;
184
+ int outY = relOutY + tileOutY;
185
+
186
+ // Setup receptive field.
187
+ int midX = tileMidX + relOutX * downx;
188
+ int midY = tileMidY + relOutY * downy;
189
+ int inX = floorDiv(midX, upx);
190
+ int inY = floorDiv(midY, upy);
191
+ int relInX = inX - tileInX;
192
+ int relInY = inY - tileInY;
193
+ int kernelX = (inX + 1) * upx - midX - 1; // flipped
194
+ int kernelY = (inY + 1) * upy - midY - 1; // flipped
195
+
196
+ // Inner loop.
197
+ float v = 0.0f;
198
+ #pragma unroll
199
+ for (int y = 0; y < kernelH / upy; y++)
200
+ #pragma unroll
201
+ for (int x = 0; x < kernelW / upx; x++)
202
+ v += sx[relInY + y][relInX + x] * sk[kernelY + y * upy][kernelX + x * upx];
203
+
204
+ // Store result.
205
+ if (outX < p.outW & outY < p.outH)
206
+ p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v;
207
+ }
208
+ }
209
+ }
210
+
211
+ //------------------------------------------------------------------------
212
+ // TensorFlow op.
213
+
214
+ template <class T>
215
+ struct UpFirDn2DOp : public OpKernel
216
+ {
217
+ UpFirDn2DKernelParams<T> m_attribs;
218
+
219
+ UpFirDn2DOp(OpKernelConstruction* ctx) : OpKernel(ctx)
220
+ {
221
+ memset(&m_attribs, 0, sizeof(m_attribs));
222
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("upx", &m_attribs.upx));
223
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("upy", &m_attribs.upy));
224
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("downx", &m_attribs.downx));
225
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("downy", &m_attribs.downy));
226
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("padx0", &m_attribs.padx0));
227
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("padx1", &m_attribs.padx1));
228
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("pady0", &m_attribs.pady0));
229
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("pady1", &m_attribs.pady1));
230
+ OP_REQUIRES(ctx, m_attribs.upx >= 1 && m_attribs.upy >= 1, errors::InvalidArgument("upx and upy must be at least 1x1"));
231
+ OP_REQUIRES(ctx, m_attribs.downx >= 1 && m_attribs.downy >= 1, errors::InvalidArgument("downx and downy must be at least 1x1"));
232
+ }
233
+
234
+ void Compute(OpKernelContext* ctx)
235
+ {
236
+ UpFirDn2DKernelParams<T> p = m_attribs;
237
+ cudaStream_t stream = ctx->eigen_device<Eigen::GpuDevice>().stream();
238
+
239
+ const Tensor& x = ctx->input(0); // [majorDim, inH, inW, minorDim]
240
+ const Tensor& k = ctx->input(1); // [kernelH, kernelW]
241
+ p.x = x.flat<T>().data();
242
+ p.k = k.flat<T>().data();
243
+ OP_REQUIRES(ctx, x.dims() == 4, errors::InvalidArgument("input must have rank 4"));
244
+ OP_REQUIRES(ctx, k.dims() == 2, errors::InvalidArgument("kernel must have rank 2"));
245
+ OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("input too large"));
246
+ OP_REQUIRES(ctx, k.NumElements() <= kint32max, errors::InvalidArgument("kernel too large"));
247
+
248
+ p.majorDim = (int)x.dim_size(0);
249
+ p.inH = (int)x.dim_size(1);
250
+ p.inW = (int)x.dim_size(2);
251
+ p.minorDim = (int)x.dim_size(3);
252
+ p.kernelH = (int)k.dim_size(0);
253
+ p.kernelW = (int)k.dim_size(1);
254
+ OP_REQUIRES(ctx, p.kernelW >= 1 && p.kernelH >= 1, errors::InvalidArgument("kernel must be at least 1x1"));
255
+
256
+ p.outW = (p.inW * p.upx + p.padx0 + p.padx1 - p.kernelW + p.downx) / p.downx;
257
+ p.outH = (p.inH * p.upy + p.pady0 + p.pady1 - p.kernelH + p.downy) / p.downy;
258
+ OP_REQUIRES(ctx, p.outW >= 1 && p.outH >= 1, errors::InvalidArgument("output must be at least 1x1"));
259
+
260
+ Tensor* y = NULL; // [majorDim, outH, outW, minorDim]
261
+ TensorShape ys;
262
+ ys.AddDim(p.majorDim);
263
+ ys.AddDim(p.outH);
264
+ ys.AddDim(p.outW);
265
+ ys.AddDim(p.minorDim);
266
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, ys, &y));
267
+ p.y = y->flat<T>().data();
268
+ OP_REQUIRES(ctx, y->NumElements() <= kint32max, errors::InvalidArgument("output too large"));
269
+
270
+ // Choose CUDA kernel to use.
271
+ void* cudaKernel = (void*)UpFirDn2DKernel_large<T>;
272
+ int tileOutW = -1;
273
+ int tileOutH = -1;
274
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 7 && p.kernelH <= 7) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 7,7, 64,16>; tileOutW = 64; tileOutH = 16; }
275
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 6,6, 64,16>; tileOutW = 64; tileOutH = 16; }
276
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 5 && p.kernelH <= 5) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 5,5, 64,16>; tileOutW = 64; tileOutH = 16; }
277
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 4,4, 64,16>; tileOutW = 64; tileOutH = 16; }
278
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 3 && p.kernelH <= 3) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 3,3, 64,16>; tileOutW = 64; tileOutH = 16; }
279
+ if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 8) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 8,8, 64,16>; tileOutW = 64; tileOutH = 16; }
280
+ if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 6,6, 64,16>; tileOutW = 64; tileOutH = 16; }
281
+ if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 4,4, 64,16>; tileOutW = 64; tileOutH = 16; }
282
+ if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 2 && p.kernelH <= 2) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 2,2, 64,16>; tileOutW = 64; tileOutH = 16; }
283
+ if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 8 && p.kernelH <= 8) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 8,8, 32,8>; tileOutW = 32; tileOutH = 8; }
284
+ if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 6,6, 32,8>; tileOutW = 32; tileOutH = 8; }
285
+ if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 4,4, 32,8>; tileOutW = 32; tileOutH = 8; }
286
+ if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 2 && p.kernelH <= 2) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 2,2, 32,8>; tileOutW = 32; tileOutH = 8; }
287
+
288
+ // Choose launch params.
289
+ dim3 blockSize;
290
+ dim3 gridSize;
291
+ if (tileOutW > 0 && tileOutH > 0) // small
292
+ {
293
+ p.loopMajor = (p.majorDim - 1) / 16384 + 1;
294
+ p.loopX = 1;
295
+ blockSize = dim3(32 * 8, 1, 1);
296
+ gridSize = dim3(((p.outH - 1) / tileOutH + 1) * p.minorDim, (p.outW - 1) / (p.loopX * tileOutW) + 1, (p.majorDim - 1) / p.loopMajor + 1);
297
+ }
298
+ else // large
299
+ {
300
+ p.loopMajor = (p.majorDim - 1) / 16384 + 1;
301
+ p.loopX = 4;
302
+ blockSize = dim3(4, 32, 1);
303
+ gridSize = dim3((p.outH * p.minorDim - 1) / blockSize.x + 1, (p.outW - 1) / (p.loopX * blockSize.y) + 1, (p.majorDim - 1) / p.loopMajor + 1);
304
+ }
305
+
306
+ // Launch CUDA kernel.
307
+ void* args[] = {&p};
308
+ OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(cudaKernel, gridSize, blockSize, args, 0, stream));
309
+ }
310
+ };
311
+
312
+ REGISTER_OP("UpFirDn2D")
313
+ .Input ("x: T")
314
+ .Input ("k: T")
315
+ .Output ("y: T")
316
+ .Attr ("T: {float, half}")
317
+ .Attr ("upx: int = 1")
318
+ .Attr ("upy: int = 1")
319
+ .Attr ("downx: int = 1")
320
+ .Attr ("downy: int = 1")
321
+ .Attr ("padx0: int = 0")
322
+ .Attr ("padx1: int = 0")
323
+ .Attr ("pady0: int = 0")
324
+ .Attr ("pady1: int = 0");
325
+ REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint<float>("T"), UpFirDn2DOp<float>);
326
+ REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), UpFirDn2DOp<Eigen::half>);
327
+
328
+ //------------------------------------------------------------------------
dnnlib/tflib/ops/upfirdn_2d.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
4
+ #
5
+ # This work is made available under the Nvidia Source Code License-NC.
6
+ # To view a copy of this license, visit
7
+ # https://nvlabs.github.io/stylegan2/license.html
8
+
9
+ """Custom TensorFlow ops for efficient resampling of 2D images."""
10
+
11
+ import os
12
+ import numpy as np
13
+ import tensorflow as tf
14
+ from .. import custom_ops
15
+
16
+ def _get_plugin():
17
+ return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu')
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ def upfirdn_2d(x, k, upx=1, upy=1, downx=1, downy=1, padx0=0, padx1=0, pady0=0, pady1=0, impl='cuda'):
22
+ r"""Pad, upsample, FIR filter, and downsample a batch of 2D images.
23
+
24
+ Accepts a batch of 2D images of the shape `[majorDim, inH, inW, minorDim]`
25
+ and performs the following operations for each image, batched across
26
+ `majorDim` and `minorDim`:
27
+
28
+ 1. Pad the image with zeros by the specified number of pixels on each side
29
+ (`padx0`, `padx1`, `pady0`, `pady1`). Specifying a negative value
30
+ corresponds to cropping the image.
31
+
32
+ 2. Upsample the image by inserting the zeros after each pixel (`upx`, `upy`).
33
+
34
+ 3. Convolve the image with the specified 2D FIR filter (`k`), shrinking the
35
+ image so that the footprint of all output pixels lies within the input image.
36
+
37
+ 4. Downsample the image by throwing away pixels (`downx`, `downy`).
38
+
39
+ This sequence of operations bears close resemblance to scipy.signal.upfirdn().
40
+ The fused op is considerably more efficient than performing the same calculation
41
+ using standard TensorFlow ops. It supports gradients of arbitrary order.
42
+
43
+ Args:
44
+ x: Input tensor of the shape `[majorDim, inH, inW, minorDim]`.
45
+ k: 2D FIR filter of the shape `[firH, firW]`.
46
+ upx: Integer upsampling factor along the X-axis (default: 1).
47
+ upy: Integer upsampling factor along the Y-axis (default: 1).
48
+ downx: Integer downsampling factor along the X-axis (default: 1).
49
+ downy: Integer downsampling factor along the Y-axis (default: 1).
50
+ padx0: Number of pixels to pad on the left side (default: 0).
51
+ padx1: Number of pixels to pad on the right side (default: 0).
52
+ pady0: Number of pixels to pad on the top side (default: 0).
53
+ pady1: Number of pixels to pad on the bottom side (default: 0).
54
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
55
+
56
+ Returns:
57
+ Tensor of the shape `[majorDim, outH, outW, minorDim]`, and same datatype as `x`.
58
+ """
59
+
60
+ impl_dict = {
61
+ 'ref': _upfirdn_2d_ref,
62
+ 'cuda': _upfirdn_2d_cuda,
63
+ }
64
+ return impl_dict[impl](x=x, k=k, upx=upx, upy=upy, downx=downx, downy=downy, padx0=padx0, padx1=padx1, pady0=pady0, pady1=pady1)
65
+
66
+ #----------------------------------------------------------------------------
67
+
68
+ def _upfirdn_2d_ref(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1):
69
+ """Slow reference implementation of `upfirdn_2d()` using standard TensorFlow ops."""
70
+
71
+ x = tf.convert_to_tensor(x)
72
+ k = np.asarray(k, dtype=np.float32)
73
+ assert x.shape.rank == 4
74
+ inH = x.shape[1].value
75
+ inW = x.shape[2].value
76
+ minorDim = _shape(x, 3)
77
+ kernelH, kernelW = k.shape
78
+ assert inW >= 1 and inH >= 1
79
+ assert kernelW >= 1 and kernelH >= 1
80
+ assert isinstance(upx, int) and isinstance(upy, int)
81
+ assert isinstance(downx, int) and isinstance(downy, int)
82
+ assert isinstance(padx0, int) and isinstance(padx1, int)
83
+ assert isinstance(pady0, int) and isinstance(pady1, int)
84
+
85
+ # Upsample (insert zeros).
86
+ x = tf.reshape(x, [-1, inH, 1, inW, 1, minorDim])
87
+ x = tf.pad(x, [[0, 0], [0, 0], [0, upy - 1], [0, 0], [0, upx - 1], [0, 0]])
88
+ x = tf.reshape(x, [-1, inH * upy, inW * upx, minorDim])
89
+
90
+ # Pad (crop if negative).
91
+ x = tf.pad(x, [[0, 0], [max(pady0, 0), max(pady1, 0)], [max(padx0, 0), max(padx1, 0)], [0, 0]])
92
+ x = x[:, max(-pady0, 0) : x.shape[1].value - max(-pady1, 0), max(-padx0, 0) : x.shape[2].value - max(-padx1, 0), :]
93
+
94
+ # Convolve with filter.
95
+ x = tf.transpose(x, [0, 3, 1, 2])
96
+ x = tf.reshape(x, [-1, 1, inH * upy + pady0 + pady1, inW * upx + padx0 + padx1])
97
+ w = tf.constant(k[::-1, ::-1, np.newaxis, np.newaxis], dtype=x.dtype)
98
+ x = tf.nn.conv2d(x, w, strides=[1,1,1,1], padding='VALID', data_format='NCHW')
99
+ x = tf.reshape(x, [-1, minorDim, inH * upy + pady0 + pady1 - kernelH + 1, inW * upx + padx0 + padx1 - kernelW + 1])
100
+ x = tf.transpose(x, [0, 2, 3, 1])
101
+
102
+ # Downsample (throw away pixels).
103
+ return x[:, ::downy, ::downx, :]
104
+
105
+ #----------------------------------------------------------------------------
106
+
107
+ def _upfirdn_2d_cuda(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1):
108
+ """Fast CUDA implementation of `upfirdn_2d()` using custom ops."""
109
+
110
+ x = tf.convert_to_tensor(x)
111
+ k = np.asarray(k, dtype=np.float32)
112
+ majorDim, inH, inW, minorDim = x.shape.as_list()
113
+ kernelH, kernelW = k.shape
114
+ assert inW >= 1 and inH >= 1
115
+ assert kernelW >= 1 and kernelH >= 1
116
+ assert isinstance(upx, int) and isinstance(upy, int)
117
+ assert isinstance(downx, int) and isinstance(downy, int)
118
+ assert isinstance(padx0, int) and isinstance(padx1, int)
119
+ assert isinstance(pady0, int) and isinstance(pady1, int)
120
+
121
+ outW = (inW * upx + padx0 + padx1 - kernelW) // downx + 1
122
+ outH = (inH * upy + pady0 + pady1 - kernelH) // downy + 1
123
+ assert outW >= 1 and outH >= 1
124
+
125
+ kc = tf.constant(k, dtype=x.dtype)
126
+ gkc = tf.constant(k[::-1, ::-1], dtype=x.dtype)
127
+ gpadx0 = kernelW - padx0 - 1
128
+ gpady0 = kernelH - pady0 - 1
129
+ gpadx1 = inW * upx - outW * downx + padx0 - upx + 1
130
+ gpady1 = inH * upy - outH * downy + pady0 - upy + 1
131
+
132
+ @tf.custom_gradient
133
+ def func(x):
134
+ y = _get_plugin().up_fir_dn2d(x=x, k=kc, upx=upx, upy=upy, downx=downx, downy=downy, padx0=padx0, padx1=padx1, pady0=pady0, pady1=pady1)
135
+ y.set_shape([majorDim, outH, outW, minorDim])
136
+ @tf.custom_gradient
137
+ def grad(dy):
138
+ dx = _get_plugin().up_fir_dn2d(x=dy, k=gkc, upx=downx, upy=downy, downx=upx, downy=upy, padx0=gpadx0, padx1=gpadx1, pady0=gpady0, pady1=gpady1)
139
+ dx.set_shape([majorDim, inH, inW, minorDim])
140
+ return dx, func
141
+ return y, grad
142
+ return func(x)
143
+
144
+ #----------------------------------------------------------------------------
145
+
146
+ def filter_2d(x, k, gain=1, data_format='NCHW', impl='cuda'):
147
+ r"""Filter a batch of 2D images with the given FIR filter.
148
+
149
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
150
+ and filters each image with the given filter. The filter is normalized so that
151
+ if the input pixels are constant, they will be scaled by the specified `gain`.
152
+ Pixels outside the image are assumed to be zero.
153
+
154
+ Args:
155
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
156
+ k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
157
+ gain: Scaling factor for signal magnitude (default: 1.0).
158
+ data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
159
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
160
+
161
+ Returns:
162
+ Tensor of the same shape and datatype as `x`.
163
+ """
164
+
165
+ k = _setup_kernel(k) * gain
166
+ p = k.shape[0] - 1
167
+ return _simple_upfirdn_2d(x, k, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl)
168
+
169
+ #----------------------------------------------------------------------------
170
+
171
+ def upsample_2d(x, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
172
+ r"""Upsample a batch of 2D images with the given filter.
173
+
174
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
175
+ and upsamples each image with the given filter. The filter is normalized so that
176
+ if the input pixels are constant, they will be scaled by the specified `gain`.
177
+ Pixels outside the image are assumed to be zero, and the filter is padded with
178
+ zeros so that its shape is a multiple of the upsampling factor.
179
+
180
+ Args:
181
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
182
+ k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
183
+ The default is `[1] * factor`, which corresponds to nearest-neighbor
184
+ upsampling.
185
+ factor: Integer upsampling factor (default: 2).
186
+ gain: Scaling factor for signal magnitude (default: 1.0).
187
+ data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
188
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
189
+
190
+ Returns:
191
+ Tensor of the shape `[N, C, H * factor, W * factor]` or
192
+ `[N, H * factor, W * factor, C]`, and same datatype as `x`.
193
+ """
194
+
195
+ assert isinstance(factor, int) and factor >= 1
196
+ if k is None:
197
+ k = [1] * factor
198
+ k = _setup_kernel(k) * (gain * (factor ** 2))
199
+ p = k.shape[0] - factor
200
+ return _simple_upfirdn_2d(x, k, up=factor, pad0=(p+1)//2+factor-1, pad1=p//2, data_format=data_format, impl=impl)
201
+
202
+ #----------------------------------------------------------------------------
203
+
204
+ def downsample_2d(x, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
205
+ r"""Downsample a batch of 2D images with the given filter.
206
+
207
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
208
+ and downsamples each image with the given filter. The filter is normalized so that
209
+ if the input pixels are constant, they will be scaled by the specified `gain`.
210
+ Pixels outside the image are assumed to be zero, and the filter is padded with
211
+ zeros so that its shape is a multiple of the downsampling factor.
212
+
213
+ Args:
214
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
215
+ k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
216
+ The default is `[1] * factor`, which corresponds to average pooling.
217
+ factor: Integer downsampling factor (default: 2).
218
+ gain: Scaling factor for signal magnitude (default: 1.0).
219
+ data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
220
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
221
+
222
+ Returns:
223
+ Tensor of the shape `[N, C, H // factor, W // factor]` or
224
+ `[N, H // factor, W // factor, C]`, and same datatype as `x`.
225
+ """
226
+
227
+ assert isinstance(factor, int) and factor >= 1
228
+ if k is None:
229
+ k = [1] * factor
230
+ k = _setup_kernel(k) * gain
231
+ p = k.shape[0] - factor
232
+ return _simple_upfirdn_2d(x, k, down=factor, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl)
233
+
234
+ #----------------------------------------------------------------------------
235
+
236
+ def upsample_conv_2d(x, w, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
237
+ r"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
238
+
239
+ Padding is performed only once at the beginning, not between the operations.
240
+ The fused op is considerably more efficient than performing the same calculation
241
+ using standard TensorFlow ops. It supports gradients of arbitrary order.
242
+
243
+ Args:
244
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
245
+ w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`.
246
+ Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
247
+ k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
248
+ The default is `[1] * factor`, which corresponds to nearest-neighbor
249
+ upsampling.
250
+ factor: Integer upsampling factor (default: 2).
251
+ gain: Scaling factor for signal magnitude (default: 1.0).
252
+ data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
253
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
254
+
255
+ Returns:
256
+ Tensor of the shape `[N, C, H * factor, W * factor]` or
257
+ `[N, H * factor, W * factor, C]`, and same datatype as `x`.
258
+ """
259
+
260
+ assert isinstance(factor, int) and factor >= 1
261
+
262
+ # Check weight shape.
263
+ w = tf.convert_to_tensor(w)
264
+ assert w.shape.rank == 4
265
+ convH = w.shape[0].value
266
+ convW = w.shape[1].value
267
+ inC = _shape(w, 2)
268
+ outC = _shape(w, 3)
269
+ assert convW == convH
270
+
271
+ # Setup filter kernel.
272
+ if k is None:
273
+ k = [1] * factor
274
+ k = _setup_kernel(k) * (gain * (factor ** 2))
275
+ p = (k.shape[0] - factor) - (convW - 1)
276
+
277
+ # Determine data dimensions.
278
+ if data_format == 'NCHW':
279
+ stride = [1, 1, factor, factor]
280
+ output_shape = [_shape(x, 0), outC, (_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW]
281
+ num_groups = _shape(x, 1) // inC
282
+ else:
283
+ stride = [1, factor, factor, 1]
284
+ output_shape = [_shape(x, 0), (_shape(x, 1) - 1) * factor + convH, (_shape(x, 2) - 1) * factor + convW, outC]
285
+ num_groups = _shape(x, 3) // inC
286
+
287
+ # Transpose weights.
288
+ w = tf.reshape(w, [convH, convW, inC, num_groups, -1])
289
+ w = tf.transpose(w[::-1, ::-1], [0, 1, 4, 3, 2])
290
+ w = tf.reshape(w, [convH, convW, -1, num_groups * inC])
291
+
292
+ # Execute.
293
+ x = tf.nn.conv2d_transpose(x, w, output_shape=output_shape, strides=stride, padding='VALID', data_format=data_format)
294
+ return _simple_upfirdn_2d(x, k, pad0=(p+1)//2+factor-1, pad1=p//2+1, data_format=data_format, impl=impl)
295
+
296
+ #----------------------------------------------------------------------------
297
+
298
+ def conv_downsample_2d(x, w, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
299
+ r"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
300
+
301
+ Padding is performed only once at the beginning, not between the operations.
302
+ The fused op is considerably more efficient than performing the same calculation
303
+ using standard TensorFlow ops. It supports gradients of arbitrary order.
304
+
305
+ Args:
306
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
307
+ w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`.
308
+ Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
309
+ k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
310
+ The default is `[1] * factor`, which corresponds to average pooling.
311
+ factor: Integer downsampling factor (default: 2).
312
+ gain: Scaling factor for signal magnitude (default: 1.0).
313
+ data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
314
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
315
+
316
+ Returns:
317
+ Tensor of the shape `[N, C, H // factor, W // factor]` or
318
+ `[N, H // factor, W // factor, C]`, and same datatype as `x`.
319
+ """
320
+
321
+ assert isinstance(factor, int) and factor >= 1
322
+ w = tf.convert_to_tensor(w)
323
+ convH, convW, _inC, _outC = w.shape.as_list()
324
+ assert convW == convH
325
+ if k is None:
326
+ k = [1] * factor
327
+ k = _setup_kernel(k) * gain
328
+ p = (k.shape[0] - factor) + (convW - 1)
329
+ if data_format == 'NCHW':
330
+ s = [1, 1, factor, factor]
331
+ else:
332
+ s = [1, factor, factor, 1]
333
+ x = _simple_upfirdn_2d(x, k, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl)
334
+ return tf.nn.conv2d(x, w, strides=s, padding='VALID', data_format=data_format)
335
+
336
+ #----------------------------------------------------------------------------
337
+ # Internal helper funcs.
338
+
339
+ def _shape(tf_expr, dim_idx):
340
+ if tf_expr.shape.rank is not None:
341
+ dim = tf_expr.shape[dim_idx].value
342
+ if dim is not None:
343
+ return dim
344
+ return tf.shape(tf_expr)[dim_idx]
345
+
346
+ def _setup_kernel(k):
347
+ k = np.asarray(k, dtype=np.float32)
348
+ if k.ndim == 1:
349
+ k = np.outer(k, k)
350
+ k /= np.sum(k)
351
+ assert k.ndim == 2
352
+ assert k.shape[0] == k.shape[1]
353
+ return k
354
+
355
+ def _simple_upfirdn_2d(x, k, up=1, down=1, pad0=0, pad1=0, data_format='NCHW', impl='cuda'):
356
+ assert data_format in ['NCHW', 'NHWC']
357
+ assert x.shape.rank == 4
358
+ y = x
359
+ if data_format == 'NCHW':
360
+ y = tf.reshape(y, [-1, _shape(y, 2), _shape(y, 3), 1])
361
+ y = upfirdn_2d(y, k, upx=up, upy=up, downx=down, downy=down, padx0=pad0, padx1=pad1, pady0=pad0, pady1=pad1, impl=impl)
362
+ if data_format == 'NCHW':
363
+ y = tf.reshape(y, [-1, _shape(x, 1), _shape(y, 1), _shape(y, 2)])
364
+ return y
365
+
366
+ #----------------------------------------------------------------------------
dnnlib/tflib/optimizer.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
4
+ #
5
+ # This work is made available under the Nvidia Source Code License-NC.
6
+ # To view a copy of this license, visit
7
+ # https://nvlabs.github.io/stylegan2/license.html
8
+
9
+ """Helper wrapper for a Tensorflow optimizer."""
10
+
11
+ import numpy as np
12
+ import tensorflow as tf
13
+
14
+ from collections import OrderedDict
15
+ from typing import List, Union
16
+
17
+ from . import autosummary
18
+ from . import tfutil
19
+ from .. import util
20
+
21
+ from .tfutil import TfExpression, TfExpressionEx
22
+
23
+ try:
24
+ # TensorFlow 1.13
25
+ from tensorflow.python.ops import nccl_ops
26
+ except:
27
+ # Older TensorFlow versions
28
+ import tensorflow.contrib.nccl as nccl_ops
29
+
30
+ class Optimizer:
31
+ """A Wrapper for tf.train.Optimizer.
32
+
33
+ Automatically takes care of:
34
+ - Gradient averaging for multi-GPU training.
35
+ - Gradient accumulation for arbitrarily large minibatches.
36
+ - Dynamic loss scaling and typecasts for FP16 training.
37
+ - Ignoring corrupted gradients that contain NaNs/Infs.
38
+ - Reporting statistics.
39
+ - Well-chosen default settings.
40
+ """
41
+
42
+ def __init__(self,
43
+ name: str = "Train", # Name string that will appear in TensorFlow graph.
44
+ tf_optimizer: str = "tf.train.AdamOptimizer", # Underlying optimizer class.
45
+ learning_rate: TfExpressionEx = 0.001, # Learning rate. Can vary over time.
46
+ minibatch_multiplier: TfExpressionEx = None, # Treat N consecutive minibatches as one by accumulating gradients.
47
+ share: "Optimizer" = None, # Share internal state with a previously created optimizer?
48
+ use_loss_scaling: bool = False, # Enable dynamic loss scaling for robust mixed-precision training?
49
+ loss_scaling_init: float = 64.0, # Log2 of initial loss scaling factor.
50
+ loss_scaling_inc: float = 0.0005, # Log2 of per-minibatch loss scaling increment when there is no overflow.
51
+ loss_scaling_dec: float = 1.0, # Log2 of per-minibatch loss scaling decrement when there is an overflow.
52
+ report_mem_usage: bool = False, # Report fine-grained memory usage statistics in TensorBoard?
53
+ **kwargs):
54
+
55
+ # Public fields.
56
+ self.name = name
57
+ self.learning_rate = learning_rate
58
+ self.minibatch_multiplier = minibatch_multiplier
59
+ self.id = self.name.replace("/", ".")
60
+ self.scope = tf.get_default_graph().unique_name(self.id)
61
+ self.optimizer_class = util.get_obj_by_name(tf_optimizer)
62
+ self.optimizer_kwargs = dict(kwargs)
63
+ self.use_loss_scaling = use_loss_scaling
64
+ self.loss_scaling_init = loss_scaling_init
65
+ self.loss_scaling_inc = loss_scaling_inc
66
+ self.loss_scaling_dec = loss_scaling_dec
67
+
68
+ # Private fields.
69
+ self._updates_applied = False
70
+ self._devices = OrderedDict() # device_name => EasyDict()
71
+ self._shared_optimizers = OrderedDict() # device_name => optimizer_class
72
+ self._gradient_shapes = None # [shape, ...]
73
+ self._report_mem_usage = report_mem_usage
74
+
75
+ # Validate arguments.
76
+ assert callable(self.optimizer_class)
77
+
78
+ # Share internal state if requested.
79
+ if share is not None:
80
+ assert isinstance(share, Optimizer)
81
+ assert self.optimizer_class is share.optimizer_class
82
+ assert self.learning_rate is share.learning_rate
83
+ assert self.optimizer_kwargs == share.optimizer_kwargs
84
+ self._shared_optimizers = share._shared_optimizers # pylint: disable=protected-access
85
+
86
+ def _get_device(self, device_name: str):
87
+ """Get internal state for the given TensorFlow device."""
88
+ tfutil.assert_tf_initialized()
89
+ if device_name in self._devices:
90
+ return self._devices[device_name]
91
+
92
+ # Initialize fields.
93
+ device = util.EasyDict()
94
+ device.name = device_name
95
+ device.optimizer = None # Underlying optimizer: optimizer_class
96
+ device.loss_scaling_var = None # Log2 of loss scaling: tf.Variable
97
+ device.grad_raw = OrderedDict() # Raw gradients: var => [grad, ...]
98
+ device.grad_clean = OrderedDict() # Clean gradients: var => grad
99
+ device.grad_acc_vars = OrderedDict() # Accumulation sums: var => tf.Variable
100
+ device.grad_acc_count = None # Accumulation counter: tf.Variable
101
+ device.grad_acc = OrderedDict() # Accumulated gradients: var => grad
102
+
103
+ # Setup TensorFlow objects.
104
+ with tfutil.absolute_name_scope(self.scope + "/Devices"), tf.device(device_name), tf.control_dependencies(None):
105
+ if device_name not in self._shared_optimizers:
106
+ optimizer_name = self.scope.replace("/", "_") + "_opt%d" % len(self._shared_optimizers)
107
+ self._shared_optimizers[device_name] = self.optimizer_class(name=optimizer_name, learning_rate=self.learning_rate, **self.optimizer_kwargs)
108
+ device.optimizer = self._shared_optimizers[device_name]
109
+ if self.use_loss_scaling:
110
+ device.loss_scaling_var = tf.Variable(np.float32(self.loss_scaling_init), trainable=False, name="loss_scaling_var")
111
+
112
+ # Register device.
113
+ self._devices[device_name] = device
114
+ return device
115
+
116
+ def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None:
117
+ """Register the gradients of the given loss function with respect to the given variables.
118
+ Intended to be called once per GPU."""
119
+ tfutil.assert_tf_initialized()
120
+ assert not self._updates_applied
121
+ device = self._get_device(loss.device)
122
+
123
+ # Validate trainables.
124
+ if isinstance(trainable_vars, dict):
125
+ trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars
126
+ assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
127
+ assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])
128
+ assert all(var.device == device.name for var in trainable_vars)
129
+
130
+ # Validate shapes.
131
+ if self._gradient_shapes is None:
132
+ self._gradient_shapes = [var.shape.as_list() for var in trainable_vars]
133
+ assert len(trainable_vars) == len(self._gradient_shapes)
134
+ assert all(var.shape.as_list() == var_shape for var, var_shape in zip(trainable_vars, self._gradient_shapes))
135
+
136
+ # Report memory usage if requested.
137
+ deps = []
138
+ if self._report_mem_usage:
139
+ self._report_mem_usage = False
140
+ try:
141
+ with tf.name_scope(self.id + '_mem'), tf.device(device.name), tf.control_dependencies([loss]):
142
+ deps.append(autosummary.autosummary(self.id + "/mem_usage_gb", tf.contrib.memory_stats.BytesInUse() / 2**30))
143
+ except tf.errors.NotFoundError:
144
+ pass
145
+
146
+ # Compute gradients.
147
+ with tf.name_scope(self.id + "_grad"), tf.device(device.name), tf.control_dependencies(deps):
148
+ loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
149
+ gate = tf.train.Optimizer.GATE_NONE # disable gating to reduce memory usage
150
+ grad_list = device.optimizer.compute_gradients(loss=loss, var_list=trainable_vars, gate_gradients=gate)
151
+
152
+ # Register gradients.
153
+ for grad, var in grad_list:
154
+ if var not in device.grad_raw:
155
+ device.grad_raw[var] = []
156
+ device.grad_raw[var].append(grad)
157
+
158
+ def apply_updates(self, allow_no_op: bool = False) -> tf.Operation:
159
+ """Construct training op to update the registered variables based on their gradients."""
160
+ tfutil.assert_tf_initialized()
161
+ assert not self._updates_applied
162
+ self._updates_applied = True
163
+ all_ops = []
164
+
165
+ # Check for no-op.
166
+ if allow_no_op and len(self._devices) == 0:
167
+ with tfutil.absolute_name_scope(self.scope):
168
+ return tf.no_op(name='TrainingOp')
169
+
170
+ # Clean up gradients.
171
+ for device_idx, device in enumerate(self._devices.values()):
172
+ with tfutil.absolute_name_scope(self.scope + "/Clean%d" % device_idx), tf.device(device.name):
173
+ for var, grad in device.grad_raw.items():
174
+
175
+ # Filter out disconnected gradients and convert to float32.
176
+ grad = [g for g in grad if g is not None]
177
+ grad = [tf.cast(g, tf.float32) for g in grad]
178
+
179
+ # Sum within the device.
180
+ if len(grad) == 0:
181
+ grad = tf.zeros(var.shape) # No gradients => zero.
182
+ elif len(grad) == 1:
183
+ grad = grad[0] # Single gradient => use as is.
184
+ else:
185
+ grad = tf.add_n(grad) # Multiple gradients => sum.
186
+
187
+ # Scale as needed.
188
+ scale = 1.0 / len(device.grad_raw[var]) / len(self._devices)
189
+ scale = tf.constant(scale, dtype=tf.float32, name="scale")
190
+ if self.minibatch_multiplier is not None:
191
+ scale /= tf.cast(self.minibatch_multiplier, tf.float32)
192
+ scale = self.undo_loss_scaling(scale)
193
+ device.grad_clean[var] = grad * scale
194
+
195
+ # Sum gradients across devices.
196
+ if len(self._devices) > 1:
197
+ with tfutil.absolute_name_scope(self.scope + "/Broadcast"), tf.device(None):
198
+ for all_vars in zip(*[device.grad_clean.keys() for device in self._devices.values()]):
199
+ if len(all_vars) > 0 and all(dim > 0 for dim in all_vars[0].shape.as_list()): # NCCL does not support zero-sized tensors.
200
+ all_grads = [device.grad_clean[var] for device, var in zip(self._devices.values(), all_vars)]
201
+ all_grads = nccl_ops.all_sum(all_grads)
202
+ for device, var, grad in zip(self._devices.values(), all_vars, all_grads):
203
+ device.grad_clean[var] = grad
204
+
205
+ # Apply updates separately on each device.
206
+ for device_idx, device in enumerate(self._devices.values()):
207
+ with tfutil.absolute_name_scope(self.scope + "/Apply%d" % device_idx), tf.device(device.name):
208
+ # pylint: disable=cell-var-from-loop
209
+
210
+ # Accumulate gradients over time.
211
+ if self.minibatch_multiplier is None:
212
+ acc_ok = tf.constant(True, name='acc_ok')
213
+ device.grad_acc = OrderedDict(device.grad_clean)
214
+ else:
215
+ # Create variables.
216
+ with tf.control_dependencies(None):
217
+ for var in device.grad_clean.keys():
218
+ device.grad_acc_vars[var] = tf.Variable(tf.zeros(var.shape), trainable=False, name="grad_acc_var")
219
+ device.grad_acc_count = tf.Variable(tf.zeros([]), trainable=False, name="grad_acc_count")
220
+
221
+ # Track counter.
222
+ count_cur = device.grad_acc_count + 1.0
223
+ count_inc_op = lambda: tf.assign(device.grad_acc_count, count_cur)
224
+ count_reset_op = lambda: tf.assign(device.grad_acc_count, tf.zeros([]))
225
+ acc_ok = (count_cur >= tf.cast(self.minibatch_multiplier, tf.float32))
226
+ all_ops.append(tf.cond(acc_ok, count_reset_op, count_inc_op))
227
+
228
+ # Track gradients.
229
+ for var, grad in device.grad_clean.items():
230
+ acc_var = device.grad_acc_vars[var]
231
+ acc_cur = acc_var + grad
232
+ device.grad_acc[var] = acc_cur
233
+ with tf.control_dependencies([acc_cur]):
234
+ acc_inc_op = lambda: tf.assign(acc_var, acc_cur)
235
+ acc_reset_op = lambda: tf.assign(acc_var, tf.zeros(var.shape))
236
+ all_ops.append(tf.cond(acc_ok, acc_reset_op, acc_inc_op))
237
+
238
+ # No overflow => apply gradients.
239
+ all_ok = tf.reduce_all(tf.stack([acc_ok] + [tf.reduce_all(tf.is_finite(g)) for g in device.grad_acc.values()]))
240
+ apply_op = lambda: device.optimizer.apply_gradients([(tf.cast(grad, var.dtype), var) for var, grad in device.grad_acc.items()])
241
+ all_ops.append(tf.cond(all_ok, apply_op, tf.no_op))
242
+
243
+ # Adjust loss scaling.
244
+ if self.use_loss_scaling:
245
+ ls_inc_op = lambda: tf.assign_add(device.loss_scaling_var, self.loss_scaling_inc)
246
+ ls_dec_op = lambda: tf.assign_sub(device.loss_scaling_var, self.loss_scaling_dec)
247
+ ls_update_op = lambda: tf.group(tf.cond(all_ok, ls_inc_op, ls_dec_op))
248
+ all_ops.append(tf.cond(acc_ok, ls_update_op, tf.no_op))
249
+
250
+ # Last device => report statistics.
251
+ if device_idx == len(self._devices) - 1:
252
+ all_ops.append(autosummary.autosummary(self.id + "/learning_rate", self.learning_rate))
253
+ all_ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(all_ok, 0, 1), condition=acc_ok))
254
+ if self.use_loss_scaling:
255
+ all_ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", device.loss_scaling_var))
256
+
257
+ # Initialize variables.
258
+ self.reset_optimizer_state()
259
+ if self.use_loss_scaling:
260
+ tfutil.init_uninitialized_vars([device.loss_scaling_var for device in self._devices.values()])
261
+ if self.minibatch_multiplier is not None:
262
+ tfutil.run([var.initializer for device in self._devices.values() for var in list(device.grad_acc_vars.values()) + [device.grad_acc_count]])
263
+
264
+ # Group everything into a single op.
265
+ with tfutil.absolute_name_scope(self.scope):
266
+ return tf.group(*all_ops, name="TrainingOp")
267
+
268
+ def reset_optimizer_state(self) -> None:
269
+ """Reset internal state of the underlying optimizer."""
270
+ tfutil.assert_tf_initialized()
271
+ tfutil.run([var.initializer for device in self._devices.values() for var in device.optimizer.variables()])
272
+
273
+ def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:
274
+ """Get or create variable representing log2 of the current dynamic loss scaling factor."""
275
+ return self._get_device(device).loss_scaling_var
276
+
277
+ def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
278
+ """Apply dynamic loss scaling for the given expression."""
279
+ assert tfutil.is_tf_expression(value)
280
+ if not self.use_loss_scaling:
281
+ return value
282
+ return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
283
+
284
+ def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
285
+ """Undo the effect of dynamic loss scaling for the given expression."""
286
+ assert tfutil.is_tf_expression(value)
287
+ if not self.use_loss_scaling:
288
+ return value
289
+ return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type
290
+
291
+
292
+ class SimpleAdam:
293
+ """Simplified version of tf.train.AdamOptimizer that behaves identically when used with dnnlib.tflib.Optimizer."""
294
+
295
+ def __init__(self, name="Adam", learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
296
+ self.name = name
297
+ self.learning_rate = learning_rate
298
+ self.beta1 = beta1
299
+ self.beta2 = beta2
300
+ self.epsilon = epsilon
301
+ self.all_state_vars = []
302
+
303
+ def variables(self):
304
+ return self.all_state_vars
305
+
306
+ def compute_gradients(self, loss, var_list, gate_gradients=tf.train.Optimizer.GATE_NONE):
307
+ assert gate_gradients == tf.train.Optimizer.GATE_NONE
308
+ return list(zip(tf.gradients(loss, var_list), var_list))
309
+
310
+ def apply_gradients(self, grads_and_vars):
311
+ with tf.name_scope(self.name):
312
+ state_vars = []
313
+ update_ops = []
314
+
315
+ # Adjust learning rate to deal with startup bias.
316
+ with tf.control_dependencies(None):
317
+ b1pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
318
+ b2pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
319
+ state_vars += [b1pow_var, b2pow_var]
320
+ b1pow_new = b1pow_var * self.beta1
321
+ b2pow_new = b2pow_var * self.beta2
322
+ update_ops += [tf.assign(b1pow_var, b1pow_new), tf.assign(b2pow_var, b2pow_new)]
323
+ lr_new = self.learning_rate * tf.sqrt(1 - b2pow_new) / (1 - b1pow_new)
324
+
325
+ # Construct ops to update each variable.
326
+ for grad, var in grads_and_vars:
327
+ with tf.control_dependencies(None):
328
+ m_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
329
+ v_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
330
+ state_vars += [m_var, v_var]
331
+ m_new = self.beta1 * m_var + (1 - self.beta1) * grad
332
+ v_new = self.beta2 * v_var + (1 - self.beta2) * tf.square(grad)
333
+ var_delta = lr_new * m_new / (tf.sqrt(v_new) + self.epsilon)
334
+ update_ops += [tf.assign(m_var, m_new), tf.assign(v_var, v_new), tf.assign_sub(var, var_delta)]
335
+
336
+ # Group everything together.
337
+ self.all_state_vars += state_vars
338
+ return tf.group(*update_ops)
dnnlib/tflib/tfutil.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
4
+ #
5
+ # This work is made available under the Nvidia Source Code License-NC.
6
+ # To view a copy of this license, visit
7
+ # https://nvlabs.github.io/stylegan2/license.html
8
+
9
+ """Miscellaneous helper utils for Tensorflow."""
10
+
11
+ import os
12
+ import numpy as np
13
+ import tensorflow as tf
14
+
15
+ # Silence deprecation warnings from TensorFlow 1.13 onwards
16
+ import logging
17
+ logging.getLogger('tensorflow').setLevel(logging.ERROR)
18
+ import tensorflow.contrib # requires TensorFlow 1.x!
19
+ tf.contrib = tensorflow.contrib
20
+
21
+ from typing import Any, Iterable, List, Union
22
+
23
+ TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation]
24
+ """A type that represents a valid Tensorflow expression."""
25
+
26
+ TfExpressionEx = Union[TfExpression, int, float, np.ndarray]
27
+ """A type that can be converted to a valid Tensorflow expression."""
28
+
29
+
30
+ def run(*args, **kwargs) -> Any:
31
+ """Run the specified ops in the default session."""
32
+ assert_tf_initialized()
33
+ return tf.get_default_session().run(*args, **kwargs)
34
+
35
+
36
+ def is_tf_expression(x: Any) -> bool:
37
+ """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation."""
38
+ return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation))
39
+
40
+
41
+ def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]:
42
+ """Convert a Tensorflow shape to a list of ints. Retained for backwards compatibility -- use TensorShape.as_list() in new code."""
43
+ return [dim.value for dim in shape]
44
+
45
+
46
+ def flatten(x: TfExpressionEx) -> TfExpression:
47
+ """Shortcut function for flattening a tensor."""
48
+ with tf.name_scope("Flatten"):
49
+ return tf.reshape(x, [-1])
50
+
51
+
52
+ def log2(x: TfExpressionEx) -> TfExpression:
53
+ """Logarithm in base 2."""
54
+ with tf.name_scope("Log2"):
55
+ return tf.log(x) * np.float32(1.0 / np.log(2.0))
56
+
57
+
58
+ def exp2(x: TfExpressionEx) -> TfExpression:
59
+ """Exponent in base 2."""
60
+ with tf.name_scope("Exp2"):
61
+ return tf.exp(x * np.float32(np.log(2.0)))
62
+
63
+
64
+ def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx:
65
+ """Linear interpolation."""
66
+ with tf.name_scope("Lerp"):
67
+ return a + (b - a) * t
68
+
69
+
70
+ def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression:
71
+ """Linear interpolation with clip."""
72
+ with tf.name_scope("LerpClip"):
73
+ return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)
74
+
75
+
76
+ def absolute_name_scope(scope: str) -> tf.name_scope:
77
+ """Forcefully enter the specified name scope, ignoring any surrounding scopes."""
78
+ return tf.name_scope(scope + "/")
79
+
80
+
81
+ def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope:
82
+ """Forcefully enter the specified variable scope, ignoring any surrounding scopes."""
83
+ return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False)
84
+
85
+
86
+ def _sanitize_tf_config(config_dict: dict = None) -> dict:
87
+ # Defaults.
88
+ cfg = dict()
89
+ cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is.
90
+ cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is.
91
+ cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info.
92
+ cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used.
93
+ cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed.
94
+
95
+ # Remove defaults for environment variables that are already set.
96
+ for key in list(cfg):
97
+ fields = key.split(".")
98
+ if fields[0] == "env":
99
+ assert len(fields) == 2
100
+ if fields[1] in os.environ:
101
+ del cfg[key]
102
+
103
+ # User overrides.
104
+ if config_dict is not None:
105
+ cfg.update(config_dict)
106
+ return cfg
107
+
108
+
109
+ def init_tf(config_dict: dict = None) -> None:
110
+ """Initialize TensorFlow session using good default settings."""
111
+ # Skip if already initialized.
112
+ if tf.get_default_session() is not None:
113
+ return
114
+
115
+ # Setup config dict and random seeds.
116
+ cfg = _sanitize_tf_config(config_dict)
117
+ np_random_seed = cfg["rnd.np_random_seed"]
118
+ if np_random_seed is not None:
119
+ np.random.seed(np_random_seed)
120
+ tf_random_seed = cfg["rnd.tf_random_seed"]
121
+ if tf_random_seed == "auto":
122
+ tf_random_seed = np.random.randint(1 << 31)
123
+ if tf_random_seed is not None:
124
+ tf.set_random_seed(tf_random_seed)
125
+
126
+ # Setup environment variables.
127
+ for key, value in cfg.items():
128
+ fields = key.split(".")
129
+ if fields[0] == "env":
130
+ assert len(fields) == 2
131
+ os.environ[fields[1]] = str(value)
132
+
133
+ # Create default TensorFlow session.
134
+ create_session(cfg, force_as_default=True)
135
+
136
+
137
+ def assert_tf_initialized():
138
+ """Check that TensorFlow session has been initialized."""
139
+ if tf.get_default_session() is None:
140
+ raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().")
141
+
142
+
143
+ def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session:
144
+ """Create tf.Session based on config dict."""
145
+ # Setup TensorFlow config proto.
146
+ cfg = _sanitize_tf_config(config_dict)
147
+ config_proto = tf.ConfigProto()
148
+ for key, value in cfg.items():
149
+ fields = key.split(".")
150
+ if fields[0] not in ["rnd", "env"]:
151
+ obj = config_proto
152
+ for field in fields[:-1]:
153
+ obj = getattr(obj, field)
154
+ setattr(obj, fields[-1], value)
155
+
156
+ # Create session.
157
+ session = tf.Session(config=config_proto)
158
+ if force_as_default:
159
+ # pylint: disable=protected-access
160
+ session._default_session = session.as_default()
161
+ session._default_session.enforce_nesting = False
162
+ session._default_session.__enter__()
163
+ return session
164
+
165
+
166
+ def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None:
167
+ """Initialize all tf.Variables that have not already been initialized.
168
+
169
+ Equivalent to the following, but more efficient and does not bloat the tf graph:
170
+ tf.variables_initializer(tf.report_uninitialized_variables()).run()
171
+ """
172
+ assert_tf_initialized()
173
+ if target_vars is None:
174
+ target_vars = tf.global_variables()
175
+
176
+ test_vars = []
177
+ test_ops = []
178
+
179
+ with tf.control_dependencies(None): # ignore surrounding control_dependencies
180
+ for var in target_vars:
181
+ assert is_tf_expression(var)
182
+
183
+ try:
184
+ tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0"))
185
+ except KeyError:
186
+ # Op does not exist => variable may be uninitialized.
187
+ test_vars.append(var)
188
+
189
+ with absolute_name_scope(var.name.split(":")[0]):
190
+ test_ops.append(tf.is_variable_initialized(var))
191
+
192
+ init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited]
193
+ run([var.initializer for var in init_vars])
194
+
195
+
196
+ def set_vars(var_to_value_dict: dict) -> None:
197
+ """Set the values of given tf.Variables.
198
+
199
+ Equivalent to the following, but more efficient and does not bloat the tf graph:
200
+ tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()]
201
+ """
202
+ assert_tf_initialized()
203
+ ops = []
204
+ feed_dict = {}
205
+
206
+ for var, value in var_to_value_dict.items():
207
+ assert is_tf_expression(var)
208
+
209
+ try:
210
+ setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op
211
+ except KeyError:
212
+ with absolute_name_scope(var.name.split(":")[0]):
213
+ with tf.control_dependencies(None): # ignore surrounding control_dependencies
214
+ setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter
215
+
216
+ ops.append(setter)
217
+ feed_dict[setter.op.inputs[1]] = value
218
+
219
+ run(ops, feed_dict)
220
+
221
+
222
+ def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs):
223
+ """Create tf.Variable with large initial value without bloating the tf graph."""
224
+ assert_tf_initialized()
225
+ assert isinstance(initial_value, np.ndarray)
226
+ zeros = tf.zeros(initial_value.shape, initial_value.dtype)
227
+ var = tf.Variable(zeros, *args, **kwargs)
228
+ set_vars({var: initial_value})
229
+ return var
230
+
231
+
232
+ def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
233
+ """Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
234
+ Can be used as an input transformation for Network.run().
235
+ """
236
+ images = tf.cast(images, tf.float32)
237
+ if nhwc_to_nchw:
238
+ images = tf.transpose(images, [0, 3, 1, 2])
239
+ return images * ((drange[1] - drange[0]) / 255) + drange[0]
240
+
241
+
242
+ def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1):
243
+ """Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
244
+ Can be used as an output transformation for Network.run().
245
+ """
246
+ images = tf.cast(images, tf.float32)
247
+ if shrink > 1:
248
+ ksize = [1, 1, shrink, shrink]
249
+ images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW")
250
+ if nchw_to_nhwc:
251
+ images = tf.transpose(images, [0, 2, 3, 1])
252
+ scale = 255 / (drange[1] - drange[0])
253
+ images = images * scale + (0.5 - drange[0] * scale)
254
+ return tf.saturate_cast(images, tf.uint8)
dnnlib/util.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
5
+ # and proprietary rights in and to this software, related documentation
6
+ # and any modifications thereto. Any use, reproduction, disclosure or
7
+ # distribution of this software and related documentation without an express
8
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
9
+
10
+ """Miscellaneous utility classes and functions."""
11
+
12
+ import ctypes
13
+ import fnmatch
14
+ import importlib
15
+ import inspect
16
+ import numpy as np
17
+ import os
18
+ import shutil
19
+ import sys
20
+ import types
21
+ import io
22
+ import pickle
23
+ import re
24
+ import requests
25
+ import html
26
+ import hashlib
27
+ import glob
28
+ import tempfile
29
+ import urllib
30
+ import urllib.request
31
+ import uuid
32
+
33
+ from distutils.util import strtobool
34
+ from typing import Any, List, Tuple, Union
35
+
36
+
37
+ # Util classes
38
+ # ------------------------------------------------------------------------------------------
39
+
40
+
41
+ class EasyDict(dict):
42
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
43
+
44
+ def __getattr__(self, name: str) -> Any:
45
+ try:
46
+ return self[name]
47
+ except KeyError:
48
+ raise AttributeError(name)
49
+
50
+ def __setattr__(self, name: str, value: Any) -> None:
51
+ self[name] = value
52
+
53
+ def __delattr__(self, name: str) -> None:
54
+ del self[name]
55
+
56
+
57
+ class Logger(object):
58
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
59
+
60
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
61
+ self.file = None
62
+
63
+ if file_name is not None:
64
+ self.file = open(file_name, file_mode)
65
+
66
+ self.should_flush = should_flush
67
+ self.stdout = sys.stdout
68
+ self.stderr = sys.stderr
69
+
70
+ sys.stdout = self
71
+ sys.stderr = self
72
+
73
+ def __enter__(self) -> "Logger":
74
+ return self
75
+
76
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
77
+ self.close()
78
+
79
+ def write(self, text: Union[str, bytes]) -> None:
80
+ """Write text to stdout (and a file) and optionally flush."""
81
+ if isinstance(text, bytes):
82
+ text = text.decode()
83
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
84
+ return
85
+
86
+ if self.file is not None:
87
+ self.file.write(text)
88
+
89
+ self.stdout.write(text)
90
+
91
+ if self.should_flush:
92
+ self.flush()
93
+
94
+ def flush(self) -> None:
95
+ """Flush written text to both stdout and a file, if open."""
96
+ if self.file is not None:
97
+ self.file.flush()
98
+
99
+ self.stdout.flush()
100
+
101
+ def close(self) -> None:
102
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
103
+ self.flush()
104
+
105
+ # if using multiple loggers, prevent closing in wrong order
106
+ if sys.stdout is self:
107
+ sys.stdout = self.stdout
108
+ if sys.stderr is self:
109
+ sys.stderr = self.stderr
110
+
111
+ if self.file is not None:
112
+ self.file.close()
113
+ self.file = None
114
+
115
+
116
+ # Cache directories
117
+ # ------------------------------------------------------------------------------------------
118
+
119
+ _dnnlib_cache_dir = None
120
+
121
+ def set_cache_dir(path: str) -> None:
122
+ global _dnnlib_cache_dir
123
+ _dnnlib_cache_dir = path
124
+
125
+ def make_cache_dir_path(*paths: str) -> str:
126
+ if _dnnlib_cache_dir is not None:
127
+ return os.path.join(_dnnlib_cache_dir, *paths)
128
+ if 'DNNLIB_CACHE_DIR' in os.environ:
129
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
130
+ if 'HOME' in os.environ:
131
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
132
+ if 'USERPROFILE' in os.environ:
133
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
134
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
135
+
136
+ # Small util functions
137
+ # ------------------------------------------------------------------------------------------
138
+
139
+
140
+ def format_time(seconds: Union[int, float]) -> str:
141
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
142
+ s = int(np.rint(seconds))
143
+
144
+ if s < 60:
145
+ return "{0}s".format(s)
146
+ elif s < 60 * 60:
147
+ return "{0}m {1:02}s".format(s // 60, s % 60)
148
+ elif s < 24 * 60 * 60:
149
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
150
+ else:
151
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
152
+
153
+
154
+ def ask_yes_no(question: str) -> bool:
155
+ """Ask the user the question until the user inputs a valid answer."""
156
+ while True:
157
+ try:
158
+ print("{0} [y/n]".format(question))
159
+ return strtobool(input().lower())
160
+ except ValueError:
161
+ pass
162
+
163
+
164
+ def tuple_product(t: Tuple) -> Any:
165
+ """Calculate the product of the tuple elements."""
166
+ result = 1
167
+
168
+ for v in t:
169
+ result *= v
170
+
171
+ return result
172
+
173
+
174
+ _str_to_ctype = {
175
+ "uint8": ctypes.c_ubyte,
176
+ "uint16": ctypes.c_uint16,
177
+ "uint32": ctypes.c_uint32,
178
+ "uint64": ctypes.c_uint64,
179
+ "int8": ctypes.c_byte,
180
+ "int16": ctypes.c_int16,
181
+ "int32": ctypes.c_int32,
182
+ "int64": ctypes.c_int64,
183
+ "float32": ctypes.c_float,
184
+ "float64": ctypes.c_double
185
+ }
186
+
187
+
188
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
189
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
190
+ type_str = None
191
+
192
+ if isinstance(type_obj, str):
193
+ type_str = type_obj
194
+ elif hasattr(type_obj, "__name__"):
195
+ type_str = type_obj.__name__
196
+ elif hasattr(type_obj, "name"):
197
+ type_str = type_obj.name
198
+ else:
199
+ raise RuntimeError("Cannot infer type name from input")
200
+
201
+ assert type_str in _str_to_ctype.keys()
202
+
203
+ my_dtype = np.dtype(type_str)
204
+ my_ctype = _str_to_ctype[type_str]
205
+
206
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
207
+
208
+ return my_dtype, my_ctype
209
+
210
+
211
+ def is_pickleable(obj: Any) -> bool:
212
+ try:
213
+ with io.BytesIO() as stream:
214
+ pickle.dump(obj, stream)
215
+ return True
216
+ except:
217
+ return False
218
+
219
+
220
+ # Functionality to import modules/objects by name, and call functions by name
221
+ # ------------------------------------------------------------------------------------------
222
+
223
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
224
+ """Searches for the underlying module behind the name to some python object.
225
+ Returns the module and the object name (original name with module part removed)."""
226
+
227
+ # allow convenience shorthands, substitute them by full names
228
+ obj_name = re.sub("^np.", "numpy.", obj_name)
229
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
230
+
231
+ # list alternatives for (module_name, local_obj_name)
232
+ parts = obj_name.split(".")
233
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
234
+
235
+ # try each alternative in turn
236
+ for module_name, local_obj_name in name_pairs:
237
+ try:
238
+ module = importlib.import_module(module_name) # may raise ImportError
239
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
240
+ return module, local_obj_name
241
+ except:
242
+ pass
243
+
244
+ # maybe some of the modules themselves contain errors?
245
+ for module_name, _local_obj_name in name_pairs:
246
+ try:
247
+ importlib.import_module(module_name) # may raise ImportError
248
+ except ImportError:
249
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
250
+ raise
251
+
252
+ # maybe the requested attribute is missing?
253
+ for module_name, local_obj_name in name_pairs:
254
+ try:
255
+ module = importlib.import_module(module_name) # may raise ImportError
256
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
257
+ except ImportError:
258
+ pass
259
+
260
+ # we are out of luck, but we have no idea why
261
+ raise ImportError(obj_name)
262
+
263
+
264
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
265
+ """Traverses the object name and returns the last (rightmost) python object."""
266
+ if obj_name == '':
267
+ return module
268
+ obj = module
269
+ for part in obj_name.split("."):
270
+ obj = getattr(obj, part)
271
+ return obj
272
+
273
+
274
+ def get_obj_by_name(name: str) -> Any:
275
+ """Finds the python object with the given name."""
276
+ module, obj_name = get_module_from_obj_name(name)
277
+ return get_obj_from_module(module, obj_name)
278
+
279
+
280
+ def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
281
+ """Finds the python object with the given name and calls it as a function."""
282
+ assert func_name is not None
283
+ # print('func_name: ', func_name) #'training.dataset.ImageFolderDataset'
284
+ func_obj = get_obj_by_name(func_name)
285
+ assert callable(func_obj)
286
+ return func_obj(*args, **kwargs)
287
+
288
+
289
+ def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
290
+ """Finds the python class with the given name and constructs it with the given arguments."""
291
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
292
+
293
+
294
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
295
+ """Get the directory path of the module containing the given object name."""
296
+ module, _ = get_module_from_obj_name(obj_name)
297
+ return os.path.dirname(inspect.getfile(module))
298
+
299
+
300
+ def is_top_level_function(obj: Any) -> bool:
301
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
302
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
303
+
304
+
305
+ def get_top_level_function_name(obj: Any) -> str:
306
+ """Return the fully-qualified name of a top-level function."""
307
+ assert is_top_level_function(obj)
308
+ module = obj.__module__
309
+ if module == '__main__':
310
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
311
+ return module + "." + obj.__name__
312
+
313
+
314
+ # File system helpers
315
+ # ------------------------------------------------------------------------------------------
316
+
317
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
318
+ """List all files recursively in a given directory while ignoring given file and directory names.
319
+ Returns list of tuples containing both absolute and relative paths."""
320
+ assert os.path.isdir(dir_path)
321
+ base_name = os.path.basename(os.path.normpath(dir_path))
322
+
323
+ if ignores is None:
324
+ ignores = []
325
+
326
+ result = []
327
+
328
+ for root, dirs, files in os.walk(dir_path, topdown=True):
329
+ for ignore_ in ignores:
330
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
331
+
332
+ # dirs need to be edited in-place
333
+ for d in dirs_to_remove:
334
+ dirs.remove(d)
335
+
336
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
337
+
338
+ absolute_paths = [os.path.join(root, f) for f in files]
339
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
340
+
341
+ if add_base_to_relative:
342
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
343
+
344
+ assert len(absolute_paths) == len(relative_paths)
345
+ result += zip(absolute_paths, relative_paths)
346
+
347
+ return result
348
+
349
+
350
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
351
+ """Takes in a list of tuples of (src, dst) paths and copies files.
352
+ Will create all necessary directories."""
353
+ for file in files:
354
+ target_dir_name = os.path.dirname(file[1])
355
+
356
+ # will create all intermediate-level directories
357
+ if not os.path.exists(target_dir_name):
358
+ os.makedirs(target_dir_name)
359
+
360
+ shutil.copyfile(file[0], file[1])
361
+
362
+
363
+ # URL helpers
364
+ # ------------------------------------------------------------------------------------------
365
+
366
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
367
+ """Determine whether the given object is a valid URL string."""
368
+ if not isinstance(obj, str) or not "://" in obj:
369
+ return False
370
+ if allow_file_urls and obj.startswith('file://'):
371
+ return True
372
+ try:
373
+ res = requests.compat.urlparse(obj)
374
+ if not res.scheme or not res.netloc or not "." in res.netloc:
375
+ return False
376
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
377
+ if not res.scheme or not res.netloc or not "." in res.netloc:
378
+ return False
379
+ except:
380
+ return False
381
+ return True
382
+
383
+
384
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
385
+ """Download the given URL and return a binary-mode file object to access the data."""
386
+ assert num_attempts >= 1
387
+ assert not (return_filename and (not cache))
388
+
389
+ # Doesn't look like an URL scheme so interpret it as a local filename.
390
+ if not re.match('^[a-z]+://', url):
391
+ return url if return_filename else open(url, "rb")
392
+
393
+ # Handle file URLs. This code handles unusual file:// patterns that
394
+ # arise on Windows:
395
+ #
396
+ # file:///c:/foo.txt
397
+ #
398
+ # which would translate to a local '/c:/foo.txt' filename that's
399
+ # invalid. Drop the forward slash for such pathnames.
400
+ #
401
+ # If you touch this code path, you should test it on both Linux and
402
+ # Windows.
403
+ #
404
+ # Some internet resources suggest using urllib.request.url2pathname() but
405
+ # but that converts forward slashes to backslashes and this causes
406
+ # its own set of problems.
407
+ if url.startswith('file://'):
408
+ filename = urllib.parse.urlparse(url).path
409
+ if re.match(r'^/[a-zA-Z]:', filename):
410
+ filename = filename[1:]
411
+ return filename if return_filename else open(filename, "rb")
412
+
413
+ assert is_url(url)
414
+
415
+ # Lookup from cache.
416
+ if cache_dir is None:
417
+ cache_dir = make_cache_dir_path('downloads')
418
+
419
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
420
+ if cache:
421
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
422
+ if len(cache_files) == 1:
423
+ filename = cache_files[0]
424
+ return filename if return_filename else open(filename, "rb")
425
+
426
+ # Download.
427
+ url_name = None
428
+ url_data = None
429
+ with requests.Session() as session:
430
+ if verbose:
431
+ print("Downloading %s ..." % url, end="", flush=True)
432
+ for attempts_left in reversed(range(num_attempts)):
433
+ try:
434
+ with session.get(url) as res:
435
+ res.raise_for_status()
436
+ if len(res.content) == 0:
437
+ raise IOError("No data received")
438
+
439
+ if len(res.content) < 8192:
440
+ content_str = res.content.decode("utf-8")
441
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
442
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
443
+ if len(links) == 1:
444
+ url = requests.compat.urljoin(url, links[0])
445
+ raise IOError("Google Drive virus checker nag")
446
+ if "Google Drive - Quota exceeded" in content_str:
447
+ raise IOError("Google Drive download quota exceeded -- please try again later")
448
+
449
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
450
+ url_name = match[1] if match else url
451
+ url_data = res.content
452
+ if verbose:
453
+ print(" done")
454
+ break
455
+ except KeyboardInterrupt:
456
+ raise
457
+ except:
458
+ if not attempts_left:
459
+ if verbose:
460
+ print(" failed")
461
+ raise
462
+ if verbose:
463
+ print(".", end="", flush=True)
464
+
465
+ # Save to cache.
466
+ if cache:
467
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
468
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
469
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
470
+ os.makedirs(cache_dir, exist_ok=True)
471
+ with open(temp_file, "wb") as f:
472
+ f.write(url_data)
473
+ os.replace(temp_file, cache_file) # atomic
474
+ if return_filename:
475
+ return cache_file
476
+
477
+ # Return data as file object.
478
+ assert not return_filename
479
+ return io.BytesIO(url_data)