kyonimouto commited on
Commit
b779806
·
1 Parent(s): a579d5b

uploading the pretrained pkl, generating script and dnnlib

Browse files
best_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83411361ddbfc11b44ba9192672888f892a306135915a9730695b878d7cd3520
3
+ size 1788704500
dnnlib/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from .util import EasyDict, make_cache_dir_path
dnnlib/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (229 Bytes). View file
 
dnnlib/__pycache__/util.cpython-310.pyc ADDED
Binary file (14 kB). View file
 
dnnlib/gan-metrics/resized256-inception-2015-12-05-37cd2b4e1ddcd3d72062f9bd3f6d4738.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:159911c1553a35fd0bee62832d75db66a9c05ed70b16350a52ba86756d8db9ca
3
+ size 33571138
dnnlib/util.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Miscellaneous utility classes and functions."""
10
+
11
+ import ctypes
12
+ import fnmatch
13
+ import importlib
14
+ import inspect
15
+ import numpy as np
16
+ import os
17
+ import shutil
18
+ import sys
19
+ import types
20
+ import io
21
+ import pickle
22
+ import re
23
+ import requests
24
+ import html
25
+ import hashlib
26
+ import glob
27
+ import tempfile
28
+ import urllib
29
+ import urllib.request
30
+ import uuid
31
+
32
+ from distutils.util import strtobool
33
+ from typing import Any, List, Tuple, Union
34
+
35
+
36
+ # Util classes
37
+ # ------------------------------------------------------------------------------------------
38
+
39
+
40
+ class EasyDict(dict):
41
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
42
+
43
+ def __getattr__(self, name: str) -> Any:
44
+ try:
45
+ return self[name]
46
+ except KeyError:
47
+ raise AttributeError(name)
48
+
49
+ def __setattr__(self, name: str, value: Any) -> None:
50
+ self[name] = value
51
+
52
+ def __delattr__(self, name: str) -> None:
53
+ del self[name]
54
+
55
+
56
+ class Logger(object):
57
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
58
+
59
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
60
+ self.file = None
61
+
62
+ if file_name is not None:
63
+ self.file = open(file_name, file_mode)
64
+
65
+ self.should_flush = should_flush
66
+ self.stdout = sys.stdout
67
+ self.stderr = sys.stderr
68
+
69
+ sys.stdout = self
70
+ sys.stderr = self
71
+
72
+ def __enter__(self) -> "Logger":
73
+ return self
74
+
75
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
76
+ self.close()
77
+
78
+ def write(self, text: Union[str, bytes]) -> None:
79
+ """Write text to stdout (and a file) and optionally flush."""
80
+ if isinstance(text, bytes):
81
+ text = text.decode()
82
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
83
+ return
84
+
85
+ if self.file is not None:
86
+ self.file.write(text)
87
+
88
+ self.stdout.write(text)
89
+
90
+ if self.should_flush:
91
+ self.flush()
92
+
93
+ def flush(self) -> None:
94
+ """Flush written text to both stdout and a file, if open."""
95
+ if self.file is not None:
96
+ self.file.flush()
97
+
98
+ self.stdout.flush()
99
+
100
+ def close(self) -> None:
101
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
102
+ self.flush()
103
+
104
+ # if using multiple loggers, prevent closing in wrong order
105
+ if sys.stdout is self:
106
+ sys.stdout = self.stdout
107
+ if sys.stderr is self:
108
+ sys.stderr = self.stderr
109
+
110
+ if self.file is not None:
111
+ self.file.close()
112
+ self.file = None
113
+
114
+
115
+ # Cache directories
116
+ # ------------------------------------------------------------------------------------------
117
+
118
+ _dnnlib_cache_dir = None
119
+
120
+ def set_cache_dir(path: str) -> None:
121
+ global _dnnlib_cache_dir
122
+ _dnnlib_cache_dir = path
123
+
124
+ def make_cache_dir_path(*paths: str) -> str:
125
+ if _dnnlib_cache_dir is not None:
126
+ return os.path.join(_dnnlib_cache_dir, *paths)
127
+ if 'DNNLIB_CACHE_DIR' in os.environ:
128
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
129
+ if 'HOME' in os.environ:
130
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
131
+ if 'USERPROFILE' in os.environ:
132
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
133
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
134
+
135
+ # Small util functions
136
+ # ------------------------------------------------------------------------------------------
137
+
138
+
139
+ def format_time(seconds: Union[int, float]) -> str:
140
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
141
+ s = int(np.rint(seconds))
142
+
143
+ if s < 60:
144
+ return "{0}s".format(s)
145
+ elif s < 60 * 60:
146
+ return "{0}m {1:02}s".format(s // 60, s % 60)
147
+ elif s < 24 * 60 * 60:
148
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
149
+ else:
150
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
151
+
152
+
153
+ def format_time_brief(seconds: Union[int, float]) -> str:
154
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
155
+ s = int(np.rint(seconds))
156
+
157
+ if s < 60:
158
+ return "{0}s".format(s)
159
+ elif s < 60 * 60:
160
+ return "{0}m {1:02}s".format(s // 60, s % 60)
161
+ elif s < 24 * 60 * 60:
162
+ return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
163
+ else:
164
+ return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
165
+
166
+
167
+ def ask_yes_no(question: str) -> bool:
168
+ """Ask the user the question until the user inputs a valid answer."""
169
+ while True:
170
+ try:
171
+ print("{0} [y/n]".format(question))
172
+ return strtobool(input().lower())
173
+ except ValueError:
174
+ pass
175
+
176
+
177
+ def tuple_product(t: Tuple) -> Any:
178
+ """Calculate the product of the tuple elements."""
179
+ result = 1
180
+
181
+ for v in t:
182
+ result *= v
183
+
184
+ return result
185
+
186
+
187
+ _str_to_ctype = {
188
+ "uint8": ctypes.c_ubyte,
189
+ "uint16": ctypes.c_uint16,
190
+ "uint32": ctypes.c_uint32,
191
+ "uint64": ctypes.c_uint64,
192
+ "int8": ctypes.c_byte,
193
+ "int16": ctypes.c_int16,
194
+ "int32": ctypes.c_int32,
195
+ "int64": ctypes.c_int64,
196
+ "float32": ctypes.c_float,
197
+ "float64": ctypes.c_double
198
+ }
199
+
200
+
201
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
202
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
203
+ type_str = None
204
+
205
+ if isinstance(type_obj, str):
206
+ type_str = type_obj
207
+ elif hasattr(type_obj, "__name__"):
208
+ type_str = type_obj.__name__
209
+ elif hasattr(type_obj, "name"):
210
+ type_str = type_obj.name
211
+ else:
212
+ raise RuntimeError("Cannot infer type name from input")
213
+
214
+ assert type_str in _str_to_ctype.keys()
215
+
216
+ my_dtype = np.dtype(type_str)
217
+ my_ctype = _str_to_ctype[type_str]
218
+
219
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
220
+
221
+ return my_dtype, my_ctype
222
+
223
+
224
+ def is_pickleable(obj: Any) -> bool:
225
+ try:
226
+ with io.BytesIO() as stream:
227
+ pickle.dump(obj, stream)
228
+ return True
229
+ except:
230
+ return False
231
+
232
+
233
+ # Functionality to import modules/objects by name, and call functions by name
234
+ # ------------------------------------------------------------------------------------------
235
+
236
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
237
+ """Searches for the underlying module behind the name to some python object.
238
+ Returns the module and the object name (original name with module part removed)."""
239
+
240
+ # allow convenience shorthands, substitute them by full names
241
+ obj_name = re.sub("^np.", "numpy.", obj_name)
242
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
243
+
244
+ # list alternatives for (module_name, local_obj_name)
245
+ parts = obj_name.split(".")
246
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
247
+
248
+ # try each alternative in turn
249
+ for module_name, local_obj_name in name_pairs:
250
+ try:
251
+ module = importlib.import_module(module_name) # may raise ImportError
252
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
253
+ return module, local_obj_name
254
+ except:
255
+ pass
256
+
257
+ # maybe some of the modules themselves contain errors?
258
+ for module_name, _local_obj_name in name_pairs:
259
+ try:
260
+ importlib.import_module(module_name) # may raise ImportError
261
+ except ImportError:
262
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
263
+ raise
264
+
265
+ # maybe the requested attribute is missing?
266
+ for module_name, local_obj_name in name_pairs:
267
+ try:
268
+ module = importlib.import_module(module_name) # may raise ImportError
269
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
270
+ except ImportError:
271
+ pass
272
+
273
+ # we are out of luck, but we have no idea why
274
+ raise ImportError(obj_name)
275
+
276
+
277
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
278
+ """Traverses the object name and returns the last (rightmost) python object."""
279
+ if obj_name == '':
280
+ return module
281
+ obj = module
282
+ for part in obj_name.split("."):
283
+ obj = getattr(obj, part)
284
+ return obj
285
+
286
+
287
+ def get_obj_by_name(name: str) -> Any:
288
+ """Finds the python object with the given name."""
289
+ module, obj_name = get_module_from_obj_name(name)
290
+ return get_obj_from_module(module, obj_name)
291
+
292
+
293
+ def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
294
+ """Finds the python object with the given name and calls it as a function."""
295
+ assert func_name is not None
296
+ func_obj = get_obj_by_name(func_name)
297
+ assert callable(func_obj)
298
+ return func_obj(*args, **kwargs)
299
+
300
+
301
+ def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
302
+ """Finds the python class with the given name and constructs it with the given arguments."""
303
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
304
+
305
+
306
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
307
+ """Get the directory path of the module containing the given object name."""
308
+ module, _ = get_module_from_obj_name(obj_name)
309
+ return os.path.dirname(inspect.getfile(module))
310
+
311
+
312
+ def is_top_level_function(obj: Any) -> bool:
313
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
314
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
315
+
316
+
317
+ def get_top_level_function_name(obj: Any) -> str:
318
+ """Return the fully-qualified name of a top-level function."""
319
+ assert is_top_level_function(obj)
320
+ module = obj.__module__
321
+ if module == '__main__':
322
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
323
+ return module + "." + obj.__name__
324
+
325
+
326
+ # File system helpers
327
+ # ------------------------------------------------------------------------------------------
328
+
329
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
330
+ """List all files recursively in a given directory while ignoring given file and directory names.
331
+ Returns list of tuples containing both absolute and relative paths."""
332
+ assert os.path.isdir(dir_path)
333
+ base_name = os.path.basename(os.path.normpath(dir_path))
334
+
335
+ if ignores is None:
336
+ ignores = []
337
+
338
+ result = []
339
+
340
+ for root, dirs, files in os.walk(dir_path, topdown=True):
341
+ for ignore_ in ignores:
342
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
343
+
344
+ # dirs need to be edited in-place
345
+ for d in dirs_to_remove:
346
+ dirs.remove(d)
347
+
348
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
349
+
350
+ absolute_paths = [os.path.join(root, f) for f in files]
351
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
352
+
353
+ if add_base_to_relative:
354
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
355
+
356
+ assert len(absolute_paths) == len(relative_paths)
357
+ result += zip(absolute_paths, relative_paths)
358
+
359
+ return result
360
+
361
+
362
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
363
+ """Takes in a list of tuples of (src, dst) paths and copies files.
364
+ Will create all necessary directories."""
365
+ for file in files:
366
+ target_dir_name = os.path.dirname(file[1])
367
+
368
+ # will create all intermediate-level directories
369
+ if not os.path.exists(target_dir_name):
370
+ os.makedirs(target_dir_name)
371
+
372
+ shutil.copyfile(file[0], file[1])
373
+
374
+
375
+ # URL helpers
376
+ # ------------------------------------------------------------------------------------------
377
+
378
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
379
+ """Determine whether the given object is a valid URL string."""
380
+ if not isinstance(obj, str) or not "://" in obj:
381
+ return False
382
+ if allow_file_urls and obj.startswith('file://'):
383
+ return True
384
+ try:
385
+ res = requests.compat.urlparse(obj)
386
+ if not res.scheme or not res.netloc or not "." in res.netloc:
387
+ return False
388
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
389
+ if not res.scheme or not res.netloc or not "." in res.netloc:
390
+ return False
391
+ except:
392
+ return False
393
+ return True
394
+
395
+
396
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
397
+ """Download the given URL and return a binary-mode file object to access the data."""
398
+ assert num_attempts >= 1
399
+ assert not (return_filename and (not cache))
400
+
401
+ # Doesn't look like an URL scheme so interpret it as a local filename.
402
+ if not re.match('^[a-z]+://', url):
403
+ return url if return_filename else open(url, "rb")
404
+
405
+ # Handle file URLs. This code handles unusual file:// patterns that
406
+ # arise on Windows:
407
+ #
408
+ # file:///c:/foo.txt
409
+ #
410
+ # which would translate to a local '/c:/foo.txt' filename that's
411
+ # invalid. Drop the forward slash for such pathnames.
412
+ #
413
+ # If you touch this code path, you should test it on both Linux and
414
+ # Windows.
415
+ #
416
+ # Some internet resources suggest using urllib.request.url2pathname() but
417
+ # but that converts forward slashes to backslashes and this causes
418
+ # its own set of problems.
419
+ if url.startswith('file://'):
420
+ filename = urllib.parse.urlparse(url).path
421
+ if re.match(r'^/[a-zA-Z]:', filename):
422
+ filename = filename[1:]
423
+ return filename if return_filename else open(filename, "rb")
424
+
425
+ assert is_url(url)
426
+
427
+ # Lookup from cache.
428
+ if cache_dir is None:
429
+ cache_dir = make_cache_dir_path('downloads')
430
+
431
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
432
+ if cache:
433
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
434
+ if len(cache_files) == 1:
435
+ filename = cache_files[0]
436
+ return filename if return_filename else open(filename, "rb")
437
+
438
+ # Download.
439
+ url_name = None
440
+ url_data = None
441
+ with requests.Session() as session:
442
+ if verbose:
443
+ print("Downloading %s ..." % url, end="", flush=True)
444
+ for attempts_left in reversed(range(num_attempts)):
445
+ try:
446
+ with session.get(url) as res:
447
+ res.raise_for_status()
448
+ if len(res.content) == 0:
449
+ raise IOError("No data received")
450
+
451
+ if len(res.content) < 8192:
452
+ content_str = res.content.decode("utf-8")
453
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
454
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
455
+ if len(links) == 1:
456
+ url = requests.compat.urljoin(url, links[0])
457
+ raise IOError("Google Drive virus checker nag")
458
+ if "Google Drive - Quota exceeded" in content_str:
459
+ raise IOError("Google Drive download quota exceeded -- please try again later")
460
+
461
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
462
+ url_name = match[1] if match else url
463
+ url_data = res.content
464
+ if verbose:
465
+ print(" done")
466
+ break
467
+ except KeyboardInterrupt:
468
+ raise
469
+ except:
470
+ if not attempts_left:
471
+ if verbose:
472
+ print(" failed")
473
+ raise
474
+ if verbose:
475
+ print(".", end="", flush=True)
476
+
477
+ # Save to cache.
478
+ if cache:
479
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
480
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
481
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
482
+ os.makedirs(cache_dir, exist_ok=True)
483
+ with open(temp_file, "wb") as f:
484
+ f.write(url_data)
485
+ os.replace(temp_file, cache_file) # atomic
486
+ if return_filename:
487
+ return cache_file
488
+
489
+ # Return data as file object.
490
+ assert not return_filename
491
+ return io.BytesIO(url_data)
gen_images.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Generate images using pretrained network pickle."""
10
+
11
+ import os
12
+ import re
13
+ from typing import List, Optional, Tuple, Union
14
+
15
+ import click
16
+ import dnnlib
17
+ import numpy as np
18
+ import PIL.Image
19
+ import torch
20
+
21
+ import legacy
22
+
23
+ #----------------------------------------------------------------------------
24
+
25
+ def parse_range(s: Union[str, List]) -> List[int]:
26
+ '''Parse a comma separated list of numbers or ranges and return a list of ints.
27
+
28
+ Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
29
+ '''
30
+ if isinstance(s, list): return s
31
+ ranges = []
32
+ range_re = re.compile(r'^(\d+)-(\d+)$')
33
+ for p in s.split(','):
34
+ m = range_re.match(p)
35
+ if m:
36
+ ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
37
+ else:
38
+ ranges.append(int(p))
39
+ return ranges
40
+
41
+ #----------------------------------------------------------------------------
42
+
43
+ def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]:
44
+ '''Parse a floating point 2-vector of syntax 'a,b'.
45
+
46
+ Example:
47
+ '0,1' returns (0,1)
48
+ '''
49
+ if isinstance(s, tuple): return s
50
+ parts = s.split(',')
51
+ if len(parts) == 2:
52
+ return (float(parts[0]), float(parts[1]))
53
+ raise ValueError(f'cannot parse 2-vector {s}')
54
+
55
+ #----------------------------------------------------------------------------
56
+
57
+ def make_transform(translate: Tuple[float,float], angle: float):
58
+ m = np.eye(3)
59
+ s = np.sin(angle/360.0*np.pi*2)
60
+ c = np.cos(angle/360.0*np.pi*2)
61
+ m[0][0] = c
62
+ m[0][1] = s
63
+ m[0][2] = translate[0]
64
+ m[1][0] = -s
65
+ m[1][1] = c
66
+ m[1][2] = translate[1]
67
+ return m
68
+
69
+ #----------------------------------------------------------------------------
70
+
71
+ @click.command()
72
+ @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
73
+ @click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', required=True)
74
+ @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
75
+ @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
76
+ @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
77
+ @click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2')
78
+ @click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE')
79
+ @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
80
+ def generate_images(
81
+ network_pkl: str,
82
+ seeds: List[int],
83
+ truncation_psi: float,
84
+ noise_mode: str,
85
+ outdir: str,
86
+ translate: Tuple[float,float],
87
+ rotate: float,
88
+ class_idx: Optional[int]
89
+ ):
90
+ """Generate images using pretrained network pickle.
91
+
92
+ Examples:
93
+
94
+ \b
95
+ # Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left).
96
+ python gen_images.py --outdir=out --trunc=1 --seeds=2 \\
97
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
98
+
99
+ \b
100
+ # Generate uncurated images with truncation using the MetFaces-U dataset
101
+ python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\
102
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl
103
+ """
104
+
105
+ print('Loading networks from "%s"...' % network_pkl)
106
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
107
+ with dnnlib.util.open_url(network_pkl) as f:
108
+ G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
109
+
110
+ os.makedirs(outdir, exist_ok=True)
111
+
112
+ # Labels.
113
+ label = torch.zeros([1, G.c_dim], device=device)
114
+ if G.c_dim != 0:
115
+ if class_idx is None:
116
+ raise click.ClickException('Must specify class label with --class when using a conditional network')
117
+ label[:, class_idx] = 1
118
+ else:
119
+ if class_idx is not None:
120
+ print ('warn: --class=lbl ignored when running on an unconditional network')
121
+
122
+ # Generate images.
123
+ for seed_idx, seed in enumerate(seeds):
124
+ print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
125
+ z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device).float()
126
+
127
+ # Construct an inverse rotation/translation matrix and pass to the generator. The
128
+ # generator expects this matrix as an inverse to avoid potentially failing numerical
129
+ # operations in the network.
130
+ if hasattr(G.synthesis, 'input'):
131
+ m = make_transform(translate, rotate)
132
+ m = np.linalg.inv(m)
133
+ G.synthesis.input.transform.copy_(torch.from_numpy(m))
134
+
135
+ img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
136
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
137
+ PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
138
+
139
+
140
+ #----------------------------------------------------------------------------
141
+
142
+ if __name__ == "__main__":
143
+ generate_images() # pylint: disable=no-value-for-parameter
144
+
145
+ #----------------------------------------------------------------------------