Spaces:
Runtime error
Runtime error
comiit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +27 -0
- .gitignore +0 -8
- .gitmodules +4 -9
- LICENSE +0 -21
- LICENSE-NVIDIA +0 -101
- LICENSE-STYLEGAN2 +0 -21
- README.md +13 -119
- app.py +99 -162
- dnnlib/__init__.py +11 -0
- dnnlib/tflib/__init__.py +20 -0
- dnnlib/tflib/autosummary.py +193 -0
- dnnlib/tflib/custom_ops.py +171 -0
- dnnlib/tflib/network.py +592 -0
- dnnlib/tflib/ops/__init__.py +9 -0
- dnnlib/tflib/ops/fused_bias_act.cu +190 -0
- dnnlib/tflib/ops/fused_bias_act.py +198 -0
- dnnlib/tflib/ops/upfirdn_2d.cu +328 -0
- dnnlib/tflib/ops/upfirdn_2d.py +366 -0
- dnnlib/tflib/optimizer.py +338 -0
- dnnlib/tflib/tfutil.py +254 -0
- dnnlib/util.py +479 -0
- losses/color_transfer_loss.py +0 -60
- losses/joint_loss.py +0 -167
- losses/perceptual_loss.py +0 -111
- losses/reconstruction.py +0 -119
- losses/regularize_noise.py +0 -37
- models/__init__.py +0 -0
- models/degrade.py +0 -122
- models/encoder.py +0 -66
- models/gaussian_smoothing.py +0 -74
- models/resnet.py +0 -99
- models/vggface.py +0 -150
- op/upfirdn2d_kernel.cu +0 -272
- optim/__init__.py +0 -15
- optim/radam.py +0 -250
- requirements.txt +5 -25
- scripts/download_checkpoints.sh +0 -14
- scripts/install.sh +0 -6
- scripts/run.sh +0 -34
- tools/__init__.py +0 -0
- tools/data/__init__.py +0 -0
- tools/data/align_images.py +0 -117
- tools/initialize.py +0 -160
- tools/match_histogram.py +0 -167
- tools/match_skin_histogram.py +0 -67
- tools/parse_face.py +0 -55
- torch_utils/__init__.py +11 -0
- torch_utils/custom_ops.py +238 -0
- torch_utils/misc.py +264 -0
- torch_utils/models.py +756 -0
.gitattributes
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
19 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
CHANGED
@@ -131,11 +131,3 @@ dmypy.json
|
|
131 |
wandb/
|
132 |
*.lmdb/
|
133 |
*.pkl
|
134 |
-
|
135 |
-
# results
|
136 |
-
results
|
137 |
-
results_old
|
138 |
-
log
|
139 |
-
checkpoint
|
140 |
-
*.pt
|
141 |
-
*.old
|
|
|
131 |
wandb/
|
132 |
*.lmdb/
|
133 |
*.pkl
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitmodules
CHANGED
@@ -1,9 +1,4 @@
|
|
1 |
-
[submodule "
|
2 |
-
path =
|
3 |
-
url = https://github.com/
|
4 |
-
|
5 |
-
path = models/encoder4editing
|
6 |
-
url = https://github.com/Time-Travel-Rephotography/encoder4editing.git
|
7 |
-
[submodule "losses/contextual_loss"]
|
8 |
-
path = losses/contextual_loss
|
9 |
-
url = https://github.com/Time-Travel-Rephotography/contextual_loss_pytorch.git
|
|
|
1 |
+
[submodule "StyleGAN-Human"]
|
2 |
+
path = StyleGAN-Human
|
3 |
+
url = https://github.com/stylegan-human/StyleGAN-Human
|
4 |
+
|
|
|
|
|
|
|
|
|
|
LICENSE
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
MIT License
|
2 |
-
|
3 |
-
Copyright (c) 2020 Time-Travel-Rephotography
|
4 |
-
|
5 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
-
of this software and associated documentation files (the "Software"), to deal
|
7 |
-
in the Software without restriction, including without limitation the rights
|
8 |
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
-
copies of the Software, and to permit persons to whom the Software is
|
10 |
-
furnished to do so, subject to the following conditions:
|
11 |
-
|
12 |
-
The above copyright notice and this permission notice shall be included in all
|
13 |
-
copies or substantial portions of the Software.
|
14 |
-
|
15 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
-
SOFTWARE.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LICENSE-NVIDIA
DELETED
@@ -1,101 +0,0 @@
|
|
1 |
-
Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
-
|
3 |
-
|
4 |
-
Nvidia Source Code License-NC
|
5 |
-
|
6 |
-
=======================================================================
|
7 |
-
|
8 |
-
1. Definitions
|
9 |
-
|
10 |
-
"Licensor" means any person or entity that distributes its Work.
|
11 |
-
|
12 |
-
"Software" means the original work of authorship made available under
|
13 |
-
this License.
|
14 |
-
|
15 |
-
"Work" means the Software and any additions to or derivative works of
|
16 |
-
the Software that are made available under this License.
|
17 |
-
|
18 |
-
"Nvidia Processors" means any central processing unit (CPU), graphics
|
19 |
-
processing unit (GPU), field-programmable gate array (FPGA),
|
20 |
-
application-specific integrated circuit (ASIC) or any combination
|
21 |
-
thereof designed, made, sold, or provided by Nvidia or its affiliates.
|
22 |
-
|
23 |
-
The terms "reproduce," "reproduction," "derivative works," and
|
24 |
-
"distribution" have the meaning as provided under U.S. copyright law;
|
25 |
-
provided, however, that for the purposes of this License, derivative
|
26 |
-
works shall not include works that remain separable from, or merely
|
27 |
-
link (or bind by name) to the interfaces of, the Work.
|
28 |
-
|
29 |
-
Works, including the Software, are "made available" under this License
|
30 |
-
by including in or with the Work either (a) a copyright notice
|
31 |
-
referencing the applicability of this License to the Work, or (b) a
|
32 |
-
copy of this License.
|
33 |
-
|
34 |
-
2. License Grants
|
35 |
-
|
36 |
-
2.1 Copyright Grant. Subject to the terms and conditions of this
|
37 |
-
License, each Licensor grants to you a perpetual, worldwide,
|
38 |
-
non-exclusive, royalty-free, copyright license to reproduce,
|
39 |
-
prepare derivative works of, publicly display, publicly perform,
|
40 |
-
sublicense and distribute its Work and any resulting derivative
|
41 |
-
works in any form.
|
42 |
-
|
43 |
-
3. Limitations
|
44 |
-
|
45 |
-
3.1 Redistribution. You may reproduce or distribute the Work only
|
46 |
-
if (a) you do so under this License, (b) you include a complete
|
47 |
-
copy of this License with your distribution, and (c) you retain
|
48 |
-
without modification any copyright, patent, trademark, or
|
49 |
-
attribution notices that are present in the Work.
|
50 |
-
|
51 |
-
3.2 Derivative Works. You may specify that additional or different
|
52 |
-
terms apply to the use, reproduction, and distribution of your
|
53 |
-
derivative works of the Work ("Your Terms") only if (a) Your Terms
|
54 |
-
provide that the use limitation in Section 3.3 applies to your
|
55 |
-
derivative works, and (b) you identify the specific derivative
|
56 |
-
works that are subject to Your Terms. Notwithstanding Your Terms,
|
57 |
-
this License (including the redistribution requirements in Section
|
58 |
-
3.1) will continue to apply to the Work itself.
|
59 |
-
|
60 |
-
3.3 Use Limitation. The Work and any derivative works thereof only
|
61 |
-
may be used or intended for use non-commercially. The Work or
|
62 |
-
derivative works thereof may be used or intended for use by Nvidia
|
63 |
-
or its affiliates commercially or non-commercially. As used herein,
|
64 |
-
"non-commercially" means for research or evaluation purposes only.
|
65 |
-
|
66 |
-
3.4 Patent Claims. If you bring or threaten to bring a patent claim
|
67 |
-
against any Licensor (including any claim, cross-claim or
|
68 |
-
counterclaim in a lawsuit) to enforce any patents that you allege
|
69 |
-
are infringed by any Work, then your rights under this License from
|
70 |
-
such Licensor (including the grants in Sections 2.1 and 2.2) will
|
71 |
-
terminate immediately.
|
72 |
-
|
73 |
-
3.5 Trademarks. This License does not grant any rights to use any
|
74 |
-
Licensor's or its affiliates' names, logos, or trademarks, except
|
75 |
-
as necessary to reproduce the notices described in this License.
|
76 |
-
|
77 |
-
3.6 Termination. If you violate any term of this License, then your
|
78 |
-
rights under this License (including the grants in Sections 2.1 and
|
79 |
-
2.2) will terminate immediately.
|
80 |
-
|
81 |
-
4. Disclaimer of Warranty.
|
82 |
-
|
83 |
-
THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
84 |
-
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
|
85 |
-
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
|
86 |
-
NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
|
87 |
-
THIS LICENSE.
|
88 |
-
|
89 |
-
5. Limitation of Liability.
|
90 |
-
|
91 |
-
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
|
92 |
-
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
|
93 |
-
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
|
94 |
-
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
|
95 |
-
OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
|
96 |
-
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
|
97 |
-
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
|
98 |
-
COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
|
99 |
-
THE POSSIBILITY OF SUCH DAMAGES.
|
100 |
-
|
101 |
-
=======================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LICENSE-STYLEGAN2
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
MIT License
|
2 |
-
|
3 |
-
Copyright (c) 2019 Kim Seonghyeon
|
4 |
-
|
5 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
-
of this software and associated documentation files (the "Software"), to deal
|
7 |
-
in the Software without restriction, including without limitation the rights
|
8 |
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
-
copies of the Software, and to permit persons to whom the Software is
|
10 |
-
furnished to do so, subject to the following conditions:
|
11 |
-
|
12 |
-
The above copyright notice and this permission notice shall be included in all
|
13 |
-
copies or substantial portions of the Software.
|
14 |
-
|
15 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
-
SOFTWARE.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
@@ -1,119 +1,13 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
**Time-Travel Rephotography**
|
15 |
-
<br/>
|
16 |
-
[Xuan Luo](https://roxanneluo.github.io),
|
17 |
-
[Xuaner Zhang](https://people.eecs.berkeley.edu/~cecilia77/),
|
18 |
-
[Paul Yoo](https://www.linkedin.com/in/paul-yoo-768a3715b),
|
19 |
-
[Ricardo Martin-Brualla](http://www.ricardomartinbrualla.com/),
|
20 |
-
[Jason Lawrence](http://jasonlawrence.info/), and
|
21 |
-
[Steven M. Seitz](https://homes.cs.washington.edu/~seitz/)
|
22 |
-
<br/>
|
23 |
-
In SIGGRAPH Asia 2021.
|
24 |
-
|
25 |
-
## Demo
|
26 |
-
We provide an easy-to-get-started demo using Google Colab!
|
27 |
-
The Colab will allow you to try our method on the sample Abraham Lincoln photo or **your own photos** using Cloud GPUs on Google Colab.
|
28 |
-
|
29 |
-
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/15D2WIF_vE2l48ddxEx45cM3RykZwQXM8?usp=sharing)
|
30 |
-
|
31 |
-
Or you can run our method on your own machine following the instructions below.
|
32 |
-
|
33 |
-
## Prerequisite
|
34 |
-
- Pull third-party packages.
|
35 |
-
```
|
36 |
-
git submodule update --init --recursive
|
37 |
-
```
|
38 |
-
- Install python packages.
|
39 |
-
```
|
40 |
-
conda create --name rephotography python=3.8.5
|
41 |
-
conda activate rephotography
|
42 |
-
conda install pytorch==1.4.0 torchvision==0.5.0 cudatoolkit=10.1 -c pytorch
|
43 |
-
pip install -r requirements.txt
|
44 |
-
```
|
45 |
-
|
46 |
-
## Quick Start
|
47 |
-
Run our method on the example photo of Abraham Lincoln.
|
48 |
-
- Download models:
|
49 |
-
```
|
50 |
-
./scripts/download_checkpoints.sh
|
51 |
-
```
|
52 |
-
- Run:
|
53 |
-
```
|
54 |
-
./scripts/run.sh b "dataset/Abraham Lincoln_01.png" 0.75
|
55 |
-
```
|
56 |
-
- You can inspect the optimization process by
|
57 |
-
```
|
58 |
-
tensorboard --logdir "log/Abraham Lincoln_01"
|
59 |
-
```
|
60 |
-
- You can find your results as below.
|
61 |
-
```
|
62 |
-
results/
|
63 |
-
Abraham Lincoln_01/ # intermediate outputs for histogram matching and face parsing
|
64 |
-
Abraham Lincoln_01_b.png # the input after matching the histogram of the sibling image
|
65 |
-
Abraham Lincoln_01-b-G0.75-init(10,18)-s256-vgg1-vggface0.3-eye0.1-color1.0e+10-cx0.1(relu3_4,relu2_2,relu1_2)-NR5.0e+04-lr0.1_0.01-c32-wp(250,750)-init.png # the sibling image
|
66 |
-
Abraham Lincoln_01-b-G0.75-init(10,18)-s256-vgg1-vggface0.3-eye0.1-color1.0e+10-cx0.1(relu3_4,relu2_2,relu1_2)-NR5.0e+04-lr0.1_0.01-c32-wp(250,750)-init.pt # the sibing latent codes and initialized noise maps
|
67 |
-
Abraham Lincoln_01-b-G0.75-init(10,18)-s256-vgg1-vggface0.3-eye0.1-color1.0e+10-cx0.1(relu3_4,relu2_2,relu1_2)-NR5.0e+04-lr0.1_0.01-c32-wp(250,750).png # the output result
|
68 |
-
Abraham Lincoln_01-b-G0.75-init(10,18)-s256-vgg1-vggface0.3-eye0.1-color1.0e+10-cx0.1(relu3_4,relu2_2,relu1_2)-NR5.0e+04-lr0.1_0.01-c32-wp(250,750).pt # the final optimized latent codes and noise maps
|
69 |
-
Abraham Lincoln_01-b-G0.75-init(10,18)-s256-vgg1-vggface0.3-eye0.1-color1.0e+10-cx0.1(relu3_4,relu2_2,relu1_2)-NR5.0e+04-lr0.1_0.01-c32-wp(250,750)-rand.png # the result with the final latent codes but random noise maps
|
70 |
-
|
71 |
-
```
|
72 |
-
|
73 |
-
## Run on Your Own Image
|
74 |
-
- Crop and align the head regions of your images:
|
75 |
-
```
|
76 |
-
python -m tools.data.align_images <input_raw_image_dir> <aligned_image_dir>
|
77 |
-
```
|
78 |
-
- Run:
|
79 |
-
```
|
80 |
-
./scripts/run.sh <spectral_sensitivity> <input_image_path> <blur_radius>
|
81 |
-
```
|
82 |
-
The `spectral_sensitivity` can be `b` (blue-sensitive), `gb` (orthochromatic), or `g` (panchromatic). You can roughly estimate the `spectral_sensitivity` of your photo as follows. Use the *blue-sensitive* model for photos before 1873, manually select between blue-sensitive and *orthochromatic* for images from 1873 to 1906 and among all models for photos taken afterwards.
|
83 |
-
|
84 |
-
The `blur_radius` is the estimated gaussian blur radius in pixels if the input photot is resized to 1024x1024.
|
85 |
-
|
86 |
-
## Historical Wiki Face Dataset
|
87 |
-
| Path | Size | Description |
|
88 |
-
|----------- | ----------- | ----------- |
|
89 |
-
| [Historical Wiki Face Dataset.zip](https://drive.google.com/open?id=1mgC2U7quhKSz_lTL97M-0cPrIILTiUCE&authuser=xuanluo%40cs.washington.edu&usp=drive_fs)| 148 MB | Images|
|
90 |
-
| [spectral_sensitivity.json](https://drive.google.com/open?id=1n3Bqd8G0g-wNpshlgoZiOMXxLlOycAXr&authuser=xuanluo%40cs.washington.edu&usp=drive_fs)| 6 KB | Spectral sensitivity (`b`, `gb`, or `g`). |
|
91 |
-
| [blur_radius.json](https://drive.google.com/open?id=1n4vUsbQo2BcxtKVMGfD1wFHaINzEmAVP&authuser=xuanluo%40cs.washington.edu&usp=drive_fs)| 6 KB | Blur radius in pixels|
|
92 |
-
|
93 |
-
The `json`s are dictionares that map input names to the corresponding spectral sensitivity or blur radius.
|
94 |
-
Due to copyright constraints, `Historical Wiki Face Dataset.zip` contains all images in the *Historical Wiki Face Dataset* that were used in our user study except the photo of [Mao Zedong](https://en.wikipedia.org/wiki/File:Mao_Zedong_in_1959_%28cropped%29.jpg). You can download it separately and crop it as [above](#run-on-your-own-image).
|
95 |
-
|
96 |
-
## Citation
|
97 |
-
If you find our code useful, please consider citing our paper:
|
98 |
-
```
|
99 |
-
@article{Luo-Rephotography-2021,
|
100 |
-
author = {Luo, Xuan and Zhang, Xuaner and Yoo, Paul and Martin-Brualla, Ricardo and Lawrence, Jason and Seitz, Steven M.},
|
101 |
-
title = {Time-Travel Rephotography},
|
102 |
-
journal = {ACM Transactions on Graphics (Proceedings of ACM SIGGRAPH Asia 2021)},
|
103 |
-
publisher = {ACM New York, NY, USA},
|
104 |
-
volume = {40},
|
105 |
-
number = {6},
|
106 |
-
articleno = {213},
|
107 |
-
doi = {https://doi.org/10.1145/3478513.3480485},
|
108 |
-
year = {2021},
|
109 |
-
month = {12}
|
110 |
-
}
|
111 |
-
```
|
112 |
-
|
113 |
-
## License
|
114 |
-
This work is licensed under MIT License. See [LICENSE](LICENSE) for details.
|
115 |
-
|
116 |
-
Codes for the StyleGAN2 model come from [https://github.com/rosinality/stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch).
|
117 |
-
|
118 |
-
## Acknowledgments
|
119 |
-
We thank [Nick Brandreth](https://www.nickbrandreth.com/) for capturing the dry plate photos. We thank Bo Zhang, Qingnan Fan, Roy Or-El, Aleksander Holynski and Keunhong Park for insightful advice.
|
|
|
1 |
+
---
|
2 |
+
title: Time TravelRephotography
|
3 |
+
emoji: 🦀
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: red
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 2.9.4
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -1,172 +1,109 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
2 |
import os
|
3 |
-
|
4 |
-
import random
|
5 |
import sys
|
6 |
-
from typing import (
|
7 |
-
Iterable,
|
8 |
-
Optional,
|
9 |
-
)
|
10 |
|
11 |
-
import
|
12 |
import numpy as np
|
13 |
-
from PIL import Image
|
14 |
import torch
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
Grayscale,
|
19 |
-
Resize,
|
20 |
-
ToTensor,
|
21 |
-
Normalize,
|
22 |
-
)
|
23 |
-
|
24 |
-
from losses.joint_loss import JointLoss
|
25 |
-
from model import Generator
|
26 |
-
from tools.initialize import Initializer
|
27 |
-
from tools.match_skin_histogram import match_skin_histogram
|
28 |
-
from utils.projector_arguments import ProjectorArguments
|
29 |
-
from utils import torch_helpers as th
|
30 |
-
from utils.torch_helpers import make_image
|
31 |
-
from utils.misc import stem
|
32 |
-
from utils.optimize import Optimizer
|
33 |
-
from models.degrade import (
|
34 |
-
Degrade,
|
35 |
-
Downsample,
|
36 |
-
)
|
37 |
-
|
38 |
-
|
39 |
-
def set_random_seed(seed: int):
|
40 |
-
# FIXME (xuanluo): this setup still allows randomness somehow
|
41 |
-
torch.manual_seed(seed)
|
42 |
-
random.seed(seed)
|
43 |
-
np.random.seed(seed)
|
44 |
-
|
45 |
-
|
46 |
-
def read_images(paths: str, max_size: Optional[int] = None):
|
47 |
-
transform = Compose(
|
48 |
-
[
|
49 |
-
Grayscale(),
|
50 |
-
ToTensor(),
|
51 |
-
]
|
52 |
-
)
|
53 |
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
imgs_rand: Optional[torch.Tensor] = None,
|
84 |
-
):
|
85 |
-
assert len(path_prefixes) == len(imgs) and len(latents) == len(path_prefixes)
|
86 |
-
if imgs_rand is not None:
|
87 |
-
assert len(imgs) == len(imgs_rand)
|
88 |
-
imgs_arr = make_image(imgs)
|
89 |
-
for path_prefix, img, latent, noise in zip(path_prefixes, imgs_arr, latents, noises):
|
90 |
-
os.makedirs(os.path.dirname(path_prefix), exist_ok=True)
|
91 |
-
cv2.imwrite(path_prefix + ".png", img[...,::-1])
|
92 |
-
torch.save({"latent": latent.detach().cpu(), "noise": noise.detach().cpu()},
|
93 |
-
path_prefix + ".pt")
|
94 |
-
|
95 |
-
if imgs_rand is not None:
|
96 |
-
imgs_arr = make_image(imgs_rand)
|
97 |
-
for path_prefix, img in zip(path_prefixes, imgs_arr):
|
98 |
-
cv2.imwrite(path_prefix + "-rand.png", img[...,::-1])
|
99 |
-
|
100 |
-
|
101 |
-
def main(args):
|
102 |
-
opt_str = ProjectorArguments.to_string(args)
|
103 |
-
print(opt_str)
|
104 |
-
|
105 |
-
if args.rand_seed is not None:
|
106 |
-
set_random_seed(args.rand_seed)
|
107 |
-
device = th.device()
|
108 |
-
|
109 |
-
# read inputs. TODO imgs_orig has channel 1
|
110 |
-
imgs_orig = read_images([args.input], max_size=args.generator_size).to(device)
|
111 |
-
imgs = normalize(imgs_orig) # actually this will be overwritten by the histogram matching result
|
112 |
-
|
113 |
-
# initialize
|
114 |
-
with torch.no_grad():
|
115 |
-
init = Initializer(args).to(device)
|
116 |
-
latent_init = init(imgs_orig)
|
117 |
-
|
118 |
-
# create generator
|
119 |
-
generator = create_generator(args, device)
|
120 |
-
|
121 |
-
# init noises
|
122 |
-
with torch.no_grad():
|
123 |
-
noises_init = generator.make_noise()
|
124 |
-
|
125 |
-
# create a new input by matching the input's histogram to the sibling image
|
126 |
-
with torch.no_grad():
|
127 |
-
sibling, _, sibling_rgbs = generator([latent_init], input_is_latent=True, noise=noises_init)
|
128 |
-
mh_dir = pjoin(args.results_dir, stem(args.input))
|
129 |
-
imgs = match_skin_histogram(
|
130 |
-
imgs, sibling,
|
131 |
-
args.spectral_sensitivity,
|
132 |
-
pjoin(mh_dir, "input_sibling"),
|
133 |
-
pjoin(mh_dir, "skin_mask"),
|
134 |
-
matched_hist_fn=mh_dir.rstrip(os.sep) + f"_{args.spectral_sensitivity}.png",
|
135 |
-
normalize=normalize,
|
136 |
-
).to(device)
|
137 |
-
torch.cuda.empty_cache()
|
138 |
-
# TODO imgs has channel 3
|
139 |
-
|
140 |
-
degrade = Degrade(args).to(device)
|
141 |
-
|
142 |
-
rgb_levels = generator.get_latent_size(args.coarse_min) // 2 + len(args.wplus_step) - 1
|
143 |
-
criterion = JointLoss(
|
144 |
-
args, imgs,
|
145 |
-
sibling=sibling.detach(), sibling_rgbs=sibling_rgbs[:rgb_levels]).to(device)
|
146 |
-
|
147 |
-
# save initialization
|
148 |
-
save(
|
149 |
-
[pjoin(args.results_dir, f"{stem(args.input)}-{opt_str}-init")],
|
150 |
-
sibling, latent_init, noises_init,
|
151 |
-
)
|
152 |
|
153 |
-
writer = SummaryWriter(pjoin(args.log_dir, f"{stem(args.input)}/{opt_str}"))
|
154 |
-
# start optimize
|
155 |
-
latent, noises = Optimizer.optimize(generator, criterion, degrade, imgs, latent_init, noises_init, args, writer=writer)
|
156 |
-
|
157 |
-
# generate output
|
158 |
-
img_out, _, _ = generator([latent], input_is_latent=True, noise=noises)
|
159 |
-
img_out_rand_noise, _, _ = generator([latent], input_is_latent=True)
|
160 |
-
# save output
|
161 |
-
save(
|
162 |
-
[pjoin(args.results_dir, f"{stem(args.input)}-{opt_str}")],
|
163 |
-
img_out, latent, noises,
|
164 |
-
imgs_rand=img_out_rand_noise
|
165 |
-
)
|
166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
-
def parse_args():
|
169 |
-
return ProjectorArguments().parse()
|
170 |
|
171 |
-
if __name__ ==
|
172 |
-
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import functools
|
7 |
import os
|
8 |
+
import pickle
|
|
|
9 |
import sys
|
|
|
|
|
|
|
|
|
10 |
|
11 |
+
import gradio as gr
|
12 |
import numpy as np
|
|
|
13 |
import torch
|
14 |
+
import torch_utils
|
15 |
+
import torch.nn as nn
|
16 |
+
from huggingface_hub import hf_hub_download
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
sys.path.insert(0, 'StyleGAN-Human')
|
19 |
+
|
20 |
+
TITLE = 'StyleGAN-Human'
|
21 |
+
DESCRIPTION = '''This is an unofficial demo for https://github.com/stylegan-human/StyleGAN-Human.
|
22 |
+
Expected execution time on Hugging Face Spaces: 0.8s
|
23 |
+
Related App: [StyleGAN-Human (Interpolation)](https://huggingface.co/spaces/hysts/StyleGAN-Human-Interpolation)
|
24 |
+
'''
|
25 |
+
ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.stylegan-human" alt="visitor badge"/></center>'
|
26 |
+
|
27 |
+
TOKEN = "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"
|
28 |
+
|
29 |
+
|
30 |
+
def parse_args() -> argparse.Namespace:
|
31 |
+
parser = argparse.ArgumentParser()
|
32 |
+
parser.add_argument('--device', type=str, default='cpu')
|
33 |
+
parser.add_argument('--theme', type=str)
|
34 |
+
parser.add_argument('--live', action='store_true')
|
35 |
+
parser.add_argument('--share', action='store_true')
|
36 |
+
parser.add_argument('--port', type=int)
|
37 |
+
parser.add_argument('--disable-queue',
|
38 |
+
dest='enable_queue',
|
39 |
+
action='store_false')
|
40 |
+
parser.add_argument('--allow-flagging', type=str, default='never')
|
41 |
+
return parser.parse_args()
|
42 |
+
|
43 |
+
|
44 |
+
def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
|
45 |
+
return torch.from_numpy(np.random.RandomState(seed).randn(
|
46 |
+
1, z_dim)).to(device).float()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
+
@torch.inference_mode()
|
50 |
+
def generate_image(seed: int, truncation_psi: float, model: nn.Module,
|
51 |
+
device: torch.device) -> np.ndarray:
|
52 |
+
seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
|
53 |
+
|
54 |
+
z = generate_z(model.z_dim, seed, device)
|
55 |
+
label = torch.zeros([1, model.c_dim], device=device)
|
56 |
+
|
57 |
+
out = model(z, label, truncation_psi=truncation_psi, force_fp32=True)
|
58 |
+
out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
59 |
+
return out[0].cpu().numpy()
|
60 |
+
|
61 |
+
|
62 |
+
def load_model(file_name: str, device: torch.device) -> nn.Module:
|
63 |
+
path = hf_hub_download('feng2022/Time-TravelRephotography',
|
64 |
+
f'{file_name}',
|
65 |
+
use_auth_token=TOKEN)
|
66 |
+
with open(path, 'rb') as f:
|
67 |
+
model = pickle.load(f)['G_ema']
|
68 |
+
model.eval()
|
69 |
+
model.to(device)
|
70 |
+
with torch.inference_mode():
|
71 |
+
z = torch.zeros((1, model.z_dim)).to(device)
|
72 |
+
label = torch.zeros([1, model.c_dim], device=device)
|
73 |
+
model(z, label, force_fp32=True)
|
74 |
+
return model
|
75 |
+
|
76 |
+
|
77 |
+
def main():
|
78 |
+
args = parse_args()
|
79 |
+
device = torch.device(args.device)
|
80 |
+
|
81 |
+
model = load_model('stylegan_human_v2_1024.pkl', device)
|
82 |
+
|
83 |
+
func = functools.partial(generate_image, model=model, device=device)
|
84 |
+
func = functools.update_wrapper(func, generate_image)
|
85 |
+
|
86 |
+
gr.Interface(
|
87 |
+
func,
|
88 |
+
[
|
89 |
+
gr.inputs.Number(default=0, label='Seed'),
|
90 |
+
gr.inputs.Slider(
|
91 |
+
0, 2, step=0.05, default=0.7, label='Truncation psi'),
|
92 |
+
],
|
93 |
+
gr.outputs.Image(type='numpy', label='Output'),
|
94 |
+
title=TITLE,
|
95 |
+
description=DESCRIPTION,
|
96 |
+
article=ARTICLE,
|
97 |
+
theme=args.theme,
|
98 |
+
allow_flagging=args.allow_flagging,
|
99 |
+
live=args.live,
|
100 |
+
).launch(
|
101 |
+
enable_queue=args.enable_queue,
|
102 |
+
server_port=args.port,
|
103 |
+
share=args.share,
|
104 |
+
)
|
105 |
|
|
|
|
|
106 |
|
107 |
+
if __name__ == '__main__':
|
108 |
+
main()
|
109 |
+
|
dnnlib/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
+
|
3 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
+
# and proprietary rights in and to this software, related documentation
|
7 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
+
# distribution of this software and related documentation without an express
|
9 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
+
|
11 |
+
from .util import EasyDict, make_cache_dir_path
|
dnnlib/tflib/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
+
|
3 |
+
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
4 |
+
#
|
5 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
6 |
+
# To view a copy of this license, visit
|
7 |
+
# https://nvlabs.github.io/stylegan2/license.html
|
8 |
+
|
9 |
+
from . import autosummary
|
10 |
+
from . import network
|
11 |
+
from . import optimizer
|
12 |
+
from . import tfutil
|
13 |
+
from . import custom_ops
|
14 |
+
|
15 |
+
from .tfutil import *
|
16 |
+
from .network import Network
|
17 |
+
|
18 |
+
from .optimizer import Optimizer
|
19 |
+
|
20 |
+
from .custom_ops import get_plugin
|
dnnlib/tflib/autosummary.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
+
|
3 |
+
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
4 |
+
#
|
5 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
6 |
+
# To view a copy of this license, visit
|
7 |
+
# https://nvlabs.github.io/stylegan2/license.html
|
8 |
+
|
9 |
+
"""Helper for adding automatically tracked values to Tensorboard.
|
10 |
+
|
11 |
+
Autosummary creates an identity op that internally keeps track of the input
|
12 |
+
values and automatically shows up in TensorBoard. The reported value
|
13 |
+
represents an average over input components. The average is accumulated
|
14 |
+
constantly over time and flushed when save_summaries() is called.
|
15 |
+
|
16 |
+
Notes:
|
17 |
+
- The output tensor must be used as an input for something else in the
|
18 |
+
graph. Otherwise, the autosummary op will not get executed, and the average
|
19 |
+
value will not get accumulated.
|
20 |
+
- It is perfectly fine to include autosummaries with the same name in
|
21 |
+
several places throughout the graph, even if they are executed concurrently.
|
22 |
+
- It is ok to also pass in a python scalar or numpy array. In this case, it
|
23 |
+
is added to the average immediately.
|
24 |
+
"""
|
25 |
+
|
26 |
+
from collections import OrderedDict
|
27 |
+
import numpy as np
|
28 |
+
import tensorflow as tf
|
29 |
+
from tensorboard import summary as summary_lib
|
30 |
+
from tensorboard.plugins.custom_scalar import layout_pb2
|
31 |
+
|
32 |
+
from . import tfutil
|
33 |
+
from .tfutil import TfExpression
|
34 |
+
from .tfutil import TfExpressionEx
|
35 |
+
|
36 |
+
# Enable "Custom scalars" tab in TensorBoard for advanced formatting.
|
37 |
+
# Disabled by default to reduce tfevents file size.
|
38 |
+
enable_custom_scalars = False
|
39 |
+
|
40 |
+
_dtype = tf.float64
|
41 |
+
_vars = OrderedDict() # name => [var, ...]
|
42 |
+
_immediate = OrderedDict() # name => update_op, update_value
|
43 |
+
_finalized = False
|
44 |
+
_merge_op = None
|
45 |
+
|
46 |
+
|
47 |
+
def _create_var(name: str, value_expr: TfExpression) -> TfExpression:
|
48 |
+
"""Internal helper for creating autosummary accumulators."""
|
49 |
+
assert not _finalized
|
50 |
+
name_id = name.replace("/", "_")
|
51 |
+
v = tf.cast(value_expr, _dtype)
|
52 |
+
|
53 |
+
if v.shape.is_fully_defined():
|
54 |
+
size = np.prod(v.shape.as_list())
|
55 |
+
size_expr = tf.constant(size, dtype=_dtype)
|
56 |
+
else:
|
57 |
+
size = None
|
58 |
+
size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype))
|
59 |
+
|
60 |
+
if size == 1:
|
61 |
+
if v.shape.ndims != 0:
|
62 |
+
v = tf.reshape(v, [])
|
63 |
+
v = [size_expr, v, tf.square(v)]
|
64 |
+
else:
|
65 |
+
v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))]
|
66 |
+
v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype))
|
67 |
+
|
68 |
+
with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None):
|
69 |
+
var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)]
|
70 |
+
update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v))
|
71 |
+
|
72 |
+
if name in _vars:
|
73 |
+
_vars[name].append(var)
|
74 |
+
else:
|
75 |
+
_vars[name] = [var]
|
76 |
+
return update_op
|
77 |
+
|
78 |
+
|
79 |
+
def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None, condition: TfExpressionEx = True) -> TfExpressionEx:
|
80 |
+
"""Create a new autosummary.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
name: Name to use in TensorBoard
|
84 |
+
value: TensorFlow expression or python value to track
|
85 |
+
passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node.
|
86 |
+
|
87 |
+
Example use of the passthru mechanism:
|
88 |
+
|
89 |
+
n = autosummary('l2loss', loss, passthru=n)
|
90 |
+
|
91 |
+
This is a shorthand for the following code:
|
92 |
+
|
93 |
+
with tf.control_dependencies([autosummary('l2loss', loss)]):
|
94 |
+
n = tf.identity(n)
|
95 |
+
"""
|
96 |
+
tfutil.assert_tf_initialized()
|
97 |
+
name_id = name.replace("/", "_")
|
98 |
+
|
99 |
+
if tfutil.is_tf_expression(value):
|
100 |
+
with tf.name_scope("summary_" + name_id), tf.device(value.device):
|
101 |
+
condition = tf.convert_to_tensor(condition, name='condition')
|
102 |
+
update_op = tf.cond(condition, lambda: tf.group(_create_var(name, value)), tf.no_op)
|
103 |
+
with tf.control_dependencies([update_op]):
|
104 |
+
return tf.identity(value if passthru is None else passthru)
|
105 |
+
|
106 |
+
else: # python scalar or numpy array
|
107 |
+
assert not tfutil.is_tf_expression(passthru)
|
108 |
+
assert not tfutil.is_tf_expression(condition)
|
109 |
+
if condition:
|
110 |
+
if name not in _immediate:
|
111 |
+
with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None):
|
112 |
+
update_value = tf.placeholder(_dtype)
|
113 |
+
update_op = _create_var(name, update_value)
|
114 |
+
_immediate[name] = update_op, update_value
|
115 |
+
update_op, update_value = _immediate[name]
|
116 |
+
tfutil.run(update_op, {update_value: value})
|
117 |
+
return value if passthru is None else passthru
|
118 |
+
|
119 |
+
|
120 |
+
def finalize_autosummaries() -> None:
|
121 |
+
"""Create the necessary ops to include autosummaries in TensorBoard report.
|
122 |
+
Note: This should be done only once per graph.
|
123 |
+
"""
|
124 |
+
global _finalized
|
125 |
+
tfutil.assert_tf_initialized()
|
126 |
+
|
127 |
+
if _finalized:
|
128 |
+
return None
|
129 |
+
|
130 |
+
_finalized = True
|
131 |
+
tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list])
|
132 |
+
|
133 |
+
# Create summary ops.
|
134 |
+
with tf.device(None), tf.control_dependencies(None):
|
135 |
+
for name, vars_list in _vars.items():
|
136 |
+
name_id = name.replace("/", "_")
|
137 |
+
with tfutil.absolute_name_scope("Autosummary/" + name_id):
|
138 |
+
moments = tf.add_n(vars_list)
|
139 |
+
moments /= moments[0]
|
140 |
+
with tf.control_dependencies([moments]): # read before resetting
|
141 |
+
reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list]
|
142 |
+
with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting
|
143 |
+
mean = moments[1]
|
144 |
+
std = tf.sqrt(moments[2] - tf.square(moments[1]))
|
145 |
+
tf.summary.scalar(name, mean)
|
146 |
+
if enable_custom_scalars:
|
147 |
+
tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std)
|
148 |
+
tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std)
|
149 |
+
|
150 |
+
# Setup layout for custom scalars.
|
151 |
+
layout = None
|
152 |
+
if enable_custom_scalars:
|
153 |
+
cat_dict = OrderedDict()
|
154 |
+
for series_name in sorted(_vars.keys()):
|
155 |
+
p = series_name.split("/")
|
156 |
+
cat = p[0] if len(p) >= 2 else ""
|
157 |
+
chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1]
|
158 |
+
if cat not in cat_dict:
|
159 |
+
cat_dict[cat] = OrderedDict()
|
160 |
+
if chart not in cat_dict[cat]:
|
161 |
+
cat_dict[cat][chart] = []
|
162 |
+
cat_dict[cat][chart].append(series_name)
|
163 |
+
categories = []
|
164 |
+
for cat_name, chart_dict in cat_dict.items():
|
165 |
+
charts = []
|
166 |
+
for chart_name, series_names in chart_dict.items():
|
167 |
+
series = []
|
168 |
+
for series_name in series_names:
|
169 |
+
series.append(layout_pb2.MarginChartContent.Series(
|
170 |
+
value=series_name,
|
171 |
+
lower="xCustomScalars/" + series_name + "/margin_lo",
|
172 |
+
upper="xCustomScalars/" + series_name + "/margin_hi"))
|
173 |
+
margin = layout_pb2.MarginChartContent(series=series)
|
174 |
+
charts.append(layout_pb2.Chart(title=chart_name, margin=margin))
|
175 |
+
categories.append(layout_pb2.Category(title=cat_name, chart=charts))
|
176 |
+
layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories))
|
177 |
+
return layout
|
178 |
+
|
179 |
+
def save_summaries(file_writer, global_step=None):
|
180 |
+
"""Call FileWriter.add_summary() with all summaries in the default graph,
|
181 |
+
automatically finalizing and merging them on the first call.
|
182 |
+
"""
|
183 |
+
global _merge_op
|
184 |
+
tfutil.assert_tf_initialized()
|
185 |
+
|
186 |
+
if _merge_op is None:
|
187 |
+
layout = finalize_autosummaries()
|
188 |
+
if layout is not None:
|
189 |
+
file_writer.add_summary(layout)
|
190 |
+
with tf.device(None), tf.control_dependencies(None):
|
191 |
+
_merge_op = tf.summary.merge_all()
|
192 |
+
|
193 |
+
file_writer.add_summary(_merge_op.eval(), global_step)
|
dnnlib/tflib/custom_ops.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
+
|
3 |
+
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
4 |
+
#
|
5 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
6 |
+
# To view a copy of this license, visit
|
7 |
+
# https://nvlabs.github.io/stylegan2/license.html
|
8 |
+
|
9 |
+
"""TensorFlow custom ops builder.
|
10 |
+
"""
|
11 |
+
|
12 |
+
import os
|
13 |
+
import re
|
14 |
+
import uuid
|
15 |
+
import hashlib
|
16 |
+
import tempfile
|
17 |
+
import shutil
|
18 |
+
import tensorflow as tf
|
19 |
+
from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module
|
20 |
+
|
21 |
+
#----------------------------------------------------------------------------
|
22 |
+
# Global options.
|
23 |
+
|
24 |
+
cuda_cache_path = os.path.join(os.path.dirname(__file__), '_cudacache')
|
25 |
+
cuda_cache_version_tag = 'v1'
|
26 |
+
do_not_hash_included_headers = False # Speed up compilation by assuming that headers included by the CUDA code never change. Unsafe!
|
27 |
+
verbose = True # Print status messages to stdout.
|
28 |
+
|
29 |
+
compiler_bindir_search_path = [
|
30 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/2017/Community/VC/Tools/MSVC/14.14.26428/bin/Hostx64/x64',
|
31 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/2019/Community/VC/Tools/MSVC/14.23.28105/bin/Hostx64/x64',
|
32 |
+
'C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin',
|
33 |
+
]
|
34 |
+
|
35 |
+
#----------------------------------------------------------------------------
|
36 |
+
# Internal helper funcs.
|
37 |
+
|
38 |
+
def _find_compiler_bindir():
|
39 |
+
for compiler_path in compiler_bindir_search_path:
|
40 |
+
if os.path.isdir(compiler_path):
|
41 |
+
return compiler_path
|
42 |
+
return None
|
43 |
+
|
44 |
+
def _get_compute_cap(device):
|
45 |
+
caps_str = device.physical_device_desc
|
46 |
+
m = re.search('compute capability: (\\d+).(\\d+)', caps_str)
|
47 |
+
major = m.group(1)
|
48 |
+
minor = m.group(2)
|
49 |
+
return (major, minor)
|
50 |
+
|
51 |
+
def _get_cuda_gpu_arch_string():
|
52 |
+
gpus = [x for x in device_lib.list_local_devices() if x.device_type == 'GPU']
|
53 |
+
if len(gpus) == 0:
|
54 |
+
raise RuntimeError('No GPU devices found')
|
55 |
+
(major, minor) = _get_compute_cap(gpus[0])
|
56 |
+
return 'sm_%s%s' % (major, minor)
|
57 |
+
|
58 |
+
def _run_cmd(cmd):
|
59 |
+
with os.popen(cmd) as pipe:
|
60 |
+
output = pipe.read()
|
61 |
+
status = pipe.close()
|
62 |
+
if status is not None:
|
63 |
+
raise RuntimeError('NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s' % (cmd, output))
|
64 |
+
|
65 |
+
def _prepare_nvcc_cli(opts):
|
66 |
+
cmd = 'nvcc ' + opts.strip()
|
67 |
+
cmd += ' --disable-warnings'
|
68 |
+
cmd += ' --include-path "%s"' % tf.sysconfig.get_include()
|
69 |
+
cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'protobuf_archive', 'src')
|
70 |
+
cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'com_google_absl')
|
71 |
+
cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'eigen_archive')
|
72 |
+
|
73 |
+
compiler_bindir = _find_compiler_bindir()
|
74 |
+
if compiler_bindir is None:
|
75 |
+
# Require that _find_compiler_bindir succeeds on Windows. Allow
|
76 |
+
# nvcc to use whatever is the default on Linux.
|
77 |
+
if os.name == 'nt':
|
78 |
+
raise RuntimeError('Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".' % __file__)
|
79 |
+
else:
|
80 |
+
cmd += ' --compiler-bindir "%s"' % compiler_bindir
|
81 |
+
cmd += ' 2>&1'
|
82 |
+
return cmd
|
83 |
+
|
84 |
+
#----------------------------------------------------------------------------
|
85 |
+
# Main entry point.
|
86 |
+
|
87 |
+
_plugin_cache = dict()
|
88 |
+
|
89 |
+
def get_plugin(cuda_file):
|
90 |
+
cuda_file_base = os.path.basename(cuda_file)
|
91 |
+
cuda_file_name, cuda_file_ext = os.path.splitext(cuda_file_base)
|
92 |
+
|
93 |
+
# Already in cache?
|
94 |
+
if cuda_file in _plugin_cache:
|
95 |
+
return _plugin_cache[cuda_file]
|
96 |
+
|
97 |
+
# Setup plugin.
|
98 |
+
if verbose:
|
99 |
+
print('Setting up TensorFlow plugin "%s": ' % cuda_file_base, end='', flush=True)
|
100 |
+
try:
|
101 |
+
# Hash CUDA source.
|
102 |
+
md5 = hashlib.md5()
|
103 |
+
with open(cuda_file, 'rb') as f:
|
104 |
+
md5.update(f.read())
|
105 |
+
md5.update(b'\n')
|
106 |
+
|
107 |
+
# Hash headers included by the CUDA code by running it through the preprocessor.
|
108 |
+
if not do_not_hash_included_headers:
|
109 |
+
if verbose:
|
110 |
+
print('Preprocessing... ', end='', flush=True)
|
111 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
112 |
+
tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + cuda_file_ext)
|
113 |
+
_run_cmd(_prepare_nvcc_cli('"%s" --preprocess -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir)))
|
114 |
+
with open(tmp_file, 'rb') as f:
|
115 |
+
bad_file_str = ('"' + cuda_file.replace('\\', '/') + '"').encode('utf-8') # __FILE__ in error check macros
|
116 |
+
good_file_str = ('"' + cuda_file_base + '"').encode('utf-8')
|
117 |
+
for ln in f:
|
118 |
+
if not ln.startswith(b'# ') and not ln.startswith(b'#line '): # ignore line number pragmas
|
119 |
+
ln = ln.replace(bad_file_str, good_file_str)
|
120 |
+
md5.update(ln)
|
121 |
+
md5.update(b'\n')
|
122 |
+
|
123 |
+
# Select compiler options.
|
124 |
+
compile_opts = ''
|
125 |
+
if os.name == 'nt':
|
126 |
+
compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.lib')
|
127 |
+
elif os.name == 'posix':
|
128 |
+
compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.so')
|
129 |
+
compile_opts += ' --compiler-options \'-fPIC -D_GLIBCXX_USE_CXX11_ABI=0\''
|
130 |
+
else:
|
131 |
+
assert False # not Windows or Linux, w00t?
|
132 |
+
compile_opts += ' --gpu-architecture=%s' % _get_cuda_gpu_arch_string()
|
133 |
+
compile_opts += ' --use_fast_math'
|
134 |
+
nvcc_cmd = _prepare_nvcc_cli(compile_opts)
|
135 |
+
|
136 |
+
# Hash build configuration.
|
137 |
+
md5.update(('nvcc_cmd: ' + nvcc_cmd).encode('utf-8') + b'\n')
|
138 |
+
md5.update(('tf.VERSION: ' + tf.VERSION).encode('utf-8') + b'\n')
|
139 |
+
md5.update(('cuda_cache_version_tag: ' + cuda_cache_version_tag).encode('utf-8') + b'\n')
|
140 |
+
|
141 |
+
# Compile if not already compiled.
|
142 |
+
bin_file_ext = '.dll' if os.name == 'nt' else '.so'
|
143 |
+
bin_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + md5.hexdigest() + bin_file_ext)
|
144 |
+
if not os.path.isfile(bin_file):
|
145 |
+
if verbose:
|
146 |
+
print('Compiling... ', end='', flush=True)
|
147 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
148 |
+
tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + bin_file_ext)
|
149 |
+
_run_cmd(nvcc_cmd + ' "%s" --shared -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir))
|
150 |
+
os.makedirs(cuda_cache_path, exist_ok=True)
|
151 |
+
intermediate_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + uuid.uuid4().hex + '_tmp' + bin_file_ext)
|
152 |
+
shutil.copyfile(tmp_file, intermediate_file)
|
153 |
+
os.rename(intermediate_file, bin_file) # atomic
|
154 |
+
|
155 |
+
# Load.
|
156 |
+
if verbose:
|
157 |
+
print('Loading... ', end='', flush=True)
|
158 |
+
plugin = tf.load_op_library(bin_file)
|
159 |
+
|
160 |
+
# Add to cache.
|
161 |
+
_plugin_cache[cuda_file] = plugin
|
162 |
+
if verbose:
|
163 |
+
print('Done.', flush=True)
|
164 |
+
return plugin
|
165 |
+
|
166 |
+
except:
|
167 |
+
if verbose:
|
168 |
+
print('Failed!', flush=True)
|
169 |
+
raise
|
170 |
+
|
171 |
+
#----------------------------------------------------------------------------
|
dnnlib/tflib/network.py
ADDED
@@ -0,0 +1,592 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
+
|
3 |
+
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
4 |
+
#
|
5 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
6 |
+
# To view a copy of this license, visit
|
7 |
+
# https://nvlabs.github.io/stylegan2/license.html
|
8 |
+
|
9 |
+
"""Helper for managing networks."""
|
10 |
+
|
11 |
+
import types
|
12 |
+
import inspect
|
13 |
+
import re
|
14 |
+
import uuid
|
15 |
+
import sys
|
16 |
+
import numpy as np
|
17 |
+
import tensorflow as tf
|
18 |
+
|
19 |
+
from collections import OrderedDict
|
20 |
+
from typing import Any, List, Tuple, Union
|
21 |
+
|
22 |
+
from . import tfutil
|
23 |
+
from .. import util
|
24 |
+
|
25 |
+
from .tfutil import TfExpression, TfExpressionEx
|
26 |
+
|
27 |
+
_import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import.
|
28 |
+
_import_module_src = dict() # Source code for temporary modules created during pickle import.
|
29 |
+
|
30 |
+
|
31 |
+
def import_handler(handler_func):
|
32 |
+
"""Function decorator for declaring custom import handlers."""
|
33 |
+
_import_handlers.append(handler_func)
|
34 |
+
return handler_func
|
35 |
+
|
36 |
+
|
37 |
+
class Network:
|
38 |
+
"""Generic network abstraction.
|
39 |
+
|
40 |
+
Acts as a convenience wrapper for a parameterized network construction
|
41 |
+
function, providing several utility methods and convenient access to
|
42 |
+
the inputs/outputs/weights.
|
43 |
+
|
44 |
+
Network objects can be safely pickled and unpickled for long-term
|
45 |
+
archival purposes. The pickling works reliably as long as the underlying
|
46 |
+
network construction function is defined in a standalone Python module
|
47 |
+
that has no side effects or application-specific imports.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
name: Network name. Used to select TensorFlow name and variable scopes.
|
51 |
+
func_name: Fully qualified name of the underlying network construction function, or a top-level function object.
|
52 |
+
static_kwargs: Keyword arguments to be passed in to the network construction function.
|
53 |
+
|
54 |
+
Attributes:
|
55 |
+
name: User-specified name, defaults to build func name if None.
|
56 |
+
scope: Unique TensorFlow scope containing template graph and variables, derived from the user-specified name.
|
57 |
+
static_kwargs: Arguments passed to the user-supplied build func.
|
58 |
+
components: Container for sub-networks. Passed to the build func, and retained between calls.
|
59 |
+
num_inputs: Number of input tensors.
|
60 |
+
num_outputs: Number of output tensors.
|
61 |
+
input_shapes: Input tensor shapes (NC or NCHW), including minibatch dimension.
|
62 |
+
output_shapes: Output tensor shapes (NC or NCHW), including minibatch dimension.
|
63 |
+
input_shape: Short-hand for input_shapes[0].
|
64 |
+
output_shape: Short-hand for output_shapes[0].
|
65 |
+
input_templates: Input placeholders in the template graph.
|
66 |
+
output_templates: Output tensors in the template graph.
|
67 |
+
input_names: Name string for each input.
|
68 |
+
output_names: Name string for each output.
|
69 |
+
own_vars: Variables defined by this network (local_name => var), excluding sub-networks.
|
70 |
+
vars: All variables (local_name => var).
|
71 |
+
trainables: All trainable variables (local_name => var).
|
72 |
+
var_global_to_local: Mapping from variable global names to local names.
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(self, name: str = None, func_name: Any = None, **static_kwargs):
|
76 |
+
tfutil.assert_tf_initialized()
|
77 |
+
assert isinstance(name, str) or name is None
|
78 |
+
assert func_name is not None
|
79 |
+
assert isinstance(func_name, str) or util.is_top_level_function(func_name)
|
80 |
+
assert util.is_pickleable(static_kwargs)
|
81 |
+
|
82 |
+
self._init_fields()
|
83 |
+
self.name = name
|
84 |
+
self.static_kwargs = util.EasyDict(static_kwargs)
|
85 |
+
|
86 |
+
# Locate the user-specified network build function.
|
87 |
+
if util.is_top_level_function(func_name):
|
88 |
+
func_name = util.get_top_level_function_name(func_name)
|
89 |
+
module, self._build_func_name = util.get_module_from_obj_name(func_name)
|
90 |
+
self._build_func = util.get_obj_from_module(module, self._build_func_name)
|
91 |
+
assert callable(self._build_func)
|
92 |
+
|
93 |
+
# Dig up source code for the module containing the build function.
|
94 |
+
self._build_module_src = _import_module_src.get(module, None)
|
95 |
+
if self._build_module_src is None:
|
96 |
+
self._build_module_src = inspect.getsource(module)
|
97 |
+
|
98 |
+
# Init TensorFlow graph.
|
99 |
+
self._init_graph()
|
100 |
+
self.reset_own_vars()
|
101 |
+
|
102 |
+
def _init_fields(self) -> None:
|
103 |
+
self.name = None
|
104 |
+
self.scope = None
|
105 |
+
self.static_kwargs = util.EasyDict()
|
106 |
+
self.components = util.EasyDict()
|
107 |
+
self.num_inputs = 0
|
108 |
+
self.num_outputs = 0
|
109 |
+
self.input_shapes = [[]]
|
110 |
+
self.output_shapes = [[]]
|
111 |
+
self.input_shape = []
|
112 |
+
self.output_shape = []
|
113 |
+
self.input_templates = []
|
114 |
+
self.output_templates = []
|
115 |
+
self.input_names = []
|
116 |
+
self.output_names = []
|
117 |
+
self.own_vars = OrderedDict()
|
118 |
+
self.vars = OrderedDict()
|
119 |
+
self.trainables = OrderedDict()
|
120 |
+
self.var_global_to_local = OrderedDict()
|
121 |
+
|
122 |
+
self._build_func = None # User-supplied build function that constructs the network.
|
123 |
+
self._build_func_name = None # Name of the build function.
|
124 |
+
self._build_module_src = None # Full source code of the module containing the build function.
|
125 |
+
self._run_cache = dict() # Cached graph data for Network.run().
|
126 |
+
|
127 |
+
def _init_graph(self) -> None:
|
128 |
+
# Collect inputs.
|
129 |
+
self.input_names = []
|
130 |
+
|
131 |
+
for param in inspect.signature(self._build_func).parameters.values():
|
132 |
+
if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty:
|
133 |
+
self.input_names.append(param.name)
|
134 |
+
|
135 |
+
self.num_inputs = len(self.input_names)
|
136 |
+
assert self.num_inputs >= 1
|
137 |
+
|
138 |
+
# Choose name and scope.
|
139 |
+
if self.name is None:
|
140 |
+
self.name = self._build_func_name
|
141 |
+
assert re.match("^[A-Za-z0-9_.\\-]*$", self.name)
|
142 |
+
with tf.name_scope(None):
|
143 |
+
self.scope = tf.get_default_graph().unique_name(self.name, mark_as_used=True)
|
144 |
+
|
145 |
+
# Finalize build func kwargs.
|
146 |
+
build_kwargs = dict(self.static_kwargs)
|
147 |
+
build_kwargs["is_template_graph"] = True
|
148 |
+
build_kwargs["components"] = self.components
|
149 |
+
|
150 |
+
# Build template graph.
|
151 |
+
with tfutil.absolute_variable_scope(self.scope, reuse=False), tfutil.absolute_name_scope(self.scope): # ignore surrounding scopes
|
152 |
+
assert tf.get_variable_scope().name == self.scope
|
153 |
+
assert tf.get_default_graph().get_name_scope() == self.scope
|
154 |
+
with tf.control_dependencies(None): # ignore surrounding control dependencies
|
155 |
+
self.input_templates = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
|
156 |
+
out_expr = self._build_func(*self.input_templates, **build_kwargs)
|
157 |
+
|
158 |
+
# Collect outputs.
|
159 |
+
assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
|
160 |
+
self.output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
|
161 |
+
self.num_outputs = len(self.output_templates)
|
162 |
+
assert self.num_outputs >= 1
|
163 |
+
assert all(tfutil.is_tf_expression(t) for t in self.output_templates)
|
164 |
+
|
165 |
+
# Perform sanity checks.
|
166 |
+
if any(t.shape.ndims is None for t in self.input_templates):
|
167 |
+
raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.")
|
168 |
+
if any(t.shape.ndims is None for t in self.output_templates):
|
169 |
+
raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.")
|
170 |
+
if any(not isinstance(comp, Network) for comp in self.components.values()):
|
171 |
+
raise ValueError("Components of a Network must be Networks themselves.")
|
172 |
+
if len(self.components) != len(set(comp.name for comp in self.components.values())):
|
173 |
+
raise ValueError("Components of a Network must have unique names.")
|
174 |
+
|
175 |
+
# List inputs and outputs.
|
176 |
+
self.input_shapes = [t.shape.as_list() for t in self.input_templates]
|
177 |
+
self.output_shapes = [t.shape.as_list() for t in self.output_templates]
|
178 |
+
self.input_shape = self.input_shapes[0]
|
179 |
+
self.output_shape = self.output_shapes[0]
|
180 |
+
self.output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates]
|
181 |
+
|
182 |
+
# List variables.
|
183 |
+
self.own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/"))
|
184 |
+
self.vars = OrderedDict(self.own_vars)
|
185 |
+
self.vars.update((comp.name + "/" + name, var) for comp in self.components.values() for name, var in comp.vars.items())
|
186 |
+
self.trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable)
|
187 |
+
self.var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items())
|
188 |
+
|
189 |
+
def reset_own_vars(self) -> None:
|
190 |
+
"""Re-initialize all variables of this network, excluding sub-networks."""
|
191 |
+
tfutil.run([var.initializer for var in self.own_vars.values()])
|
192 |
+
|
193 |
+
def reset_vars(self) -> None:
|
194 |
+
"""Re-initialize all variables of this network, including sub-networks."""
|
195 |
+
tfutil.run([var.initializer for var in self.vars.values()])
|
196 |
+
|
197 |
+
def reset_trainables(self) -> None:
|
198 |
+
"""Re-initialize all trainable variables of this network, including sub-networks."""
|
199 |
+
tfutil.run([var.initializer for var in self.trainables.values()])
|
200 |
+
|
201 |
+
def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]:
|
202 |
+
"""Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s)."""
|
203 |
+
assert len(in_expr) == self.num_inputs
|
204 |
+
assert not all(expr is None for expr in in_expr)
|
205 |
+
|
206 |
+
# Finalize build func kwargs.
|
207 |
+
build_kwargs = dict(self.static_kwargs)
|
208 |
+
build_kwargs.update(dynamic_kwargs)
|
209 |
+
build_kwargs["is_template_graph"] = False
|
210 |
+
build_kwargs["components"] = self.components
|
211 |
+
|
212 |
+
# Build TensorFlow graph to evaluate the network.
|
213 |
+
with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name):
|
214 |
+
assert tf.get_variable_scope().name == self.scope
|
215 |
+
valid_inputs = [expr for expr in in_expr if expr is not None]
|
216 |
+
final_inputs = []
|
217 |
+
for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes):
|
218 |
+
if expr is not None:
|
219 |
+
expr = tf.identity(expr, name=name)
|
220 |
+
else:
|
221 |
+
expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name)
|
222 |
+
final_inputs.append(expr)
|
223 |
+
out_expr = self._build_func(*final_inputs, **build_kwargs)
|
224 |
+
|
225 |
+
# Propagate input shapes back to the user-specified expressions.
|
226 |
+
for expr, final in zip(in_expr, final_inputs):
|
227 |
+
if isinstance(expr, tf.Tensor):
|
228 |
+
expr.set_shape(final.shape)
|
229 |
+
|
230 |
+
# Express outputs in the desired format.
|
231 |
+
assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
|
232 |
+
if return_as_list:
|
233 |
+
out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
|
234 |
+
return out_expr
|
235 |
+
|
236 |
+
def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str:
|
237 |
+
"""Get the local name of a given variable, without any surrounding name scopes."""
|
238 |
+
assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str)
|
239 |
+
global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name
|
240 |
+
return self.var_global_to_local[global_name]
|
241 |
+
|
242 |
+
def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression:
|
243 |
+
"""Find variable by local or global name."""
|
244 |
+
assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str)
|
245 |
+
return self.vars[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name
|
246 |
+
|
247 |
+
def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray:
|
248 |
+
"""Get the value of a given variable as NumPy array.
|
249 |
+
Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible."""
|
250 |
+
return self.find_var(var_or_local_name).eval()
|
251 |
+
|
252 |
+
def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None:
|
253 |
+
"""Set the value of a given variable based on the given NumPy array.
|
254 |
+
Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible."""
|
255 |
+
tfutil.set_vars({self.find_var(var_or_local_name): new_value})
|
256 |
+
|
257 |
+
def __getstate__(self) -> dict:
|
258 |
+
"""Pickle export."""
|
259 |
+
state = dict()
|
260 |
+
state["version"] = 4
|
261 |
+
state["name"] = self.name
|
262 |
+
state["static_kwargs"] = dict(self.static_kwargs)
|
263 |
+
state["components"] = dict(self.components)
|
264 |
+
state["build_module_src"] = self._build_module_src
|
265 |
+
state["build_func_name"] = self._build_func_name
|
266 |
+
state["variables"] = list(zip(self.own_vars.keys(), tfutil.run(list(self.own_vars.values()))))
|
267 |
+
return state
|
268 |
+
|
269 |
+
def __setstate__(self, state: dict) -> None:
|
270 |
+
"""Pickle import."""
|
271 |
+
# pylint: disable=attribute-defined-outside-init
|
272 |
+
tfutil.assert_tf_initialized()
|
273 |
+
self._init_fields()
|
274 |
+
|
275 |
+
# Execute custom import handlers.
|
276 |
+
for handler in _import_handlers:
|
277 |
+
state = handler(state)
|
278 |
+
|
279 |
+
# Set basic fields.
|
280 |
+
assert state["version"] in [2, 3, 4]
|
281 |
+
self.name = state["name"]
|
282 |
+
self.static_kwargs = util.EasyDict(state["static_kwargs"])
|
283 |
+
self.components = util.EasyDict(state.get("components", {}))
|
284 |
+
self._build_module_src = state["build_module_src"]
|
285 |
+
self._build_func_name = state["build_func_name"]
|
286 |
+
|
287 |
+
# Create temporary module from the imported source code.
|
288 |
+
module_name = "_tflib_network_import_" + uuid.uuid4().hex
|
289 |
+
module = types.ModuleType(module_name)
|
290 |
+
sys.modules[module_name] = module
|
291 |
+
_import_module_src[module] = self._build_module_src
|
292 |
+
exec(self._build_module_src, module.__dict__) # pylint: disable=exec-used
|
293 |
+
|
294 |
+
# Locate network build function in the temporary module.
|
295 |
+
self._build_func = util.get_obj_from_module(module, self._build_func_name)
|
296 |
+
assert callable(self._build_func)
|
297 |
+
|
298 |
+
# Init TensorFlow graph.
|
299 |
+
self._init_graph()
|
300 |
+
self.reset_own_vars()
|
301 |
+
tfutil.set_vars({self.find_var(name): value for name, value in state["variables"]})
|
302 |
+
|
303 |
+
def clone(self, name: str = None, **new_static_kwargs) -> "Network":
|
304 |
+
"""Create a clone of this network with its own copy of the variables."""
|
305 |
+
# pylint: disable=protected-access
|
306 |
+
net = object.__new__(Network)
|
307 |
+
net._init_fields()
|
308 |
+
net.name = name if name is not None else self.name
|
309 |
+
net.static_kwargs = util.EasyDict(self.static_kwargs)
|
310 |
+
net.static_kwargs.update(new_static_kwargs)
|
311 |
+
net._build_module_src = self._build_module_src
|
312 |
+
net._build_func_name = self._build_func_name
|
313 |
+
net._build_func = self._build_func
|
314 |
+
net._init_graph()
|
315 |
+
net.copy_vars_from(self)
|
316 |
+
return net
|
317 |
+
|
318 |
+
def copy_own_vars_from(self, src_net: "Network") -> None:
|
319 |
+
"""Copy the values of all variables from the given network, excluding sub-networks."""
|
320 |
+
names = [name for name in self.own_vars.keys() if name in src_net.own_vars]
|
321 |
+
tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
|
322 |
+
|
323 |
+
def copy_vars_from(self, src_net: "Network") -> None:
|
324 |
+
"""Copy the values of all variables from the given network, including sub-networks."""
|
325 |
+
names = [name for name in self.vars.keys() if name in src_net.vars]
|
326 |
+
tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
|
327 |
+
|
328 |
+
def copy_trainables_from(self, src_net: "Network") -> None:
|
329 |
+
"""Copy the values of all trainable variables from the given network, including sub-networks."""
|
330 |
+
names = [name for name in self.trainables.keys() if name in src_net.trainables]
|
331 |
+
tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
|
332 |
+
|
333 |
+
def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network":
|
334 |
+
"""Create new network with the given parameters, and copy all variables from this network."""
|
335 |
+
if new_name is None:
|
336 |
+
new_name = self.name
|
337 |
+
static_kwargs = dict(self.static_kwargs)
|
338 |
+
static_kwargs.update(new_static_kwargs)
|
339 |
+
net = Network(name=new_name, func_name=new_func_name, **static_kwargs)
|
340 |
+
net.copy_vars_from(self)
|
341 |
+
return net
|
342 |
+
|
343 |
+
def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation:
|
344 |
+
"""Construct a TensorFlow op that updates the variables of this network
|
345 |
+
to be slightly closer to those of the given network."""
|
346 |
+
with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"):
|
347 |
+
ops = []
|
348 |
+
for name, var in self.vars.items():
|
349 |
+
if name in src_net.vars:
|
350 |
+
cur_beta = beta if name in self.trainables else beta_nontrainable
|
351 |
+
new_value = tfutil.lerp(src_net.vars[name], var, cur_beta)
|
352 |
+
ops.append(var.assign(new_value))
|
353 |
+
return tf.group(*ops)
|
354 |
+
|
355 |
+
def run(self,
|
356 |
+
*in_arrays: Tuple[Union[np.ndarray, None], ...],
|
357 |
+
input_transform: dict = None,
|
358 |
+
output_transform: dict = None,
|
359 |
+
return_as_list: bool = False,
|
360 |
+
print_progress: bool = False,
|
361 |
+
minibatch_size: int = None,
|
362 |
+
num_gpus: int = 1,
|
363 |
+
assume_frozen: bool = False,
|
364 |
+
**dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]:
|
365 |
+
"""Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s).
|
366 |
+
|
367 |
+
Args:
|
368 |
+
input_transform: A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network.
|
369 |
+
The dict must contain a 'func' field that points to a top-level function. The function is called with the input
|
370 |
+
TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
|
371 |
+
output_transform: A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network.
|
372 |
+
The dict must contain a 'func' field that points to a top-level function. The function is called with the output
|
373 |
+
TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
|
374 |
+
return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs.
|
375 |
+
print_progress: Print progress to the console? Useful for very large input arrays.
|
376 |
+
minibatch_size: Maximum minibatch size to use, None = disable batching.
|
377 |
+
num_gpus: Number of GPUs to use.
|
378 |
+
assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls.
|
379 |
+
dynamic_kwargs: Additional keyword arguments to be passed into the network build function.
|
380 |
+
"""
|
381 |
+
assert len(in_arrays) == self.num_inputs
|
382 |
+
assert not all(arr is None for arr in in_arrays)
|
383 |
+
assert input_transform is None or util.is_top_level_function(input_transform["func"])
|
384 |
+
assert output_transform is None or util.is_top_level_function(output_transform["func"])
|
385 |
+
output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs)
|
386 |
+
num_items = in_arrays[0].shape[0]
|
387 |
+
if minibatch_size is None:
|
388 |
+
minibatch_size = num_items
|
389 |
+
|
390 |
+
# Construct unique hash key from all arguments that affect the TensorFlow graph.
|
391 |
+
key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs)
|
392 |
+
def unwind_key(obj):
|
393 |
+
if isinstance(obj, dict):
|
394 |
+
return [(key, unwind_key(value)) for key, value in sorted(obj.items())]
|
395 |
+
if callable(obj):
|
396 |
+
return util.get_top_level_function_name(obj)
|
397 |
+
return obj
|
398 |
+
key = repr(unwind_key(key))
|
399 |
+
|
400 |
+
# Build graph.
|
401 |
+
if key not in self._run_cache:
|
402 |
+
with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None):
|
403 |
+
with tf.device("/cpu:0"):
|
404 |
+
in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
|
405 |
+
in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
|
406 |
+
|
407 |
+
out_split = []
|
408 |
+
for gpu in range(num_gpus):
|
409 |
+
with tf.device("/gpu:%d" % gpu):
|
410 |
+
net_gpu = self.clone() if assume_frozen else self
|
411 |
+
in_gpu = in_split[gpu]
|
412 |
+
|
413 |
+
if input_transform is not None:
|
414 |
+
in_kwargs = dict(input_transform)
|
415 |
+
in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs)
|
416 |
+
in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu)
|
417 |
+
|
418 |
+
assert len(in_gpu) == self.num_inputs
|
419 |
+
out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs)
|
420 |
+
|
421 |
+
if output_transform is not None:
|
422 |
+
out_kwargs = dict(output_transform)
|
423 |
+
out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs)
|
424 |
+
out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu)
|
425 |
+
|
426 |
+
assert len(out_gpu) == self.num_outputs
|
427 |
+
out_split.append(out_gpu)
|
428 |
+
|
429 |
+
with tf.device("/cpu:0"):
|
430 |
+
out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)]
|
431 |
+
self._run_cache[key] = in_expr, out_expr
|
432 |
+
|
433 |
+
# Run minibatches.
|
434 |
+
in_expr, out_expr = self._run_cache[key]
|
435 |
+
out_arrays = [np.empty([num_items] + expr.shape.as_list()[1:], expr.dtype.name) for expr in out_expr]
|
436 |
+
|
437 |
+
for mb_begin in range(0, num_items, minibatch_size):
|
438 |
+
if print_progress:
|
439 |
+
print("\r%d / %d" % (mb_begin, num_items), end="")
|
440 |
+
|
441 |
+
mb_end = min(mb_begin + minibatch_size, num_items)
|
442 |
+
mb_num = mb_end - mb_begin
|
443 |
+
mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)]
|
444 |
+
mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in)))
|
445 |
+
|
446 |
+
for dst, src in zip(out_arrays, mb_out):
|
447 |
+
dst[mb_begin: mb_end] = src
|
448 |
+
|
449 |
+
# Done.
|
450 |
+
if print_progress:
|
451 |
+
print("\r%d / %d" % (num_items, num_items))
|
452 |
+
|
453 |
+
if not return_as_list:
|
454 |
+
out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays)
|
455 |
+
return out_arrays
|
456 |
+
|
457 |
+
def list_ops(self) -> List[TfExpression]:
|
458 |
+
include_prefix = self.scope + "/"
|
459 |
+
exclude_prefix = include_prefix + "_"
|
460 |
+
ops = tf.get_default_graph().get_operations()
|
461 |
+
ops = [op for op in ops if op.name.startswith(include_prefix)]
|
462 |
+
ops = [op for op in ops if not op.name.startswith(exclude_prefix)]
|
463 |
+
return ops
|
464 |
+
|
465 |
+
def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]:
|
466 |
+
"""Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to
|
467 |
+
individual layers of the network. Mainly intended to be used for reporting."""
|
468 |
+
layers = []
|
469 |
+
|
470 |
+
def recurse(scope, parent_ops, parent_vars, level):
|
471 |
+
# Ignore specific patterns.
|
472 |
+
if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]):
|
473 |
+
return
|
474 |
+
|
475 |
+
# Filter ops and vars by scope.
|
476 |
+
global_prefix = scope + "/"
|
477 |
+
local_prefix = global_prefix[len(self.scope) + 1:]
|
478 |
+
cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]]
|
479 |
+
cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]]
|
480 |
+
if not cur_ops and not cur_vars:
|
481 |
+
return
|
482 |
+
|
483 |
+
# Filter out all ops related to variables.
|
484 |
+
for var in [op for op in cur_ops if op.type.startswith("Variable")]:
|
485 |
+
var_prefix = var.name + "/"
|
486 |
+
cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)]
|
487 |
+
|
488 |
+
# Scope does not contain ops as immediate children => recurse deeper.
|
489 |
+
contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type not in ["Identity", "Cast", "Transpose"] for op in cur_ops)
|
490 |
+
if (level == 0 or not contains_direct_ops) and (len(cur_ops) + len(cur_vars)) > 1:
|
491 |
+
visited = set()
|
492 |
+
for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]:
|
493 |
+
token = rel_name.split("/")[0]
|
494 |
+
if token not in visited:
|
495 |
+
recurse(global_prefix + token, cur_ops, cur_vars, level + 1)
|
496 |
+
visited.add(token)
|
497 |
+
return
|
498 |
+
|
499 |
+
# Report layer.
|
500 |
+
layer_name = scope[len(self.scope) + 1:]
|
501 |
+
layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1]
|
502 |
+
layer_trainables = [var for _name, var in cur_vars if var.trainable]
|
503 |
+
layers.append((layer_name, layer_output, layer_trainables))
|
504 |
+
|
505 |
+
recurse(self.scope, self.list_ops(), list(self.vars.items()), 0)
|
506 |
+
return layers
|
507 |
+
|
508 |
+
def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None:
|
509 |
+
"""Print a summary table of the network structure."""
|
510 |
+
rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]]
|
511 |
+
rows += [["---"] * 4]
|
512 |
+
total_params = 0
|
513 |
+
|
514 |
+
for layer_name, layer_output, layer_trainables in self.list_layers():
|
515 |
+
num_params = sum(int(np.prod(var.shape.as_list())) for var in layer_trainables)
|
516 |
+
weights = [var for var in layer_trainables if var.name.endswith("/weight:0")]
|
517 |
+
weights.sort(key=lambda x: len(x.name))
|
518 |
+
if len(weights) == 0 and len(layer_trainables) == 1:
|
519 |
+
weights = layer_trainables
|
520 |
+
total_params += num_params
|
521 |
+
|
522 |
+
if not hide_layers_with_no_params or num_params != 0:
|
523 |
+
num_params_str = str(num_params) if num_params > 0 else "-"
|
524 |
+
output_shape_str = str(layer_output.shape)
|
525 |
+
weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-"
|
526 |
+
rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]]
|
527 |
+
|
528 |
+
rows += [["---"] * 4]
|
529 |
+
rows += [["Total", str(total_params), "", ""]]
|
530 |
+
|
531 |
+
widths = [max(len(cell) for cell in column) for column in zip(*rows)]
|
532 |
+
print()
|
533 |
+
for row in rows:
|
534 |
+
print(" ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths)))
|
535 |
+
print()
|
536 |
+
|
537 |
+
def setup_weight_histograms(self, title: str = None) -> None:
|
538 |
+
"""Construct summary ops to include histograms of all trainable parameters in TensorBoard."""
|
539 |
+
if title is None:
|
540 |
+
title = self.name
|
541 |
+
|
542 |
+
with tf.name_scope(None), tf.device(None), tf.control_dependencies(None):
|
543 |
+
for local_name, var in self.trainables.items():
|
544 |
+
if "/" in local_name:
|
545 |
+
p = local_name.split("/")
|
546 |
+
name = title + "_" + p[-1] + "/" + "_".join(p[:-1])
|
547 |
+
else:
|
548 |
+
name = title + "_toplevel/" + local_name
|
549 |
+
|
550 |
+
tf.summary.histogram(name, var)
|
551 |
+
|
552 |
+
#----------------------------------------------------------------------------
|
553 |
+
# Backwards-compatible emulation of legacy output transformation in Network.run().
|
554 |
+
|
555 |
+
_print_legacy_warning = True
|
556 |
+
|
557 |
+
def _handle_legacy_output_transforms(output_transform, dynamic_kwargs):
|
558 |
+
global _print_legacy_warning
|
559 |
+
legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"]
|
560 |
+
if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs):
|
561 |
+
return output_transform, dynamic_kwargs
|
562 |
+
|
563 |
+
if _print_legacy_warning:
|
564 |
+
_print_legacy_warning = False
|
565 |
+
print()
|
566 |
+
print("WARNING: Old-style output transformations in Network.run() are deprecated.")
|
567 |
+
print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'")
|
568 |
+
print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.")
|
569 |
+
print()
|
570 |
+
assert output_transform is None
|
571 |
+
|
572 |
+
new_kwargs = dict(dynamic_kwargs)
|
573 |
+
new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs}
|
574 |
+
new_transform["func"] = _legacy_output_transform_func
|
575 |
+
return new_transform, new_kwargs
|
576 |
+
|
577 |
+
def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None):
|
578 |
+
if out_mul != 1.0:
|
579 |
+
expr = [x * out_mul for x in expr]
|
580 |
+
|
581 |
+
if out_add != 0.0:
|
582 |
+
expr = [x + out_add for x in expr]
|
583 |
+
|
584 |
+
if out_shrink > 1:
|
585 |
+
ksize = [1, 1, out_shrink, out_shrink]
|
586 |
+
expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr]
|
587 |
+
|
588 |
+
if out_dtype is not None:
|
589 |
+
if tf.as_dtype(out_dtype).is_integer:
|
590 |
+
expr = [tf.round(x) for x in expr]
|
591 |
+
expr = [tf.saturate_cast(x, out_dtype) for x in expr]
|
592 |
+
return expr
|
dnnlib/tflib/ops/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
+
|
3 |
+
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
4 |
+
#
|
5 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
6 |
+
# To view a copy of this license, visit
|
7 |
+
# https://nvlabs.github.io/stylegan2/license.html
|
8 |
+
|
9 |
+
# empty
|
dnnlib/tflib/ops/fused_bias_act.cu
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) SenseTime Research. All rights reserved.
|
2 |
+
|
3 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
4 |
+
//
|
5 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
6 |
+
// To view a copy of this license, visit
|
7 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
8 |
+
|
9 |
+
#define EIGEN_USE_GPU
|
10 |
+
#define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__
|
11 |
+
#include "tensorflow/core/framework/op.h"
|
12 |
+
#include "tensorflow/core/framework/op_kernel.h"
|
13 |
+
#include "tensorflow/core/framework/shape_inference.h"
|
14 |
+
#include <stdio.h>
|
15 |
+
|
16 |
+
using namespace tensorflow;
|
17 |
+
using namespace tensorflow::shape_inference;
|
18 |
+
|
19 |
+
#define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false)
|
20 |
+
|
21 |
+
//------------------------------------------------------------------------
|
22 |
+
// CUDA kernel.
|
23 |
+
|
24 |
+
template <class T>
|
25 |
+
struct FusedBiasActKernelParams
|
26 |
+
{
|
27 |
+
const T* x; // [sizeX]
|
28 |
+
const T* b; // [sizeB] or NULL
|
29 |
+
const T* ref; // [sizeX] or NULL
|
30 |
+
T* y; // [sizeX]
|
31 |
+
|
32 |
+
int grad;
|
33 |
+
int axis;
|
34 |
+
int act;
|
35 |
+
float alpha;
|
36 |
+
float gain;
|
37 |
+
|
38 |
+
int sizeX;
|
39 |
+
int sizeB;
|
40 |
+
int stepB;
|
41 |
+
int loopX;
|
42 |
+
};
|
43 |
+
|
44 |
+
template <class T>
|
45 |
+
static __global__ void FusedBiasActKernel(const FusedBiasActKernelParams<T> p)
|
46 |
+
{
|
47 |
+
const float expRange = 80.0f;
|
48 |
+
const float halfExpRange = 40.0f;
|
49 |
+
const float seluScale = 1.0507009873554804934193349852946f;
|
50 |
+
const float seluAlpha = 1.6732632423543772848170429916717f;
|
51 |
+
|
52 |
+
// Loop over elements.
|
53 |
+
int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
|
54 |
+
for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
|
55 |
+
{
|
56 |
+
// Load and apply bias.
|
57 |
+
float x = (float)p.x[xi];
|
58 |
+
if (p.b)
|
59 |
+
x += (float)p.b[(xi / p.stepB) % p.sizeB];
|
60 |
+
float ref = (p.ref) ? (float)p.ref[xi] : 0.0f;
|
61 |
+
if (p.gain != 0.0f & p.act != 9)
|
62 |
+
ref /= p.gain;
|
63 |
+
|
64 |
+
// Evaluate activation func.
|
65 |
+
float y;
|
66 |
+
switch (p.act * 10 + p.grad)
|
67 |
+
{
|
68 |
+
// linear
|
69 |
+
default:
|
70 |
+
case 10: y = x; break;
|
71 |
+
case 11: y = x; break;
|
72 |
+
case 12: y = 0.0f; break;
|
73 |
+
|
74 |
+
// relu
|
75 |
+
case 20: y = (x > 0.0f) ? x : 0.0f; break;
|
76 |
+
case 21: y = (ref > 0.0f) ? x : 0.0f; break;
|
77 |
+
case 22: y = 0.0f; break;
|
78 |
+
|
79 |
+
// lrelu
|
80 |
+
case 30: y = (x > 0.0f) ? x : x * p.alpha; break;
|
81 |
+
case 31: y = (ref > 0.0f) ? x : x * p.alpha; break;
|
82 |
+
case 32: y = 0.0f; break;
|
83 |
+
|
84 |
+
// tanh
|
85 |
+
case 40: { float c = expf(x); float d = 1.0f / c; y = (x < -expRange) ? -1.0f : (x > expRange) ? 1.0f : (c - d) / (c + d); } break;
|
86 |
+
case 41: y = x * (1.0f - ref * ref); break;
|
87 |
+
case 42: y = x * (1.0f - ref * ref) * (-2.0f * ref); break;
|
88 |
+
|
89 |
+
// sigmoid
|
90 |
+
case 50: y = (x < -expRange) ? 0.0f : 1.0f / (expf(-x) + 1.0f); break;
|
91 |
+
case 51: y = x * ref * (1.0f - ref); break;
|
92 |
+
case 52: y = x * ref * (1.0f - ref) * (1.0f - 2.0f * ref); break;
|
93 |
+
|
94 |
+
// elu
|
95 |
+
case 60: y = (x >= 0.0f) ? x : expf(x) - 1.0f; break;
|
96 |
+
case 61: y = (ref >= 0.0f) ? x : x * (ref + 1.0f); break;
|
97 |
+
case 62: y = (ref >= 0.0f) ? 0.0f : x * (ref + 1.0f); break;
|
98 |
+
|
99 |
+
// selu
|
100 |
+
case 70: y = (x >= 0.0f) ? seluScale * x : (seluScale * seluAlpha) * (expf(x) - 1.0f); break;
|
101 |
+
case 71: y = (ref >= 0.0f) ? x * seluScale : x * (ref + seluScale * seluAlpha); break;
|
102 |
+
case 72: y = (ref >= 0.0f) ? 0.0f : x * (ref + seluScale * seluAlpha); break;
|
103 |
+
|
104 |
+
// softplus
|
105 |
+
case 80: y = (x > expRange) ? x : logf(expf(x) + 1.0f); break;
|
106 |
+
case 81: y = x * (1.0f - expf(-ref)); break;
|
107 |
+
case 82: { float c = expf(-ref); y = x * c * (1.0f - c); } break;
|
108 |
+
|
109 |
+
// swish
|
110 |
+
case 90: y = (x < -expRange) ? 0.0f : x / (expf(-x) + 1.0f); break;
|
111 |
+
case 91: { float c = expf(ref); float d = c + 1.0f; y = (ref > halfExpRange) ? x : x * c * (ref + d) / (d * d); } break;
|
112 |
+
case 92: { float c = expf(ref); float d = c + 1.0f; y = (ref > halfExpRange) ? 0.0f : x * c * (ref * (2.0f - d) + 2.0f * d) / (d * d * d); } break;
|
113 |
+
}
|
114 |
+
|
115 |
+
// Apply gain and store.
|
116 |
+
p.y[xi] = (T)(y * p.gain);
|
117 |
+
}
|
118 |
+
}
|
119 |
+
|
120 |
+
//------------------------------------------------------------------------
|
121 |
+
// TensorFlow op.
|
122 |
+
|
123 |
+
template <class T>
|
124 |
+
struct FusedBiasActOp : public OpKernel
|
125 |
+
{
|
126 |
+
FusedBiasActKernelParams<T> m_attribs;
|
127 |
+
|
128 |
+
FusedBiasActOp(OpKernelConstruction* ctx) : OpKernel(ctx)
|
129 |
+
{
|
130 |
+
memset(&m_attribs, 0, sizeof(m_attribs));
|
131 |
+
OP_REQUIRES_OK(ctx, ctx->GetAttr("grad", &m_attribs.grad));
|
132 |
+
OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &m_attribs.axis));
|
133 |
+
OP_REQUIRES_OK(ctx, ctx->GetAttr("act", &m_attribs.act));
|
134 |
+
OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &m_attribs.alpha));
|
135 |
+
OP_REQUIRES_OK(ctx, ctx->GetAttr("gain", &m_attribs.gain));
|
136 |
+
OP_REQUIRES(ctx, m_attribs.grad >= 0, errors::InvalidArgument("grad must be non-negative"));
|
137 |
+
OP_REQUIRES(ctx, m_attribs.axis >= 0, errors::InvalidArgument("axis must be non-negative"));
|
138 |
+
OP_REQUIRES(ctx, m_attribs.act >= 0, errors::InvalidArgument("act must be non-negative"));
|
139 |
+
}
|
140 |
+
|
141 |
+
void Compute(OpKernelContext* ctx)
|
142 |
+
{
|
143 |
+
FusedBiasActKernelParams<T> p = m_attribs;
|
144 |
+
cudaStream_t stream = ctx->eigen_device<Eigen::GpuDevice>().stream();
|
145 |
+
|
146 |
+
const Tensor& x = ctx->input(0); // [...]
|
147 |
+
const Tensor& b = ctx->input(1); // [sizeB] or [0]
|
148 |
+
const Tensor& ref = ctx->input(2); // x.shape or [0]
|
149 |
+
p.x = x.flat<T>().data();
|
150 |
+
p.b = (b.NumElements()) ? b.flat<T>().data() : NULL;
|
151 |
+
p.ref = (ref.NumElements()) ? ref.flat<T>().data() : NULL;
|
152 |
+
OP_REQUIRES(ctx, b.NumElements() == 0 || m_attribs.axis < x.dims(), errors::InvalidArgument("axis out of bounds"));
|
153 |
+
OP_REQUIRES(ctx, b.dims() == 1, errors::InvalidArgument("b must have rank 1"));
|
154 |
+
OP_REQUIRES(ctx, b.NumElements() == 0 || b.NumElements() == x.dim_size(m_attribs.axis), errors::InvalidArgument("b has wrong number of elements"));
|
155 |
+
OP_REQUIRES(ctx, ref.NumElements() == ((p.grad == 0) ? 0 : x.NumElements()), errors::InvalidArgument("ref has wrong number of elements"));
|
156 |
+
OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("x is too large"));
|
157 |
+
|
158 |
+
p.sizeX = (int)x.NumElements();
|
159 |
+
p.sizeB = (int)b.NumElements();
|
160 |
+
p.stepB = 1;
|
161 |
+
for (int i = m_attribs.axis + 1; i < x.dims(); i++)
|
162 |
+
p.stepB *= (int)x.dim_size(i);
|
163 |
+
|
164 |
+
Tensor* y = NULL; // x.shape
|
165 |
+
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, x.shape(), &y));
|
166 |
+
p.y = y->flat<T>().data();
|
167 |
+
|
168 |
+
p.loopX = 4;
|
169 |
+
int blockSize = 4 * 32;
|
170 |
+
int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
|
171 |
+
void* args[] = {&p};
|
172 |
+
OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel((void*)FusedBiasActKernel<T>, gridSize, blockSize, args, 0, stream));
|
173 |
+
}
|
174 |
+
};
|
175 |
+
|
176 |
+
REGISTER_OP("FusedBiasAct")
|
177 |
+
.Input ("x: T")
|
178 |
+
.Input ("b: T")
|
179 |
+
.Input ("ref: T")
|
180 |
+
.Output ("y: T")
|
181 |
+
.Attr ("T: {float, half}")
|
182 |
+
.Attr ("grad: int = 0")
|
183 |
+
.Attr ("axis: int = 1")
|
184 |
+
.Attr ("act: int = 0")
|
185 |
+
.Attr ("alpha: float = 0.0")
|
186 |
+
.Attr ("gain: float = 1.0");
|
187 |
+
REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint<float>("T"), FusedBiasActOp<float>);
|
188 |
+
REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), FusedBiasActOp<Eigen::half>);
|
189 |
+
|
190 |
+
//------------------------------------------------------------------------
|
dnnlib/tflib/ops/fused_bias_act.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
+
|
3 |
+
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
4 |
+
#
|
5 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
6 |
+
# To view a copy of this license, visit
|
7 |
+
# https://nvlabs.github.io/stylegan2/license.html
|
8 |
+
|
9 |
+
"""Custom TensorFlow ops for efficient bias and activation."""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import numpy as np
|
13 |
+
import tensorflow as tf
|
14 |
+
from .. import custom_ops
|
15 |
+
from ...util import EasyDict
|
16 |
+
|
17 |
+
def _get_plugin():
|
18 |
+
return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu')
|
19 |
+
|
20 |
+
#----------------------------------------------------------------------------
|
21 |
+
|
22 |
+
activation_funcs = {
|
23 |
+
'linear': EasyDict(func=lambda x, **_: x, def_alpha=None, def_gain=1.0, cuda_idx=1, ref='y', zero_2nd_grad=True),
|
24 |
+
'relu': EasyDict(func=lambda x, **_: tf.nn.relu(x), def_alpha=None, def_gain=np.sqrt(2), cuda_idx=2, ref='y', zero_2nd_grad=True),
|
25 |
+
'lrelu': EasyDict(func=lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', zero_2nd_grad=True),
|
26 |
+
'tanh': EasyDict(func=lambda x, **_: tf.nn.tanh(x), def_alpha=None, def_gain=1.0, cuda_idx=4, ref='y', zero_2nd_grad=False),
|
27 |
+
'sigmoid': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x), def_alpha=None, def_gain=1.0, cuda_idx=5, ref='y', zero_2nd_grad=False),
|
28 |
+
'elu': EasyDict(func=lambda x, **_: tf.nn.elu(x), def_alpha=None, def_gain=1.0, cuda_idx=6, ref='y', zero_2nd_grad=False),
|
29 |
+
'selu': EasyDict(func=lambda x, **_: tf.nn.selu(x), def_alpha=None, def_gain=1.0, cuda_idx=7, ref='y', zero_2nd_grad=False),
|
30 |
+
'softplus': EasyDict(func=lambda x, **_: tf.nn.softplus(x), def_alpha=None, def_gain=1.0, cuda_idx=8, ref='y', zero_2nd_grad=False),
|
31 |
+
'swish': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x) * x, def_alpha=None, def_gain=np.sqrt(2), cuda_idx=9, ref='x', zero_2nd_grad=False),
|
32 |
+
}
|
33 |
+
|
34 |
+
#----------------------------------------------------------------------------
|
35 |
+
|
36 |
+
def fused_bias_act(x, b=None, axis=1, act='linear', alpha=None, gain=None, impl='cuda'):
|
37 |
+
r"""Fused bias and activation function.
|
38 |
+
|
39 |
+
Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
|
40 |
+
and scales the result by `gain`. Each of the steps is optional. In most cases,
|
41 |
+
the fused op is considerably more efficient than performing the same calculation
|
42 |
+
using standard TensorFlow ops. It supports first and second order gradients,
|
43 |
+
but not third order gradients.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
x: Input activation tensor. Can have any shape, but if `b` is defined, the
|
47 |
+
dimension corresponding to `axis`, as well as the rank, must be known.
|
48 |
+
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
|
49 |
+
as `x`. The shape must be known, and it must match the dimension of `x`
|
50 |
+
corresponding to `axis`.
|
51 |
+
axis: The dimension in `x` corresponding to the elements of `b`.
|
52 |
+
The value of `axis` is ignored if `b` is not specified.
|
53 |
+
act: Name of the activation function to evaluate, or `"linear"` to disable.
|
54 |
+
Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
|
55 |
+
See `activation_funcs` for a full list. `None` is not allowed.
|
56 |
+
alpha: Shape parameter for the activation function, or `None` to use the default.
|
57 |
+
gain: Scaling factor for the output tensor, or `None` to use default.
|
58 |
+
See `activation_funcs` for the default scaling of each activation function.
|
59 |
+
If unsure, consider specifying `1.0`.
|
60 |
+
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
Tensor of the same shape and datatype as `x`.
|
64 |
+
"""
|
65 |
+
|
66 |
+
impl_dict = {
|
67 |
+
'ref': _fused_bias_act_ref,
|
68 |
+
'cuda': _fused_bias_act_cuda,
|
69 |
+
}
|
70 |
+
return impl_dict[impl](x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain)
|
71 |
+
|
72 |
+
#----------------------------------------------------------------------------
|
73 |
+
|
74 |
+
def _fused_bias_act_ref(x, b, axis, act, alpha, gain):
|
75 |
+
"""Slow reference implementation of `fused_bias_act()` using standard TensorFlow ops."""
|
76 |
+
|
77 |
+
# Validate arguments.
|
78 |
+
x = tf.convert_to_tensor(x)
|
79 |
+
b = tf.convert_to_tensor(b) if b is not None else tf.constant([], dtype=x.dtype)
|
80 |
+
act_spec = activation_funcs[act]
|
81 |
+
assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis])
|
82 |
+
assert b.shape[0] == 0 or 0 <= axis < x.shape.rank
|
83 |
+
if alpha is None:
|
84 |
+
alpha = act_spec.def_alpha
|
85 |
+
if gain is None:
|
86 |
+
gain = act_spec.def_gain
|
87 |
+
|
88 |
+
# Add bias.
|
89 |
+
if b.shape[0] != 0:
|
90 |
+
x += tf.reshape(b, [-1 if i == axis else 1 for i in range(x.shape.rank)])
|
91 |
+
|
92 |
+
# Evaluate activation function.
|
93 |
+
x = act_spec.func(x, alpha=alpha)
|
94 |
+
|
95 |
+
# Scale by gain.
|
96 |
+
if gain != 1:
|
97 |
+
x *= gain
|
98 |
+
return x
|
99 |
+
|
100 |
+
#----------------------------------------------------------------------------
|
101 |
+
|
102 |
+
def _fused_bias_act_cuda(x, b, axis, act, alpha, gain):
|
103 |
+
"""Fast CUDA implementation of `fused_bias_act()` using custom ops."""
|
104 |
+
|
105 |
+
# Validate arguments.
|
106 |
+
x = tf.convert_to_tensor(x)
|
107 |
+
empty_tensor = tf.constant([], dtype=x.dtype)
|
108 |
+
b = tf.convert_to_tensor(b) if b is not None else empty_tensor
|
109 |
+
act_spec = activation_funcs[act]
|
110 |
+
assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis])
|
111 |
+
assert b.shape[0] == 0 or 0 <= axis < x.shape.rank
|
112 |
+
if alpha is None:
|
113 |
+
alpha = act_spec.def_alpha
|
114 |
+
if gain is None:
|
115 |
+
gain = act_spec.def_gain
|
116 |
+
|
117 |
+
# Special cases.
|
118 |
+
if act == 'linear' and b is None and gain == 1.0:
|
119 |
+
return x
|
120 |
+
if act_spec.cuda_idx is None:
|
121 |
+
return _fused_bias_act_ref(x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain)
|
122 |
+
|
123 |
+
# CUDA kernel.
|
124 |
+
cuda_kernel = _get_plugin().fused_bias_act
|
125 |
+
cuda_kwargs = dict(axis=axis, act=act_spec.cuda_idx, alpha=alpha, gain=gain)
|
126 |
+
|
127 |
+
# Forward pass: y = func(x, b).
|
128 |
+
def func_y(x, b):
|
129 |
+
y = cuda_kernel(x=x, b=b, ref=empty_tensor, grad=0, **cuda_kwargs)
|
130 |
+
y.set_shape(x.shape)
|
131 |
+
return y
|
132 |
+
|
133 |
+
# Backward pass: dx, db = grad(dy, x, y)
|
134 |
+
def grad_dx(dy, x, y):
|
135 |
+
ref = {'x': x, 'y': y}[act_spec.ref]
|
136 |
+
dx = cuda_kernel(x=dy, b=empty_tensor, ref=ref, grad=1, **cuda_kwargs)
|
137 |
+
dx.set_shape(x.shape)
|
138 |
+
return dx
|
139 |
+
def grad_db(dx):
|
140 |
+
if b.shape[0] == 0:
|
141 |
+
return empty_tensor
|
142 |
+
db = dx
|
143 |
+
if axis < x.shape.rank - 1:
|
144 |
+
db = tf.reduce_sum(db, list(range(axis + 1, x.shape.rank)))
|
145 |
+
if axis > 0:
|
146 |
+
db = tf.reduce_sum(db, list(range(axis)))
|
147 |
+
db.set_shape(b.shape)
|
148 |
+
return db
|
149 |
+
|
150 |
+
# Second order gradients: d_dy, d_x = grad2(d_dx, d_db, x, y)
|
151 |
+
def grad2_d_dy(d_dx, d_db, x, y):
|
152 |
+
ref = {'x': x, 'y': y}[act_spec.ref]
|
153 |
+
d_dy = cuda_kernel(x=d_dx, b=d_db, ref=ref, grad=1, **cuda_kwargs)
|
154 |
+
d_dy.set_shape(x.shape)
|
155 |
+
return d_dy
|
156 |
+
def grad2_d_x(d_dx, d_db, x, y):
|
157 |
+
ref = {'x': x, 'y': y}[act_spec.ref]
|
158 |
+
d_x = cuda_kernel(x=d_dx, b=d_db, ref=ref, grad=2, **cuda_kwargs)
|
159 |
+
d_x.set_shape(x.shape)
|
160 |
+
return d_x
|
161 |
+
|
162 |
+
# Fast version for piecewise-linear activation funcs.
|
163 |
+
@tf.custom_gradient
|
164 |
+
def func_zero_2nd_grad(x, b):
|
165 |
+
y = func_y(x, b)
|
166 |
+
@tf.custom_gradient
|
167 |
+
def grad(dy):
|
168 |
+
dx = grad_dx(dy, x, y)
|
169 |
+
db = grad_db(dx)
|
170 |
+
def grad2(d_dx, d_db):
|
171 |
+
d_dy = grad2_d_dy(d_dx, d_db, x, y)
|
172 |
+
return d_dy
|
173 |
+
return (dx, db), grad2
|
174 |
+
return y, grad
|
175 |
+
|
176 |
+
# Slow version for general activation funcs.
|
177 |
+
@tf.custom_gradient
|
178 |
+
def func_nonzero_2nd_grad(x, b):
|
179 |
+
y = func_y(x, b)
|
180 |
+
def grad_wrap(dy):
|
181 |
+
@tf.custom_gradient
|
182 |
+
def grad_impl(dy, x):
|
183 |
+
dx = grad_dx(dy, x, y)
|
184 |
+
db = grad_db(dx)
|
185 |
+
def grad2(d_dx, d_db):
|
186 |
+
d_dy = grad2_d_dy(d_dx, d_db, x, y)
|
187 |
+
d_x = grad2_d_x(d_dx, d_db, x, y)
|
188 |
+
return d_dy, d_x
|
189 |
+
return (dx, db), grad2
|
190 |
+
return grad_impl(dy, x)
|
191 |
+
return y, grad_wrap
|
192 |
+
|
193 |
+
# Which version to use?
|
194 |
+
if act_spec.zero_2nd_grad:
|
195 |
+
return func_zero_2nd_grad(x, b)
|
196 |
+
return func_nonzero_2nd_grad(x, b)
|
197 |
+
|
198 |
+
#----------------------------------------------------------------------------
|
dnnlib/tflib/ops/upfirdn_2d.cu
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) SenseTime Research. All rights reserved.
|
2 |
+
|
3 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
4 |
+
//
|
5 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
6 |
+
// To view a copy of this license, visit
|
7 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
8 |
+
|
9 |
+
#define EIGEN_USE_GPU
|
10 |
+
#define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__
|
11 |
+
#include "tensorflow/core/framework/op.h"
|
12 |
+
#include "tensorflow/core/framework/op_kernel.h"
|
13 |
+
#include "tensorflow/core/framework/shape_inference.h"
|
14 |
+
#include <stdio.h>
|
15 |
+
|
16 |
+
using namespace tensorflow;
|
17 |
+
using namespace tensorflow::shape_inference;
|
18 |
+
|
19 |
+
//------------------------------------------------------------------------
|
20 |
+
// Helpers.
|
21 |
+
|
22 |
+
#define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false)
|
23 |
+
|
24 |
+
static __host__ __device__ __forceinline__ int floorDiv(int a, int b)
|
25 |
+
{
|
26 |
+
int c = a / b;
|
27 |
+
if (c * b > a)
|
28 |
+
c--;
|
29 |
+
return c;
|
30 |
+
}
|
31 |
+
|
32 |
+
//------------------------------------------------------------------------
|
33 |
+
// CUDA kernel params.
|
34 |
+
|
35 |
+
template <class T>
|
36 |
+
struct UpFirDn2DKernelParams
|
37 |
+
{
|
38 |
+
const T* x; // [majorDim, inH, inW, minorDim]
|
39 |
+
const T* k; // [kernelH, kernelW]
|
40 |
+
T* y; // [majorDim, outH, outW, minorDim]
|
41 |
+
|
42 |
+
int upx;
|
43 |
+
int upy;
|
44 |
+
int downx;
|
45 |
+
int downy;
|
46 |
+
int padx0;
|
47 |
+
int padx1;
|
48 |
+
int pady0;
|
49 |
+
int pady1;
|
50 |
+
|
51 |
+
int majorDim;
|
52 |
+
int inH;
|
53 |
+
int inW;
|
54 |
+
int minorDim;
|
55 |
+
int kernelH;
|
56 |
+
int kernelW;
|
57 |
+
int outH;
|
58 |
+
int outW;
|
59 |
+
int loopMajor;
|
60 |
+
int loopX;
|
61 |
+
};
|
62 |
+
|
63 |
+
//------------------------------------------------------------------------
|
64 |
+
// General CUDA implementation for large filter kernels.
|
65 |
+
|
66 |
+
template <class T>
|
67 |
+
static __global__ void UpFirDn2DKernel_large(const UpFirDn2DKernelParams<T> p)
|
68 |
+
{
|
69 |
+
// Calculate thread index.
|
70 |
+
int minorIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
71 |
+
int outY = minorIdx / p.minorDim;
|
72 |
+
minorIdx -= outY * p.minorDim;
|
73 |
+
int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
|
74 |
+
int majorIdxBase = blockIdx.z * p.loopMajor;
|
75 |
+
if (outXBase >= p.outW || outY >= p.outH || majorIdxBase >= p.majorDim)
|
76 |
+
return;
|
77 |
+
|
78 |
+
// Setup Y receptive field.
|
79 |
+
int midY = outY * p.downy + p.upy - 1 - p.pady0;
|
80 |
+
int inY = min(max(floorDiv(midY, p.upy), 0), p.inH);
|
81 |
+
int h = min(max(floorDiv(midY + p.kernelH, p.upy), 0), p.inH) - inY;
|
82 |
+
int kernelY = midY + p.kernelH - (inY + 1) * p.upy;
|
83 |
+
|
84 |
+
// Loop over majorDim and outX.
|
85 |
+
for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor && majorIdx < p.majorDim; loopMajor++, majorIdx++)
|
86 |
+
for (int loopX = 0, outX = outXBase; loopX < p.loopX && outX < p.outW; loopX++, outX += blockDim.y)
|
87 |
+
{
|
88 |
+
// Setup X receptive field.
|
89 |
+
int midX = outX * p.downx + p.upx - 1 - p.padx0;
|
90 |
+
int inX = min(max(floorDiv(midX, p.upx), 0), p.inW);
|
91 |
+
int w = min(max(floorDiv(midX + p.kernelW, p.upx), 0), p.inW) - inX;
|
92 |
+
int kernelX = midX + p.kernelW - (inX + 1) * p.upx;
|
93 |
+
|
94 |
+
// Initialize pointers.
|
95 |
+
const T* xp = &p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx];
|
96 |
+
const T* kp = &p.k[kernelY * p.kernelW + kernelX];
|
97 |
+
int xpx = p.minorDim;
|
98 |
+
int kpx = -p.upx;
|
99 |
+
int xpy = p.inW * p.minorDim;
|
100 |
+
int kpy = -p.upy * p.kernelW;
|
101 |
+
|
102 |
+
// Inner loop.
|
103 |
+
float v = 0.0f;
|
104 |
+
for (int y = 0; y < h; y++)
|
105 |
+
{
|
106 |
+
for (int x = 0; x < w; x++)
|
107 |
+
{
|
108 |
+
v += (float)(*xp) * (float)(*kp);
|
109 |
+
xp += xpx;
|
110 |
+
kp += kpx;
|
111 |
+
}
|
112 |
+
xp += xpy - w * xpx;
|
113 |
+
kp += kpy - w * kpx;
|
114 |
+
}
|
115 |
+
|
116 |
+
// Store result.
|
117 |
+
p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v;
|
118 |
+
}
|
119 |
+
}
|
120 |
+
|
121 |
+
//------------------------------------------------------------------------
|
122 |
+
// Specialized CUDA implementation for small filter kernels.
|
123 |
+
|
124 |
+
template <class T, int upx, int upy, int downx, int downy, int kernelW, int kernelH, int tileOutW, int tileOutH>
|
125 |
+
static __global__ void UpFirDn2DKernel_small(const UpFirDn2DKernelParams<T> p)
|
126 |
+
{
|
127 |
+
//assert(kernelW % upx == 0);
|
128 |
+
//assert(kernelH % upy == 0);
|
129 |
+
const int tileInW = ((tileOutW - 1) * downx + kernelW - 1) / upx + 1;
|
130 |
+
const int tileInH = ((tileOutH - 1) * downy + kernelH - 1) / upy + 1;
|
131 |
+
__shared__ volatile float sk[kernelH][kernelW];
|
132 |
+
__shared__ volatile float sx[tileInH][tileInW];
|
133 |
+
|
134 |
+
// Calculate tile index.
|
135 |
+
int minorIdx = blockIdx.x;
|
136 |
+
int tileOutY = minorIdx / p.minorDim;
|
137 |
+
minorIdx -= tileOutY * p.minorDim;
|
138 |
+
tileOutY *= tileOutH;
|
139 |
+
int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
|
140 |
+
int majorIdxBase = blockIdx.z * p.loopMajor;
|
141 |
+
if (tileOutXBase >= p.outW | tileOutY >= p.outH | majorIdxBase >= p.majorDim)
|
142 |
+
return;
|
143 |
+
|
144 |
+
// Load filter kernel (flipped).
|
145 |
+
for (int tapIdx = threadIdx.x; tapIdx < kernelH * kernelW; tapIdx += blockDim.x)
|
146 |
+
{
|
147 |
+
int ky = tapIdx / kernelW;
|
148 |
+
int kx = tapIdx - ky * kernelW;
|
149 |
+
float v = 0.0f;
|
150 |
+
if (kx < p.kernelW & ky < p.kernelH)
|
151 |
+
v = (float)p.k[(p.kernelH - 1 - ky) * p.kernelW + (p.kernelW - 1 - kx)];
|
152 |
+
sk[ky][kx] = v;
|
153 |
+
}
|
154 |
+
|
155 |
+
// Loop over majorDim and outX.
|
156 |
+
for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor & majorIdx < p.majorDim; loopMajor++, majorIdx++)
|
157 |
+
for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outW; loopX++, tileOutX += tileOutW)
|
158 |
+
{
|
159 |
+
// Load input pixels.
|
160 |
+
int tileMidX = tileOutX * downx + upx - 1 - p.padx0;
|
161 |
+
int tileMidY = tileOutY * downy + upy - 1 - p.pady0;
|
162 |
+
int tileInX = floorDiv(tileMidX, upx);
|
163 |
+
int tileInY = floorDiv(tileMidY, upy);
|
164 |
+
__syncthreads();
|
165 |
+
for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW; inIdx += blockDim.x)
|
166 |
+
{
|
167 |
+
int relInY = inIdx / tileInW;
|
168 |
+
int relInX = inIdx - relInY * tileInW;
|
169 |
+
int inX = relInX + tileInX;
|
170 |
+
int inY = relInY + tileInY;
|
171 |
+
float v = 0.0f;
|
172 |
+
if (inX >= 0 & inY >= 0 & inX < p.inW & inY < p.inH)
|
173 |
+
v = (float)p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx];
|
174 |
+
sx[relInY][relInX] = v;
|
175 |
+
}
|
176 |
+
|
177 |
+
// Loop over output pixels.
|
178 |
+
__syncthreads();
|
179 |
+
for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW; outIdx += blockDim.x)
|
180 |
+
{
|
181 |
+
int relOutY = outIdx / tileOutW;
|
182 |
+
int relOutX = outIdx - relOutY * tileOutW;
|
183 |
+
int outX = relOutX + tileOutX;
|
184 |
+
int outY = relOutY + tileOutY;
|
185 |
+
|
186 |
+
// Setup receptive field.
|
187 |
+
int midX = tileMidX + relOutX * downx;
|
188 |
+
int midY = tileMidY + relOutY * downy;
|
189 |
+
int inX = floorDiv(midX, upx);
|
190 |
+
int inY = floorDiv(midY, upy);
|
191 |
+
int relInX = inX - tileInX;
|
192 |
+
int relInY = inY - tileInY;
|
193 |
+
int kernelX = (inX + 1) * upx - midX - 1; // flipped
|
194 |
+
int kernelY = (inY + 1) * upy - midY - 1; // flipped
|
195 |
+
|
196 |
+
// Inner loop.
|
197 |
+
float v = 0.0f;
|
198 |
+
#pragma unroll
|
199 |
+
for (int y = 0; y < kernelH / upy; y++)
|
200 |
+
#pragma unroll
|
201 |
+
for (int x = 0; x < kernelW / upx; x++)
|
202 |
+
v += sx[relInY + y][relInX + x] * sk[kernelY + y * upy][kernelX + x * upx];
|
203 |
+
|
204 |
+
// Store result.
|
205 |
+
if (outX < p.outW & outY < p.outH)
|
206 |
+
p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v;
|
207 |
+
}
|
208 |
+
}
|
209 |
+
}
|
210 |
+
|
211 |
+
//------------------------------------------------------------------------
|
212 |
+
// TensorFlow op.
|
213 |
+
|
214 |
+
template <class T>
|
215 |
+
struct UpFirDn2DOp : public OpKernel
|
216 |
+
{
|
217 |
+
UpFirDn2DKernelParams<T> m_attribs;
|
218 |
+
|
219 |
+
UpFirDn2DOp(OpKernelConstruction* ctx) : OpKernel(ctx)
|
220 |
+
{
|
221 |
+
memset(&m_attribs, 0, sizeof(m_attribs));
|
222 |
+
OP_REQUIRES_OK(ctx, ctx->GetAttr("upx", &m_attribs.upx));
|
223 |
+
OP_REQUIRES_OK(ctx, ctx->GetAttr("upy", &m_attribs.upy));
|
224 |
+
OP_REQUIRES_OK(ctx, ctx->GetAttr("downx", &m_attribs.downx));
|
225 |
+
OP_REQUIRES_OK(ctx, ctx->GetAttr("downy", &m_attribs.downy));
|
226 |
+
OP_REQUIRES_OK(ctx, ctx->GetAttr("padx0", &m_attribs.padx0));
|
227 |
+
OP_REQUIRES_OK(ctx, ctx->GetAttr("padx1", &m_attribs.padx1));
|
228 |
+
OP_REQUIRES_OK(ctx, ctx->GetAttr("pady0", &m_attribs.pady0));
|
229 |
+
OP_REQUIRES_OK(ctx, ctx->GetAttr("pady1", &m_attribs.pady1));
|
230 |
+
OP_REQUIRES(ctx, m_attribs.upx >= 1 && m_attribs.upy >= 1, errors::InvalidArgument("upx and upy must be at least 1x1"));
|
231 |
+
OP_REQUIRES(ctx, m_attribs.downx >= 1 && m_attribs.downy >= 1, errors::InvalidArgument("downx and downy must be at least 1x1"));
|
232 |
+
}
|
233 |
+
|
234 |
+
void Compute(OpKernelContext* ctx)
|
235 |
+
{
|
236 |
+
UpFirDn2DKernelParams<T> p = m_attribs;
|
237 |
+
cudaStream_t stream = ctx->eigen_device<Eigen::GpuDevice>().stream();
|
238 |
+
|
239 |
+
const Tensor& x = ctx->input(0); // [majorDim, inH, inW, minorDim]
|
240 |
+
const Tensor& k = ctx->input(1); // [kernelH, kernelW]
|
241 |
+
p.x = x.flat<T>().data();
|
242 |
+
p.k = k.flat<T>().data();
|
243 |
+
OP_REQUIRES(ctx, x.dims() == 4, errors::InvalidArgument("input must have rank 4"));
|
244 |
+
OP_REQUIRES(ctx, k.dims() == 2, errors::InvalidArgument("kernel must have rank 2"));
|
245 |
+
OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("input too large"));
|
246 |
+
OP_REQUIRES(ctx, k.NumElements() <= kint32max, errors::InvalidArgument("kernel too large"));
|
247 |
+
|
248 |
+
p.majorDim = (int)x.dim_size(0);
|
249 |
+
p.inH = (int)x.dim_size(1);
|
250 |
+
p.inW = (int)x.dim_size(2);
|
251 |
+
p.minorDim = (int)x.dim_size(3);
|
252 |
+
p.kernelH = (int)k.dim_size(0);
|
253 |
+
p.kernelW = (int)k.dim_size(1);
|
254 |
+
OP_REQUIRES(ctx, p.kernelW >= 1 && p.kernelH >= 1, errors::InvalidArgument("kernel must be at least 1x1"));
|
255 |
+
|
256 |
+
p.outW = (p.inW * p.upx + p.padx0 + p.padx1 - p.kernelW + p.downx) / p.downx;
|
257 |
+
p.outH = (p.inH * p.upy + p.pady0 + p.pady1 - p.kernelH + p.downy) / p.downy;
|
258 |
+
OP_REQUIRES(ctx, p.outW >= 1 && p.outH >= 1, errors::InvalidArgument("output must be at least 1x1"));
|
259 |
+
|
260 |
+
Tensor* y = NULL; // [majorDim, outH, outW, minorDim]
|
261 |
+
TensorShape ys;
|
262 |
+
ys.AddDim(p.majorDim);
|
263 |
+
ys.AddDim(p.outH);
|
264 |
+
ys.AddDim(p.outW);
|
265 |
+
ys.AddDim(p.minorDim);
|
266 |
+
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, ys, &y));
|
267 |
+
p.y = y->flat<T>().data();
|
268 |
+
OP_REQUIRES(ctx, y->NumElements() <= kint32max, errors::InvalidArgument("output too large"));
|
269 |
+
|
270 |
+
// Choose CUDA kernel to use.
|
271 |
+
void* cudaKernel = (void*)UpFirDn2DKernel_large<T>;
|
272 |
+
int tileOutW = -1;
|
273 |
+
int tileOutH = -1;
|
274 |
+
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 7 && p.kernelH <= 7) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 7,7, 64,16>; tileOutW = 64; tileOutH = 16; }
|
275 |
+
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 6,6, 64,16>; tileOutW = 64; tileOutH = 16; }
|
276 |
+
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 5 && p.kernelH <= 5) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 5,5, 64,16>; tileOutW = 64; tileOutH = 16; }
|
277 |
+
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 4,4, 64,16>; tileOutW = 64; tileOutH = 16; }
|
278 |
+
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 3 && p.kernelH <= 3) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 3,3, 64,16>; tileOutW = 64; tileOutH = 16; }
|
279 |
+
if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 8) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 8,8, 64,16>; tileOutW = 64; tileOutH = 16; }
|
280 |
+
if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 6,6, 64,16>; tileOutW = 64; tileOutH = 16; }
|
281 |
+
if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 4,4, 64,16>; tileOutW = 64; tileOutH = 16; }
|
282 |
+
if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 2 && p.kernelH <= 2) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 2,2, 64,16>; tileOutW = 64; tileOutH = 16; }
|
283 |
+
if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 8 && p.kernelH <= 8) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 8,8, 32,8>; tileOutW = 32; tileOutH = 8; }
|
284 |
+
if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 6,6, 32,8>; tileOutW = 32; tileOutH = 8; }
|
285 |
+
if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 4,4, 32,8>; tileOutW = 32; tileOutH = 8; }
|
286 |
+
if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 2 && p.kernelH <= 2) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 2,2, 32,8>; tileOutW = 32; tileOutH = 8; }
|
287 |
+
|
288 |
+
// Choose launch params.
|
289 |
+
dim3 blockSize;
|
290 |
+
dim3 gridSize;
|
291 |
+
if (tileOutW > 0 && tileOutH > 0) // small
|
292 |
+
{
|
293 |
+
p.loopMajor = (p.majorDim - 1) / 16384 + 1;
|
294 |
+
p.loopX = 1;
|
295 |
+
blockSize = dim3(32 * 8, 1, 1);
|
296 |
+
gridSize = dim3(((p.outH - 1) / tileOutH + 1) * p.minorDim, (p.outW - 1) / (p.loopX * tileOutW) + 1, (p.majorDim - 1) / p.loopMajor + 1);
|
297 |
+
}
|
298 |
+
else // large
|
299 |
+
{
|
300 |
+
p.loopMajor = (p.majorDim - 1) / 16384 + 1;
|
301 |
+
p.loopX = 4;
|
302 |
+
blockSize = dim3(4, 32, 1);
|
303 |
+
gridSize = dim3((p.outH * p.minorDim - 1) / blockSize.x + 1, (p.outW - 1) / (p.loopX * blockSize.y) + 1, (p.majorDim - 1) / p.loopMajor + 1);
|
304 |
+
}
|
305 |
+
|
306 |
+
// Launch CUDA kernel.
|
307 |
+
void* args[] = {&p};
|
308 |
+
OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(cudaKernel, gridSize, blockSize, args, 0, stream));
|
309 |
+
}
|
310 |
+
};
|
311 |
+
|
312 |
+
REGISTER_OP("UpFirDn2D")
|
313 |
+
.Input ("x: T")
|
314 |
+
.Input ("k: T")
|
315 |
+
.Output ("y: T")
|
316 |
+
.Attr ("T: {float, half}")
|
317 |
+
.Attr ("upx: int = 1")
|
318 |
+
.Attr ("upy: int = 1")
|
319 |
+
.Attr ("downx: int = 1")
|
320 |
+
.Attr ("downy: int = 1")
|
321 |
+
.Attr ("padx0: int = 0")
|
322 |
+
.Attr ("padx1: int = 0")
|
323 |
+
.Attr ("pady0: int = 0")
|
324 |
+
.Attr ("pady1: int = 0");
|
325 |
+
REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint<float>("T"), UpFirDn2DOp<float>);
|
326 |
+
REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), UpFirDn2DOp<Eigen::half>);
|
327 |
+
|
328 |
+
//------------------------------------------------------------------------
|
dnnlib/tflib/ops/upfirdn_2d.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
+
|
3 |
+
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
4 |
+
#
|
5 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
6 |
+
# To view a copy of this license, visit
|
7 |
+
# https://nvlabs.github.io/stylegan2/license.html
|
8 |
+
|
9 |
+
"""Custom TensorFlow ops for efficient resampling of 2D images."""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import numpy as np
|
13 |
+
import tensorflow as tf
|
14 |
+
from .. import custom_ops
|
15 |
+
|
16 |
+
def _get_plugin():
|
17 |
+
return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu')
|
18 |
+
|
19 |
+
#----------------------------------------------------------------------------
|
20 |
+
|
21 |
+
def upfirdn_2d(x, k, upx=1, upy=1, downx=1, downy=1, padx0=0, padx1=0, pady0=0, pady1=0, impl='cuda'):
|
22 |
+
r"""Pad, upsample, FIR filter, and downsample a batch of 2D images.
|
23 |
+
|
24 |
+
Accepts a batch of 2D images of the shape `[majorDim, inH, inW, minorDim]`
|
25 |
+
and performs the following operations for each image, batched across
|
26 |
+
`majorDim` and `minorDim`:
|
27 |
+
|
28 |
+
1. Pad the image with zeros by the specified number of pixels on each side
|
29 |
+
(`padx0`, `padx1`, `pady0`, `pady1`). Specifying a negative value
|
30 |
+
corresponds to cropping the image.
|
31 |
+
|
32 |
+
2. Upsample the image by inserting the zeros after each pixel (`upx`, `upy`).
|
33 |
+
|
34 |
+
3. Convolve the image with the specified 2D FIR filter (`k`), shrinking the
|
35 |
+
image so that the footprint of all output pixels lies within the input image.
|
36 |
+
|
37 |
+
4. Downsample the image by throwing away pixels (`downx`, `downy`).
|
38 |
+
|
39 |
+
This sequence of operations bears close resemblance to scipy.signal.upfirdn().
|
40 |
+
The fused op is considerably more efficient than performing the same calculation
|
41 |
+
using standard TensorFlow ops. It supports gradients of arbitrary order.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
x: Input tensor of the shape `[majorDim, inH, inW, minorDim]`.
|
45 |
+
k: 2D FIR filter of the shape `[firH, firW]`.
|
46 |
+
upx: Integer upsampling factor along the X-axis (default: 1).
|
47 |
+
upy: Integer upsampling factor along the Y-axis (default: 1).
|
48 |
+
downx: Integer downsampling factor along the X-axis (default: 1).
|
49 |
+
downy: Integer downsampling factor along the Y-axis (default: 1).
|
50 |
+
padx0: Number of pixels to pad on the left side (default: 0).
|
51 |
+
padx1: Number of pixels to pad on the right side (default: 0).
|
52 |
+
pady0: Number of pixels to pad on the top side (default: 0).
|
53 |
+
pady1: Number of pixels to pad on the bottom side (default: 0).
|
54 |
+
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
Tensor of the shape `[majorDim, outH, outW, minorDim]`, and same datatype as `x`.
|
58 |
+
"""
|
59 |
+
|
60 |
+
impl_dict = {
|
61 |
+
'ref': _upfirdn_2d_ref,
|
62 |
+
'cuda': _upfirdn_2d_cuda,
|
63 |
+
}
|
64 |
+
return impl_dict[impl](x=x, k=k, upx=upx, upy=upy, downx=downx, downy=downy, padx0=padx0, padx1=padx1, pady0=pady0, pady1=pady1)
|
65 |
+
|
66 |
+
#----------------------------------------------------------------------------
|
67 |
+
|
68 |
+
def _upfirdn_2d_ref(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1):
|
69 |
+
"""Slow reference implementation of `upfirdn_2d()` using standard TensorFlow ops."""
|
70 |
+
|
71 |
+
x = tf.convert_to_tensor(x)
|
72 |
+
k = np.asarray(k, dtype=np.float32)
|
73 |
+
assert x.shape.rank == 4
|
74 |
+
inH = x.shape[1].value
|
75 |
+
inW = x.shape[2].value
|
76 |
+
minorDim = _shape(x, 3)
|
77 |
+
kernelH, kernelW = k.shape
|
78 |
+
assert inW >= 1 and inH >= 1
|
79 |
+
assert kernelW >= 1 and kernelH >= 1
|
80 |
+
assert isinstance(upx, int) and isinstance(upy, int)
|
81 |
+
assert isinstance(downx, int) and isinstance(downy, int)
|
82 |
+
assert isinstance(padx0, int) and isinstance(padx1, int)
|
83 |
+
assert isinstance(pady0, int) and isinstance(pady1, int)
|
84 |
+
|
85 |
+
# Upsample (insert zeros).
|
86 |
+
x = tf.reshape(x, [-1, inH, 1, inW, 1, minorDim])
|
87 |
+
x = tf.pad(x, [[0, 0], [0, 0], [0, upy - 1], [0, 0], [0, upx - 1], [0, 0]])
|
88 |
+
x = tf.reshape(x, [-1, inH * upy, inW * upx, minorDim])
|
89 |
+
|
90 |
+
# Pad (crop if negative).
|
91 |
+
x = tf.pad(x, [[0, 0], [max(pady0, 0), max(pady1, 0)], [max(padx0, 0), max(padx1, 0)], [0, 0]])
|
92 |
+
x = x[:, max(-pady0, 0) : x.shape[1].value - max(-pady1, 0), max(-padx0, 0) : x.shape[2].value - max(-padx1, 0), :]
|
93 |
+
|
94 |
+
# Convolve with filter.
|
95 |
+
x = tf.transpose(x, [0, 3, 1, 2])
|
96 |
+
x = tf.reshape(x, [-1, 1, inH * upy + pady0 + pady1, inW * upx + padx0 + padx1])
|
97 |
+
w = tf.constant(k[::-1, ::-1, np.newaxis, np.newaxis], dtype=x.dtype)
|
98 |
+
x = tf.nn.conv2d(x, w, strides=[1,1,1,1], padding='VALID', data_format='NCHW')
|
99 |
+
x = tf.reshape(x, [-1, minorDim, inH * upy + pady0 + pady1 - kernelH + 1, inW * upx + padx0 + padx1 - kernelW + 1])
|
100 |
+
x = tf.transpose(x, [0, 2, 3, 1])
|
101 |
+
|
102 |
+
# Downsample (throw away pixels).
|
103 |
+
return x[:, ::downy, ::downx, :]
|
104 |
+
|
105 |
+
#----------------------------------------------------------------------------
|
106 |
+
|
107 |
+
def _upfirdn_2d_cuda(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1):
|
108 |
+
"""Fast CUDA implementation of `upfirdn_2d()` using custom ops."""
|
109 |
+
|
110 |
+
x = tf.convert_to_tensor(x)
|
111 |
+
k = np.asarray(k, dtype=np.float32)
|
112 |
+
majorDim, inH, inW, minorDim = x.shape.as_list()
|
113 |
+
kernelH, kernelW = k.shape
|
114 |
+
assert inW >= 1 and inH >= 1
|
115 |
+
assert kernelW >= 1 and kernelH >= 1
|
116 |
+
assert isinstance(upx, int) and isinstance(upy, int)
|
117 |
+
assert isinstance(downx, int) and isinstance(downy, int)
|
118 |
+
assert isinstance(padx0, int) and isinstance(padx1, int)
|
119 |
+
assert isinstance(pady0, int) and isinstance(pady1, int)
|
120 |
+
|
121 |
+
outW = (inW * upx + padx0 + padx1 - kernelW) // downx + 1
|
122 |
+
outH = (inH * upy + pady0 + pady1 - kernelH) // downy + 1
|
123 |
+
assert outW >= 1 and outH >= 1
|
124 |
+
|
125 |
+
kc = tf.constant(k, dtype=x.dtype)
|
126 |
+
gkc = tf.constant(k[::-1, ::-1], dtype=x.dtype)
|
127 |
+
gpadx0 = kernelW - padx0 - 1
|
128 |
+
gpady0 = kernelH - pady0 - 1
|
129 |
+
gpadx1 = inW * upx - outW * downx + padx0 - upx + 1
|
130 |
+
gpady1 = inH * upy - outH * downy + pady0 - upy + 1
|
131 |
+
|
132 |
+
@tf.custom_gradient
|
133 |
+
def func(x):
|
134 |
+
y = _get_plugin().up_fir_dn2d(x=x, k=kc, upx=upx, upy=upy, downx=downx, downy=downy, padx0=padx0, padx1=padx1, pady0=pady0, pady1=pady1)
|
135 |
+
y.set_shape([majorDim, outH, outW, minorDim])
|
136 |
+
@tf.custom_gradient
|
137 |
+
def grad(dy):
|
138 |
+
dx = _get_plugin().up_fir_dn2d(x=dy, k=gkc, upx=downx, upy=downy, downx=upx, downy=upy, padx0=gpadx0, padx1=gpadx1, pady0=gpady0, pady1=gpady1)
|
139 |
+
dx.set_shape([majorDim, inH, inW, minorDim])
|
140 |
+
return dx, func
|
141 |
+
return y, grad
|
142 |
+
return func(x)
|
143 |
+
|
144 |
+
#----------------------------------------------------------------------------
|
145 |
+
|
146 |
+
def filter_2d(x, k, gain=1, data_format='NCHW', impl='cuda'):
|
147 |
+
r"""Filter a batch of 2D images with the given FIR filter.
|
148 |
+
|
149 |
+
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
|
150 |
+
and filters each image with the given filter. The filter is normalized so that
|
151 |
+
if the input pixels are constant, they will be scaled by the specified `gain`.
|
152 |
+
Pixels outside the image are assumed to be zero.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
156 |
+
k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
|
157 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
158 |
+
data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
|
159 |
+
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
Tensor of the same shape and datatype as `x`.
|
163 |
+
"""
|
164 |
+
|
165 |
+
k = _setup_kernel(k) * gain
|
166 |
+
p = k.shape[0] - 1
|
167 |
+
return _simple_upfirdn_2d(x, k, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl)
|
168 |
+
|
169 |
+
#----------------------------------------------------------------------------
|
170 |
+
|
171 |
+
def upsample_2d(x, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
|
172 |
+
r"""Upsample a batch of 2D images with the given filter.
|
173 |
+
|
174 |
+
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
|
175 |
+
and upsamples each image with the given filter. The filter is normalized so that
|
176 |
+
if the input pixels are constant, they will be scaled by the specified `gain`.
|
177 |
+
Pixels outside the image are assumed to be zero, and the filter is padded with
|
178 |
+
zeros so that its shape is a multiple of the upsampling factor.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
182 |
+
k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
|
183 |
+
The default is `[1] * factor`, which corresponds to nearest-neighbor
|
184 |
+
upsampling.
|
185 |
+
factor: Integer upsampling factor (default: 2).
|
186 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
187 |
+
data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
|
188 |
+
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
Tensor of the shape `[N, C, H * factor, W * factor]` or
|
192 |
+
`[N, H * factor, W * factor, C]`, and same datatype as `x`.
|
193 |
+
"""
|
194 |
+
|
195 |
+
assert isinstance(factor, int) and factor >= 1
|
196 |
+
if k is None:
|
197 |
+
k = [1] * factor
|
198 |
+
k = _setup_kernel(k) * (gain * (factor ** 2))
|
199 |
+
p = k.shape[0] - factor
|
200 |
+
return _simple_upfirdn_2d(x, k, up=factor, pad0=(p+1)//2+factor-1, pad1=p//2, data_format=data_format, impl=impl)
|
201 |
+
|
202 |
+
#----------------------------------------------------------------------------
|
203 |
+
|
204 |
+
def downsample_2d(x, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
|
205 |
+
r"""Downsample a batch of 2D images with the given filter.
|
206 |
+
|
207 |
+
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
|
208 |
+
and downsamples each image with the given filter. The filter is normalized so that
|
209 |
+
if the input pixels are constant, they will be scaled by the specified `gain`.
|
210 |
+
Pixels outside the image are assumed to be zero, and the filter is padded with
|
211 |
+
zeros so that its shape is a multiple of the downsampling factor.
|
212 |
+
|
213 |
+
Args:
|
214 |
+
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
215 |
+
k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
|
216 |
+
The default is `[1] * factor`, which corresponds to average pooling.
|
217 |
+
factor: Integer downsampling factor (default: 2).
|
218 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
219 |
+
data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
|
220 |
+
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
Tensor of the shape `[N, C, H // factor, W // factor]` or
|
224 |
+
`[N, H // factor, W // factor, C]`, and same datatype as `x`.
|
225 |
+
"""
|
226 |
+
|
227 |
+
assert isinstance(factor, int) and factor >= 1
|
228 |
+
if k is None:
|
229 |
+
k = [1] * factor
|
230 |
+
k = _setup_kernel(k) * gain
|
231 |
+
p = k.shape[0] - factor
|
232 |
+
return _simple_upfirdn_2d(x, k, down=factor, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl)
|
233 |
+
|
234 |
+
#----------------------------------------------------------------------------
|
235 |
+
|
236 |
+
def upsample_conv_2d(x, w, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
|
237 |
+
r"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
|
238 |
+
|
239 |
+
Padding is performed only once at the beginning, not between the operations.
|
240 |
+
The fused op is considerably more efficient than performing the same calculation
|
241 |
+
using standard TensorFlow ops. It supports gradients of arbitrary order.
|
242 |
+
|
243 |
+
Args:
|
244 |
+
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
245 |
+
w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`.
|
246 |
+
Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
|
247 |
+
k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
|
248 |
+
The default is `[1] * factor`, which corresponds to nearest-neighbor
|
249 |
+
upsampling.
|
250 |
+
factor: Integer upsampling factor (default: 2).
|
251 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
252 |
+
data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
|
253 |
+
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
Tensor of the shape `[N, C, H * factor, W * factor]` or
|
257 |
+
`[N, H * factor, W * factor, C]`, and same datatype as `x`.
|
258 |
+
"""
|
259 |
+
|
260 |
+
assert isinstance(factor, int) and factor >= 1
|
261 |
+
|
262 |
+
# Check weight shape.
|
263 |
+
w = tf.convert_to_tensor(w)
|
264 |
+
assert w.shape.rank == 4
|
265 |
+
convH = w.shape[0].value
|
266 |
+
convW = w.shape[1].value
|
267 |
+
inC = _shape(w, 2)
|
268 |
+
outC = _shape(w, 3)
|
269 |
+
assert convW == convH
|
270 |
+
|
271 |
+
# Setup filter kernel.
|
272 |
+
if k is None:
|
273 |
+
k = [1] * factor
|
274 |
+
k = _setup_kernel(k) * (gain * (factor ** 2))
|
275 |
+
p = (k.shape[0] - factor) - (convW - 1)
|
276 |
+
|
277 |
+
# Determine data dimensions.
|
278 |
+
if data_format == 'NCHW':
|
279 |
+
stride = [1, 1, factor, factor]
|
280 |
+
output_shape = [_shape(x, 0), outC, (_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW]
|
281 |
+
num_groups = _shape(x, 1) // inC
|
282 |
+
else:
|
283 |
+
stride = [1, factor, factor, 1]
|
284 |
+
output_shape = [_shape(x, 0), (_shape(x, 1) - 1) * factor + convH, (_shape(x, 2) - 1) * factor + convW, outC]
|
285 |
+
num_groups = _shape(x, 3) // inC
|
286 |
+
|
287 |
+
# Transpose weights.
|
288 |
+
w = tf.reshape(w, [convH, convW, inC, num_groups, -1])
|
289 |
+
w = tf.transpose(w[::-1, ::-1], [0, 1, 4, 3, 2])
|
290 |
+
w = tf.reshape(w, [convH, convW, -1, num_groups * inC])
|
291 |
+
|
292 |
+
# Execute.
|
293 |
+
x = tf.nn.conv2d_transpose(x, w, output_shape=output_shape, strides=stride, padding='VALID', data_format=data_format)
|
294 |
+
return _simple_upfirdn_2d(x, k, pad0=(p+1)//2+factor-1, pad1=p//2+1, data_format=data_format, impl=impl)
|
295 |
+
|
296 |
+
#----------------------------------------------------------------------------
|
297 |
+
|
298 |
+
def conv_downsample_2d(x, w, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
|
299 |
+
r"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
|
300 |
+
|
301 |
+
Padding is performed only once at the beginning, not between the operations.
|
302 |
+
The fused op is considerably more efficient than performing the same calculation
|
303 |
+
using standard TensorFlow ops. It supports gradients of arbitrary order.
|
304 |
+
|
305 |
+
Args:
|
306 |
+
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
307 |
+
w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`.
|
308 |
+
Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
|
309 |
+
k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
|
310 |
+
The default is `[1] * factor`, which corresponds to average pooling.
|
311 |
+
factor: Integer downsampling factor (default: 2).
|
312 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
313 |
+
data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
|
314 |
+
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
315 |
+
|
316 |
+
Returns:
|
317 |
+
Tensor of the shape `[N, C, H // factor, W // factor]` or
|
318 |
+
`[N, H // factor, W // factor, C]`, and same datatype as `x`.
|
319 |
+
"""
|
320 |
+
|
321 |
+
assert isinstance(factor, int) and factor >= 1
|
322 |
+
w = tf.convert_to_tensor(w)
|
323 |
+
convH, convW, _inC, _outC = w.shape.as_list()
|
324 |
+
assert convW == convH
|
325 |
+
if k is None:
|
326 |
+
k = [1] * factor
|
327 |
+
k = _setup_kernel(k) * gain
|
328 |
+
p = (k.shape[0] - factor) + (convW - 1)
|
329 |
+
if data_format == 'NCHW':
|
330 |
+
s = [1, 1, factor, factor]
|
331 |
+
else:
|
332 |
+
s = [1, factor, factor, 1]
|
333 |
+
x = _simple_upfirdn_2d(x, k, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl)
|
334 |
+
return tf.nn.conv2d(x, w, strides=s, padding='VALID', data_format=data_format)
|
335 |
+
|
336 |
+
#----------------------------------------------------------------------------
|
337 |
+
# Internal helper funcs.
|
338 |
+
|
339 |
+
def _shape(tf_expr, dim_idx):
|
340 |
+
if tf_expr.shape.rank is not None:
|
341 |
+
dim = tf_expr.shape[dim_idx].value
|
342 |
+
if dim is not None:
|
343 |
+
return dim
|
344 |
+
return tf.shape(tf_expr)[dim_idx]
|
345 |
+
|
346 |
+
def _setup_kernel(k):
|
347 |
+
k = np.asarray(k, dtype=np.float32)
|
348 |
+
if k.ndim == 1:
|
349 |
+
k = np.outer(k, k)
|
350 |
+
k /= np.sum(k)
|
351 |
+
assert k.ndim == 2
|
352 |
+
assert k.shape[0] == k.shape[1]
|
353 |
+
return k
|
354 |
+
|
355 |
+
def _simple_upfirdn_2d(x, k, up=1, down=1, pad0=0, pad1=0, data_format='NCHW', impl='cuda'):
|
356 |
+
assert data_format in ['NCHW', 'NHWC']
|
357 |
+
assert x.shape.rank == 4
|
358 |
+
y = x
|
359 |
+
if data_format == 'NCHW':
|
360 |
+
y = tf.reshape(y, [-1, _shape(y, 2), _shape(y, 3), 1])
|
361 |
+
y = upfirdn_2d(y, k, upx=up, upy=up, downx=down, downy=down, padx0=pad0, padx1=pad1, pady0=pad0, pady1=pad1, impl=impl)
|
362 |
+
if data_format == 'NCHW':
|
363 |
+
y = tf.reshape(y, [-1, _shape(x, 1), _shape(y, 1), _shape(y, 2)])
|
364 |
+
return y
|
365 |
+
|
366 |
+
#----------------------------------------------------------------------------
|
dnnlib/tflib/optimizer.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
+
|
3 |
+
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
4 |
+
#
|
5 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
6 |
+
# To view a copy of this license, visit
|
7 |
+
# https://nvlabs.github.io/stylegan2/license.html
|
8 |
+
|
9 |
+
"""Helper wrapper for a Tensorflow optimizer."""
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import tensorflow as tf
|
13 |
+
|
14 |
+
from collections import OrderedDict
|
15 |
+
from typing import List, Union
|
16 |
+
|
17 |
+
from . import autosummary
|
18 |
+
from . import tfutil
|
19 |
+
from .. import util
|
20 |
+
|
21 |
+
from .tfutil import TfExpression, TfExpressionEx
|
22 |
+
|
23 |
+
try:
|
24 |
+
# TensorFlow 1.13
|
25 |
+
from tensorflow.python.ops import nccl_ops
|
26 |
+
except:
|
27 |
+
# Older TensorFlow versions
|
28 |
+
import tensorflow.contrib.nccl as nccl_ops
|
29 |
+
|
30 |
+
class Optimizer:
|
31 |
+
"""A Wrapper for tf.train.Optimizer.
|
32 |
+
|
33 |
+
Automatically takes care of:
|
34 |
+
- Gradient averaging for multi-GPU training.
|
35 |
+
- Gradient accumulation for arbitrarily large minibatches.
|
36 |
+
- Dynamic loss scaling and typecasts for FP16 training.
|
37 |
+
- Ignoring corrupted gradients that contain NaNs/Infs.
|
38 |
+
- Reporting statistics.
|
39 |
+
- Well-chosen default settings.
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self,
|
43 |
+
name: str = "Train", # Name string that will appear in TensorFlow graph.
|
44 |
+
tf_optimizer: str = "tf.train.AdamOptimizer", # Underlying optimizer class.
|
45 |
+
learning_rate: TfExpressionEx = 0.001, # Learning rate. Can vary over time.
|
46 |
+
minibatch_multiplier: TfExpressionEx = None, # Treat N consecutive minibatches as one by accumulating gradients.
|
47 |
+
share: "Optimizer" = None, # Share internal state with a previously created optimizer?
|
48 |
+
use_loss_scaling: bool = False, # Enable dynamic loss scaling for robust mixed-precision training?
|
49 |
+
loss_scaling_init: float = 64.0, # Log2 of initial loss scaling factor.
|
50 |
+
loss_scaling_inc: float = 0.0005, # Log2 of per-minibatch loss scaling increment when there is no overflow.
|
51 |
+
loss_scaling_dec: float = 1.0, # Log2 of per-minibatch loss scaling decrement when there is an overflow.
|
52 |
+
report_mem_usage: bool = False, # Report fine-grained memory usage statistics in TensorBoard?
|
53 |
+
**kwargs):
|
54 |
+
|
55 |
+
# Public fields.
|
56 |
+
self.name = name
|
57 |
+
self.learning_rate = learning_rate
|
58 |
+
self.minibatch_multiplier = minibatch_multiplier
|
59 |
+
self.id = self.name.replace("/", ".")
|
60 |
+
self.scope = tf.get_default_graph().unique_name(self.id)
|
61 |
+
self.optimizer_class = util.get_obj_by_name(tf_optimizer)
|
62 |
+
self.optimizer_kwargs = dict(kwargs)
|
63 |
+
self.use_loss_scaling = use_loss_scaling
|
64 |
+
self.loss_scaling_init = loss_scaling_init
|
65 |
+
self.loss_scaling_inc = loss_scaling_inc
|
66 |
+
self.loss_scaling_dec = loss_scaling_dec
|
67 |
+
|
68 |
+
# Private fields.
|
69 |
+
self._updates_applied = False
|
70 |
+
self._devices = OrderedDict() # device_name => EasyDict()
|
71 |
+
self._shared_optimizers = OrderedDict() # device_name => optimizer_class
|
72 |
+
self._gradient_shapes = None # [shape, ...]
|
73 |
+
self._report_mem_usage = report_mem_usage
|
74 |
+
|
75 |
+
# Validate arguments.
|
76 |
+
assert callable(self.optimizer_class)
|
77 |
+
|
78 |
+
# Share internal state if requested.
|
79 |
+
if share is not None:
|
80 |
+
assert isinstance(share, Optimizer)
|
81 |
+
assert self.optimizer_class is share.optimizer_class
|
82 |
+
assert self.learning_rate is share.learning_rate
|
83 |
+
assert self.optimizer_kwargs == share.optimizer_kwargs
|
84 |
+
self._shared_optimizers = share._shared_optimizers # pylint: disable=protected-access
|
85 |
+
|
86 |
+
def _get_device(self, device_name: str):
|
87 |
+
"""Get internal state for the given TensorFlow device."""
|
88 |
+
tfutil.assert_tf_initialized()
|
89 |
+
if device_name in self._devices:
|
90 |
+
return self._devices[device_name]
|
91 |
+
|
92 |
+
# Initialize fields.
|
93 |
+
device = util.EasyDict()
|
94 |
+
device.name = device_name
|
95 |
+
device.optimizer = None # Underlying optimizer: optimizer_class
|
96 |
+
device.loss_scaling_var = None # Log2 of loss scaling: tf.Variable
|
97 |
+
device.grad_raw = OrderedDict() # Raw gradients: var => [grad, ...]
|
98 |
+
device.grad_clean = OrderedDict() # Clean gradients: var => grad
|
99 |
+
device.grad_acc_vars = OrderedDict() # Accumulation sums: var => tf.Variable
|
100 |
+
device.grad_acc_count = None # Accumulation counter: tf.Variable
|
101 |
+
device.grad_acc = OrderedDict() # Accumulated gradients: var => grad
|
102 |
+
|
103 |
+
# Setup TensorFlow objects.
|
104 |
+
with tfutil.absolute_name_scope(self.scope + "/Devices"), tf.device(device_name), tf.control_dependencies(None):
|
105 |
+
if device_name not in self._shared_optimizers:
|
106 |
+
optimizer_name = self.scope.replace("/", "_") + "_opt%d" % len(self._shared_optimizers)
|
107 |
+
self._shared_optimizers[device_name] = self.optimizer_class(name=optimizer_name, learning_rate=self.learning_rate, **self.optimizer_kwargs)
|
108 |
+
device.optimizer = self._shared_optimizers[device_name]
|
109 |
+
if self.use_loss_scaling:
|
110 |
+
device.loss_scaling_var = tf.Variable(np.float32(self.loss_scaling_init), trainable=False, name="loss_scaling_var")
|
111 |
+
|
112 |
+
# Register device.
|
113 |
+
self._devices[device_name] = device
|
114 |
+
return device
|
115 |
+
|
116 |
+
def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None:
|
117 |
+
"""Register the gradients of the given loss function with respect to the given variables.
|
118 |
+
Intended to be called once per GPU."""
|
119 |
+
tfutil.assert_tf_initialized()
|
120 |
+
assert not self._updates_applied
|
121 |
+
device = self._get_device(loss.device)
|
122 |
+
|
123 |
+
# Validate trainables.
|
124 |
+
if isinstance(trainable_vars, dict):
|
125 |
+
trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars
|
126 |
+
assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
|
127 |
+
assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])
|
128 |
+
assert all(var.device == device.name for var in trainable_vars)
|
129 |
+
|
130 |
+
# Validate shapes.
|
131 |
+
if self._gradient_shapes is None:
|
132 |
+
self._gradient_shapes = [var.shape.as_list() for var in trainable_vars]
|
133 |
+
assert len(trainable_vars) == len(self._gradient_shapes)
|
134 |
+
assert all(var.shape.as_list() == var_shape for var, var_shape in zip(trainable_vars, self._gradient_shapes))
|
135 |
+
|
136 |
+
# Report memory usage if requested.
|
137 |
+
deps = []
|
138 |
+
if self._report_mem_usage:
|
139 |
+
self._report_mem_usage = False
|
140 |
+
try:
|
141 |
+
with tf.name_scope(self.id + '_mem'), tf.device(device.name), tf.control_dependencies([loss]):
|
142 |
+
deps.append(autosummary.autosummary(self.id + "/mem_usage_gb", tf.contrib.memory_stats.BytesInUse() / 2**30))
|
143 |
+
except tf.errors.NotFoundError:
|
144 |
+
pass
|
145 |
+
|
146 |
+
# Compute gradients.
|
147 |
+
with tf.name_scope(self.id + "_grad"), tf.device(device.name), tf.control_dependencies(deps):
|
148 |
+
loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
|
149 |
+
gate = tf.train.Optimizer.GATE_NONE # disable gating to reduce memory usage
|
150 |
+
grad_list = device.optimizer.compute_gradients(loss=loss, var_list=trainable_vars, gate_gradients=gate)
|
151 |
+
|
152 |
+
# Register gradients.
|
153 |
+
for grad, var in grad_list:
|
154 |
+
if var not in device.grad_raw:
|
155 |
+
device.grad_raw[var] = []
|
156 |
+
device.grad_raw[var].append(grad)
|
157 |
+
|
158 |
+
def apply_updates(self, allow_no_op: bool = False) -> tf.Operation:
|
159 |
+
"""Construct training op to update the registered variables based on their gradients."""
|
160 |
+
tfutil.assert_tf_initialized()
|
161 |
+
assert not self._updates_applied
|
162 |
+
self._updates_applied = True
|
163 |
+
all_ops = []
|
164 |
+
|
165 |
+
# Check for no-op.
|
166 |
+
if allow_no_op and len(self._devices) == 0:
|
167 |
+
with tfutil.absolute_name_scope(self.scope):
|
168 |
+
return tf.no_op(name='TrainingOp')
|
169 |
+
|
170 |
+
# Clean up gradients.
|
171 |
+
for device_idx, device in enumerate(self._devices.values()):
|
172 |
+
with tfutil.absolute_name_scope(self.scope + "/Clean%d" % device_idx), tf.device(device.name):
|
173 |
+
for var, grad in device.grad_raw.items():
|
174 |
+
|
175 |
+
# Filter out disconnected gradients and convert to float32.
|
176 |
+
grad = [g for g in grad if g is not None]
|
177 |
+
grad = [tf.cast(g, tf.float32) for g in grad]
|
178 |
+
|
179 |
+
# Sum within the device.
|
180 |
+
if len(grad) == 0:
|
181 |
+
grad = tf.zeros(var.shape) # No gradients => zero.
|
182 |
+
elif len(grad) == 1:
|
183 |
+
grad = grad[0] # Single gradient => use as is.
|
184 |
+
else:
|
185 |
+
grad = tf.add_n(grad) # Multiple gradients => sum.
|
186 |
+
|
187 |
+
# Scale as needed.
|
188 |
+
scale = 1.0 / len(device.grad_raw[var]) / len(self._devices)
|
189 |
+
scale = tf.constant(scale, dtype=tf.float32, name="scale")
|
190 |
+
if self.minibatch_multiplier is not None:
|
191 |
+
scale /= tf.cast(self.minibatch_multiplier, tf.float32)
|
192 |
+
scale = self.undo_loss_scaling(scale)
|
193 |
+
device.grad_clean[var] = grad * scale
|
194 |
+
|
195 |
+
# Sum gradients across devices.
|
196 |
+
if len(self._devices) > 1:
|
197 |
+
with tfutil.absolute_name_scope(self.scope + "/Broadcast"), tf.device(None):
|
198 |
+
for all_vars in zip(*[device.grad_clean.keys() for device in self._devices.values()]):
|
199 |
+
if len(all_vars) > 0 and all(dim > 0 for dim in all_vars[0].shape.as_list()): # NCCL does not support zero-sized tensors.
|
200 |
+
all_grads = [device.grad_clean[var] for device, var in zip(self._devices.values(), all_vars)]
|
201 |
+
all_grads = nccl_ops.all_sum(all_grads)
|
202 |
+
for device, var, grad in zip(self._devices.values(), all_vars, all_grads):
|
203 |
+
device.grad_clean[var] = grad
|
204 |
+
|
205 |
+
# Apply updates separately on each device.
|
206 |
+
for device_idx, device in enumerate(self._devices.values()):
|
207 |
+
with tfutil.absolute_name_scope(self.scope + "/Apply%d" % device_idx), tf.device(device.name):
|
208 |
+
# pylint: disable=cell-var-from-loop
|
209 |
+
|
210 |
+
# Accumulate gradients over time.
|
211 |
+
if self.minibatch_multiplier is None:
|
212 |
+
acc_ok = tf.constant(True, name='acc_ok')
|
213 |
+
device.grad_acc = OrderedDict(device.grad_clean)
|
214 |
+
else:
|
215 |
+
# Create variables.
|
216 |
+
with tf.control_dependencies(None):
|
217 |
+
for var in device.grad_clean.keys():
|
218 |
+
device.grad_acc_vars[var] = tf.Variable(tf.zeros(var.shape), trainable=False, name="grad_acc_var")
|
219 |
+
device.grad_acc_count = tf.Variable(tf.zeros([]), trainable=False, name="grad_acc_count")
|
220 |
+
|
221 |
+
# Track counter.
|
222 |
+
count_cur = device.grad_acc_count + 1.0
|
223 |
+
count_inc_op = lambda: tf.assign(device.grad_acc_count, count_cur)
|
224 |
+
count_reset_op = lambda: tf.assign(device.grad_acc_count, tf.zeros([]))
|
225 |
+
acc_ok = (count_cur >= tf.cast(self.minibatch_multiplier, tf.float32))
|
226 |
+
all_ops.append(tf.cond(acc_ok, count_reset_op, count_inc_op))
|
227 |
+
|
228 |
+
# Track gradients.
|
229 |
+
for var, grad in device.grad_clean.items():
|
230 |
+
acc_var = device.grad_acc_vars[var]
|
231 |
+
acc_cur = acc_var + grad
|
232 |
+
device.grad_acc[var] = acc_cur
|
233 |
+
with tf.control_dependencies([acc_cur]):
|
234 |
+
acc_inc_op = lambda: tf.assign(acc_var, acc_cur)
|
235 |
+
acc_reset_op = lambda: tf.assign(acc_var, tf.zeros(var.shape))
|
236 |
+
all_ops.append(tf.cond(acc_ok, acc_reset_op, acc_inc_op))
|
237 |
+
|
238 |
+
# No overflow => apply gradients.
|
239 |
+
all_ok = tf.reduce_all(tf.stack([acc_ok] + [tf.reduce_all(tf.is_finite(g)) for g in device.grad_acc.values()]))
|
240 |
+
apply_op = lambda: device.optimizer.apply_gradients([(tf.cast(grad, var.dtype), var) for var, grad in device.grad_acc.items()])
|
241 |
+
all_ops.append(tf.cond(all_ok, apply_op, tf.no_op))
|
242 |
+
|
243 |
+
# Adjust loss scaling.
|
244 |
+
if self.use_loss_scaling:
|
245 |
+
ls_inc_op = lambda: tf.assign_add(device.loss_scaling_var, self.loss_scaling_inc)
|
246 |
+
ls_dec_op = lambda: tf.assign_sub(device.loss_scaling_var, self.loss_scaling_dec)
|
247 |
+
ls_update_op = lambda: tf.group(tf.cond(all_ok, ls_inc_op, ls_dec_op))
|
248 |
+
all_ops.append(tf.cond(acc_ok, ls_update_op, tf.no_op))
|
249 |
+
|
250 |
+
# Last device => report statistics.
|
251 |
+
if device_idx == len(self._devices) - 1:
|
252 |
+
all_ops.append(autosummary.autosummary(self.id + "/learning_rate", self.learning_rate))
|
253 |
+
all_ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(all_ok, 0, 1), condition=acc_ok))
|
254 |
+
if self.use_loss_scaling:
|
255 |
+
all_ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", device.loss_scaling_var))
|
256 |
+
|
257 |
+
# Initialize variables.
|
258 |
+
self.reset_optimizer_state()
|
259 |
+
if self.use_loss_scaling:
|
260 |
+
tfutil.init_uninitialized_vars([device.loss_scaling_var for device in self._devices.values()])
|
261 |
+
if self.minibatch_multiplier is not None:
|
262 |
+
tfutil.run([var.initializer for device in self._devices.values() for var in list(device.grad_acc_vars.values()) + [device.grad_acc_count]])
|
263 |
+
|
264 |
+
# Group everything into a single op.
|
265 |
+
with tfutil.absolute_name_scope(self.scope):
|
266 |
+
return tf.group(*all_ops, name="TrainingOp")
|
267 |
+
|
268 |
+
def reset_optimizer_state(self) -> None:
|
269 |
+
"""Reset internal state of the underlying optimizer."""
|
270 |
+
tfutil.assert_tf_initialized()
|
271 |
+
tfutil.run([var.initializer for device in self._devices.values() for var in device.optimizer.variables()])
|
272 |
+
|
273 |
+
def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:
|
274 |
+
"""Get or create variable representing log2 of the current dynamic loss scaling factor."""
|
275 |
+
return self._get_device(device).loss_scaling_var
|
276 |
+
|
277 |
+
def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
|
278 |
+
"""Apply dynamic loss scaling for the given expression."""
|
279 |
+
assert tfutil.is_tf_expression(value)
|
280 |
+
if not self.use_loss_scaling:
|
281 |
+
return value
|
282 |
+
return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
|
283 |
+
|
284 |
+
def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
|
285 |
+
"""Undo the effect of dynamic loss scaling for the given expression."""
|
286 |
+
assert tfutil.is_tf_expression(value)
|
287 |
+
if not self.use_loss_scaling:
|
288 |
+
return value
|
289 |
+
return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type
|
290 |
+
|
291 |
+
|
292 |
+
class SimpleAdam:
|
293 |
+
"""Simplified version of tf.train.AdamOptimizer that behaves identically when used with dnnlib.tflib.Optimizer."""
|
294 |
+
|
295 |
+
def __init__(self, name="Adam", learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
|
296 |
+
self.name = name
|
297 |
+
self.learning_rate = learning_rate
|
298 |
+
self.beta1 = beta1
|
299 |
+
self.beta2 = beta2
|
300 |
+
self.epsilon = epsilon
|
301 |
+
self.all_state_vars = []
|
302 |
+
|
303 |
+
def variables(self):
|
304 |
+
return self.all_state_vars
|
305 |
+
|
306 |
+
def compute_gradients(self, loss, var_list, gate_gradients=tf.train.Optimizer.GATE_NONE):
|
307 |
+
assert gate_gradients == tf.train.Optimizer.GATE_NONE
|
308 |
+
return list(zip(tf.gradients(loss, var_list), var_list))
|
309 |
+
|
310 |
+
def apply_gradients(self, grads_and_vars):
|
311 |
+
with tf.name_scope(self.name):
|
312 |
+
state_vars = []
|
313 |
+
update_ops = []
|
314 |
+
|
315 |
+
# Adjust learning rate to deal with startup bias.
|
316 |
+
with tf.control_dependencies(None):
|
317 |
+
b1pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
|
318 |
+
b2pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
|
319 |
+
state_vars += [b1pow_var, b2pow_var]
|
320 |
+
b1pow_new = b1pow_var * self.beta1
|
321 |
+
b2pow_new = b2pow_var * self.beta2
|
322 |
+
update_ops += [tf.assign(b1pow_var, b1pow_new), tf.assign(b2pow_var, b2pow_new)]
|
323 |
+
lr_new = self.learning_rate * tf.sqrt(1 - b2pow_new) / (1 - b1pow_new)
|
324 |
+
|
325 |
+
# Construct ops to update each variable.
|
326 |
+
for grad, var in grads_and_vars:
|
327 |
+
with tf.control_dependencies(None):
|
328 |
+
m_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
|
329 |
+
v_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
|
330 |
+
state_vars += [m_var, v_var]
|
331 |
+
m_new = self.beta1 * m_var + (1 - self.beta1) * grad
|
332 |
+
v_new = self.beta2 * v_var + (1 - self.beta2) * tf.square(grad)
|
333 |
+
var_delta = lr_new * m_new / (tf.sqrt(v_new) + self.epsilon)
|
334 |
+
update_ops += [tf.assign(m_var, m_new), tf.assign(v_var, v_new), tf.assign_sub(var, var_delta)]
|
335 |
+
|
336 |
+
# Group everything together.
|
337 |
+
self.all_state_vars += state_vars
|
338 |
+
return tf.group(*update_ops)
|
dnnlib/tflib/tfutil.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
+
|
3 |
+
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
4 |
+
#
|
5 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
6 |
+
# To view a copy of this license, visit
|
7 |
+
# https://nvlabs.github.io/stylegan2/license.html
|
8 |
+
|
9 |
+
"""Miscellaneous helper utils for Tensorflow."""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import numpy as np
|
13 |
+
import tensorflow as tf
|
14 |
+
|
15 |
+
# Silence deprecation warnings from TensorFlow 1.13 onwards
|
16 |
+
import logging
|
17 |
+
logging.getLogger('tensorflow').setLevel(logging.ERROR)
|
18 |
+
import tensorflow.contrib # requires TensorFlow 1.x!
|
19 |
+
tf.contrib = tensorflow.contrib
|
20 |
+
|
21 |
+
from typing import Any, Iterable, List, Union
|
22 |
+
|
23 |
+
TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation]
|
24 |
+
"""A type that represents a valid Tensorflow expression."""
|
25 |
+
|
26 |
+
TfExpressionEx = Union[TfExpression, int, float, np.ndarray]
|
27 |
+
"""A type that can be converted to a valid Tensorflow expression."""
|
28 |
+
|
29 |
+
|
30 |
+
def run(*args, **kwargs) -> Any:
|
31 |
+
"""Run the specified ops in the default session."""
|
32 |
+
assert_tf_initialized()
|
33 |
+
return tf.get_default_session().run(*args, **kwargs)
|
34 |
+
|
35 |
+
|
36 |
+
def is_tf_expression(x: Any) -> bool:
|
37 |
+
"""Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation."""
|
38 |
+
return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation))
|
39 |
+
|
40 |
+
|
41 |
+
def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]:
|
42 |
+
"""Convert a Tensorflow shape to a list of ints. Retained for backwards compatibility -- use TensorShape.as_list() in new code."""
|
43 |
+
return [dim.value for dim in shape]
|
44 |
+
|
45 |
+
|
46 |
+
def flatten(x: TfExpressionEx) -> TfExpression:
|
47 |
+
"""Shortcut function for flattening a tensor."""
|
48 |
+
with tf.name_scope("Flatten"):
|
49 |
+
return tf.reshape(x, [-1])
|
50 |
+
|
51 |
+
|
52 |
+
def log2(x: TfExpressionEx) -> TfExpression:
|
53 |
+
"""Logarithm in base 2."""
|
54 |
+
with tf.name_scope("Log2"):
|
55 |
+
return tf.log(x) * np.float32(1.0 / np.log(2.0))
|
56 |
+
|
57 |
+
|
58 |
+
def exp2(x: TfExpressionEx) -> TfExpression:
|
59 |
+
"""Exponent in base 2."""
|
60 |
+
with tf.name_scope("Exp2"):
|
61 |
+
return tf.exp(x * np.float32(np.log(2.0)))
|
62 |
+
|
63 |
+
|
64 |
+
def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx:
|
65 |
+
"""Linear interpolation."""
|
66 |
+
with tf.name_scope("Lerp"):
|
67 |
+
return a + (b - a) * t
|
68 |
+
|
69 |
+
|
70 |
+
def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression:
|
71 |
+
"""Linear interpolation with clip."""
|
72 |
+
with tf.name_scope("LerpClip"):
|
73 |
+
return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)
|
74 |
+
|
75 |
+
|
76 |
+
def absolute_name_scope(scope: str) -> tf.name_scope:
|
77 |
+
"""Forcefully enter the specified name scope, ignoring any surrounding scopes."""
|
78 |
+
return tf.name_scope(scope + "/")
|
79 |
+
|
80 |
+
|
81 |
+
def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope:
|
82 |
+
"""Forcefully enter the specified variable scope, ignoring any surrounding scopes."""
|
83 |
+
return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False)
|
84 |
+
|
85 |
+
|
86 |
+
def _sanitize_tf_config(config_dict: dict = None) -> dict:
|
87 |
+
# Defaults.
|
88 |
+
cfg = dict()
|
89 |
+
cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is.
|
90 |
+
cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is.
|
91 |
+
cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info.
|
92 |
+
cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used.
|
93 |
+
cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed.
|
94 |
+
|
95 |
+
# Remove defaults for environment variables that are already set.
|
96 |
+
for key in list(cfg):
|
97 |
+
fields = key.split(".")
|
98 |
+
if fields[0] == "env":
|
99 |
+
assert len(fields) == 2
|
100 |
+
if fields[1] in os.environ:
|
101 |
+
del cfg[key]
|
102 |
+
|
103 |
+
# User overrides.
|
104 |
+
if config_dict is not None:
|
105 |
+
cfg.update(config_dict)
|
106 |
+
return cfg
|
107 |
+
|
108 |
+
|
109 |
+
def init_tf(config_dict: dict = None) -> None:
|
110 |
+
"""Initialize TensorFlow session using good default settings."""
|
111 |
+
# Skip if already initialized.
|
112 |
+
if tf.get_default_session() is not None:
|
113 |
+
return
|
114 |
+
|
115 |
+
# Setup config dict and random seeds.
|
116 |
+
cfg = _sanitize_tf_config(config_dict)
|
117 |
+
np_random_seed = cfg["rnd.np_random_seed"]
|
118 |
+
if np_random_seed is not None:
|
119 |
+
np.random.seed(np_random_seed)
|
120 |
+
tf_random_seed = cfg["rnd.tf_random_seed"]
|
121 |
+
if tf_random_seed == "auto":
|
122 |
+
tf_random_seed = np.random.randint(1 << 31)
|
123 |
+
if tf_random_seed is not None:
|
124 |
+
tf.set_random_seed(tf_random_seed)
|
125 |
+
|
126 |
+
# Setup environment variables.
|
127 |
+
for key, value in cfg.items():
|
128 |
+
fields = key.split(".")
|
129 |
+
if fields[0] == "env":
|
130 |
+
assert len(fields) == 2
|
131 |
+
os.environ[fields[1]] = str(value)
|
132 |
+
|
133 |
+
# Create default TensorFlow session.
|
134 |
+
create_session(cfg, force_as_default=True)
|
135 |
+
|
136 |
+
|
137 |
+
def assert_tf_initialized():
|
138 |
+
"""Check that TensorFlow session has been initialized."""
|
139 |
+
if tf.get_default_session() is None:
|
140 |
+
raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().")
|
141 |
+
|
142 |
+
|
143 |
+
def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session:
|
144 |
+
"""Create tf.Session based on config dict."""
|
145 |
+
# Setup TensorFlow config proto.
|
146 |
+
cfg = _sanitize_tf_config(config_dict)
|
147 |
+
config_proto = tf.ConfigProto()
|
148 |
+
for key, value in cfg.items():
|
149 |
+
fields = key.split(".")
|
150 |
+
if fields[0] not in ["rnd", "env"]:
|
151 |
+
obj = config_proto
|
152 |
+
for field in fields[:-1]:
|
153 |
+
obj = getattr(obj, field)
|
154 |
+
setattr(obj, fields[-1], value)
|
155 |
+
|
156 |
+
# Create session.
|
157 |
+
session = tf.Session(config=config_proto)
|
158 |
+
if force_as_default:
|
159 |
+
# pylint: disable=protected-access
|
160 |
+
session._default_session = session.as_default()
|
161 |
+
session._default_session.enforce_nesting = False
|
162 |
+
session._default_session.__enter__()
|
163 |
+
return session
|
164 |
+
|
165 |
+
|
166 |
+
def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None:
|
167 |
+
"""Initialize all tf.Variables that have not already been initialized.
|
168 |
+
|
169 |
+
Equivalent to the following, but more efficient and does not bloat the tf graph:
|
170 |
+
tf.variables_initializer(tf.report_uninitialized_variables()).run()
|
171 |
+
"""
|
172 |
+
assert_tf_initialized()
|
173 |
+
if target_vars is None:
|
174 |
+
target_vars = tf.global_variables()
|
175 |
+
|
176 |
+
test_vars = []
|
177 |
+
test_ops = []
|
178 |
+
|
179 |
+
with tf.control_dependencies(None): # ignore surrounding control_dependencies
|
180 |
+
for var in target_vars:
|
181 |
+
assert is_tf_expression(var)
|
182 |
+
|
183 |
+
try:
|
184 |
+
tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0"))
|
185 |
+
except KeyError:
|
186 |
+
# Op does not exist => variable may be uninitialized.
|
187 |
+
test_vars.append(var)
|
188 |
+
|
189 |
+
with absolute_name_scope(var.name.split(":")[0]):
|
190 |
+
test_ops.append(tf.is_variable_initialized(var))
|
191 |
+
|
192 |
+
init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited]
|
193 |
+
run([var.initializer for var in init_vars])
|
194 |
+
|
195 |
+
|
196 |
+
def set_vars(var_to_value_dict: dict) -> None:
|
197 |
+
"""Set the values of given tf.Variables.
|
198 |
+
|
199 |
+
Equivalent to the following, but more efficient and does not bloat the tf graph:
|
200 |
+
tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()]
|
201 |
+
"""
|
202 |
+
assert_tf_initialized()
|
203 |
+
ops = []
|
204 |
+
feed_dict = {}
|
205 |
+
|
206 |
+
for var, value in var_to_value_dict.items():
|
207 |
+
assert is_tf_expression(var)
|
208 |
+
|
209 |
+
try:
|
210 |
+
setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op
|
211 |
+
except KeyError:
|
212 |
+
with absolute_name_scope(var.name.split(":")[0]):
|
213 |
+
with tf.control_dependencies(None): # ignore surrounding control_dependencies
|
214 |
+
setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter
|
215 |
+
|
216 |
+
ops.append(setter)
|
217 |
+
feed_dict[setter.op.inputs[1]] = value
|
218 |
+
|
219 |
+
run(ops, feed_dict)
|
220 |
+
|
221 |
+
|
222 |
+
def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs):
|
223 |
+
"""Create tf.Variable with large initial value without bloating the tf graph."""
|
224 |
+
assert_tf_initialized()
|
225 |
+
assert isinstance(initial_value, np.ndarray)
|
226 |
+
zeros = tf.zeros(initial_value.shape, initial_value.dtype)
|
227 |
+
var = tf.Variable(zeros, *args, **kwargs)
|
228 |
+
set_vars({var: initial_value})
|
229 |
+
return var
|
230 |
+
|
231 |
+
|
232 |
+
def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
|
233 |
+
"""Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
|
234 |
+
Can be used as an input transformation for Network.run().
|
235 |
+
"""
|
236 |
+
images = tf.cast(images, tf.float32)
|
237 |
+
if nhwc_to_nchw:
|
238 |
+
images = tf.transpose(images, [0, 3, 1, 2])
|
239 |
+
return images * ((drange[1] - drange[0]) / 255) + drange[0]
|
240 |
+
|
241 |
+
|
242 |
+
def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1):
|
243 |
+
"""Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
|
244 |
+
Can be used as an output transformation for Network.run().
|
245 |
+
"""
|
246 |
+
images = tf.cast(images, tf.float32)
|
247 |
+
if shrink > 1:
|
248 |
+
ksize = [1, 1, shrink, shrink]
|
249 |
+
images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW")
|
250 |
+
if nchw_to_nhwc:
|
251 |
+
images = tf.transpose(images, [0, 2, 3, 1])
|
252 |
+
scale = 255 / (drange[1] - drange[0])
|
253 |
+
images = images * scale + (0.5 - drange[0] * scale)
|
254 |
+
return tf.saturate_cast(images, tf.uint8)
|
dnnlib/util.py
ADDED
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
5 |
+
# and proprietary rights in and to this software, related documentation
|
6 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
7 |
+
# distribution of this software and related documentation without an express
|
8 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
9 |
+
|
10 |
+
"""Miscellaneous utility classes and functions."""
|
11 |
+
|
12 |
+
import ctypes
|
13 |
+
import fnmatch
|
14 |
+
import importlib
|
15 |
+
import inspect
|
16 |
+
import numpy as np
|
17 |
+
import os
|
18 |
+
import shutil
|
19 |
+
import sys
|
20 |
+
import types
|
21 |
+
import io
|
22 |
+
import pickle
|
23 |
+
import re
|
24 |
+
import requests
|
25 |
+
import html
|
26 |
+
import hashlib
|
27 |
+
import glob
|
28 |
+
import tempfile
|
29 |
+
import urllib
|
30 |
+
import urllib.request
|
31 |
+
import uuid
|
32 |
+
|
33 |
+
from distutils.util import strtobool
|
34 |
+
from typing import Any, List, Tuple, Union
|
35 |
+
|
36 |
+
|
37 |
+
# Util classes
|
38 |
+
# ------------------------------------------------------------------------------------------
|
39 |
+
|
40 |
+
|
41 |
+
class EasyDict(dict):
|
42 |
+
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
43 |
+
|
44 |
+
def __getattr__(self, name: str) -> Any:
|
45 |
+
try:
|
46 |
+
return self[name]
|
47 |
+
except KeyError:
|
48 |
+
raise AttributeError(name)
|
49 |
+
|
50 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
51 |
+
self[name] = value
|
52 |
+
|
53 |
+
def __delattr__(self, name: str) -> None:
|
54 |
+
del self[name]
|
55 |
+
|
56 |
+
|
57 |
+
class Logger(object):
|
58 |
+
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
|
59 |
+
|
60 |
+
def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
|
61 |
+
self.file = None
|
62 |
+
|
63 |
+
if file_name is not None:
|
64 |
+
self.file = open(file_name, file_mode)
|
65 |
+
|
66 |
+
self.should_flush = should_flush
|
67 |
+
self.stdout = sys.stdout
|
68 |
+
self.stderr = sys.stderr
|
69 |
+
|
70 |
+
sys.stdout = self
|
71 |
+
sys.stderr = self
|
72 |
+
|
73 |
+
def __enter__(self) -> "Logger":
|
74 |
+
return self
|
75 |
+
|
76 |
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
77 |
+
self.close()
|
78 |
+
|
79 |
+
def write(self, text: Union[str, bytes]) -> None:
|
80 |
+
"""Write text to stdout (and a file) and optionally flush."""
|
81 |
+
if isinstance(text, bytes):
|
82 |
+
text = text.decode()
|
83 |
+
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
|
84 |
+
return
|
85 |
+
|
86 |
+
if self.file is not None:
|
87 |
+
self.file.write(text)
|
88 |
+
|
89 |
+
self.stdout.write(text)
|
90 |
+
|
91 |
+
if self.should_flush:
|
92 |
+
self.flush()
|
93 |
+
|
94 |
+
def flush(self) -> None:
|
95 |
+
"""Flush written text to both stdout and a file, if open."""
|
96 |
+
if self.file is not None:
|
97 |
+
self.file.flush()
|
98 |
+
|
99 |
+
self.stdout.flush()
|
100 |
+
|
101 |
+
def close(self) -> None:
|
102 |
+
"""Flush, close possible files, and remove stdout/stderr mirroring."""
|
103 |
+
self.flush()
|
104 |
+
|
105 |
+
# if using multiple loggers, prevent closing in wrong order
|
106 |
+
if sys.stdout is self:
|
107 |
+
sys.stdout = self.stdout
|
108 |
+
if sys.stderr is self:
|
109 |
+
sys.stderr = self.stderr
|
110 |
+
|
111 |
+
if self.file is not None:
|
112 |
+
self.file.close()
|
113 |
+
self.file = None
|
114 |
+
|
115 |
+
|
116 |
+
# Cache directories
|
117 |
+
# ------------------------------------------------------------------------------------------
|
118 |
+
|
119 |
+
_dnnlib_cache_dir = None
|
120 |
+
|
121 |
+
def set_cache_dir(path: str) -> None:
|
122 |
+
global _dnnlib_cache_dir
|
123 |
+
_dnnlib_cache_dir = path
|
124 |
+
|
125 |
+
def make_cache_dir_path(*paths: str) -> str:
|
126 |
+
if _dnnlib_cache_dir is not None:
|
127 |
+
return os.path.join(_dnnlib_cache_dir, *paths)
|
128 |
+
if 'DNNLIB_CACHE_DIR' in os.environ:
|
129 |
+
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
|
130 |
+
if 'HOME' in os.environ:
|
131 |
+
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
|
132 |
+
if 'USERPROFILE' in os.environ:
|
133 |
+
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
|
134 |
+
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
|
135 |
+
|
136 |
+
# Small util functions
|
137 |
+
# ------------------------------------------------------------------------------------------
|
138 |
+
|
139 |
+
|
140 |
+
def format_time(seconds: Union[int, float]) -> str:
|
141 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
142 |
+
s = int(np.rint(seconds))
|
143 |
+
|
144 |
+
if s < 60:
|
145 |
+
return "{0}s".format(s)
|
146 |
+
elif s < 60 * 60:
|
147 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
148 |
+
elif s < 24 * 60 * 60:
|
149 |
+
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
|
150 |
+
else:
|
151 |
+
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
|
152 |
+
|
153 |
+
|
154 |
+
def ask_yes_no(question: str) -> bool:
|
155 |
+
"""Ask the user the question until the user inputs a valid answer."""
|
156 |
+
while True:
|
157 |
+
try:
|
158 |
+
print("{0} [y/n]".format(question))
|
159 |
+
return strtobool(input().lower())
|
160 |
+
except ValueError:
|
161 |
+
pass
|
162 |
+
|
163 |
+
|
164 |
+
def tuple_product(t: Tuple) -> Any:
|
165 |
+
"""Calculate the product of the tuple elements."""
|
166 |
+
result = 1
|
167 |
+
|
168 |
+
for v in t:
|
169 |
+
result *= v
|
170 |
+
|
171 |
+
return result
|
172 |
+
|
173 |
+
|
174 |
+
_str_to_ctype = {
|
175 |
+
"uint8": ctypes.c_ubyte,
|
176 |
+
"uint16": ctypes.c_uint16,
|
177 |
+
"uint32": ctypes.c_uint32,
|
178 |
+
"uint64": ctypes.c_uint64,
|
179 |
+
"int8": ctypes.c_byte,
|
180 |
+
"int16": ctypes.c_int16,
|
181 |
+
"int32": ctypes.c_int32,
|
182 |
+
"int64": ctypes.c_int64,
|
183 |
+
"float32": ctypes.c_float,
|
184 |
+
"float64": ctypes.c_double
|
185 |
+
}
|
186 |
+
|
187 |
+
|
188 |
+
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
|
189 |
+
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
|
190 |
+
type_str = None
|
191 |
+
|
192 |
+
if isinstance(type_obj, str):
|
193 |
+
type_str = type_obj
|
194 |
+
elif hasattr(type_obj, "__name__"):
|
195 |
+
type_str = type_obj.__name__
|
196 |
+
elif hasattr(type_obj, "name"):
|
197 |
+
type_str = type_obj.name
|
198 |
+
else:
|
199 |
+
raise RuntimeError("Cannot infer type name from input")
|
200 |
+
|
201 |
+
assert type_str in _str_to_ctype.keys()
|
202 |
+
|
203 |
+
my_dtype = np.dtype(type_str)
|
204 |
+
my_ctype = _str_to_ctype[type_str]
|
205 |
+
|
206 |
+
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
|
207 |
+
|
208 |
+
return my_dtype, my_ctype
|
209 |
+
|
210 |
+
|
211 |
+
def is_pickleable(obj: Any) -> bool:
|
212 |
+
try:
|
213 |
+
with io.BytesIO() as stream:
|
214 |
+
pickle.dump(obj, stream)
|
215 |
+
return True
|
216 |
+
except:
|
217 |
+
return False
|
218 |
+
|
219 |
+
|
220 |
+
# Functionality to import modules/objects by name, and call functions by name
|
221 |
+
# ------------------------------------------------------------------------------------------
|
222 |
+
|
223 |
+
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
|
224 |
+
"""Searches for the underlying module behind the name to some python object.
|
225 |
+
Returns the module and the object name (original name with module part removed)."""
|
226 |
+
|
227 |
+
# allow convenience shorthands, substitute them by full names
|
228 |
+
obj_name = re.sub("^np.", "numpy.", obj_name)
|
229 |
+
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
|
230 |
+
|
231 |
+
# list alternatives for (module_name, local_obj_name)
|
232 |
+
parts = obj_name.split(".")
|
233 |
+
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
|
234 |
+
|
235 |
+
# try each alternative in turn
|
236 |
+
for module_name, local_obj_name in name_pairs:
|
237 |
+
try:
|
238 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
239 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
240 |
+
return module, local_obj_name
|
241 |
+
except:
|
242 |
+
pass
|
243 |
+
|
244 |
+
# maybe some of the modules themselves contain errors?
|
245 |
+
for module_name, _local_obj_name in name_pairs:
|
246 |
+
try:
|
247 |
+
importlib.import_module(module_name) # may raise ImportError
|
248 |
+
except ImportError:
|
249 |
+
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
|
250 |
+
raise
|
251 |
+
|
252 |
+
# maybe the requested attribute is missing?
|
253 |
+
for module_name, local_obj_name in name_pairs:
|
254 |
+
try:
|
255 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
256 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
257 |
+
except ImportError:
|
258 |
+
pass
|
259 |
+
|
260 |
+
# we are out of luck, but we have no idea why
|
261 |
+
raise ImportError(obj_name)
|
262 |
+
|
263 |
+
|
264 |
+
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
|
265 |
+
"""Traverses the object name and returns the last (rightmost) python object."""
|
266 |
+
if obj_name == '':
|
267 |
+
return module
|
268 |
+
obj = module
|
269 |
+
for part in obj_name.split("."):
|
270 |
+
obj = getattr(obj, part)
|
271 |
+
return obj
|
272 |
+
|
273 |
+
|
274 |
+
def get_obj_by_name(name: str) -> Any:
|
275 |
+
"""Finds the python object with the given name."""
|
276 |
+
module, obj_name = get_module_from_obj_name(name)
|
277 |
+
return get_obj_from_module(module, obj_name)
|
278 |
+
|
279 |
+
|
280 |
+
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
|
281 |
+
"""Finds the python object with the given name and calls it as a function."""
|
282 |
+
assert func_name is not None
|
283 |
+
# print('func_name: ', func_name) #'training.dataset.ImageFolderDataset'
|
284 |
+
func_obj = get_obj_by_name(func_name)
|
285 |
+
assert callable(func_obj)
|
286 |
+
return func_obj(*args, **kwargs)
|
287 |
+
|
288 |
+
|
289 |
+
def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
|
290 |
+
"""Finds the python class with the given name and constructs it with the given arguments."""
|
291 |
+
return call_func_by_name(*args, func_name=class_name, **kwargs)
|
292 |
+
|
293 |
+
|
294 |
+
def get_module_dir_by_obj_name(obj_name: str) -> str:
|
295 |
+
"""Get the directory path of the module containing the given object name."""
|
296 |
+
module, _ = get_module_from_obj_name(obj_name)
|
297 |
+
return os.path.dirname(inspect.getfile(module))
|
298 |
+
|
299 |
+
|
300 |
+
def is_top_level_function(obj: Any) -> bool:
|
301 |
+
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
|
302 |
+
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
|
303 |
+
|
304 |
+
|
305 |
+
def get_top_level_function_name(obj: Any) -> str:
|
306 |
+
"""Return the fully-qualified name of a top-level function."""
|
307 |
+
assert is_top_level_function(obj)
|
308 |
+
module = obj.__module__
|
309 |
+
if module == '__main__':
|
310 |
+
module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
|
311 |
+
return module + "." + obj.__name__
|
312 |
+
|
313 |
+
|
314 |
+
# File system helpers
|
315 |
+
# ------------------------------------------------------------------------------------------
|
316 |
+
|
317 |
+
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
|
318 |
+
"""List all files recursively in a given directory while ignoring given file and directory names.
|
319 |
+
Returns list of tuples containing both absolute and relative paths."""
|
320 |
+
assert os.path.isdir(dir_path)
|
321 |
+
base_name = os.path.basename(os.path.normpath(dir_path))
|
322 |
+
|
323 |
+
if ignores is None:
|
324 |
+
ignores = []
|
325 |
+
|
326 |
+
result = []
|
327 |
+
|
328 |
+
for root, dirs, files in os.walk(dir_path, topdown=True):
|
329 |
+
for ignore_ in ignores:
|
330 |
+
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
|
331 |
+
|
332 |
+
# dirs need to be edited in-place
|
333 |
+
for d in dirs_to_remove:
|
334 |
+
dirs.remove(d)
|
335 |
+
|
336 |
+
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
|
337 |
+
|
338 |
+
absolute_paths = [os.path.join(root, f) for f in files]
|
339 |
+
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
|
340 |
+
|
341 |
+
if add_base_to_relative:
|
342 |
+
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
|
343 |
+
|
344 |
+
assert len(absolute_paths) == len(relative_paths)
|
345 |
+
result += zip(absolute_paths, relative_paths)
|
346 |
+
|
347 |
+
return result
|
348 |
+
|
349 |
+
|
350 |
+
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
|
351 |
+
"""Takes in a list of tuples of (src, dst) paths and copies files.
|
352 |
+
Will create all necessary directories."""
|
353 |
+
for file in files:
|
354 |
+
target_dir_name = os.path.dirname(file[1])
|
355 |
+
|
356 |
+
# will create all intermediate-level directories
|
357 |
+
if not os.path.exists(target_dir_name):
|
358 |
+
os.makedirs(target_dir_name)
|
359 |
+
|
360 |
+
shutil.copyfile(file[0], file[1])
|
361 |
+
|
362 |
+
|
363 |
+
# URL helpers
|
364 |
+
# ------------------------------------------------------------------------------------------
|
365 |
+
|
366 |
+
def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
|
367 |
+
"""Determine whether the given object is a valid URL string."""
|
368 |
+
if not isinstance(obj, str) or not "://" in obj:
|
369 |
+
return False
|
370 |
+
if allow_file_urls and obj.startswith('file://'):
|
371 |
+
return True
|
372 |
+
try:
|
373 |
+
res = requests.compat.urlparse(obj)
|
374 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
375 |
+
return False
|
376 |
+
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
|
377 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
378 |
+
return False
|
379 |
+
except:
|
380 |
+
return False
|
381 |
+
return True
|
382 |
+
|
383 |
+
|
384 |
+
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
|
385 |
+
"""Download the given URL and return a binary-mode file object to access the data."""
|
386 |
+
assert num_attempts >= 1
|
387 |
+
assert not (return_filename and (not cache))
|
388 |
+
|
389 |
+
# Doesn't look like an URL scheme so interpret it as a local filename.
|
390 |
+
if not re.match('^[a-z]+://', url):
|
391 |
+
return url if return_filename else open(url, "rb")
|
392 |
+
|
393 |
+
# Handle file URLs. This code handles unusual file:// patterns that
|
394 |
+
# arise on Windows:
|
395 |
+
#
|
396 |
+
# file:///c:/foo.txt
|
397 |
+
#
|
398 |
+
# which would translate to a local '/c:/foo.txt' filename that's
|
399 |
+
# invalid. Drop the forward slash for such pathnames.
|
400 |
+
#
|
401 |
+
# If you touch this code path, you should test it on both Linux and
|
402 |
+
# Windows.
|
403 |
+
#
|
404 |
+
# Some internet resources suggest using urllib.request.url2pathname() but
|
405 |
+
# but that converts forward slashes to backslashes and this causes
|
406 |
+
# its own set of problems.
|
407 |
+
if url.startswith('file://'):
|
408 |
+
filename = urllib.parse.urlparse(url).path
|
409 |
+
if re.match(r'^/[a-zA-Z]:', filename):
|
410 |
+
filename = filename[1:]
|
411 |
+
return filename if return_filename else open(filename, "rb")
|
412 |
+
|
413 |
+
assert is_url(url)
|
414 |
+
|
415 |
+
# Lookup from cache.
|
416 |
+
if cache_dir is None:
|
417 |
+
cache_dir = make_cache_dir_path('downloads')
|
418 |
+
|
419 |
+
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
420 |
+
if cache:
|
421 |
+
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
|
422 |
+
if len(cache_files) == 1:
|
423 |
+
filename = cache_files[0]
|
424 |
+
return filename if return_filename else open(filename, "rb")
|
425 |
+
|
426 |
+
# Download.
|
427 |
+
url_name = None
|
428 |
+
url_data = None
|
429 |
+
with requests.Session() as session:
|
430 |
+
if verbose:
|
431 |
+
print("Downloading %s ..." % url, end="", flush=True)
|
432 |
+
for attempts_left in reversed(range(num_attempts)):
|
433 |
+
try:
|
434 |
+
with session.get(url) as res:
|
435 |
+
res.raise_for_status()
|
436 |
+
if len(res.content) == 0:
|
437 |
+
raise IOError("No data received")
|
438 |
+
|
439 |
+
if len(res.content) < 8192:
|
440 |
+
content_str = res.content.decode("utf-8")
|
441 |
+
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
442 |
+
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
443 |
+
if len(links) == 1:
|
444 |
+
url = requests.compat.urljoin(url, links[0])
|
445 |
+
raise IOError("Google Drive virus checker nag")
|
446 |
+
if "Google Drive - Quota exceeded" in content_str:
|
447 |
+
raise IOError("Google Drive download quota exceeded -- please try again later")
|
448 |
+
|
449 |
+
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
450 |
+
url_name = match[1] if match else url
|
451 |
+
url_data = res.content
|
452 |
+
if verbose:
|
453 |
+
print(" done")
|
454 |
+
break
|
455 |
+
except KeyboardInterrupt:
|
456 |
+
raise
|
457 |
+
except:
|
458 |
+
if not attempts_left:
|
459 |
+
if verbose:
|
460 |
+
print(" failed")
|
461 |
+
raise
|
462 |
+
if verbose:
|
463 |
+
print(".", end="", flush=True)
|
464 |
+
|
465 |
+
# Save to cache.
|
466 |
+
if cache:
|
467 |
+
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
|
468 |
+
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
|
469 |
+
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
|
470 |
+
os.makedirs(cache_dir, exist_ok=True)
|
471 |
+
with open(temp_file, "wb") as f:
|
472 |
+
f.write(url_data)
|
473 |
+
os.replace(temp_file, cache_file) # atomic
|
474 |
+
if return_filename:
|
475 |
+
return cache_file
|
476 |
+
|
477 |
+
# Return data as file object.
|
478 |
+
assert not return_filename
|
479 |
+
return io.BytesIO(url_data)
|
losses/color_transfer_loss.py
DELETED
@@ -1,60 +0,0 @@
|
|
1 |
-
from typing import List, Optional
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from torch import nn
|
5 |
-
from torch.nn.functional import (
|
6 |
-
smooth_l1_loss,
|
7 |
-
)
|
8 |
-
|
9 |
-
|
10 |
-
def flatten_CHW(im: torch.Tensor) -> torch.Tensor:
|
11 |
-
"""
|
12 |
-
(B, C, H, W) -> (B, -1)
|
13 |
-
"""
|
14 |
-
B = im.shape[0]
|
15 |
-
return im.reshape(B, -1)
|
16 |
-
|
17 |
-
|
18 |
-
def stddev(x: torch.Tensor) -> torch.Tensor:
|
19 |
-
"""
|
20 |
-
x: (B, -1), assume with mean normalized
|
21 |
-
Retuens:
|
22 |
-
stddev: (B)
|
23 |
-
"""
|
24 |
-
return torch.sqrt(torch.mean(x * x, dim=-1))
|
25 |
-
|
26 |
-
|
27 |
-
def gram_matrix(input_):
|
28 |
-
B, C = input_.shape[:2]
|
29 |
-
features = input_.view(B, C, -1)
|
30 |
-
N = features.shape[-1]
|
31 |
-
G = torch.bmm(features, features.transpose(1, 2)) # C x C
|
32 |
-
return G.div(C * N)
|
33 |
-
|
34 |
-
|
35 |
-
class ColorTransferLoss(nn.Module):
|
36 |
-
"""Penalize the gram matrix difference between StyleGAN2's ToRGB outputs"""
|
37 |
-
def __init__(
|
38 |
-
self,
|
39 |
-
init_rgbs,
|
40 |
-
scale_rgb: bool = False
|
41 |
-
):
|
42 |
-
super().__init__()
|
43 |
-
|
44 |
-
with torch.no_grad():
|
45 |
-
init_feats = [x.detach() for x in init_rgbs]
|
46 |
-
self.stds = [stddev(flatten_CHW(rgb)) if scale_rgb else 1 for rgb in init_feats] # (B, 1, 1, 1) or scalar
|
47 |
-
self.grams = [gram_matrix(rgb / std) for rgb, std in zip(init_feats, self.stds)]
|
48 |
-
|
49 |
-
def forward(self, rgbs: List[torch.Tensor], level: int = None):
|
50 |
-
if level is None:
|
51 |
-
level = len(self.grams)
|
52 |
-
|
53 |
-
feats = rgbs
|
54 |
-
loss = 0
|
55 |
-
for i, (rgb, std) in enumerate(zip(feats[:level], self.stds[:level])):
|
56 |
-
G = gram_matrix(rgb / std)
|
57 |
-
loss = loss + smooth_l1_loss(G, self.grams[i])
|
58 |
-
|
59 |
-
return loss
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
losses/joint_loss.py
DELETED
@@ -1,167 +0,0 @@
|
|
1 |
-
from argparse import (
|
2 |
-
ArgumentParser,
|
3 |
-
Namespace,
|
4 |
-
)
|
5 |
-
from typing import (
|
6 |
-
Dict,
|
7 |
-
Iterable,
|
8 |
-
Optional,
|
9 |
-
Tuple,
|
10 |
-
)
|
11 |
-
|
12 |
-
import numpy as np
|
13 |
-
import torch
|
14 |
-
from torch import nn
|
15 |
-
|
16 |
-
from utils.misc import (
|
17 |
-
optional_string,
|
18 |
-
iterable_to_str,
|
19 |
-
)
|
20 |
-
|
21 |
-
from .contextual_loss import ContextualLoss
|
22 |
-
from .color_transfer_loss import ColorTransferLoss
|
23 |
-
from .regularize_noise import NoiseRegularizer
|
24 |
-
from .reconstruction import (
|
25 |
-
EyeLoss,
|
26 |
-
FaceLoss,
|
27 |
-
create_perceptual_loss,
|
28 |
-
ReconstructionArguments,
|
29 |
-
)
|
30 |
-
|
31 |
-
class LossArguments:
|
32 |
-
@staticmethod
|
33 |
-
def add_arguments(parser: ArgumentParser):
|
34 |
-
ReconstructionArguments.add_arguments(parser)
|
35 |
-
|
36 |
-
parser.add_argument("--color_transfer", type=float, default=1e10, help="color transfer loss weight")
|
37 |
-
parser.add_argument("--eye", type=float, default=0.1, help="eye loss weight")
|
38 |
-
parser.add_argument('--noise_regularize', type=float, default=5e4)
|
39 |
-
# contextual loss
|
40 |
-
parser.add_argument("--contextual", type=float, default=0.1, help="contextual loss weight")
|
41 |
-
parser.add_argument("--cx_layers", nargs='*', help="contextual loss layers",
|
42 |
-
choices=['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4', 'relu5_4'],
|
43 |
-
default=['relu3_4', 'relu2_2', 'relu1_2'])
|
44 |
-
|
45 |
-
@staticmethod
|
46 |
-
def to_string(args: Namespace) -> str:
|
47 |
-
return (
|
48 |
-
ReconstructionArguments.to_string(args)
|
49 |
-
+ optional_string(args.eye > 0, f"-eye{args.eye}")
|
50 |
-
+ optional_string(args.color_transfer, f"-color{args.color_transfer:.1e}")
|
51 |
-
+ optional_string(
|
52 |
-
args.contextual,
|
53 |
-
f"-cx{args.contextual}({iterable_to_str(args.cx_layers)})"
|
54 |
-
)
|
55 |
-
#+ optional_string(args.mse, f"-mse{args.mse}")
|
56 |
-
+ optional_string(args.noise_regularize, f"-NR{args.noise_regularize:.1e}")
|
57 |
-
)
|
58 |
-
|
59 |
-
|
60 |
-
class BakedMultiContextualLoss(nn.Module):
|
61 |
-
"""Random sample different image patches for different vgg layers."""
|
62 |
-
def __init__(self, sibling: torch.Tensor, args: Namespace, size: int = 256):
|
63 |
-
super().__init__()
|
64 |
-
|
65 |
-
self.cxs = nn.ModuleList([ContextualLoss(use_vgg=True, vgg_layers=[layer])
|
66 |
-
for layer in args.cx_layers])
|
67 |
-
self.size = size
|
68 |
-
self.sibling = sibling.detach()
|
69 |
-
|
70 |
-
def forward(self, img: torch.Tensor):
|
71 |
-
cx_loss = 0
|
72 |
-
for cx in self.cxs:
|
73 |
-
h, w = np.random.randint(0, high=img.shape[-1] - self.size, size=2)
|
74 |
-
cx_loss = cx(self.sibling[..., h:h+self.size, w:w+self.size], img[..., h:h+self.size, w:w+self.size]) + cx_loss
|
75 |
-
return cx_loss
|
76 |
-
|
77 |
-
|
78 |
-
class BakedContextualLoss(ContextualLoss):
|
79 |
-
def __init__(self, sibling: torch.Tensor, args: Namespace, size: int = 256):
|
80 |
-
super().__init__(use_vgg=True, vgg_layers=args.cx_layers)
|
81 |
-
self.size = size
|
82 |
-
self.sibling = sibling.detach()
|
83 |
-
|
84 |
-
def forward(self, img: torch.Tensor):
|
85 |
-
h, w = np.random.randint(0, high=img.shape[-1] - self.size, size=2)
|
86 |
-
return super().forward(self.sibling[..., h:h+self.size, w:w+self.size], img[..., h:h+self.size, w:w+self.size])
|
87 |
-
|
88 |
-
|
89 |
-
class JointLoss(nn.Module):
|
90 |
-
def __init__(
|
91 |
-
self,
|
92 |
-
args: Namespace,
|
93 |
-
target: torch.Tensor,
|
94 |
-
sibling: Optional[torch.Tensor],
|
95 |
-
sibling_rgbs: Optional[Iterable[torch.Tensor]] = None,
|
96 |
-
):
|
97 |
-
super().__init__()
|
98 |
-
|
99 |
-
self.weights = {
|
100 |
-
"face": 1., "eye": args.eye,
|
101 |
-
"contextual": args.contextual, "color_transfer": args.color_transfer,
|
102 |
-
"noise": args.noise_regularize,
|
103 |
-
}
|
104 |
-
|
105 |
-
reconstruction = {}
|
106 |
-
if args.vgg > 0 or args.vggface > 0:
|
107 |
-
percept = create_perceptual_loss(args)
|
108 |
-
reconstruction.update(
|
109 |
-
{"face": FaceLoss(target, input_size=args.generator_size, size=args.recon_size, percept=percept)}
|
110 |
-
)
|
111 |
-
if args.eye > 0:
|
112 |
-
reconstruction.update(
|
113 |
-
{"eye": EyeLoss(target, input_size=args.generator_size, percept=percept)}
|
114 |
-
)
|
115 |
-
self.reconstruction = nn.ModuleDict(reconstruction)
|
116 |
-
|
117 |
-
exemplar = {}
|
118 |
-
if args.contextual > 0 and len(args.cx_layers) > 0:
|
119 |
-
assert sibling is not None
|
120 |
-
exemplar.update(
|
121 |
-
{"contextual": BakedContextualLoss(sibling, args)}
|
122 |
-
)
|
123 |
-
if args.color_transfer > 0:
|
124 |
-
assert sibling_rgbs is not None
|
125 |
-
self.sibling_rgbs = sibling_rgbs
|
126 |
-
exemplar.update(
|
127 |
-
{"color_transfer": ColorTransferLoss(init_rgbs=sibling_rgbs)}
|
128 |
-
)
|
129 |
-
self.exemplar = nn.ModuleDict(exemplar)
|
130 |
-
|
131 |
-
if args.noise_regularize > 0:
|
132 |
-
self.noise_criterion = NoiseRegularizer()
|
133 |
-
|
134 |
-
def forward(
|
135 |
-
self, img, degrade=None, noises=None, rgbs=None, rgb_level: Optional[int] = None
|
136 |
-
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
137 |
-
"""
|
138 |
-
Args:
|
139 |
-
rgbs: results from the ToRGB layers
|
140 |
-
"""
|
141 |
-
# TODO: add current optimization resolution for noises
|
142 |
-
|
143 |
-
losses = {}
|
144 |
-
|
145 |
-
# reconstruction losses
|
146 |
-
for name, criterion in self.reconstruction.items():
|
147 |
-
losses[name] = criterion(img, degrade=degrade)
|
148 |
-
|
149 |
-
# exemplar losses
|
150 |
-
if 'contextual' in self.exemplar:
|
151 |
-
losses["contextual"] = self.exemplar["contextual"](img)
|
152 |
-
if "color_transfer" in self.exemplar:
|
153 |
-
assert rgbs is not None
|
154 |
-
losses["color_transfer"] = self.exemplar["color_transfer"](rgbs, level=rgb_level)
|
155 |
-
|
156 |
-
# noise regularizer
|
157 |
-
if self.weights["noise"] > 0:
|
158 |
-
losses["noise"] = self.noise_criterion(noises)
|
159 |
-
|
160 |
-
total_loss = 0
|
161 |
-
for name, loss in losses.items():
|
162 |
-
total_loss = total_loss + self.weights[name] * loss
|
163 |
-
return total_loss, losses
|
164 |
-
|
165 |
-
def update_sibling(self, sibling: torch.Tensor):
|
166 |
-
assert "contextual" in self.exemplar
|
167 |
-
self.exemplar["contextual"].sibling = sibling.detach()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
losses/perceptual_loss.py
DELETED
@@ -1,111 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Code borrowed from https://gist.github.com/alper111/8233cdb0414b4cb5853f2f730ab95a49#file-vgg_perceptual_loss-py-L5
|
3 |
-
"""
|
4 |
-
import torch
|
5 |
-
import torchvision
|
6 |
-
from models.vggface import VGGFaceFeats
|
7 |
-
|
8 |
-
|
9 |
-
def cos_loss(fi, ft):
|
10 |
-
return 1 - torch.nn.functional.cosine_similarity(fi, ft).mean()
|
11 |
-
|
12 |
-
|
13 |
-
class VGGPerceptualLoss(torch.nn.Module):
|
14 |
-
def __init__(self, resize=False):
|
15 |
-
super(VGGPerceptualLoss, self).__init__()
|
16 |
-
blocks = []
|
17 |
-
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
|
18 |
-
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
|
19 |
-
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
|
20 |
-
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
|
21 |
-
for bl in blocks:
|
22 |
-
for p in bl:
|
23 |
-
p.requires_grad = False
|
24 |
-
self.blocks = torch.nn.ModuleList(blocks)
|
25 |
-
self.transform = torch.nn.functional.interpolate
|
26 |
-
self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
|
27 |
-
self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))
|
28 |
-
self.resize = resize
|
29 |
-
|
30 |
-
def forward(self, input, target, max_layer=4, cos_dist: bool = False):
|
31 |
-
target = (target + 1) * 0.5
|
32 |
-
input = (input + 1) * 0.5
|
33 |
-
|
34 |
-
if input.shape[1] != 3:
|
35 |
-
input = input.repeat(1, 3, 1, 1)
|
36 |
-
target = target.repeat(1, 3, 1, 1)
|
37 |
-
input = (input-self.mean) / self.std
|
38 |
-
target = (target-self.mean) / self.std
|
39 |
-
if self.resize:
|
40 |
-
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
|
41 |
-
target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
|
42 |
-
x = input
|
43 |
-
y = target
|
44 |
-
loss = 0.0
|
45 |
-
loss_func = cos_loss if cos_dist else torch.nn.functional.l1_loss
|
46 |
-
for bi, block in enumerate(self.blocks[:max_layer]):
|
47 |
-
x = block(x)
|
48 |
-
y = block(y)
|
49 |
-
loss += loss_func(x, y.detach())
|
50 |
-
return loss
|
51 |
-
|
52 |
-
|
53 |
-
class VGGFacePerceptualLoss(torch.nn.Module):
|
54 |
-
def __init__(self, weight_path: str = "checkpoint/vgg_face_dag.pt", resize: bool = False):
|
55 |
-
super().__init__()
|
56 |
-
self.vgg = VGGFaceFeats()
|
57 |
-
self.vgg.load_state_dict(torch.load(weight_path))
|
58 |
-
|
59 |
-
mean = torch.tensor(self.vgg.meta["mean"]).view(1, 3, 1, 1) / 255.0
|
60 |
-
self.register_buffer("mean", mean)
|
61 |
-
|
62 |
-
self.transform = torch.nn.functional.interpolate
|
63 |
-
self.resize = resize
|
64 |
-
|
65 |
-
def forward(self, input, target, max_layer: int = 4, cos_dist: bool = False):
|
66 |
-
target = (target + 1) * 0.5
|
67 |
-
input = (input + 1) * 0.5
|
68 |
-
|
69 |
-
# preprocessing
|
70 |
-
if input.shape[1] != 3:
|
71 |
-
input = input.repeat(1, 3, 1, 1)
|
72 |
-
target = target.repeat(1, 3, 1, 1)
|
73 |
-
input = input - self.mean
|
74 |
-
target = target - self.mean
|
75 |
-
if self.resize:
|
76 |
-
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
|
77 |
-
target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
|
78 |
-
|
79 |
-
input_feats = self.vgg(input)
|
80 |
-
target_feats = self.vgg(target)
|
81 |
-
|
82 |
-
loss_func = cos_loss if cos_dist else torch.nn.functional.l1_loss
|
83 |
-
# calc perceptual loss
|
84 |
-
loss = 0.0
|
85 |
-
for fi, ft in zip(input_feats[:max_layer], target_feats[:max_layer]):
|
86 |
-
loss = loss + loss_func(fi, ft.detach())
|
87 |
-
return loss
|
88 |
-
|
89 |
-
|
90 |
-
class PerceptualLoss(torch.nn.Module):
|
91 |
-
def __init__(
|
92 |
-
self, lambda_vggface: float = 0.025 / 0.15, lambda_vgg: float = 1, eps: float = 1e-8, cos_dist: bool = False
|
93 |
-
):
|
94 |
-
super().__init__()
|
95 |
-
self.register_buffer("lambda_vggface", torch.tensor(lambda_vggface))
|
96 |
-
self.register_buffer("lambda_vgg", torch.tensor(lambda_vgg))
|
97 |
-
self.cos_dist = cos_dist
|
98 |
-
|
99 |
-
if lambda_vgg > eps:
|
100 |
-
self.vgg = VGGPerceptualLoss()
|
101 |
-
if lambda_vggface > eps:
|
102 |
-
self.vggface = VGGFacePerceptualLoss()
|
103 |
-
|
104 |
-
def forward(self, input, target, eps=1e-8, use_vggface: bool = True, use_vgg=True, max_vgg_layer=4):
|
105 |
-
loss = 0.0
|
106 |
-
if self.lambda_vgg > eps and use_vgg:
|
107 |
-
loss = loss + self.lambda_vgg * self.vgg(input, target, max_layer=max_vgg_layer)
|
108 |
-
if self.lambda_vggface > eps and use_vggface:
|
109 |
-
loss = loss + self.lambda_vggface * self.vggface(input, target, cos_dist=self.cos_dist)
|
110 |
-
return loss
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
losses/reconstruction.py
DELETED
@@ -1,119 +0,0 @@
|
|
1 |
-
from argparse import (
|
2 |
-
ArgumentParser,
|
3 |
-
Namespace,
|
4 |
-
)
|
5 |
-
from typing import Optional
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
import torch
|
9 |
-
from torch import nn
|
10 |
-
|
11 |
-
from losses.perceptual_loss import PerceptualLoss
|
12 |
-
from models.degrade import Downsample
|
13 |
-
from utils.misc import optional_string
|
14 |
-
|
15 |
-
|
16 |
-
class ReconstructionArguments:
|
17 |
-
@staticmethod
|
18 |
-
def add_arguments(parser: ArgumentParser):
|
19 |
-
parser.add_argument("--vggface", type=float, default=0.3, help="vggface")
|
20 |
-
parser.add_argument("--vgg", type=float, default=1, help="vgg")
|
21 |
-
parser.add_argument('--recon_size', type=int, default=256, help="size for face reconstruction loss")
|
22 |
-
|
23 |
-
@staticmethod
|
24 |
-
def to_string(args: Namespace) -> str:
|
25 |
-
return (
|
26 |
-
f"s{args.recon_size}"
|
27 |
-
+ optional_string(args.vgg > 0, f"-vgg{args.vgg}")
|
28 |
-
+ optional_string(args.vggface > 0, f"-vggface{args.vggface}")
|
29 |
-
)
|
30 |
-
|
31 |
-
|
32 |
-
def create_perceptual_loss(args: Namespace):
|
33 |
-
return PerceptualLoss(lambda_vgg=args.vgg, lambda_vggface=args.vggface, cos_dist=False)
|
34 |
-
|
35 |
-
|
36 |
-
class EyeLoss(nn.Module):
|
37 |
-
def __init__(
|
38 |
-
self,
|
39 |
-
target: torch.Tensor,
|
40 |
-
input_size: int = 1024,
|
41 |
-
input_channels: int = 3,
|
42 |
-
percept: Optional[nn.Module] = None,
|
43 |
-
args: Optional[Namespace] = None
|
44 |
-
):
|
45 |
-
"""
|
46 |
-
target: target image
|
47 |
-
"""
|
48 |
-
assert not (percept is None and args is None)
|
49 |
-
|
50 |
-
super().__init__()
|
51 |
-
|
52 |
-
self.target = target
|
53 |
-
|
54 |
-
target_size = target.shape[-1]
|
55 |
-
self.downsample = Downsample(input_size, target_size, input_channels) \
|
56 |
-
if target_size != input_size else (lambda x: x)
|
57 |
-
|
58 |
-
self.percept = percept if percept is not None else create_perceptual_loss(args)
|
59 |
-
|
60 |
-
eye_size = np.array((224, 224))
|
61 |
-
btlrs = []
|
62 |
-
for sgn in [1, -1]:
|
63 |
-
center = np.array((480, 384 * sgn)) # (y, x)
|
64 |
-
b, t = center[0] - eye_size[0] // 2, center[0] + eye_size[0] // 2
|
65 |
-
l, r = center[1] - eye_size[1] // 2, center[1] + eye_size[1] // 2
|
66 |
-
btlrs.append((np.array((b, t, l, r)) / 1024 * target_size).astype(int))
|
67 |
-
self.btlrs = np.stack(btlrs, axis=0)
|
68 |
-
|
69 |
-
def forward(self, img: torch.Tensor, degrade: nn.Module = None):
|
70 |
-
"""
|
71 |
-
img: it should be the degraded version of the generated image
|
72 |
-
"""
|
73 |
-
if degrade is not None:
|
74 |
-
img = degrade(img, downsample=self.downsample)
|
75 |
-
|
76 |
-
loss = 0
|
77 |
-
for (b, t, l, r) in self.btlrs:
|
78 |
-
loss = loss + self.percept(
|
79 |
-
img[:, :, b:t, l:r], self.target[:, :, b:t, l:r],
|
80 |
-
use_vggface=False, max_vgg_layer=4,
|
81 |
-
# use_vgg=False,
|
82 |
-
)
|
83 |
-
return loss
|
84 |
-
|
85 |
-
|
86 |
-
class FaceLoss(nn.Module):
|
87 |
-
def __init__(
|
88 |
-
self,
|
89 |
-
target: torch.Tensor,
|
90 |
-
input_size: int = 1024,
|
91 |
-
input_channels: int = 3,
|
92 |
-
size: int = 256,
|
93 |
-
percept: Optional[nn.Module] = None,
|
94 |
-
args: Optional[Namespace] = None
|
95 |
-
):
|
96 |
-
"""
|
97 |
-
target: target image
|
98 |
-
"""
|
99 |
-
assert not (percept is None and args is None)
|
100 |
-
|
101 |
-
super().__init__()
|
102 |
-
|
103 |
-
target_size = target.shape[-1]
|
104 |
-
self.target = target if target_size == size \
|
105 |
-
else Downsample(target_size, size, target.shape[1]).to(target.device)(target)
|
106 |
-
|
107 |
-
self.downsample = Downsample(input_size, size, input_channels) \
|
108 |
-
if size != input_size else (lambda x: x)
|
109 |
-
|
110 |
-
self.percept = percept if percept is not None else create_perceptual_loss(args)
|
111 |
-
|
112 |
-
def forward(self, img: torch.Tensor, degrade: nn.Module = None):
|
113 |
-
"""
|
114 |
-
img: it should be the degraded version of the generated image
|
115 |
-
"""
|
116 |
-
if degrade is not None:
|
117 |
-
img = degrade(img, downsample=self.downsample)
|
118 |
-
loss = self.percept(img, self.target)
|
119 |
-
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
losses/regularize_noise.py
DELETED
@@ -1,37 +0,0 @@
|
|
1 |
-
from typing import Iterable
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from torch import nn
|
5 |
-
|
6 |
-
|
7 |
-
class NoiseRegularizer(nn.Module):
|
8 |
-
def forward(self, noises: Iterable[torch.Tensor]):
|
9 |
-
loss = 0
|
10 |
-
|
11 |
-
for noise in noises:
|
12 |
-
size = noise.shape[2]
|
13 |
-
|
14 |
-
while True:
|
15 |
-
loss = (
|
16 |
-
loss
|
17 |
-
+ (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
|
18 |
-
+ (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
|
19 |
-
)
|
20 |
-
|
21 |
-
if size <= 8:
|
22 |
-
break
|
23 |
-
|
24 |
-
noise = noise.reshape([1, 1, size // 2, 2, size // 2, 2])
|
25 |
-
noise = noise.mean([3, 5])
|
26 |
-
size //= 2
|
27 |
-
|
28 |
-
return loss
|
29 |
-
|
30 |
-
@staticmethod
|
31 |
-
def normalize(noises: Iterable[torch.Tensor]):
|
32 |
-
for noise in noises:
|
33 |
-
mean = noise.mean()
|
34 |
-
std = noise.std()
|
35 |
-
|
36 |
-
noise.data.add_(-mean).div_(std)
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/__init__.py
DELETED
File without changes
|
models/degrade.py
DELETED
@@ -1,122 +0,0 @@
|
|
1 |
-
from argparse import (
|
2 |
-
ArgumentParser,
|
3 |
-
Namespace,
|
4 |
-
)
|
5 |
-
|
6 |
-
import torch
|
7 |
-
from torch import nn
|
8 |
-
from torch.nn import functional as F
|
9 |
-
|
10 |
-
from utils.misc import optional_string
|
11 |
-
|
12 |
-
from .gaussian_smoothing import GaussianSmoothing
|
13 |
-
|
14 |
-
|
15 |
-
class DegradeArguments:
|
16 |
-
@staticmethod
|
17 |
-
def add_arguments(parser: ArgumentParser):
|
18 |
-
parser.add_argument('--spectral_sensitivity', choices=["g", "b", "gb"], default="g",
|
19 |
-
help="Type of spectral sensitivity. g: grayscale (panchromatic), b: blue-sensitive, gb: green+blue (orthochromatic)")
|
20 |
-
parser.add_argument('--gaussian', type=float, default=0,
|
21 |
-
help="estimated blur radius in pixels of the input photo if it is scaled to 1024x1024")
|
22 |
-
|
23 |
-
@staticmethod
|
24 |
-
def to_string(args: Namespace) -> str:
|
25 |
-
return (
|
26 |
-
f"{args.spectral_sensitivity}"
|
27 |
-
+ optional_string(args.gaussian > 0, f"-G{args.gaussian}")
|
28 |
-
)
|
29 |
-
|
30 |
-
|
31 |
-
class CameraResponse(nn.Module):
|
32 |
-
def __init__(self):
|
33 |
-
super().__init__()
|
34 |
-
|
35 |
-
self.register_parameter("gamma", nn.Parameter(torch.ones(1)))
|
36 |
-
self.register_parameter("offset", nn.Parameter(torch.zeros(1)))
|
37 |
-
self.register_parameter("gain", nn.Parameter(torch.ones(1)))
|
38 |
-
|
39 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
40 |
-
x = torch.clamp(x, max=1, min=-1+1e-2)
|
41 |
-
x = (1 + x) * 0.5
|
42 |
-
x = self.offset + self.gain * torch.pow(x, self.gamma)
|
43 |
-
x = (x - 0.5) * 2
|
44 |
-
# b = torch.clamp(b, max=1, min=-1)
|
45 |
-
return x
|
46 |
-
|
47 |
-
|
48 |
-
class SpectralResponse(nn.Module):
|
49 |
-
# TODO: use enum instead for color mode
|
50 |
-
def __init__(self, spectral_sensitivity: str = 'b'):
|
51 |
-
assert spectral_sensitivity in ("g", "b", "gb"), f"spectral_sensitivity {spectral_sensitivity} is not implemented."
|
52 |
-
|
53 |
-
super().__init__()
|
54 |
-
|
55 |
-
self.spectral_sensitivity = spectral_sensitivity
|
56 |
-
|
57 |
-
if self.spectral_sensitivity == "g":
|
58 |
-
self.register_buffer("to_gray", torch.tensor([0.299, 0.587, 0.114]).reshape(1, -1, 1, 1))
|
59 |
-
|
60 |
-
def forward(self, rgb: torch.Tensor) -> torch.Tensor:
|
61 |
-
if self.spectral_sensitivity == "b":
|
62 |
-
x = rgb[:, -1:]
|
63 |
-
elif self.spectral_sensitivity == "gb":
|
64 |
-
x = (rgb[:, 1:2] + rgb[:, -1:]) * 0.5
|
65 |
-
else:
|
66 |
-
assert self.spectral_sensitivity == "g"
|
67 |
-
x = (rgb * self.to_gray).sum(dim=1, keepdim=True)
|
68 |
-
return x
|
69 |
-
|
70 |
-
|
71 |
-
class Downsample(nn.Module):
|
72 |
-
"""Antialiasing downsampling"""
|
73 |
-
def __init__(self, input_size: int, output_size: int, channels: int):
|
74 |
-
super().__init__()
|
75 |
-
if input_size % output_size == 0:
|
76 |
-
self.stride = input_size // output_size
|
77 |
-
self.grid = None
|
78 |
-
else:
|
79 |
-
self.stride = 1
|
80 |
-
step = input_size / output_size
|
81 |
-
x = torch.arange(output_size) * step
|
82 |
-
Y, X = torch.meshgrid(x, x)
|
83 |
-
grid = torch.stack((X, Y), dim=-1)
|
84 |
-
grid /= torch.Tensor((input_size - 1, input_size - 1)).view(1, 1, -1)
|
85 |
-
grid = grid * 2 - 1
|
86 |
-
self.register_buffer("grid", grid)
|
87 |
-
sigma = 0.5 * input_size / output_size
|
88 |
-
#print(f"{input_size} -> {output_size}: sigma={sigma}")
|
89 |
-
self.blur = GaussianSmoothing(channels, int(2 * (sigma * 2) + 1 + 0.5), sigma)
|
90 |
-
|
91 |
-
def forward(self, im: torch.Tensor):
|
92 |
-
out = self.blur(im, stride=self.stride)
|
93 |
-
if self.grid is not None:
|
94 |
-
out = F.grid_sample(out, self.grid[None].expand(im.shape[0], -1, -1, -1))
|
95 |
-
return out
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
class Degrade(nn.Module):
|
100 |
-
"""
|
101 |
-
Simulate the degradation of antique film
|
102 |
-
"""
|
103 |
-
def __init__(self, args:Namespace):
|
104 |
-
super().__init__()
|
105 |
-
self.srf = SpectralResponse(args.spectral_sensitivity)
|
106 |
-
self.crf = CameraResponse()
|
107 |
-
self.gaussian = None
|
108 |
-
if args.gaussian is not None and args.gaussian > 0:
|
109 |
-
self.gaussian = GaussianSmoothing(3, 2 * int(args.gaussian * 2 + 0.5) + 1, args.gaussian)
|
110 |
-
|
111 |
-
def forward(self, img: torch.Tensor, downsample: nn.Module = None):
|
112 |
-
if self.gaussian is not None:
|
113 |
-
img = self.gaussian(img)
|
114 |
-
if downsample is not None:
|
115 |
-
img = downsample(img)
|
116 |
-
img = self.srf(img)
|
117 |
-
img = self.crf(img)
|
118 |
-
# Note that I changed it back to 3 channels
|
119 |
-
return img.repeat((1, 3, 1, 1)) if img.shape[1] == 1 else img
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/encoder.py
DELETED
@@ -1,66 +0,0 @@
|
|
1 |
-
from argparse import Namespace, ArgumentParser
|
2 |
-
from functools import partial
|
3 |
-
|
4 |
-
from torch import nn
|
5 |
-
|
6 |
-
from .resnet import ResNetBasicBlock, activation_func, norm_module, Conv2dAuto
|
7 |
-
|
8 |
-
|
9 |
-
def add_arguments(parser: ArgumentParser) -> ArgumentParser:
|
10 |
-
parser.add_argument("--latent_size", type=int, default=512, help="latent size")
|
11 |
-
return parser
|
12 |
-
|
13 |
-
|
14 |
-
def create_model(args) -> nn.Module:
|
15 |
-
in_channels = 3 if "rgb" in args and args.rgb else 1
|
16 |
-
return Encoder(in_channels, args.encoder_size, latent_size=args.latent_size)
|
17 |
-
|
18 |
-
|
19 |
-
class Flatten(nn.Module):
|
20 |
-
def forward(self, input_):
|
21 |
-
return input_.view(input_.size(0), -1)
|
22 |
-
|
23 |
-
|
24 |
-
class Encoder(nn.Module):
|
25 |
-
def __init__(
|
26 |
-
self, in_channels: int, size: int, latent_size: int = 512,
|
27 |
-
activation: str = 'leaky_relu', norm: str = "instance"
|
28 |
-
):
|
29 |
-
super().__init__()
|
30 |
-
|
31 |
-
out_channels0 = 64
|
32 |
-
norm_m = norm_module(norm)
|
33 |
-
self.conv0 = nn.Sequential(
|
34 |
-
Conv2dAuto(in_channels, out_channels0, kernel_size=5),
|
35 |
-
norm_m(out_channels0),
|
36 |
-
activation_func(activation),
|
37 |
-
)
|
38 |
-
|
39 |
-
pool_kernel = 2
|
40 |
-
self.pool = nn.AvgPool2d(pool_kernel)
|
41 |
-
|
42 |
-
num_channels = [128, 256, 512, 512]
|
43 |
-
# FIXME: this is a hack
|
44 |
-
if size >= 256:
|
45 |
-
num_channels.append(512)
|
46 |
-
|
47 |
-
residual = partial(ResNetBasicBlock, activation=activation, norm=norm, bias=True)
|
48 |
-
residual_blocks = nn.ModuleList()
|
49 |
-
for in_channel, out_channel in zip([out_channels0] + num_channels[:-1], num_channels):
|
50 |
-
residual_blocks.append(residual(in_channel, out_channel))
|
51 |
-
residual_blocks.append(nn.AvgPool2d(pool_kernel))
|
52 |
-
self.residual_blocks = nn.Sequential(*residual_blocks)
|
53 |
-
|
54 |
-
self.last = nn.Sequential(
|
55 |
-
nn.ReLU(),
|
56 |
-
nn.AvgPool2d(4), # TODO: not sure whehter this would cause problem
|
57 |
-
Flatten(),
|
58 |
-
nn.Linear(num_channels[-1], latent_size, bias=True)
|
59 |
-
)
|
60 |
-
|
61 |
-
def forward(self, input_):
|
62 |
-
out = self.conv0(input_)
|
63 |
-
out = self.pool(out)
|
64 |
-
out = self.residual_blocks(out)
|
65 |
-
out = self.last(out)
|
66 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/gaussian_smoothing.py
DELETED
@@ -1,74 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
import numbers
|
3 |
-
import torch
|
4 |
-
from torch import nn
|
5 |
-
from torch.nn import functional as F
|
6 |
-
|
7 |
-
|
8 |
-
class GaussianSmoothing(nn.Module):
|
9 |
-
"""
|
10 |
-
Apply gaussian smoothing on a
|
11 |
-
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
|
12 |
-
in the input using a depthwise convolution.
|
13 |
-
Arguments:
|
14 |
-
channels (int, sequence): Number of channels of the input tensors. Output will
|
15 |
-
have this number of channels as well.
|
16 |
-
kernel_size (int, sequence): Size of the gaussian kernel.
|
17 |
-
sigma (float, sequence): Standard deviation of the gaussian kernel.
|
18 |
-
dim (int, optional): The number of dimensions of the data.
|
19 |
-
Default value is 2 (spatial).
|
20 |
-
"""
|
21 |
-
def __init__(self, channels, kernel_size, sigma, dim=2):
|
22 |
-
super(GaussianSmoothing, self).__init__()
|
23 |
-
if isinstance(kernel_size, numbers.Number):
|
24 |
-
kernel_size = [kernel_size] * dim
|
25 |
-
if isinstance(sigma, numbers.Number):
|
26 |
-
sigma = [sigma] * dim
|
27 |
-
|
28 |
-
# The gaussian kernel is the product of the
|
29 |
-
# gaussian function of each dimension.
|
30 |
-
kernel = 1
|
31 |
-
meshgrids = torch.meshgrid(
|
32 |
-
[
|
33 |
-
torch.arange(size, dtype=torch.float32)
|
34 |
-
for size in kernel_size
|
35 |
-
]
|
36 |
-
)
|
37 |
-
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
38 |
-
mean = (size - 1) / 2
|
39 |
-
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
|
40 |
-
torch.exp(-((mgrid - mean) / (2 * std)) ** 2)
|
41 |
-
|
42 |
-
# Make sure sum of values in gaussian kernel equals 1.
|
43 |
-
kernel = kernel / torch.sum(kernel)
|
44 |
-
|
45 |
-
# Reshape to depthwise convolutional weight
|
46 |
-
kernel = kernel.view(1, 1, *kernel.size())
|
47 |
-
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
48 |
-
|
49 |
-
self.register_buffer('weight', kernel)
|
50 |
-
self.groups = channels
|
51 |
-
|
52 |
-
if dim == 1:
|
53 |
-
self.conv = F.conv1d
|
54 |
-
elif dim == 2:
|
55 |
-
self.conv = F.conv2d
|
56 |
-
elif dim == 3:
|
57 |
-
self.conv = F.conv3d
|
58 |
-
else:
|
59 |
-
raise RuntimeError(
|
60 |
-
'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
|
61 |
-
)
|
62 |
-
|
63 |
-
def forward(self, input, stride: int = 1):
|
64 |
-
"""
|
65 |
-
Apply gaussian filter to input.
|
66 |
-
Arguments:
|
67 |
-
input (torch.Tensor): Input to apply gaussian filter on.
|
68 |
-
stride for applying conv
|
69 |
-
Returns:
|
70 |
-
filtered (torch.Tensor): Filtered output.
|
71 |
-
"""
|
72 |
-
padding = (self.weight.shape[-1] - 1) // 2
|
73 |
-
return self.conv(input, weight=self.weight, groups=self.groups, padding=padding, stride=stride)
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/resnet.py
DELETED
@@ -1,99 +0,0 @@
|
|
1 |
-
from functools import partial
|
2 |
-
|
3 |
-
from torch import nn
|
4 |
-
|
5 |
-
|
6 |
-
def activation_func(activation: str):
|
7 |
-
return nn.ModuleDict([
|
8 |
-
['relu', nn.ReLU(inplace=True)],
|
9 |
-
['leaky_relu', nn.LeakyReLU(negative_slope=0.01, inplace=True)],
|
10 |
-
['selu', nn.SELU(inplace=True)],
|
11 |
-
['none', nn.Identity()]
|
12 |
-
])[activation]
|
13 |
-
|
14 |
-
|
15 |
-
def norm_module(norm: str):
|
16 |
-
return {
|
17 |
-
'batch': nn.BatchNorm2d,
|
18 |
-
'instance': nn.InstanceNorm2d,
|
19 |
-
}[norm]
|
20 |
-
|
21 |
-
|
22 |
-
class Conv2dAuto(nn.Conv2d):
|
23 |
-
def __init__(self, *args, **kwargs):
|
24 |
-
super().__init__(*args, **kwargs)
|
25 |
-
# dynamic add padding based on the kernel_size
|
26 |
-
self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2)
|
27 |
-
|
28 |
-
|
29 |
-
conv3x3 = partial(Conv2dAuto, kernel_size=3)
|
30 |
-
|
31 |
-
|
32 |
-
class ResidualBlock(nn.Module):
|
33 |
-
def __init__(self, in_channels: int, out_channels: int, activation: str = 'relu'):
|
34 |
-
super().__init__()
|
35 |
-
self.in_channels, self.out_channels = in_channels, out_channels
|
36 |
-
self.blocks = nn.Identity()
|
37 |
-
self.activate = activation_func(activation)
|
38 |
-
self.shortcut = nn.Identity()
|
39 |
-
|
40 |
-
def forward(self, x):
|
41 |
-
residual = x
|
42 |
-
if self.should_apply_shortcut:
|
43 |
-
residual = self.shortcut(x)
|
44 |
-
x = self.blocks(x)
|
45 |
-
x += residual
|
46 |
-
x = self.activate(x)
|
47 |
-
return x
|
48 |
-
|
49 |
-
@property
|
50 |
-
def should_apply_shortcut(self):
|
51 |
-
return self.in_channels != self.out_channels
|
52 |
-
|
53 |
-
|
54 |
-
class ResNetResidualBlock(ResidualBlock):
|
55 |
-
def __init__(
|
56 |
-
self, in_channels: int, out_channels: int,
|
57 |
-
expansion: int = 1, downsampling: int = 1,
|
58 |
-
conv=conv3x3, norm: str = 'batch', *args, **kwargs
|
59 |
-
):
|
60 |
-
super().__init__(in_channels, out_channels, *args, **kwargs)
|
61 |
-
self.expansion, self.downsampling = expansion, downsampling
|
62 |
-
self.conv, self.norm = conv, norm_module(norm)
|
63 |
-
self.shortcut = nn.Sequential(
|
64 |
-
nn.Conv2d(self.in_channels, self.expanded_channels, kernel_size=1,
|
65 |
-
stride=self.downsampling, bias=False),
|
66 |
-
self.norm(self.expanded_channels)) if self.should_apply_shortcut else None
|
67 |
-
|
68 |
-
@property
|
69 |
-
def expanded_channels(self):
|
70 |
-
return self.out_channels * self.expansion
|
71 |
-
|
72 |
-
@property
|
73 |
-
def should_apply_shortcut(self):
|
74 |
-
return self.in_channels != self.expanded_channels
|
75 |
-
|
76 |
-
|
77 |
-
def conv_norm(in_channels: int, out_channels: int, conv, norm, *args, **kwargs):
|
78 |
-
return nn.Sequential(conv(in_channels, out_channels, *args, **kwargs), norm(out_channels))
|
79 |
-
|
80 |
-
|
81 |
-
class ResNetBasicBlock(ResNetResidualBlock):
|
82 |
-
"""
|
83 |
-
Basic ResNet block composed by two layers of 3x3conv/batchnorm/activation
|
84 |
-
"""
|
85 |
-
expansion = 1
|
86 |
-
|
87 |
-
def __init__(
|
88 |
-
self, in_channels: int, out_channels: int, bias: bool = False, *args, **kwargs
|
89 |
-
):
|
90 |
-
super().__init__(in_channels, out_channels, *args, **kwargs)
|
91 |
-
self.blocks = nn.Sequential(
|
92 |
-
conv_norm(
|
93 |
-
self.in_channels, self.out_channels, conv=self.conv, norm=self.norm,
|
94 |
-
bias=bias, stride=self.downsampling
|
95 |
-
),
|
96 |
-
self.activate,
|
97 |
-
conv_norm(self.out_channels, self.expanded_channels, conv=self.conv, norm=self.norm, bias=bias),
|
98 |
-
)
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/vggface.py
DELETED
@@ -1,150 +0,0 @@
|
|
1 |
-
|
2 |
-
import torch
|
3 |
-
import torch.nn as nn
|
4 |
-
|
5 |
-
|
6 |
-
class Vgg_face_dag(nn.Module):
|
7 |
-
|
8 |
-
def __init__(self):
|
9 |
-
super(Vgg_face_dag, self).__init__()
|
10 |
-
self.meta = {'mean': [129.186279296875, 104.76238250732422, 93.59396362304688],
|
11 |
-
'std': [1, 1, 1],
|
12 |
-
'imageSize': [224, 224, 3]}
|
13 |
-
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
|
14 |
-
self.relu1_1 = nn.ReLU(inplace=True)
|
15 |
-
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
|
16 |
-
self.relu1_2 = nn.ReLU(inplace=True)
|
17 |
-
self.pool1 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
|
18 |
-
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
|
19 |
-
self.relu2_1 = nn.ReLU(inplace=True)
|
20 |
-
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
|
21 |
-
self.relu2_2 = nn.ReLU(inplace=True)
|
22 |
-
self.pool2 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
|
23 |
-
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
|
24 |
-
self.relu3_1 = nn.ReLU(inplace=True)
|
25 |
-
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
|
26 |
-
self.relu3_2 = nn.ReLU(inplace=True)
|
27 |
-
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
|
28 |
-
self.relu3_3 = nn.ReLU(inplace=True)
|
29 |
-
self.pool3 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
|
30 |
-
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
|
31 |
-
self.relu4_1 = nn.ReLU(inplace=True)
|
32 |
-
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
|
33 |
-
self.relu4_2 = nn.ReLU(inplace=True)
|
34 |
-
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
|
35 |
-
self.relu4_3 = nn.ReLU(inplace=True)
|
36 |
-
self.pool4 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
|
37 |
-
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
|
38 |
-
self.relu5_1 = nn.ReLU(inplace=True)
|
39 |
-
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
|
40 |
-
self.relu5_2 = nn.ReLU(inplace=True)
|
41 |
-
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
|
42 |
-
self.relu5_3 = nn.ReLU(inplace=True)
|
43 |
-
self.pool5 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
|
44 |
-
self.fc6 = nn.Linear(in_features=25088, out_features=4096, bias=True)
|
45 |
-
self.relu6 = nn.ReLU(inplace=True)
|
46 |
-
self.dropout6 = nn.Dropout(p=0.5)
|
47 |
-
self.fc7 = nn.Linear(in_features=4096, out_features=4096, bias=True)
|
48 |
-
self.relu7 = nn.ReLU(inplace=True)
|
49 |
-
self.dropout7 = nn.Dropout(p=0.5)
|
50 |
-
self.fc8 = nn.Linear(in_features=4096, out_features=2622, bias=True)
|
51 |
-
|
52 |
-
def forward(self, x0):
|
53 |
-
x1 = self.conv1_1(x0)
|
54 |
-
x2 = self.relu1_1(x1)
|
55 |
-
x3 = self.conv1_2(x2)
|
56 |
-
x4 = self.relu1_2(x3)
|
57 |
-
x5 = self.pool1(x4)
|
58 |
-
x6 = self.conv2_1(x5)
|
59 |
-
x7 = self.relu2_1(x6)
|
60 |
-
x8 = self.conv2_2(x7)
|
61 |
-
x9 = self.relu2_2(x8)
|
62 |
-
x10 = self.pool2(x9)
|
63 |
-
x11 = self.conv3_1(x10)
|
64 |
-
x12 = self.relu3_1(x11)
|
65 |
-
x13 = self.conv3_2(x12)
|
66 |
-
x14 = self.relu3_2(x13)
|
67 |
-
x15 = self.conv3_3(x14)
|
68 |
-
x16 = self.relu3_3(x15)
|
69 |
-
x17 = self.pool3(x16)
|
70 |
-
x18 = self.conv4_1(x17)
|
71 |
-
x19 = self.relu4_1(x18)
|
72 |
-
x20 = self.conv4_2(x19)
|
73 |
-
x21 = self.relu4_2(x20)
|
74 |
-
x22 = self.conv4_3(x21)
|
75 |
-
x23 = self.relu4_3(x22)
|
76 |
-
x24 = self.pool4(x23)
|
77 |
-
x25 = self.conv5_1(x24)
|
78 |
-
x26 = self.relu5_1(x25)
|
79 |
-
x27 = self.conv5_2(x26)
|
80 |
-
x28 = self.relu5_2(x27)
|
81 |
-
x29 = self.conv5_3(x28)
|
82 |
-
x30 = self.relu5_3(x29)
|
83 |
-
x31_preflatten = self.pool5(x30)
|
84 |
-
x31 = x31_preflatten.view(x31_preflatten.size(0), -1)
|
85 |
-
x32 = self.fc6(x31)
|
86 |
-
x33 = self.relu6(x32)
|
87 |
-
x34 = self.dropout6(x33)
|
88 |
-
x35 = self.fc7(x34)
|
89 |
-
x36 = self.relu7(x35)
|
90 |
-
x37 = self.dropout7(x36)
|
91 |
-
x38 = self.fc8(x37)
|
92 |
-
return x38
|
93 |
-
|
94 |
-
|
95 |
-
def vgg_face_dag(weights_path=None, **kwargs):
|
96 |
-
"""
|
97 |
-
load imported model instance
|
98 |
-
|
99 |
-
Args:
|
100 |
-
weights_path (str): If set, loads model weights from the given path
|
101 |
-
"""
|
102 |
-
model = Vgg_face_dag()
|
103 |
-
if weights_path:
|
104 |
-
state_dict = torch.load(weights_path)
|
105 |
-
model.load_state_dict(state_dict)
|
106 |
-
return model
|
107 |
-
|
108 |
-
|
109 |
-
class VGGFaceFeats(Vgg_face_dag):
|
110 |
-
def forward(self, x0):
|
111 |
-
x1 = self.conv1_1(x0)
|
112 |
-
x2 = self.relu1_1(x1)
|
113 |
-
x3 = self.conv1_2(x2)
|
114 |
-
x4 = self.relu1_2(x3)
|
115 |
-
x5 = self.pool1(x4)
|
116 |
-
x6 = self.conv2_1(x5)
|
117 |
-
x7 = self.relu2_1(x6)
|
118 |
-
x8 = self.conv2_2(x7)
|
119 |
-
x9 = self.relu2_2(x8)
|
120 |
-
x10 = self.pool2(x9)
|
121 |
-
x11 = self.conv3_1(x10)
|
122 |
-
x12 = self.relu3_1(x11)
|
123 |
-
x13 = self.conv3_2(x12)
|
124 |
-
x14 = self.relu3_2(x13)
|
125 |
-
x15 = self.conv3_3(x14)
|
126 |
-
x16 = self.relu3_3(x15)
|
127 |
-
x17 = self.pool3(x16)
|
128 |
-
x18 = self.conv4_1(x17)
|
129 |
-
x19 = self.relu4_1(x18)
|
130 |
-
x20 = self.conv4_2(x19)
|
131 |
-
x21 = self.relu4_2(x20)
|
132 |
-
x22 = self.conv4_3(x21)
|
133 |
-
x23 = self.relu4_3(x22)
|
134 |
-
x24 = self.pool4(x23)
|
135 |
-
x25 = self.conv5_1(x24)
|
136 |
-
# x26 = self.relu5_1(x25)
|
137 |
-
# x27 = self.conv5_2(x26)
|
138 |
-
# x28 = self.relu5_2(x27)
|
139 |
-
# x29 = self.conv5_3(x28)
|
140 |
-
# x30 = self.relu5_3(x29)
|
141 |
-
# x31_preflatten = self.pool5(x30)
|
142 |
-
# x31 = x31_preflatten.view(x31_preflatten.size(0), -1)
|
143 |
-
# x32 = self.fc6(x31)
|
144 |
-
# x33 = self.relu6(x32)
|
145 |
-
# x34 = self.dropout6(x33)
|
146 |
-
# x35 = self.fc7(x34)
|
147 |
-
# x36 = self.relu7(x35)
|
148 |
-
# x37 = self.dropout7(x36)
|
149 |
-
# x38 = self.fc8(x37)
|
150 |
-
return x1, x6, x11, x18, x25
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
op/upfirdn2d_kernel.cu
DELETED
@@ -1,272 +0,0 @@
|
|
1 |
-
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
-
//
|
3 |
-
// This work is made available under the Nvidia Source Code License-NC.
|
4 |
-
// To view a copy of this license, visit
|
5 |
-
// https://nvlabs.github.io/stylegan2/license.html
|
6 |
-
|
7 |
-
#include <torch/types.h>
|
8 |
-
|
9 |
-
#include <ATen/ATen.h>
|
10 |
-
#include <ATen/AccumulateType.h>
|
11 |
-
#include <ATen/cuda/CUDAContext.h>
|
12 |
-
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
13 |
-
|
14 |
-
#include <cuda.h>
|
15 |
-
#include <cuda_runtime.h>
|
16 |
-
|
17 |
-
|
18 |
-
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
|
19 |
-
int c = a / b;
|
20 |
-
|
21 |
-
if (c * b > a) {
|
22 |
-
c--;
|
23 |
-
}
|
24 |
-
|
25 |
-
return c;
|
26 |
-
}
|
27 |
-
|
28 |
-
|
29 |
-
struct UpFirDn2DKernelParams {
|
30 |
-
int up_x;
|
31 |
-
int up_y;
|
32 |
-
int down_x;
|
33 |
-
int down_y;
|
34 |
-
int pad_x0;
|
35 |
-
int pad_x1;
|
36 |
-
int pad_y0;
|
37 |
-
int pad_y1;
|
38 |
-
|
39 |
-
int major_dim;
|
40 |
-
int in_h;
|
41 |
-
int in_w;
|
42 |
-
int minor_dim;
|
43 |
-
int kernel_h;
|
44 |
-
int kernel_w;
|
45 |
-
int out_h;
|
46 |
-
int out_w;
|
47 |
-
int loop_major;
|
48 |
-
int loop_x;
|
49 |
-
};
|
50 |
-
|
51 |
-
|
52 |
-
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y, int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
|
53 |
-
__global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) {
|
54 |
-
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
|
55 |
-
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
|
56 |
-
|
57 |
-
__shared__ volatile float sk[kernel_h][kernel_w];
|
58 |
-
__shared__ volatile float sx[tile_in_h][tile_in_w];
|
59 |
-
|
60 |
-
int minor_idx = blockIdx.x;
|
61 |
-
int tile_out_y = minor_idx / p.minor_dim;
|
62 |
-
minor_idx -= tile_out_y * p.minor_dim;
|
63 |
-
tile_out_y *= tile_out_h;
|
64 |
-
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
|
65 |
-
int major_idx_base = blockIdx.z * p.loop_major;
|
66 |
-
|
67 |
-
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) {
|
68 |
-
return;
|
69 |
-
}
|
70 |
-
|
71 |
-
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) {
|
72 |
-
int ky = tap_idx / kernel_w;
|
73 |
-
int kx = tap_idx - ky * kernel_w;
|
74 |
-
scalar_t v = 0.0;
|
75 |
-
|
76 |
-
if (kx < p.kernel_w & ky < p.kernel_h) {
|
77 |
-
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
|
78 |
-
}
|
79 |
-
|
80 |
-
sk[ky][kx] = v;
|
81 |
-
}
|
82 |
-
|
83 |
-
for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) {
|
84 |
-
for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) {
|
85 |
-
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
|
86 |
-
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
|
87 |
-
int tile_in_x = floor_div(tile_mid_x, up_x);
|
88 |
-
int tile_in_y = floor_div(tile_mid_y, up_y);
|
89 |
-
|
90 |
-
__syncthreads();
|
91 |
-
|
92 |
-
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) {
|
93 |
-
int rel_in_y = in_idx / tile_in_w;
|
94 |
-
int rel_in_x = in_idx - rel_in_y * tile_in_w;
|
95 |
-
int in_x = rel_in_x + tile_in_x;
|
96 |
-
int in_y = rel_in_y + tile_in_y;
|
97 |
-
|
98 |
-
scalar_t v = 0.0;
|
99 |
-
|
100 |
-
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
|
101 |
-
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx];
|
102 |
-
}
|
103 |
-
|
104 |
-
sx[rel_in_y][rel_in_x] = v;
|
105 |
-
}
|
106 |
-
|
107 |
-
__syncthreads();
|
108 |
-
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) {
|
109 |
-
int rel_out_y = out_idx / tile_out_w;
|
110 |
-
int rel_out_x = out_idx - rel_out_y * tile_out_w;
|
111 |
-
int out_x = rel_out_x + tile_out_x;
|
112 |
-
int out_y = rel_out_y + tile_out_y;
|
113 |
-
|
114 |
-
int mid_x = tile_mid_x + rel_out_x * down_x;
|
115 |
-
int mid_y = tile_mid_y + rel_out_y * down_y;
|
116 |
-
int in_x = floor_div(mid_x, up_x);
|
117 |
-
int in_y = floor_div(mid_y, up_y);
|
118 |
-
int rel_in_x = in_x - tile_in_x;
|
119 |
-
int rel_in_y = in_y - tile_in_y;
|
120 |
-
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
|
121 |
-
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
|
122 |
-
|
123 |
-
scalar_t v = 0.0;
|
124 |
-
|
125 |
-
#pragma unroll
|
126 |
-
for (int y = 0; y < kernel_h / up_y; y++)
|
127 |
-
#pragma unroll
|
128 |
-
for (int x = 0; x < kernel_w / up_x; x++)
|
129 |
-
v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x];
|
130 |
-
|
131 |
-
if (out_x < p.out_w & out_y < p.out_h) {
|
132 |
-
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v;
|
133 |
-
}
|
134 |
-
}
|
135 |
-
}
|
136 |
-
}
|
137 |
-
}
|
138 |
-
|
139 |
-
|
140 |
-
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
|
141 |
-
int up_x, int up_y, int down_x, int down_y,
|
142 |
-
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
|
143 |
-
int curDevice = -1;
|
144 |
-
cudaGetDevice(&curDevice);
|
145 |
-
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
146 |
-
|
147 |
-
UpFirDn2DKernelParams p;
|
148 |
-
|
149 |
-
auto x = input.contiguous();
|
150 |
-
auto k = kernel.contiguous();
|
151 |
-
|
152 |
-
p.major_dim = x.size(0);
|
153 |
-
p.in_h = x.size(1);
|
154 |
-
p.in_w = x.size(2);
|
155 |
-
p.minor_dim = x.size(3);
|
156 |
-
p.kernel_h = k.size(0);
|
157 |
-
p.kernel_w = k.size(1);
|
158 |
-
p.up_x = up_x;
|
159 |
-
p.up_y = up_y;
|
160 |
-
p.down_x = down_x;
|
161 |
-
p.down_y = down_y;
|
162 |
-
p.pad_x0 = pad_x0;
|
163 |
-
p.pad_x1 = pad_x1;
|
164 |
-
p.pad_y0 = pad_y0;
|
165 |
-
p.pad_y1 = pad_y1;
|
166 |
-
|
167 |
-
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y;
|
168 |
-
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x;
|
169 |
-
|
170 |
-
auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
|
171 |
-
|
172 |
-
int mode = -1;
|
173 |
-
|
174 |
-
int tile_out_h;
|
175 |
-
int tile_out_w;
|
176 |
-
|
177 |
-
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
|
178 |
-
mode = 1;
|
179 |
-
tile_out_h = 16;
|
180 |
-
tile_out_w = 64;
|
181 |
-
}
|
182 |
-
|
183 |
-
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) {
|
184 |
-
mode = 2;
|
185 |
-
tile_out_h = 16;
|
186 |
-
tile_out_w = 64;
|
187 |
-
}
|
188 |
-
|
189 |
-
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
|
190 |
-
mode = 3;
|
191 |
-
tile_out_h = 16;
|
192 |
-
tile_out_w = 64;
|
193 |
-
}
|
194 |
-
|
195 |
-
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) {
|
196 |
-
mode = 4;
|
197 |
-
tile_out_h = 16;
|
198 |
-
tile_out_w = 64;
|
199 |
-
}
|
200 |
-
|
201 |
-
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) {
|
202 |
-
mode = 5;
|
203 |
-
tile_out_h = 8;
|
204 |
-
tile_out_w = 32;
|
205 |
-
}
|
206 |
-
|
207 |
-
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) {
|
208 |
-
mode = 6;
|
209 |
-
tile_out_h = 8;
|
210 |
-
tile_out_w = 32;
|
211 |
-
}
|
212 |
-
|
213 |
-
dim3 block_size;
|
214 |
-
dim3 grid_size;
|
215 |
-
|
216 |
-
if (tile_out_h > 0 && tile_out_w) {
|
217 |
-
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
218 |
-
p.loop_x = 1;
|
219 |
-
block_size = dim3(32 * 8, 1, 1);
|
220 |
-
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
|
221 |
-
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
|
222 |
-
(p.major_dim - 1) / p.loop_major + 1);
|
223 |
-
}
|
224 |
-
|
225 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
|
226 |
-
switch (mode) {
|
227 |
-
case 1:
|
228 |
-
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
|
229 |
-
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
230 |
-
);
|
231 |
-
|
232 |
-
break;
|
233 |
-
|
234 |
-
case 2:
|
235 |
-
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64><<<grid_size, block_size, 0, stream>>>(
|
236 |
-
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
237 |
-
);
|
238 |
-
|
239 |
-
break;
|
240 |
-
|
241 |
-
case 3:
|
242 |
-
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
|
243 |
-
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
244 |
-
);
|
245 |
-
|
246 |
-
break;
|
247 |
-
|
248 |
-
case 4:
|
249 |
-
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64><<<grid_size, block_size, 0, stream>>>(
|
250 |
-
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
251 |
-
);
|
252 |
-
|
253 |
-
break;
|
254 |
-
|
255 |
-
case 5:
|
256 |
-
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
|
257 |
-
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
258 |
-
);
|
259 |
-
|
260 |
-
break;
|
261 |
-
|
262 |
-
case 6:
|
263 |
-
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
|
264 |
-
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
265 |
-
);
|
266 |
-
|
267 |
-
break;
|
268 |
-
}
|
269 |
-
});
|
270 |
-
|
271 |
-
return out;
|
272 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optim/__init__.py
DELETED
@@ -1,15 +0,0 @@
|
|
1 |
-
from torch.optim import Adam
|
2 |
-
from torch.optim.lbfgs import LBFGS
|
3 |
-
from .radam import RAdam
|
4 |
-
|
5 |
-
|
6 |
-
OPTIMIZER_MAP = {
|
7 |
-
"adam": Adam,
|
8 |
-
"radam": RAdam,
|
9 |
-
"lbfgs": LBFGS,
|
10 |
-
}
|
11 |
-
|
12 |
-
|
13 |
-
def get_optimizer_class(optimizer_name):
|
14 |
-
name = optimizer_name.lower()
|
15 |
-
return OPTIMIZER_MAP[name]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optim/radam.py
DELETED
@@ -1,250 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
import torch
|
3 |
-
from torch.optim.optimizer import Optimizer, required
|
4 |
-
|
5 |
-
|
6 |
-
class RAdam(Optimizer):
|
7 |
-
|
8 |
-
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
|
9 |
-
if not 0.0 <= lr:
|
10 |
-
raise ValueError("Invalid learning rate: {}".format(lr))
|
11 |
-
if not 0.0 <= eps:
|
12 |
-
raise ValueError("Invalid epsilon value: {}".format(eps))
|
13 |
-
if not 0.0 <= betas[0] < 1.0:
|
14 |
-
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
15 |
-
if not 0.0 <= betas[1] < 1.0:
|
16 |
-
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
17 |
-
|
18 |
-
self.degenerated_to_sgd = degenerated_to_sgd
|
19 |
-
if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
|
20 |
-
for param in params:
|
21 |
-
if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
|
22 |
-
param['buffer'] = [[None, None, None] for _ in range(10)]
|
23 |
-
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
|
24 |
-
buffer=[[None, None, None] for _ in range(10)])
|
25 |
-
super(RAdam, self).__init__(params, defaults)
|
26 |
-
|
27 |
-
def __setstate__(self, state):
|
28 |
-
super(RAdam, self).__setstate__(state)
|
29 |
-
|
30 |
-
def step(self, closure=None):
|
31 |
-
|
32 |
-
loss = None
|
33 |
-
if closure is not None:
|
34 |
-
loss = closure()
|
35 |
-
|
36 |
-
for group in self.param_groups:
|
37 |
-
|
38 |
-
for p in group['params']:
|
39 |
-
if p.grad is None:
|
40 |
-
continue
|
41 |
-
grad = p.grad.data.float()
|
42 |
-
if grad.is_sparse:
|
43 |
-
raise RuntimeError('RAdam does not support sparse gradients')
|
44 |
-
|
45 |
-
p_data_fp32 = p.data.float()
|
46 |
-
|
47 |
-
state = self.state[p]
|
48 |
-
|
49 |
-
if len(state) == 0:
|
50 |
-
state['step'] = 0
|
51 |
-
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
52 |
-
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
53 |
-
else:
|
54 |
-
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
55 |
-
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
56 |
-
|
57 |
-
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
58 |
-
beta1, beta2 = group['betas']
|
59 |
-
|
60 |
-
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
61 |
-
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
62 |
-
|
63 |
-
state['step'] += 1
|
64 |
-
buffered = group['buffer'][int(state['step'] % 10)]
|
65 |
-
if state['step'] == buffered[0]:
|
66 |
-
N_sma, step_size = buffered[1], buffered[2]
|
67 |
-
else:
|
68 |
-
buffered[0] = state['step']
|
69 |
-
beta2_t = beta2 ** state['step']
|
70 |
-
N_sma_max = 2 / (1 - beta2) - 1
|
71 |
-
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
72 |
-
buffered[1] = N_sma
|
73 |
-
|
74 |
-
# more conservative since it's an approximated value
|
75 |
-
if N_sma >= 5:
|
76 |
-
step_size = math.sqrt(
|
77 |
-
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
|
78 |
-
N_sma_max - 2)) / (1 - beta1 ** state['step'])
|
79 |
-
elif self.degenerated_to_sgd:
|
80 |
-
step_size = 1.0 / (1 - beta1 ** state['step'])
|
81 |
-
else:
|
82 |
-
step_size = -1
|
83 |
-
buffered[2] = step_size
|
84 |
-
|
85 |
-
# more conservative since it's an approximated value
|
86 |
-
if N_sma >= 5:
|
87 |
-
if group['weight_decay'] != 0:
|
88 |
-
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
89 |
-
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
90 |
-
p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
|
91 |
-
p.data.copy_(p_data_fp32)
|
92 |
-
elif step_size > 0:
|
93 |
-
if group['weight_decay'] != 0:
|
94 |
-
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
95 |
-
p_data_fp32.add_(-step_size * group['lr'], exp_avg)
|
96 |
-
p.data.copy_(p_data_fp32)
|
97 |
-
|
98 |
-
return loss
|
99 |
-
|
100 |
-
|
101 |
-
class PlainRAdam(Optimizer):
|
102 |
-
|
103 |
-
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
|
104 |
-
if not 0.0 <= lr:
|
105 |
-
raise ValueError("Invalid learning rate: {}".format(lr))
|
106 |
-
if not 0.0 <= eps:
|
107 |
-
raise ValueError("Invalid epsilon value: {}".format(eps))
|
108 |
-
if not 0.0 <= betas[0] < 1.0:
|
109 |
-
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
110 |
-
if not 0.0 <= betas[1] < 1.0:
|
111 |
-
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
112 |
-
|
113 |
-
self.degenerated_to_sgd = degenerated_to_sgd
|
114 |
-
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
115 |
-
|
116 |
-
super(PlainRAdam, self).__init__(params, defaults)
|
117 |
-
|
118 |
-
def __setstate__(self, state):
|
119 |
-
super(PlainRAdam, self).__setstate__(state)
|
120 |
-
|
121 |
-
def step(self, closure=None):
|
122 |
-
|
123 |
-
loss = None
|
124 |
-
if closure is not None:
|
125 |
-
loss = closure()
|
126 |
-
|
127 |
-
for group in self.param_groups:
|
128 |
-
|
129 |
-
for p in group['params']:
|
130 |
-
if p.grad is None:
|
131 |
-
continue
|
132 |
-
grad = p.grad.data.float()
|
133 |
-
if grad.is_sparse:
|
134 |
-
raise RuntimeError('RAdam does not support sparse gradients')
|
135 |
-
|
136 |
-
p_data_fp32 = p.data.float()
|
137 |
-
|
138 |
-
state = self.state[p]
|
139 |
-
|
140 |
-
if len(state) == 0:
|
141 |
-
state['step'] = 0
|
142 |
-
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
143 |
-
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
144 |
-
else:
|
145 |
-
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
146 |
-
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
147 |
-
|
148 |
-
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
149 |
-
beta1, beta2 = group['betas']
|
150 |
-
|
151 |
-
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
152 |
-
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
153 |
-
|
154 |
-
state['step'] += 1
|
155 |
-
beta2_t = beta2 ** state['step']
|
156 |
-
N_sma_max = 2 / (1 - beta2) - 1
|
157 |
-
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
158 |
-
|
159 |
-
# more conservative since it's an approximated value
|
160 |
-
if N_sma >= 5:
|
161 |
-
if group['weight_decay'] != 0:
|
162 |
-
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
163 |
-
step_size = group['lr'] * math.sqrt(
|
164 |
-
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
|
165 |
-
N_sma_max - 2)) / (1 - beta1 ** state['step'])
|
166 |
-
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
167 |
-
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
|
168 |
-
p.data.copy_(p_data_fp32)
|
169 |
-
elif self.degenerated_to_sgd:
|
170 |
-
if group['weight_decay'] != 0:
|
171 |
-
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
172 |
-
step_size = group['lr'] / (1 - beta1 ** state['step'])
|
173 |
-
p_data_fp32.add_(-step_size, exp_avg)
|
174 |
-
p.data.copy_(p_data_fp32)
|
175 |
-
|
176 |
-
return loss
|
177 |
-
|
178 |
-
|
179 |
-
class AdamW(Optimizer):
|
180 |
-
|
181 |
-
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup=0):
|
182 |
-
if not 0.0 <= lr:
|
183 |
-
raise ValueError("Invalid learning rate: {}".format(lr))
|
184 |
-
if not 0.0 <= eps:
|
185 |
-
raise ValueError("Invalid epsilon value: {}".format(eps))
|
186 |
-
if not 0.0 <= betas[0] < 1.0:
|
187 |
-
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
188 |
-
if not 0.0 <= betas[1] < 1.0:
|
189 |
-
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
190 |
-
|
191 |
-
defaults = dict(lr=lr, betas=betas, eps=eps,
|
192 |
-
weight_decay=weight_decay, warmup=warmup)
|
193 |
-
super(AdamW, self).__init__(params, defaults)
|
194 |
-
|
195 |
-
def __setstate__(self, state):
|
196 |
-
super(AdamW, self).__setstate__(state)
|
197 |
-
|
198 |
-
def step(self, closure=None):
|
199 |
-
loss = None
|
200 |
-
if closure is not None:
|
201 |
-
loss = closure()
|
202 |
-
|
203 |
-
for group in self.param_groups:
|
204 |
-
|
205 |
-
for p in group['params']:
|
206 |
-
if p.grad is None:
|
207 |
-
continue
|
208 |
-
grad = p.grad.data.float()
|
209 |
-
if grad.is_sparse:
|
210 |
-
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
|
211 |
-
|
212 |
-
p_data_fp32 = p.data.float()
|
213 |
-
|
214 |
-
state = self.state[p]
|
215 |
-
|
216 |
-
if len(state) == 0:
|
217 |
-
state['step'] = 0
|
218 |
-
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
219 |
-
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
220 |
-
else:
|
221 |
-
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
222 |
-
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
223 |
-
|
224 |
-
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
225 |
-
beta1, beta2 = group['betas']
|
226 |
-
|
227 |
-
state['step'] += 1
|
228 |
-
|
229 |
-
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
230 |
-
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
231 |
-
|
232 |
-
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
233 |
-
bias_correction1 = 1 - beta1 ** state['step']
|
234 |
-
bias_correction2 = 1 - beta2 ** state['step']
|
235 |
-
|
236 |
-
if group['warmup'] > state['step']:
|
237 |
-
scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup']
|
238 |
-
else:
|
239 |
-
scheduled_lr = group['lr']
|
240 |
-
|
241 |
-
step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1
|
242 |
-
|
243 |
-
if group['weight_decay'] != 0:
|
244 |
-
p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32)
|
245 |
-
|
246 |
-
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
|
247 |
-
|
248 |
-
p.data.copy_(p_data_fp32)
|
249 |
-
|
250 |
-
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,25 +1,5 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
#setuptools==59.5.0
|
7 |
-
|
8 |
-
Pillow
|
9 |
-
ninja
|
10 |
-
tqdm
|
11 |
-
opencv-python
|
12 |
-
scikit-image
|
13 |
-
numpy
|
14 |
-
|
15 |
-
tensorboard
|
16 |
-
|
17 |
-
# for face alignment
|
18 |
-
tensorflow
|
19 |
-
#keras
|
20 |
-
#bz2
|
21 |
-
dlib
|
22 |
-
scipy
|
23 |
-
|
24 |
-
matplotlib
|
25 |
-
pprintpp
|
|
|
1 |
+
numpy==1.22.3
|
2 |
+
Pillow==9.1.0
|
3 |
+
scipy==1.8.0
|
4 |
+
torch==1.11.0
|
5 |
+
torchvision==0.12.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/download_checkpoints.sh
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
set -exo
|
2 |
-
|
3 |
-
mkdir -p checkpoint
|
4 |
-
gdown https://drive.google.com/uc?id=1hWc2JLM58_PkwfLG23Q5IH3Ysj2Mo1nr -O checkpoint/e4e_ffhq_encode.pt
|
5 |
-
gdown https://drive.google.com/uc?id=1hvAAql9Jo0wlmLBSHRIGrtXHcKQE-Whn -O checkpoint/stylegan2-ffhq-config-f.pt
|
6 |
-
gdown https://drive.google.com/uc?id=1mbGWbjivZxMGxZqyyOHbE310aOkYe2BR -O checkpoint/vgg_face_dag.pt
|
7 |
-
mkdir -p checkpoint/encoder
|
8 |
-
gdown https://drive.google.com/uc?id=1ha4WXsaIpZfMHsqNLvqOPlUXsgh9VawU -O checkpoint/encoder/checkpoint_b.pt
|
9 |
-
gdown https://drive.google.com/uc?id=1hfxDLujRIGU0G7pOdW9MMSBRzxZBmSKJ -O checkpoint/encoder/checkpoint_g.pt
|
10 |
-
gdown https://drive.google.com/uc?id=1htekHopgxaW-MIjs6pYy7pyIK0v7Q0iS -O checkpoint/encoder/checkpoint_gb.pt
|
11 |
-
|
12 |
-
pushd third_party/face_parsing
|
13 |
-
./scripts/download_checkpoints.sh
|
14 |
-
popd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/install.sh
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
# conda create -n stylegan python=3.7
|
2 |
-
# conda activate stylegan
|
3 |
-
conda install -c conda-forge/label/gcc7 opencv --yes
|
4 |
-
conda install tensorflow-gpu=1.15 cudatoolkit=10.0 --yes
|
5 |
-
conda install pytorch torchvision cudatoolkit=10.0 -c pytorch --yes
|
6 |
-
pip install -r requirements.txt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/run.sh
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
set -x
|
2 |
-
|
3 |
-
# Example command
|
4 |
-
# ```
|
5 |
-
# ./scripts/run.sh b "dataset/Abraham Lincoln_01.png" 0.75
|
6 |
-
# ```
|
7 |
-
|
8 |
-
spectral_sensitivity="$1"
|
9 |
-
path="$2"
|
10 |
-
blur_radius="$3"
|
11 |
-
|
12 |
-
|
13 |
-
list="$(dirname "${path}")"
|
14 |
-
list="$(basename "${list}")"
|
15 |
-
|
16 |
-
if [ "${spectral_sensitivity}" == "b" ]; then
|
17 |
-
FLAGS=(--spectral_sensitivity b --encoder_ckpt checkpoint/encoder/checkpoint_b.pt);
|
18 |
-
elif [ "${spectral_sensitivity}" == "gb" ]; then
|
19 |
-
FLAGS=(--spectral_sensitivity "gb" --encoder_ckpt checkpoint/encoder/checkpoint_gb.pt);
|
20 |
-
else
|
21 |
-
FLAGS=(--spectral_sensitivity "g" --encoder_ckpt checkpoint/encoder/checkpoint_g.pt);
|
22 |
-
fi
|
23 |
-
|
24 |
-
name="${path%.*}"
|
25 |
-
name="${name##*/}"
|
26 |
-
echo "${name}"
|
27 |
-
|
28 |
-
# TODO: I did l2 or cos for contextual
|
29 |
-
time python projector.py \
|
30 |
-
"${path}" \
|
31 |
-
--gaussian "${blur_radius}" \
|
32 |
-
--log_dir "log/" \
|
33 |
-
--results_dir "results/" \
|
34 |
-
"${FLAGS[@]}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/__init__.py
DELETED
File without changes
|
tools/data/__init__.py
DELETED
File without changes
|
tools/data/align_images.py
DELETED
@@ -1,117 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import json
|
3 |
-
import os
|
4 |
-
from os.path import join as pjoin
|
5 |
-
import sys
|
6 |
-
import bz2
|
7 |
-
import numpy as np
|
8 |
-
import cv2
|
9 |
-
from tqdm import tqdm
|
10 |
-
from tensorflow.keras.utils import get_file
|
11 |
-
from utils.ffhq_dataset.face_alignment import image_align
|
12 |
-
from utils.ffhq_dataset.landmarks_detector import LandmarksDetector
|
13 |
-
|
14 |
-
LANDMARKS_MODEL_URL = 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2'
|
15 |
-
|
16 |
-
|
17 |
-
def unpack_bz2(src_path):
|
18 |
-
data = bz2.BZ2File(src_path).read()
|
19 |
-
dst_path = src_path[:-4]
|
20 |
-
with open(dst_path, 'wb') as fp:
|
21 |
-
fp.write(data)
|
22 |
-
return dst_path
|
23 |
-
|
24 |
-
|
25 |
-
class SizePathMap(dict):
|
26 |
-
"""{size: {aligned_face_path0, aligned_face_path1, ...}, ...}"""
|
27 |
-
def add_item(self, size, path):
|
28 |
-
if size not in self:
|
29 |
-
self[size] = set()
|
30 |
-
self[size].add(path)
|
31 |
-
|
32 |
-
def get_sizes(self):
|
33 |
-
sizes = []
|
34 |
-
for key, paths in self.items():
|
35 |
-
sizes.extend([key,]*len(paths))
|
36 |
-
return sizes
|
37 |
-
|
38 |
-
def serialize(self):
|
39 |
-
result = {}
|
40 |
-
for key, paths in self.items():
|
41 |
-
result[key] = list(paths)
|
42 |
-
return result
|
43 |
-
|
44 |
-
|
45 |
-
def main(args):
|
46 |
-
landmarks_model_path = unpack_bz2(get_file('shape_predictor_68_face_landmarks.dat.bz2',
|
47 |
-
LANDMARKS_MODEL_URL, cache_subdir='temp'))
|
48 |
-
|
49 |
-
landmarks_detector = LandmarksDetector(landmarks_model_path)
|
50 |
-
face_sizes = SizePathMap()
|
51 |
-
raw_img_dir = args.raw_image_dir
|
52 |
-
img_names = [n for n in os.listdir(raw_img_dir) if os.path.isfile(pjoin(raw_img_dir, n))]
|
53 |
-
aligned_image_dir = args.aligned_image_dir
|
54 |
-
os.makedirs(aligned_image_dir, exist_ok=True)
|
55 |
-
pbar = tqdm(img_names)
|
56 |
-
for img_name in pbar:
|
57 |
-
pbar.set_description(img_name)
|
58 |
-
if os.path.splitext(img_name)[-1] == '.txt':
|
59 |
-
continue
|
60 |
-
raw_img_path = os.path.join(raw_img_dir, img_name)
|
61 |
-
try:
|
62 |
-
for i, face_landmarks in enumerate(landmarks_detector.get_landmarks(raw_img_path), start=1):
|
63 |
-
face_img_name = '%s_%02d.png' % (os.path.splitext(img_name)[0], i)
|
64 |
-
aligned_face_path = os.path.join(aligned_image_dir, face_img_name)
|
65 |
-
|
66 |
-
face_size = image_align(
|
67 |
-
raw_img_path, aligned_face_path, face_landmarks, resize=args.resize
|
68 |
-
)
|
69 |
-
face_sizes.add_item(face_size, aligned_face_path)
|
70 |
-
pbar.set_description(f"{img_name}: {face_size}")
|
71 |
-
|
72 |
-
if args.draw:
|
73 |
-
visual = LandmarksDetector.draw(cv2.imread(raw_img_path), face_landmarks)
|
74 |
-
cv2.imwrite(
|
75 |
-
pjoin(args.aligned_image_dir, os.path.splitext(face_img_name)[0] + "_landmarks.png"),
|
76 |
-
visual
|
77 |
-
)
|
78 |
-
except Exception as e:
|
79 |
-
print('[Error]', e, 'error happened when processing', raw_img_path)
|
80 |
-
|
81 |
-
print(args.raw_image_dir, ':')
|
82 |
-
sizes = face_sizes.get_sizes()
|
83 |
-
results = {
|
84 |
-
'mean_size': np.mean(sizes),
|
85 |
-
'num_faces_detected': len(sizes),
|
86 |
-
'num_images': len(img_names),
|
87 |
-
'sizes': sizes,
|
88 |
-
'size_path_dict': face_sizes.serialize(),
|
89 |
-
}
|
90 |
-
print('\t', results)
|
91 |
-
if args.out_stats is not None:
|
92 |
-
os.makedirs(os.path.dirname(args.out_stats), exist_ok=True)
|
93 |
-
with open(out_stats, 'w') as f:
|
94 |
-
json.dump(results, f)
|
95 |
-
|
96 |
-
|
97 |
-
def parse_args(args=None, namespace=None):
|
98 |
-
parser = argparse.ArgumentParser(description="""
|
99 |
-
Extracts and aligns all faces from images using DLib and a function from original FFHQ dataset preparation step
|
100 |
-
python align_images.py /raw_images /aligned_images
|
101 |
-
"""
|
102 |
-
)
|
103 |
-
parser.add_argument('raw_image_dir')
|
104 |
-
parser.add_argument('aligned_image_dir')
|
105 |
-
parser.add_argument('--resize',
|
106 |
-
help="True if want to resize to 1024",
|
107 |
-
action='store_true')
|
108 |
-
parser.add_argument('--draw',
|
109 |
-
help="True if want to visualize landmarks",
|
110 |
-
action='store_true')
|
111 |
-
parser.add_argument('--out_stats',
|
112 |
-
help="output_fn for statistics of faces", default=None)
|
113 |
-
return parser.parse_args(args=args, namespace=namespace)
|
114 |
-
|
115 |
-
|
116 |
-
if __name__ == "__main__":
|
117 |
-
main(parse_args())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/initialize.py
DELETED
@@ -1,160 +0,0 @@
|
|
1 |
-
from argparse import ArgumentParser, Namespace
|
2 |
-
from typing import (
|
3 |
-
List,
|
4 |
-
Tuple,
|
5 |
-
)
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
from PIL import Image
|
9 |
-
import torch
|
10 |
-
from torch import nn
|
11 |
-
import torch.nn.functional as F
|
12 |
-
from torchvision.transforms import (
|
13 |
-
Compose,
|
14 |
-
Grayscale,
|
15 |
-
Resize,
|
16 |
-
ToTensor,
|
17 |
-
)
|
18 |
-
|
19 |
-
from models.encoder import Encoder
|
20 |
-
from models.encoder4editing import (
|
21 |
-
get_latents as get_e4e_latents,
|
22 |
-
setup_model as setup_e4e_model,
|
23 |
-
)
|
24 |
-
from utils.misc import (
|
25 |
-
optional_string,
|
26 |
-
iterable_to_str,
|
27 |
-
stem,
|
28 |
-
)
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
class ColorEncoderArguments:
|
33 |
-
def __init__(self):
|
34 |
-
parser = ArgumentParser("Encode an image via a feed-forward encoder")
|
35 |
-
|
36 |
-
self.add_arguments(parser)
|
37 |
-
|
38 |
-
self.parser = parser
|
39 |
-
|
40 |
-
@staticmethod
|
41 |
-
def add_arguments(parser: ArgumentParser):
|
42 |
-
parser.add_argument("--encoder_ckpt", default=None,
|
43 |
-
help="encoder checkpoint path. initialize w with encoder output if specified")
|
44 |
-
parser.add_argument("--encoder_size", type=int, default=256,
|
45 |
-
help="Resize to this size to pass as input to the encoder")
|
46 |
-
|
47 |
-
|
48 |
-
class InitializerArguments:
|
49 |
-
@classmethod
|
50 |
-
def add_arguments(cls, parser: ArgumentParser):
|
51 |
-
ColorEncoderArguments.add_arguments(parser)
|
52 |
-
cls.add_e4e_arguments(parser)
|
53 |
-
parser.add_argument("--mix_layer_range", default=[10, 18], type=int, nargs=2,
|
54 |
-
help="replace layers <start> to <end> in the e4e code by the color code")
|
55 |
-
|
56 |
-
parser.add_argument("--init_latent", default=None, help="path to init wp")
|
57 |
-
|
58 |
-
@staticmethod
|
59 |
-
def to_string(args: Namespace):
|
60 |
-
return (f"init{stem(args.init_latent).lstrip('0')[:10]}" if args.init_latent
|
61 |
-
else f"init({iterable_to_str(args.mix_layer_range)})")
|
62 |
-
#+ optional_string(args.init_noise > 0, f"-initN{args.init_noise}")
|
63 |
-
|
64 |
-
@staticmethod
|
65 |
-
def add_e4e_arguments(parser: ArgumentParser):
|
66 |
-
parser.add_argument("--e4e_ckpt", default='checkpoint/e4e_ffhq_encode.pt',
|
67 |
-
help="e4e checkpoint path.")
|
68 |
-
parser.add_argument("--e4e_size", type=int, default=256,
|
69 |
-
help="Resize to this size to pass as input to the e4e")
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
def create_color_encoder(args: Namespace):
|
74 |
-
encoder = Encoder(1, args.encoder_size, 512)
|
75 |
-
ckpt = torch.load(args.encoder_ckpt)
|
76 |
-
encoder.load_state_dict(ckpt["model"])
|
77 |
-
return encoder
|
78 |
-
|
79 |
-
|
80 |
-
def transform_input(img: Image):
|
81 |
-
tsfm = Compose([
|
82 |
-
Grayscale(),
|
83 |
-
Resize(args.encoder_size),
|
84 |
-
ToTensor(),
|
85 |
-
])
|
86 |
-
return tsfm(img)
|
87 |
-
|
88 |
-
|
89 |
-
def encode_color(imgs: torch.Tensor, args: Namespace) -> torch.Tensor:
|
90 |
-
assert args.encoder_size is not None
|
91 |
-
|
92 |
-
imgs = Resize(args.encoder_size)(imgs)
|
93 |
-
|
94 |
-
color_encoder = create_color_encoder(args).to(imgs.device)
|
95 |
-
color_encoder.eval()
|
96 |
-
with torch.no_grad():
|
97 |
-
latent = color_encoder(imgs)
|
98 |
-
return latent.detach()
|
99 |
-
|
100 |
-
|
101 |
-
def resize(imgs: torch.Tensor, size: int) -> torch.Tensor:
|
102 |
-
return F.interpolate(imgs, size=size, mode='bilinear')
|
103 |
-
|
104 |
-
|
105 |
-
class Initializer(nn.Module):
|
106 |
-
def __init__(self, args: Namespace):
|
107 |
-
super().__init__()
|
108 |
-
|
109 |
-
self.path = None
|
110 |
-
if args.init_latent is not None:
|
111 |
-
self.path = args.init_latent
|
112 |
-
return
|
113 |
-
|
114 |
-
|
115 |
-
assert args.encoder_size is not None
|
116 |
-
self.color_encoder = create_color_encoder(args)
|
117 |
-
self.color_encoder.eval()
|
118 |
-
self.color_encoder_size = args.encoder_size
|
119 |
-
|
120 |
-
self.e4e, e4e_opts = setup_e4e_model(args.e4e_ckpt)
|
121 |
-
assert 'cars_' not in e4e_opts.dataset_type
|
122 |
-
self.e4e.decoder.eval()
|
123 |
-
self.e4e.eval()
|
124 |
-
self.e4e_size = args.e4e_size
|
125 |
-
|
126 |
-
self.mix_layer_range = args.mix_layer_range
|
127 |
-
|
128 |
-
def encode_color(self, imgs: torch.Tensor) -> torch.Tensor:
|
129 |
-
"""
|
130 |
-
Get the color W code
|
131 |
-
"""
|
132 |
-
imgs = resize(imgs, self.color_encoder_size)
|
133 |
-
|
134 |
-
latent = self.color_encoder(imgs)
|
135 |
-
|
136 |
-
return latent
|
137 |
-
|
138 |
-
def encode_shape(self, imgs: torch.Tensor) -> torch.Tensor:
|
139 |
-
imgs = resize(imgs, self.e4e_size)
|
140 |
-
imgs = (imgs - 0.5) / 0.5
|
141 |
-
if imgs.shape[1] == 1: # 1 channel
|
142 |
-
imgs = imgs.repeat(1, 3, 1, 1)
|
143 |
-
return get_e4e_latents(self.e4e, imgs)
|
144 |
-
|
145 |
-
def load(self, device: torch.device):
|
146 |
-
latent_np = np.load(self.path)
|
147 |
-
return torch.tensor(latent_np, device=device)[None, ...]
|
148 |
-
|
149 |
-
def forward(self, imgs: torch.Tensor) -> torch.Tensor:
|
150 |
-
if self.path is not None:
|
151 |
-
return self.load(imgs.device)
|
152 |
-
|
153 |
-
shape_code = self.encode_shape(imgs)
|
154 |
-
color_code = self.encode_color(imgs)
|
155 |
-
|
156 |
-
# style mix
|
157 |
-
latent = shape_code
|
158 |
-
start, end = self.mix_layer_range
|
159 |
-
latent[:, start:end] = color_code
|
160 |
-
return latent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/match_histogram.py
DELETED
@@ -1,167 +0,0 @@
|
|
1 |
-
from argparse import (
|
2 |
-
ArgumentParser,
|
3 |
-
Namespace,
|
4 |
-
)
|
5 |
-
import os
|
6 |
-
from os.path import join as pjoin
|
7 |
-
from typing import Optional
|
8 |
-
import sys
|
9 |
-
|
10 |
-
import numpy as np
|
11 |
-
import cv2
|
12 |
-
from skimage import exposure
|
13 |
-
|
14 |
-
|
15 |
-
# sys.path.append('Face_Detection')
|
16 |
-
# from align_warp_back_multiple_dlib import match_histograms
|
17 |
-
|
18 |
-
|
19 |
-
def calculate_cdf(histogram):
|
20 |
-
"""
|
21 |
-
This method calculates the cumulative distribution function
|
22 |
-
:param array histogram: The values of the histogram
|
23 |
-
:return: normalized_cdf: The normalized cumulative distribution function
|
24 |
-
:rtype: array
|
25 |
-
"""
|
26 |
-
# Get the cumulative sum of the elements
|
27 |
-
cdf = histogram.cumsum()
|
28 |
-
|
29 |
-
# Normalize the cdf
|
30 |
-
normalized_cdf = cdf / float(cdf.max())
|
31 |
-
|
32 |
-
return normalized_cdf
|
33 |
-
|
34 |
-
|
35 |
-
def calculate_lookup(src_cdf, ref_cdf):
|
36 |
-
"""
|
37 |
-
This method creates the lookup table
|
38 |
-
:param array src_cdf: The cdf for the source image
|
39 |
-
:param array ref_cdf: The cdf for the reference image
|
40 |
-
:return: lookup_table: The lookup table
|
41 |
-
:rtype: array
|
42 |
-
"""
|
43 |
-
lookup_table = np.zeros(256)
|
44 |
-
lookup_val = 0
|
45 |
-
for src_pixel_val in range(len(src_cdf)):
|
46 |
-
lookup_val
|
47 |
-
for ref_pixel_val in range(len(ref_cdf)):
|
48 |
-
if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]:
|
49 |
-
lookup_val = ref_pixel_val
|
50 |
-
break
|
51 |
-
lookup_table[src_pixel_val] = lookup_val
|
52 |
-
return lookup_table
|
53 |
-
|
54 |
-
|
55 |
-
def match_histograms(src_image, ref_image, src_mask=None, ref_mask=None):
|
56 |
-
"""
|
57 |
-
This method matches the source image histogram to the
|
58 |
-
reference signal
|
59 |
-
:param image src_image: The original source image
|
60 |
-
:param image ref_image: The reference image
|
61 |
-
:return: image_after_matching
|
62 |
-
:rtype: image (array)
|
63 |
-
"""
|
64 |
-
# Split the images into the different color channels
|
65 |
-
# b means blue, g means green and r means red
|
66 |
-
src_b, src_g, src_r = cv2.split(src_image)
|
67 |
-
ref_b, ref_g, ref_r = cv2.split(ref_image)
|
68 |
-
|
69 |
-
def rv(im):
|
70 |
-
if ref_mask is None:
|
71 |
-
return im.flatten()
|
72 |
-
return im[ref_mask]
|
73 |
-
|
74 |
-
def sv(im):
|
75 |
-
if src_mask is None:
|
76 |
-
return im.flatten()
|
77 |
-
return im[src_mask]
|
78 |
-
|
79 |
-
# Compute the b, g, and r histograms separately
|
80 |
-
# The flatten() Numpy method returns a copy of the array c
|
81 |
-
# collapsed into one dimension.
|
82 |
-
src_hist_blue, bin_0 = np.histogram(sv(src_b), 256, [0, 256])
|
83 |
-
src_hist_green, bin_1 = np.histogram(sv(src_g), 256, [0, 256])
|
84 |
-
src_hist_red, bin_2 = np.histogram(sv(src_r), 256, [0, 256])
|
85 |
-
ref_hist_blue, bin_3 = np.histogram(rv(ref_b), 256, [0, 256])
|
86 |
-
ref_hist_green, bin_4 = np.histogram(rv(ref_g), 256, [0, 256])
|
87 |
-
ref_hist_red, bin_5 = np.histogram(rv(ref_r), 256, [0, 256])
|
88 |
-
|
89 |
-
# Compute the normalized cdf for the source and reference image
|
90 |
-
src_cdf_blue = calculate_cdf(src_hist_blue)
|
91 |
-
src_cdf_green = calculate_cdf(src_hist_green)
|
92 |
-
src_cdf_red = calculate_cdf(src_hist_red)
|
93 |
-
ref_cdf_blue = calculate_cdf(ref_hist_blue)
|
94 |
-
ref_cdf_green = calculate_cdf(ref_hist_green)
|
95 |
-
ref_cdf_red = calculate_cdf(ref_hist_red)
|
96 |
-
|
97 |
-
# Make a separate lookup table for each color
|
98 |
-
blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue)
|
99 |
-
green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green)
|
100 |
-
red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red)
|
101 |
-
|
102 |
-
# Use the lookup function to transform the colors of the original
|
103 |
-
# source image
|
104 |
-
blue_after_transform = cv2.LUT(src_b, blue_lookup_table)
|
105 |
-
green_after_transform = cv2.LUT(src_g, green_lookup_table)
|
106 |
-
red_after_transform = cv2.LUT(src_r, red_lookup_table)
|
107 |
-
|
108 |
-
# Put the image back together
|
109 |
-
image_after_matching = cv2.merge([blue_after_transform, green_after_transform, red_after_transform])
|
110 |
-
image_after_matching = cv2.convertScaleAbs(image_after_matching)
|
111 |
-
|
112 |
-
return image_after_matching
|
113 |
-
|
114 |
-
|
115 |
-
def convert_to_BW(im, mode):
|
116 |
-
if mode == "b":
|
117 |
-
gray = im[..., 0]
|
118 |
-
elif mode == "gb":
|
119 |
-
gray = (im[..., 0].astype(float) + im[..., 1]) / 2.0
|
120 |
-
else:
|
121 |
-
gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
|
122 |
-
gray = gray.astype(np.uint8)
|
123 |
-
|
124 |
-
return np.stack([gray] * 3, axis=-1)
|
125 |
-
|
126 |
-
|
127 |
-
def parse_args(args=None, namespace: Optional[Namespace] = None):
|
128 |
-
parser = ArgumentParser('match histogram of src to ref')
|
129 |
-
parser.add_argument('src')
|
130 |
-
parser.add_argument('ref')
|
131 |
-
parser.add_argument('--out', default=None, help="converted src that matches ref")
|
132 |
-
parser.add_argument('--src_mask', default=None, help="mask on which to match the histogram")
|
133 |
-
parser.add_argument('--ref_mask', default=None, help="mask on which to match the histogram")
|
134 |
-
parser.add_argument('--spectral_sensitivity', choices=['b', 'gb', 'g'], help="match the histogram of corresponding sensitive channel(s)")
|
135 |
-
parser.add_argument('--crop', type=int, default=0, help="crop the boundary to match")
|
136 |
-
return parser.parse_args(args=args, namespace=namespace)
|
137 |
-
|
138 |
-
|
139 |
-
def main(args):
|
140 |
-
A = cv2.imread(args.ref)
|
141 |
-
A = convert_to_BW(A, args.spectral_sensitivity)
|
142 |
-
B = cv2.imread(args.src, 0)
|
143 |
-
B = np.stack((B,) * 3, axis=-1)
|
144 |
-
|
145 |
-
mask_A = cv2.resize(cv2.imread(args.ref_mask, 0), A.shape[:2][::-1],
|
146 |
-
interpolation=cv2.INTER_NEAREST) > 0 if args.ref_mask else None
|
147 |
-
mask_B = cv2.resize(cv2.imread(args.src_mask, 0), B.shape[:2][::-1],
|
148 |
-
interpolation=cv2.INTER_NEAREST) > 0 if args.src_mask else None
|
149 |
-
|
150 |
-
if args.crop > 0:
|
151 |
-
c = args.crop
|
152 |
-
bc = int(c / A.shape[0] * B.shape[0] + 0.5)
|
153 |
-
A = A[c:-c, c:-c]
|
154 |
-
B = B[bc:-bc, bc:-bc]
|
155 |
-
|
156 |
-
B = match_histograms(B, A, src_mask=mask_B, ref_mask=mask_A)
|
157 |
-
# B = exposure.match_histograms(B, A, multichannel=True)
|
158 |
-
|
159 |
-
if args.out:
|
160 |
-
os.makedirs(os.path.dirname(args.out), exist_ok=True)
|
161 |
-
cv2.imwrite(args.out, B)
|
162 |
-
|
163 |
-
return B
|
164 |
-
|
165 |
-
|
166 |
-
if __name__ == "__main__":
|
167 |
-
main(parse_args())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/match_skin_histogram.py
DELETED
@@ -1,67 +0,0 @@
|
|
1 |
-
from argparse import Namespace
|
2 |
-
import os
|
3 |
-
from os.path import join as pjoin
|
4 |
-
from typing import Optional
|
5 |
-
|
6 |
-
import cv2
|
7 |
-
import torch
|
8 |
-
|
9 |
-
from tools import (
|
10 |
-
parse_face,
|
11 |
-
match_histogram,
|
12 |
-
)
|
13 |
-
from utils.torch_helpers import make_image
|
14 |
-
from utils.misc import stem
|
15 |
-
|
16 |
-
|
17 |
-
def match_skin_histogram(
|
18 |
-
imgs: torch.Tensor,
|
19 |
-
sibling_img: torch.Tensor,
|
20 |
-
spectral_sensitivity,
|
21 |
-
im_sibling_dir: str,
|
22 |
-
mask_dir: str,
|
23 |
-
matched_hist_fn: Optional[str] = None,
|
24 |
-
normalize=None, # normalize the range of the tensor
|
25 |
-
):
|
26 |
-
"""
|
27 |
-
Extract the skin of the input and sibling images. Create a new input image by matching
|
28 |
-
its histogram to the sibling.
|
29 |
-
"""
|
30 |
-
# TODO: Currently only allows imgs of batch size 1
|
31 |
-
im_sibling_dir = os.path.abspath(im_sibling_dir)
|
32 |
-
mask_dir = os.path.abspath(mask_dir)
|
33 |
-
|
34 |
-
img_np = make_image(imgs)[0]
|
35 |
-
sibling_np = make_image(sibling_img)[0][...,::-1]
|
36 |
-
|
37 |
-
# save img, sibling
|
38 |
-
os.makedirs(im_sibling_dir, exist_ok=True)
|
39 |
-
im_name, sibling_name = 'input.png', 'sibling.png'
|
40 |
-
cv2.imwrite(pjoin(im_sibling_dir, im_name), img_np)
|
41 |
-
cv2.imwrite(pjoin(im_sibling_dir, sibling_name), sibling_np)
|
42 |
-
|
43 |
-
# face parsing
|
44 |
-
parse_face.main(
|
45 |
-
Namespace(in_dir=im_sibling_dir, out_dir=mask_dir, include_hair=False)
|
46 |
-
)
|
47 |
-
|
48 |
-
# match_histogram
|
49 |
-
mh_args = match_histogram.parse_args(
|
50 |
-
args=[
|
51 |
-
pjoin(im_sibling_dir, im_name),
|
52 |
-
pjoin(im_sibling_dir, sibling_name),
|
53 |
-
],
|
54 |
-
namespace=Namespace(
|
55 |
-
out=matched_hist_fn if matched_hist_fn else pjoin(im_sibling_dir, "match_histogram.png"),
|
56 |
-
src_mask=pjoin(mask_dir, im_name),
|
57 |
-
ref_mask=pjoin(mask_dir, sibling_name),
|
58 |
-
spectral_sensitivity=spectral_sensitivity,
|
59 |
-
)
|
60 |
-
)
|
61 |
-
matched_np = match_histogram.main(mh_args) / 255.0 # [0, 1]
|
62 |
-
matched = torch.FloatTensor(matched_np).permute(2, 0, 1)[None,...] #BCHW
|
63 |
-
|
64 |
-
if normalize is not None:
|
65 |
-
matched = normalize(matched)
|
66 |
-
|
67 |
-
return matched
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/parse_face.py
DELETED
@@ -1,55 +0,0 @@
|
|
1 |
-
from argparse import ArgumentParser
|
2 |
-
import os
|
3 |
-
from os.path import join as pjoin
|
4 |
-
from subprocess import run
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
import cv2
|
8 |
-
from tqdm import tqdm
|
9 |
-
|
10 |
-
|
11 |
-
def create_skin_mask(anno_dir, mask_dir, skin_thresh=13, include_hair=False):
|
12 |
-
names = os.listdir(anno_dir)
|
13 |
-
names = [n for n in names if n.endswith('.png')]
|
14 |
-
os.makedirs(mask_dir, exist_ok=True)
|
15 |
-
for name in tqdm(names):
|
16 |
-
anno = cv2.imread(pjoin(anno_dir, name), 0)
|
17 |
-
mask = np.logical_and(0 < anno, anno <= skin_thresh)
|
18 |
-
if include_hair:
|
19 |
-
mask |= anno == 17
|
20 |
-
cv2.imwrite(pjoin(mask_dir, name), mask * 255)
|
21 |
-
|
22 |
-
|
23 |
-
def main(args):
|
24 |
-
FACE_PARSING_DIR = 'third_party/face_parsing'
|
25 |
-
|
26 |
-
main_env = os.getcwd()
|
27 |
-
os.chdir(FACE_PARSING_DIR)
|
28 |
-
tmp_parse_dir = pjoin(args.out_dir, 'face_parsing')
|
29 |
-
cmd = [
|
30 |
-
'python',
|
31 |
-
'test.py',
|
32 |
-
args.in_dir,
|
33 |
-
tmp_parse_dir,
|
34 |
-
]
|
35 |
-
print(' '.join(cmd))
|
36 |
-
run(cmd)
|
37 |
-
|
38 |
-
create_skin_mask(tmp_parse_dir, args.out_dir, include_hair=args.include_hair)
|
39 |
-
|
40 |
-
os.chdir(main_env)
|
41 |
-
|
42 |
-
|
43 |
-
def parse_args(args=None, namespace=None):
|
44 |
-
parser = ArgumentParser("Face Parsing and generate skin (& hair) mask")
|
45 |
-
parser.add_argument('in_dir')
|
46 |
-
parser.add_argument('out_dir')
|
47 |
-
parser.add_argument('--include_hair', action="store_true", help="include hair in the mask")
|
48 |
-
return parser.parse_args(args=args, namespace=namespace)
|
49 |
-
|
50 |
-
|
51 |
-
if __name__ == "__main__":
|
52 |
-
main(parse_args())
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
+
|
3 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
+
# and proprietary rights in and to this software, related documentation
|
7 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
+
# distribution of this software and related documentation without an express
|
9 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
+
|
11 |
+
# empty
|
torch_utils/custom_ops.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
+
|
3 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
+
# and proprietary rights in and to this software, related documentation
|
7 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
+
# distribution of this software and related documentation without an express
|
9 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
+
|
11 |
+
import os
|
12 |
+
import glob
|
13 |
+
import torch
|
14 |
+
import torch.utils.cpp_extension
|
15 |
+
import importlib
|
16 |
+
import hashlib
|
17 |
+
import shutil
|
18 |
+
from pathlib import Path
|
19 |
+
import re
|
20 |
+
import uuid
|
21 |
+
|
22 |
+
from torch.utils.file_baton import FileBaton
|
23 |
+
|
24 |
+
#----------------------------------------------------------------------------
|
25 |
+
# Global options.
|
26 |
+
|
27 |
+
verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
|
28 |
+
|
29 |
+
#----------------------------------------------------------------------------
|
30 |
+
# Internal helper funcs.
|
31 |
+
|
32 |
+
def _find_compiler_bindir():
|
33 |
+
patterns = [
|
34 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
35 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
36 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
37 |
+
'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
|
38 |
+
]
|
39 |
+
for pattern in patterns:
|
40 |
+
matches = sorted(glob.glob(pattern))
|
41 |
+
if len(matches):
|
42 |
+
return matches[-1]
|
43 |
+
return None
|
44 |
+
|
45 |
+
def _get_mangled_gpu_name():
|
46 |
+
name = torch.cuda.get_device_name().lower()
|
47 |
+
out = []
|
48 |
+
for c in name:
|
49 |
+
if re.match('[a-z0-9_-]+', c):
|
50 |
+
out.append(c)
|
51 |
+
else:
|
52 |
+
out.append('-')
|
53 |
+
return ''.join(out)
|
54 |
+
|
55 |
+
|
56 |
+
#----------------------------------------------------------------------------
|
57 |
+
# Main entry point for compiling and loading C++/CUDA plugins.
|
58 |
+
|
59 |
+
_cached_plugins = dict()
|
60 |
+
|
61 |
+
def get_plugin(module_name, sources, **build_kwargs):
|
62 |
+
assert verbosity in ['none', 'brief', 'full']
|
63 |
+
|
64 |
+
# Already cached?
|
65 |
+
if module_name in _cached_plugins:
|
66 |
+
return _cached_plugins[module_name]
|
67 |
+
|
68 |
+
# Print status.
|
69 |
+
if verbosity == 'full':
|
70 |
+
print(f'Setting up PyTorch plugin "{module_name}"...')
|
71 |
+
elif verbosity == 'brief':
|
72 |
+
print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
|
73 |
+
|
74 |
+
try: # pylint: disable=too-many-nested-blocks
|
75 |
+
# Make sure we can find the necessary compiler binaries.
|
76 |
+
if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
|
77 |
+
compiler_bindir = _find_compiler_bindir()
|
78 |
+
if compiler_bindir is None:
|
79 |
+
raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
|
80 |
+
os.environ['PATH'] += ';' + compiler_bindir
|
81 |
+
|
82 |
+
# Compile and load.
|
83 |
+
verbose_build = (verbosity == 'full')
|
84 |
+
|
85 |
+
# Incremental build md5sum trickery. Copies all the input source files
|
86 |
+
# into a cached build directory under a combined md5 digest of the input
|
87 |
+
# source files. Copying is done only if the combined digest has changed.
|
88 |
+
# This keeps input file timestamps and filenames the same as in previous
|
89 |
+
# extension builds, allowing for fast incremental rebuilds.
|
90 |
+
#
|
91 |
+
# This optimization is done only in case all the source files reside in
|
92 |
+
# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
|
93 |
+
# environment variable is set (we take this as a signal that the user
|
94 |
+
# actually cares about this.)
|
95 |
+
source_dirs_set = set(os.path.dirname(source) for source in sources)
|
96 |
+
if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
|
97 |
+
all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
|
98 |
+
|
99 |
+
# Compute a combined hash digest for all source files in the same
|
100 |
+
# custom op directory (usually .cu, .cpp, .py and .h files).
|
101 |
+
hash_md5 = hashlib.md5()
|
102 |
+
for src in all_source_files:
|
103 |
+
with open(src, 'rb') as f:
|
104 |
+
hash_md5.update(f.read())
|
105 |
+
build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
|
106 |
+
digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
|
107 |
+
|
108 |
+
if not os.path.isdir(digest_build_dir):
|
109 |
+
os.makedirs(digest_build_dir, exist_ok=True)
|
110 |
+
baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
|
111 |
+
if baton.try_acquire():
|
112 |
+
try:
|
113 |
+
for src in all_source_files:
|
114 |
+
shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
|
115 |
+
finally:
|
116 |
+
baton.release()
|
117 |
+
else:
|
118 |
+
# Someone else is copying source files under the digest dir,
|
119 |
+
# wait until done and continue.
|
120 |
+
baton.wait()
|
121 |
+
digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
|
122 |
+
torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
|
123 |
+
verbose=verbose_build, sources=digest_sources, **build_kwargs)
|
124 |
+
else:
|
125 |
+
torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
|
126 |
+
module = importlib.import_module(module_name)
|
127 |
+
|
128 |
+
except:
|
129 |
+
if verbosity == 'brief':
|
130 |
+
print('Failed!')
|
131 |
+
raise
|
132 |
+
|
133 |
+
# Print status and add to cache.
|
134 |
+
if verbosity == 'full':
|
135 |
+
print(f'Done setting up PyTorch plugin "{module_name}".')
|
136 |
+
elif verbosity == 'brief':
|
137 |
+
print('Done.')
|
138 |
+
_cached_plugins[module_name] = module
|
139 |
+
return module
|
140 |
+
|
141 |
+
#----------------------------------------------------------------------------
|
142 |
+
def get_plugin_v3(module_name, sources, headers=None, source_dir=None, **build_kwargs):
|
143 |
+
assert verbosity in ['none', 'brief', 'full']
|
144 |
+
if headers is None:
|
145 |
+
headers = []
|
146 |
+
if source_dir is not None:
|
147 |
+
sources = [os.path.join(source_dir, fname) for fname in sources]
|
148 |
+
headers = [os.path.join(source_dir, fname) for fname in headers]
|
149 |
+
|
150 |
+
# Already cached?
|
151 |
+
if module_name in _cached_plugins:
|
152 |
+
return _cached_plugins[module_name]
|
153 |
+
|
154 |
+
# Print status.
|
155 |
+
if verbosity == 'full':
|
156 |
+
print(f'Setting up PyTorch plugin "{module_name}"...')
|
157 |
+
elif verbosity == 'brief':
|
158 |
+
print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
|
159 |
+
verbose_build = (verbosity == 'full')
|
160 |
+
|
161 |
+
# Compile and load.
|
162 |
+
try: # pylint: disable=too-many-nested-blocks
|
163 |
+
# Make sure we can find the necessary compiler binaries.
|
164 |
+
if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
|
165 |
+
compiler_bindir = _find_compiler_bindir()
|
166 |
+
if compiler_bindir is None:
|
167 |
+
raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
|
168 |
+
os.environ['PATH'] += ';' + compiler_bindir
|
169 |
+
|
170 |
+
# Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
|
171 |
+
# break the build or unnecessarily restrict what's available to nvcc.
|
172 |
+
# Unset it to let nvcc decide based on what's available on the
|
173 |
+
# machine.
|
174 |
+
os.environ['TORCH_CUDA_ARCH_LIST'] = ''
|
175 |
+
|
176 |
+
# Incremental build md5sum trickery. Copies all the input source files
|
177 |
+
# into a cached build directory under a combined md5 digest of the input
|
178 |
+
# source files. Copying is done only if the combined digest has changed.
|
179 |
+
# This keeps input file timestamps and filenames the same as in previous
|
180 |
+
# extension builds, allowing for fast incremental rebuilds.
|
181 |
+
#
|
182 |
+
# This optimization is done only in case all the source files reside in
|
183 |
+
# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
|
184 |
+
# environment variable is set (we take this as a signal that the user
|
185 |
+
# actually cares about this.)
|
186 |
+
#
|
187 |
+
# EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
|
188 |
+
# around the *.cu dependency bug in ninja config.
|
189 |
+
#
|
190 |
+
all_source_files = sorted(sources + headers)
|
191 |
+
all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
|
192 |
+
if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
|
193 |
+
|
194 |
+
# Compute combined hash digest for all source files.
|
195 |
+
hash_md5 = hashlib.md5()
|
196 |
+
for src in all_source_files:
|
197 |
+
with open(src, 'rb') as f:
|
198 |
+
hash_md5.update(f.read())
|
199 |
+
|
200 |
+
# Select cached build directory name.
|
201 |
+
source_digest = hash_md5.hexdigest()
|
202 |
+
build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
|
203 |
+
cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
|
204 |
+
|
205 |
+
if not os.path.isdir(cached_build_dir):
|
206 |
+
tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
|
207 |
+
os.makedirs(tmpdir)
|
208 |
+
for src in all_source_files:
|
209 |
+
shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
|
210 |
+
try:
|
211 |
+
os.replace(tmpdir, cached_build_dir) # atomic
|
212 |
+
except OSError:
|
213 |
+
# source directory already exists, delete tmpdir and its contents.
|
214 |
+
shutil.rmtree(tmpdir)
|
215 |
+
if not os.path.isdir(cached_build_dir): raise
|
216 |
+
|
217 |
+
# Compile.
|
218 |
+
cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
|
219 |
+
torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
|
220 |
+
verbose=verbose_build, sources=cached_sources, **build_kwargs)
|
221 |
+
else:
|
222 |
+
torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
|
223 |
+
|
224 |
+
# Load.
|
225 |
+
module = importlib.import_module(module_name)
|
226 |
+
|
227 |
+
except:
|
228 |
+
if verbosity == 'brief':
|
229 |
+
print('Failed!')
|
230 |
+
raise
|
231 |
+
|
232 |
+
# Print status and add to cache dict.
|
233 |
+
if verbosity == 'full':
|
234 |
+
print(f'Done setting up PyTorch plugin "{module_name}".')
|
235 |
+
elif verbosity == 'brief':
|
236 |
+
print('Done.')
|
237 |
+
_cached_plugins[module_name] = module
|
238 |
+
return module
|
torch_utils/misc.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
+
|
3 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
+
# and proprietary rights in and to this software, related documentation
|
7 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
+
# distribution of this software and related documentation without an express
|
9 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
+
|
11 |
+
import re
|
12 |
+
import contextlib
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import warnings
|
16 |
+
import dnnlib
|
17 |
+
|
18 |
+
#----------------------------------------------------------------------------
|
19 |
+
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
|
20 |
+
# same constant is used multiple times.
|
21 |
+
|
22 |
+
_constant_cache = dict()
|
23 |
+
|
24 |
+
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
|
25 |
+
value = np.asarray(value)
|
26 |
+
if shape is not None:
|
27 |
+
shape = tuple(shape)
|
28 |
+
if dtype is None:
|
29 |
+
dtype = torch.get_default_dtype()
|
30 |
+
if device is None:
|
31 |
+
device = torch.device('cpu')
|
32 |
+
if memory_format is None:
|
33 |
+
memory_format = torch.contiguous_format
|
34 |
+
|
35 |
+
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
|
36 |
+
tensor = _constant_cache.get(key, None)
|
37 |
+
if tensor is None:
|
38 |
+
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
|
39 |
+
if shape is not None:
|
40 |
+
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
|
41 |
+
tensor = tensor.contiguous(memory_format=memory_format)
|
42 |
+
_constant_cache[key] = tensor
|
43 |
+
return tensor
|
44 |
+
|
45 |
+
#----------------------------------------------------------------------------
|
46 |
+
# Replace NaN/Inf with specified numerical values.
|
47 |
+
|
48 |
+
try:
|
49 |
+
nan_to_num = torch.nan_to_num # 1.8.0a0
|
50 |
+
except AttributeError:
|
51 |
+
def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
|
52 |
+
assert isinstance(input, torch.Tensor)
|
53 |
+
if posinf is None:
|
54 |
+
posinf = torch.finfo(input.dtype).max
|
55 |
+
if neginf is None:
|
56 |
+
neginf = torch.finfo(input.dtype).min
|
57 |
+
assert nan == 0
|
58 |
+
return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
|
59 |
+
|
60 |
+
#----------------------------------------------------------------------------
|
61 |
+
# Symbolic assert.
|
62 |
+
|
63 |
+
try:
|
64 |
+
symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
|
65 |
+
except AttributeError:
|
66 |
+
symbolic_assert = torch.Assert # 1.7.0
|
67 |
+
|
68 |
+
#----------------------------------------------------------------------------
|
69 |
+
# Context manager to suppress known warnings in torch.jit.trace().
|
70 |
+
|
71 |
+
class suppress_tracer_warnings(warnings.catch_warnings):
|
72 |
+
def __enter__(self):
|
73 |
+
super().__enter__()
|
74 |
+
warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
|
75 |
+
return self
|
76 |
+
|
77 |
+
#----------------------------------------------------------------------------
|
78 |
+
# Assert that the shape of a tensor matches the given list of integers.
|
79 |
+
# None indicates that the size of a dimension is allowed to vary.
|
80 |
+
# Performs symbolic assertion when used in torch.jit.trace().
|
81 |
+
|
82 |
+
def assert_shape(tensor, ref_shape):
|
83 |
+
if tensor.ndim != len(ref_shape):
|
84 |
+
raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
|
85 |
+
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
|
86 |
+
if ref_size is None:
|
87 |
+
pass
|
88 |
+
elif isinstance(ref_size, torch.Tensor):
|
89 |
+
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
90 |
+
symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
|
91 |
+
elif isinstance(size, torch.Tensor):
|
92 |
+
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
93 |
+
symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
|
94 |
+
elif size != ref_size:
|
95 |
+
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
|
96 |
+
|
97 |
+
#----------------------------------------------------------------------------
|
98 |
+
# Function decorator that calls torch.autograd.profiler.record_function().
|
99 |
+
|
100 |
+
def profiled_function(fn):
|
101 |
+
def decorator(*args, **kwargs):
|
102 |
+
with torch.autograd.profiler.record_function(fn.__name__):
|
103 |
+
return fn(*args, **kwargs)
|
104 |
+
decorator.__name__ = fn.__name__
|
105 |
+
return decorator
|
106 |
+
|
107 |
+
#----------------------------------------------------------------------------
|
108 |
+
# Sampler for torch.utils.data.DataLoader that loops over the dataset
|
109 |
+
# indefinitely, shuffling items as it goes.
|
110 |
+
|
111 |
+
class InfiniteSampler(torch.utils.data.Sampler):
|
112 |
+
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
|
113 |
+
assert len(dataset) > 0
|
114 |
+
assert num_replicas > 0
|
115 |
+
assert 0 <= rank < num_replicas
|
116 |
+
assert 0 <= window_size <= 1
|
117 |
+
super().__init__(dataset)
|
118 |
+
self.dataset = dataset
|
119 |
+
self.rank = rank
|
120 |
+
self.num_replicas = num_replicas
|
121 |
+
self.shuffle = shuffle
|
122 |
+
self.seed = seed
|
123 |
+
self.window_size = window_size
|
124 |
+
|
125 |
+
def __iter__(self):
|
126 |
+
order = np.arange(len(self.dataset))
|
127 |
+
rnd = None
|
128 |
+
window = 0
|
129 |
+
if self.shuffle:
|
130 |
+
rnd = np.random.RandomState(self.seed)
|
131 |
+
rnd.shuffle(order)
|
132 |
+
window = int(np.rint(order.size * self.window_size))
|
133 |
+
|
134 |
+
idx = 0
|
135 |
+
while True:
|
136 |
+
i = idx % order.size
|
137 |
+
if idx % self.num_replicas == self.rank:
|
138 |
+
yield order[i]
|
139 |
+
if window >= 2:
|
140 |
+
j = (i - rnd.randint(window)) % order.size
|
141 |
+
order[i], order[j] = order[j], order[i]
|
142 |
+
idx += 1
|
143 |
+
|
144 |
+
#----------------------------------------------------------------------------
|
145 |
+
# Utilities for operating with torch.nn.Module parameters and buffers.
|
146 |
+
|
147 |
+
def params_and_buffers(module):
|
148 |
+
assert isinstance(module, torch.nn.Module)
|
149 |
+
return list(module.parameters()) + list(module.buffers())
|
150 |
+
|
151 |
+
def named_params_and_buffers(module):
|
152 |
+
assert isinstance(module, torch.nn.Module)
|
153 |
+
return list(module.named_parameters()) + list(module.named_buffers())
|
154 |
+
|
155 |
+
def copy_params_and_buffers(src_module, dst_module, require_all=False):
|
156 |
+
assert isinstance(src_module, torch.nn.Module)
|
157 |
+
assert isinstance(dst_module, torch.nn.Module)
|
158 |
+
src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
|
159 |
+
for name, tensor in named_params_and_buffers(dst_module):
|
160 |
+
assert (name in src_tensors) or (not require_all)
|
161 |
+
if name in src_tensors:
|
162 |
+
tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
|
163 |
+
|
164 |
+
#----------------------------------------------------------------------------
|
165 |
+
# Context manager for easily enabling/disabling DistributedDataParallel
|
166 |
+
# synchronization.
|
167 |
+
|
168 |
+
@contextlib.contextmanager
|
169 |
+
def ddp_sync(module, sync):
|
170 |
+
assert isinstance(module, torch.nn.Module)
|
171 |
+
if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
|
172 |
+
yield
|
173 |
+
else:
|
174 |
+
with module.no_sync():
|
175 |
+
yield
|
176 |
+
|
177 |
+
#----------------------------------------------------------------------------
|
178 |
+
# Check DistributedDataParallel consistency across processes.
|
179 |
+
|
180 |
+
def check_ddp_consistency(module, ignore_regex=None):
|
181 |
+
assert isinstance(module, torch.nn.Module)
|
182 |
+
for name, tensor in named_params_and_buffers(module):
|
183 |
+
fullname = type(module).__name__ + '.' + name
|
184 |
+
if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
|
185 |
+
continue
|
186 |
+
tensor = tensor.detach()
|
187 |
+
other = tensor.clone()
|
188 |
+
torch.distributed.broadcast(tensor=other, src=0)
|
189 |
+
assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
|
190 |
+
|
191 |
+
#----------------------------------------------------------------------------
|
192 |
+
# Print summary table of module hierarchy.
|
193 |
+
|
194 |
+
def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
|
195 |
+
assert isinstance(module, torch.nn.Module)
|
196 |
+
assert not isinstance(module, torch.jit.ScriptModule)
|
197 |
+
assert isinstance(inputs, (tuple, list))
|
198 |
+
|
199 |
+
# Register hooks.
|
200 |
+
entries = []
|
201 |
+
nesting = [0]
|
202 |
+
def pre_hook(_mod, _inputs):
|
203 |
+
nesting[0] += 1
|
204 |
+
def post_hook(mod, _inputs, outputs):
|
205 |
+
nesting[0] -= 1
|
206 |
+
if nesting[0] <= max_nesting:
|
207 |
+
outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
|
208 |
+
outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
|
209 |
+
entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
|
210 |
+
hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
|
211 |
+
hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
|
212 |
+
|
213 |
+
# Run module.
|
214 |
+
outputs = module(*inputs)
|
215 |
+
for hook in hooks:
|
216 |
+
hook.remove()
|
217 |
+
|
218 |
+
# Identify unique outputs, parameters, and buffers.
|
219 |
+
tensors_seen = set()
|
220 |
+
for e in entries:
|
221 |
+
e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
|
222 |
+
e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
|
223 |
+
e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
|
224 |
+
tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
|
225 |
+
|
226 |
+
# Filter out redundant entries.
|
227 |
+
if skip_redundant:
|
228 |
+
entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
|
229 |
+
|
230 |
+
# Construct table.
|
231 |
+
rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
|
232 |
+
rows += [['---'] * len(rows[0])]
|
233 |
+
param_total = 0
|
234 |
+
buffer_total = 0
|
235 |
+
submodule_names = {mod: name for name, mod in module.named_modules()}
|
236 |
+
for e in entries:
|
237 |
+
name = '<top-level>' if e.mod is module else submodule_names[e.mod]
|
238 |
+
param_size = sum(t.numel() for t in e.unique_params)
|
239 |
+
buffer_size = sum(t.numel() for t in e.unique_buffers)
|
240 |
+
output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
|
241 |
+
output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
|
242 |
+
rows += [[
|
243 |
+
name + (':0' if len(e.outputs) >= 2 else ''),
|
244 |
+
str(param_size) if param_size else '-',
|
245 |
+
str(buffer_size) if buffer_size else '-',
|
246 |
+
(output_shapes + ['-'])[0],
|
247 |
+
(output_dtypes + ['-'])[0],
|
248 |
+
]]
|
249 |
+
for idx in range(1, len(e.outputs)):
|
250 |
+
rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
|
251 |
+
param_total += param_size
|
252 |
+
buffer_total += buffer_size
|
253 |
+
rows += [['---'] * len(rows[0])]
|
254 |
+
rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
|
255 |
+
|
256 |
+
# Print table.
|
257 |
+
widths = [max(len(cell) for cell in column) for column in zip(*rows)]
|
258 |
+
print()
|
259 |
+
for row in rows:
|
260 |
+
print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
|
261 |
+
print()
|
262 |
+
return outputs
|
263 |
+
|
264 |
+
#----------------------------------------------------------------------------
|
torch_utils/models.py
ADDED
@@ -0,0 +1,756 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
+
|
3 |
+
# https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py
|
4 |
+
|
5 |
+
import math
|
6 |
+
import random
|
7 |
+
import functools
|
8 |
+
import operator
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
from torch.nn import functional as F
|
13 |
+
import torch.nn.init as init
|
14 |
+
from torch.autograd import Function
|
15 |
+
|
16 |
+
from .op_edit import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
|
17 |
+
|
18 |
+
|
19 |
+
class PixelNorm(nn.Module):
|
20 |
+
def __init__(self):
|
21 |
+
super().__init__()
|
22 |
+
|
23 |
+
def forward(self, input):
|
24 |
+
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
|
25 |
+
|
26 |
+
|
27 |
+
def make_kernel(k):
|
28 |
+
k = torch.tensor(k, dtype=torch.float32)
|
29 |
+
if k.ndim == 1:
|
30 |
+
k = k[None, :] * k[:, None]
|
31 |
+
k /= k.sum()
|
32 |
+
return k
|
33 |
+
|
34 |
+
|
35 |
+
class Upsample(nn.Module):
|
36 |
+
def __init__(self, kernel, factor=2):
|
37 |
+
super().__init__()
|
38 |
+
|
39 |
+
self.factor = factor
|
40 |
+
kernel = make_kernel(kernel) * (factor ** 2)
|
41 |
+
self.register_buffer("kernel", kernel)
|
42 |
+
|
43 |
+
p = kernel.shape[0] - factor
|
44 |
+
|
45 |
+
pad0 = (p + 1) // 2 + factor - 1
|
46 |
+
pad1 = p // 2
|
47 |
+
|
48 |
+
self.pad = (pad0, pad1)
|
49 |
+
|
50 |
+
def forward(self, input):
|
51 |
+
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
52 |
+
return out
|
53 |
+
|
54 |
+
|
55 |
+
class Downsample(nn.Module):
|
56 |
+
def __init__(self, kernel, factor=2):
|
57 |
+
super().__init__()
|
58 |
+
|
59 |
+
self.factor = factor
|
60 |
+
kernel = make_kernel(kernel)
|
61 |
+
self.register_buffer("kernel", kernel)
|
62 |
+
|
63 |
+
p = kernel.shape[0] - factor
|
64 |
+
|
65 |
+
pad0 = (p + 1) // 2
|
66 |
+
pad1 = p // 2
|
67 |
+
|
68 |
+
self.pad = (pad0, pad1)
|
69 |
+
|
70 |
+
def forward(self, input):
|
71 |
+
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
|
72 |
+
return out
|
73 |
+
|
74 |
+
|
75 |
+
class Blur(nn.Module):
|
76 |
+
def __init__(self, kernel, pad, upsample_factor=1):
|
77 |
+
super().__init__()
|
78 |
+
|
79 |
+
kernel = make_kernel(kernel)
|
80 |
+
|
81 |
+
if upsample_factor > 1:
|
82 |
+
kernel = kernel * (upsample_factor ** 2)
|
83 |
+
|
84 |
+
self.register_buffer("kernel", kernel)
|
85 |
+
|
86 |
+
self.pad = pad
|
87 |
+
|
88 |
+
def forward(self, input):
|
89 |
+
out = upfirdn2d(input, self.kernel, pad=self.pad)
|
90 |
+
return out
|
91 |
+
|
92 |
+
|
93 |
+
class EqualConv2d(nn.Module):
|
94 |
+
def __init__(
|
95 |
+
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
|
96 |
+
):
|
97 |
+
super().__init__()
|
98 |
+
|
99 |
+
self.weight = nn.Parameter(
|
100 |
+
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
|
101 |
+
)
|
102 |
+
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
103 |
+
|
104 |
+
self.stride = stride
|
105 |
+
self.padding = padding
|
106 |
+
|
107 |
+
if bias:
|
108 |
+
self.bias = nn.Parameter(torch.zeros(out_channel))
|
109 |
+
|
110 |
+
else:
|
111 |
+
self.bias = None
|
112 |
+
|
113 |
+
def forward(self, input):
|
114 |
+
out = F.conv2d(
|
115 |
+
input,
|
116 |
+
self.weight * self.scale,
|
117 |
+
bias=self.bias,
|
118 |
+
stride=self.stride,
|
119 |
+
padding=self.padding,
|
120 |
+
)
|
121 |
+
return out
|
122 |
+
|
123 |
+
def __repr__(self):
|
124 |
+
return (
|
125 |
+
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
|
126 |
+
f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
|
127 |
+
)
|
128 |
+
|
129 |
+
|
130 |
+
class EqualLinear(nn.Module):
|
131 |
+
def __init__(
|
132 |
+
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
|
133 |
+
):
|
134 |
+
super().__init__()
|
135 |
+
|
136 |
+
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
137 |
+
|
138 |
+
if bias:
|
139 |
+
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
140 |
+
else:
|
141 |
+
self.bias = None
|
142 |
+
|
143 |
+
self.activation = activation
|
144 |
+
|
145 |
+
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
146 |
+
self.lr_mul = lr_mul
|
147 |
+
|
148 |
+
def forward(self, input):
|
149 |
+
if self.activation:
|
150 |
+
out = F.linear(input, self.weight * self.scale)
|
151 |
+
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
152 |
+
else:
|
153 |
+
out = F.linear(
|
154 |
+
input, self.weight * self.scale, bias=self.bias * self.lr_mul
|
155 |
+
)
|
156 |
+
return out
|
157 |
+
|
158 |
+
def __repr__(self):
|
159 |
+
return (
|
160 |
+
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
|
161 |
+
)
|
162 |
+
|
163 |
+
|
164 |
+
class ScaledLeakyReLU(nn.Module):
|
165 |
+
def __init__(self, negative_slope=0.2):
|
166 |
+
super().__init__()
|
167 |
+
self.negative_slope = negative_slope
|
168 |
+
|
169 |
+
def forward(self, input):
|
170 |
+
out = F.leaky_relu(input, negative_slope=self.negative_slope)
|
171 |
+
return out * math.sqrt(2)
|
172 |
+
|
173 |
+
|
174 |
+
class ModulatedConv2d(nn.Module):
|
175 |
+
def __init__(
|
176 |
+
self,
|
177 |
+
in_channel,
|
178 |
+
out_channel,
|
179 |
+
kernel_size,
|
180 |
+
style_dim,
|
181 |
+
demodulate=True,
|
182 |
+
upsample=False,
|
183 |
+
downsample=False,
|
184 |
+
blur_kernel=[1, 3, 3, 1],
|
185 |
+
):
|
186 |
+
super().__init__()
|
187 |
+
|
188 |
+
self.eps = 1e-8
|
189 |
+
self.kernel_size = kernel_size
|
190 |
+
self.in_channel = in_channel
|
191 |
+
self.out_channel = out_channel
|
192 |
+
self.upsample = upsample
|
193 |
+
self.downsample = downsample
|
194 |
+
|
195 |
+
if upsample:
|
196 |
+
factor = 2
|
197 |
+
p = (len(blur_kernel) - factor) - (kernel_size - 1)
|
198 |
+
pad0 = (p + 1) // 2 + factor - 1
|
199 |
+
pad1 = p // 2 + 1
|
200 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
|
201 |
+
|
202 |
+
if downsample:
|
203 |
+
factor = 2
|
204 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
205 |
+
pad0 = (p + 1) // 2
|
206 |
+
pad1 = p // 2
|
207 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
|
208 |
+
|
209 |
+
fan_in = in_channel * kernel_size ** 2
|
210 |
+
self.scale = 1 / math.sqrt(fan_in)
|
211 |
+
self.padding = kernel_size // 2
|
212 |
+
self.weight = nn.Parameter(
|
213 |
+
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
|
214 |
+
)
|
215 |
+
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
|
216 |
+
self.demodulate = demodulate
|
217 |
+
|
218 |
+
def __repr__(self):
|
219 |
+
return (
|
220 |
+
f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
|
221 |
+
f"upsample={self.upsample}, downsample={self.downsample})"
|
222 |
+
)
|
223 |
+
|
224 |
+
def forward(self, input, style):
|
225 |
+
batch, in_channel, height, width = input.shape
|
226 |
+
|
227 |
+
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
|
228 |
+
weight = self.scale * self.weight * style
|
229 |
+
|
230 |
+
if self.demodulate:
|
231 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
|
232 |
+
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
|
233 |
+
|
234 |
+
weight = weight.view(
|
235 |
+
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
236 |
+
)
|
237 |
+
|
238 |
+
if self.upsample:
|
239 |
+
input = input.view(1, batch * in_channel, height, width)
|
240 |
+
weight = weight.view(
|
241 |
+
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
242 |
+
)
|
243 |
+
weight = weight.transpose(1, 2).reshape(
|
244 |
+
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
|
245 |
+
)
|
246 |
+
out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
|
247 |
+
_, _, height, width = out.shape
|
248 |
+
out = out.view(batch, self.out_channel, height, width)
|
249 |
+
out = self.blur(out)
|
250 |
+
|
251 |
+
elif self.downsample:
|
252 |
+
input = self.blur(input)
|
253 |
+
_, _, height, width = input.shape
|
254 |
+
input = input.view(1, batch * in_channel, height, width)
|
255 |
+
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
|
256 |
+
_, _, height, width = out.shape
|
257 |
+
out = out.view(batch, self.out_channel, height, width)
|
258 |
+
|
259 |
+
else:
|
260 |
+
input = input.view(1, batch * in_channel, height, width)
|
261 |
+
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
|
262 |
+
_, _, height, width = out.shape
|
263 |
+
out = out.view(batch, self.out_channel, height, width)
|
264 |
+
|
265 |
+
return out
|
266 |
+
|
267 |
+
|
268 |
+
class NoiseInjection(nn.Module):
|
269 |
+
def __init__(self):
|
270 |
+
super().__init__()
|
271 |
+
self.weight = nn.Parameter(torch.zeros(1))
|
272 |
+
|
273 |
+
def forward(self, image, noise=None):
|
274 |
+
if noise is None:
|
275 |
+
batch, _, height, width = image.shape
|
276 |
+
noise = image.new_empty(batch, 1, height, width).normal_()
|
277 |
+
return image + self.weight * noise
|
278 |
+
|
279 |
+
|
280 |
+
class ConstantInput(nn.Module):
|
281 |
+
def __init__(self, channel, size=4):
|
282 |
+
super().__init__()
|
283 |
+
self.input = nn.Parameter(torch.randn(1, channel, size, size // 2))
|
284 |
+
|
285 |
+
def forward(self, input):
|
286 |
+
batch = input.shape[0]
|
287 |
+
out = self.input.repeat(batch, 1, 1, 1)
|
288 |
+
return out
|
289 |
+
|
290 |
+
|
291 |
+
class StyledConv(nn.Module):
|
292 |
+
def __init__(
|
293 |
+
self,
|
294 |
+
in_channel,
|
295 |
+
out_channel,
|
296 |
+
kernel_size,
|
297 |
+
style_dim,
|
298 |
+
upsample=False,
|
299 |
+
blur_kernel=[1, 3, 3, 1],
|
300 |
+
demodulate=True,
|
301 |
+
):
|
302 |
+
super().__init__()
|
303 |
+
self.conv = ModulatedConv2d(
|
304 |
+
in_channel,
|
305 |
+
out_channel,
|
306 |
+
kernel_size,
|
307 |
+
style_dim,
|
308 |
+
upsample=upsample,
|
309 |
+
blur_kernel=blur_kernel,
|
310 |
+
demodulate=demodulate,
|
311 |
+
)
|
312 |
+
self.noise = NoiseInjection()
|
313 |
+
self.activate = FusedLeakyReLU(out_channel)
|
314 |
+
|
315 |
+
def forward(self, input, style, noise=None):
|
316 |
+
out = self.conv(input, style)
|
317 |
+
out = self.noise(out, noise=noise)
|
318 |
+
out = self.activate(out)
|
319 |
+
return out
|
320 |
+
|
321 |
+
|
322 |
+
class ToRGB(nn.Module):
|
323 |
+
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
|
324 |
+
super().__init__()
|
325 |
+
if upsample:
|
326 |
+
self.upsample = Upsample(blur_kernel)
|
327 |
+
|
328 |
+
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
|
329 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
330 |
+
|
331 |
+
def forward(self, input, style, skip=None):
|
332 |
+
out = self.conv(input, style)
|
333 |
+
out = out + self.bias
|
334 |
+
|
335 |
+
if skip is not None:
|
336 |
+
skip = self.upsample(skip)
|
337 |
+
out = out + skip
|
338 |
+
|
339 |
+
return out
|
340 |
+
|
341 |
+
|
342 |
+
class Generator(nn.Module):
|
343 |
+
def __init__(
|
344 |
+
self,
|
345 |
+
size,
|
346 |
+
style_dim,
|
347 |
+
n_mlp,
|
348 |
+
channel_multiplier=1,
|
349 |
+
blur_kernel=[1, 3, 3, 1],
|
350 |
+
lr_mlp=0.01,
|
351 |
+
small=False,
|
352 |
+
small_isaac=False,
|
353 |
+
):
|
354 |
+
super().__init__()
|
355 |
+
|
356 |
+
self.size = size
|
357 |
+
|
358 |
+
if small and size > 64:
|
359 |
+
raise ValueError("small only works for sizes <= 64")
|
360 |
+
|
361 |
+
self.style_dim = style_dim
|
362 |
+
layers = [PixelNorm()]
|
363 |
+
|
364 |
+
for i in range(n_mlp):
|
365 |
+
layers.append(
|
366 |
+
EqualLinear(
|
367 |
+
style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
|
368 |
+
)
|
369 |
+
)
|
370 |
+
|
371 |
+
self.style = nn.Sequential(*layers)
|
372 |
+
|
373 |
+
if small:
|
374 |
+
self.channels = {
|
375 |
+
4: 64 * channel_multiplier,
|
376 |
+
8: 64 * channel_multiplier,
|
377 |
+
16: 64 * channel_multiplier,
|
378 |
+
32: 64 * channel_multiplier,
|
379 |
+
64: 64 * channel_multiplier,
|
380 |
+
}
|
381 |
+
elif small_isaac:
|
382 |
+
self.channels = {4: 256, 8: 256, 16: 256, 32: 256, 64: 128, 128: 128}
|
383 |
+
else:
|
384 |
+
self.channels = {
|
385 |
+
4: 512,
|
386 |
+
8: 512,
|
387 |
+
16: 512,
|
388 |
+
32: 512,
|
389 |
+
64: 256 * channel_multiplier,
|
390 |
+
128: 128 * channel_multiplier,
|
391 |
+
256: 64 * channel_multiplier,
|
392 |
+
512: 32 * channel_multiplier,
|
393 |
+
1024: 16 * channel_multiplier,
|
394 |
+
}
|
395 |
+
|
396 |
+
self.input = ConstantInput(self.channels[4])
|
397 |
+
self.conv1 = StyledConv(
|
398 |
+
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
|
399 |
+
)
|
400 |
+
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
|
401 |
+
|
402 |
+
self.log_size = int(math.log(size, 2))
|
403 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
404 |
+
|
405 |
+
self.convs = nn.ModuleList()
|
406 |
+
self.upsamples = nn.ModuleList()
|
407 |
+
self.to_rgbs = nn.ModuleList()
|
408 |
+
self.noises = nn.Module()
|
409 |
+
|
410 |
+
in_channel = self.channels[4]
|
411 |
+
|
412 |
+
for layer_idx in range(self.num_layers):
|
413 |
+
res = (layer_idx + 5) // 2
|
414 |
+
shape = [1, 1, 2 ** res, 2 ** res // 2]
|
415 |
+
self.noises.register_buffer(
|
416 |
+
"noise_{}".format(layer_idx), torch.randn(*shape)
|
417 |
+
)
|
418 |
+
|
419 |
+
for i in range(3, self.log_size + 1):
|
420 |
+
out_channel = self.channels[2 ** i]
|
421 |
+
|
422 |
+
self.convs.append(
|
423 |
+
StyledConv(
|
424 |
+
in_channel,
|
425 |
+
out_channel,
|
426 |
+
3,
|
427 |
+
style_dim,
|
428 |
+
upsample=True,
|
429 |
+
blur_kernel=blur_kernel,
|
430 |
+
)
|
431 |
+
)
|
432 |
+
|
433 |
+
self.convs.append(
|
434 |
+
StyledConv(
|
435 |
+
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
|
436 |
+
)
|
437 |
+
)
|
438 |
+
|
439 |
+
self.to_rgbs.append(ToRGB(out_channel, style_dim))
|
440 |
+
in_channel = out_channel
|
441 |
+
|
442 |
+
self.n_latent = self.log_size * 2 - 2
|
443 |
+
|
444 |
+
def make_noise(self):
|
445 |
+
device = self.input.input.device
|
446 |
+
|
447 |
+
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2 // 2, device=device)]
|
448 |
+
|
449 |
+
for i in range(3, self.log_size + 1):
|
450 |
+
for _ in range(2):
|
451 |
+
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i // 2, device=device))
|
452 |
+
|
453 |
+
return noises
|
454 |
+
|
455 |
+
def mean_latent(self, n_latent):
|
456 |
+
latent_in = torch.randn(
|
457 |
+
n_latent, self.style_dim, device=self.input.input.device
|
458 |
+
)
|
459 |
+
latent = self.style(latent_in).mean(0, keepdim=True)
|
460 |
+
|
461 |
+
return latent
|
462 |
+
|
463 |
+
def get_latent(self, input):
|
464 |
+
return self.style(input)
|
465 |
+
|
466 |
+
def forward(
|
467 |
+
self,
|
468 |
+
styles,
|
469 |
+
return_latents=False,
|
470 |
+
return_features=False,
|
471 |
+
inject_index=None,
|
472 |
+
truncation=1,
|
473 |
+
truncation_latent=None,
|
474 |
+
input_is_latent=False,
|
475 |
+
noise=None,
|
476 |
+
randomize_noise=True,
|
477 |
+
real=False,
|
478 |
+
):
|
479 |
+
if not input_is_latent:
|
480 |
+
styles = [self.style(s) for s in styles]
|
481 |
+
if noise is None:
|
482 |
+
if randomize_noise:
|
483 |
+
noise = [None] * self.num_layers
|
484 |
+
else:
|
485 |
+
noise = [
|
486 |
+
getattr(self.noises, "noise_{}".format(i))
|
487 |
+
for i in range(self.num_layers)
|
488 |
+
]
|
489 |
+
|
490 |
+
if truncation < 1:
|
491 |
+
# print('truncation_latent: ', truncation_latent.shape)
|
492 |
+
if not real: #if type(styles) == list:
|
493 |
+
style_t = []
|
494 |
+
for style in styles:
|
495 |
+
style_t.append(
|
496 |
+
truncation_latent + truncation * (style - truncation_latent)
|
497 |
+
) # (-1.1162e-03-(-1.0914e-01))*0.8+(-1.0914e-01)
|
498 |
+
styles = style_t
|
499 |
+
else: # styles are latent (tensor: 1,18,512), for real PTI output
|
500 |
+
truncation_latent = truncation_latent.repeat(18,1).unsqueeze(0) # (1,512) --> (1,18,512)
|
501 |
+
styles = torch.add(truncation_latent,torch.mul(torch.sub(styles,truncation_latent),truncation))
|
502 |
+
# print('now styles after truncation : ', styles)
|
503 |
+
#if type(styles) == list and len(styles) < 2: # this if for input as list of [(1,512)]
|
504 |
+
if not real:
|
505 |
+
if len(styles) < 2:
|
506 |
+
inject_index = self.n_latent
|
507 |
+
if styles[0].ndim < 3:
|
508 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
509 |
+
else:
|
510 |
+
latent = styles[0]
|
511 |
+
elif type(styles) == list:
|
512 |
+
if inject_index is None:
|
513 |
+
inject_index = 4
|
514 |
+
|
515 |
+
latent = styles[0].unsqueeze(0)
|
516 |
+
if latent.shape[1] == 1:
|
517 |
+
latent = latent.repeat(1, inject_index, 1)
|
518 |
+
else:
|
519 |
+
latent = latent[:, :inject_index, :]
|
520 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
|
521 |
+
latent = torch.cat([latent, latent2], 1)
|
522 |
+
else: # input is tensor of size with torch.Size([1, 18, 512]), for real PTI output
|
523 |
+
latent = styles
|
524 |
+
|
525 |
+
# print(f'processed latent: {latent.shape}')
|
526 |
+
|
527 |
+
features = {}
|
528 |
+
out = self.input(latent)
|
529 |
+
features["out_0"] = out
|
530 |
+
out = self.conv1(out, latent[:, 0], noise=noise[0])
|
531 |
+
features["conv1_0"] = out
|
532 |
+
|
533 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
534 |
+
features["skip_0"] = skip
|
535 |
+
i = 1
|
536 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
537 |
+
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
|
538 |
+
):
|
539 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
540 |
+
features["conv1_{}".format(i)] = out
|
541 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
542 |
+
features["conv2_{}".format(i)] = out
|
543 |
+
skip = to_rgb(out, latent[:, i + 2], skip)
|
544 |
+
features["skip_{}".format(i)] = skip
|
545 |
+
|
546 |
+
i += 2
|
547 |
+
|
548 |
+
image = skip
|
549 |
+
|
550 |
+
if return_latents:
|
551 |
+
return image, latent
|
552 |
+
elif return_features:
|
553 |
+
return image, features
|
554 |
+
else:
|
555 |
+
return image, None
|
556 |
+
|
557 |
+
|
558 |
+
class ConvLayer(nn.Sequential):
|
559 |
+
def __init__(
|
560 |
+
self,
|
561 |
+
in_channel,
|
562 |
+
out_channel,
|
563 |
+
kernel_size,
|
564 |
+
downsample=False,
|
565 |
+
blur_kernel=[1, 3, 3, 1],
|
566 |
+
bias=True,
|
567 |
+
activate=True,
|
568 |
+
):
|
569 |
+
layers = []
|
570 |
+
|
571 |
+
if downsample:
|
572 |
+
factor = 2
|
573 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
574 |
+
pad0 = (p + 1) // 2
|
575 |
+
pad1 = p // 2
|
576 |
+
|
577 |
+
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
578 |
+
|
579 |
+
stride = 2
|
580 |
+
self.padding = 0
|
581 |
+
|
582 |
+
else:
|
583 |
+
stride = 1
|
584 |
+
self.padding = kernel_size // 2
|
585 |
+
|
586 |
+
layers.append(
|
587 |
+
EqualConv2d(
|
588 |
+
in_channel,
|
589 |
+
out_channel,
|
590 |
+
kernel_size,
|
591 |
+
padding=self.padding,
|
592 |
+
stride=stride,
|
593 |
+
bias=bias and not activate,
|
594 |
+
)
|
595 |
+
)
|
596 |
+
|
597 |
+
if activate:
|
598 |
+
if bias:
|
599 |
+
layers.append(FusedLeakyReLU(out_channel))
|
600 |
+
else:
|
601 |
+
layers.append(ScaledLeakyReLU(0.2))
|
602 |
+
|
603 |
+
super().__init__(*layers)
|
604 |
+
|
605 |
+
|
606 |
+
class ResBlock(nn.Module):
|
607 |
+
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
608 |
+
super().__init__()
|
609 |
+
|
610 |
+
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
611 |
+
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
612 |
+
|
613 |
+
self.skip = ConvLayer(
|
614 |
+
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
|
615 |
+
)
|
616 |
+
|
617 |
+
def forward(self, input):
|
618 |
+
out = self.conv1(input)
|
619 |
+
out = self.conv2(out)
|
620 |
+
|
621 |
+
skip = self.skip(input)
|
622 |
+
out = (out + skip) / math.sqrt(2)
|
623 |
+
|
624 |
+
return out
|
625 |
+
|
626 |
+
|
627 |
+
class StyleDiscriminator(nn.Module):
|
628 |
+
def __init__(
|
629 |
+
self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], small=False
|
630 |
+
):
|
631 |
+
super().__init__()
|
632 |
+
|
633 |
+
if small:
|
634 |
+
channels = {4: 64, 8: 64, 16: 64, 32: 64, 64: 64}
|
635 |
+
|
636 |
+
else:
|
637 |
+
channels = {
|
638 |
+
4: 512,
|
639 |
+
8: 512,
|
640 |
+
16: 512,
|
641 |
+
32: 512,
|
642 |
+
64: 256 * channel_multiplier,
|
643 |
+
128: 128 * channel_multiplier,
|
644 |
+
256: 64 * channel_multiplier,
|
645 |
+
512: 32 * channel_multiplier,
|
646 |
+
1024: 16 * channel_multiplier,
|
647 |
+
}
|
648 |
+
|
649 |
+
convs = [ConvLayer(3, channels[size], 1)]
|
650 |
+
|
651 |
+
log_size = int(math.log(size, 2))
|
652 |
+
in_channel = channels[size]
|
653 |
+
|
654 |
+
for i in range(log_size, 2, -1):
|
655 |
+
out_channel = channels[2 ** (i - 1)]
|
656 |
+
|
657 |
+
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
|
658 |
+
|
659 |
+
in_channel = out_channel
|
660 |
+
|
661 |
+
self.convs = nn.Sequential(*convs)
|
662 |
+
|
663 |
+
self.stddev_group = 4
|
664 |
+
self.stddev_feat = 1
|
665 |
+
|
666 |
+
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
|
667 |
+
self.final_linear = nn.Sequential(
|
668 |
+
EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
|
669 |
+
EqualLinear(channels[4], 1),
|
670 |
+
)
|
671 |
+
|
672 |
+
|
673 |
+
def forward(self, input):
|
674 |
+
h = input
|
675 |
+
h_list = []
|
676 |
+
|
677 |
+
for index, blocklist in enumerate(self.convs):
|
678 |
+
h = blocklist(h)
|
679 |
+
h_list.append(h)
|
680 |
+
|
681 |
+
out = h
|
682 |
+
batch, channel, height, width = out.shape
|
683 |
+
group = min(batch, self.stddev_group)
|
684 |
+
stddev = out.view(
|
685 |
+
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
|
686 |
+
)
|
687 |
+
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
688 |
+
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
689 |
+
stddev = stddev.repeat(group, 1, height, width)
|
690 |
+
out = torch.cat([out, stddev], 1)
|
691 |
+
|
692 |
+
out = self.final_conv(out)
|
693 |
+
h_list.append(out)
|
694 |
+
|
695 |
+
out = out.view(batch, -1)
|
696 |
+
out = self.final_linear(out)
|
697 |
+
|
698 |
+
return out, h_list
|
699 |
+
|
700 |
+
|
701 |
+
class StyleEncoder(nn.Module):
|
702 |
+
def __init__(self, size, w_dim=512):
|
703 |
+
super().__init__()
|
704 |
+
|
705 |
+
channels = {
|
706 |
+
4: 512,
|
707 |
+
8: 512,
|
708 |
+
16: 512,
|
709 |
+
32: 512,
|
710 |
+
64: 256,
|
711 |
+
128: 128,
|
712 |
+
256: 64,
|
713 |
+
512: 32,
|
714 |
+
1024: 16
|
715 |
+
}
|
716 |
+
|
717 |
+
self.w_dim = w_dim
|
718 |
+
log_size = int(math.log(size, 2))
|
719 |
+
convs = [ConvLayer(3, channels[size], 1)]
|
720 |
+
|
721 |
+
in_channel = channels[size]
|
722 |
+
for i in range(log_size, 2, -1):
|
723 |
+
out_channel = channels[2 ** (i - 1)]
|
724 |
+
convs.append(ResBlock(in_channel, out_channel))
|
725 |
+
in_channel = out_channel
|
726 |
+
|
727 |
+
convs.append(EqualConv2d(in_channel,2*self.w_dim, 4, padding=0, bias=False))
|
728 |
+
|
729 |
+
self.convs = nn.Sequential(*convs)
|
730 |
+
|
731 |
+
def forward(self, input):
|
732 |
+
out = self.convs(input)
|
733 |
+
# return out.view(len(input), self.n_latents, self.w_dim)
|
734 |
+
reshaped = out.view(len(input), 2*self.w_dim)
|
735 |
+
return reshaped[:,:self.w_dim], reshaped[:,self.w_dim:]
|
736 |
+
|
737 |
+
def kaiming_init(m):
|
738 |
+
if isinstance(m, (nn.Linear, nn.Conv2d)):
|
739 |
+
init.kaiming_normal_(m.weight)
|
740 |
+
if m.bias is not None:
|
741 |
+
m.bias.data.fill_(0)
|
742 |
+
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
|
743 |
+
m.weight.data.fill_(1)
|
744 |
+
if m.bias is not None:
|
745 |
+
m.bias.data.fill_(0)
|
746 |
+
|
747 |
+
|
748 |
+
def normal_init(m):
|
749 |
+
if isinstance(m, (nn.Linear, nn.Conv2d)):
|
750 |
+
init.normal_(m.weight, 0, 0.02)
|
751 |
+
if m.bias is not None:
|
752 |
+
m.bias.data.fill_(0)
|
753 |
+
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
|
754 |
+
m.weight.data.fill_(1)
|
755 |
+
if m.bias is not None:
|
756 |
+
m.bias.data.fill_(0)
|