jhaozhuang commited on
Commit
77771e4
·
1 Parent(s): e826d2f
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. BidirectionalTranslation/LICENSE +26 -0
  2. BidirectionalTranslation/README.md +72 -0
  3. BidirectionalTranslation/data/__init__.py +100 -0
  4. BidirectionalTranslation/data/aligned_dataset.py +60 -0
  5. BidirectionalTranslation/data/base_dataset.py +164 -0
  6. BidirectionalTranslation/data/image_folder.py +66 -0
  7. BidirectionalTranslation/data/singleCo_dataset.py +85 -0
  8. BidirectionalTranslation/data/singleSr_dataset.py +73 -0
  9. BidirectionalTranslation/models/__init__.py +61 -0
  10. BidirectionalTranslation/models/base_model.py +277 -0
  11. BidirectionalTranslation/models/cycle_ganstft_model.py +103 -0
  12. BidirectionalTranslation/models/networks.py +1375 -0
  13. BidirectionalTranslation/options/base_options.py +142 -0
  14. BidirectionalTranslation/options/test_options.py +19 -0
  15. BidirectionalTranslation/requirements.txt +8 -0
  16. BidirectionalTranslation/scripts/test_western2manga.sh +49 -0
  17. BidirectionalTranslation/test.py +71 -0
  18. BidirectionalTranslation/util/html.py +86 -0
  19. BidirectionalTranslation/util/util.py +136 -0
  20. BidirectionalTranslation/util/visualizer.py +221 -0
  21. app.py +507 -0
  22. assets/example_0/input.jpg +0 -0
  23. assets/example_0/ref1.jpg +0 -0
  24. assets/example_1/input.jpg +0 -0
  25. assets/example_1/ref1.jpg +0 -0
  26. assets/example_1/ref2.jpg +0 -0
  27. assets/example_1/ref3.jpg +0 -0
  28. assets/example_2/input.png +0 -0
  29. assets/example_2/ref1.png +0 -0
  30. assets/example_2/ref2.png +0 -0
  31. assets/example_2/ref3.png +0 -0
  32. assets/example_3/input.png +0 -0
  33. assets/example_3/ref1.png +0 -0
  34. assets/example_3/ref2.png +0 -0
  35. assets/example_3/ref3.png +0 -0
  36. assets/example_4/input.jpg +0 -0
  37. assets/example_4/ref1.jpg +0 -0
  38. assets/example_4/ref2.jpg +0 -0
  39. assets/example_4/ref3.jpg +0 -0
  40. assets/example_5/input.png +0 -0
  41. assets/example_5/ref1.png +0 -0
  42. assets/example_5/ref2.png +0 -0
  43. assets/example_5/ref3.png +0 -0
  44. assets/mask.png +0 -0
  45. diffusers/.github/ISSUE_TEMPLATE/bug-report.yml +110 -0
  46. diffusers/.github/ISSUE_TEMPLATE/config.yml +4 -0
  47. diffusers/.github/ISSUE_TEMPLATE/feature_request.md +20 -0
  48. diffusers/.github/ISSUE_TEMPLATE/feedback.md +12 -0
  49. diffusers/.github/ISSUE_TEMPLATE/new-model-addition.yml +31 -0
  50. diffusers/.github/ISSUE_TEMPLATE/translate.md +29 -0
BidirectionalTranslation/LICENSE ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Manga Filling Style Conversion with Screentone Variational Autoencoder
2
+
3
+ Copyright (c) 2020 The Chinese University of Hong Kong
4
+
5
+ Copyright and License Information: The source code, the binary executable, and all data files (hereafter, Software) are copyrighted by The Chinese University of Hong Kong and Tien-Tsin Wong (hereafter, Author), Copyright (c) 2021 The Chinese University of Hong Kong. All Rights Reserved.
6
+
7
+ The Author grants to you ("Licensee") a non-exclusive license to use the Software for academic, research and commercial purposes, without fee. For commercial use, Licensee should submit a WRITTEN NOTICE to the Author. The notice should clearly identify the software package/system/hardware (name, version, and/or model number) using the Software. Licensee may distribute the Software to third parties provided that the copyright notice and this statement appears on all copies. Licensee agrees that the copyright notice and this statement will appear on all copies of the Software, or portions thereof. The Author retains exclusive ownership of the Software.
8
+
9
+ Licensee may make derivatives of the Software, provided that such derivatives can only be used for the purposes specified in the license grant above.
10
+
11
+ THE AUTHOR MAKES NO REPRESENTATIONS OR WARRANTIES ABOUT THE SUITABILITY OF THE SOFTWARE, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, OR NON-INFRINGEMENT. THE AUTHOR SHALL NOT BE LIABLE FOR ANY DAMAGES SUFFERED BY LICENSEE AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THE SOFTWARE OR ITS DERIVATIVES.
12
+
13
+ By using the source code, Licensee agrees to cite the following papers in
14
+ Licensee's publication/work:
15
+
16
+ Minshan Xie, Chengze Li, Xueting Liu and Tien-Tsin Wong
17
+ "Manga Filling Style Conversion with Screentone Variational Autoencoder"
18
+ ACM Transactions on Graphics (SIGGRAPH Asia 2020 issue), Vol. 39, No. 6, December 2020, pp. 226:1-226:15.
19
+
20
+
21
+ By using or copying the Software, Licensee agrees to abide by the intellectual property laws, and all other applicable laws of the U.S., and the terms of this license.
22
+
23
+ Author shall have the right to terminate this license immediately by written notice upon Licensee's breach of, or non-compliance with, any of its terms.
24
+ Licensee may be held legally responsible for any infringement that is caused or encouraged by Licensee's failure to abide by the terms of this license.
25
+
26
+ For more information or comments, send mail to: [email protected]
BidirectionalTranslation/README.md ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Bidirectional Translation
2
+
3
+ Pytorch implementation for multimodal comic-to-manga translation.
4
+
5
+ **Note**: The current software works well with PyTorch 1.6.0+.
6
+
7
+ ## Prerequisites
8
+ - Linux
9
+ - Python 3
10
+ - CPU or NVIDIA GPU + CUDA CuDNN
11
+
12
+ ## Getting Started ###
13
+ ### Installation
14
+ - Clone this repo:
15
+ ```bash
16
+ git clone https://github.com/msxie/ScreenStyle.git
17
+ cd ScreenStyle/MangaScreening
18
+ ```
19
+ - Install PyTorch and dependencies from http://pytorch.org
20
+ - Install python libraries [tensorboardX](https://github.com/lanpa/tensorboardX)
21
+ - Install other libraries
22
+ For pip users:
23
+ ```
24
+ pip install -r requirements.txt
25
+ ```
26
+
27
+ ## Data praperation
28
+ The training requires paired data (including manga image, western image and their line drawings).
29
+ The line drawing can be extracted using [MangaLineExtraction](https://github.com/ljsabc/MangaLineExtraction).
30
+
31
+ ```
32
+ ${DATASET}
33
+ |-- color2manga
34
+ | |-- val
35
+ | | |-- ${FOLDER}
36
+ | | | |-- imgs
37
+ | | | | |-- 0001.png
38
+ | | | | |-- ...
39
+ | | | |-- line
40
+ | | | | |-- 0001.png
41
+ | | | | |-- ...
42
+ ```
43
+
44
+ ### Use a Pre-trained Model
45
+ - Download the pre-trained [ScreenVAE](https://drive.google.com/file/d/1OBxWHjijMwi9gfTOfDiFiHRZA_CXNSWr/view?usp=sharing) model and place under `checkpoints/ScreenVAE/` folder.
46
+
47
+ - Download the pre-trained [color2manga](https://drive.google.com/file/d/18-N1W0t3igWLJWFyplNZ5Fa2YHWASCZY/view?usp=sharing) model and place under `checkpoints/color2manga/` folder.
48
+ - Generate results with the model
49
+ ```bash
50
+ bash ./scripts/test_western2manga.sh
51
+ ```
52
+
53
+ ## Copyright and License
54
+ You are granted with the [LICENSE](LICENSE) for both academic and commercial usages.
55
+
56
+ ## Citation
57
+ If you find the code helpful in your resarch or work, please cite the following papers.
58
+ ```
59
+ @article{xie-2020-manga,
60
+ author = {Minshan Xie and Chengze Li and Xueting Liu and Tien-Tsin Wong},
61
+ title = {Manga Filling Style Conversion with Screentone Variational Autoencoder},
62
+ journal = {ACM Transactions on Graphics (SIGGRAPH Asia 2020 issue)},
63
+ month = {December},
64
+ year = {2020},
65
+ volume = {39},
66
+ number = {6},
67
+ pages = {226:1--226:15}
68
+ }
69
+ ```
70
+
71
+ ### Acknowledgements
72
+ This code borrows heavily from the [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) repository.
BidirectionalTranslation/data/__init__.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package includes all the modules related to data loading and preprocessing
2
+
3
+ To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4
+ You need to implement four functions:
5
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6
+ -- <__len__>: return the size of dataset.
7
+ -- <__getitem__>: get a data point from data loader.
8
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
9
+
10
+ Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11
+ See our template dataset class 'template_dataset.py' for more details.
12
+ """
13
+ import importlib
14
+ import torch.utils.data
15
+ from data.base_dataset import BaseDataset
16
+
17
+
18
+ def find_dataset_using_name(dataset_name):
19
+ """Import the module "data/[dataset_name]_dataset.py".
20
+
21
+ In the file, the class called DatasetNameDataset() will
22
+ be instantiated. It has to be a subclass of BaseDataset,
23
+ and it is case-insensitive.
24
+ """
25
+ dataset_filename = "data." + dataset_name + "_dataset"
26
+ datasetlib = importlib.import_module(dataset_filename)
27
+
28
+ dataset = None
29
+ target_dataset_name = dataset_name.replace('_', '') + 'dataset'
30
+ for name, cls in datasetlib.__dict__.items():
31
+ if name.lower() == target_dataset_name.lower() \
32
+ and issubclass(cls, BaseDataset):
33
+ dataset = cls
34
+
35
+ if dataset is None:
36
+ raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
37
+
38
+ return dataset
39
+
40
+
41
+ def get_option_setter(dataset_name):
42
+ """Return the static method <modify_commandline_options> of the dataset class."""
43
+ dataset_class = find_dataset_using_name(dataset_name)
44
+ return dataset_class.modify_commandline_options
45
+
46
+
47
+ def create_dataset(opt):
48
+ """Create a dataset given the option.
49
+
50
+ This function wraps the class CustomDatasetDataLoader.
51
+ This is the main interface between this package and 'train.py'/'test.py'
52
+
53
+ Example:
54
+ >>> from data import create_dataset
55
+ >>> dataset = create_dataset(opt)
56
+ """
57
+ data_loader = CustomDatasetDataLoader(opt)
58
+ dataset = data_loader.load_data()
59
+ return dataset
60
+
61
+
62
+ class CustomDatasetDataLoader():
63
+ """Wrapper class of Dataset class that performs multi-threaded data loading"""
64
+
65
+ def __init__(self, opt):
66
+ """Initialize this class
67
+
68
+ Step 1: create a dataset instance given the name [dataset_mode]
69
+ Step 2: create a multi-threaded data loader.
70
+ """
71
+ self.opt = opt
72
+ dataset_class = find_dataset_using_name(opt.dataset_mode)
73
+ self.dataset = dataset_class(opt)
74
+ print("dataset [%s] was created" % type(self.dataset).__name__)
75
+
76
+ train_sampler = None
77
+ if len(opt.gpu_ids) > 1:
78
+ train_sampler = torch.utils.data.distributed.DistributedSampler(self.dataset)
79
+
80
+ self.dataloader = torch.utils.data.DataLoader(
81
+ self.dataset,
82
+ batch_size=opt.batch_size,
83
+ #shuffle=not opt.serial_batches,
84
+ num_workers=int(opt.num_threads),
85
+ pin_memory=True, sampler=train_sampler
86
+ )
87
+
88
+ def load_data(self):
89
+ return self
90
+
91
+ def __len__(self):
92
+ """Return the number of data in the dataset"""
93
+ return min(len(self.dataset), self.opt.max_dataset_size)
94
+
95
+ def __iter__(self):
96
+ """Return a batch of data"""
97
+ for i, data in enumerate(self.dataloader):
98
+ if i * self.opt.batch_size >= self.opt.max_dataset_size:
99
+ break
100
+ yield data
BidirectionalTranslation/data/aligned_dataset.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from data.base_dataset import BaseDataset, get_params, get_transform
3
+ from data.image_folder import make_dataset
4
+ from PIL import Image
5
+
6
+
7
+ class AlignedDataset(BaseDataset):
8
+ """A dataset class for paired image dataset.
9
+
10
+ It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}.
11
+ During test time, you need to prepare a directory '/path/to/data/test'.
12
+ """
13
+
14
+ def __init__(self, opt):
15
+ """Initialize this dataset class.
16
+
17
+ Parameters:
18
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
19
+ """
20
+ BaseDataset.__init__(self, opt)
21
+ self.dir_AB = os.path.join(opt.dataroot, opt.phase) # get the image directory
22
+ self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) # get image paths
23
+ assert(self.opt.load_size >= self.opt.crop_size) # crop_size should be smaller than the size of loaded image
24
+ self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
25
+ self.output_nc = self.opt.input_nc if self.opt.direction == 'BtoA' else self.opt.output_nc
26
+
27
+ def __getitem__(self, index):
28
+ """Return a data point and its metadata information.
29
+
30
+ Parameters:
31
+ index - - a random integer for data indexing
32
+
33
+ Returns a dictionary that contains A, B, A_paths and B_paths
34
+ A (tensor) - - an image in the input domain
35
+ B (tensor) - - its corresponding image in the target domain
36
+ A_paths (str) - - image paths
37
+ B_paths (str) - - image paths (same as A_paths)
38
+ """
39
+ # read a image given a random integer index
40
+ AB_path = self.AB_paths[index%len(self.AB_paths)]
41
+ AB = Image.open(AB_path).convert('RGB')
42
+ # split AB image into A and B
43
+ w, h = AB.size
44
+ w2 = int(w / 2)
45
+ A = AB.crop((0, 0, w2, h))
46
+ B = AB.crop((w2, 0, w, h))
47
+
48
+ # apply the same transform to both A and B
49
+ transform_params = get_params(self.opt, A.size)
50
+ A_transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1))
51
+ B_transform = get_transform(self.opt, transform_params, grayscale=(self.output_nc == 1))
52
+
53
+ A = A_transform(A)
54
+ B = B_transform(B)
55
+
56
+ return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}
57
+
58
+ def __len__(self):
59
+ """Return the total number of images in the dataset."""
60
+ return len(self.AB_paths)*100
BidirectionalTranslation/data/base_dataset.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2
+
3
+ It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4
+ """
5
+ import random
6
+ import numpy as np
7
+ import torch.utils.data as data
8
+ from PIL import Image, ImageOps
9
+ import torchvision.transforms as transforms
10
+ from abc import ABC, abstractmethod
11
+
12
+
13
+ class BaseDataset(data.Dataset, ABC):
14
+ """This class is an abstract base class (ABC) for datasets.
15
+
16
+ To create a subclass, you need to implement the following four functions:
17
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18
+ -- <__len__>: return the size of dataset.
19
+ -- <__getitem__>: get a data point.
20
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
21
+ """
22
+
23
+ def __init__(self, opt):
24
+ """Initialize the class; save the options in the class
25
+
26
+ Parameters:
27
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28
+ """
29
+ self.opt = opt
30
+ self.root = opt.dataroot
31
+
32
+ @staticmethod
33
+ def modify_commandline_options(parser, is_train):
34
+ """Add new dataset-specific options, and rewrite default values for existing options.
35
+
36
+ Parameters:
37
+ parser -- original option parser
38
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
39
+
40
+ Returns:
41
+ the modified parser.
42
+ """
43
+ return parser
44
+
45
+ @abstractmethod
46
+ def __len__(self):
47
+ """Return the total number of images in the dataset."""
48
+ return 0
49
+
50
+ @abstractmethod
51
+ def __getitem__(self, index):
52
+ """Return a data point and its metadata information.
53
+
54
+ Parameters:
55
+ index - - a random integer for data indexing
56
+
57
+ Returns:
58
+ a dictionary of data with their names. It ususally contains the data itself and its metadata information.
59
+ """
60
+ pass
61
+
62
+
63
+ def get_params(opt, size):
64
+ w, h = size
65
+ new_h = h
66
+ new_w = w
67
+ crop = 0
68
+ if opt.preprocess == 'resize_and_crop':
69
+ new_h = new_w = opt.load_size
70
+ elif opt.preprocess == 'scale_width_and_crop':
71
+ new_w = opt.load_size
72
+ new_h = opt.load_size * h // w
73
+
74
+ # x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
75
+ # y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
76
+
77
+ x = random.randint(crop, np.maximum(0, new_w - opt.crop_size-crop))
78
+ y = random.randint(crop, np.maximum(0, new_h - opt.crop_size-crop))
79
+
80
+ flip = random.random() > 0.5
81
+
82
+ return {'crop_pos': (x, y), 'flip': flip}
83
+
84
+
85
+ def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
86
+ transform_list = []
87
+ if grayscale:
88
+ transform_list.append(transforms.Grayscale(1))
89
+ if 'resize' in opt.preprocess:
90
+ osize = [opt.load_size, opt.load_size]
91
+ transform_list.append(transforms.Resize(osize, method))
92
+ elif 'scale_width' in opt.preprocess:
93
+ transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
94
+
95
+ if 'crop' in opt.preprocess:
96
+ if params is None:
97
+ # transform_list.append(transforms.RandomCrop(opt.crop_size))
98
+ transform_list.append(transforms.CenterCrop(opt.crop_size))
99
+ else:
100
+ transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
101
+
102
+ if opt.preprocess == 'none':
103
+ transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=2**8, method=method)))
104
+
105
+ if not opt.no_flip:
106
+ if params is None:
107
+ transform_list.append(transforms.RandomHorizontalFlip())
108
+ elif params['flip']:
109
+ transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
110
+
111
+ # transform_list += [transforms.ToTensor()]
112
+ if convert:
113
+ transform_list += [transforms.ToTensor()]
114
+ if grayscale:
115
+ transform_list += [transforms.Normalize((0.5,), (0.5,))]
116
+ else:
117
+ transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
118
+ return transforms.Compose(transform_list)
119
+
120
+
121
+ def __make_power_2(img, base, method=Image.BICUBIC):
122
+ ow, oh = img.size
123
+ h = int((oh+base-1) // base * base)
124
+ w = int((ow+base-1) // base * base)
125
+ if (h == oh) and (w == ow):
126
+ return img
127
+
128
+ __print_size_warning(ow, oh, w, h)
129
+ return ImageOps.expand(img, (0, 0, w-ow, h-oh), fill=255)
130
+
131
+
132
+ def __scale_width(img, target_width, method=Image.BICUBIC):
133
+ ow, oh = img.size
134
+ if (ow == target_width):
135
+ return img
136
+ w = target_width
137
+ h = int(target_width * oh / ow)
138
+ return img.resize((w, h), method)
139
+
140
+
141
+ def __crop(img, pos, size):
142
+ ow, oh = img.size
143
+ x1, y1 = pos
144
+ tw = th = size
145
+ if (ow > tw or oh > th):
146
+ return img.crop((x1, y1, x1 + tw, y1 + th))
147
+ return img
148
+
149
+
150
+ def __flip(img, flip):
151
+ if flip:
152
+ return img.transpose(Image.FLIP_LEFT_RIGHT)
153
+ return img
154
+
155
+
156
+ def __print_size_warning(ow, oh, w, h):
157
+ """Print warning information about image size(only print once)"""
158
+ if not hasattr(__print_size_warning, 'has_printed'):
159
+ print("The image size needs to be a multiple of 4. "
160
+ "The loaded image size was (%d, %d), so it was adjusted to "
161
+ "(%d, %d). This adjustment will be done to all images "
162
+ "whose sizes are not multiples of 4" % (ow, oh, w, h))
163
+ __print_size_warning.has_printed = True
164
+
BidirectionalTranslation/data/image_folder.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A modified image folder class
2
+
3
+ We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4
+ so that this class can load images from both current directory and its subdirectories.
5
+ """
6
+
7
+ import torch.utils.data as data
8
+
9
+ from PIL import Image
10
+ import os
11
+ import os.path
12
+
13
+ IMG_EXTENSIONS = [
14
+ '.jpg', '.JPG', '.jpeg', '.JPEG', '.npz', 'npy',
15
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16
+ ]
17
+
18
+
19
+ def is_image_file(filename):
20
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
21
+
22
+
23
+ def make_dataset(dir, max_dataset_size=float("inf")):
24
+ images = []
25
+ assert os.path.isdir(dir), '%s is not a valid directory' % dir
26
+
27
+ for root, _, fnames in sorted(os.walk(dir)):
28
+ for fname in fnames:
29
+ if is_image_file(fname):
30
+ path = os.path.join(root, fname)
31
+ images.append(path)
32
+ return images[:min(max_dataset_size, len(images))]
33
+
34
+
35
+ def default_loader(path):
36
+ return Image.open(path).convert('RGB')
37
+
38
+
39
+ class ImageFolder(data.Dataset):
40
+
41
+ def __init__(self, root, transform=None, return_paths=False,
42
+ loader=default_loader):
43
+ imgs = make_dataset(root)
44
+ if len(imgs) == 0:
45
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
46
+ "Supported image extensions are: " +
47
+ ",".join(IMG_EXTENSIONS)))
48
+
49
+ self.root = root
50
+ self.imgs = imgs
51
+ self.transform = transform
52
+ self.return_paths = return_paths
53
+ self.loader = loader
54
+
55
+ def __getitem__(self, index):
56
+ path = self.imgs[index]
57
+ img = self.loader(path)
58
+ if self.transform is not None:
59
+ img = self.transform(img)
60
+ if self.return_paths:
61
+ return img, path
62
+ else:
63
+ return img
64
+
65
+ def __len__(self):
66
+ return len(self.imgs)
BidirectionalTranslation/data/singleCo_dataset.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from data.base_dataset import BaseDataset, get_params, get_transform
3
+ from data.image_folder import make_dataset
4
+ from PIL import Image, ImageEnhance
5
+ import random
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import cv2
10
+
11
+
12
+ class SingleCoDataset(BaseDataset):
13
+ @staticmethod
14
+ def modify_commandline_options(parser, is_train):
15
+ return parser
16
+
17
+ def __init__(self, opt):
18
+ self.opt = opt
19
+ self.root = opt.dataroot
20
+ self.dir_A = os.path.join(opt.dataroot, opt.phase, opt.folder, 'imgs')
21
+
22
+ self.A_paths = make_dataset(self.dir_A)
23
+
24
+ self.A_paths = sorted(self.A_paths)
25
+
26
+ self.A_size = len(self.A_paths)
27
+ # self.transform = get_transform(opt)
28
+
29
+ def __getitem__(self, index):
30
+ A_path = self.A_paths[index]
31
+
32
+ A_img = Image.open(A_path).convert('RGB')
33
+ # enhancer = ImageEnhance.Brightness(A_img)
34
+ # A_img = enhancer.enhance(1.5)
35
+ if os.path.exists(A_path.replace('imgs','line')[:-4]+'.jpg'):
36
+ # L_img = Image.open(A_path.replace('imgs','line')[:-4]+'.png')
37
+ L_img = cv2.imread(A_path.replace('imgs','line')[:-4]+'.jpg')
38
+ kernel = np.ones((3,3), np.uint8)
39
+ L_img = cv2.erode(L_img, kernel, iterations=1)
40
+ L_img = Image.fromarray(L_img)
41
+ else:
42
+ L_img = A_img
43
+ if A_img.size!=L_img.size:
44
+ # L_img = L_img.resize(A_img.size, Image.ANTIALIAS)
45
+ A_img = A_img.resize(L_img.size, Image.ANTIALIAS)
46
+ if A_img.size[1]>2500:
47
+ A_img = A_img.resize((A_img.size[0]//2, A_img.size[1]//2), Image.ANTIALIAS)
48
+
49
+ ow, oh = A_img.size
50
+ transform_params = get_params(self.opt, A_img.size)
51
+ A_transform = get_transform(self.opt, transform_params, grayscale=False)
52
+ L_transform = get_transform(self.opt, transform_params, grayscale=True)
53
+ A = A_transform(A_img)
54
+ L = L_transform(L_img)
55
+
56
+ # base = 2**9
57
+ # h = int((oh+base-1) // base * base)
58
+ # w = int((ow+base-1) // base * base)
59
+ # A = F.pad(A.unsqueeze(0), (0,w-ow, 0,h-oh), 'replicate').squeeze(0)
60
+ # L = F.pad(L.unsqueeze(0), (0,w-ow, 0,h-oh), 'replicate').squeeze(0)
61
+
62
+ tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
63
+ Ai = tmp.unsqueeze(0)
64
+
65
+ return {'A': A, 'Ai': Ai, 'L': L,
66
+ 'B': torch.zeros(1), 'Bs': torch.zeros(1), 'Bi': torch.zeros(1), 'Bl': torch.zeros(1),
67
+ 'A_paths': A_path, 'h': oh, 'w': ow}
68
+
69
+ def __len__(self):
70
+ return self.A_size
71
+
72
+ def name(self):
73
+ return 'SingleCoDataset'
74
+
75
+
76
+ def M_transform(feat, opt, params=None):
77
+ outfeat = feat.copy()
78
+ oh,ow = feat.shape[1:]
79
+ x1, y1 = params['crop_pos']
80
+ tw = th = opt.crop_size
81
+ if (ow > tw or oh > th):
82
+ outfeat = outfeat[:,y1:y1+th,x1:x1+tw]
83
+ if params['flip']:
84
+ outfeat = np.flip(outfeat, 2)#outfeat[:,:,::-1]
85
+ return torch.from_numpy(outfeat.copy()).float()*2-1.0
BidirectionalTranslation/data/singleSr_dataset.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from data.base_dataset import BaseDataset, get_params, get_transform
3
+ from data.image_folder import make_dataset
4
+ from PIL import Image
5
+ import random
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class SingleSrDataset(BaseDataset):
12
+ @staticmethod
13
+ def modify_commandline_options(parser, is_train):
14
+ return parser
15
+
16
+ def __init__(self, opt):
17
+ self.opt = opt
18
+ self.root = opt.dataroot
19
+ self.dir_B = os.path.join(opt.dataroot, opt.phase, opt.folder, 'imgs')
20
+ # self.dir_B = os.path.join(opt.dataroot, opt.phase, 'test/imgs', opt.folder)
21
+
22
+ self.B_paths = make_dataset(self.dir_B)
23
+
24
+ self.B_paths = sorted(self.B_paths)
25
+
26
+ self.B_size = len(self.B_paths)
27
+ # self.transform = get_transform(opt)
28
+ # print(self.B_size)
29
+
30
+ def __getitem__(self, index):
31
+ B_path = self.B_paths[index]
32
+
33
+ B_img = Image.open(B_path).convert('RGB')
34
+ if os.path.exists(B_path.replace('imgs','line').replace('.jpg','.png')):
35
+ L_img = Image.open(B_path.replace('imgs','line').replace('.jpg','.png'))#.convert('RGB')
36
+ else:
37
+ L_img = Image.open(B_path.replace('imgs','line').replace('.png','.jpg'))#.convert('RGB')
38
+ B_img = B_img.resize(L_img.size, Image.ANTIALIAS)
39
+
40
+ ow, oh = B_img.size
41
+ transform_params = get_params(self.opt, B_img.size)
42
+ B_transform = get_transform(self.opt, transform_params, grayscale=True)
43
+ B = B_transform(B_img)
44
+ L = B_transform(L_img)
45
+
46
+ # base = 2**8
47
+ # h = int((oh+base-1) // base * base)
48
+ # w = int((ow+base-1) // base * base)
49
+ # B = F.pad(B.unsqueeze(0), (0,w-ow, 0,h-oh), 'replicate').squeeze(0)
50
+ # L = F.pad(L.unsqueeze(0), (0,w-ow, 0,h-oh), 'replicate').squeeze(0)
51
+
52
+ return {'B': B, 'Bs': B, 'Bi': B, 'Bl': L,
53
+ 'A': torch.zeros(1), 'Ai': torch.zeros(1), 'L': torch.zeros(1),
54
+ 'A_paths': B_path, 'h': oh, 'w': ow}
55
+
56
+ def __len__(self):
57
+ return self.B_size
58
+
59
+ def name(self):
60
+ return 'SingleSrDataset'
61
+
62
+
63
+ def M_transform(feat, opt, params=None):
64
+ outfeat = feat.copy()
65
+ if params is not None:
66
+ oh,ow = feat.shape[1:]
67
+ x1, y1 = params['crop_pos']
68
+ tw = th = opt.crop_size
69
+ if (ow > tw or oh > th):
70
+ outfeat = outfeat[:,y1:y1+th,x1:x1+tw]
71
+ if params['flip']:
72
+ outfeat = np.flip(outfeat, 2).copy()#outfeat[:,:,::-1]
73
+ return torch.from_numpy(outfeat).float()*2-1.0
BidirectionalTranslation/models/__init__.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package contains modules related to objective functions, optimizations, and network architectures.
2
+ To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
3
+ You need to implement the following five functions:
4
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
5
+ -- <set_input>: unpack data from dataset and apply preprocessing.
6
+ -- <forward>: produce intermediate results.
7
+ -- <optimize_parameters>: calculate loss, gradients, and update network weights.
8
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
9
+ In the function <__init__>, you need to define four lists:
10
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
11
+ -- self.model_names (str list): specify the images that you want to display and save.
12
+ -- self.visual_names (str list): define networks used in our training.
13
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
14
+ Now you can use the model class by specifying flag '--model dummy'.
15
+ See our template model class 'template_model.py' for an example.
16
+ """
17
+
18
+ import importlib
19
+ from models.base_model import BaseModel
20
+
21
+
22
+ def find_model_using_name(model_name):
23
+ """Import the module "models/[model_name]_model.py".
24
+ In the file, the class called DatasetNameModel() will
25
+ be instantiated. It has to be a subclass of BaseModel,
26
+ and it is case-insensitive.
27
+ """
28
+ model_filename = "models." + model_name + "_model"
29
+ modellib = importlib.import_module(model_filename)
30
+ model = None
31
+ target_model_name = model_name.replace('_', '') + 'model'
32
+ for name, cls in modellib.__dict__.items():
33
+ if name.lower() == target_model_name.lower() \
34
+ and issubclass(cls, BaseModel):
35
+ model = cls
36
+
37
+ if model is None:
38
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
39
+ exit(0)
40
+
41
+ return model
42
+
43
+
44
+ def get_option_setter(model_name):
45
+ """Return the static method <modify_commandline_options> of the model class."""
46
+ model_class = find_model_using_name(model_name)
47
+ return model_class.modify_commandline_options
48
+
49
+
50
+ def create_model(opt, ckpt_root):
51
+ """Create a model given the option.
52
+ This function warps the class CustomDatasetDataLoader.
53
+ This is the main interface between this package and 'train.py'/'test.py'
54
+ Example:
55
+ >>> from models import create_model
56
+ >>> model = create_model(opt)
57
+ """
58
+ model = find_model_using_name(opt.model)
59
+ instance = model(opt, ckpt_root = ckpt_root)
60
+ print("model [%s] was created" % type(instance).__name__)
61
+ return instance
BidirectionalTranslation/models/base_model.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from collections import OrderedDict
4
+ from abc import ABC, abstractmethod
5
+ from . import networks
6
+ import numpy as np
7
+ from torch.nn.parallel import DistributedDataParallel as DDP
8
+
9
+ class BaseModel(ABC):
10
+ """This class is an abstract base class (ABC) for models.
11
+ To create a subclass, you need to implement the following five functions:
12
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
13
+ -- <set_input>: unpack data from dataset and apply preprocessing.
14
+ -- <forward>: produce intermediate results.
15
+ -- <optimize_parameters>: calculate losses, gradients, and update network weights.
16
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
17
+ """
18
+
19
+ def __init__(self, opt):
20
+ """Initialize the BaseModel class.
21
+
22
+ Parameters:
23
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
24
+
25
+ When creating your custom class, you need to implement your own initialization.
26
+ In this fucntion, you should first call `BaseModel.__init__(self, opt)`
27
+ Then, you need to define four lists:
28
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
29
+ -- self.model_names (str list): specify the images that you want to display and save.
30
+ -- self.visual_names (str list): define networks used in our training.
31
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
32
+ """
33
+ self.opt = opt
34
+ self.gpu_ids = opt.gpu_ids
35
+ self.isTrain = opt.isTrain
36
+ self.iter = 0
37
+ self.last_iter = 0
38
+ self.device = torch.device('cuda:{}'.format(
39
+ self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
40
+ # save all the checkpoints to save_dir
41
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
42
+ try:
43
+ os.mkdir(self.save_dir)
44
+ except:
45
+ pass
46
+ # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
47
+ if opt.preprocess != 'scale_width':
48
+ torch.backends.cudnn.benchmark = True
49
+ self.loss_names = []
50
+ self.model_names = []
51
+ self.visual_names = []
52
+ self.optimizers = []
53
+ self.image_paths = []
54
+
55
+ self.label_colours = np.random.randint(255, size=(100,3))
56
+
57
+ def save_suppixel(self,l_inds):
58
+ im_target_rgb = np.array([self.label_colours[ c % 100 ] for c in l_inds])
59
+ b,h,w = l_inds.shape
60
+ im_target_rgb = im_target_rgb.reshape(b,h,w,3).transpose(0,3,1,2)/127.5-1.0
61
+ return torch.from_numpy(im_target_rgb)
62
+
63
+ @staticmethod
64
+ def modify_commandline_options(parser, is_train):
65
+ """Add new model-specific options, and rewrite default values for existing options.
66
+
67
+ Parameters:
68
+ parser -- original option parser
69
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
70
+
71
+ Returns:
72
+ the modified parser.
73
+ """
74
+ return parser
75
+
76
+ @abstractmethod
77
+ def set_input(self, input):
78
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
79
+
80
+ Parameters:
81
+ input (dict): includes the data itself and its metadata information.
82
+ """
83
+ pass
84
+
85
+ @abstractmethod
86
+ def forward(self):
87
+ """Run forward pass; called by both functions <optimize_parameters> and <test>."""
88
+ pass
89
+
90
+ def is_train(self):
91
+ """check if the current batch is good for training."""
92
+ return True
93
+
94
+ @abstractmethod
95
+ def optimize_parameters(self):
96
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
97
+ pass
98
+
99
+ def setup(self, opt):
100
+ """Load and print networks; create schedulers
101
+
102
+ Parameters:
103
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
104
+ """
105
+ if self.isTrain:
106
+ self.schedulers = [networks.get_scheduler(
107
+ optimizer, opt) for optimizer in self.optimizers]
108
+ if not self.isTrain or opt.continue_train:
109
+ self.load_networks(opt.epoch)
110
+ self.print_networks(opt.verbose)
111
+
112
+ def eval(self):
113
+ """Make models eval mode during test time"""
114
+ for name in self.model_names:
115
+ if isinstance(name, str):
116
+ net = getattr(self, 'net' + name)
117
+ net.eval()
118
+
119
+ def test(self):
120
+ """Forward function used in test time.
121
+
122
+ This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
123
+ It also calls <compute_visuals> to produce additional visualization results
124
+ """
125
+ with torch.no_grad():
126
+ self.forward()
127
+ self.compute_visuals()
128
+
129
+ def compute_visuals(self):
130
+ """Calculate additional output images for visdom and HTML visualization"""
131
+ pass
132
+
133
+ def get_image_paths(self):
134
+ """ Return image paths that are used to load current data"""
135
+ return self.image_paths
136
+
137
+ def update_learning_rate(self):
138
+ """Update learning rates for all the networks; called at the end of every epoch"""
139
+ for scheduler in self.schedulers:
140
+ scheduler.step()
141
+ lr = self.optimizers[0].param_groups[0]['lr']
142
+ print('learning rate = %.7f' % lr)
143
+
144
+ def get_current_visuals(self):
145
+ """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
146
+ visual_ret = OrderedDict()
147
+ for name in self.visual_names:
148
+ if isinstance(name, str):
149
+ if 'Lab' in name:
150
+ labimg = getattr(self, name).cpu()
151
+ labimg[:,0,:,:]+=1
152
+ labimg[:,0,:,:]*=50
153
+ labimg[:,1:,:,:] *= 110
154
+ labimg = labimg.permute((0,2,3,1))
155
+ for i in range(labimg.shape[0]):
156
+ labimg[i,:,:,:]=lab2rgb(labimg[i,:,:,:])
157
+ visual_ret[name] = (labimg.permute((0,3,1,2))*2-1.0).to(self.device)
158
+ elif 'Fm' in name:
159
+ visual_ret[name] = self.save_suppixel(getattr(self, name).cpu()).to(self.device)
160
+ else:
161
+ visual_ret[name] = getattr(self, name)
162
+ return visual_ret
163
+
164
+ def get_current_losses(self):
165
+ """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
166
+ errors_ret = OrderedDict()
167
+ for name in self.loss_names:
168
+ if isinstance(name, str):
169
+ # float(...) works for both scalar tensor and float number
170
+ errors_ret[name] = float(getattr(self, 'loss_' + name))
171
+ return errors_ret
172
+
173
+ def save_networks(self, epoch):
174
+ """Save all the networks to the disk.
175
+
176
+ Parameters:
177
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
178
+ """
179
+ for name in self.model_names:
180
+ if isinstance(name, str):
181
+ save_filename = '%s_net_%s.pth' % (epoch, name)
182
+ save_path = os.path.join(self.save_dir, save_filename)
183
+ # print(save_path)
184
+ net = getattr(self, 'net' + name)
185
+
186
+ if len(self.gpu_ids) > 0 and torch.cuda.is_available():
187
+ torch.save(net.state_dict(), save_path)
188
+ # net.cuda(self.gpu_ids[0])
189
+ else:
190
+ torch.save(net.cpu().state_dict(), save_path)
191
+
192
+ save_filename = '%s_net_opt.pth' % (epoch)
193
+ save_path = os.path.join(self.save_dir, save_filename)
194
+ save_dict = {'iter': str(self.iter // self.opt.print_freq * self.opt.print_freq)}
195
+ for i, name in enumerate(self.optimizer_names):
196
+ save_dict.update({name.lower(): self.optimizers[i].state_dict()})
197
+ torch.save(save_dict, save_path)
198
+
199
+
200
+ def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
201
+ """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
202
+ key = keys[i]
203
+ if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
204
+ if module.__class__.__name__.startswith('InstanceNorm') and \
205
+ (key == 'running_mean' or key == 'running_var'):
206
+ if getattr(module, key) is None:
207
+ state_dict.pop('.'.join(keys))
208
+ if module.__class__.__name__.startswith('InstanceNorm') and \
209
+ (key == 'num_batches_tracked'):
210
+ state_dict.pop('.'.join(keys))
211
+ else:
212
+ self.__patch_instance_norm_state_dict(
213
+ state_dict, getattr(module, key), keys, i + 1)
214
+
215
+ def load_networks(self, epoch):
216
+ """Load all the networks from the disk.
217
+
218
+ Parameters:
219
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
220
+ """
221
+ for name in self.model_names:
222
+ if isinstance(name, str):
223
+ load_filename = '%s_net_%s.pth' % (epoch, name)
224
+ load_path = os.path.join(self.save_dir, load_filename)
225
+ net = getattr(self, 'net' + name)
226
+ # if isinstance(net, torch.nn.DataParallel):
227
+ if isinstance(net, DDP):
228
+ net = net.module
229
+ # print(net)
230
+ print('loading the model from %s' % load_path)
231
+ # if you are using PyTorch newer than 0.4 (e.g., built from
232
+ # GitHub source), you can remove str() on self.device
233
+ state_dict = torch.load(
234
+ load_path, map_location=lambda storage, loc: storage.cuda())
235
+ if hasattr(state_dict, '_metadata'):
236
+ del state_dict._metadata
237
+
238
+ # patch InstanceNorm checkpoints prior to 0.4
239
+ # need to copy keys here because we mutate in loop
240
+ #for key in list(state_dict.keys()):
241
+ # self.__patch_instance_norm_state_dict(
242
+ # state_dict, net, key.split('.'))
243
+
244
+ net.load_state_dict(state_dict)
245
+ del state_dict
246
+
247
+ def print_networks(self, verbose):
248
+ """Print the total number of parameters in the network and (if verbose) network architecture
249
+
250
+ Parameters:
251
+ verbose (bool) -- if verbose: print the network architecture
252
+ """
253
+ print('---------- Networks initialized -------------')
254
+ for name in self.model_names:
255
+ if isinstance(name, str):
256
+ net = getattr(self, 'net' + name)
257
+ num_params = 0
258
+ for param in net.parameters():
259
+ num_params += param.numel()
260
+ if verbose:
261
+ print(net)
262
+ print('[Network %s] Total number of parameters : %.3f M' %
263
+ (name, num_params / 1e6))
264
+ print('-----------------------------------------------')
265
+
266
+ def set_requires_grad(self, nets, requires_grad=False):
267
+ """Set requires_grad=False for all the networks to avoid unnecessary computations
268
+ Parameters:
269
+ nets (network list) -- a list of networks
270
+ requires_grad (bool) -- whether the networks require gradients or not
271
+ """
272
+ if not isinstance(nets, list):
273
+ nets = [nets]
274
+ for net in nets:
275
+ if net is not None:
276
+ for param in net.parameters():
277
+ param.requires_grad = requires_grad
BidirectionalTranslation/models/cycle_ganstft_model.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from .base_model import BaseModel
4
+ from . import networks
5
+ import torch.nn.functional as F
6
+ from torch.nn.parallel import DistributedDataParallel as DDP
7
+
8
+
9
+ class CycleGANSTFTModel(BaseModel):
10
+
11
+ def __init__(self, opt, ckpt_root):
12
+
13
+ BaseModel.__init__(self, opt)
14
+
15
+ use_vae = True
16
+ self.interchnnls = 4
17
+ use_noise = False
18
+ self.half_size = opt.batch_size //2
19
+ self.device=opt.local_rank
20
+ self.gpu_ids=[self.device]
21
+ self.local_rank = opt.local_rank
22
+ self.cropsize = opt.crop_size
23
+
24
+ self.model_names = ['G_INTSCR2RGB','G_RGB2INTSCR','E']
25
+ self.netG_INTSCR2RGB = networks.define_G(self.interchnnls + 1, 3, opt.nz, opt.ngf, netG='unet_256',
26
+ norm='layer', nl='lrelu', use_dropout=opt.use_dropout, init_type='kaiming', init_gain=opt.init_gain,
27
+ gpu_ids=self.gpu_ids, where_add='all', upsample='bilinear', use_noise=use_noise)
28
+ self.netG_RGB2INTSCR = networks.define_G(4, self.interchnnls, 0, opt.ngf, netG='unet_256',
29
+ norm='layer', nl='lrelu', use_dropout=opt.use_dropout, init_type='kaiming', init_gain=opt.init_gain,
30
+ gpu_ids=self.gpu_ids, where_add='input', upsample='bilinear', use_noise=use_noise)
31
+ self.netE = networks.define_E(opt.output_nc, opt.nz, opt.nef, netE=opt.netE, norm='none', nl='lrelu',
32
+ init_type='xavier', init_gain=opt.init_gain, gpu_ids=self.gpu_ids, vaeLike=use_vae)
33
+ self.nets = [self.netG_INTSCR2RGB, self.netG_RGB2INTSCR, self.netE]
34
+
35
+ self.netSVAE = networks.define_SVAE(inc=1, outc=self.interchnnls, outplanes=64, blocks=3, netVAE='SVAE',
36
+ save_dir= ckpt_root+'/ScreenStyle/ScreenVAE',init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids)
37
+
38
+
39
+ def set_input(self, input):
40
+ AtoB = self.opt.direction == 'AtoB'
41
+ self.real_RGB = input['A'].to(self.device)
42
+ self.real_Ai = self.grayscale(self.real_RGB)
43
+ self.real_L = input['L'].to(self.device)
44
+ self.real_ML = input['Bl'].to(self.device)
45
+ self.real_M = input['B'].to(self.device)
46
+
47
+ self.h = input['h']
48
+ self.w = input['w']
49
+
50
+ def grayscale(self, input_image):
51
+ rate = torch.Tensor([0.299, 0.587, 0.114]).reshape(1, 3, 1, 1).to(input_image.device)
52
+ # tmp = input_image[:,0, ...] * 0.299 + input_image[:,1, ...] * 0.587 + input_image[:,2, ...] * 0.114
53
+ return (input_image*rate).sum(1,keepdims=True)
54
+
55
+ def forward(self, AtoB=True, sty=None):
56
+ if AtoB:
57
+ real_LRGB = torch.cat([self.real_L, self.real_RGB],1)
58
+ fake_SCR = self.netG_RGB2INTSCR(real_LRGB)
59
+ fake_M = self.netSVAE(fake_SCR, line=self.real_L, img_input=False)
60
+ fake_M = torch.clamp(fake_M, -1,1)
61
+ fake_M2 = self.norm(torch.mul(self.denorm(fake_M), self.denorm(self.real_L)))#*self.mask2
62
+ return fake_M[:,:,:self.h, :self.w], fake_M2[:,:,:self.h, :self.w], fake_SCR[:,:,:self.h, :self.w]
63
+ else:
64
+ if sty is None: # use encoded z
65
+ z0, _ = self.netE(self.real_RGB)
66
+ else:
67
+ z0 = sty
68
+ # z0 = self.get_z_random(self.real_A.size(0), self.opt.nz)
69
+ real_SCR = self.netSVAE(self.real_M, self.real_ML, output_screen_only=True) #8
70
+ real_LSCR = torch.cat([self.real_ML, real_SCR], 1)
71
+ fake_nRGB = self.netG_INTSCR2RGB(real_LSCR, z0)
72
+ fake_nRGB = torch.clamp(fake_nRGB, -1,1)
73
+ fake_RGB = self.norm(torch.mul(self.denorm(fake_nRGB), self.denorm(self.real_ML)))
74
+ return fake_RGB[:,:,:self.h, :self.w], real_SCR[:,:,:self.h, :self.w], self.real_ML[:,:,:self.h, :self.w]
75
+
76
+ def norm(self, im):
77
+ return im * 2.0 - 1
78
+
79
+ def denorm(self, im):
80
+ return (im + 1) / 2.0
81
+
82
+ def optimize_parameters(self):
83
+ pass
84
+
85
+ def get_z_random(self, batch_size, nz, random_type='gauss', truncation=False, tvalue=1):
86
+ z = None
87
+ if random_type == 'uni':
88
+ z = torch.rand(batch_size, nz) * 2.0 - 1.0
89
+ elif random_type == 'gauss':
90
+ z = torch.randn(batch_size, nz) * tvalue
91
+ # do the truncation trick
92
+ if truncation:
93
+ k = 0
94
+ while (k < 15 * nz):
95
+ if torch.max(z) <= tvalue:
96
+ break
97
+ zabs = torch.abs(z)
98
+ zz = torch.randn(batch_size, nz)
99
+ z[zabs > tvalue] = zz[zabs > tvalue]
100
+ k += 1
101
+ z = torch.clamp(z, -tvalue, tvalue)
102
+
103
+ return z.detach().to(self.device)
BidirectionalTranslation/models/networks.py ADDED
@@ -0,0 +1,1375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import init
4
+ import functools
5
+ from torch.optim import lr_scheduler
6
+ import numpy as np
7
+ import torch.nn.functional as F
8
+ from torch.nn.modules.normalization import LayerNorm
9
+ import os
10
+ from torch.nn.utils import spectral_norm
11
+ from torchvision import models
12
+
13
+ ###############################################################################
14
+ # Helper functions
15
+ ###############################################################################
16
+
17
+
18
+ def init_weights(net, init_type='normal', init_gain=0.02):
19
+ """Initialize network weights.
20
+ Parameters:
21
+ net (network) -- network to be initialized
22
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
23
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
24
+ We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
25
+ work better for some applications. Feel free to try yourself.
26
+ """
27
+ def init_func(m): # define the initialization function
28
+ classname = m.__class__.__name__
29
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
30
+ if init_type == 'normal':
31
+ init.normal_(m.weight.data, 0.0, init_gain)
32
+ elif init_type == 'xavier':
33
+ init.xavier_normal_(m.weight.data, gain=init_gain)
34
+ elif init_type == 'kaiming':
35
+ #init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
36
+ init.kaiming_normal_(m.weight.data, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
37
+ elif init_type == 'orthogonal':
38
+ init.orthogonal_(m.weight.data, gain=init_gain)
39
+ else:
40
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
41
+ if hasattr(m, 'bias') and m.bias is not None:
42
+ init.constant_(m.bias.data, 0.0)
43
+ elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
44
+ init.normal_(m.weight.data, 1.0, init_gain)
45
+ init.constant_(m.bias.data, 0.0)
46
+
47
+ print('initialize network with %s' % init_type)
48
+ net.apply(init_func) # apply the initialization function <init_func>
49
+
50
+
51
+ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], init=True):
52
+ """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
53
+ Parameters:
54
+ net (network) -- the network to be initialized
55
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
56
+ gain (float) -- scaling factor for normal, xavier and orthogonal.
57
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
58
+ Return an initialized network.
59
+ """
60
+ if len(gpu_ids) > 0:
61
+ assert(torch.cuda.is_available())
62
+ net.to(gpu_ids[0])
63
+ if init:
64
+ init_weights(net, init_type, init_gain=init_gain)
65
+ return net
66
+
67
+
68
+ def get_scheduler(optimizer, opt):
69
+ """Return a learning rate scheduler
70
+ Parameters:
71
+ optimizer -- the optimizer of the network
72
+ opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. 
73
+ opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
74
+ For 'linear', we keep the same learning rate for the first <opt.niter> epochs
75
+ and linearly decay the rate to zero over the next <opt.niter_decay> epochs.
76
+ For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
77
+ See https://pytorch.org/docs/stable/optim.html for more details.
78
+ """
79
+ if opt.lr_policy == 'linear':
80
+ def lambda_rule(epoch):
81
+ lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
82
+ return lr_l
83
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
84
+ elif opt.lr_policy == 'step':
85
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
86
+ elif opt.lr_policy == 'plateau':
87
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
88
+ elif opt.lr_policy == 'cosine':
89
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
90
+ else:
91
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
92
+ return scheduler
93
+
94
+ class LayerNormWarpper(nn.Module):
95
+ def __init__(self, num_features):
96
+ super(LayerNormWarpper, self).__init__()
97
+ self.num_features = int(num_features)
98
+
99
+ def forward(self, x):
100
+ x = nn.LayerNorm([self.num_features, x.size()[2], x.size()[3]], elementwise_affine=False).cuda()(x)
101
+ return x
102
+
103
+ def get_norm_layer(norm_type='instance'):
104
+ """Return a normalization layer
105
+ Parameters:
106
+ norm_type (str) -- the name of the normalization layer: batch | instance | none
107
+ For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
108
+ For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
109
+ """
110
+ if norm_type == 'batch':
111
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
112
+ elif norm_type == 'instance':
113
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
114
+ elif norm_type == 'layer':
115
+ norm_layer = functools.partial(LayerNormWarpper)
116
+ elif norm_type == 'none':
117
+ norm_layer = None
118
+ else:
119
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
120
+ return norm_layer
121
+
122
+
123
+ def get_non_linearity(layer_type='relu'):
124
+ if layer_type == 'relu':
125
+ nl_layer = functools.partial(nn.ReLU, inplace=True)
126
+ elif layer_type == 'lrelu':
127
+ nl_layer = functools.partial(
128
+ nn.LeakyReLU, negative_slope=0.2, inplace=True)
129
+ elif layer_type == 'elu':
130
+ nl_layer = functools.partial(nn.ELU, inplace=True)
131
+ elif layer_type == 'selu':
132
+ nl_layer = functools.partial(nn.SELU, inplace=True)
133
+ elif layer_type == 'prelu':
134
+ nl_layer = functools.partial(nn.PReLU)
135
+ else:
136
+ raise NotImplementedError(
137
+ 'nonlinearity activitation [%s] is not found' % layer_type)
138
+ return nl_layer
139
+
140
+
141
+ def define_G(input_nc, output_nc, nz, ngf, netG='unet_128', norm='batch', nl='relu', use_noise=False,
142
+ use_dropout=False, init_type='xavier', init_gain=0.02, gpu_ids=[], where_add='input', upsample='bilinear'):
143
+ net = None
144
+ norm_layer = get_norm_layer(norm_type=norm)
145
+ nl_layer = get_non_linearity(layer_type=nl)
146
+ # print(norm, norm_layer)
147
+
148
+ if nz == 0:
149
+ where_add = 'input'
150
+
151
+ if netG == 'unet_128' and where_add == 'input':
152
+ net = G_Unet_add_input(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
153
+ use_dropout=use_dropout, upsample=upsample, device=gpu_ids)
154
+ elif netG == 'unet_128_G' and where_add == 'input':
155
+ net = G_Unet_add_input_G(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
156
+ use_dropout=use_dropout, upsample=upsample, device=gpu_ids)
157
+ elif netG == 'unet_256' and where_add == 'input':
158
+ net = G_Unet_add_input(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
159
+ use_dropout=use_dropout, upsample=upsample, device=gpu_ids)
160
+ elif netG == 'unet_256_G' and where_add == 'input':
161
+ net = G_Unet_add_input_G(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
162
+ use_dropout=use_dropout, upsample=upsample, device=gpu_ids)
163
+ elif netG == 'unet_128' and where_add == 'all':
164
+ net = G_Unet_add_all(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
165
+ use_dropout=use_dropout, upsample=upsample)
166
+ elif netG == 'unet_256' and where_add == 'all':
167
+ net = G_Unet_add_all(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
168
+ use_dropout=use_dropout, upsample=upsample)
169
+ else:
170
+ raise NotImplementedError('Generator model name [%s] is not recognized' % net)
171
+ # print(net)
172
+ return init_net(net, init_type, init_gain, gpu_ids)
173
+
174
+
175
+ def define_C(input_nc, output_nc, nz, ngf, netC='unet_128', norm='instance', nl='relu',
176
+ use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], upsample='basic'):
177
+ net = None
178
+ norm_layer = get_norm_layer(norm_type=norm)
179
+ nl_layer = get_non_linearity(layer_type=nl)
180
+
181
+ if netC == 'resnet_9blocks':
182
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
183
+ elif netC == 'resnet_6blocks':
184
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
185
+ elif netC == 'unet_128':
186
+ net = G_Unet_add_input_C(input_nc, output_nc, 0, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer,
187
+ use_dropout=use_dropout, upsample=upsample)
188
+ elif netC == 'unet_256':
189
+ net = G_Unet_add_input(input_nc, output_nc, 0, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer,
190
+ use_dropout=use_dropout, upsample=upsample)
191
+ elif netC == 'unet_32':
192
+ net = G_Unet_add_input(input_nc, output_nc, 0, 5, ngf, norm_layer=norm_layer, nl_layer=nl_layer,
193
+ use_dropout=use_dropout, upsample=upsample)
194
+ else:
195
+ raise NotImplementedError('Generator model name [%s] is not recognized' % net)
196
+
197
+ return init_net(net, init_type, init_gain, gpu_ids)
198
+
199
+
200
+ def define_D(input_nc, ndf, netD, norm='batch', nl='lrelu', init_type='xavier', init_gain=0.02, num_Ds=1, gpu_ids=[]):
201
+ net = None
202
+ norm_layer = get_norm_layer(norm_type=norm)
203
+ nl = 'lrelu' # use leaky relu for D
204
+ nl_layer = get_non_linearity(layer_type=nl)
205
+
206
+ if netD == 'basic_128':
207
+ net = D_NLayers(input_nc, ndf, n_layers=2, norm_layer=norm_layer, nl_layer=nl_layer)
208
+ elif netD == 'basic_256':
209
+ net = D_NLayers(input_nc, ndf, n_layers=3, norm_layer=norm_layer, nl_layer=nl_layer)
210
+ elif netD == 'basic_128_multi':
211
+ net = D_NLayersMulti(input_nc=input_nc, ndf=ndf, n_layers=2, norm_layer=norm_layer, num_D=num_Ds, nl_layer=nl_layer)
212
+ elif netD == 'basic_256_multi':
213
+ net = D_NLayersMulti(input_nc=input_nc, ndf=ndf, n_layers=3, norm_layer=norm_layer, num_D=num_Ds, nl_layer=nl_layer)
214
+ else:
215
+ raise NotImplementedError('Discriminator model name [%s] is not recognized' % net)
216
+ return init_net(net, init_type, init_gain, gpu_ids)
217
+
218
+
219
+ def define_E(input_nc, output_nc, ndf, netE, norm='batch', nl='lrelu',
220
+ init_type='xavier', init_gain=0.02, gpu_ids=[], vaeLike=False):
221
+ net = None
222
+ norm_layer = get_norm_layer(norm_type=norm)
223
+ nl = 'lrelu' # use leaky relu for E
224
+ nl_layer = get_non_linearity(layer_type=nl)
225
+ if netE == 'resnet_128':
226
+ net = E_ResNet(input_nc, output_nc, ndf, n_blocks=4, norm_layer=norm_layer,
227
+ nl_layer=nl_layer, vaeLike=vaeLike)
228
+ elif netE == 'resnet_256':
229
+ net = E_ResNet(input_nc, output_nc, ndf, n_blocks=5, norm_layer=norm_layer,
230
+ nl_layer=nl_layer, vaeLike=vaeLike)
231
+ elif netE == 'conv_128':
232
+ net = E_NLayers(input_nc, output_nc, ndf, n_layers=4, norm_layer=norm_layer,
233
+ nl_layer=nl_layer, vaeLike=vaeLike)
234
+ elif netE == 'conv_256':
235
+ net = E_NLayers(input_nc, output_nc, ndf, n_layers=5, norm_layer=norm_layer,
236
+ nl_layer=nl_layer, vaeLike=vaeLike)
237
+ else:
238
+ raise NotImplementedError('Encoder model name [%s] is not recognized' % net)
239
+
240
+ return init_net(net, init_type, init_gain, gpu_ids, False)
241
+
242
+
243
+ class ResnetGenerator(nn.Module):
244
+ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, norm_layer=None, use_dropout=False, n_blocks=6, padding_type='replicate'):
245
+ assert(n_blocks >= 0)
246
+ super(ResnetGenerator, self).__init__()
247
+ self.input_nc = input_nc
248
+ self.output_nc = output_nc
249
+ self.ngf = ngf
250
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
251
+ use_bias = norm_layer.func != nn.BatchNorm2d
252
+ else:
253
+ use_bias = norm_layer != nn.BatchNorm2d
254
+
255
+ model = [nn.ReplicationPad2d(3),
256
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
257
+ bias=use_bias)]
258
+ if norm_layer is not None:
259
+ model += [norm_layer(ngf)]
260
+ model += [nn.ReLU(True)]
261
+
262
+ # n_downsampling = 2
263
+ for i in range(n_downsampling):
264
+ mult = 2**i
265
+ model += [nn.ReplicationPad2d(1),nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
266
+ stride=2, padding=0, bias=use_bias)]
267
+ # model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
268
+ # stride=2, padding=1, bias=use_bias)]
269
+ if norm_layer is not None:
270
+ model += [norm_layer(ngf * mult * 2)]
271
+ model += [nn.ReLU(True)]
272
+
273
+ mult = 2**n_downsampling
274
+ for i in range(n_blocks):
275
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
276
+
277
+ for i in range(n_downsampling):
278
+ mult = 2**(n_downsampling - i)
279
+ # model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
280
+ # kernel_size=3, stride=2,
281
+ # padding=1, output_padding=1,
282
+ # bias=use_bias)]
283
+ # if norm_layer is not None:
284
+ # model += [norm_layer(ngf * mult / 2)]
285
+ # model += [nn.ReLU(True)]
286
+ model += upsampleLayer(ngf * mult, int(ngf * mult / 2), upsample='bilinear', padding_type=padding_type)
287
+ if norm_layer is not None:
288
+ model += [norm_layer(int(ngf * mult / 2))]
289
+ model += [nn.ReLU(True)]
290
+ model +=[nn.ReplicationPad2d(1),
291
+ nn.Conv2d(int(ngf * mult / 2), int(ngf * mult / 2), kernel_size=3, padding=0)]
292
+ if norm_layer is not None:
293
+ model += [norm_layer(ngf * mult / 2)]
294
+ model += [nn.ReLU(True)]
295
+ model += [nn.ReplicationPad2d(3)]
296
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
297
+ #model += [nn.Tanh()]
298
+
299
+ self.model = nn.Sequential(*model)
300
+
301
+ def forward(self, input):
302
+ return self.model(input)
303
+
304
+
305
+ # Define a resnet block
306
+ class ResnetBlock(nn.Module):
307
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
308
+ super(ResnetBlock, self).__init__()
309
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
310
+
311
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
312
+ conv_block = []
313
+ p = 0
314
+ if padding_type == 'reflect':
315
+ conv_block += [nn.ReflectionPad2d(1)]
316
+ elif padding_type == 'replicate':
317
+ conv_block += [nn.ReplicationPad2d(1)]
318
+ elif padding_type == 'zero':
319
+ p = 1
320
+ else:
321
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
322
+
323
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)]
324
+ if norm_layer is not None:
325
+ conv_block += [norm_layer(dim)]
326
+ conv_block += [nn.ReLU(True)]
327
+ # if use_dropout:
328
+ # conv_block += [nn.Dropout(0.5)]
329
+
330
+ p = 0
331
+ if padding_type == 'reflect':
332
+ conv_block += [nn.ReflectionPad2d(1)]
333
+ elif padding_type == 'replicate':
334
+ conv_block += [nn.ReplicationPad2d(1)]
335
+ elif padding_type == 'zero':
336
+ p = 1
337
+ else:
338
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
339
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)]
340
+ if norm_layer is not None:
341
+ conv_block += [norm_layer(dim)]
342
+
343
+ return nn.Sequential(*conv_block)
344
+
345
+ def forward(self, x):
346
+ out = x + self.conv_block(x)
347
+ return out
348
+
349
+
350
+ class D_NLayersMulti(nn.Module):
351
+ def __init__(self, input_nc, ndf=64, n_layers=3,
352
+ norm_layer=nn.BatchNorm2d, num_D=1, nl_layer=None):
353
+ super(D_NLayersMulti, self).__init__()
354
+ # st()
355
+ self.num_D = num_D
356
+ self.nl_layer=nl_layer
357
+ if num_D == 1:
358
+ layers = self.get_layers(input_nc, ndf, n_layers, norm_layer)
359
+ self.model = nn.Sequential(*layers)
360
+ else:
361
+ layers = self.get_layers(input_nc, ndf, n_layers, norm_layer)
362
+ self.add_module("model_0", nn.Sequential(*layers))
363
+ self.down = nn.functional.interpolate
364
+ for i in range(1, num_D):
365
+ ndf_i = int(round(ndf / (2**i)))
366
+ layers = self.get_layers(input_nc, ndf_i, n_layers, norm_layer)
367
+ self.add_module("model_%d" % i, nn.Sequential(*layers))
368
+
369
+ def get_layers(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
370
+ kw = 3
371
+ padw = 1
372
+ sequence = [spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw,
373
+ stride=2, padding=padw)), nn.LeakyReLU(0.2, True)]
374
+
375
+ nf_mult = 1
376
+ nf_mult_prev = 1
377
+ for n in range(1, n_layers):
378
+ nf_mult_prev = nf_mult
379
+ nf_mult = min(2**n, 8)
380
+ sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
381
+ kernel_size=kw, stride=2, padding=padw))]
382
+ if norm_layer:
383
+ sequence += [norm_layer(ndf * nf_mult)]
384
+
385
+ sequence += [self.nl_layer()]
386
+
387
+ nf_mult_prev = nf_mult
388
+ nf_mult = min(2**n_layers, 8)
389
+ sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
390
+ kernel_size=kw, stride=1, padding=padw))]
391
+ if norm_layer:
392
+ sequence += [norm_layer(ndf * nf_mult)]
393
+ sequence += [self.nl_layer()]
394
+
395
+ sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult, 1,
396
+ kernel_size=kw, stride=1, padding=padw))]
397
+
398
+ return sequence
399
+
400
+ def forward(self, input):
401
+ if self.num_D == 1:
402
+ return self.model(input)
403
+ result = []
404
+ down = input
405
+ for i in range(self.num_D):
406
+ model = getattr(self, "model_%d" % i)
407
+ result.append(model(down))
408
+ if i != self.num_D - 1:
409
+ down = self.down(down, scale_factor=0.5, mode='bilinear')
410
+ return result
411
+
412
+ class D_NLayers(nn.Module):
413
+ """Defines a PatchGAN discriminator"""
414
+
415
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
416
+ """Construct a PatchGAN discriminator
417
+ Parameters:
418
+ input_nc (int) -- the number of channels in input images
419
+ ndf (int) -- the number of filters in the last conv layer
420
+ n_layers (int) -- the number of conv layers in the discriminator
421
+ norm_layer -- normalization layer
422
+ """
423
+ super(D_NLayers, self).__init__()
424
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
425
+ use_bias = norm_layer.func != nn.BatchNorm2d
426
+ else:
427
+ use_bias = norm_layer != nn.BatchNorm2d
428
+
429
+ kw = 3
430
+ padw = 1
431
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
432
+ nf_mult = 1
433
+ nf_mult_prev = 1
434
+ for n in range(1, n_layers): # gradually increase the number of filters
435
+ nf_mult_prev = nf_mult
436
+ nf_mult = min(2 ** n, 8)
437
+ sequence += [
438
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
439
+ norm_layer(ndf * nf_mult),
440
+ nn.LeakyReLU(0.2, True)
441
+ ]
442
+
443
+ nf_mult_prev = nf_mult
444
+ nf_mult = min(2 ** n_layers, 8)
445
+ sequence += [
446
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
447
+ norm_layer(ndf * nf_mult),
448
+ nn.LeakyReLU(0.2, True)
449
+ ]
450
+
451
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
452
+ self.model = nn.Sequential(*sequence)
453
+
454
+ def forward(self, input):
455
+ """Standard forward."""
456
+ return self.model(input)
457
+
458
+
459
+ class G_Unet_add_input(nn.Module):
460
+ def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64,
461
+ norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False,
462
+ upsample='basic', device=0):
463
+ super(G_Unet_add_input, self).__init__()
464
+ self.nz = nz
465
+ max_nchn = 8
466
+ noise = []
467
+ for i in range(num_downs+1):
468
+ if use_noise:
469
+ noise.append(True)
470
+ else:
471
+ noise.append(False)
472
+
473
+ # construct unet structure
474
+ #print(num_downs)
475
+ unet_block = UnetBlock_A(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, noise=noise[num_downs-1],
476
+ innermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
477
+ for i in range(num_downs - 5):
478
+ unet_block = UnetBlock_A(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, unet_block, noise[num_downs-i-3],
479
+ norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
480
+ unet_block = UnetBlock_A(ngf * 4, ngf * 4, ngf * max_nchn, unet_block, noise[2],
481
+ norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
482
+ unet_block = UnetBlock_A(ngf * 2, ngf * 2, ngf * 4, unet_block, noise[1],
483
+ norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
484
+ unet_block = UnetBlock_A(ngf, ngf, ngf * 2, unet_block, noise[0],
485
+ norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
486
+ unet_block = UnetBlock_A(input_nc + nz, output_nc, ngf, unet_block, None,
487
+ outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
488
+
489
+ self.model = unet_block
490
+
491
+ def forward(self, x, z=None):
492
+ if self.nz > 0:
493
+ z_img = z.view(z.size(0), z.size(1), 1, 1).expand(
494
+ z.size(0), z.size(1), x.size(2), x.size(3))
495
+ x_with_z = torch.cat([x, z_img], 1)
496
+ else:
497
+ x_with_z = x # no z
498
+
499
+
500
+ return torch.tanh(self.model(x_with_z))
501
+ # return self.model(x_with_z)
502
+
503
+ class G_Unet_add_input_G(nn.Module):
504
+ def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64,
505
+ norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False,
506
+ upsample='basic', device=0):
507
+ super(G_Unet_add_input_G, self).__init__()
508
+ self.nz = nz
509
+ max_nchn = 8
510
+ noise = []
511
+ for i in range(num_downs+1):
512
+ if use_noise:
513
+ noise.append(True)
514
+ else:
515
+ noise.append(False)
516
+ # construct unet structure
517
+ #print(num_downs)
518
+ unet_block = UnetBlock_G(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, noise=False,
519
+ innermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
520
+ for i in range(num_downs - 5):
521
+ unet_block = UnetBlock_G(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, unet_block, noise=False,
522
+ norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
523
+ unet_block = UnetBlock_G(ngf * 4, ngf * 4, ngf * max_nchn, unet_block, noise[2],
524
+ norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic')
525
+ unet_block = UnetBlock_G(ngf * 2, ngf * 2, ngf * 4, unet_block, noise[1],
526
+ norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic')
527
+ unet_block = UnetBlock_G(ngf, ngf, ngf * 2, unet_block, noise[0],
528
+ norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic')
529
+ unet_block = UnetBlock_G(input_nc + nz, output_nc, ngf, unet_block, None,
530
+ outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic')
531
+
532
+ self.model = unet_block
533
+
534
+ def forward(self, x, z=None):
535
+ if self.nz > 0:
536
+ z_img = z.view(z.size(0), z.size(1), 1, 1).expand(
537
+ z.size(0), z.size(1), x.size(2), x.size(3))
538
+ x_with_z = torch.cat([x, z_img], 1)
539
+ else:
540
+ x_with_z = x # no z
541
+
542
+ # return F.tanh(self.model(x_with_z))
543
+ return self.model(x_with_z)
544
+
545
+ class G_Unet_add_input_C(nn.Module):
546
+ def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64,
547
+ norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False,
548
+ upsample='basic', device=0):
549
+ super(G_Unet_add_input_C, self).__init__()
550
+ self.nz = nz
551
+ max_nchn = 8
552
+ # construct unet structure
553
+ #print(num_downs)
554
+ unet_block = UnetBlock(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, noise=False,
555
+ innermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
556
+ for i in range(num_downs - 5):
557
+ unet_block = UnetBlock(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, unet_block, noise=False,
558
+ norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
559
+ unet_block = UnetBlock(ngf * 4, ngf * 4, ngf * max_nchn, unet_block, noise=False,
560
+ norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
561
+ unet_block = UnetBlock(ngf * 2, ngf * 2, ngf * 4, unet_block, noise=False,
562
+ norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
563
+ unet_block = UnetBlock(ngf, ngf, ngf * 2, unet_block, noise=False,
564
+ norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
565
+ unet_block = UnetBlock(input_nc + nz, output_nc, ngf, unet_block, noise=False,
566
+ outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
567
+
568
+ self.model = unet_block
569
+
570
+ def forward(self, x, z=None):
571
+ if self.nz > 0:
572
+ z_img = z.view(z.size(0), z.size(1), 1, 1).expand(
573
+ z.size(0), z.size(1), x.size(2), x.size(3))
574
+ x_with_z = torch.cat([x, z_img], 1)
575
+ else:
576
+ x_with_z = x # no z
577
+
578
+ # return torch.tanh(self.model(x_with_z))
579
+ return self.model(x_with_z)
580
+
581
+ def upsampleLayer(inplanes, outplanes, kw=1, upsample='basic', padding_type='replicate'):
582
+ # padding_type = 'zero'
583
+ if upsample == 'basic':
584
+ upconv = [nn.ConvTranspose2d(inplanes, outplanes, kernel_size=4, stride=2, padding=1)]#, padding_mode='replicate'
585
+ elif upsample == 'bilinear' or upsample == 'nearest' or upsample == 'linear':
586
+ upconv = [nn.Upsample(scale_factor=2, mode=upsample, align_corners=True),
587
+ #nn.ReplicationPad2d(1),
588
+ nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0)]
589
+ # p = kw//2
590
+ # upconv = [nn.Upsample(scale_factor=2, mode=upsample, align_corners=True),
591
+ # nn.Conv2d(inplanes, outplanes, kernel_size=kw, stride=1, padding=p, padding_mode='replicate')]
592
+ else:
593
+ raise NotImplementedError(
594
+ 'upsample layer [%s] not implemented' % upsample)
595
+ return upconv
596
+
597
+ class UnetBlock_G(nn.Module):
598
+ def __init__(self, input_nc, outer_nc, inner_nc,
599
+ submodule=None, noise=None, outermost=False, innermost=False,
600
+ norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='replicate'):
601
+ super(UnetBlock_G, self).__init__()
602
+ self.outermost = outermost
603
+ p = 0
604
+ downconv = []
605
+ if padding_type == 'reflect':
606
+ downconv += [nn.ReflectionPad2d(1)]
607
+ elif padding_type == 'replicate':
608
+ downconv += [nn.ReplicationPad2d(1)]
609
+ elif padding_type == 'zero':
610
+ p = 1
611
+ else:
612
+ raise NotImplementedError(
613
+ 'padding [%s] is not implemented' % padding_type)
614
+
615
+ downconv += [nn.Conv2d(input_nc, inner_nc,
616
+ kernel_size=3, stride=2, padding=p)]
617
+ # downsample is different from upsample
618
+ downrelu = nn.LeakyReLU(0.2, True)
619
+ downnorm = norm_layer(inner_nc) if norm_layer is not None else None
620
+ uprelu = nl_layer()
621
+ uprelu2 = nl_layer()
622
+ uppad = nn.ReplicationPad2d(1)
623
+ upnorm = norm_layer(outer_nc) if norm_layer is not None else None
624
+ upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None
625
+ self.noiseblock = ApplyNoise(outer_nc)
626
+ self.noise = noise
627
+
628
+ if outermost:
629
+ upconv = upsampleLayer(inner_nc * 2, inner_nc, upsample=upsample, padding_type=padding_type)
630
+ uppad = nn.ReplicationPad2d(3)
631
+ upconv2 = nn.Conv2d(inner_nc, outer_nc, kernel_size=7, padding=0)
632
+ down = downconv
633
+ up = [uprelu] + upconv
634
+ if upnorm is not None:
635
+ up += [norm_layer(inner_nc)]
636
+ # upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
637
+ # upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=0)
638
+ # down = downconv
639
+ # up = [uprelu] + upconv
640
+ # if upnorm is not None:
641
+ # up += [norm_layer(outer_nc)]
642
+ up +=[uprelu2, uppad, upconv2] #+ [nn.Tanh()]
643
+ model = down + [submodule] + up
644
+ elif innermost:
645
+ upconv = upsampleLayer(inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
646
+ upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
647
+ down = [downrelu] + downconv
648
+ up = [uprelu] + upconv
649
+ if upnorm is not None:
650
+ up += [upnorm]
651
+ up += [uprelu2, uppad, upconv2]
652
+ if upnorm2 is not None:
653
+ up += [upnorm2]
654
+ model = down + up
655
+ else:
656
+ upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
657
+ upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
658
+ down = [downrelu] + downconv
659
+ if downnorm is not None:
660
+ down += [downnorm]
661
+ up = [uprelu] + upconv
662
+ if upnorm is not None:
663
+ up += [upnorm]
664
+ up += [uprelu2, uppad, upconv2]
665
+ if upnorm2 is not None:
666
+ up += [upnorm2]
667
+
668
+ if use_dropout:
669
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
670
+ else:
671
+ model = down + [submodule] + up
672
+
673
+ self.model = nn.Sequential(*model)
674
+
675
+ def forward(self, x):
676
+ if self.outermost:
677
+ return self.model(x)
678
+ else:
679
+ x2 = self.model(x)
680
+ if self.noise:
681
+ x2 = self.noiseblock(x2, self.noise)
682
+ return torch.cat([x2, x], 1)
683
+
684
+
685
+ class UnetBlock(nn.Module):
686
+ def __init__(self, input_nc, outer_nc, inner_nc,
687
+ submodule=None, noise=None, outermost=False, innermost=False,
688
+ norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='replicate'):
689
+ super(UnetBlock, self).__init__()
690
+ self.outermost = outermost
691
+ p = 0
692
+ downconv = []
693
+ if padding_type == 'reflect':
694
+ downconv += [nn.ReflectionPad2d(1)]
695
+ elif padding_type == 'replicate':
696
+ downconv += [nn.ReplicationPad2d(1)]
697
+ elif padding_type == 'zero':
698
+ p = 1
699
+ else:
700
+ raise NotImplementedError(
701
+ 'padding [%s] is not implemented' % padding_type)
702
+
703
+ downconv += [nn.Conv2d(input_nc, inner_nc,
704
+ kernel_size=3, stride=2, padding=p)]
705
+ # downsample is different from upsample
706
+ downrelu = nn.LeakyReLU(0.2, True)
707
+ downnorm = norm_layer(inner_nc) if norm_layer is not None else None
708
+ uprelu = nl_layer()
709
+ uprelu2 = nl_layer()
710
+ uppad = nn.ReplicationPad2d(1)
711
+ upnorm = norm_layer(outer_nc) if norm_layer is not None else None
712
+ upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None
713
+ self.noiseblock = ApplyNoise(outer_nc)
714
+ self.noise = noise
715
+
716
+ if outermost:
717
+ upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
718
+ upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
719
+ down = downconv
720
+ up = [uprelu] + upconv
721
+ if upnorm is not None:
722
+ up += [upnorm]
723
+ up +=[uprelu2, uppad, upconv2] #+ [nn.Tanh()]
724
+ model = down + [submodule] + up
725
+ elif innermost:
726
+ upconv = upsampleLayer(inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
727
+ upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
728
+ down = [downrelu] + downconv
729
+ up = [uprelu] + upconv
730
+ if upnorm is not None:
731
+ up += [upnorm]
732
+ up += [uprelu2, uppad, upconv2]
733
+ if upnorm2 is not None:
734
+ up += [upnorm2]
735
+ model = down + up
736
+ else:
737
+ upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
738
+ upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
739
+ down = [downrelu] + downconv
740
+ if downnorm is not None:
741
+ down += [downnorm]
742
+ up = [uprelu] + upconv
743
+ if upnorm is not None:
744
+ up += [upnorm]
745
+ up += [uprelu2, uppad, upconv2]
746
+ if upnorm2 is not None:
747
+ up += [upnorm2]
748
+
749
+ if use_dropout:
750
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
751
+ else:
752
+ model = down + [submodule] + up
753
+
754
+ self.model = nn.Sequential(*model)
755
+
756
+ def forward(self, x):
757
+ if self.outermost:
758
+ return self.model(x)
759
+ else:
760
+ x2 = self.model(x)
761
+ if self.noise:
762
+ x2 = self.noiseblock(x2, self.noise)
763
+ return torch.cat([x2, x], 1)
764
+
765
+ # Defines the submodule with skip connection.
766
+ # X -------------------identity---------------------- X
767
+ # |-- downsampling -- |submodule| -- upsampling --|
768
+ class UnetBlock_A(nn.Module):
769
+ def __init__(self, input_nc, outer_nc, inner_nc,
770
+ submodule=None, noise=None, outermost=False, innermost=False,
771
+ norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='replicate'):
772
+ super(UnetBlock_A, self).__init__()
773
+ self.outermost = outermost
774
+ p = 0
775
+ downconv = []
776
+ if padding_type == 'reflect':
777
+ downconv += [nn.ReflectionPad2d(1)]
778
+ elif padding_type == 'replicate':
779
+ downconv += [nn.ReplicationPad2d(1)]
780
+ elif padding_type == 'zero':
781
+ p = 1
782
+ else:
783
+ raise NotImplementedError(
784
+ 'padding [%s] is not implemented' % padding_type)
785
+
786
+ downconv += [spectral_norm(nn.Conv2d(input_nc, inner_nc,
787
+ kernel_size=3, stride=2, padding=p))]
788
+ # downsample is different from upsample
789
+ downrelu = nn.LeakyReLU(0.2, True)
790
+ downnorm = norm_layer(inner_nc) if norm_layer is not None else None
791
+ uprelu = nl_layer()
792
+ uprelu2 = nl_layer()
793
+ uppad = nn.ReplicationPad2d(1)
794
+ upnorm = norm_layer(outer_nc) if norm_layer is not None else None
795
+ upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None
796
+ self.noiseblock = ApplyNoise(outer_nc)
797
+ self.noise = noise
798
+
799
+ if outermost:
800
+ upconv = upsampleLayer(inner_nc * 1, outer_nc, upsample=upsample, padding_type=padding_type)
801
+ upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
802
+ down = downconv
803
+ up = [uprelu] + upconv
804
+ if upnorm is not None:
805
+ up += [upnorm]
806
+ up +=[uprelu2, uppad, upconv2] #+ [nn.Tanh()]
807
+ model = down + [submodule] + up
808
+ elif innermost:
809
+ upconv = upsampleLayer(inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
810
+ upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
811
+ down = [downrelu] + downconv
812
+ up = [uprelu] + upconv
813
+ if upnorm is not None:
814
+ up += [upnorm]
815
+ up += [uprelu2, uppad, upconv2]
816
+ if upnorm2 is not None:
817
+ up += [upnorm2]
818
+ model = down + up
819
+ else:
820
+ upconv = upsampleLayer(inner_nc * 1, outer_nc, upsample=upsample, padding_type=padding_type)
821
+ upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
822
+ down = [downrelu] + downconv
823
+ if downnorm is not None:
824
+ down += [downnorm]
825
+ up = [uprelu] + upconv
826
+ if upnorm is not None:
827
+ up += [upnorm]
828
+ up += [uprelu2, uppad, upconv2]
829
+ if upnorm2 is not None:
830
+ up += [upnorm2]
831
+
832
+ if use_dropout:
833
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
834
+ else:
835
+ model = down + [submodule] + up
836
+
837
+ self.model = nn.Sequential(*model)
838
+
839
+ def forward(self, x):
840
+ if self.outermost:
841
+ return self.model(x)
842
+ else:
843
+ x2 = self.model(x)
844
+ if self.noise:
845
+ x2 = self.noiseblock(x2, self.noise)
846
+ if x2.shape[-1]==x.shape[-1]:
847
+ return x2 + x
848
+ else:
849
+ x2 = F.interpolate(x2, x.shape[2:])
850
+ return x2 + x
851
+
852
+
853
+ class E_ResNet(nn.Module):
854
+ def __init__(self, input_nc=3, output_nc=1, ndf=64, n_blocks=4,
855
+ norm_layer=None, nl_layer=None, vaeLike=False):
856
+ super(E_ResNet, self).__init__()
857
+ self.vaeLike = vaeLike
858
+ max_ndf = 4
859
+ conv_layers = [
860
+ nn.Conv2d(input_nc, ndf, kernel_size=3, stride=2, padding=1, bias=True)]
861
+ for n in range(1, n_blocks):
862
+ input_ndf = ndf * min(max_ndf, n)
863
+ output_ndf = ndf * min(max_ndf, n + 1)
864
+ conv_layers += [BasicBlock(input_ndf,
865
+ output_ndf, norm_layer, nl_layer)]
866
+ conv_layers += [nl_layer(), nn.AdaptiveAvgPool2d(4)]
867
+ if vaeLike:
868
+ self.fc = nn.Sequential(*[nn.Linear(output_ndf * 16, output_nc)])
869
+ self.fcVar = nn.Sequential(*[nn.Linear(output_ndf * 16, output_nc)])
870
+ else:
871
+ self.fc = nn.Sequential(*[nn.Linear(output_ndf * 16, output_nc)])
872
+ self.conv = nn.Sequential(*conv_layers)
873
+
874
+ def forward(self, x):
875
+ x_conv = self.conv(x)
876
+ conv_flat = x_conv.view(x.size(0), -1)
877
+ output = self.fc(conv_flat)
878
+ if self.vaeLike:
879
+ outputVar = self.fcVar(conv_flat)
880
+ return output, outputVar
881
+ else:
882
+ return output
883
+ return output
884
+
885
+
886
+ # Defines the Unet generator.
887
+ # |num_downs|: number of downsamplings in UNet. For example,
888
+ # if |num_downs| == 7, image of size 128x128 will become of size 1x1
889
+ # at the bottleneck
890
+ class G_Unet_add_all(nn.Module):
891
+ def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64,
892
+ norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False, upsample='basic'):
893
+ super(G_Unet_add_all, self).__init__()
894
+ self.nz = nz
895
+ self.mapping = G_mapping(self.nz, self.nz, 512, normalize_latents=False, lrmul=1)
896
+ self.truncation_psi = 0
897
+ self.truncation_cutoff = 0
898
+
899
+ # - 2 means we start from feature map with height and width equals 4.
900
+ # as this example, we get num_layers = 18.
901
+ num_layers = int(np.log2(512)) * 2 - 2
902
+ # Noise inputs.
903
+ self.noise_inputs = []
904
+ for layer_idx in range(num_layers):
905
+ res = layer_idx // 2 + 2
906
+ shape = [1, 1, 2 ** res, 2 ** res]
907
+ self.noise_inputs.append(torch.randn(*shape).to("cuda"))
908
+
909
+ # construct unet structure
910
+ unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, submodule=None, innermost=True,
911
+ norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
912
+ unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, submodule=unet_block,
913
+ norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
914
+ for i in range(num_downs - 6):
915
+ unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, submodule=unet_block,
916
+ norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
917
+ unet_block = UnetBlock_with_z(ngf * 4, ngf * 4, ngf * 8, nz, submodule=unet_block,
918
+ norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
919
+ unet_block = UnetBlock_with_z(ngf * 2, ngf * 2, ngf * 4, nz, submodule=unet_block,
920
+ norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
921
+ unet_block = UnetBlock_with_z(ngf, ngf, ngf * 2, nz, submodule=unet_block,
922
+ norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
923
+ unet_block = UnetBlock_with_z(input_nc, output_nc, ngf, nz, submodule=unet_block,
924
+ outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
925
+ self.model = unet_block
926
+
927
+ def forward(self, x, z):
928
+
929
+ dlatents1, num_layers = self.mapping(z)
930
+ dlatents1 = dlatents1.unsqueeze(1)
931
+ dlatents1 = dlatents1.expand(-1, int(num_layers), -1)
932
+
933
+ # Apply truncation trick.
934
+ if self.truncation_psi and self.truncation_cutoff:
935
+ coefs = np.ones([1, num_layers, 1], dtype=np.float32)
936
+ for i in range(num_layers):
937
+ if i < self.truncation_cutoff:
938
+ coefs[:, i, :] *= self.truncation_psi
939
+ """Linear interpolation.
940
+ a + (b - a) * t (a = 0)
941
+ reduce to
942
+ b * t
943
+ """
944
+ dlatents1 = dlatents1 * torch.Tensor(coefs).to(dlatents1.device)
945
+
946
+ return torch.tanh(self.model(x, dlatents1, self.noise_inputs))
947
+
948
+
949
+ class ApplyNoise(nn.Module):
950
+ def __init__(self, channels):
951
+ super().__init__()
952
+ self.channels = channels
953
+ self.weight = nn.Parameter(torch.randn(channels), requires_grad=True)
954
+ self.bias = nn.Parameter(torch.zeros(channels), requires_grad=True)
955
+
956
+ def forward(self, x, noise):
957
+ W,_ = torch.split(self.weight.view(1, -1, 1, 1), self.channels // 2, dim=1)
958
+ B,_ = torch.split(self.bias.view(1, -1, 1, 1), self.channels // 2, dim=1)
959
+ Z = torch.zeros_like(W)
960
+ w = torch.cat([W,Z], dim=1).to(x.device)
961
+ b = torch.cat([B,Z], dim=1).to(x.device)
962
+ adds = w * torch.randn_like(x) + b
963
+ return x + adds.type_as(x)
964
+
965
+
966
+ class FC(nn.Module):
967
+ def __init__(self,
968
+ in_channels,
969
+ out_channels,
970
+ gain=2**(0.5),
971
+ use_wscale=False,
972
+ lrmul=1.0,
973
+ bias=True):
974
+ """
975
+ The complete conversion of Dense/FC/Linear Layer of original Tensorflow version.
976
+ """
977
+ super(FC, self).__init__()
978
+ he_std = gain * in_channels ** (-0.5) # He init
979
+ if use_wscale:
980
+ init_std = 1.0 / lrmul
981
+ self.w_lrmul = he_std * lrmul
982
+ else:
983
+ init_std = he_std / lrmul
984
+ self.w_lrmul = lrmul
985
+
986
+ self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels) * init_std)
987
+ if bias:
988
+ self.bias = torch.nn.Parameter(torch.zeros(out_channels))
989
+ self.b_lrmul = lrmul
990
+ else:
991
+ self.bias = None
992
+
993
+ def forward(self, x):
994
+ if self.bias is not None:
995
+ out = F.linear(x, self.weight * self.w_lrmul, self.bias * self.b_lrmul)
996
+ else:
997
+ out = F.linear(x, self.weight * self.w_lrmul)
998
+ out = F.leaky_relu(out, 0.2, inplace=True)
999
+ return out
1000
+
1001
+
1002
+ class ApplyStyle(nn.Module):
1003
+ """
1004
+ @ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
1005
+ """
1006
+ def __init__(self, latent_size, channels, use_wscale, nl_layer):
1007
+ super(ApplyStyle, self).__init__()
1008
+ modules = [nn.Linear(latent_size, channels*2)]
1009
+ if nl_layer:
1010
+ modules += [nl_layer()]
1011
+ self.linear = nn.Sequential(*modules)
1012
+
1013
+ def forward(self, x, latent):
1014
+ style = self.linear(latent) # style => [batch_size, n_channels*2]
1015
+ shape = [-1, 2, x.size(1), 1, 1]
1016
+ style = style.view(shape) # [batch_size, 2, n_channels, ...]
1017
+ x = x * (style[:, 0] + 1.) + style[:, 1]
1018
+ return x
1019
+
1020
+ class PixelNorm(nn.Module):
1021
+ def __init__(self, epsilon=1e-8):
1022
+ """
1023
+ @notice: avoid in-place ops.
1024
+ https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
1025
+ """
1026
+ super(PixelNorm, self).__init__()
1027
+ self.epsilon = epsilon
1028
+
1029
+ def forward(self, x):
1030
+ tmp = torch.mul(x, x) # or x ** 2
1031
+ tmp1 = torch.rsqrt(torch.mean(tmp, dim=1, keepdim=True) + self.epsilon)
1032
+
1033
+ return x * tmp1
1034
+
1035
+
1036
+ class InstanceNorm(nn.Module):
1037
+ def __init__(self, epsilon=1e-8):
1038
+ """
1039
+ @notice: avoid in-place ops.
1040
+ https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
1041
+ """
1042
+ super(InstanceNorm, self).__init__()
1043
+ self.epsilon = epsilon
1044
+
1045
+ def forward(self, x):
1046
+ x = x - torch.mean(x, (2, 3), True)
1047
+ tmp = torch.mul(x, x) # or x ** 2
1048
+ tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
1049
+ return x * tmp
1050
+
1051
+
1052
+ class LayerEpilogue(nn.Module):
1053
+ def __init__(self, channels, dlatent_size, use_wscale, use_noise,
1054
+ use_pixel_norm, use_instance_norm, use_styles, nl_layer=None):
1055
+ super(LayerEpilogue, self).__init__()
1056
+ self.use_noise = use_noise
1057
+ if use_noise:
1058
+ self.noise = ApplyNoise(channels)
1059
+ self.act = nn.LeakyReLU(negative_slope=0.2)
1060
+
1061
+ if use_pixel_norm:
1062
+ self.pixel_norm = PixelNorm()
1063
+ else:
1064
+ self.pixel_norm = None
1065
+
1066
+ if use_instance_norm:
1067
+ self.instance_norm = InstanceNorm()
1068
+ else:
1069
+ self.instance_norm = None
1070
+
1071
+ if use_styles:
1072
+ self.style_mod = ApplyStyle(dlatent_size, channels, use_wscale=use_wscale, nl_layer=nl_layer)
1073
+ else:
1074
+ self.style_mod = None
1075
+
1076
+ def forward(self, x, noise, dlatents_in_slice=None):
1077
+ # if noise is not None:
1078
+ if self.use_noise:
1079
+ x = self.noise(x, noise)
1080
+ x = self.act(x)
1081
+ if self.pixel_norm is not None:
1082
+ x = self.pixel_norm(x)
1083
+ if self.instance_norm is not None:
1084
+ x = self.instance_norm(x)
1085
+ if self.style_mod is not None:
1086
+ x = self.style_mod(x, dlatents_in_slice)
1087
+
1088
+ return x
1089
+
1090
+ class G_mapping(nn.Module):
1091
+ def __init__(self,
1092
+ mapping_fmaps=512,
1093
+ dlatent_size=512,
1094
+ resolution=512,
1095
+ normalize_latents=True, # Normalize latent vectors (Z) before feeding them to the mapping layers?
1096
+ use_wscale=True, # Enable equalized learning rate?
1097
+ lrmul=0.01, # Learning rate multiplier for the mapping layers.
1098
+ gain=2**(0.5), # original gain in tensorflow.
1099
+ nl_layer=None
1100
+ ):
1101
+ super(G_mapping, self).__init__()
1102
+ self.mapping_fmaps = mapping_fmaps
1103
+ func = [
1104
+ nn.Linear(self.mapping_fmaps, dlatent_size)
1105
+ ]
1106
+ if nl_layer:
1107
+ func += [nl_layer()]
1108
+
1109
+ for j in range(0,4):
1110
+ func += [
1111
+ nn.Linear(dlatent_size, dlatent_size)
1112
+ ]
1113
+ if nl_layer:
1114
+ func += [nl_layer()]
1115
+
1116
+ self.func = nn.Sequential(*func)
1117
+ #FC(self.mapping_fmaps, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale),
1118
+ #FC(dlatent_size, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale),
1119
+
1120
+ self.normalize_latents = normalize_latents
1121
+ self.resolution_log2 = int(np.log2(resolution))
1122
+ self.num_layers = self.resolution_log2 * 2 - 2
1123
+ self.pixel_norm = PixelNorm()
1124
+ # - 2 means we start from feature map with height and width equals 4.
1125
+ # as this example, we get num_layers = 18.
1126
+
1127
+ def forward(self, x):
1128
+ if self.normalize_latents:
1129
+ x = self.pixel_norm(x)
1130
+ out = self.func(x)
1131
+ return out, self.num_layers
1132
+
1133
+ class UnetBlock_with_z(nn.Module):
1134
+ def __init__(self, input_nc, outer_nc, inner_nc, nz=0,
1135
+ submodule=None, outermost=False, innermost=False,
1136
+ norm_layer=None, nl_layer=None, use_dropout=False,
1137
+ upsample='basic', padding_type='replicate'):
1138
+ super(UnetBlock_with_z, self).__init__()
1139
+ p = 0
1140
+ downconv = []
1141
+ if padding_type == 'reflect':
1142
+ downconv += [nn.ReflectionPad2d(1)]
1143
+ elif padding_type == 'replicate':
1144
+ downconv += [nn.ReplicationPad2d(1)]
1145
+ elif padding_type == 'zero':
1146
+ p = 1
1147
+ else:
1148
+ raise NotImplementedError(
1149
+ 'padding [%s] is not implemented' % padding_type)
1150
+
1151
+ self.outermost = outermost
1152
+ self.innermost = innermost
1153
+ self.nz = nz
1154
+
1155
+ # input_nc = input_nc + nz
1156
+ downconv += [spectral_norm(nn.Conv2d(input_nc, inner_nc,
1157
+ kernel_size=3, stride=2, padding=p))]
1158
+ # downsample is different from upsample
1159
+ downrelu = nn.LeakyReLU(0.2, True)
1160
+ downnorm = norm_layer(inner_nc) if norm_layer is not None else None
1161
+ uprelu = nl_layer()
1162
+ uprelu2 = nl_layer()
1163
+ uppad = nn.ReplicationPad2d(1)
1164
+ upnorm = norm_layer(outer_nc) if norm_layer is not None else None
1165
+ upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None
1166
+
1167
+ use_styles=False
1168
+ uprelu = nl_layer()
1169
+ if self.nz >0:
1170
+ use_styles=True
1171
+
1172
+ if outermost:
1173
+ self.adaIn = LayerEpilogue(inner_nc, self.nz, use_wscale=True, use_noise=False,
1174
+ use_pixel_norm=True, use_instance_norm=True, use_styles=use_styles, nl_layer=nl_layer)
1175
+ upconv = upsampleLayer(
1176
+ inner_nc , outer_nc, upsample=upsample, padding_type=padding_type)
1177
+ upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
1178
+ down = downconv
1179
+ up = [uprelu] + upconv
1180
+ if upnorm is not None:
1181
+ up += [upnorm]
1182
+ up +=[uprelu2, uppad, upconv2] #+ [nn.Tanh()]
1183
+ elif innermost:
1184
+ self.adaIn = LayerEpilogue(inner_nc, self.nz, use_wscale=True, use_noise=True,
1185
+ use_pixel_norm=True, use_instance_norm=True, use_styles=use_styles, nl_layer=nl_layer)
1186
+ upconv = upsampleLayer(
1187
+ inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
1188
+ upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
1189
+ down = [downrelu] + downconv
1190
+ up = [uprelu] + upconv
1191
+ if norm_layer is not None:
1192
+ up += [norm_layer(outer_nc)]
1193
+ up += [uprelu2, uppad, upconv2]
1194
+ if upnorm2 is not None:
1195
+ up += [upnorm2]
1196
+ else:
1197
+ self.adaIn = LayerEpilogue(inner_nc, self.nz, use_wscale=True, use_noise=False,
1198
+ use_pixel_norm=True, use_instance_norm=True, use_styles=use_styles, nl_layer=nl_layer)
1199
+ upconv = upsampleLayer(
1200
+ inner_nc , outer_nc, upsample=upsample, padding_type=padding_type)
1201
+ upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
1202
+ down = [downrelu] + downconv
1203
+ if norm_layer is not None:
1204
+ down += [norm_layer(inner_nc)]
1205
+ up = [uprelu] + upconv
1206
+
1207
+ if norm_layer is not None:
1208
+ up += [norm_layer(outer_nc)]
1209
+ up += [uprelu2, uppad, upconv2]
1210
+ if upnorm2 is not None:
1211
+ up += [upnorm2]
1212
+
1213
+ if use_dropout:
1214
+ up += [nn.Dropout(0.5)]
1215
+ self.down = nn.Sequential(*down)
1216
+ self.submodule = submodule
1217
+ self.up = nn.Sequential(*up)
1218
+
1219
+
1220
+ def forward(self, x, z, noise):
1221
+ if self.outermost:
1222
+ x1 = self.down(x)
1223
+ x2 = self.submodule(x1, z[:,2:], noise[2:])
1224
+ return self.up(x2)
1225
+
1226
+ elif self.innermost:
1227
+ x1 = self.down(x)
1228
+ x_and_z = self.adaIn(x1, noise[0], z[:,0])
1229
+ x2 = self.up(x_and_z)
1230
+ x2 = F.interpolate(x2, x.shape[2:])
1231
+ return x2 + x
1232
+
1233
+ else:
1234
+ x1 = self.down(x)
1235
+ x2 = self.submodule(x1, z[:,2:], noise[2:])
1236
+ x_and_z = self.adaIn(x2, noise[0], z[:,0])
1237
+ return self.up(x_and_z) + x
1238
+
1239
+
1240
+ class E_NLayers(nn.Module):
1241
+ def __init__(self, input_nc, output_nc=1, ndf=64, n_layers=4,
1242
+ norm_layer=None, nl_layer=None, vaeLike=False):
1243
+ super(E_NLayers, self).__init__()
1244
+ self.vaeLike = vaeLike
1245
+
1246
+ kw, padw = 3, 1
1247
+ sequence = [spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw,
1248
+ stride=2, padding=padw, padding_mode='replicate')), nl_layer()]
1249
+
1250
+ nf_mult = 1
1251
+ nf_mult_prev = 1
1252
+ for n in range(1, n_layers):
1253
+ nf_mult_prev = nf_mult
1254
+ nf_mult = min(2**n, 8)
1255
+ sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
1256
+ kernel_size=kw, stride=2, padding=padw, padding_mode='replicate'))]
1257
+ if norm_layer is not None:
1258
+ sequence += [norm_layer(ndf * nf_mult)]
1259
+ sequence += [nl_layer()]
1260
+ sequence += [nn.AdaptiveAvgPool2d(4)]
1261
+ self.conv = nn.Sequential(*sequence)
1262
+ self.fc = nn.Sequential(*[spectral_norm(nn.Linear(ndf * nf_mult * 16, output_nc))])
1263
+ if vaeLike:
1264
+ self.fcVar = nn.Sequential(*[spectral_norm(nn.Linear(ndf * nf_mult * 16, output_nc))])
1265
+
1266
+ def forward(self, x):
1267
+ x_conv = self.conv(x)
1268
+ conv_flat = x_conv.view(x.size(0), -1)
1269
+ output = self.fc(conv_flat)
1270
+ if self.vaeLike:
1271
+ outputVar = self.fcVar(conv_flat)
1272
+ return output, outputVar
1273
+ return output
1274
+
1275
+ class BasicBlock(nn.Module):
1276
+ def __init__(self, inplanes, outplanes):
1277
+ super(BasicBlock, self).__init__()
1278
+ layers = []
1279
+ norm_layer=get_norm_layer(norm_type='layer') #functools.partial(LayerNorm)
1280
+ # norm_layer = None
1281
+ nl_layer=nn.ReLU()
1282
+ if norm_layer is not None:
1283
+ layers += [norm_layer(inplanes)]
1284
+ layers += [nl_layer]
1285
+ layers += [nn.ReplicationPad2d(1),
1286
+ nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=1,
1287
+ padding=0, bias=True)]
1288
+ self.conv = nn.Sequential(*layers)
1289
+
1290
+ def forward(self, x):
1291
+ return self.conv(x)
1292
+
1293
+
1294
+ def define_SVAE(inc=96, outc=3, outplanes=64, blocks=1, netVAE='SVAE', model_name='', load_ext=True, save_dir='',
1295
+ init_type="normal", init_gain=0.02, gpu_ids=[]):
1296
+ if netVAE == 'SVAE':
1297
+ net = ScreenVAE(inc=inc, outc=outc, outplanes=outplanes, blocks=blocks, save_dir=save_dir,
1298
+ init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids)
1299
+ else:
1300
+ raise NotImplementedError('Encoder model name [%s] is not recognized' % net)
1301
+ init_net(net, init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids)
1302
+ net.load_networks('latest')
1303
+ return net
1304
+
1305
+
1306
+ class ScreenVAE(nn.Module):
1307
+ def __init__(self,inc=1,outc=4, outplanes=64, downs=5, blocks=2,load_ext=True, save_dir='',init_type="normal", init_gain=0.02, gpu_ids=[]):
1308
+ super(ScreenVAE, self).__init__()
1309
+ self.inc = inc
1310
+ self.outc = outc
1311
+ self.save_dir = save_dir
1312
+ norm_layer=functools.partial(LayerNormWarpper)
1313
+ nl_layer=nn.LeakyReLU
1314
+
1315
+ self.model_names=['enc','dec']
1316
+ self.enc=define_C(inc+1, outc*2, 0, 24, netC='resnet_6blocks',
1317
+ norm='layer', nl='lrelu', use_dropout=True, init_type='kaiming',
1318
+ gpu_ids=gpu_ids, upsample='bilinear')
1319
+ self.dec=define_G(outc, inc, 0, 48, netG='unet_128_G',
1320
+ norm='layer', nl='lrelu', use_dropout=True, init_type='kaiming',
1321
+ gpu_ids=gpu_ids, where_add='input', upsample='bilinear', use_noise=True)
1322
+
1323
+ for param in self.parameters():
1324
+ param.requires_grad = False
1325
+
1326
+ def load_networks(self, epoch):
1327
+ """Load all the networks from the disk.
1328
+
1329
+ Parameters:
1330
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
1331
+ """
1332
+ for name in self.model_names:
1333
+ if isinstance(name, str):
1334
+ load_filename = '%s_net_%s.pth' % (epoch, name)
1335
+ load_path = os.path.join(self.save_dir, load_filename)
1336
+ net = getattr(self, name)
1337
+ if isinstance(net, torch.nn.DataParallel):
1338
+ net = net.module
1339
+ print('loading the model from %s' % load_path)
1340
+ state_dict = torch.load(
1341
+ load_path, map_location=lambda storage, loc: storage.cuda())
1342
+ if hasattr(state_dict, '_metadata'):
1343
+ del state_dict._metadata
1344
+
1345
+ net.load_state_dict(state_dict)
1346
+ del state_dict
1347
+
1348
+ def npad(self, im, pad=128):
1349
+ h,w = im.shape[-2:]
1350
+ hp = h //pad*pad+pad
1351
+ wp = w //pad*pad+pad
1352
+ return F.pad(im, (0, wp-w, 0, hp-h), mode='replicate')
1353
+
1354
+ def forward(self, x, line=None, img_input=True, output_screen_only=True):
1355
+ if img_input:
1356
+ if line is None:
1357
+ line = torch.ones_like(x)
1358
+ else:
1359
+ line = torch.sign(line)
1360
+ x = torch.clamp(x + (1-line),-1,1)
1361
+ h,w = x.shape[-2:]
1362
+ input = torch.cat([x, line], 1)
1363
+ input = self.npad(input)
1364
+ inter = self.enc(input)[:,:,:h,:w]
1365
+ scr, logvar = torch.split(inter, (self.outc, self.outc), dim=1)
1366
+ if output_screen_only:
1367
+ return scr
1368
+ recons = self.dec(scr)
1369
+ return recons, scr, logvar
1370
+ else:
1371
+ h,w = x.shape[-2:]
1372
+ x = self.npad(x)
1373
+ recons = self.dec(x)[:,:,:h,:w]
1374
+ recons = (recons+1)*(line+1)/2-1
1375
+ return torch.clamp(recons,-1,1)
BidirectionalTranslation/options/base_options.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from util import util
4
+ import torch
5
+ import models
6
+ import data
7
+
8
+ class BaseOptions():
9
+ def __init__(self):
10
+ self.initialized = False
11
+
12
+ def initialize(self, parser):
13
+ """Initialize options used during both training and test time."""
14
+ # Basic options
15
+ parser.add_argument('--dataroot', required=False, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
16
+ parser.add_argument('--batch_size', type=int, default=2, help='input batch size')
17
+ parser.add_argument('--load_size', type=int, default=512, help='scale images to this size') # Modified default
18
+ parser.add_argument('--crop_size', type=int, default=1024, help='then crop to this size') # Modified default
19
+ parser.add_argument('--input_nc', type=int, default=1, help='# of input image channels') # Modified default
20
+ parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') # Modified default
21
+ parser.add_argument('--nz', type=int, default=64, help='#latent vector') # Modified default
22
+ parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2, -1 for CPU mode')
23
+ parser.add_argument('--name', type=str, default='color2manga_cycle_ganstft', help='name of the experiment') # Modified default
24
+ parser.add_argument('--preprocess', type=str, default='none', help='not implemented') # Modified default
25
+ parser.add_argument('--dataset_mode', type=str, default='aligned', help='aligned,single')
26
+ parser.add_argument('--model', type=str, default='cycle_ganstft', help='chooses which model to use')
27
+ parser.add_argument('--direction', type=str, default='BtoA', help='AtoB or BtoA') # Modified default
28
+ parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
29
+ parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
30
+ parser.add_argument('--local_rank', default=0, type=int, help='# threads for loading data')
31
+ parser.add_argument('--checkpoints_dir', type=str, default=self.model_global_path+'/ScreenStyle/color2manga/', help='models are saved here') # Modified default
32
+ parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
33
+ parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator')
34
+ parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset.')
35
+ parser.add_argument('--no_flip', action='store_false', help='if specified, do not flip the images for data argumentation') # Modified default
36
+
37
+ # Model parameters
38
+ parser.add_argument('--level', type=int, default=0, help='level to train')
39
+ parser.add_argument('--num_Ds', type=int, default=2, help='number of Discriminators')
40
+ parser.add_argument('--netD', type=str, default='basic_256_multi', help='selects model to use for netD')
41
+ parser.add_argument('--netD2', type=str, default='basic_256_multi', help='selects model to use for netD2')
42
+ parser.add_argument('--netG', type=str, default='unet_256', help='selects model to use for netG')
43
+ parser.add_argument('--netC', type=str, default='unet_128', help='selects model to use for netC')
44
+ parser.add_argument('--netE', type=str, default='conv_256', help='selects model to use for netE')
45
+ parser.add_argument('--nef', type=int, default=48, help='# of encoder filters in the first conv layer') # Modified default
46
+ parser.add_argument('--ngf', type=int, default=48, help='# of gen filters in the last conv layer') # Modified default
47
+ parser.add_argument('--ndf', type=int, default=32, help='# of discrim filters in the first conv layer') # Modified default
48
+ parser.add_argument('--norm', type=str, default='layer', help='instance normalization or batch normalization')
49
+ parser.add_argument('--upsample', type=str, default='bilinear', help='basic | bilinear') # Modified default
50
+ parser.add_argument('--nl', type=str, default='prelu', help='non-linearity activation: relu | lrelu | elu')
51
+ parser.add_argument('--no_encode', action='store_true', help='if specified, print more debugging information')
52
+ parser.add_argument('--color2screen', action='store_true', help='continue training: load the latest model including RGB model') # Modified default
53
+
54
+ # Extra parameters
55
+ parser.add_argument('--where_add', type=str, default='all', help='input|all|middle; where to add z in the network G')
56
+ parser.add_argument('--conditional_D', action='store_true', help='if use conditional GAN for D')
57
+ parser.add_argument('--init_type', type=str, default='kaiming', help='network initialization [normal | xavier | kaiming | orthogonal]')
58
+ parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
59
+ parser.add_argument('--center_crop', action='store_true', help='if apply for center cropping for the test') # Modified default
60
+ parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
61
+ parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
62
+ parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
63
+
64
+ # Special tasks
65
+ self.initialized = True
66
+ return parser
67
+
68
+ def gather_options(self):
69
+ """Initialize our parser with basic options (only once)."""
70
+ if not self.initialized:
71
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
72
+ parser = self.initialize(parser)
73
+
74
+ # Get the basic options
75
+ opt, _ = parser.parse_known_args()
76
+
77
+ # Modify model-related parser options
78
+ model_name = opt.model
79
+ model_option_setter = models.get_option_setter(model_name)
80
+ parser = model_option_setter(parser, self.isTrain)
81
+ opt, _ = parser.parse_known_args() # Parse again with new defaults
82
+
83
+ # Modify dataset-related parser options
84
+ dataset_name = opt.dataset_mode
85
+ dataset_option_setter = data.get_option_setter(dataset_name)
86
+ parser = dataset_option_setter(parser, self.isTrain)
87
+
88
+ # Save and return the parser
89
+ self.parser = parser
90
+ return parser.parse_args()
91
+
92
+ def print_options(self, opt):
93
+ """Print and save options."""
94
+ message = ''
95
+ message += '----------------- Options ---------------\n'
96
+ for k, v in sorted(vars(opt).items()):
97
+ comment = ''
98
+ default = self.parser.get_default(k)
99
+ if v != default:
100
+ comment = '\t[default: %s]' % str(default)
101
+ message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
102
+ message += '----------------- End -------------------'
103
+ print(message)
104
+
105
+ # Save to the disk
106
+ expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
107
+ if not os.path.exists(expr_dir):
108
+ try:
109
+ util.mkdirs(expr_dir)
110
+ except:
111
+ pass
112
+ file_name = os.path.join(expr_dir, 'opt.txt')
113
+ with open(file_name, 'wt') as opt_file:
114
+ opt_file.write(message)
115
+ opt_file.write('\n')
116
+
117
+ def parse(self, model_global_path):
118
+ """Parse options, create checkpoints directory suffix, and set up gpu device."""
119
+ self.model_global_path = model_global_path
120
+ opt = self.gather_options()
121
+ opt.isTrain = self.isTrain # train or test
122
+
123
+
124
+ # Process opt.suffix
125
+ if opt.suffix:
126
+ suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
127
+ opt.name = opt.name + suffix
128
+
129
+ self.print_options(opt)
130
+
131
+ # Set gpu ids
132
+ str_ids = opt.gpu_ids.split(',')
133
+ opt.gpu_ids = []
134
+ for str_id in str_ids:
135
+ id = int(str_id)
136
+ if id >= 0:
137
+ opt.gpu_ids.append(id)
138
+ if len(opt.gpu_ids) > 0:
139
+ torch.cuda.set_device(opt.gpu_ids[0])
140
+
141
+ self.opt = opt
142
+ return self.opt
BidirectionalTranslation/options/test_options.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_options import BaseOptions
2
+
3
+ class TestOptions(BaseOptions):
4
+ def initialize(self, parser):
5
+ BaseOptions.initialize(self, parser)
6
+
7
+
8
+ # Additional test-specific arguments
9
+ parser.add_argument('--results_dir', type=str, default='../results/', help='saves results here.')
10
+ parser.add_argument('--phase', type=str, default='val', help='train, val, test, etc')
11
+ parser.add_argument('--num_test', type=int, default=30, help='how many test images to run')
12
+ parser.add_argument('--n_samples', type=int, default=1, help='#samples')
13
+ parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio for the results')
14
+ parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
15
+ parser.add_argument('--folder', type=str, default='intra', help='saves results here.')
16
+ parser.add_argument('--sync', action='store_true', help='use the same latent code for different input images')
17
+
18
+ self.isTrain = False
19
+ return parser
BidirectionalTranslation/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch~=1.6.0
2
+ torchvision~=0.4.0
3
+ tensorboardx~=1.9
4
+ scipy==1.1
5
+ dominate~=2.3.1
6
+ scikit-image~=0.16.2
7
+ opencv-python~=3.4.2
8
+ lpips
BidirectionalTranslation/scripts/test_western2manga.sh ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -ex
2
+ # models
3
+ RESULTS_DIR='./results/test/western2manga'
4
+
5
+ # dataset
6
+ CLASS='color2manga'
7
+ MODEL='cycle_ganstft'
8
+ DIRECTION='BtoA' # from domain A to domain B
9
+ PREPROCESS='none'
10
+ LOAD_SIZE=512 # scale images to this size
11
+ CROP_SIZE=1024 # then crop to this size
12
+ INPUT_NC=1 # number of channels in the input image
13
+ OUTPUT_NC=3 # number of channels in the input image
14
+ NGF=48
15
+ NEF=48
16
+ NDF=32
17
+ NZ=64
18
+
19
+ # misc
20
+ GPU_ID=0 # gpu id
21
+ NUM_TEST=30 # number of input images duirng test
22
+ NUM_SAMPLES=1 # number of samples per input images
23
+ NAME=${CLASS}_${MODEL}
24
+
25
+ # command
26
+ CUDA_VISIBLE_DEVICES=${GPU_ID} \
27
+ python3 ./test.py \
28
+ --dataroot ./datasets/${CLASS} \
29
+ --results_dir ${RESULTS_DIR} \
30
+ --checkpoints_dir ./checkpoints/${CLASS}/ \
31
+ --name ${NAME} \
32
+ --model ${MODEL} \
33
+ --direction ${DIRECTION} \
34
+ --preprocess ${PREPROCESS} \
35
+ --load_size ${LOAD_SIZE} \
36
+ --crop_size ${CROP_SIZE} \
37
+ --input_nc ${INPUT_NC} \
38
+ --output_nc ${OUTPUT_NC} \
39
+ --nz ${NZ} \
40
+ --netE conv_256 \
41
+ --num_test ${NUM_TEST} \
42
+ --n_samples ${NUM_SAMPLES} \
43
+ --upsample bilinear \
44
+ --ngf ${NGF} \
45
+ --nef ${NEF} \
46
+ --ndf ${NDF} \
47
+ --center_crop \
48
+ --color2screen \
49
+ --no_flip
BidirectionalTranslation/test.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from options.test_options import TestOptions
3
+ from data import create_dataset
4
+ from models import create_model
5
+ from util.visualizer import save_images
6
+ from itertools import islice
7
+ from util import html
8
+ import cv2
9
+
10
+ seed = 10
11
+ import torch
12
+ import numpy as np
13
+ torch.manual_seed(seed)
14
+ torch.cuda.manual_seed(seed)
15
+ np.random.seed(seed)
16
+
17
+ # options
18
+ opt = TestOptions().parse()
19
+ opt.num_threads = 1 # test code only supports num_threads=1
20
+ opt.batch_size = 1 # test code only supports batch_size=1
21
+ opt.serial_batches = True # no shuffle
22
+
23
+ model = create_model(opt)
24
+ model.setup(opt)
25
+ model.eval()
26
+ print('Loading model %s' % opt.model)
27
+
28
+ testdata = ['manga_paper']
29
+ # fake_sty = model.get_z_random(1, 64, truncation=True)
30
+
31
+ opt.dataset_mode = 'singleSr'
32
+ for folder in testdata:
33
+ opt.folder = folder
34
+ # create dataset
35
+ dataset = create_dataset(opt)
36
+ web_dir = os.path.join(opt.results_dir, opt.folder + '_Sr2Co')
37
+ webpage = html.HTML(web_dir, 'Training = %s, Phase = %s, Class =%s' % (opt.name, opt.phase, opt.name))
38
+ # fake_sty = model.get_z_random(1, 64, truncation=True)
39
+ for i, data in enumerate(islice(dataset, opt.num_test)):
40
+ h = data['h']
41
+ w = data['w']
42
+ model.set_input(data)
43
+ fake_sty = model.get_z_random(1, 64, truncation=True, tvalue=1.25)
44
+ fake_B, SCR, line = model.forward(AtoB=False, sty=fake_sty)
45
+ images=[fake_B[:,:,:h,:w]]
46
+ names=['color']
47
+
48
+ img_path = 'input_%3.3d' % i
49
+ save_images(webpage, images, names, img_path, aspect_ratio=opt.aspect_ratio, width=opt.crop_size)
50
+ webpage.save()
51
+
52
+ testdata = ['western_paper']
53
+
54
+ opt.dataset_mode = 'singleCo'
55
+ for folder in testdata:
56
+ opt.folder = folder
57
+ # create dataset
58
+ dataset = create_dataset(opt)
59
+ web_dir = os.path.join(opt.results_dir, opt.folder + '_Sr2Co')
60
+ webpage = html.HTML(web_dir, 'Training = %s, Phase = %s, Class =%s' % (opt.name, opt.phase, opt.name))
61
+ for i, data in enumerate(islice(dataset, opt.num_test)):
62
+ h = data['h']
63
+ w = data['w']
64
+ model.set_input(data)
65
+ fake_B, fake_B2, SCR = model.forward(AtoB=True)
66
+ images=[fake_B2[:,:,:h,:w]]
67
+ names=['manga']
68
+
69
+ img_path = 'input_%3.3d' % i
70
+ save_images(webpage, images, names, img_path, aspect_ratio=opt.aspect_ratio, width=opt.crop_size)
71
+ webpage.save()
BidirectionalTranslation/util/html.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dominate
2
+ from dominate.tags import meta, h3, table, tr, td, p, a, img, br
3
+ import os
4
+
5
+
6
+ class HTML:
7
+ """This HTML class allows us to save images and write texts into a single HTML file.
8
+
9
+ It consists of functions such as <add_header> (add a text header to the HTML file),
10
+ <add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
11
+ It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
12
+ """
13
+
14
+ def __init__(self, web_dir, title, refresh=0):
15
+ """Initialize the HTML classes
16
+
17
+ Parameters:
18
+ web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
19
+ title (str) -- the webpage name
20
+ reflect (int) -- how often the website refresh itself; if 0; no refreshing
21
+ """
22
+ self.title = title
23
+ self.web_dir = web_dir
24
+ self.img_dir = os.path.join(self.web_dir, 'images')
25
+ if not os.path.exists(self.web_dir):
26
+ os.makedirs(self.web_dir)
27
+ if not os.path.exists(self.img_dir):
28
+ os.makedirs(self.img_dir)
29
+
30
+ self.doc = dominate.document(title=title)
31
+ if refresh > 0:
32
+ with self.doc.head:
33
+ meta(http_equiv="refresh", content=str(refresh))
34
+
35
+ def get_image_dir(self):
36
+ """Return the directory that stores images"""
37
+ return self.img_dir
38
+
39
+ def add_header(self, text):
40
+ """Insert a header to the HTML file
41
+
42
+ Parameters:
43
+ text (str) -- the header text
44
+ """
45
+ with self.doc:
46
+ h3(text)
47
+
48
+ def add_images(self, ims, txts, links, width=400):
49
+ """add images to the HTML file
50
+
51
+ Parameters:
52
+ ims (str list) -- a list of image paths
53
+ txts (str list) -- a list of image names shown on the website
54
+ links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
55
+ """
56
+ self.t = table(border=1, style="table-layout: fixed;") # Insert a table
57
+ self.doc.add(self.t)
58
+ with self.t:
59
+ with tr():
60
+ for im, txt, link in zip(ims, txts, links):
61
+ with td(style="word-wrap: break-word;", halign="center", valign="top"):
62
+ with p():
63
+ with a(href=os.path.join('images', link)):
64
+ img(style="width:%dpx" % width, src=os.path.join('images', im))
65
+ br()
66
+ p(txt)
67
+
68
+ def save(self):
69
+ """save the current content to the HMTL file"""
70
+ html_file = '%s/index.html' % self.web_dir
71
+ f = open(html_file, 'wt')
72
+ f.write(self.doc.render())
73
+ f.close()
74
+
75
+
76
+ if __name__ == '__main__': # we show an example usage here.
77
+ html = HTML('web/', 'test_html')
78
+ html.add_header('hello world')
79
+
80
+ ims, txts, links = [], [], []
81
+ for n in range(4):
82
+ ims.append('image_%d.png' % n)
83
+ txts.append('text_%d' % n)
84
+ links.append('image_%d.png' % n)
85
+ html.add_images(ims, txts, links)
86
+ html.save()
BidirectionalTranslation/util/util.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import os
6
+ import pickle
7
+
8
+
9
+ def tensor2im(input_image, imtype=np.uint8):
10
+ """"Convert a Tensor array into a numpy image array.
11
+ Parameters:
12
+ input_image (tensor) -- the input image tensor array
13
+ imtype (type) -- the desired type of the converted numpy array
14
+ """
15
+ if not isinstance(input_image, np.ndarray):
16
+ if isinstance(input_image, torch.Tensor): # get the data from a variable
17
+ image_tensor = input_image.data
18
+ else:
19
+ return input_image
20
+ image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
21
+ if image_numpy.shape[0] == 1: # grayscale to RGB
22
+ image_numpy = np.tile(image_numpy, (3, 1, 1))
23
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
24
+ else: # if it is a numpy array, do nothing
25
+ image_numpy = input_image
26
+ return image_numpy.astype(imtype)
27
+
28
+
29
+ def tensor2vec(vector_tensor):
30
+ numpy_vec = vector_tensor.data.cpu().numpy()
31
+ if numpy_vec.ndim == 4:
32
+ return numpy_vec[:, :, 0, 0]
33
+ else:
34
+ return numpy_vec
35
+
36
+
37
+ def pickle_load(file_name):
38
+ data = None
39
+ with open(file_name, 'rb') as f:
40
+ data = pickle.load(f)
41
+ return data
42
+
43
+
44
+ def pickle_save(file_name, data):
45
+ with open(file_name, 'wb') as f:
46
+ pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
47
+
48
+
49
+ def diagnose_network(net, name='network'):
50
+ """Calculate and print the mean of average absolute(gradients)
51
+ Parameters:
52
+ net (torch network) -- Torch network
53
+ name (str) -- the name of the network
54
+ """
55
+ mean = 0.0
56
+ count = 0
57
+ for param in net.parameters():
58
+ if param.grad is not None:
59
+ mean += torch.mean(torch.abs(param.grad.data))
60
+ count += 1
61
+ if count > 0:
62
+ mean = mean / count
63
+ print(name)
64
+ print(mean)
65
+
66
+
67
+ def interp_z(z0, z1, num_frames, interp_mode='linear'):
68
+ zs = []
69
+ if interp_mode == 'linear':
70
+ for n in range(num_frames):
71
+ ratio = n / float(num_frames - 1)
72
+ z_t = (1 - ratio) * z0 + ratio * z1
73
+ zs.append(z_t[np.newaxis, :])
74
+ zs = np.concatenate(zs, axis=0).astype(np.float32)
75
+
76
+ if interp_mode == 'slerp':
77
+ z0_n = z0 / (np.linalg.norm(z0) + 1e-10)
78
+ z1_n = z1 / (np.linalg.norm(z1) + 1e-10)
79
+ omega = np.arccos(np.dot(z0_n, z1_n))
80
+ sin_omega = np.sin(omega)
81
+ if sin_omega < 1e-10 and sin_omega > -1e-10:
82
+ zs = interp_z(z0, z1, num_frames, interp_mode='linear')
83
+ else:
84
+ for n in range(num_frames):
85
+ ratio = n / float(num_frames - 1)
86
+ z_t = np.sin((1 - ratio) * omega) / sin_omega * z0 + np.sin(ratio * omega) / sin_omega * z1
87
+ zs.append(z_t[np.newaxis, :])
88
+ zs = np.concatenate(zs, axis=0).astype(np.float32)
89
+
90
+ return zs
91
+
92
+
93
+ def save_image(image_numpy, image_path):
94
+ """Save a numpy image to the disk
95
+ Parameters:
96
+ image_numpy (numpy array) -- input numpy array
97
+ image_path (str) -- the path of the image
98
+ """
99
+ image_pil = Image.fromarray(image_numpy)
100
+ image_pil.save(image_path)
101
+
102
+
103
+ def print_numpy(x, val=True, shp=False):
104
+ """Print the mean, min, max, median, std, and size of a numpy array
105
+ Parameters:
106
+ val (bool) -- if print the values of the numpy array
107
+ shp (bool) -- if print the shape of the numpy array
108
+ """
109
+ x = x.astype(np.float64)
110
+ if shp:
111
+ print('shape,', x.shape)
112
+ if val:
113
+ x = x.flatten()
114
+ print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
115
+ np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
116
+
117
+
118
+ def mkdirs(paths):
119
+ """create empty directories if they don't exist
120
+ Parameters:
121
+ paths (str list) -- a list of directory paths
122
+ """
123
+ if isinstance(paths, list) and not isinstance(paths, str):
124
+ for path in paths:
125
+ mkdir(path)
126
+ else:
127
+ mkdir(paths)
128
+
129
+
130
+ def mkdir(path):
131
+ """create a single empty directory if it didn't exist
132
+ Parameters:
133
+ path (str) -- a single directory path
134
+ """
135
+ if not os.path.exists(path):
136
+ os.makedirs(path, exist_ok=True)
BidirectionalTranslation/util/visualizer.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import sys
4
+ import ntpath
5
+ import time
6
+ from . import util
7
+ from . import html
8
+ from subprocess import Popen, PIPE
9
+ import cv2
10
+
11
+
12
+ # if sys.version_info[0] == 2:
13
+ # VisdomExceptionBase = Exception
14
+ # else:
15
+ # VisdomExceptionBase = ConnectionError
16
+
17
+
18
+ def save_images(webpage, images, names, image_path, aspect_ratio=1.0, width=256):
19
+ """Save images to the disk.
20
+ Parameters:
21
+ webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
22
+ images (numpy array list) -- a list of numpy array that stores images
23
+ names (str list) -- a str list stores the names of the images above
24
+ image_path (str) -- the string is used to create image paths
25
+ aspect_ratio (float) -- the aspect ratio of saved images
26
+ width (int) -- the images will be resized to width x width
27
+ This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
28
+ """
29
+ image_dir = webpage.get_image_dir()
30
+ name = ntpath.basename(image_path)
31
+
32
+ webpage.add_header(name)
33
+ ims, txts, links = [], [], []
34
+
35
+ for label, im_data in zip(names, images):
36
+ im = util.tensor2im(im_data)
37
+ image_name = '%s_%s.jpg' % (name, label)
38
+ save_path = os.path.join(image_dir, image_name)
39
+ h, w, _ = im.shape
40
+ if aspect_ratio > 1.0:
41
+ im = cv2.resize(im, (h, int(w * aspect_ratio)), interpolation=cv2.INTER_CUBIC)
42
+ if aspect_ratio < 1.0:
43
+ im = cv2.resize(im, (int(h / aspect_ratio), w), interpolation=cv2.INTER_CUBIC)
44
+ util.save_image(im, save_path)
45
+
46
+ ims.append(image_name)
47
+ txts.append(label)
48
+ links.append(image_name)
49
+ webpage.add_images(ims, txts, links, width=width)
50
+
51
+
52
+ class Visualizer():
53
+ """This class includes several functions that can display/save images and print/save logging information.
54
+ It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
55
+ """
56
+
57
+ def __init__(self, opt):
58
+ """Initialize the Visualizer class
59
+ Parameters:
60
+ opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
61
+ Step 1: Cache the training/test options
62
+ Step 2: connect to a visdom server
63
+ Step 3: create an HTML object for saveing HTML filters
64
+ Step 4: create a logging file to store training losses
65
+ """
66
+ self.opt = opt # cache the option
67
+ self.display_id = opt.display_id
68
+ self.use_html = opt.isTrain and not opt.no_html
69
+ self.win_size = opt.display_winsize
70
+ self.name = opt.name
71
+ self.port = opt.display_port
72
+ self.saved = False
73
+ # if self.display_id > 0: # connect to a visdom server given <display_port> and <display_server>
74
+ # import visdom
75
+ # self.ncols = opt.display_ncols
76
+ # self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
77
+ # if not self.vis.check_connection():
78
+ # self.create_visdom_connections()
79
+ if self.use_html: # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/
80
+ self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
81
+ self.img_dir = os.path.join(self.web_dir, 'images')
82
+ print('create web directory %s...' % self.web_dir)
83
+ util.mkdirs([self.web_dir, self.img_dir])
84
+ # create a logging file to store training losses
85
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
86
+ with open(self.log_name, "a") as log_file:
87
+ now = time.strftime("%c")
88
+ log_file.write('================ Training Loss (%s) ================\n' % now)
89
+
90
+ def reset(self):
91
+ """Reset the self.saved status"""
92
+ self.saved = False
93
+
94
+ def create_visdom_connections(self):
95
+ """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
96
+ cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
97
+ print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
98
+ print('Command: %s' % cmd)
99
+ Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
100
+
101
+ def display_current_results(self, visuals, epoch, save_result):
102
+ """Display current results on visdom; save current results to an HTML file.
103
+ Parameters:
104
+ visuals (OrderedDict) - - dictionary of images to display or save
105
+ epoch (int) - - the current epoch
106
+ save_result (bool) - - if save the current results to an HTML file
107
+ """
108
+ # if self.display_id > 0: # show images in the browser using visdom
109
+ # ncols = self.ncols
110
+ # if ncols > 0: # show all the images in one visdom panel
111
+ # ncols = min(ncols, len(visuals))
112
+ # h, w = next(iter(visuals.values())).shape[:2]
113
+ # table_css = """<style>
114
+ # table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
115
+ # table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
116
+ # </style>""" % (w, h) # create a table css
117
+ # # create a table of images.
118
+ # title = self.name
119
+ # label_html = ''
120
+ # label_html_row = ''
121
+ # images = []
122
+ # idx = 0
123
+ # for label, image in visuals.items():
124
+ # image_numpy = util.tensor2im(image)
125
+ # label_html_row += '<td>%s</td>' % label
126
+ # images.append(image_numpy.transpose([2, 0, 1]))
127
+ # idx += 1
128
+ # if idx % ncols == 0:
129
+ # label_html += '<tr>%s</tr>' % label_html_row
130
+ # label_html_row = ''
131
+ # white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
132
+ # while idx % ncols != 0:
133
+ # images.append(white_image)
134
+ # label_html_row += '<td></td>'
135
+ # idx += 1
136
+ # if label_html_row != '':
137
+ # label_html += '<tr>%s</tr>' % label_html_row
138
+ # try:
139
+ # self.vis.images(images, nrow=ncols, win=self.display_id + 1,
140
+ # padding=2, opts=dict(title=title + ' images'))
141
+ # label_html = '<table>%s</table>' % label_html
142
+ # self.vis.text(table_css + label_html, win=self.display_id + 2,
143
+ # opts=dict(title=title + ' labels'))
144
+ # except VisdomExceptionBase:
145
+ # self.create_visdom_connections()
146
+
147
+ # else: # show each image in a separate visdom panel;
148
+ # idx = 1
149
+ # try:
150
+ # for label, image in visuals.items():
151
+ # image_numpy = util.tensor2im(image)
152
+ # self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
153
+ # win=self.display_id + idx)
154
+ # idx += 1
155
+ # except VisdomExceptionBase:
156
+ # self.create_visdom_connections()
157
+
158
+ if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
159
+ self.saved = True
160
+ # save images to the disk
161
+ for label, image in visuals.items():
162
+ image_numpy = util.tensor2im(image)
163
+ img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
164
+ util.save_image(image_numpy, img_path)
165
+
166
+ # update website
167
+ webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)
168
+ for n in range(epoch, 0, -1):
169
+ webpage.add_header('epoch [%d]' % n)
170
+ ims, txts, links = [], [], []
171
+
172
+ for label, image_numpy in visuals.items():
173
+ image_numpy = util.tensor2im(image)
174
+ img_path = 'epoch%.3d_%s.png' % (n, label)
175
+ ims.append(img_path)
176
+ txts.append(label)
177
+ links.append(img_path)
178
+ webpage.add_images(ims, txts, links, width=self.win_size)
179
+ webpage.save()
180
+
181
+ def plot_current_losses(self, epoch, counter_ratio, losses):
182
+ """display the current losses on visdom display: dictionary of error labels and values
183
+ Parameters:
184
+ epoch (int) -- current epoch
185
+ counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
186
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
187
+ """
188
+ if not hasattr(self, 'plot_data'):
189
+ self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
190
+ self.plot_data['X'].append(epoch + counter_ratio)
191
+ self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
192
+ # try:
193
+ # self.vis.line(
194
+ # X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
195
+ # Y=np.array(self.plot_data['Y']),
196
+ # opts={
197
+ # 'title': self.name + ' loss over time',
198
+ # 'legend': self.plot_data['legend'],
199
+ # 'xlabel': 'epoch',
200
+ # 'ylabel': 'loss'},
201
+ # win=self.display_id)
202
+ # except VisdomExceptionBase:
203
+ # self.create_visdom_connections()
204
+
205
+ # losses: same format as |losses| of plot_current_losses
206
+ def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
207
+ """print current losses on console; also save the losses to the disk
208
+ Parameters:
209
+ epoch (int) -- current epoch
210
+ iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
211
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
212
+ t_comp (float) -- computational time per data point (normalized by batch_size)
213
+ t_data (float) -- data loading time per data point (normalized by batch_size)
214
+ """
215
+ message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
216
+ for k, v in losses.items():
217
+ message += '%s: %.3f ' % (k, v)
218
+
219
+ print(message) # print the message
220
+ with open(self.log_name, "a") as log_file:
221
+ log_file.write('%s\n' % message) # save the message
app.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import gc
3
+ import json
4
+ import logging
5
+ import math
6
+ import os
7
+ import random
8
+ import shutil
9
+ import sys
10
+ import time
11
+ import itertools
12
+ from pathlib import Path
13
+
14
+ import cv2
15
+ import numpy as np
16
+ from PIL import Image, ImageDraw
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import torch.utils.checkpoint
20
+ from torch.utils.data import Dataset
21
+ from torchvision import transforms
22
+ from tqdm.auto import tqdm
23
+
24
+ import accelerate
25
+ from accelerate import Accelerator
26
+ from accelerate.logging import get_logger
27
+ from accelerate.utils import ProjectConfiguration, set_seed
28
+
29
+ from datasets import load_dataset
30
+ from huggingface_hub import create_repo, upload_folder
31
+ from packaging import version
32
+ from safetensors.torch import load_model
33
+ from peft import LoraConfig
34
+ import gradio as gr
35
+ import pandas as pd
36
+
37
+ import transformers
38
+ from transformers import (
39
+ AutoTokenizer,
40
+ PretrainedConfig,
41
+ CLIPVisionModelWithProjection,
42
+ CLIPImageProcessor,
43
+ CLIPProcessor,
44
+ )
45
+
46
+ import diffusers
47
+ from diffusers import (
48
+ AutoencoderKL,
49
+ DDPMScheduler,
50
+ ColorGuiderPixArtModel,
51
+ ColorGuiderSDModel,
52
+ UNet2DConditionModel,
53
+ PixArtTransformer2DModel,
54
+ ColorFlowPixArtAlphaPipeline,
55
+ ColorFlowSDPipeline,
56
+ UniPCMultistepScheduler,
57
+ )
58
+ from util_colorflow.utils import *
59
+
60
+ sys.path.append('./BidirectionalTranslation')
61
+ from options.test_options import TestOptions
62
+ from models import create_model
63
+ from util import util
64
+
65
+ from huggingface_hub import snapshot_download
66
+
67
+ model_global_path = snapshot_download(repo_id="JunhaoZhuang/ColorFlow", cache_dir='./colorflow/')
68
+ print(model_global_path)
69
+
70
+
71
+ transform = transforms.Compose([
72
+ transforms.ToTensor(),
73
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
74
+ ])
75
+ weight_dtype = torch.float16
76
+
77
+ # line model
78
+ line_model_path = model_global_path + '/LE/erika.pth'
79
+ line_model = res_skip()
80
+ line_model.load_state_dict(torch.load(line_model_path))
81
+ line_model.eval()
82
+ line_model.cuda()
83
+
84
+ # screen model
85
+ global opt
86
+
87
+ opt = TestOptions().parse(model_global_path)
88
+ ScreenModel = create_model(opt, model_global_path)
89
+ ScreenModel.setup(opt)
90
+ ScreenModel.eval()
91
+
92
+ image_processor = CLIPImageProcessor()
93
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(model_global_path + '/image_encoder/').to('cuda')
94
+
95
+
96
+ examples = [
97
+ [
98
+ "./assets/example_5/input.png",
99
+ ["./assets/example_5/ref1.png", "./assets/example_5/ref2.png", "./assets/example_5/ref3.png"],
100
+ "GrayImage(ScreenStyle)",
101
+ "800x512",
102
+ 0,
103
+ 10
104
+ ],
105
+ [
106
+ "./assets/example_4/input.jpg",
107
+ ["./assets/example_4/ref1.jpg", "./assets/example_4/ref2.jpg", "./assets/example_4/ref3.jpg"],
108
+ "GrayImage(ScreenStyle)",
109
+ "640x640",
110
+ 0,
111
+ 10
112
+ ],
113
+ [
114
+ "./assets/example_3/input.png",
115
+ ["./assets/example_3/ref1.png", "./assets/example_3/ref2.png", "./assets/example_3/ref3.png"],
116
+ "GrayImage(ScreenStyle)",
117
+ "800x512",
118
+ 0,
119
+ 10
120
+ ],
121
+ [
122
+ "./assets/example_2/input.png",
123
+ ["./assets/example_2/ref1.png", "./assets/example_2/ref2.png", "./assets/example_2/ref3.png"],
124
+ "GrayImage(ScreenStyle)",
125
+ "800x512",
126
+ 0,
127
+ 10
128
+ ],
129
+ [
130
+ "./assets/example_1/input.jpg",
131
+ ["./assets/example_1/ref1.jpg", "./assets/example_1/ref2.jpg", "./assets/example_1/ref3.jpg"],
132
+ "Sketch",
133
+ "640x640",
134
+ 0,
135
+ 10
136
+ ],
137
+ [
138
+ "./assets/example_0/input.jpg",
139
+ ["./assets/example_0/ref1.jpg"],
140
+ "Sketch",
141
+ "640x640",
142
+ 0,
143
+ 10
144
+ ],
145
+ ]
146
+
147
+ global pipeline
148
+ global MultiResNetModel
149
+
150
+ def load_ckpt(input_style):
151
+ global pipeline
152
+ global MultiResNetModel
153
+ if input_style == "Sketch":
154
+ ckpt_path = model_global_path + '/sketch/'
155
+ rank = 128
156
+ pretrained_model_name_or_path = 'PixArt-alpha/PixArt-XL-2-1024-MS'
157
+ transformer = PixArtTransformer2DModel.from_pretrained(
158
+ pretrained_model_name_or_path, subfolder="transformer", revision=None, variant=None
159
+ )
160
+ pixart_config = get_pixart_config()
161
+
162
+ ColorGuider = ColorGuiderPixArtModel.from_pretrained(ckpt_path)
163
+
164
+ transformer_lora_config = LoraConfig(
165
+ r=rank,
166
+ lora_alpha=rank,
167
+ init_lora_weights="gaussian",
168
+ target_modules=["to_k", "to_q", "to_v", "to_out.0", "proj_in", "proj_out", "ff.net.0.proj", "ff.net.2", "proj", "linear", "linear_1", "linear_2"]
169
+ )
170
+ transformer.add_adapter(transformer_lora_config)
171
+ ckpt_key_t = torch.load(ckpt_path + 'transformer_lora.bin', map_location='cpu')
172
+ transformer.load_state_dict(ckpt_key_t, strict=False)
173
+
174
+ transformer.to('cuda', dtype=weight_dtype)
175
+ ColorGuider.to('cuda', dtype=weight_dtype)
176
+
177
+ pipeline = ColorFlowPixArtAlphaPipeline.from_pretrained(
178
+ pretrained_model_name_or_path,
179
+ transformer=transformer,
180
+ colorguider=ColorGuider,
181
+ safety_checker=None,
182
+ revision=None,
183
+ variant=None,
184
+ torch_dtype=weight_dtype,
185
+ )
186
+ pipeline = pipeline.to("cuda")
187
+ block_out_channels = [128, 128, 256, 512, 512]
188
+
189
+ MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
190
+ MultiResNetModel.load_state_dict(torch.load(ckpt_path + 'MultiResNetModel.bin', map_location='cpu'), strict=False)
191
+ MultiResNetModel.to('cuda', dtype=weight_dtype)
192
+
193
+ elif input_style == "GrayImage(ScreenStyle)":
194
+ ckpt_path = model_global_path + '/GraySD/'
195
+ rank = 64
196
+ pretrained_model_name_or_path = 'stable-diffusion-v1-5/stable-diffusion-v1-5'
197
+ unet = UNet2DConditionModel.from_pretrained(
198
+ pretrained_model_name_or_path, subfolder="unet", revision=None, variant=None
199
+ )
200
+ ColorGuider = ColorGuiderSDModel.from_pretrained(ckpt_path)
201
+ ColorGuider.to('cuda', dtype=weight_dtype)
202
+ unet.to('cuda', dtype=weight_dtype)
203
+
204
+ pipeline = ColorFlowSDPipeline.from_pretrained(
205
+ pretrained_model_name_or_path,
206
+ unet=unet,
207
+ colorguider=ColorGuider,
208
+ safety_checker=None,
209
+ revision=None,
210
+ variant=None,
211
+ torch_dtype=weight_dtype,
212
+ )
213
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
214
+ unet_lora_config = LoraConfig(
215
+ r=rank,
216
+ lora_alpha=rank,
217
+ init_lora_weights="gaussian",
218
+ target_modules=["to_k", "to_q", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2"],#ff.net.0.proj ff.net.2
219
+ )
220
+ pipeline.unet.add_adapter(unet_lora_config)
221
+ pipeline.unet.load_state_dict(torch.load(ckpt_path + 'unet_lora.bin', map_location='cpu'), strict=False)
222
+ pipeline = pipeline.to("cuda")
223
+ block_out_channels = [128, 128, 256, 512, 512]
224
+
225
+ MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
226
+ MultiResNetModel.load_state_dict(torch.load(ckpt_path + 'MultiResNetModel.bin', map_location='cpu'), strict=False)
227
+ MultiResNetModel.to('cuda', dtype=weight_dtype)
228
+
229
+
230
+
231
+
232
+
233
+ global cur_input_style
234
+ cur_input_style = "Sketch"
235
+ load_ckpt(cur_input_style)
236
+ cur_input_style = "GrayImage(ScreenStyle)"
237
+ load_ckpt(cur_input_style)
238
+
239
+
240
+ def fix_random_seeds(seed):
241
+ random.seed(seed)
242
+ np.random.seed(seed)
243
+ torch.manual_seed(seed)
244
+ if torch.cuda.is_available():
245
+ torch.cuda.manual_seed(seed)
246
+ torch.cuda.manual_seed_all(seed)
247
+
248
+ def process_multi_images(files):
249
+ images = [Image.open(file.name) for file in files]
250
+ imgs = []
251
+ for i, img in enumerate(images):
252
+ imgs.append(img)
253
+ return imgs
254
+
255
+ def extract_lines(image):
256
+ src = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
257
+
258
+ rows = int(np.ceil(src.shape[0] / 16)) * 16
259
+ cols = int(np.ceil(src.shape[1] / 16)) * 16
260
+
261
+ patch = np.ones((1, 1, rows, cols), dtype="float32")
262
+ patch[0, 0, 0:src.shape[0], 0:src.shape[1]] = src
263
+
264
+ tensor = torch.from_numpy(patch).cuda()
265
+
266
+ with torch.no_grad():
267
+ y = line_model(tensor)
268
+
269
+ yc = y.cpu().numpy()[0, 0, :, :]
270
+ yc[yc > 255] = 255
271
+ yc[yc < 0] = 0
272
+
273
+ outimg = yc[0:src.shape[0], 0:src.shape[1]]
274
+ outimg = outimg.astype(np.uint8)
275
+ outimg = Image.fromarray(outimg)
276
+ torch.cuda.empty_cache()
277
+ return outimg
278
+
279
+ def to_screen_image(input_image):
280
+ global opt
281
+ global ScreenModel
282
+ input_image = input_image.convert('RGB')
283
+ input_image = get_ScreenVAE_input(input_image, opt)
284
+ h = input_image['h']
285
+ w = input_image['w']
286
+ ScreenModel.set_input(input_image)
287
+ fake_B, fake_B2, SCR = ScreenModel.forward(AtoB=True)
288
+ images=fake_B2[:,:,:h,:w]
289
+ im = util.tensor2im(images)
290
+ image_pil = Image.fromarray(im)
291
+ torch.cuda.empty_cache()
292
+ return image_pil
293
+
294
+ def extract_line_image(query_image_, input_style, resolution):
295
+ if resolution == "640x640":
296
+ tar_width = 640
297
+ tar_height = 640
298
+ elif resolution == "512x800":
299
+ tar_width = 512
300
+ tar_height = 800
301
+ elif resolution == "800x512":
302
+ tar_width = 800
303
+ tar_height = 512
304
+ else:
305
+ gr.Info("Unsupported resolution")
306
+
307
+ query_image = process_image(query_image_, int(tar_width*1.5), int(tar_height*1.5))
308
+ if input_style == "GrayImage(ScreenStyle)":
309
+ extracted_line = to_screen_image(query_image)
310
+ extracted_line = Image.blend(extracted_line.convert('L').convert('RGB'), query_image.convert('L').convert('RGB'), 0.5)
311
+ input_context = extracted_line
312
+ elif input_style == "Sketch":
313
+ query_image = query_image.convert('L').convert('RGB')
314
+ extracted_line = extract_lines(query_image)
315
+ extracted_line = extracted_line.convert('L').convert('RGB')
316
+ input_context = extracted_line
317
+ torch.cuda.empty_cache()
318
+ return input_context, extracted_line, input_context
319
+
320
+ def colorize_image(VAE_input, input_context, reference_images, resolution, seed, input_style, num_inference_steps):
321
+ if VAE_input is None or input_context is None:
322
+ gr.Info("Please preprocess the image first")
323
+ raise ValueError("Please preprocess the image first")
324
+ global cur_input_style
325
+ global pipeline
326
+ global MultiResNetModel
327
+ if input_style != cur_input_style:
328
+ gr.Info(f"Loading {input_style} model...")
329
+ load_ckpt(input_style)
330
+ cur_input_style = input_style
331
+ gr.Info(f"{input_style} model loaded")
332
+ reference_images = process_multi_images(reference_images)
333
+ fix_random_seeds(seed)
334
+ if resolution == "640x640":
335
+ tar_width = 640
336
+ tar_height = 640
337
+ elif resolution == "512x800":
338
+ tar_width = 512
339
+ tar_height = 800
340
+ elif resolution == "800x512":
341
+ tar_width = 800
342
+ tar_height = 512
343
+ else:
344
+ gr.Info("Unsupported resolution")
345
+ validation_mask = Image.open('./assets/mask.png').convert('RGB').resize((tar_width*2, tar_height*2))
346
+ gr.Info("Image retrieval in progress...")
347
+ query_image_bw = process_image(input_context, int(tar_width), int(tar_height))
348
+ query_image = query_image_bw.convert('RGB')
349
+ query_image_vae = process_image(VAE_input, int(tar_width*1.5), int(tar_height*1.5))
350
+ reference_images = [process_image(ref_image, tar_width, tar_height) for ref_image in reference_images]
351
+ query_patches_pil = process_image_Q_varres(query_image, tar_width, tar_height)
352
+ reference_patches_pil = []
353
+ for reference_image in reference_images:
354
+ reference_patches_pil += process_image_ref_varres(reference_image, tar_width, tar_height)
355
+ combined_image = None
356
+ with torch.no_grad():
357
+ clip_img = image_processor(images=query_patches_pil, return_tensors="pt").pixel_values.to(image_encoder.device, dtype=image_encoder.dtype)
358
+ query_embeddings = image_encoder(clip_img).image_embeds
359
+ reference_patches_pil_gray = [rimg.convert('RGB').convert('RGB') for rimg in reference_patches_pil]
360
+ clip_img = image_processor(images=reference_patches_pil_gray, return_tensors="pt").pixel_values.to(image_encoder.device, dtype=image_encoder.dtype)
361
+ reference_embeddings = image_encoder(clip_img).image_embeds
362
+ cosine_similarities = F.cosine_similarity(query_embeddings.unsqueeze(1), reference_embeddings.unsqueeze(0), dim=-1)
363
+ sorted_indices = torch.argsort(cosine_similarities, descending=True, dim=1).tolist()
364
+ top_k = 3
365
+ top_k_indices = [cur_sortlist[:top_k] for cur_sortlist in sorted_indices]
366
+ combined_image = Image.new('RGB', (tar_width * 2, tar_height * 2), 'white')
367
+ combined_image.paste(query_image_bw.resize((tar_width, tar_height)), (tar_width//2, tar_height//2))
368
+ idx_table = {0:[(1,0), (0,1), (0,0)], 1:[(1,3), (0,2),(0,3)], 2:[(2,0),(3,1), (3,0)], 3:[(2,3), (3,2),(3,3)]}
369
+ for i in range(2):
370
+ for j in range(2):
371
+ idx_list = idx_table[i * 2 + j]
372
+ for k in range(top_k):
373
+ ref_index = top_k_indices[i * 2 + j][k]
374
+ idx_y = idx_list[k][0]
375
+ idx_x = idx_list[k][1]
376
+ combined_image.paste(reference_patches_pil[ref_index].resize((tar_width//2-2, tar_height//2-2)), (tar_width//2 * idx_x + 1, tar_height//2 * idx_y + 1))
377
+ gr.Info("Model inference in progress...")
378
+ generator = torch.Generator(device='cuda').manual_seed(seed)
379
+ image = pipeline(
380
+ "manga", cond_image=combined_image, cond_mask=validation_mask, num_inference_steps=num_inference_steps, generator=generator
381
+ ).images[0]
382
+ gr.Info("Post-processing image...")
383
+ with torch.no_grad():
384
+ width, height = image.size
385
+ new_width = width // 2
386
+ new_height = height // 2
387
+ left = (width - new_width) // 2
388
+ top = (height - new_height) // 2
389
+ right = left + new_width
390
+ bottom = top + new_height
391
+ center_crop = image.crop((left, top, right, bottom))
392
+ up_img = center_crop.resize(query_image_vae.size)
393
+ test_low_color = transform(up_img).unsqueeze(0).to('cuda', dtype=weight_dtype)
394
+ query_image_vae = transform(query_image_vae).unsqueeze(0).to('cuda', dtype=weight_dtype)
395
+
396
+ h_color, hidden_list_color = pipeline.vae._encode(test_low_color,return_dict = False, hidden_flag = True)
397
+ h_bw, hidden_list_bw = pipeline.vae._encode(query_image_vae, return_dict = False, hidden_flag = True)
398
+
399
+ hidden_list_double = [torch.cat((hidden_list_color[hidden_idx], hidden_list_bw[hidden_idx]), dim = 1) for hidden_idx in range(len(hidden_list_color))]
400
+
401
+
402
+ hidden_list = MultiResNetModel(hidden_list_double)
403
+ output = pipeline.vae._decode(h_color.sample(),return_dict = False, hidden_list = hidden_list)[0]
404
+
405
+ output[output > 1] = 1
406
+ output[output < -1] = -1
407
+ high_res_image = Image.fromarray(((output[0] * 0.5 + 0.5).permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8)).convert("RGB")
408
+ gr.Info("Colorization complete!")
409
+ torch.cuda.empty_cache()
410
+ return high_res_image, up_img, image, query_image_bw
411
+
412
+ with gr.Blocks() as demo:
413
+ gr.HTML(
414
+ """
415
+ <div style="text-align: center;">
416
+ <h1 style="text-align: center; font-size: 3em;">🎨 ColorFlow:</h1>
417
+ <h3 style="text-align: center; font-size: 1.8em;">Retrieval-Augmented Image Sequence Colorization</h3>
418
+ <p style="text-align: center; font-weight: bold;">
419
+ <a href="https://zhuang2002.github.io/ColorFlow/">Project Page</a> |
420
+ <a href="https://arxiv.org/abs/">ArXiv Preprint</a> |
421
+ <a href="https://github.com/TencentARC/ColorFlow">GitHub Repository</a>
422
+ </p>
423
+ <p style="text-align: center; font-weight: bold;">
424
+ NOTE: Each time you switch the input style, the corresponding model will be reloaded, which may take some time. Please be patient.
425
+ </p>
426
+ <p style="text-align: left; font-size: 1.1em;">
427
+ Welcome to the demo of <strong>ColorFlow</strong>. Follow the steps below to explore the capabilities of our model:
428
+ </p>
429
+ </div>
430
+ <div style="text-align: left; margin: 0 auto;">
431
+ <ol style="font-size: 1.1em;">
432
+ <li>Choose input style: GrayImage(ScreenStyle) or Sketch.</li>
433
+ <li>Upload your image: Use the 'Upload' button to select the image you want to colorize.</li>
434
+ <li>Preprocess the image: Click the 'Preprocess' button to decolorize the image.</li>
435
+ <li>Upload reference images: Upload multiple reference images to guide the colorization.</li>
436
+ <li>Set sampling parameters (optional): Adjust the settings and click the <b>Colorize</b> button.</li>
437
+ </ol>
438
+ <p>
439
+ ⏱️ <b>ZeroGPU Time Limit</b>: Hugging Face ZeroGPU has an inference time limit of 180 seconds. You may need to log in with a free account to use this demo. Large sampling steps might lead to timeout (GPU Abort). In that case, please consider logging in with a Pro account or running it on your local machine.
440
+ </p>
441
+ </div>
442
+ <div style="text-align: center;">
443
+ <p style="text-align: center; font-weight: bold;">
444
+ 注意:每次切换输入样式时,相应的模型将被重新加载,可能需要一些时间。请耐心等待。
445
+ </p>
446
+ <p style="text-align: left; font-size: 1.1em;">
447
+ 欢迎使用 <strong>ColorFlow</strong> 演示。请按照以下步骤探索我们模型的能力:
448
+ </p>
449
+ </div>
450
+ <div style="text-align: left; margin: 0 auto;">
451
+ <ol style="font-size: 1.1em;">
452
+ <li>选择输入样式:灰度图(ScreenStyle)、线稿。</li>
453
+ <li>上传您的图像:使用“上传”按钮选择要上色的图像。</li>
454
+ <li>预处理图像:点击“预处理”按钮以去色图像。</li>
455
+ <li>上传参考图像:上传多张参考图像以指导上色。</li>
456
+ <li>设置采样参数(可选):调整设置并点击 <b>上色</b> 按钮。</li>
457
+ </ol>
458
+ <p>
459
+ ⏱️ <b>ZeroGPU时间限制</b>:Hugging Face ZeroGPU 的推理时间限制为 180 秒。您可能需要使用免费帐户登录以使用此演示。大采样步骤可能会导致超时(GPU 中止)。在这种情况下,请考虑使用专业帐户登录或在本地计算机上运行。
460
+ </p>
461
+ </div>
462
+ """
463
+ )
464
+ VAE_input = gr.State()
465
+ input_context = gr.State()
466
+ # example_loading = gr.State(value=None)
467
+
468
+ with gr.Column():
469
+ with gr.Row():
470
+ input_style = gr.Radio(["GrayImage(ScreenStyle)", "Sketch"], label="Input Style", value="GrayImage(ScreenStyle)")
471
+ with gr.Row():
472
+ with gr.Column():
473
+ input_image = gr.Image(type="pil", label="Image to Colorize")
474
+ resolution = gr.Radio(["640x640", "512x800", "800x512"], label="Select Resolution(Width*Height)", value="640x640")
475
+ extract_button = gr.Button("Preprocess (Decolorize)")
476
+ extracted_image = gr.Image(type="pil", label="Decolorized Result")
477
+ with gr.Row():
478
+ reference_images = gr.Files(label="Reference Images (Upload multiple)", file_count="multiple")
479
+ with gr.Column():
480
+ output_gallery = gr.Gallery(label="Colorization Results", type="pil")
481
+ seed = gr.Slider(label="Random Seed", minimum=0, maximum=100000, value=0, step=1)
482
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=4, maximum=100, value=10, step=1)
483
+ colorize_button = gr.Button("Colorize")
484
+
485
+ # progress_text = gr.Textbox(label="Progress", interactive=False)
486
+
487
+
488
+ extract_button.click(
489
+ extract_line_image,
490
+ inputs=[input_image, input_style, resolution],
491
+ outputs=[extracted_image, VAE_input, input_context]
492
+ )
493
+ colorize_button.click(
494
+ colorize_image,
495
+ inputs=[VAE_input, input_context, reference_images, resolution, seed, input_style, num_inference_steps],
496
+ outputs=output_gallery
497
+ )
498
+
499
+ with gr.Column():
500
+ gr.Markdown("### Quick Examples")
501
+ gr.Examples(
502
+ examples=examples,
503
+ inputs=[input_image, reference_images, input_style, resolution, seed, num_inference_steps],
504
+ label="Examples",
505
+ examples_per_page=6,
506
+ )
507
+ demo.launch(server_name="0.0.0.0", server_port=22348)
assets/example_0/input.jpg ADDED
assets/example_0/ref1.jpg ADDED
assets/example_1/input.jpg ADDED
assets/example_1/ref1.jpg ADDED
assets/example_1/ref2.jpg ADDED
assets/example_1/ref3.jpg ADDED
assets/example_2/input.png ADDED
assets/example_2/ref1.png ADDED
assets/example_2/ref2.png ADDED
assets/example_2/ref3.png ADDED
assets/example_3/input.png ADDED
assets/example_3/ref1.png ADDED
assets/example_3/ref2.png ADDED
assets/example_3/ref3.png ADDED
assets/example_4/input.jpg ADDED
assets/example_4/ref1.jpg ADDED
assets/example_4/ref2.jpg ADDED
assets/example_4/ref3.jpg ADDED
assets/example_5/input.png ADDED
assets/example_5/ref1.png ADDED
assets/example_5/ref2.png ADDED
assets/example_5/ref3.png ADDED
assets/mask.png ADDED
diffusers/.github/ISSUE_TEMPLATE/bug-report.yml ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "\U0001F41B Bug Report"
2
+ description: Report a bug on Diffusers
3
+ labels: [ "bug" ]
4
+ body:
5
+ - type: markdown
6
+ attributes:
7
+ value: |
8
+ Thanks a lot for taking the time to file this issue 🤗.
9
+ Issues do not only help to improve the library, but also publicly document common problems, questions, workflows for the whole community!
10
+ Thus, issues are of the same importance as pull requests when contributing to this library ❤️.
11
+ In order to make your issue as **useful for the community as possible**, let's try to stick to some simple guidelines:
12
+ - 1. Please try to be as precise and concise as possible.
13
+ *Give your issue a fitting title. Assume that someone which very limited knowledge of Diffusers can understand your issue. Add links to the source code, documentation other issues, pull requests etc...*
14
+ - 2. If your issue is about something not working, **always** provide a reproducible code snippet. The reader should be able to reproduce your issue by **only copy-pasting your code snippet into a Python shell**.
15
+ *The community cannot solve your issue if it cannot reproduce it. If your bug is related to training, add your training script and make everything needed to train public. Otherwise, just add a simple Python code snippet.*
16
+ - 3. Add the **minimum** amount of code / context that is needed to understand, reproduce your issue.
17
+ *Make the life of maintainers easy. `diffusers` is getting many issues every day. Make sure your issue is about one bug and one bug only. Make sure you add only the context, code needed to understand your issues - nothing more. Generally, every issue is a way of documenting this library, try to make it a good documentation entry.*
18
+ - 4. For issues related to community pipelines (i.e., the pipelines located in the `examples/community` folder), please tag the author of the pipeline in your issue thread as those pipelines are not maintained.
19
+ - type: markdown
20
+ attributes:
21
+ value: |
22
+ For more in-detail information on how to write good issues you can have a look [here](https://huggingface.co/course/chapter8/5?fw=pt).
23
+ - type: textarea
24
+ id: bug-description
25
+ attributes:
26
+ label: Describe the bug
27
+ description: A clear and concise description of what the bug is. If you intend to submit a pull request for this issue, tell us in the description. Thanks!
28
+ placeholder: Bug description
29
+ validations:
30
+ required: true
31
+ - type: textarea
32
+ id: reproduction
33
+ attributes:
34
+ label: Reproduction
35
+ description: Please provide a minimal reproducible code which we can copy/paste and reproduce the issue.
36
+ placeholder: Reproduction
37
+ validations:
38
+ required: true
39
+ - type: textarea
40
+ id: logs
41
+ attributes:
42
+ label: Logs
43
+ description: "Please include the Python logs if you can."
44
+ render: shell
45
+ - type: textarea
46
+ id: system-info
47
+ attributes:
48
+ label: System Info
49
+ description: Please share your system info with us. You can run the command `diffusers-cli env` and copy-paste its output below.
50
+ placeholder: Diffusers version, platform, Python version, ...
51
+ validations:
52
+ required: true
53
+ - type: textarea
54
+ id: who-can-help
55
+ attributes:
56
+ label: Who can help?
57
+ description: |
58
+ Your issue will be replied to more quickly if you can figure out the right person to tag with @.
59
+ If you know how to use git blame, that is the easiest way, otherwise, here is a rough guide of **who to tag**.
60
+
61
+ All issues are read by one of the core maintainers, so if you don't know who to tag, just leave this blank and
62
+ a core maintainer will ping the right person.
63
+
64
+ Please tag a maximum of 2 people.
65
+
66
+ Questions on DiffusionPipeline (Saving, Loading, From pretrained, ...): @sayakpaul @DN6
67
+
68
+ Questions on pipelines:
69
+ - Stable Diffusion @yiyixuxu @asomoza
70
+ - Stable Diffusion XL @yiyixuxu @sayakpaul @DN6
71
+ - Stable Diffusion 3: @yiyixuxu @sayakpaul @DN6 @asomoza
72
+ - Kandinsky @yiyixuxu
73
+ - ControlNet @sayakpaul @yiyixuxu @DN6
74
+ - T2I Adapter @sayakpaul @yiyixuxu @DN6
75
+ - IF @DN6
76
+ - Text-to-Video / Video-to-Video @DN6 @a-r-r-o-w
77
+ - Wuerstchen @DN6
78
+ - Other: @yiyixuxu @DN6
79
+ - Improving generation quality: @asomoza
80
+
81
+ Questions on models:
82
+ - UNet @DN6 @yiyixuxu @sayakpaul
83
+ - VAE @sayakpaul @DN6 @yiyixuxu
84
+ - Transformers/Attention @DN6 @yiyixuxu @sayakpaul
85
+
86
+ Questions on single file checkpoints: @DN6
87
+
88
+ Questions on Schedulers: @yiyixuxu
89
+
90
+ Questions on LoRA: @sayakpaul
91
+
92
+ Questions on Textual Inversion: @sayakpaul
93
+
94
+ Questions on Training:
95
+ - DreamBooth @sayakpaul
96
+ - Text-to-Image Fine-tuning @sayakpaul
97
+ - Textual Inversion @sayakpaul
98
+ - ControlNet @sayakpaul
99
+
100
+ Questions on Tests: @DN6 @sayakpaul @yiyixuxu
101
+
102
+ Questions on Documentation: @stevhliu
103
+
104
+ Questions on JAX- and MPS-related things: @pcuenca
105
+
106
+ Questions on audio pipelines: @sanchit-gandhi
107
+
108
+
109
+
110
+ placeholder: "@Username ..."
diffusers/.github/ISSUE_TEMPLATE/config.yml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ contact_links:
2
+ - name: Questions / Discussions
3
+ url: https://github.com/huggingface/diffusers/discussions
4
+ about: General usage questions and community discussions
diffusers/.github/ISSUE_TEMPLATE/feature_request.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "\U0001F680 Feature Request"
3
+ about: Suggest an idea for this project
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Is your feature request related to a problem? Please describe.**
11
+ A clear and concise description of what the problem is. Ex. I'm always frustrated when [...].
12
+
13
+ **Describe the solution you'd like.**
14
+ A clear and concise description of what you want to happen.
15
+
16
+ **Describe alternatives you've considered.**
17
+ A clear and concise description of any alternative solutions or features you've considered.
18
+
19
+ **Additional context.**
20
+ Add any other context or screenshots about the feature request here.
diffusers/.github/ISSUE_TEMPLATE/feedback.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "💬 Feedback about API Design"
3
+ about: Give feedback about the current API design
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **What API design would you like to have changed or added to the library? Why?**
11
+
12
+ **What use case would this enable or better enable? Can you give us a code example?**
diffusers/.github/ISSUE_TEMPLATE/new-model-addition.yml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "\U0001F31F New Model/Pipeline/Scheduler Addition"
2
+ description: Submit a proposal/request to implement a new diffusion model/pipeline/scheduler
3
+ labels: [ "New model/pipeline/scheduler" ]
4
+
5
+ body:
6
+ - type: textarea
7
+ id: description-request
8
+ validations:
9
+ required: true
10
+ attributes:
11
+ label: Model/Pipeline/Scheduler description
12
+ description: |
13
+ Put any and all important information relative to the model/pipeline/scheduler
14
+
15
+ - type: checkboxes
16
+ id: information-tasks
17
+ attributes:
18
+ label: Open source status
19
+ description: |
20
+ Please note that if the model implementation isn't available or if the weights aren't open-source, we are less likely to implement it in `diffusers`.
21
+ options:
22
+ - label: "The model implementation is available."
23
+ - label: "The model weights are available (Only relevant if addition is not a scheduler)."
24
+
25
+ - type: textarea
26
+ id: additional-info
27
+ attributes:
28
+ label: Provide useful links for the implementation
29
+ description: |
30
+ Please provide information regarding the implementation, the weights, and the authors.
31
+ Please mention the authors by @gh-username if you're aware of their usernames.
diffusers/.github/ISSUE_TEMPLATE/translate.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: 🌐 Translating a New Language?
3
+ about: Start a new translation effort in your language
4
+ title: '[<languageCode>] Translating docs to <languageName>'
5
+ labels: WIP
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ <!--
11
+ Note: Please search to see if an issue already exists for the language you are trying to translate.
12
+ -->
13
+
14
+ Hi!
15
+
16
+ Let's bring the documentation to all the <languageName>-speaking community 🌐.
17
+
18
+ Who would want to translate? Please follow the 🤗 [TRANSLATING guide](https://github.com/huggingface/diffusers/blob/main/docs/TRANSLATING.md). Here is a list of the files ready for translation. Let us know in this issue if you'd like to translate any, and we'll add your name to the list.
19
+
20
+ Some notes:
21
+
22
+ * Please translate using an informal tone (imagine you are talking with a friend about Diffusers 🤗).
23
+ * Please translate in a gender-neutral way.
24
+ * Add your translations to the folder called `<languageCode>` inside the [source folder](https://github.com/huggingface/diffusers/tree/main/docs/source).
25
+ * Register your translation in `<languageCode>/_toctree.yml`; please follow the order of the [English version](https://github.com/huggingface/diffusers/blob/main/docs/source/en/_toctree.yml).
26
+ * Once you're finished, open a pull request and tag this issue by including #issue-number in the description, where issue-number is the number of this issue. Please ping @stevhliu for review.
27
+ * 🙋 If you'd like others to help you with the translation, you can also post in the 🤗 [forums](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63).
28
+
29
+ Thank you so much for your help! 🤗